Merge pull request #45456 from ROCmSoftwarePlatform:google_upstream_rocm_remove_rocm_root

PiperOrigin-RevId: 348623263
Change-Id: I62a0648c15584864cbd045a530972ec3158aec9c
diff --git a/.bazelrc b/.bazelrc
index b81b153..640cfa8 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -145,10 +145,6 @@
 # opts in to modular op registration support by default.
 build --define framework_shared_object=true
 
-# Flags for open source build, always set to be true.
-build --define open_source_build=true
-test --define open_source_build=true
-
 # For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1
 build --java_toolchain=//third_party/toolchains/java:tf_java_toolchain
 build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
diff --git a/CODEOWNERS b/CODEOWNERS
index 9de1922..3b0565b 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -3,6 +3,8 @@
 /tensorflow/c/eager @qqfish @kkimdev
 /tensorflow/core/common_runtime/eager @qqfish @kkimdev
 /tenosrflow/core/debug @caisq
+/tensorflow/core/kernels/mkl/ @penpornk
+/tensorflow/core/kernels/sparse/ @penpornk
 /tensorflow/core/nccl/ @azaks2 @chsigg
 /tensorflow/core/platform/windows/ @mihaimaruseac
 /tensorflow/lite/experimental/micro @petewarden @advaitjain
diff --git a/README.md b/README.md
index ed2686b..34738ac 100644
--- a/README.md
+++ b/README.md
@@ -132,8 +132,8 @@
 **Linux ppc64le CPU** Stable Release                                                | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)                                                                                                                                                                                                                             | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
 **Linux ppc64le GPU** Nightly                                                       | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/)                                                                                                                                                                                                                                             | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
 **Linux ppc64le GPU** Stable Release                                                | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)                                                                                                                                                                                                                             | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
-**Linux aarch64 CPU** Nightly (Linaro)<br> Python 3.8                               | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-hpc-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-hpc-tensorflow/)                                                                                                                                                                                                                                                   | [Nightly](http://snapshots.linaro.org/hpc/python/tensorflow/latest/)
-**Linux aarch64 CPU** Stable Release (Linaro)                                       | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-hpc-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-hpc-tensorflow/)                                                                                                                                                                                                                                                   | Release [1.x & 2.x](http://snapshots.linaro.org/hpc/python/tensorflow/latest/)
+**Linux aarch64 CPU** Nightly (Linaro)                                              | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-python-tensorflow-nightly)](https://ci.linaro.org/jenkins/job/ldcg-python-tensorflow-nightly/)                                                                                                                                                                                                                             | [Nightly](http://snapshots.linaro.org/ldcg/python/tensorflow-nightly/latest/)
+**Linux aarch64 CPU** Stable Release (Linaro)                                       | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-python-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-python-tensorflow/)                                                                                                                                                                                                                                             | Release [1.x & 2.x](http://snapshots.linaro.org/ldcg/python/tensorflow/latest/)
 **Linux aarch64 CPU** Nightly (OpenLab)<br> Python 3.6                              | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master)                                                                                                                                                                              | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master)
 **Linux aarch64 CPU** Stable Release (OpenLab)                                      | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)
 **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly        | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)                                                                                                                                                                                                                           | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
diff --git a/RELEASE.md b/RELEASE.md
index 6340db0..b1847b7 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -32,6 +32,8 @@
 *   `tf.keras`:
     *   Improvements to Keras preprocessing layers:
         *   Discretization combiner implemented, with additional arg `epsilon`.
+    *   Improvements to model saving/loading:
+        *   `model.load_weights` now accepts paths to saved models.
 
 *   `tf.data`:
     *   Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
@@ -58,9 +60,11 @@
                 directly.
     *  16 bits quantization
         *   Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
-    *   Added support for saved model's session initializer through
+    *  Added support for saved model's session initializer through
          `TFLiteConverter.from_saved_model`.
-    *   Added dynamic range quantization support for the BatchMatMul op.
+    *  Added DEPTH_TO_SPACE support in Post training quantization.
+    *  Added dynamic range quantization support for the BatchMatMul op.
+        * Both symmetric and asymmetric quantized input tensor are supported.
     *  Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
        only supports float32 input.
     *  TFLite Supports SingatureDef:
@@ -95,6 +99,11 @@
         value of `is_dynamic_op` is not True. We didn't use the value for
         `max_batch_size` for building TensorRT engines.
     *   Issue a warning when function get_tensorrt_rewriter_config is used.
+*   Other:
+    *   Add new enum value `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED` to
+        `tf.config.experimental.mlir_bridge_rollout` to enable a \"safe\" mode.
+        This runs the MLIR bridge only when an analysis of the graph only when
+        an analysis of the graph determines that it is safe to run.
 
 ## Thanks to our Contributors
 
@@ -104,459 +113,448 @@
 
 # Release 2.4.0
 
-<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
+ ## Major Features and Improvements
+
+* `tf.distribute` introduces experimental support for asynchronous training of
+  models via the [`tf.distribute.experimental.ParameterServerStrategy`]
+  (https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/ParameterServerStrategy)
+  API. Please see the [tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training)
+  to learn more.
+
+* [`MultiWorkerMirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MultiWorkerMirroredStrategy)
+  is now a stable API and is no longer considered experimental. Some of the
+  major improvements involve handling peer failure and many bug fixes. Please
+  check out the detailed tutorial on [Multi-worker training with Keras]
+  (https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras).
+
+* Introduces experimental support for a new module named [`tf.experimental.numpy`]
+  (https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which is a
+  NumPy-compatible API for writing TF programs. See the [detailed guide]
+  (https://www.tensorflow.org/guide/tf_numpy) to learn more. Additional details below.
+
+* Adds Support for
+  [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)
+  on Ampere based GPUs. TensorFloat-32, or TF32 for short, is a math mode for
+  NVIDIA Ampere based GPUs and is enabled by default.
+
+* A major refactoring of the internals of the Keras Functional API has been
+  completed, that should improve the reliability, stability, and performance of
+  constructing Functional models.
+
+* Keras mixed precision API [`tf.keras.mixed_precision`]
+  (https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision?version=nightly)
+  is no longer experimental and allows the use of 16-bit floating point formats
+  during training, improving performance by up to 3x on GPUs and 60% on TPUs.
+  Please see below for additional details.
+
+* TensorFlow Profiler now supports profiling `MultiWorkerMirroredStrategy` and
+  tracing multiple workers using the [sampling mode API]
+  (https://www.tensorflow.org/guide/profiler#profiling_apis).
+
+* TFLite Profiler for Android is available. See the detailed [guide]
+  (https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android)
+  to learn more.
+
+* TensorFlow pip packages are now built with CUDA11 and cuDNN 8.0.2.
 
 ## Breaking Changes
 
-* <DOCUMENT BREAKING CHANGES HERE>
-* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
-* Certain float32 ops run in lower precsion on Ampere based GPUs, including 
-  matmuls and convolutions, due to the use of
-  [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
+* TF Core:
+  * Certain float32 ops run in lower precsion on Ampere based GPUs, including
+  matmuls and convolutions, due to the use of [TensorFloat-32]
+  (https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
   Specifically, inputs to such ops are rounded from 23 bits of precision to 10
-  bits of precision. This is unlikely to cause issues in practice for deep
-  learning models. In some cases, TensorFloat-32 is also used for complex64 ops.
-  TensorFloat-32 can be disabled by running
-  `config.experimental.enable_tensor_float_32_execution(False)`. The "Major
-  Features and Improvements" section has more details.
-* The byte layout for string tensors across the C-API has been updated to match
+  bits of precision. This is unlikely to cause issues in practice for deep learning
+  models. In some cases, TensorFloat-32 is also used for complex64 ops.
+  TensorFloat-32 can be disabled by running `tf.config.experimental.enable_tensor_float_32_execution(False)`.
+  * The byte layout for string tensors across the C-API has been updated to match
   TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
-* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
-  `TF_StringEncodedSize` are no longer relevant and have been removed; see
-  core/platform/ctstring.h for string access/modification in C.
-* Removed `tf.distribute.Strategy.experimental_run_v2` method, which was deprecated in TF 2.2.
-* `tensorflow.python`, `tensorflow.core` and `tensorflow.compiler` modules are
-    now hidden. These modules are not part of TensorFlow public API.
-* A major refactoring of the internals of the Keras Functional API may affect code that is relying on certain internal details:
-    * Code that uses `isinstance(x, tf.Tensor)` instead of `tf.is_tensor` when checking Keras symbolic inputs/outputs should switch to using `tf.is_tensor`.
-    * Code that is overly dependent on the exact names attached to symbolic tensors (e.g. assumes there will be ":0" at the end of the inputs, treats names as unique identifiers instead of using `tensor.ref()`, etc.)
-    * Code that uses `get_concrete_function` to trace Keras symbolic inputs directly should switch to building matching `tf.TensorSpec`s directly and tracing the `TensorSpec` objects.
-    * Code that relies on the exact number and names of the op layers that TensorFlow operations were converted into. These may have changed.
-    * Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers and happens to work before TF 2.4. These will explicitly be unsupported now. Converting these ops to Functional API op layers was unreliable before TF 2.4, and prone to erroring incomprehensibly or being silently buggy.
-    * Code that directly asserts on a Keras symbolic value in cases where ops like `tf.rank` used to return a static or symbolic value depending on if the input had a fully static shape or not. Now these ops always return symbolic values.
-    * Code already susceptible to leaking tensors outside of graphs becomes slightly more likely to do so now.
-    * Code that tries directly getting gradients with respect to symbolic Keras inputs/outputs. Use GradientTape on the actual Tensors passed to the already-constructed model instead.
-    * Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient.
-    * Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know.
-    * Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed.
-* Start enforcing input shape assumptions when calling Functional API Keras
+  * C-API functions `TF_StringDecode`, `TF_StringEncode`, and `TF_StringEncodedSize`
+  are no longer relevant and have been removed; see `core/platform/ctstring.h` for
+  string access/modification in C.
+  * `tensorflow.python`, `tensorflow.core` and `tensorflow.compiler` modules are
+  now hidden. These modules are not part of TensorFlow public API.
+  * `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
+  `tf.complex64` or `tf.complex128`, because the behavior of these ops is not
+  well defined for complex types.
+  * XLA:CPU and XLA:GPU devices are no longer registered by default. Use
+  `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them, but this
+  flag will eventually be removed in subsequent releases.
+
+* `tf.keras`:
+  * The `steps_per_execution` argument in `model.compile()` is no longer experimental;
+  if you were passing `experimental_steps_per_execution`, rename it to
+  `steps_per_execution` in your code. This argument controls the number of batches
+  to run during each `tf.function` call when calling `model.fit()`. Running multiple
+  batches inside a single `tf.function` call can greatly improve performance on
+  TPUs or small models with a large Python overhead.
+  * A **major refactoring** of the internals of the Keras Functional API may affect code that
+  is relying on certain internal details:
+    * Code that uses `isinstance(x, tf.Tensor)` instead of `tf.is_tensor` when
+  checking Keras symbolic inputs/outputs should switch to using `tf.is_tensor`.
+    * Code that is overly dependent on the exact names attached to symbolic tensors
+  (e.g. assumes there will be ":0" at the end of the inputs, treats names as
+  unique identifiers instead of using `tensor.ref()`, etc.) may break.
+    * Code that uses full path for `get_concrete_function` to trace Keras symbolic
+  inputs directly should switch to building matching `tf.TensorSpec`s directly and
+  tracing the `TensorSpec` objects.
+    * Code that relies on the exact number and names of the op layers that TensorFlow
+  operations  were converted into may have changed.
+    * Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers
+  and  happens to work before TF 2.4. These will explicitly be unsupported now.
+  Converting these ops to Functional API op layers was unreliable before TF 2.4,
+  and prone to erroring incomprehensibly  or being silently buggy.
+    * Code that directly asserts on a Keras symbolic value in cases where ops
+  like `tf.rank` used to  return a static or symbolic value depending on if the
+  input had a fully static shape or not. Now these ops always return symbolic values.
+    * Code already susceptible to leaking tensors outside of graphs becomes slightly
+  more likely to do so now.
+    * Code that tries directly getting gradients with respect to symbolic Keras
+  inputs/outputs. Use `GradientTape` on the actual Tensors passed to the already-constructed
+  model instead.
+    * Code that requires very tricky shape manipulation via converted op layers
+  in order to work, where the Keras symbolic shape inference proves insufficient.
+    * Code that tries manually walking a `tf.keras.Model` layer by layer and assumes
+  layers only ever have one positional argument. This assumption doesn't hold
+  true before TF 2.4 either, but is more likely to cause issues now.
+    * Code that manually enters `keras.backend.get_graph()` before building a
+  functional model is no longer needed.
+    * Start enforcing input shape assumptions when calling Functional API Keras
   models. This may potentially break some users, in case there is a mismatch
   between the shape used when creating `Input` objects in a Functional model,
   and the shape of the data passed to that model. You can fix this mismatch by
-  either calling the model with correctly-shaped data, or by relaxing `Input`
-  shape assumptions (note that you can pass shapes with `None` entries for axes
-  that are meant to be dynamic). You can also disable the input checking
-  entirely by setting `model.input_spec = None`.
-* TF pip packages now use CUDA11 and cuDNN 8.0.2.
-* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
-  `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
-  removed).
-* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
-  `tf.complex64` or `tf.complex128`, because the behavior of these ops is not
-  well defined for complex types.
-* `tf.data.experimental.service.DispatchServer` now takes a config tuple
+  either calling the model with correctly-shaped data, or by relaxing `Input` shape
+  assumptions (note that you can pass shapes with `None` entries for axes that
+  are meant to be dynamic). You can also disable the input checking entirely by
+  setting `model.input_spec = None`.
+  * Several changes have been made to `tf.keras.mixed_precision.experimental`.
+  Note that it is now recommended to use the non-experimental
+  `tf.keras.mixed_precision` API.
+   * `AutoCastVariable.dtype` now refers to the actual variable dtype, not the
+  dtype it will be casted to.
+   * When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a
+  float16 or bfloat16 tensor instead of a float32 tensor.
+   * The property `tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale`
+  is now a tensor, not a `LossScale` object. This means to get a loss scale
+  of a `LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale`instead of `opt.loss_scale()`.
+   * The property `should_cast_variables` has been removed from `tf.keras.mixed_precision.experimental.Policy`
+   * When passing a `tf.mixed_precision.experimental.DynamicLossScale` to `tf.keras.mixed_precision.experimental.LossScaleOptimizer`,
+  the `DynamicLossScale`'s multiplier must be 2.
+   * When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
+  `tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of
+  the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of being reused.
+  This means modifying the weights of the `DynamicLossScale` will no longer affect the weights of the LossScaleOptimizer, and vice versa.
+   * The global policy can no longer be set to a non-floating point policy in `tf.keras.mixed_precision.experimental.set_policy`
+   * In `Layer.call`, `AutoCastVariable`s will no longer be casted within
+  `MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a thread local
+  variable is used to determine whether `AutoCastVariable`s are casted, and those
+  two functions run with a different thread. Note this only applies if one of
+  these two functions is called within `Layer.call`; if one of those two functions calls `Layer.call`, `AutoCastVariable`s will still be casted.
+
+* `tf.data`:
+  * `tf.data.experimental.service.DispatchServer` now takes a config tuple
   instead of individual arguments. Usages should be updated to
   `tf.data.experimental.service.DispatchServer(dispatcher_config)`.
-* `tf.data.experimental.service.WorkerServer` now takes a config tuple
-  instead of individual arguments. Usages should be updated to
-  `tf.data.experimental.service.WorkerServer(worker_config)`.
-* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which
-  updates the gradient definition for quantization which is outside the range
-  to be 0. To simulate the V1 the behavior of
-  tf.quantization.quantize_and_dequantize(...) use
-  tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
-* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
-  use `tf.data.Dataset.from_tensor_slices` instead.
-* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
-  `tf.distribute.StrategyExtended.batch_reduce_to`,
-  `tf.distribute.ReplicaContext.all_reduce` are renamed to `options`.
-  `tf.distribute.experimental.CollectiveHints` is renamed
-  `tf.distribute.experimental.CommunicationOptions`.
-  `tf.distribute.experimental.CollectiveCommunication` is renamed
-  `tf.distribute.experimental.CommunicationImplementation`.
-* `tf.keras.mixed_precision.experimental`:
-  * `AutoCastVariable.dtype` now refers to the actual variable dtype, not the
-    dtype it will be casted to.
-  * When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a
-    float16 or bfloat16 tensor instead of a float32 tensor.
-  * The property
-    `tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale` is now
-    a tensor, not a `LossScale` object. This means to get a loss scale of a
-    `LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale` instead
-    of `opt.loss_scale()`.
-  * The property `should_cast_variables` has been removed from
-    `tf.keras.mixed_precision.experimental.Policy`
-  * When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
-    `tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the
-    `DynamicLossScale`'s multiplier must be 2.
-  * When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
-    `tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of
-    the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of
-    being reused. This means modifying the weights of the `DynamicLossScale`
-    will no longer affect the weights of the LossScaleOptimizer, and vice versa.
-  * The global policy can no longer be set to a non-floating point policy in
-    `tf.keras.mixed_precision.experimental.set_policy`
-  * In `Layer.call`, `AutoCastVariable`s will no longer be casted within
-    `MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a
-    thread local variable is used to determine whether `AutoCastVariable`s are
-    casted, and those two functions run with a different thread. Note this only
-    applies if one of these two functions is called within `Layer.call`; if one
-    of those two functions calls `Layer.call`, `AutoCastVariable`s will still be
-    casted.
-
-## Known Caveats
-
-* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES). E.G. ADDING A NEW DEPENDENCY, BUMPING A DEPENDENCY NUMBER, LACK OF SUPPORT ON SOME PLATFORM, ETC>
-
-## Major Features and Improvements
-
-* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
-* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
-* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy.
-* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models.
-* Support for
-  [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)
-  on Ampere based GPUs has been added. TensorFloat-32, or TF32 for short, is a
-  math mode for NVIDIA Ampere GPUs which causes certain float32 ops, such as
-  matrix multiplications and convolutions, to run much faster on Ampere GPUs but
-  with reduced precision. This reduced precision has not been found to effect
-  convergence quality of deep learning models in practice. TensorFloat-32 is
-  enabled by default, but can be disabled with
-  `tf.config.experimental.enable_tensor_float_32_execution`.
+  * `tf.data.experimental.service.WorkerServer` now takes a config tuple instead
+  of individual arguments. Usages should be updated to  `tf.data.experimental.service.WorkerServer(worker_config)`.
 
 * `tf.distribute`:
-  * `MultiWorkerMirroredStrategy` is graduated out of experimental.
-    * Peer failure will no longer cause the cluster to hang.
-    * Major issues with saving are fixed.
-    * See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial.
-  * Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
-* The `tf.keras.mixed_precision` API has been made non-experimental. The major
-  changes to the new non-experimental API are:
-  * `tf.keras.mixed_precision.Policy` no longer takes in a
-    `tf.mixed_precision.experimental.LossScale` in the constructor, and no
-    longer has a `LossScale` associated with it. Instead, `Model.compile` will
-    automatically wrap the optimizer with a `LossScaleOptimizer` using dynamic
-    loss scaling if `Policy.name` is "mixed_float16".
-  * `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in
-    different arguments. In particular, it no longer takes in a `LossScale`, and
-    there is no longer a `LossScale` associated with the `LossScaleOptimizer`.
-    Instead, `LossScaleOptimizer` directly implements fixed or dynamic loss
-    scaling. See the documentation of
-    `tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details on
-    the differences between the experimental `LossScaleOptimizer` and the new
-    non-experimental `LossScaleOptimizer`.
-  * `tf.mixed_precision.experimental.LossScale` and its subclasses are
-    deprecated, as all of its functionality now exists within
-    `tf.keras.mixed_precision.LossScaleOptimizer`
+  * Removes `tf.distribute.Strategy.experimental_make_numpy_dataset`. Please use
+  `tf.data.Dataset.from_tensor_slices` instead.
+  * Renames `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
+  `tf.distribute.StrategyExtended.batch_reduce_to`, `tf.distribute.ReplicaContext.all_reduce`
+  to `options`.
+  * Renames `tf.distribute.experimental.CollectiveHints` to `tf.distribute.experimental.CommunicationOptions`.
+  * Renames `tf.distribute.experimental.CollectiveCommunication` to `tf.distribute.experimental.CommunicationImplementation`.
+  * Renames `tf.distribute.Strategy.experimental_distribute_datasets_from_function` to `distribute_datasets_from_function` as it is no longer experimental.
+  * Removes `tf.distribute.Strategy.experimental_run_v2` method, which was deprecated in TF 2.2.
+
+* `tf.lite`:
+  * `tf.quantization.quantize_and_dequantize_v2` has been introduced, which updates the gradient definition for quantization which is outside the range
+     to be 0. To simulate the V1 the behavior of `tf.quantization.quantize_and_dequantize(...)` use
+  `tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...)`.
+
+* Building TensorFlow:
+  * Windows platform builds: TensorFlow on Windows under MSVC is now built with
+  `--copt=/experimental:preprocessor --host_copt=/experimental:preprocessor`
+  (see `.bazelrc` for more details). Builds including TensorFlow may fail with
+  unexpected syntax errors if these flags are absent. See also
+  [this thread on SIG Build](https://groups.google.com/a/tensorflow.org/g/build/c/LbAw8RILvTg/m/ttnuhYU2BgAJ).
+
+## Known Caveats
+  * `tf.keras.mixed_precision`
+    * When using mixed precision, calling `RMSprop.apply_gradients` or
+  `Nadam.apply_gradients` outside a `tf.function` does not work and will raise
+  the AttributeError "Tensor.op is meaningless when eager execution is enabled".
+  See this [issue](https://github.com/tensorflow/tensorflow/issues/45536) for details and a workaround.
 
 ## Bug Fixes and Other Changes
 
-*   <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
-*   <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
-*   <NOTES SHOULD BE GROUPED PER AREA>
-*   Security:
-    *   Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
-        ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
-    *   Fixes three vulnerabilities in conversion to DLPack format
-        ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
-        [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
-        [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
-    *   Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
-        ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
-        [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
-    *   Fixes several vulnerabilities in `RaggedCountSparseOutput` and
-        `SparseCountSparseOutput` operations
-        ([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
-        [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
-        [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
-        [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
-        [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
-        [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
-    *   Fixes an integer truncation vulnerability in code using the work sharder
-        API
-        ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
-    *   Fixes a format string vulnerability in `tf.strings.as_string`
-        ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
-    *   Fixes segfault raised by calling session-only ops in eager mode
-        ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
-    *   Fixes data leak and potential ASLR violation from
-        `tf.raw_ops.StringNGrams`
-        ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
-    *   Fixes segfaults caused by incomplete `SavedModel` validation
-        ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
-    *   Fixes a data corruption due to a bug in negative indexing support in
-        TFLite
-        ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
-    *   Fixes a data corruption due to dimension mismatch in TFLite
-        ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
-    *   Fixes several vulnerabilities in TFLite saved model format
-        ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
-        [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
-        [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
-    *   Fixes several vulnerabilities in TFLite implementation of segment sum
-        ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
-        [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
-        [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
-    *   Fixes a segfault in `tf.quantization.quantize_and_dequantize`
-        ([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265))
-    *   Fixes an undefined behavior float cast causing a crash
-        ([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266))
-*   TF Core:
-    *   `tf.types.experimental.TensorLike` is a new `Union` type that can be
-        used as type annotation for variables representing a Tensor or a value
-        that can be converted to Tensor by `tf.convert_to_tensor`.
-    *   Calling ops with a python constants or numpy values is now consistent
-        with tf.convert_to_tensor behavior. This avoids operations like
-        tf.reshape truncating inputs such as from int64 to int32.
-    *   Added `tf.sparse.map_values` to apply a function to the `.value`s of
-        `SparseTensor` arguments.
-    *   The Python bitwise operators for `Tensor` (`__and__`, `__or__`,
-        `__xor__` and `__invert__` now support non-`bool` arguments and apply
-        the corresponding bitwise ops. `bool` arguments continue to be supported
-        and dispatch to logical ops. This brings them more in line with Python
-        and NumPy behavior.
-    *   Added `tf.SparseTensor.with_values`. This returns a new SparseTensor
-        with the same sparsity pattern, but with new provided values. It is
-        similar to the `with_values` function of `RaggedTensor`.
-    *   Added `StatelessCase` op, and uses it if none of case branches has
-        stateful ops.
-    *   Added `tf.config.experimental.get_memory_usage` to return total memory
-        usage of the device.
-    *   Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
-    *   Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions.
-*   `tf.data`:
-    *   tf.data service:
-    *   Added new `tf.data.experimental.service.register_dataset` and
-        `tf.data.experimental.service.from_dataset_id` APIs to enable one
-        process to register a dataset with the tf.data service, and another
-        process to consume data from the dataset.
-    *   Added support for dispatcher fault tolerance. To enable fault tolerance,
-        configure a `work_dir` when running your dispatcher server and set
-        `dispatcher_fault_tolerance=True`. The dispatcher will store its state
-        to `work_dir`, so that on restart it can continue from its previous
-        state after restart.
-    *   Added support for sharing dataset graphs via shared filesystem instead
-        of over RPC. This reduces load on the dispatcher, improving performance
-        of distributing datasets. For this to work, the dispatcher's `work_dir`
-        must be accessible from workers. If the worker fails to read from the
-        `work_dir`, it falls back to using RPC for dataset graph transfer.
-    *   Added support for a new "distributed_epoch" processing mode. This
-        processing mode distributes a dataset across all tf.data workers,
-        instead of having each worker process the full dataset. See
-        [the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
-        to learn more.
-    *   Added optional `exclude_cols` parameter to CsvDataset. This parameter is
-        the complement of `select_cols`; at most one of these should be
-        specified.
-    *   We have implemented an optimization which reorders data-discarding
-        transformations such as `take` and `shard` to happen earlier in the
-        dataset when it is safe to do so. The optimization can be disabled via
-        the `experimental_optimization.reorder_data_discarding_ops` dataset
-        option.
-    *   `tf.data.Options` were previously immutable and can now be overridden.
-    *   `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
-        with a new `output_signature` argument, which allows `from_generator` to
-        produce any type describable by a `tf.TypeSpec`.
-    *   `tf.data.experimental.AUTOTUNE` is now available in the core API as
-        `tf.data.AUTOTUNE`.
-*   `tf.image`:
-    *   Added deterministic `tf.image.stateless_random_*` functions for each
-        `tf.image.random_*` function. Added a new op
-        `stateless_sample_distorted_bounding_box` which is a deterministic
-        version of `sample_distorted_bounding_box` op. Given the same seed,
-        these stateless functions/ops produce the same results independent of
-        how many times the function is called, and independent of global seed
-        settings.
-*   `tf.distribute`:
-    *   (Experimental) Parameter server training:
-        *   Replaced the existing
-            `tf.distribute.experimental.ParameterServerStrategy` symbol with
-            a new class that is for parameter server training in TF2. Usage with
-            the old symbol, usually with Estimator, should be replaced with
-            `tf.compat.v1.distribute.experimental.ParameterServerStrategy`.
-        *   Added `tf.distribute.experimental.coordinator.*` namespace,
-            including the main API `ClusterCoordinator` for coordinating the
-            training cluster, the related data structure `RemoteValue`
-            and `PerWorkerValue`.
-*   `tf.keras`:
-    *   Improvements from the functional API refactoring:
-        *   Functional model construction does not need to maintain a global
-            workspace graph, removing memory leaks especially when building many
-            models or very large models.
-        *   Functional model construction should be ~8-10% faster on average.
-        *   Functional models can now contain non-symbolic values in their call
-            inputs inside of the first positional argument.
-        *   Several classes of TF ops that were not reliably converted to Keras
-            layers during functional API construction should now work, e.g.
-            `tf.image.ssim_multiscale`
-        *   Error messages when Functional API construction goes wrong (and when
-            ops cannot be converted to Keras layers automatically) should be
-            clearer and easier to understand.
-    *   `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
-        as an alternative to accepting a `callable` loss.
-    *   Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
-        to match FTRL paper
-        (https://research.google.com/pubs/archive/41159.pdf).
-    *   Added `mobilenet_v3` to keras application model.
-    *   `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
-        customization of how gradients are aggregated across devices, as well as
-        `gradients_transformers` to allow for custom gradient transformations
-        (such as gradient clipping).
-    *   The `steps_per_execution` argument in `compile()` is no longer
-        experimental; if you were passing `experimental_steps_per_execution`,
-        rename it to `steps_per_execution` in your code. This argument controls
-        the number of batches to run during each `tf.function` call when calling
-        `fit()`. Running multiple batches inside a single `tf.function` call can
-        greatly improve performance on TPUs or small models with a large Python
-        overhead.
-    *   Improvements to Keras preprocessing layers:
-        *   TextVectorization can now accept a vocabulary list or file as an
-            init arg.
-        *   TextVectorization, StringLookup, and IntegerLookup can now accept a
-            vocabulary file via the `set_vocab_from_file` method.
-        *   Normalization can now accept mean and variance values as init args.
-    *   In `Attention` and `AdditiveAttention` layers, the `call()` method now
-        accepts a `return_attention_scores` argument. When set to
-        True, the layer returns the attention scores as an additional output
-        argument.
-    *   Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
-        with the same implementation as their `tf.losses` equivalent.
-    *   For Keras model, the individual call of `Model.evaluate` uses no cached
-        data for evaluation, while `Model.fit` uses cached data when
-        `validation_data` arg is provided for better performance.
-    *   Added a `save_traces` argument to `model.save`/
-        `tf.keras.models.save_model` which determines whether the SavedModel
-        format stores the Keras model/layer call functions. The traced functions
-        allow Keras to revive custom models and layers without the original
-        class definition, but if this isn't required the tracing can be
-        disabled with the added option.
-*   `tf.function` / AutoGraph:
-    *   Added `experimental_follow_type_hints` argument for `tf.function`. When
-        True, the function may use type annotations to optimize the tracing
-        performance.
-    *   Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
-    *   AutoGraph now allows creating new symbols inside a TensorFlow loop, if
-        the values of these symbols at an iteration does not depend on the
-        previous iteration. These types of loops must run at least one
-        iteration, and will raise a runtime error otherwise.
-    *   Variables contained in `tf.Module`s that are set as attributes of
-        custom Keras `Layer`s and `Model`s are now tracked in
-        the properties `layer.trainable_variables` and
-        `layer.non_trainable_variables`.
+### TF Core:
+  * Introduces experimental support for a new module named [`tf.experimental.numpy`]
+  (https://www.tensorflow.org/api_docs/python/tf/experimental/numpy), which is a
+  NumPy-compatible API for writing TF programs. This module provides class
+  `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable
+  `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are
+  provided. Their inter-operation with TF facilities is seamless in most cases.
+    See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md)
+    for details of what operations are supported and what are the differences
+  from NumPy.
+  * `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
+  type annotation for variables representing a Tensor or a value
+    that can be converted to Tensor by `tf.convert_to_tensor`.
+  * Calling ops with a python constants or numpy values is now consistent with
+  tf.convert_to_tensor behavior. This avoids operations like
+    tf.reshape truncating inputs such as from int64 to int32.
+  * Adds `tf.sparse.map_values` to apply a function to the `.value`s of
+  `SparseTensor` arguments.
+  * The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__` and `__invert__` now support non-`bool`
+  arguments and apply the corresponding bitwise ops. `bool` arguments continue
+  to be supported and dispatch to logical ops. This brings them more in line with
+  Python and NumPy behavior.
+  * Adds `tf.SparseTensor.with_values`. This returns a new SparseTensor with the same sparsity pattern, but with new provided values. It is
+    similar to the `with_values` function of `RaggedTensor`.
+  * Adds `StatelessCase` op, and uses it if none of case branches has stateful ops.
+  * Adds `tf.config.experimental.get_memory_usage` to return total memory usage of the device.
+  * Adds gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
+  * Improve shape inference of nested function calls by supporting constant
+  folding across Arg nodes which makes more static values available to shape
+  inference functions.
+* `tf.debugging`:
+  * `tf.debugging.assert_shapes()` now works on `SparseTensor`s (Fixes [#36268](https://github.com/tensorflow/tensorflow/issues/36268)).
+* GPU
+  * Adds Support for [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)
+  on Ampere based GPUs.TensorFloat-32, or TF32 for short, is a math mode for
+  NVIDIA Ampere based GPUs which causes certain float32 ops, such as matrix
+  multiplications and convolutions, to run much faster on Ampere GPUs but with
+  reduced precision. This reduced precision has not been found to effect
+  convergence quality of deep learning models in practice. TensorFloat-32 is
+  enabled by default, but can be disabled with `tf.config.experimental.enable_tensor_float_32_execution`.
+* `tf.math`:
+  * Adds `tf.math.erfcinv`, the inverse to `tf.math.erfc`.
+* `tf.nn`:
+  *   `tf.nn.max_pool2d` now supports explicit padding.
+* `tf.image`:
+  * Adds deterministic `tf.image.stateless_random_*` functions for each
+  `tf.image.random_*` function. Added a new op `stateless_sample_distorted_bounding_box`
+  which is a deterministic version of `sample_distorted_bounding_box` op.
+  Given the same seed, these stateless functions/ops produce the same results
+  independent of how many times the function is called, and independent of global seed settings.
+  * Adds deterministic `tf.image.resize` backprop CUDA kernels for
+  `method=ResizeMethod.BILINEAR` (the default method). Enable by setting the environment
+  variable `TF_DETERMINISTIC_OPS` to `"true"` or `"1"`.
+* `tf.print`:
+  * Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
+  didn't have the keys sorted, the keys and values were not being printed
+    in accordance with their correct mapping.
+* `tf.train.Checkpoint`:
+  * Now accepts a `root` argument in the initialization, which generates a
+  checkpoint with a root object. This allows users to create a `Checkpoint`
+  object that     is compatible with Keras `model.save_weights()` and
+  `model.load_weights`. The checkpoint is also compatible with the checkpoint
+  saved in the `variables/` folder in the SavedModel.
+  * When restoring, `save_path` can be a path to a SavedModel. The function will
+  automatically find the checkpoint in the SavedModel.
 
-    Example:
+### `tf.data`:
+  * Adds new `tf.data.experimental.service.register_dataset` and
+  `tf.data.experimental.service.from_dataset_id` APIs to enable one process to
+  register a dataset with the tf.data service, and another process to consume
+  data from the dataset.
+  * Adds support for dispatcher fault tolerance. To enable fault tolerance,
+  configure a `work_dir` when running your dispatcher server and set
+  `dispatcher_fault_tolerance=True`. The dispatcher will store its state to
+  `work_dir`, so that on restart it can continue from its previous state after restart.
+  * Adds support for sharing dataset graphs via shared filesystem instead of
+  over RPC. This reduces load on the dispatcher, improving performance
+    of distributing datasets. For this to work, the dispatcher's `work_dir`
+  must be accessible from workers. If the worker fails to read from the `work_dir`,
+  it falls back to using RPC for dataset graph transfer.
+  * Adds support for a new "distributed_epoch" processing mode.
+  This processing mode distributes a dataset across all tf.data workers,
+    instead of having each worker process the full dataset. See
+  [the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
+  to learn more.
+  * Adds optional `exclude_cols` parameter to CsvDataset. This parameter is the
+  complement of `select_cols`; at most one of these should be specified.
+  * We have implemented an optimization which reorders data-discarding
+  transformations such as `take` and `shard` to happen earlier in the dataset
+  when it is safe to do so. The optimization can be disabled via the
+  `experimental_optimization.reorder_data_discarding_ops` dataset option.
+  * `tf.data.Options` were previously immutable and can now be overridden.
+  * `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors with
+  a new `output_signature` argument, which allows `from_generator` to produce any
+  type describable by a `tf.TypeSpec`.
+  * `tf.data.experimental.AUTOTUNE` is now available in the core API as `tf.data.AUTOTUNE`.
 
-    ```
-    for batch in data:
-      outputs = train_step(batch)
-    tf.print('final outputs', outputs)
-    ```
+### `tf.distribute`:
+  * Introduces experimental support for asynchronous training of models via
+  `tf.distribute.experimental.ParameterServerStrategy`:
+    * Replaces the existing `tf.distribute.experimental.ParameterServerStrategy`
+  symbol with a new class that is for parameter server training in TF2. Usage of
+  the old symbol, usually with Estimator API, should be **replaced** with
+  [`tf.compat.v1.distribute.experimental.ParameterServerStrategy`].
+    * Added `tf.distribute.experimental.coordinator.*` namespace, including the
+  main API `ClusterCoordinator` for coordinating the training cluster, the
+  related data structure `RemoteValue` and `PerWorkerValue`.
+  * `MultiWorkerMirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MultiWorkerMirroredStrategy)
+  is now a stable API and is no longer considered experimental. Some of the major
+  improvements involve handling peer failure and many bug fixes. Please check out
+  the detailed tutorial on [Multi-worer training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras).
+  * Adds `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather`
+  APIs to support gathering dense distributed values.
+  * Fixes various issues with saving a distributed model.
 
-    See tensorflow/python/autograph/g3doc/reference/limitations.md for more
-    info.
+### `tf.keras`:
+  * Improvements from the Functional API refactoring:
+    * Functional model construction does not need to maintain a global workspace
+  graph, removing memory leaks especially when building many models or very large models.
+    * Functional model construction should be ~8-10% faster on average.
+    * Functional models can now contain non-symbolic values in their call inputs
+  inside of the first positional argument.
+    * Several classes of TF ops that were not reliably converted to Keras layers
+  during functional API construction should now work, e.g.`tf.image.ssim_multiscale`
+    * Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be
+      clearer and easier to understand.
+  * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
+  as an alternative to accepting a `callable` loss.
+  * Adds `beta` hyperparameter to [FTRL](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl)
+  optimizer classes (Keras and others) to match [FTRL paper](https://research.google.com/pubs/archive/41159.pdf).
+  * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for customization
+  of how gradients are aggregated across devices, as well as `gradients_transformers`
+  to allow for custom gradient transformations (such as gradient clipping).
+  * Improvements to Keras preprocessing layers:
+    * TextVectorization can now accept a vocabulary list or file as an init arg.
+    * Normalization can now accept mean and variance values as init args.
+  * In `Attention` and `AdditiveAttention` layers, the `call()` method now accepts a `return_attention_scores` argument. When set to
+    True, the layer returns the attention scores as an additional output argument.
+  * Adds `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints with the
+  same implementation as their `tf.losses` equivalent.
+  * For Keras model, the individual call of `Model.evaluate` uses no cached data
+  for evaluation, while `Model.fit` uses cached data when `validation_data` arg
+  is provided for better performance.
+  * Adds a `save_traces` argument to `model.save`/ `tf.keras.models.save_model`
+  which determines whether the SavedModel format stores the Keras model/layer call
+  functions. The traced functions allow Keras to revive custom models and layers
+  without the original class definition, but if this isn't required the tracing
+  can be disabled with the added option.
+  * The `tf.keras.mixed_precision` API is now non-experimental.
+  The non-experimental API differs from the experimental API in several ways.
+    * `tf.keras.mixed_precision.Policy` no longer takes in a `tf.mixed_precision.
+  experimental.LossScale` in the constructor, and no longer has a `LossScale`
+  associated with it. Instead, `Model.compile` will automatically wrap the optimizer
+  with a `LossScaleOptimizer` using dynamic loss scaling if `Policy.name`
+  is "mixed_float16".
+    * `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in different
+  arguments. In particular, it no longer takes in a `LossScale`, and there is
+  no longer a `LossScale` associated with the `LossScaleOptimizer`. Instead,
+  `LossScaleOptimizer` directly implements fixed or dynamic loss scaling. See the
+  documentation of [`tf.keras.mixed_precision.experimental.LossScaleOptimizer`]
+  (https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/experimental/LossScaleOptimizer?version=nightly)
+  for details on the differences between the experimental `LossScaleOptimizer`
+  and the new non-experimental `LossScaleOptimizer`.
+    * `tf.mixed_precision.experimental.LossScale` and its subclasses are
+  deprecated, as all of its functionality now exists within `tf.keras.mixed_precision.LossScaleOptimizer`
 
-*   `tf.lite`:
+### `tf.lite`:
+  * `TFLiteConverter`:
+    * Support optional flags `inference_input_type` and `inference_output_type`
+  for full integer quantized models. This allows users to modify the model input
+  and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting
+  to float type (`tf.float32`).
+  * NNAPI
+    * Adds NNAPI Delegation support for requantization use cases by converting
+  the operation into a dequantize-quantize pair.
+    * Removes deprecated `Interpreter.setUseNNAPI(boolean)` Java API. Use
+  `Interpreter.Options.setUseNNAPI` instead.
+    * Deprecates `Interpreter::UseNNAPI(bool)` C++ API. Use `NnApiDelegate()`
+  and related delegate configuration methods directly.
+    * Deprecates `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API.
+  Prefer controlling this via delegate options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16'
+  or `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
+  * GPU
+    * GPU acceleration now supports quantized models by default
+  * `DynamicBuffer::AddJoinedString()` will now add a separator if the first string to be joined is empty.
+  *  Adds support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
 
-    *   `TFLiteConverter`:
-        *   Support optional flags `inference_input_type` and
-            `inference_output_type` for full integer quantized models. This
-            allows users to modify the model input and output type to integer
-            types (`tf.int8`, `tf.uint8`) instead of defaulting to float type
-            (`tf.float32`).
-    *   TFLite Profiler for Android is available. See the detailed
-        [guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
-    * NNAPI
-        *   Added NNAPI Delegation support for requantization use cases by
-            converting the operation into a dequantize-quantize pair.
-        *   Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API.
-            *   Use `Interpreter.Options.setUseNNAPI` instead.
-        *   Deprecate `Interpreter::UseNNAPI(bool)` C++ API.
-            *   Use `NnApiDelegate()` and related delegate configuration methods
-                directly.
-        *   Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API
-            *   Prefer controlling this via delegate options, e.g.
-                `tflite::StatefulNnApiDelegate::Options::allow_fp16' or
-                `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
-    *   `DynamicBuffer::AddJoinedString()` will now add a separator if the first
-        string to be joined is empty.
-    *  Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
-    *   <ADD RELEASE NOTES HERE>
+### `TensorRT`
+  * Issues a warning when the `session_config` parameter for the TF1 converter
+  is used or the `rewrite_config_template` field in the TF2 converter parameter
+  object is used.
 
-*   `tf.random`:
+### TPU Enhancements:
+  * Adds support for the `beta` parameter of the FTRL optimizer for TPU
+  embeddings. Users of other TensorFlow platforms can implement equivalent
+  behavior by adjusting the `l2` parameter.
 
-    *   <ADD RELEASE NOTES HERE>
+### XLA Support:
+  * xla.experimental.compile is deprecated, use `tf.function(experimental_compile=True)` instead.
+  * Adds `tf.function.experimental_get_compiler_ir` which returns compiler IR
+  (currently 'hlo' and 'optimized_hlo') for given input for given function.
 
-*   Math and Linear Algebra:
+### Security:
+  * Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`,
+  ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
+  * Fixes three vulnerabilities in conversion to DLPack format
+    * [CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
+    * [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
+    * [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)
+  * Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
+    * [CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
+    * [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)
+  * Fixes several vulnerabilities in `RaggedCountSparseOutput` and `SparseCountSparseOutput` operations
+    * [CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
+    * [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
+    * [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
+    * [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
+    * [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
+    * [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201)
+  * Fixes an integer truncation vulnerability in code using the work sharder API,
+  ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
+  * Fixes a format string vulnerability in `tf.strings.as_string`,
+  ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
+  * Fixes segfault raised by calling session-only ops in eager mode,
+  ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
+  * Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`,
+  ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
+  * Fixes segfaults caused by incomplete `SavedModel` validation,
+  ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
+  * Fixes a data corruption due to a bug in negative indexing support in TFLite,
+  ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
+  * Fixes a data corruption due to dimension mismatch in TFLite,
+  ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
+  * Fixes several vulnerabilities in TFLite saved model format
+    * [CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
+    * [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
+    * [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)
+  * Fixes several vulnerabilities in TFLite implementation of segment sum
+    * [CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
+    * [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
+    * [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)
+  * Fixes a segfault in `tf.quantization.quantize_and_dequantize`,
+  ([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265))
+  * Fixes an undefined behavior float cast causing a crash,
+  ([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266))
+  * Fixes a lack of validation in `tf.raw_ops.DataFormatVecPermute` and
+  `tf.raw_ops.DataFormatDimMap` which can cause uninitialized memory access,
+  read outside bounds of arrays, data corruption and segmentation faults
+  ([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
+  * Fixes a crash caused by writing to read only memory region
+  ([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268))
+  * Fixes a heap out of bounds access in filesystem globbing implementation
+  ([CVE-2020-26269](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26269))
 
-    * Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`.
-
-*   TPU Enhancements:
-
-    *   Added support for the `beta` parameter of the FTRL optimizer for TPU
-        embeddings. Users of other TensorFlow platforms can implement equivalent
-        behavior by adjusting the `l2` parameter.
-    *   <ADD RELEASE NOTES HERE>
-
-*   XLA Support:
-
-    *   xla.experimental.compile is deprecated, use
-        `tf.function(experimental_compile=True)` instead
-    *   Added `tf.function.experimental_get_compiler_ir` which returns compiler
-        IR (currently 'hlo' and 'optimized_hlo') for given input for given
-        function.
-    *   <ADD RELEASE NOTES HERE>
-
-*   Tracing and Debugging:
-
-    *   <ADD RELEASE NOTES HERE>
-
-*   `tf.train.Checkpoint`:
-
-    *   Now accepts a `root` argument in the initialization, which generates a
-        checkpoint with a root object. This allows users to create a
-        `Checkpoint` object that is compatible with Keras `model.save_weights()`
-        and `model.load_weights`. The checkpoint is also compatible with the
-        checkpoint saved in the `variables/` folder in the SavedModel.
-    *   When restoring, `save_path` can be a path to a SavedModel. The function
-        will automatically find the checkpoint in the SavedModel.
-
-*   `tf.nn`:
-
-    *   `tf.nn.max_pool2d` now supports explicit padding.
-
-*   `tf.debugging`:
-
-    *   `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
-
-*   `tf.print`:
-
-    *   Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
-        didn't have the keys sorted, the keys and values were not being printed
-        in accordance with their correct mapping.
-
-*    `TensorRT`
-
-    *   We now issue a warning when the `session_config` parameter for the TF1
-        converter is used or the `rewrite_config_template` field in the TF2
-        converter parameter object is used.
-
-*   Other:
-
-    *   We have replaced uses of "whitelist" and "blacklist" with "allowlist"
-        and "denylist" where possible. Please see
-        https://developers.google.com/style/word-list#blacklist for more
-        context.
-    *   Add `tf.config.experimental.mlir_bridge_rollout` which will help us
-        rollout the new MLIR TPU bridge.
-    *   Added `tf.experimental.register_filesystem_plugin` to load modular
-        filesystem plugins from Python
-    *   <ADD RELEASE NOTES HERE>
+### Other:
+  * We have replaced uses of "whitelist" and "blacklist" with "allowlist" and
+  "denylist" where possible. Please see [this list](https://developers.google.com/style/word-list#blacklist) for more context.
+  * Adds `tf.config.experimental.mlir_bridge_rollout` which will help us rollout the new MLIR TPU bridge.
+  * Adds `tf.experimental.register_filesystem_plugin` to load modular filesystem plugins from Python
 
 ## Thanks to our Contributors
 
-This release contains contributions from many people at Google, as well as:
+This release contains contributions from many people at Google as well as the following external contributors:
 
-stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
+8bitmp3, aaa.jq, Abhineet Choudhary, Abolfazl Shahbazi, acxz, Adam Hillier, Adrian Garcia Badaracco, Ag Ramesh, ahmedsabie, Alan Anderson, Alexander Grund, Alexandre Lissy, Alexey Ivanov, Amedeo Cavallo, anencore94, Aniket Kumar Singh, Anthony Platanios, Ashwin Phadke, Balint Cristian, Basit Ayantunde, bbbboom, Ben Barsdell, Benjamin Chetioui, Benjamin Peterson, bhack, Bhanu Prakash Bandaru Venkata, Biagio Montaruli, Brent M. Spell, bubblebooy, bzhao, cfRod, Cheng Chen, Cheng(Kit) Chen, Chris Tessum, Christian, chuanqiw, codeadmin_peritiae, COTASPAR, CuiYifeng, danielknobe, danielyou0230, dannyfriar, daria, DarrenZhang01, Denisa Roberts, dependabot[bot], Deven Desai, Dmitry Volodin, Dmitry Zakharov, drebain, Duncan Riach, Eduard Feicho, Ehsan Toosi, Elena Zhelezina, emlaprise2358, Eugene Kuznetsov, Evaderan-Lab, Evgeniy Polyakov, Fausto Morales, Felix Johnny, fo40225, Frederic Bastien, Fredrik Knutsson, fsx950223, Gaurav Singh, Gauri1 Deshpande, George Grzegorz Pawelczak, gerbauz, Gianluca Baratti, Giorgio Arena, Gmc2, Guozhong Zhuang, Hannes Achleitner, Harirai, HarisWang, Harsh188, hedgehog91, Hemal Mamtora, Hideto Ueno, Hugh Ku, Ian Beauregard, Ilya Persky, jacco, Jakub Beránek, Jan Jongboom, Javier Montalt Tordera, Jens Elofsson, Jerry Shih, jerryyin, jgehw, Jinjing Zhou, jma, jmsmdy, Johan Nordström, John Poole, Jonah Kohn, Jonathan Dekhtiar, jpodivin, Jung Daun, Kai Katsumata, Kaixi Hou, Kamil Rakoczy, Kaustubh Maske Patil, Kazuaki Ishizaki, Kedar Sovani, Koan-Sin Tan, Koki Ibukuro, Krzysztof Laskowski, Kushagra Sharma, Kushan Ahmadian, Lakshay Tokas, Leicong Li, levinxo, Lukas Geiger, Maderator, Mahmoud Abuzaina, Mao Yunfei, Marius Brehler, markf, Martin Hwasser, Martin Kubovčík, Matt Conley, Matthias, mazharul, mdfaijul, Michael137, MichelBr, Mikhail Startsev, Milan Straka, Ml-0, Myung-Hyun Kim, Måns Nilsson, Nathan Luehr, ngc92, nikochiko, Niranjan Hasabnis, nyagato_00, Oceania2018, Oleg Guba, Ongun Kanat, OscarVanL, Patrik Laurell, Paul Tanger, Peter Sobot, Phil Pearl, PlusPlusUltra, Poedator, Prasad Nikam, Rahul-Kamat, Rajeshwar Reddy T, redwrasse, Rickard, Robert Szczepanski, Rohan Lekhwani, Sam Holt, Sami Kama, Samuel Holt, Sandeep Giri, sboshin, Sean Settle, settle, Sharada Shiddibhavi, Shawn Presser, ShengYang1, Shi,Guangyong, Shuxiang Gao, Sicong Li, Sidong-Wei, Srihari Humbarwadi, Srinivasan Narayanamoorthy, Steenu Johnson, Steven Clarkson, stjohnso98, Tamas Bela Feher, Tamas Nyiri, Tarandeep Singh, Teng Lu, Thibaut Goetghebuer-Planchon, Tim Bradley, Tomasz Strejczek, Tongzhou Wang, Torsten Rudolf, Trent Lo, Ty Mick, Tzu-Wei Sung, Varghese, Jojimon, Vignesh Kothapalli, Vishakha Agrawal, Vividha, Vladimir Menshakov, Vladimir Silyaev, VoVAllen, Võ Văn Nghĩa, wondertx, xiaohong1031, Xiaoming (Jason) Cui, Xinan Jiang, Yair Ehrenwald, Yasir Modak, Yasuhiro Matsumoto, Yimei Sun, Yiwen Li, Yixing, Yoav Ramon, Yong Tang, Yong Wu, yuanbopeng, Yunmo Koo, Zhangqiang, Zhou Peng, ZhuBaohe, zilinzhu, zmx
 
 
 # Release 2.3.1
diff --git a/configure.py b/configure.py
index 1207986..25c3769 100644
--- a/configure.py
+++ b/configure.py
@@ -55,8 +55,7 @@
 
 # List of files to configure when building Bazel on Apple platforms.
 APPLE_BAZEL_FILES = [
-    'tensorflow/lite/ios/BUILD',
-    'tensorflow/lite/objc/BUILD',
+    'tensorflow/lite/ios/BUILD', 'tensorflow/lite/objc/BUILD',
     'tensorflow/lite/swift/BUILD',
     'tensorflow/lite/tools/benchmark/experimental/ios/BUILD'
 ]
@@ -1174,7 +1173,9 @@
   # First available in VS 16.4. Speeds up Windows compile times by a lot. See
   # https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
   # pylint: disable=line-too-long
-  write_to_bazelrc('build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions')
+  write_to_bazelrc(
+      'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions'
+  )
 
   if get_var(
       environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline',
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index d0ae973..52eb4e3 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -3,7 +3,7 @@
 # learning applications.
 
 load("@bazel_skylib//lib:selects.bzl", "selects")
-load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
+load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting")
 load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
 load(
     "//tensorflow/core/platform:build_config.bzl",
@@ -401,13 +401,20 @@
     define_values = {"using_cuda_clang": "true"},
 )
 
-# Flag to indicate open source build, .bazelrc always has it set to be true
+# Config setting to use in select()s to distinguish open source build from
+# google internal build on configurable attributes.
 config_setting(
     name = "oss",
-    define_values = {"open_source_build": "true"},
+    flag_values = {":oss_setting": "True"},
     visibility = ["//visibility:public"],
 )
 
+# Fixed setting to indicate open source build.
+bool_setting(
+    name = "oss_setting",
+    build_setting_default = True,
+)
+
 config_setting(
     name = "using_cuda_clang_with_dynamic_build",
     define_values = {
@@ -416,12 +423,12 @@
     },
 )
 
-config_setting(
+selects.config_setting_group(
     name = "build_oss_using_cuda_clang",
-    define_values = {
-        "using_cuda_clang": "true",
-        "open_source_build": "true",
-    },
+    match_all = [
+        ":using_cuda_clang",
+        ":oss",
+    ],
 )
 
 # Setting to use when loading kernels dynamically
@@ -447,12 +454,12 @@
     },
 )
 
-config_setting(
+selects.config_setting_group(
     name = "build_oss_using_cuda_nvcc",
-    define_values = {
-        "using_cuda_nvcc": "true",
-        "open_source_build": "true",
-    },
+    match_all = [
+        ":using_cuda_nvcc",
+        ":oss",
+    ],
 )
 
 config_setting(
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 99a278a..eea81d0 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -116,7 +116,8 @@
 
 # Get sitepackages directories for the python installation.
 _site_packages_dirs = []
-_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
+if _site.ENABLE_USER_SITE and _site.USER_SITE is not None:
+  _site_packages_dirs += [_site.USER_SITE]
 _site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
 if 'getsitepackages' in dir(_site):
   _site_packages_dirs += _site.getsitepackages()
@@ -145,6 +146,8 @@
     _plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
     if _os.path.exists(_plugin_dir):
       _ll.load_library(_plugin_dir)
+      # Load Pluggable Device Library
+      _ll.load_pluggable_device_library(_plugin_dir)
 
 # Add module aliases
 if hasattr(_current_module, 'keras'):
diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py
index ae82f7b..e69287a 100644
--- a/tensorflow/api_template_v1.__init__.py
+++ b/tensorflow/api_template_v1.__init__.py
@@ -155,6 +155,8 @@
     _plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
     if _os.path.exists(_plugin_dir):
       _ll.load_library(_plugin_dir)
+      # Load Pluggable Device Library
+      _ll.load_pluggable_device_library(_plugin_dir)
 
 # Delete modules that should be hidden from dir().
 # Don't fail if these modules are not available.
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 2ce9e9a..294f369 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -78,7 +78,7 @@
     ],
     visibility = [
         "//tensorflow/core:__pkg__",
-        "//tensorflow/python:__pkg__",
+        "//tensorflow/python:__subpackages__",
     ],
 )
 
@@ -684,7 +684,10 @@
     name = "c_api_experimental_test",
     size = "medium",
     srcs = ["c_api_experimental_test.cc"],
-    data = ["testdata/tf_record"],
+    data = [
+        "testdata/tf_record",
+        "//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
+    ],
     linkopts = select({
         "//tensorflow:macos": ["-headerpad_max_install_names"],
         "//conditions:default": [],
@@ -704,6 +707,7 @@
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:resource_loader",
         "@com_google_absl//absl/types:optional",
     ],
 )
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 0d188aa..e973442 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -37,7 +37,9 @@
 #include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/platform/blocking_counter.h"
 #include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/net.h"
 #include "tensorflow/core/platform/platform.h"
 #include "tensorflow/core/platform/strcat.h"
@@ -630,6 +632,9 @@
 
 namespace tensorflow {
 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
+
+// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
+Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
 }  // namespace tensorflow
 
 void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
@@ -743,3 +748,45 @@
     TF_ImportGraphDefOptions* opts, unsigned char enable) {
   opts->opts.validate_colocation_constraints = enable;
 }
+
+// Load a Pluggable Device library.
+// On success, returns the handle to library in result and return OK from the
+// function. Otherwise return nullptr in result and error Status from the
+// function.
+//
+// If `library_filename` has already been loaded, we return a cached handle.
+// Device and Kernels/Ops are registered as globals when a library is loaded
+// for the first time.
+TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
+                                          TF_Status* status) {
+#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
+  status->status = tensorflow::errors::Unimplemented(
+      "PluggableDevice plugin functionality is not supported on mobile");
+  return nullptr;
+#else
+  TF_Library* lib_handle = new TF_Library;
+  static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
+  static std::unordered_map<std::string, void*>* loaded_libs =
+      new std::unordered_map<std::string, void*>();
+  tensorflow::Env* env = tensorflow::Env::Default();
+  {
+    tensorflow::mutex_lock lock(mu);
+    auto it = loaded_libs->find(library_filename);
+    if (it != loaded_libs->end()) {
+      lib_handle->lib_handle = it->second;
+    } else {
+      status->status =
+          env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
+      if (!status->status.ok()) {
+        delete lib_handle;
+        return nullptr;
+      }
+    }
+    return lib_handle;
+  }
+#endif
+}
+
+void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
+  delete lib_handle;
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index e877c77..d413215 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -304,6 +304,27 @@
 TF_ImportGraphDefOptionsSetValidateColocationConstraints(
     TF_ImportGraphDefOptions* opts, unsigned char enable);
 
+// Load the library specified by library_filename and register the pluggable
+// device and related kernels present in that library. This function is not
+// supported on embedded on mobile and embedded platforms and will fail if
+// called.
+//
+// Pass "library_filename" to a platform-specific mechanism for dynamically
+// loading a library. The rules for determining the exact location of the
+// library are platform-specific and are not documented here.
+//
+// On success, returns the newly created library handle and places OK in status.
+// The caller owns the library handle.
+//
+// On failure, returns nullptr and places an error status in status.
+TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
+    const char* library_filename, TF_Status* status);
+
+// Frees the memory associated with the library handle.
+// Does NOT unload the library.
+TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
+    TF_Library* lib_handle);
+
 #ifdef __cplusplus
 } /* end extern "C" */
 #endif
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc
index cfeba34..4c319d0 100644
--- a/tensorflow/c/c_api_experimental_test.cc
+++ b/tensorflow/c/c_api_experimental_test.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/resource_loader.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
 
@@ -234,5 +235,22 @@
   TF_DeleteTensor(tensor_1X6);
 }
 
+TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
+#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
+  // Load the library.
+  TF_Status* status = TF_NewStatus();
+  string lib_path =
+      tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
+          "tensorflow", "c", "experimental", "stream_executor", "test",
+          "test_pluggable_device.so"));
+  TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
+  TF_Code code = TF_GetCode(status);
+  string status_msg(TF_Message(status));
+  TF_DeleteStatus(status);
+  ASSERT_EQ(TF_OK, code) << status_msg;
+  TF_DeletePluggableDeviceLibraryHandle(lib);
+#endif  // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 3bae78d..d37cb96 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -213,7 +213,11 @@
     TF_DeleteFunction(tf_function);
     return nullptr;
   }
-  tf_function->graph_with_debug_info = &fn_body->graph;
+
+  for (const Node* n : fn_body->graph.nodes()) {
+    tf_function->stack_traces[n->name()] = n->GetStackTrace();
+  }
+
   return tf_function;
 }
 
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index d45aa9a..b5ab775 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -157,9 +157,7 @@
 
 struct TF_Function {
   tensorflow::FunctionDef fdef;
-
-  // Graph with nodes with debug stack traces.
-  const tensorflow::Graph* graph_with_debug_info = nullptr;
+  tensorflow::StackTracesMap stack_traces;
 };
 
 struct TF_ApiDefMap {
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 4ec7371..09d5e65 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -248,6 +248,7 @@
         ":c_api_unified_internal",
         "//tensorflow/c:tf_status",
         "//tensorflow/c:tf_status_helper",
+        "//tensorflow/c:tf_tensor",
         "//tensorflow/core:framework",
         "//tensorflow/core/lib/llvm_rtti",
         "//tensorflow/core/platform:errors",
@@ -388,6 +389,7 @@
 
 cc_library(
     name = "gradient_checker",
+    testonly = 1,
     srcs = [
         "gradient_checker.cc",
     ],
@@ -399,27 +401,11 @@
     ],
     deps = [
         ":abstract_tensor_handle",
-        ":c_api_experimental",
-        ":c_api_unified_internal",
-        ":gradients_internal",
-        ":gradients_util",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "//tensorflow/c:c_api",
-        "//tensorflow/c:tf_status_helper",
-        "//tensorflow/c/experimental/gradients:math_grad",
-        "//tensorflow/c/experimental/gradients:nn_grad",
-        "//tensorflow/c/experimental/ops:array_ops",
+        ":unified_api_testutil",
+        "//tensorflow/c:tf_tensor_internal",
         "//tensorflow/c/experimental/ops:math_ops",
-        "//tensorflow/c/experimental/ops:nn_ops",
-        "//tensorflow/cc/profiler",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/lib/llvm_rtti",
-    ] + if_libtpu(
-        if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
-        if_true = [],
-    ),
+        "@com_google_absl//absl/types:span",
+    ],
 )
 
 tf_cuda_cc_test(
@@ -430,33 +416,19 @@
     ],
     args = ["--heap_check=local"],
     linkstatic = tf_kernel_tests_linkstatic(),
-    tags = tf_cuda_tests_tags() + ["nomac"],
+    tags = tf_cuda_tests_tags() + [
+        "no_cuda_asan",  # b/175330074
+    ],
     deps = [
         ":abstract_tensor_handle",
         ":c_api_experimental",
-        ":c_api_test_util",
-        ":c_api_unified_internal",
         ":gradient_checker",
-        ":gradients_internal",
-        ":gradients_util",
-        ":mnist_gradients_testutil",
-        "//tensorflow/c:c_api",
-        "//tensorflow/c:c_test_util",
+        ":unified_api_testutil",
         "//tensorflow/c:tf_status_helper",
-        "//tensorflow/c/experimental/gradients:math_grad",
-        "//tensorflow/c/experimental/gradients:nn_grad",
-        "//tensorflow/c/experimental/ops:array_ops",
-        "//tensorflow/c/experimental/ops:math_ops",
-        "//tensorflow/c/experimental/ops:nn_ops",
-        "//tensorflow/cc/profiler",
-        "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/c:tf_tensor_internal",
+        "//tensorflow/c/experimental/ops",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
-        "//tensorflow/core/lib/llvm_rtti",
-        "//tensorflow/core/platform:tensor_float_32_utils",
-        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
     ],
 )
@@ -503,6 +475,7 @@
 
 cc_library(
     name = "abstract_tensor_handle",
+    srcs = ["abstract_tensor_handle.cc"],
     hdrs = ["abstract_tensor_handle.h"],
     visibility = [
         "//tensorflow:internal",
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/c/eager/abstract_tensor_handle.cc
similarity index 61%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
copy to tensorflow/c/eager/abstract_tensor_handle.cc
index 2dd4a8d..a30063a 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
+++ b/tensorflow/c/eager/abstract_tensor_handle.cc
@@ -13,13 +13,21 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/c/eager/abstract_tensor_handle.h"
 
 namespace tensorflow {
 
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
+std::string AbstractTensorHandle::DebugString() const {
+  PartialTensorShape shape;
+  Status s = Shape(&shape);
+  std::string shape_string;
+  if (!s.ok()) {
+    shape_string = "<error computing shape>";
+  } else {
+    shape_string = shape.DebugString();
+  }
+  return absl::StrCat("TensorHandle(shape=", shape_string,
+                      ", dtype=", DataType_Name(DataType()), ")");
+}
 
 }  // namespace tensorflow
diff --git a/tensorflow/c/eager/abstract_tensor_handle.h b/tensorflow/c/eager/abstract_tensor_handle.h
index 1ca4a9a..8d7e211 100644
--- a/tensorflow/c/eager/abstract_tensor_handle.h
+++ b/tensorflow/c/eager/abstract_tensor_handle.h
@@ -27,7 +27,7 @@
 // execution mode.
 class AbstractTensorHandle : public core::RefCounted {
  protected:
-  enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
+  enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt, kCustomDevice };
   explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
   virtual ~AbstractTensorHandle() {}
 
@@ -38,6 +38,10 @@
   virtual tensorflow::Status Shape(
       tensorflow::PartialTensorShape* shape) const = 0;
 
+  // The default debug string includes a shape and dtype. Implementations are
+  // free to override it with something more informative.
+  virtual std::string DebugString() const;
+
   AbstractTensorHandleKind getKind() const { return kind_; }
 
  private:
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 7738558..6c86ab7 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -43,6 +43,7 @@
 #include "tensorflow/core/common_runtime/eager/execute.h"
 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
 #include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/node_def_util.h"
@@ -76,6 +77,15 @@
   return (d == nullptr) ? "cpu:0" : d->name();
 }
 
+// Annotate eager runtime construction context to the given `function_def` as
+// an attribute.
+void AnnotateEagerRuntimeConstructionContext(
+    tensorflow::FunctionDef& function_def) {
+  tensorflow::AttrValue value;
+  SetAttrValue("kEagerRuntime", &value);
+  (*function_def.mutable_attr())["_construction_context"] = value;
+}
+
 }  // namespace
 
 extern "C" {
@@ -744,13 +754,16 @@
         tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
     return;
   }
+
+  AnnotateEagerRuntimeConstructionContext(function_def);
   status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
 }
 
 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
                             TF_Status* status) {
-  status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithDebugInfo(
-      function->fdef, function->graph_with_debug_info);
+  AnnotateEagerRuntimeConstructionContext(function->fdef);
+  status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
+      function->fdef, function->stack_traces);
 }
 
 void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
diff --git a/tensorflow/c/eager/c_api_remote_function_test.cc b/tensorflow/c/eager/c_api_remote_function_test.cc
index a9bbd5b..45e8302 100644
--- a/tensorflow/c/eager/c_api_remote_function_test.cc
+++ b/tensorflow/c/eager/c_api_remote_function_test.cc
@@ -20,10 +20,11 @@
 
 void TestRemoteExecuteSilentCopiesFunc(bool async, bool remote,
                                        bool heavy_load_on_streaming_rpc,
-                                       bool remote_func_outputs = false) {
+                                       bool remote_func_outputs = false,
+                                       bool has_packed_input = false) {
   return TestRemoteExecuteSilentCopies(async, remote, /*func=*/true,
                                        heavy_load_on_streaming_rpc,
-                                       remote_func_outputs);
+                                       remote_func_outputs, has_packed_input);
 }
 
 TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
@@ -60,5 +61,14 @@
   TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
                                     /*heavy_load_on_streaming_rpc=*/true);
 }
+TEST(CAPI, RemoteExecuteSilentCopiesRemoteAsyncPackedInputFuncOrdering) {
+  // A remote input (packed) may be not ready when we start running a function.
+  // Test that the function execution should wait until the remote input is
+  // ready.
+  TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
+                                    /*heavy_load_on_streaming_rpc=*/true,
+                                    /*remote_func_outputs*/ true,
+                                    /*has_packed_input=*/true);
+}
 
 }  // namespace
diff --git a/tensorflow/c/eager/c_api_remote_test_util.cc b/tensorflow/c/eager/c_api_remote_test_util.cc
index 159fa44..beb1baf 100644
--- a/tensorflow/c/eager/c_api_remote_test_util.cc
+++ b/tensorflow/c/eager/c_api_remote_test_util.cc
@@ -68,7 +68,9 @@
 
 void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
                                    bool heavy_load_on_streaming_rpc,
-                                   bool remote_func_outputs) {
+                                   bool remote_func_outputs,
+                                   bool has_packed_input) {
+  CHECK(!has_packed_input || func);
   tensorflow::ServerDef server_def = GetServerDef(3);
 
   // This server def has the task index set to 0.
@@ -123,6 +125,15 @@
       TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
 
+  TFE_TensorHandle* packed_handle = nullptr;
+  if (has_packed_input) {
+    int num_replicas = 1;
+    std::vector<TFE_TensorHandle*> packed_handles = {h1_task2};
+    packed_handle = TFE_CreatePackedTensorHandle(ctx, packed_handles.data(),
+                                                 &num_replicas, status);
+    ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+  }
+
   TFE_Op* matmul = nullptr;
   if (func) {
     const string matmul_device = remote_func_outputs ? task2_name : "";
@@ -135,7 +146,7 @@
     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
     TFE_OpAddInput(matmul, h0_task0, status);
     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
-    TFE_OpAddInput(matmul, h1_task2, status);
+    TFE_OpAddInput(matmul, has_packed_input ? packed_handle : h1_task2, status);
     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
   } else {
     // Handles are on task0 (local), and task2, but op is on task1.
@@ -194,6 +205,9 @@
 
   TFE_DeleteTensorHandle(h0_task0);
   TFE_DeleteTensorHandle(h1_task0);
+  if (packed_handle) {
+    TFE_DeleteTensorHandle(packed_handle);
+  }
   TFE_DeleteTensorHandle(h1_task2);
   TFE_DeleteTensorHandle(retvals[0]);
   for (auto* h : handles_task0) {
diff --git a/tensorflow/c/eager/c_api_remote_test_util.h b/tensorflow/c/eager/c_api_remote_test_util.h
index 266ca5a..6d9edb6 100644
--- a/tensorflow/c/eager/c_api_remote_test_util.h
+++ b/tensorflow/c/eager/c_api_remote_test_util.h
@@ -21,6 +21,7 @@
 // is not ready when we start running an op or a function.
 void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
                                    bool heavy_load_on_streaming_rpc,
-                                   bool remote_func_outputs = false);
+                                   bool remote_func_outputs = false,
+                                   bool has_packed_input = false);
 
 #endif  // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc
index 50c67d7..d555c59 100644
--- a/tensorflow/c/eager/gradient_checker.cc
+++ b/tensorflow/c/eager/gradient_checker.cc
@@ -18,18 +18,8 @@
 
 #include "absl/types/span.h"
 #include "tensorflow/c/eager/abstract_tensor_handle.h"
-#include "tensorflow/c/eager/c_api_experimental.h"
-#include "tensorflow/c/eager/c_api_unified_experimental.h"
-#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
-#include "tensorflow/c/eager/gradients.h"
-#include "tensorflow/c/eager/gradients_internal.h"
-#include "tensorflow/c/experimental/gradients/math_grad.h"
-#include "tensorflow/c/experimental/gradients/nn_grad.h"
-#include "tensorflow/c/experimental/ops/array_ops.h"
-#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/c/experimental/ops/math_ops.h"
 #include "tensorflow/c/tf_tensor.h"
-#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
-#include "tensorflow/core/platform/errors.h"
 
 namespace tensorflow {
 namespace gradients {
@@ -45,16 +35,6 @@
   }
 }
 
-// Returns AbstractTensorHandlePtr containing [0, ..., n-1].
-AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) {
-  vector<int> vals(n);
-  int64_t vals_shape[] = {n};
-  Range(&vals, 0, n);
-  AbstractTensorHandlePtr r =
-      GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1);
-  return r;
-}
-
 // Fills out_dims with the dimensions of the given tensor.
 void GetDims(const TF_Tensor* t, int64_t* out_dims) {
   int num_dims = TF_NumDims(t);
@@ -69,39 +49,41 @@
                       absl::Span<AbstractTensorHandle* const> inputs,
                       absl::Span<AbstractTensorHandle*> outputs,
                       bool use_function) {
-  GradientRegistry registry;
   std::vector<AbstractTensorHandle*> model_outputs(1);
 
   // Run the model.
   TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs,
-                              absl::MakeSpan(model_outputs), use_function,
-                              registry));
-  AbstractTensorHandle* model_out = model_outputs[0];
+                              absl::MakeSpan(model_outputs), use_function));
+  AbstractTensorHandlePtr model_out(model_outputs[0]);
 
   TF_Tensor* model_out_tensor;
-  TF_RETURN_IF_ERROR(GetValue(model_out, &model_out_tensor));
+  TF_RETURN_IF_ERROR(GetValue(model_out.get(), &model_out_tensor));
   int num_dims_out = TF_NumDims(model_out_tensor);
+  TF_DeleteTensor(model_out_tensor);
 
   // If the output is a scalar, then return the scalar output
   if (num_dims_out == 0) {
-    outputs[0] = model_out;
+    outputs[0] = model_out.release();
     return Status::OK();
   }
 
   // Else, reduce sum the output to get a scalar
 
   // Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1].
-  AbstractTensorHandlePtr sum_dims =
-      GetRangeTensorHandleUtil(ctx, num_dims_out);
+  AbstractTensorHandlePtr sum_dims;
+  {
+    vector<int> vals(num_dims_out);
+    int64_t vals_shape[] = {num_dims_out};
+    Range(&vals, 0, num_dims_out);
+    AbstractTensorHandle* sum_dims_raw = nullptr;
+    TF_RETURN_IF_ERROR(TestTensorHandleWithDimsInt(ctx, vals.data(), vals_shape,
+                                                   1, &sum_dims_raw));
+    sum_dims.reset(sum_dims_raw);
+  }
 
   // Reduce sum the output on all dimensions.
-  std::vector<AbstractTensorHandle*> sum_inputs(2);
-  sum_inputs[0] = model_out;
-  sum_inputs[1] = sum_dims.get();
-
   TF_RETURN_IF_ERROR(
-      ops::Sum(ctx, sum_inputs, absl::MakeSpan(model_outputs), "sum_output"));
-  outputs[0] = model_outputs[0];
+      ops::Sum(ctx, {model_out.get(), sum_dims.get()}, outputs, "sum_output"));
   return Status::OK();
 }
 // ========================= End Helper Functions==============================
@@ -144,61 +126,77 @@
   // Numerical Grad Check
   for (int i = 0; i < num_elems; i++) {
     // Get relative epsilon value
-    float epsilon =
-        std::abs(theta_data[i] * 1e-4 + 1e-4);  // add 1e-4 to prevent div by 0
-    AbstractTensorHandlePtr two_eps =
-        GetScalarTensorHandleUtil(ctx, 2 * epsilon);
+    float epsilon = theta_data[i] == 0 ? 1e-4 : std::abs(theta_data[i] * 1e-4);
+    AbstractTensorHandlePtr two_eps;
+    {
+      AbstractTensorHandle* two_eps_raw = nullptr;
+      TF_RETURN_IF_ERROR(
+          TestScalarTensorHandle(ctx, 2 * epsilon, &two_eps_raw));
+      two_eps.reset(two_eps_raw);
+    }
 
     // Initialize theta[i] + epsilon.
     memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor),
            TF_TensorByteSize(theta_tensor));
     thetaPlus_data[i] += epsilon;
-    AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat(
-        ctx, thetaPlus_data.data(), theta_dims.data(), num_dims);
+    AbstractTensorHandlePtr thetaPlus;
+    {
+      AbstractTensorHandle* thetaPlus_raw = nullptr;
+      TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
+          ctx, thetaPlus_data.data(), theta_dims.data(), num_dims,
+          &thetaPlus_raw));
+      thetaPlus.reset(thetaPlus_raw);
+    }
 
     // Initialize theta[i] - epsilon.
     memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor),
            TF_TensorByteSize(theta_tensor));
     thetaMinus_data[i] -= epsilon;
-    AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat(
-        ctx, thetaMinus_data.data(), theta_dims.data(), num_dims);
+    AbstractTensorHandlePtr thetaMinus;
+    {
+      AbstractTensorHandle* thetaMinus_raw = nullptr;
+      TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
+          ctx, thetaMinus_data.data(), theta_dims.data(), num_dims,
+          &thetaMinus_raw));
+      thetaMinus.reset(thetaMinus_raw);
+    }
 
     // Get f(theta + eps):
     theta_inputs[input_index] = thetaPlus.get();
     TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs,
                                       absl::MakeSpan(f_outputs), use_function));
-    AbstractTensorHandle* fPlus = f_outputs[0];
+    AbstractTensorHandlePtr fPlus(f_outputs[0]);
 
     // Get f(theta - eps):
     theta_inputs[input_index] = thetaMinus.get();
     TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs,
                                       absl::MakeSpan(f_outputs), use_function));
-    AbstractTensorHandle* fMinus = f_outputs[0];
+    AbstractTensorHandlePtr fMinus(f_outputs[0]);
 
     // Take Difference of both estimates: (f(theta + eps) - f(theta - eps)).
-    TF_RETURN_IF_ERROR(
-        ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top"));
-    AbstractTensorHandle* fDiff = f_outputs[0];
+    TF_RETURN_IF_ERROR(ops::Sub(ctx, {fPlus.get(), fMinus.get()},
+                                absl::MakeSpan(f_outputs), "sub_top"));
+    AbstractTensorHandlePtr fDiff(f_outputs[0]);
 
     // Calculate using the difference quotient definition:
     // (f(theta + eps) - f(theta - eps)) / (2 * eps).
-    TF_RETURN_IF_ERROR(ops::DivNoNan(ctx, {fDiff, two_eps.get()},
-                                     absl::MakeSpan(f_outputs),
-                                     "diff_quotient"));
-    AbstractTensorHandle* diff_quotient = f_outputs[0];
+    TF_RETURN_IF_ERROR(ops::Div(ctx, {fDiff.get(), two_eps.get()},
+                                absl::MakeSpan(f_outputs), "diff_quotient"));
+    AbstractTensorHandlePtr diff_quotient(f_outputs[0]);
 
     TF_Tensor* grad_tensor;
-    TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor));
+    TF_RETURN_IF_ERROR(GetValue(diff_quotient.get(), &grad_tensor));
     float grad_data[1];
     memcpy(&grad_data[0], TF_TensorData(grad_tensor),
            TF_TensorByteSize(grad_tensor));
-
+    TF_DeleteTensor(grad_tensor);
     dtheta_approx[i] = grad_data[0];
   }
 
   // Populate *numerical_grad with the data from dtheta_approx.
-  TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat(
+  TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
       ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
+  TF_DeleteTensor(theta_tensor);
   return Status::OK();
 }
 
diff --git a/tensorflow/c/eager/gradient_checker.h b/tensorflow/c/eager/gradient_checker.h
index 705318b..c167148 100644
--- a/tensorflow/c/eager/gradient_checker.h
+++ b/tensorflow/c/eager/gradient_checker.h
@@ -12,23 +12,14 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#ifndef TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_
+#define TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_
+
 #include <memory>
 
 #include "absl/types/span.h"
 #include "tensorflow/c/eager/abstract_tensor_handle.h"
-#include "tensorflow/c/eager/c_api_experimental.h"
-#include "tensorflow/c/eager/c_api_unified_experimental.h"
-#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
-#include "tensorflow/c/eager/gradients.h"
-#include "tensorflow/c/eager/gradients_internal.h"
-#include "tensorflow/c/eager/gradients_util.h"
-#include "tensorflow/c/experimental/gradients/math_grad.h"
-#include "tensorflow/c/experimental/gradients/nn_grad.h"
-#include "tensorflow/c/experimental/ops/array_ops.h"
-#include "tensorflow/c/tf_status_helper.h"
-#include "tensorflow/c/tf_tensor.h"
-#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
-#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/c/eager/unified_api_testutil.h"
 
 namespace tensorflow {
 namespace gradients {
@@ -51,3 +42,5 @@
 
 }  // namespace gradients
 }  // namespace tensorflow
+
+#endif  // TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_
diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc
index 393ad2c..0680515 100644
--- a/tensorflow/c/eager/gradient_checker_test.cc
+++ b/tensorflow/c/eager/gradient_checker_test.cc
@@ -15,21 +15,11 @@
 
 #include "absl/types/span.h"
 #include "tensorflow/c/eager/abstract_tensor_handle.h"
-#include "tensorflow/c/eager/c_api_experimental.h"
 #include "tensorflow/c/eager/c_api_unified_experimental.h"
-#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
-#include "tensorflow/c/eager/gradients.h"
-#include "tensorflow/c/eager/gradients_internal.h"
-#include "tensorflow/c/eager/gradients_util.h"
-#include "tensorflow/c/eager/mnist_gradients_testutil.h"
-#include "tensorflow/c/experimental/gradients/math_grad.h"
-#include "tensorflow/c/experimental/gradients/nn_grad.h"
-#include "tensorflow/c/experimental/ops/array_ops.h"
+#include "tensorflow/c/eager/unified_api_testutil.h"
+#include "tensorflow/c/experimental/ops/math_ops.h"
 #include "tensorflow/c/tf_status_helper.h"
 #include "tensorflow/c/tf_tensor.h"
-#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
-#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/tensor_float_32_utils.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
@@ -37,6 +27,59 @@
 namespace internal {
 namespace {
 
+using tensorflow::TF_StatusPtr;
+
+void CompareNumericalAndManualGradients(
+    Model model, AbstractContext* ctx,
+    absl::Span<AbstractTensorHandle* const> inputs, int input_index,
+    float* expected_grad, int num_grad, bool use_function,
+    double abs_error = 1e-2) {
+  Status s;
+  AbstractTensorHandlePtr numerical_grad;
+  {
+    AbstractTensorHandle* numerical_grad_raw;
+    s = CalcNumericalGrad(ctx, model, inputs, input_index, use_function,
+                          &numerical_grad_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    numerical_grad.reset(numerical_grad_raw);
+  }
+
+  TF_Tensor* numerical_tensor;
+  s = GetValue(numerical_grad.get(), &numerical_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+  auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);
+  ASSERT_EQ(num_elem_numerical, num_grad);
+
+  float* dnumerical = new float[num_elem_numerical]{0};
+  memcpy(&dnumerical[0], TF_TensorData(numerical_tensor),
+         TF_TensorByteSize(numerical_tensor));
+
+  for (int j = 0; j < num_grad; j++) {
+    ASSERT_NEAR(dnumerical[j], expected_grad[j], abs_error);
+  }
+  delete[] dnumerical;
+  TF_DeleteTensor(numerical_tensor);
+}
+
+Status MatMulModel(AbstractContext* ctx,
+                   absl::Span<AbstractTensorHandle* const> inputs,
+                   absl::Span<AbstractTensorHandle*> outputs) {
+  return ops::MatMul(ctx, inputs, outputs, "MatMul",
+                     /*transpose_a=*/false,
+                     /*transpose_b=*/false);
+}
+
+Status MulModel(AbstractContext* ctx,
+                absl::Span<AbstractTensorHandle* const> inputs,
+                absl::Span<AbstractTensorHandle*> outputs) {
+  return ops::Mul(ctx, inputs, outputs, "Mul");
+}
+
+// TODO(vnvo2409): Add more tests from `python/ops/gradient_checker_v2_test.py`.
+// These tests should not be confused with `[*]_grad_test` which compare the
+// result of `gradient_checker` and `[*]_grad`. The tests here test the
+// functionality of `gradient_checker` by comparing the result with expected
+// manual user-provided gradients.
 class GradientCheckerTest
     : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
  protected:
@@ -45,84 +88,56 @@
     TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
     Status s = StatusFromTF_Status(status.get());
     CHECK_EQ(errors::OK, s.code()) << s.error_message();
+
+    {
+      AbstractContext* ctx_raw = nullptr;
+      Status s =
+          BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+      ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+      ctx_.reset(ctx_raw);
+    }
   }
+
+  AbstractContextPtr ctx_;
+
+ public:
+  bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
+  bool UseFunction() const { return std::get<2>(GetParam()); }
 };
 
-Status RegisterGradients(GradientRegistry* registry) {
-  TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
-  TF_RETURN_IF_ERROR(
-      registry->Register("SparseSoftmaxCrossEntropyWithLogits",
-                         SparseSoftmaxCrossEntropyWithLogitsRegisterer));
-  return Status::OK();
-}
-
-TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
-  // Computing numerical gradients with TensorFloat-32 is numerically unstable
-  enable_tensor_float_32_execution(false);
-
-  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
-      TF_NewStatus(), TF_DeleteStatus);
-  AbstractContextPtr ctx;
-  {
-    AbstractContext* ctx_raw = nullptr;
-    Status s =
-        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
-    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-    ctx.reset(ctx_raw);
-  }
-
+TEST_P(GradientCheckerTest, TestMatMul) {
   float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
   int64_t A_dims[] = {2, 2};
+  AbstractTensorHandlePtr A;
+  {
+    AbstractTensorHandle* A_raw;
+    Status s =
+        TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    A.reset(A_raw);
+  }
   float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
   int64_t B_dims[] = {2, 2};
-  int num_dims = 2;
-
-  AbstractTensorHandlePtr A =
-      GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
-  AbstractTensorHandlePtr B =
-      GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
-
-  std::vector<AbstractTensorHandle*> inputs;
-  inputs.push_back(A.get());
-  inputs.push_back(B.get());
-
-  AbstractTensorHandle* grad_approx;
-  Status s = CalcNumericalGrad(
-      ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0,
-      /*use_function=*/!std::get<2>(GetParam()), &grad_approx);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-
-  TF_Tensor* gt;
-  s = GetValue(grad_approx, &gt);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-  float result_data[4] = {0};
-  memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
+  AbstractTensorHandlePtr B;
+  {
+    AbstractTensorHandle* B_raw;
+    Status s =
+        TestTensorHandleWithDimsFloat(ctx_.get(), B_vals, B_dims, 2, &B_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    B.reset(B_raw);
+  }
 
   float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
-  float tolerance = 1e-2;
-  for (int j = 0; j < 4; j++) {
-    ASSERT_NEAR(expected_dA[j], result_data[j], tolerance);
-  }
-  TF_DeleteTensor(gt);
+  ASSERT_NO_FATAL_FAILURE(CompareNumericalAndManualGradients(
+      MatMulModel, ctx_.get(), {A.get(), B.get()}, 0, expected_dA, 4,
+      UseFunction()));
 }
 
-TEST_P(GradientCheckerTest, TestGradCheckMul) {
-  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
-      TF_NewStatus(), TF_DeleteStatus);
-
-  AbstractContextPtr ctx;
-  {
-    AbstractContext* ctx_raw = nullptr;
-    Status s =
-        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
-    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-    ctx.reset(ctx_raw);
-  }
-
+TEST_P(GradientCheckerTest, TestMul) {
   AbstractTensorHandlePtr x;
   {
     AbstractTensorHandle* x_raw = nullptr;
-    Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
+    Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
     x.reset(x_raw);
   }
@@ -130,124 +145,15 @@
   AbstractTensorHandlePtr y;
   {
     AbstractTensorHandle* y_raw = nullptr;
-    Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
+    Status s = TestScalarTensorHandle(ctx_.get(), 7.0f, &y_raw);
     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
     y.reset(y_raw);
   }
 
-  // Will perform z = x*y.
-  // dz/dx = y
-
-  std::vector<AbstractTensorHandle*> inputs;
-  inputs.push_back(x.get());
-  inputs.push_back(y.get());
-  AbstractTensorHandle* g;
-
-  Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs),
-                               /*input_index=*/0,
-                               /*use_function=*/!std::get<2>(GetParam()), &g);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-
-  TF_Tensor* gt;
-  s = GetValue(g, &gt);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-  float result_data[1] = {0};
-  memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
-
-  ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2);
-  TF_DeleteTensor(gt);
-}
-
-TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
-  bool use_function = !std::get<2>(GetParam());
-  if (use_function) {
-    // TODO(b/168850692): Enable this.
-    GTEST_SKIP() << "Can't take gradient of "
-                    "SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
-  }
-  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
-      TF_NewStatus(), TF_DeleteStatus);
-
-  /** Test to show how to use this API with analytical gradients:
-   *
-   *  We have `SoftmaxLossGradModel`, which is a wrapper for the
-   *  Softmax analytical gradient found in c/experimental/nn_grads.
-   *
-   *  We will use the GradientChecker by applying finite differences
-   *  to the forward pass wrapped in `SoftmaxModel` and verify that
-   *  both the analytical and numerical gradients are relatively
-   *  close.
-   *
-   */
-
-  AbstractContextPtr ctx;
-  {
-    AbstractContext* ctx_raw = nullptr;
-    Status s =
-        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
-    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-    ctx.reset(ctx_raw);
-  }
-
-  // X = scores
-  float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, 1.0f};
-  int64_t X_dims[] = {3, 3};
-  int num_dims = 2;
-  AbstractTensorHandlePtr X =
-      GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
-
-  // y = labels
-  int y_vals[] = {1, 0, 1};
-  int64_t y_dims[] = {3};
-  num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
-  AbstractTensorHandlePtr y =
-      GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
-
-  GradientRegistry registry;
-  Status s = RegisterGradients(&registry);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-
-  std::vector<AbstractTensorHandle*> inputs;
-  inputs.push_back(X.get());
-  inputs.push_back(y.get());
-
-  // Run analytical gradient and get its data.
-  std::vector<AbstractTensorHandle*> outputs(2);
-  s = RunModel(SoftmaxLossGradModel, ctx.get(), absl::MakeSpan(inputs),
-               absl::MakeSpan(outputs),
-               /*use_function=*/!std::get<2>(GetParam()), registry);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-
-  TF_Tensor* dX_tensor;
-  s = GetValue(outputs[0], &dX_tensor);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-
-  float danalytical[9] = {0};  // Contains data from analytical gradient.
-  memcpy(&danalytical[0], TF_TensorData(dX_tensor),
-         TF_TensorByteSize(dX_tensor));
-
-  // Run numerical gradient approximation using the GradientChecker API.
-  AbstractTensorHandle* g;  // Will contain numerical approximation data.
-  s = CalcNumericalGrad(ctx.get(), SoftmaxModel, absl::MakeSpan(inputs),
-                        /*input_index=*/0,
-                        /*use_function=*/!std::get<2>(GetParam()), &g);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-
-  TF_Tensor* gt;
-  s = GetValue(g, &gt);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-  float dnumerical[9] = {0};
-  memcpy(&dnumerical[0], TF_TensorData(gt), TF_TensorByteSize(gt));
-
-  // Now compare the two implementations:
-  for (int j = 0; j < 9; j++) {
-    ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2);
-  }
-
-  // Only Unref() first output as 2nd is nullptr grad for labels
-  outputs[0]->Unref();
-  TF_DeleteTensor(dX_tensor);
-  TF_DeleteTensor(gt);
+  float expected_dx[1] = {7.0f};
+  ASSERT_NO_FATAL_FAILURE(CompareNumericalAndManualGradients(
+      MulModel, ctx_.get(), {x.get(), y.get()}, 0, expected_dx, 1,
+      UseFunction()));
 }
 
 #ifdef PLATFORM_GOOGLE
@@ -255,13 +161,13 @@
     UnifiedCAPI, GradientCheckerTest,
     ::testing::Combine(::testing::Values("graphdef"),
                        /*tfrt*/ ::testing::Values(false),
-                       /*executing_eagerly*/ ::testing::Values(true, false)));
+                       /*use_function*/ ::testing::Values(true, false)));
 #else
 INSTANTIATE_TEST_SUITE_P(
     UnifiedCAPI, GradientCheckerTest,
     ::testing::Combine(::testing::Values("graphdef"),
                        /*tfrt*/ ::testing::Values(false),
-                       /*executing_eagerly*/ ::testing::Values(true, false)));
+                       /*use_function*/ ::testing::Values(true, false)));
 #endif
 }  // namespace
 }  // namespace internal
diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h
index 88696fc..e557753 100644
--- a/tensorflow/c/eager/immediate_execution_context.h
+++ b/tensorflow/c/eager/immediate_execution_context.h
@@ -111,11 +111,11 @@
   // already exists.
   virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
 
-  // Same as `AddFunctionDef`, and additionally saves a pointer to the Graph
-  // which has nodes containing stack traces for the nodes in `fdef`. Assumes
-  // `graph` is alive while the function is alive.
-  virtual Status AddFunctionDefWithDebugInfo(const FunctionDef& fdef,
-                                             const Graph* graph) = 0;
+  // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under
+  // the key of the function definition name (to be retrieved during function
+  // instantiation).
+  virtual Status AddFunctionDefWithStackTraces(
+      const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0;
 
   // Find and return a added function by its name.
   virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc
index 16cb011..65c7ed8 100644
--- a/tensorflow/c/eager/mnist_gradients_test.cc
+++ b/tensorflow/c/eager/mnist_gradients_test.cc
@@ -395,80 +395,6 @@
   TF_DeleteTensor(dX_tensor);
 }
 
-TEST_P(CppGradients, TestSoftmaxLossGrad) {
-  bool use_function = !std::get<2>(GetParam());
-  if (use_function) {
-    // TODO(b/168850692): Enable this.
-    GTEST_SKIP() << "Can't take gradient of "
-                    "SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
-  }
-  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
-      TF_NewStatus(), TF_DeleteStatus);
-
-  AbstractContextPtr ctx;
-  {
-    AbstractContext* ctx_raw = nullptr;
-    Status s =
-        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
-    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-    ctx.reset(ctx_raw);
-  }
-
-  // X = scores
-  float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
-  int64_t X_dims[] = {3, 3};
-  int num_dims = 2;
-  AbstractTensorHandlePtr X =
-      GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
-
-  // y = labels
-  int y_vals[] = {1, 0, 1};
-  int64_t y_dims[] = {3};
-  num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
-  AbstractTensorHandlePtr y =
-      GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
-
-  GradientRegistry registry;
-  Status s = RegisterGradients(&registry);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-
-  /* Pseudo-code:
-   *
-   * tape.watch(X)
-   * tape.watch(labels)
-   * loss = SoftmaxLoss(X, labels)
-   * outputs = tape.gradient(loss, [X, labels])
-   *
-   *
-   */
-
-  std::vector<AbstractTensorHandle*> outputs(2);
-  s = RunModel(SoftmaxLossGradModel, ctx.get(), {X.get(), y.get()},
-               absl::MakeSpan(outputs),
-               /*use_function=*/!std::get<2>(GetParam()), registry);
-
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-
-  TF_Tensor* dX_tensor;
-  s = GetValue(outputs[0], &dX_tensor);
-  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
-
-  float result_data[9] = {0};
-  memcpy(&result_data[0], TF_TensorData(dX_tensor),
-         TF_TensorByteSize(dX_tensor));
-
-  float expected_dX[9] = {0.090f,  -0.7553f, 0.6652f,  -0.9099f, 0.2447f,
-                          0.6652f, 0.8437f,  -0.8858f, 0.0420f};
-  float tolerance = 1e-3;
-  for (int j = 0; j < 9; j++) {
-    ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
-  }
-
-  // Only Unref() first output as 2nd is nullptr grad for labels
-  outputs[0]->Unref();
-  TF_DeleteTensor(dX_tensor);
-}
-
 TEST_P(CppGradients, TestMNISTGrad) {
   bool use_function = !std::get<2>(GetParam());
   if (use_function) {
diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc
index eb9b917..b555b6c 100644
--- a/tensorflow/c/eager/mnist_gradients_testutil.cc
+++ b/tensorflow/c/eager/mnist_gradients_testutil.cc
@@ -31,7 +31,6 @@
 #include "tensorflow/c/experimental/ops/nn_ops.h"
 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
 
-
 namespace tensorflow {
 namespace gradients {
 namespace internal {
@@ -184,27 +183,6 @@
   return Status::OK();
 }
 
-Status SoftmaxLossGradModel(AbstractContext* ctx,
-                            absl::Span<AbstractTensorHandle* const> inputs,
-                            absl::Span<AbstractTensorHandle*> outputs,
-                            const GradientRegistry& registry) {
-  auto tape = new Tape(/*persistent=*/false);
-  tape->Watch(inputs[0]);  // Watch scores.
-  tape->Watch(inputs[1]);  // Watch labels.
-  vector<AbstractTensorHandle*> sm_outputs(2);
-  AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
-  TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
-      tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
-
-  TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx,
-                                           /*targets=*/sm_outputs,
-                                           /*sources=*/inputs,
-                                           /*output_gradients=*/{}, outputs));
-
-  delete tape;
-  return Status::OK();
-}
-
 Status MNISTGradModel(AbstractContext* ctx,
                       absl::Span<AbstractTensorHandle* const> inputs,
                       absl::Span<AbstractTensorHandle*> outputs,
@@ -283,14 +261,6 @@
                   "mul0");  // Compute x*y
 }
 
-Status SoftmaxModel(AbstractContext* ctx,
-                    absl::Span<AbstractTensorHandle* const> inputs,
-                    absl::Span<AbstractTensorHandle*> outputs,
-                    const GradientRegistry& registry) {
-  return ops::SparseSoftmaxCrossEntropyWithLogits(ctx, inputs, outputs,
-                                                  "sm_loss");
-}
-
 // ============================= End Models ================================
 
 }  // namespace internal
diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h
index b173446..508e5c8 100644
--- a/tensorflow/c/eager/mnist_gradients_testutil.h
+++ b/tensorflow/c/eager/mnist_gradients_testutil.h
@@ -29,7 +29,6 @@
 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
 #include "tensorflow/core/platform/status.h"
 
-
 namespace tensorflow {
 namespace gradients {
 namespace internal {
@@ -68,12 +67,6 @@
                      absl::Span<AbstractTensorHandle*> outputs,
                      const GradientRegistry& registry);
 
-// Test Model to verify SoftmaxGrad functionality
-Status SoftmaxLossGradModel(AbstractContext* ctx,
-                            absl::Span<AbstractTensorHandle* const> inputs,
-                            absl::Span<AbstractTensorHandle*> outputs,
-                            const GradientRegistry& registry);
-
 // Test Model to verify Multi-grad functionality for MNIST
 Status MNISTGradModel(AbstractContext* ctx,
                       absl::Span<AbstractTensorHandle* const> inputs,
@@ -96,11 +89,6 @@
                 absl::Span<AbstractTensorHandle*> outputs,
                 const GradientRegistry& registry);
 
-Status SoftmaxModel(AbstractContext* ctx,
-                    absl::Span<AbstractTensorHandle* const> inputs,
-                    absl::Span<AbstractTensorHandle*> outputs,
-                    const GradientRegistry& registry);
-
 }  // namespace internal
 }  // namespace gradients
 }  // namespace tensorflow
diff --git a/tensorflow/c/eager/unified_api_test.cc b/tensorflow/c/eager/unified_api_test.cc
index 52de726..a3ae37c 100644
--- a/tensorflow/c/eager/unified_api_test.cc
+++ b/tensorflow/c/eager/unified_api_test.cc
@@ -119,7 +119,7 @@
   {
     AbstractTensorHandle* x_raw = nullptr;
     float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
-    int64 dim_sizes[] = {2, 4};
+    int64_t dim_sizes[] = {2, 4};
     Status s =
         TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
diff --git a/tensorflow/c/eager/unified_api_testutil.cc b/tensorflow/c/eager/unified_api_testutil.cc
index 9e8683d..9907fc2 100644
--- a/tensorflow/c/eager/unified_api_testutil.cc
+++ b/tensorflow/c/eager/unified_api_testutil.cc
@@ -144,18 +144,43 @@
 }
 
 Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
-                                     int64* dims, int num_dims,
+                                     int64_t* dims, int num_dims,
                                      AbstractTensorHandle** tensor) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TFE_Context* eager_ctx =
       TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
   TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
-  TFE_TensorHandle* input_eager = TestTensorHandleWithDimsFloat(
-      eager_ctx, data, reinterpret_cast<int64_t*>(dims), num_dims);
+  TFE_TensorHandle* input_eager =
+      TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
   *tensor =
       unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
   return Status::OK();
 }
 
+Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
+                                   int64_t* dims, int num_dims,
+                                   AbstractTensorHandle** tensor) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  TFE_Context* eager_ctx =
+      TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
+  TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
+  TFE_TensorHandle* input_eager =
+      TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
+  *tensor =
+      unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
+  return Status::OK();
+}
+
+Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  TFE_TensorHandle* result_t =
+      TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
+  TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
+  *result_tensor = TFE_TensorHandleResolve(result_t, status.get());
+  return StatusFromTF_Status(status.get());
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/c/eager/unified_api_testutil.h b/tensorflow/c/eager/unified_api_testutil.h
index eb8d0ff..39bf553 100644
--- a/tensorflow/c/eager/unified_api_testutil.h
+++ b/tensorflow/c/eager/unified_api_testutil.h
@@ -17,6 +17,7 @@
 
 #include "tensorflow/c/eager/abstract_context.h"
 #include "tensorflow/c/eager/abstract_tensor_handle.h"
+#include "tensorflow/c/tf_tensor.h"
 #include "tensorflow/core/platform/status.h"
 
 namespace tensorflow {
@@ -54,8 +55,16 @@
 
 // Get a Matrix TensorHandle with given float values and dimensions.
 Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
-                                     int64* dims, int num_dims,
+                                     int64_t* dims, int num_dims,
                                      AbstractTensorHandle** tensor);
+
+// Get a TensorHandle with given int values and dimensions
+Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
+                                   int64_t* dims, int num_dims,
+                                   AbstractTensorHandle** tensor);
+
+// Places data from `t` into *result_tensor.
+Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc
index 8cd8ad7..b6c0a40 100644
--- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc
+++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc
@@ -81,7 +81,7 @@
     return;
   }
 
-  size_t bucket_end = fname.find("/", scheme_end + 1);
+  size_t bucket_end = fname.find('/', scheme_end + 1);
   if (bucket_end == std::string::npos) {
     TF_SetStatus(status, TF_INVALID_ARGUMENT,
                  "GCS path doesn't contain a bucket name.");
diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc
index 50a9f54..67eaa23 100644
--- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc
+++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc
@@ -38,7 +38,7 @@
   size_t scheme_end = fname.find("://") + 2;
   // We don't want `://` in scheme.
   *scheme = fname.substr(0, scheme_end - 2);
-  size_t nn_end = fname.find("/", scheme_end + 1);
+  size_t nn_end = fname.find('/', scheme_end + 1);
   if (nn_end == std::string::npos) {
     *namenode = fname.substr(scheme_end + 1);
     *path = "";
diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD
index b060f4d..11fceda 100644
--- a/tensorflow/c/experimental/gradients/BUILD
+++ b/tensorflow/c/experimental/gradients/BUILD
@@ -4,6 +4,7 @@
 # buildifier: disable=same-origin-load
 load(
     "//tensorflow:tensorflow.bzl",
+    "if_libtpu",
     "tf_cuda_cc_test",
 )
 load(
@@ -165,7 +166,7 @@
     ],
     deps = [
         "//tensorflow/c/eager:gradient_checker",
-        "//tensorflow/c/eager:gradients_util",
+        "//tensorflow/c/eager:unified_api_testutil",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
     ],
@@ -183,9 +184,14 @@
     deps = [
         ":grad_test_helper",
         ":nn_grad",
+        "//tensorflow/c:tf_status_helper",
         "//tensorflow/c/eager:c_api_test_util",
         "//tensorflow/c/experimental/gradients/tape:tape_context",
+        "//tensorflow/c/experimental/ops:nn_ops",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
-    ],
+    ] + if_libtpu(
+        if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
+        if_true = [],
+    ),
 )
diff --git a/tensorflow/c/experimental/gradients/custom_gradient_test.cc b/tensorflow/c/experimental/gradients/custom_gradient_test.cc
index 9ca0187..16fb339 100644
--- a/tensorflow/c/experimental/gradients/custom_gradient_test.cc
+++ b/tensorflow/c/experimental/gradients/custom_gradient_test.cc
@@ -86,16 +86,6 @@
   return Status::OK();
 }
 
-Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
-  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
-      TF_NewStatus(), TF_DeleteStatus);
-  TFE_TensorHandle* result_t =
-      TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
-  TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
-  *result_tensor = TFE_TensorHandleResolve(result_t, status.get());
-  return Status::OK();
-}
-
 TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
@@ -128,7 +118,7 @@
   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
 
   TF_Tensor* result_tensor;
-  s = getValue(outputs[0], &result_tensor);
+  s = GetValue(outputs[0], &result_tensor);
   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
   auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
   EXPECT_EQ(*result_value, 1.0);
diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.cc b/tensorflow/c/experimental/gradients/grad_test_helper.cc
index 4031f8c..e7e9471 100644
--- a/tensorflow/c/experimental/gradients/grad_test_helper.cc
+++ b/tensorflow/c/experimental/gradients/grad_test_helper.cc
@@ -24,24 +24,28 @@
 void CompareNumericalAndAutodiffGradients(
     Model model, Model grad_model, AbstractContext* ctx,
     absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
-    const GradientRegistry& registry, double abs_error) {
+    double abs_error) {
   auto num_inputs = inputs.size();
   std::vector<AbstractTensorHandle*> outputs(num_inputs);
   auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs),
-                    /*use_function=*/use_function, registry);
+                    /*use_function=*/use_function);
   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
 
   for (int i = 0; i < num_inputs; ++i) {
     if (!outputs[i]) continue;
 
-    AbstractTensorHandle* g;  // Will contain numerical approximation data.
-    s = CalcNumericalGrad(ctx, model, inputs,
-                          /*input_index=*/i,
-                          /*use_function=*/use_function, &g);
-    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    AbstractTensorHandlePtr numerical_grad;
+    {
+      AbstractTensorHandle* numerical_grad_raw;
+      s = CalcNumericalGrad(ctx, model, inputs,
+                            /*input_index=*/i, use_function,
+                            &numerical_grad_raw);
+      ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+      numerical_grad.reset(numerical_grad_raw);
+    }
 
     TF_Tensor* numerical_tensor;
-    s = GetValue(g, &numerical_tensor);
+    s = GetValue(numerical_grad.get(), &numerical_tensor);
     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
     auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);
 
diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.h b/tensorflow/c/experimental/gradients/grad_test_helper.h
index 78b2d5b..f4902a0 100644
--- a/tensorflow/c/experimental/gradients/grad_test_helper.h
+++ b/tensorflow/c/experimental/gradients/grad_test_helper.h
@@ -15,7 +15,7 @@
 #ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_
 #define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_
 
-#include "tensorflow/c/eager/gradients_util.h"
+#include "tensorflow/c/eager/unified_api_testutil.h"
 
 namespace tensorflow {
 namespace gradients {
@@ -24,7 +24,7 @@
 void CompareNumericalAndAutodiffGradients(
     Model model, Model grad_model, AbstractContext* ctx,
     absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
-    const GradientRegistry& registry, double abs_error = 1e-2);
+    double abs_error = 1e-2);
 
 }  // namespace internal
 }  // namespace gradients
diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc
index 578229b..e4b002d 100644
--- a/tensorflow/c/experimental/gradients/nn_grad_test.cc
+++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc
@@ -15,8 +15,11 @@
 #include "tensorflow/c/experimental/gradients/nn_grad.h"
 
 #include "tensorflow/c/eager/c_api_test_util.h"
+#include "tensorflow/c/eager/unified_api_testutil.h"
 #include "tensorflow/c/experimental/gradients/grad_test_helper.h"
 #include "tensorflow/c/experimental/gradients/tape/tape_context.h"
+#include "tensorflow/c/experimental/ops/nn_ops.h"
+#include "tensorflow/c/tf_status_helper.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
@@ -26,17 +29,60 @@
 
 using tensorflow::TF_StatusPtr;
 
+Status SparseSoftmaxCrossEntropyWithLogitsModel(
+    AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
+    absl::Span<AbstractTensorHandle*> outputs) {
+  std::vector<AbstractTensorHandle*> temp_outputs(2);
+  TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
+      ctx, inputs, absl::MakeSpan(temp_outputs),
+      "SparseSoftmaxCrossEntropyWithLogits"));
+  // `gradient_checker` only works with model that returns only 1 tensor.
+  // Although, `ops::SparseSoftmaxCrossEntropyWithLogits` returns 2 tensors, the
+  // second tensor isn't needed for computing gradient so we could safely drop
+  // it.
+  outputs[0] = temp_outputs[0];
+  temp_outputs[1]->Unref();
+  return Status::OK();
+}
+
+Status SparseSoftmaxCrossEntropyWithLogitsGradModel(
+    AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
+    absl::Span<AbstractTensorHandle*> outputs) {
+  GradientRegistry registry;
+  TF_RETURN_IF_ERROR(
+      registry.Register("SparseSoftmaxCrossEntropyWithLogits",
+                        SparseSoftmaxCrossEntropyWithLogitsRegisterer));
+
+  Tape tape(/*persistent=*/false);
+  tape.Watch(inputs[0]);  // Watch score.
+  tape.Watch(inputs[1]);  // Watch label.
+  std::vector<AbstractTensorHandle*> temp_outputs(2);
+  AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
+  TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
+      tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs),
+      "SparseSoftmaxCrossEntropyWithLogitsGrad"));
+
+  TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
+                                          /*sources=*/inputs,
+                                          /*output_gradients=*/{}, outputs));
+  for (auto temp_output : temp_outputs) {
+    temp_output->Unref();
+  }
+  return Status::OK();
+}
+
 Status BiasAddModel(AbstractContext* ctx,
                     absl::Span<AbstractTensorHandle* const> inputs,
-                    absl::Span<AbstractTensorHandle*> outputs,
-                    const GradientRegistry& registry) {
+                    absl::Span<AbstractTensorHandle*> outputs) {
   return ops::BiasAdd(ctx, inputs, outputs, "BiasAdd");
 }
 
 Status BiasAddGradModel(AbstractContext* ctx,
                         absl::Span<AbstractTensorHandle* const> inputs,
-                        absl::Span<AbstractTensorHandle*> outputs,
-                        const GradientRegistry& registry) {
+                        absl::Span<AbstractTensorHandle*> outputs) {
+  GradientRegistry registry;
+  TF_RETURN_IF_ERROR(registry.Register("BiasAdd", BiasAddRegisterer));
+
   Tape tape(/*persistent=*/false);
   tape.Watch(inputs[0]);  // Watch A.
   tape.Watch(inputs[1]);  // Watch Bias.
@@ -54,11 +100,6 @@
   return Status::OK();
 }
 
-Status RegisterGradients(GradientRegistry* registry) {
-  TF_RETURN_IF_ERROR(registry->Register("BiasAdd", BiasAddRegisterer));
-  return Status::OK();
-}
-
 class CppGradients
     : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
  protected:
@@ -66,7 +107,7 @@
     TF_StatusPtr status(TF_NewStatus());
     TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
     Status s = StatusFromTF_Status(status.get());
-    CHECK_EQ(errors::OK, s.code()) << s.error_message();
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
 
     {
       AbstractContext* ctx_raw = nullptr;
@@ -75,12 +116,8 @@
       ASSERT_EQ(errors::OK, s.code()) << s.error_message();
       ctx_.reset(ctx_raw);
     }
-
-    s = RegisterGradients(&registry_);
-    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
   }
 
-  GradientRegistry registry_;
   AbstractContextPtr ctx_;
 
  public:
@@ -88,6 +125,43 @@
   bool UseFunction() const { return std::get<2>(GetParam()); }
 };
 
+TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {
+  if (UseFunction()) {
+    // TODO(b/168850692): Enable this.
+    GTEST_SKIP() << "Can't take gradient of "
+                    "SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
+  }
+
+  // Score
+  float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
+  int64_t X_dims[] = {3, 3};
+  AbstractTensorHandlePtr X;
+  {
+    AbstractTensorHandle* X_raw;
+    Status s =
+        TestTensorHandleWithDimsFloat(ctx_.get(), X_vals, X_dims, 2, &X_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    X.reset(X_raw);
+  }
+  // Label
+  int Y_vals[] = {1, 0, 1};
+  int64_t Y_dims[] = {3};
+  AbstractTensorHandlePtr Y;
+  {
+    AbstractTensorHandle* Y_raw;
+    Status s =
+        TestTensorHandleWithDimsInt(ctx_.get(), Y_vals, Y_dims, 1, &Y_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    Y.reset(Y_raw);
+  }
+
+  ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
+      SparseSoftmaxCrossEntropyWithLogitsModel,
+      SparseSoftmaxCrossEntropyWithLogitsGradModel, ctx_.get(),
+      {X.get(), Y.get()},
+      /*use_function=*/UseFunction()));
+}
+
 TEST_P(CppGradients, TestBiasAddGrad) {
   if (UseFunction() && UseMlir()) {
     GTEST_SKIP() << "SetAttrString has not been implemented yet.\n";
@@ -96,19 +170,29 @@
   // A
   float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
   int64_t A_dims[] = {2, 2};
-  AbstractTensorHandlePtr A =
-      GetTensorHandleUtilFloat(ctx_.get(), A_vals, A_dims, 2);
+  AbstractTensorHandlePtr A;
+  {
+    AbstractTensorHandle* A_raw;
+    Status s =
+        TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    A.reset(A_raw);
+  }
   // Bias
   float Bias_vals[] = {2.0f, 3.0f};
   int64_t Bias_dims[] = {2};
-  AbstractTensorHandlePtr Bias =
-      GetTensorHandleUtilFloat(ctx_.get(), Bias_vals, Bias_dims, 1);
-
-  std::vector<AbstractTensorHandle*> inputs{A.get(), Bias.get()};
+  AbstractTensorHandlePtr Bias;
+  {
+    AbstractTensorHandle* Bias_raw;
+    Status s = TestTensorHandleWithDimsFloat(ctx_.get(), Bias_vals, Bias_dims,
+                                             1, &Bias_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    Bias.reset(Bias_raw);
+  }
 
   ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
       BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()},
-      /*use_function=*/UseFunction(), registry_));
+      /*use_function=*/UseFunction()));
 }
 
 #ifdef PLATFORM_GOOGLE
diff --git a/tensorflow/c/experimental/stream_executor/test/BUILD b/tensorflow/c/experimental/stream_executor/test/BUILD
new file mode 100644
index 0000000..ca8bdaf
--- /dev/null
+++ b/tensorflow/c/experimental/stream_executor/test/BUILD
@@ -0,0 +1,17 @@
+# Description:
+# test for stream_executor
+load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_cc_shared_object",
+)
+
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+tf_cc_shared_object(
+    name = "test_pluggable_device.so",
+    srcs = ["test_pluggable_device.cc"],
+    visibility = ["//tensorflow/c:__subpackages__"],
+    deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
+)
diff --git a/tensorflow/lite/java/src/main/native/op_resolver.h b/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc
similarity index 67%
copy from tensorflow/lite/java/src/main/native/op_resolver.h
copy to tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc
index 08ff0ce..d985f3c 100644
--- a/tensorflow/lite/java/src/main/native/op_resolver.h
+++ b/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc
@@ -12,17 +12,12 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
-#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
 
-#include <memory>
+#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
 
-#include "tensorflow/lite/op_resolver.h"
-
-namespace tflite {
-
-std::unique_ptr<OpResolver> CreateOpResolver();
-
+void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
+                   TF_Status* const status) {
+  params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
+  params->platform->name = "GPU";
+  params->platform->type = "XGPU";
 }
-
-#endif  // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc
index 27f98be..d33a91b 100644
--- a/tensorflow/c/kernels.cc
+++ b/tensorflow/c/kernels.cc
@@ -32,6 +32,7 @@
 #include "tensorflow/stream_executor/stream.h"
 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
 
+using tensorflow::errors::InvalidArgument;
 // This file forms the basis of a stable ABI for third-party kernel
 // implementations. It is crucial that changes to this file are made cautiously
 // and with a focus on maintaining both source and binary compatibility.
@@ -87,9 +88,25 @@
   TF_SetStatus(status, TF_OK, "");
 }
 #undef CASE
+
 }  // namespace
 }  // namespace tensorflow
 
+namespace {
+const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
+                                          const char* attr_name,
+                                          TF_Status* status) {
+  auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
+  const tensorflow::AttrValue* attr =
+      ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
+  if (attr == nullptr) {
+    status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
+                                     "' has no attr named '", attr_name, "'.");
+  }
+  return attr;
+}
+}  // namespace
+
 void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
                                      const char* attr_name,
                                      const TF_DataType type,
@@ -257,7 +274,81 @@
   cc_ctx->CtxFailure(s);
 }
 
-#define DEFINE_TF_GETATTR(func, c_type, cc_type)                               \
+void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
+                                         const char* attr_name,
+                                         int32_t* list_size,
+                                         int32_t* total_size,
+                                         TF_Status* status) {
+  const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
+  if (!status->status.ok()) {
+    *list_size = -1;
+    *total_size = -1;
+    return;
+  }
+  switch (attr->value_case()) {
+#define SINGLE_CASE(kK, attr_type, size_expr) \
+  case tensorflow::AttrValue::kK:             \
+    *list_size = -1;                          \
+    *total_size = size_expr;                  \
+    break;
+
+    SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
+    SINGLE_CASE(kI, TF_ATTR_INT, -1);
+    SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
+    SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
+    SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
+    SINGLE_CASE(kShape, TF_ATTR_SHAPE,
+                attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
+    SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
+#undef SINGLE_CASE
+
+    case tensorflow::AttrValue::kList:
+      *list_size = 0;
+      *total_size = -1;
+#define LIST_CASE(field, attr_type, ...)      \
+  if (attr->list().field##_size() > 0) {      \
+    *list_size = attr->list().field##_size(); \
+    __VA_ARGS__;                              \
+    break;                                    \
+  }
+
+      LIST_CASE(
+          s, TF_ATTR_STRING, *total_size = 0;
+          for (int i = 0; i < attr->list().s_size();
+               ++i) { *total_size += attr->list().s(i).size(); });
+      LIST_CASE(i, TF_ATTR_INT);
+      LIST_CASE(f, TF_ATTR_FLOAT);
+      LIST_CASE(b, TF_ATTR_BOOL);
+      LIST_CASE(type, TF_ATTR_TYPE);
+      LIST_CASE(
+          shape, TF_ATTR_SHAPE, *total_size = 0;
+          for (int i = 0; i < attr->list().shape_size(); ++i) {
+            const auto& s = attr->list().shape(i);
+            *total_size += s.unknown_rank() ? 0 : s.dim_size();
+          });
+      LIST_CASE(tensor, TF_ATTR_TENSOR);
+      LIST_CASE(tensor, TF_ATTR_FUNC);
+#undef LIST_CASE
+      break;
+
+    case tensorflow::AttrValue::kPlaceholder:
+      *list_size = -1;
+      *total_size = -1;
+      break;
+
+    case tensorflow::AttrValue::kFunc:
+      *list_size = -1;
+      *total_size = -1;
+      break;
+
+    case tensorflow::AttrValue::VALUE_NOT_SET:
+      status->status =
+          InvalidArgument("Attribute '", attr_name, "' has no value set");
+      break;
+  }
+}
+
+#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field)        \
   void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx,     \
                                              const char* attr_name,            \
                                              c_type* val, TF_Status* status) { \
@@ -269,10 +360,84 @@
     if (s.ok()) {                                                              \
       *val = static_cast<c_type>(v);                                           \
     }                                                                          \
+  }                                                                            \
+  void TF_OpKernelConstruction_GetAttr##func##List(                            \
+      TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals,       \
+      int max_vals, TF_Status* status) {                                       \
+    TF_SetStatus(status, TF_OK, "");                                           \
+    const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);  \
+    if (!status->status.ok()) return;                                          \
+    if (attr->value_case() != tensorflow::AttrValue::kList) {                  \
+      status->status =                                                         \
+          InvalidArgument("Value for '", attr_name, "' is not a list.");       \
+      return;                                                                  \
+    }                                                                          \
+    status->status =                                                           \
+        tensorflow::AttrValueHasType(*attr, "list(" attr_type ")");            \
+    if (!status->status.ok()) return;                                          \
+    const auto len = std::min(max_vals, attr->list().list_field##_size());     \
+    for (int i = 0; i < len; ++i) {                                            \
+      vals[i] = static_cast<c_type>(attr->list().list_field(i));               \
+    }                                                                          \
   }
 
-DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
-DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
+DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
+DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
+DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
+DEFINE_TF_GETATTR(Float, float, float, "float", f)
+DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
+
+void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
+                                           const char* attr_name, char* value,
+                                           size_t max_length,
+                                           TF_Status* status) {
+  std::string v;
+  auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
+  ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
+  ::tensorflow::Set_TF_Status_from_Status(status, s);
+
+  if (!status->status.ok()) return;
+
+  if (max_length <= 0) {
+    return;
+  }
+  std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
+}
+
+void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
+                                               const char* attr_name,
+                                               char** values, size_t* lengths,
+                                               int max_values, void* storage,
+                                               size_t storage_size,
+                                               TF_Status* status) {
+  std::vector<std::string> v;
+  auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
+  ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
+  ::tensorflow::Set_TF_Status_from_Status(status, s);
+
+  if (!status->status.ok()) return;
+
+  const auto len = std::min(max_values, static_cast<int>(v.size()));
+  char* p = static_cast<char*>(storage);
+  for (int i = 0; i < len; ++i) {
+    const std::string& s = v[i];
+    values[i] = p;
+    lengths[i] = s.size();
+    if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
+      status->status = InvalidArgument(
+          "Not enough storage to hold the requested list of strings");
+      return;
+    }
+    memcpy(values[i], s.data(), s.size());
+    p += s.size();
+  }
+}
+
+bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
+                                     const char* attr_name, TF_Status* status) {
+  auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
+  return cc_ctx->HasAttr(attr_name);
+}
 
 TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
   auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h
index 34848a1..508d59b 100644
--- a/tensorflow/c/kernels.h
+++ b/tensorflow/c/kernels.h
@@ -184,6 +184,24 @@
 // Returns the step ID of the given context.
 TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
 
+// Get the list_size and total_size of the attribute `attr_name` of `oper`.
+// list_size - the length of the list.
+// total_size - total size of the list.
+//   (1) If attr_type == TF_ATTR_STRING
+//       then total_size is the cumulative byte size
+//       of all the strings in the list.
+//   (3) If attr_type == TF_ATTR_SHAPE
+//       then total_size is the number of dimensions
+//       of the shape valued attribute, or -1
+//       if its rank is unknown.
+//   (4) If attr_type == TF_ATTR_SHAPE
+//       then total_size is the cumulative number
+//       of dimensions of all shapes in the list.
+//   (5) Otherwise, total_size is undefined.
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize(
+    TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size,
+    int32_t* total_size, TF_Status* status);
+
 // Interprets the named kernel construction attribute as a TF_DataType and
 // places it into *val. *status is set to TF_OK.
 //
@@ -202,6 +220,112 @@
     TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
     TF_Status* status);
 
+// Interprets the named kernel construction attribute as int64_t and
+// places it into *val. *status is set to TF_OK.
+//
+// If the attribute could not be found or could not be interpreted as
+// int64, *status is populated with an error.
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64(
+    TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val,
+    TF_Status* status);
+
+// Interprets the named kernel construction attribute as float and
+// places it into *val. *status is set to TF_OK.
+//
+// If the attribute could not be found or could not be interpreted as
+// float, *status is populated with an error.
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat(
+    TF_OpKernelConstruction* ctx, const char* attr_name, float* val,
+    TF_Status* status);
+
+// Interprets the named kernel construction attribute as bool and
+// places it into *val. *status is set to TF_OK.
+//
+// If the attribute could not be found or could not be interpreted as
+// bool, *status is populated with an error.
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool(
+    TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val,
+    TF_Status* status);
+
+// Interprets the named kernel construction attribute as string and
+// places it into *val. `val` must
+// point to an array of length at least `max_length` (ideally set to
+// total_size from TF_OpKernelConstruction_GetAttrSize(ctx,
+// attr_name, list_size, total_size)). *status is set to TF_OK.
+//
+// If the attribute could not be found or could not be interpreted as
+// string, *status is populated with an error.
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString(
+    TF_OpKernelConstruction* ctx, const char* attr_name, char* val,
+    size_t max_length, TF_Status* status);
+
+// Interprets the named kernel construction attribute as a TF_DataType array and
+// places it into *vals. *status is set to TF_OK.
+// `vals` must point to an array of length at least `max_values` (ideally set
+// to list_size from
+// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
+// total_size)).
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList(
+    TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals,
+    int max_vals, TF_Status* status);
+
+// Interprets the named kernel construction attribute as int32_t array and
+// places it into *vals. *status is set to TF_OK.
+// `vals` must point to an array of length at least `max_values` (ideally set
+// to list_size from
+// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
+// total_size)).
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List(
+    TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals,
+    int max_vals, TF_Status* status);
+
+// Interprets the named kernel construction attribute as int64_t array and
+// places it into *vals. *status is set to TF_OK.
+// `vals` must point to an array of length at least `max_values` (ideally set
+// to list_size from
+// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
+// total_size)).
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List(
+    TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals,
+    int max_vals, TF_Status* status);
+
+// Interprets the named kernel construction attribute as float array and
+// places it into *vals. *status is set to TF_OK.
+// `vals` must point to an array of length at least `max_values` (ideally set
+// to list_size from
+// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
+// total_size)).
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList(
+    TF_OpKernelConstruction* ctx, const char* attr_name, float* vals,
+    int max_vals, TF_Status* status);
+
+// Interprets the named kernel construction attribute as bool array and
+// places it into *vals. *status is set to TF_OK.
+// `vals` must point to an array of length at least `max_values` (ideally set
+// to list_size from
+// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
+// total_size)).
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList(
+    TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals,
+    int max_vals, TF_Status* status);
+
+// Interprets the named kernel construction attribute as string array and fills
+// in `vals` and `lengths`, each of which must point to an array of length at
+// least `max_values`. *status is set to TF_OK. The elements of values will
+// point to addresses in `storage` which must be at least `storage_size` bytes
+// in length. Ideally, max_values would be set to list_size and `storage` would
+// be at least total_size, obtained from
+// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
+// total_size).
+TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList(
+    TF_OpKernelConstruction* ctx, const char* attr_name, char** vals,
+    size_t* lengths, int max_values, void* storage, size_t storage_size,
+    TF_Status* status);
+
+// Return true if the kernel construction has the attr_name
+TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr(
+    TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);
+
 // Returns the unique operation name for this OpKernel.
 TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
     TF_OpKernelConstruction* ctx);
diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc
index 49a168a..3fa4cd8 100644
--- a/tensorflow/c/kernels_test.cc
+++ b/tensorflow/c/kernels_test.cc
@@ -161,6 +161,336 @@
   ASSERT_TRUE(delete_called);
 }
 
+// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
+// Registers two ops, each with a single attribute called 'Attr'.
+// The attribute in one op will have a type 'type', the other
+// will have list(type).
+#define ATTR_TEST_REGISTER_OP(name, type)                     \
+  REGISTER_OP("TestKernelAttr" #name)                         \
+      .Attr("Attr: " #type)                                   \
+      .SetShapeFn(tensorflow::shape_inference::UnknownShape); \
+  REGISTER_OP("TestKernelAttr" #name "List")                  \
+      .Attr("Attr: list(" #type ")")                          \
+      .SetShapeFn(tensorflow::shape_inference::UnknownShape)
+ATTR_TEST_REGISTER_OP(String, string);
+ATTR_TEST_REGISTER_OP(Int, int);
+ATTR_TEST_REGISTER_OP(Float, float);
+ATTR_TEST_REGISTER_OP(Bool, bool);
+ATTR_TEST_REGISTER_OP(Type, type);
+#undef ATTR_TEST_REGISTER_OP
+
+// Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
+#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
+  do {                                                                     \
+    int32_t list_size, total_size;                                         \
+    TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size,        \
+                                        &total_size, status);              \
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);            \
+    EXPECT_EQ(expected_list_size, list_size);                              \
+    EXPECT_EQ(expected_total_size, total_size);                            \
+  } while (0)
+
+typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
+class TestKernelAttr : public ::testing::Test {
+ public:
+  TestKernelAttr() {}
+  ~TestKernelAttr() override {}
+
+  std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
+                                                  AttrValue v, Status* status) {
+    NodeDef def;
+    def.set_op(op_name);
+    def.set_name("FakeNode");
+    def.set_device("FakeDevice");
+    (*def.mutable_attr())["Attr"] = v;
+    return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
+                          status);
+  }
+
+  void CreateAndCallKernelWithAttr(MyCreateFuncWithAttr MyCreateFuncAttr,
+                                   const char* op_name, AttrValue& v) {
+    TF_KernelBuilder* builder = TF_NewKernelBuilder(
+        op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
+    {
+      TF_Status* status = TF_NewStatus();
+      TF_RegisterKernelBuilder("FakeNode", builder, status);
+      EXPECT_EQ(TF_OK, TF_GetCode(status));
+      TF_DeleteStatus(status);
+    }
+    Status status;
+    std::unique_ptr<OpKernel> kernel =
+        GetFakeKernelWithAttr(op_name, v, &status);
+    TF_EXPECT_OK(status);
+    ASSERT_NE(nullptr, kernel.get());
+    kernel->Compute(nullptr);
+
+    ASSERT_TRUE(delete_called);
+  }
+};
+
+TEST_F(TestKernelAttr, String) {
+  auto my_create_func = [](TF_OpKernelConstruction* ctx) {
+    struct MyCustomKernel* s = new struct MyCustomKernel;
+    s->created = true;
+    s->compute_called = false;
+
+    std::unique_ptr<char[]> val(new char[5]);
+    TF_Status* status = TF_NewStatus();
+    EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
+                   /*expected_total_size*/ 5);
+    TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
+                                          /*max_length*/ 5, status);
+
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
+    TF_DeleteStatus(status);
+    return static_cast<void*>(s);
+  };
+
+  AttrValue v;
+  v.set_s("bunny");
+  CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrString", v);
+}
+
+TEST_F(TestKernelAttr, StringList) {
+  auto my_create_func = [](TF_OpKernelConstruction* ctx) {
+    struct MyCustomKernel* s = new struct MyCustomKernel;
+    s->created = true;
+    s->compute_called = false;
+
+    std::vector<string> list = {"bugs", "bunny", "duck"};
+    int list_total_size = 0;
+    for (const auto& s : list) {
+      list_total_size += s.size();
+    }
+
+    TF_Status* status = TF_NewStatus();
+    std::unique_ptr<char*[]> values(new char*[list.size()]);
+    std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
+    std::unique_ptr<char[]> storage(new char[list_total_size]);
+    EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
+                   /*expected_total_size*/ list_total_size);
+    TF_OpKernelConstruction_GetAttrStringList(
+        ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
+        list_total_size, status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+    for (size_t i = 0; i < list.size(); ++i) {
+      EXPECT_EQ(list[i].size(), lens[i]) << i;
+      EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
+          << i;
+    }
+    TF_DeleteStatus(status);
+    return static_cast<void*>(s);
+  };
+
+  AttrValue v;
+  std::string attr_in[] = {"bugs", "bunny", "duck"};
+  SetAttrValue(gtl::ArraySlice<std::string>(attr_in, 3), &v);
+  CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrStringList", v);
+}
+
+TEST_F(TestKernelAttr, Int) {
+  auto my_create_func = [](TF_OpKernelConstruction* ctx) {
+    struct MyCustomKernel* s = new struct MyCustomKernel;
+    s->created = true;
+    s->compute_called = false;
+
+    int64_t val;
+    TF_Status* status = TF_NewStatus();
+    EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
+                   /*expected_total_size*/ -1);
+    TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    EXPECT_EQ(1234, val);
+    TF_DeleteStatus(status);
+    return static_cast<void*>(s);
+  };
+
+  AttrValue v;
+  v.set_i(1234);
+  CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrInt", v);
+}
+
+TEST_F(TestKernelAttr, IntList) {
+  auto my_create_func = [](TF_OpKernelConstruction* ctx) {
+    struct MyCustomKernel* s = new struct MyCustomKernel;
+    s->created = true;
+    s->compute_called = false;
+
+    const int64_t list[] = {1, 2, 3, 4};
+    const size_t list_size = TF_ARRAYSIZE(list);
+    int64_t values[list_size];
+
+    TF_Status* status = TF_NewStatus();
+    EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
+                   /*expected_total_size*/ -1);
+    TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
+                                             status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    EXPECT_TRUE(
+        std::equal(std::begin(list), std::end(list), std::begin(values)));
+    TF_DeleteStatus(status);
+    return static_cast<void*>(s);
+  };
+
+  AttrValue v;
+  int64 attr_in[] = {1, 2, 3, 4};
+  SetAttrValue(gtl::ArraySlice<int64>(attr_in, 4), &v);
+  CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrIntList", v);
+}
+
+TEST_F(TestKernelAttr, Float) {
+  auto my_create_func = [](TF_OpKernelConstruction* ctx) {
+    struct MyCustomKernel* s = new struct MyCustomKernel;
+    s->created = true;
+    s->compute_called = false;
+
+    float val;
+    TF_Status* status = TF_NewStatus();
+    EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
+                   /*expected_total_size*/ -1);
+    TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    EXPECT_FLOAT_EQ(2.718, val);
+    TF_DeleteStatus(status);
+    return static_cast<void*>(s);
+  };
+
+  AttrValue v;
+  v.set_f(2.718);
+  CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloat", v);
+}
+
+TEST_F(TestKernelAttr, FloatList) {
+  auto my_create_func = [](TF_OpKernelConstruction* ctx) {
+    struct MyCustomKernel* s = new struct MyCustomKernel;
+    s->created = true;
+    s->compute_called = false;
+
+    const float list[] = {1.414, 2.718, 3.1415};
+    const size_t list_size = TF_ARRAYSIZE(list);
+    float values[list_size];
+
+    TF_Status* status = TF_NewStatus();
+    EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
+                   /*expected_total_size*/ -1);
+    TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
+                                             status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    EXPECT_TRUE(
+        std::equal(std::begin(list), std::end(list), std::begin(values)));
+    TF_DeleteStatus(status);
+    return static_cast<void*>(s);
+  };
+
+  AttrValue v;
+  float attr_in[] = {1.414, 2.718, 3.1415};
+  SetAttrValue(gtl::ArraySlice<float>(attr_in, 3), &v);
+  CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloatList", v);
+}
+
+TEST_F(TestKernelAttr, Bool) {
+  auto my_create_func = [](TF_OpKernelConstruction* ctx) {
+    struct MyCustomKernel* s = new struct MyCustomKernel;
+    s->created = true;
+    s->compute_called = false;
+
+    unsigned char val;
+    TF_Status* status = TF_NewStatus();
+    EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
+                   /*expected_total_size*/ -1);
+    TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    EXPECT_EQ(1, val);
+    TF_DeleteStatus(status);
+    return static_cast<void*>(s);
+  };
+
+  AttrValue v;
+  v.set_b(true);
+  CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBool", v);
+}
+
+TEST_F(TestKernelAttr, BoolList) {
+  auto my_create_func = [](TF_OpKernelConstruction* ctx) {
+    struct MyCustomKernel* s = new struct MyCustomKernel;
+    s->created = true;
+    s->compute_called = false;
+
+    const unsigned char list[] = {1, 0, 1, 0};
+    const size_t list_size = TF_ARRAYSIZE(list);
+    unsigned char values[list_size];
+
+    TF_Status* status = TF_NewStatus();
+    EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
+                   /*expected_total_size*/ -1);
+    TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
+                                            status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    EXPECT_TRUE(
+        std::equal(std::begin(list), std::end(list), std::begin(values)));
+    TF_DeleteStatus(status);
+    return static_cast<void*>(s);
+  };
+
+  AttrValue v;
+  bool attr_in[] = {true, false, true, false};
+  SetAttrValue(gtl::ArraySlice<bool>(attr_in, 4), &v);
+  CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBoolList", v);
+}
+
+TEST_F(TestKernelAttr, Type) {
+  auto my_create_func = [](TF_OpKernelConstruction* ctx) {
+    struct MyCustomKernel* s = new struct MyCustomKernel;
+    s->created = true;
+    s->compute_called = false;
+
+    TF_DataType val;
+    TF_Status* status = TF_NewStatus();
+    EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
+                   /*expected_total_size*/ -1);
+    TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    EXPECT_EQ(TF_FLOAT, val);
+    TF_DeleteStatus(status);
+    return static_cast<void*>(s);
+  };
+
+  AttrValue v;
+  v.set_type(DT_FLOAT);
+  CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrType", v);
+}
+
+TEST_F(TestKernelAttr, TypeList) {
+  auto my_create_func = [](TF_OpKernelConstruction* ctx) {
+    struct MyCustomKernel* s = new struct MyCustomKernel;
+    s->created = true;
+    s->compute_called = false;
+
+    const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
+    const size_t list_size = TF_ARRAYSIZE(list);
+    TF_DataType values[list_size];
+
+    TF_Status* status = TF_NewStatus();
+    EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
+                   /*expected_total_size*/ -1);
+    TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
+                                            status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    EXPECT_TRUE(
+        std::equal(std::begin(list), std::end(list), std::begin(values)));
+    TF_DeleteStatus(status);
+    return static_cast<void*>(s);
+  };
+
+  AttrValue v;
+  DataType attr_in[] = {DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128};
+  SetAttrValue(gtl::ArraySlice<DataType>(attr_in, 4), &v);
+  CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrTypeList", v);
+}
+#undef EXPECT_TF_SIZE
+
 class DummyDevice : public DeviceBase {
  public:
   explicit DummyDevice(Env* env) : DeviceBase(env) {}
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index 13e666d..8cb1a8a 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -60,7 +60,7 @@
   if (result.size() > sizeof("external/") &&
       result.compare(0, sizeof("external/") - 1, "external/") == 0) {
     result = result.substr(sizeof("external/") - 1);
-    pos = result.find("/");
+    pos = result.find('/');
     if (pos != string::npos) {
       result = result.substr(pos + 1);
     }
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 51c93a7..179c787 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -184,6 +184,7 @@
     "//tensorflow/compiler/tf2xla:tf2xla_util",
     "//tensorflow/compiler/tf2xla:xla_compiler",
     "//tensorflow/compiler/tf2xla:xla_op_registry",
+    "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
     "//tensorflow/compiler/tf2xla/kernels:xla_ops",
     "//tensorflow/compiler/xla:util",
     "//tensorflow/compiler/xla/client:client_library",
@@ -364,13 +365,9 @@
         ":flags",
         ":xla_activity_listener",
         ":xla_activity_proto_cc",
-        "@com_google_absl//absl/base",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:optional",
-        "@com_google_absl//absl/types:span",
+        "//tensorflow/compiler/mlir:array_container_utils",
         "//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
+        "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla:xla_context",
@@ -385,13 +382,13 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/platform:logging",
-    ] + if_libtpu(
-        if_false = [
-            "//tensorflow/compiler/mlir:array_container_utils",
-            "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
-        ],
-        if_true = [],
-    ),
+        "@com_google_absl//absl/base",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/types:span",
+    ],
 )
 
 tf_cc_test(
diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc
index 87b06c2..fb4c187 100644
--- a/tensorflow/compiler/jit/compilability_check_util.cc
+++ b/tensorflow/compiler/jit/compilability_check_util.cc
@@ -151,10 +151,12 @@
   // not considered uncompilable.
   if (node_stack_trace != nullptr) {
     for (const auto& frame : *node_stack_trace) {
-      stack_trace.emplace_back(StackFrameView{frame.name, frame.function_name});
+      stack_trace.emplace_back(
+          StackFrameView{frame.name, frame.function_name, frame.stack_trace});
     }
   }
-  stack_trace.emplace_back(StackFrameView{node.name(), ""});
+  stack_trace.emplace_back(
+      StackFrameView{node.name(), "", node.GetStackTrace()});
 
   RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
   IsCompilableNode(node, lib_runtime, &stack_trace,
@@ -173,10 +175,11 @@
   std::vector<StackFrameView> stack_trace;
   if (node_stack_trace != nullptr) {
     for (const auto& frame : *node_stack_trace) {
-      stack_trace.emplace_back(StackFrameView{frame.name, frame.function_name});
+      stack_trace.emplace_back(
+          StackFrameView{frame.name, frame.function_name, frame.stack_trace});
     }
   }
-  stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
+  stack_trace.emplace_back(StackFrameView{call_def.name(), "", nullptr});
 
   RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
   IsCompilableCall(call_def, lib_runtime, &stack_trace,
@@ -194,12 +197,11 @@
         "SymbolicGradient should be handled by IsCompilableCall().";
     return false;
   }
+
   if (node.type_string() == "Const") {
-    // Skip Const op with type DT_STRING, since XLA doesn't support it, but the
-    // registered Const KernelDef says that it does, to support no-op Assert for
-    // tfcompile.
     const AttrValue* attr = node.attrs().Find("dtype");
-    if (attr != nullptr && attr->type() == DT_STRING) {
+    if (!op_filter_.allow_string_consts && attr != nullptr &&
+        attr->type() == DT_STRING) {
       *uncompilable_reason =
           "Const op with type DT_STRING is not supported by XLA.";
       return false;
@@ -359,7 +361,8 @@
   const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
   bool is_compilable = true;
   for (const Node* node : fbody->graph->op_nodes()) {
-    stack_trace->emplace_back(StackFrameView{node->name(), function.name()});
+    stack_trace->emplace_back(
+        StackFrameView{node->name(), function.name(), node->GetStackTrace()});
     is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
                                       &function, uncompilable_nodes);
     stack_trace->pop_back();
@@ -583,7 +586,8 @@
                     [](const StackFrameView& stack_element) {
                       return StackFrame{
                           std::string(stack_element.name),
-                          std::string(stack_element.function_name)};
+                          std::string(stack_element.function_name),
+                          stack_element.stack_trace};
                     });
 
   node_info.name = std::string(stack_trace.back().name);
diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h
index 224beda..0272536 100644
--- a/tensorflow/compiler/jit/compilability_check_util.h
+++ b/tensorflow/compiler/jit/compilability_check_util.h
@@ -62,6 +62,7 @@
   struct StackFrame {
     std::string name;
     std::string function_name;
+    std::shared_ptr<AbstractStackTrace> stack_trace;
   };
 
   // Contains information about uncompilable node inside a function body.
@@ -128,6 +129,9 @@
     // Require the function to be always compilable, regardless whether some
     // control flow branches might be dead for a given input.
     bool require_always_compilable = false;
+
+    // Whether string constants are compilable.
+    bool allow_string_consts = true;
   };
 
   RecursiveCompilabilityChecker(OperationFilter op_filter,
@@ -193,6 +197,7 @@
   struct StackFrameView {
     absl::string_view name;
     absl::string_view function_name;
+    std::shared_ptr<AbstractStackTrace> stack_trace;
   };
 
   bool IsCompilableNode(
diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc
index 4964799..15bd534 100644
--- a/tensorflow/compiler/jit/flags.cc
+++ b/tensorflow/compiler/jit/flags.cc
@@ -177,6 +177,7 @@
   // bridge, on a per-graph basis).
   bool enable_mlir_bridge = false;
   bool enable_mlir_bridge_is_explicit = false;
+  bool mlir_bridge_safe_mode = false;
 
   auto setter_for_jitter_tensor_names = [](string sequence) {
     jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
@@ -227,7 +228,13 @@
 
        Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
             "Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
-            &enable_mlir_bridge_is_explicit)});
+            &enable_mlir_bridge_is_explicit),
+       Flag(
+           "tf_mlir_bridge_safe_mode", &mlir_bridge_safe_mode,
+           "When tf_mlir_enable_mlir_bridge is true, this field can enable "
+           "the MLIR bridge's safe mode. When the MLIR bridge is in safe mode, "
+           "it only runs for graphs that use features MLIR bridge currently "
+           "supports.")});
 
   AppendMarkForCompilationPassFlagsInternal(flag_list);
   xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@@ -238,7 +245,9 @@
         ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
   } else if (enable_mlir_bridge) {
     mlir_flags->tf_mlir_enable_mlir_bridge =
-        ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
+        (mlir_bridge_safe_mode)
+            ? ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED
+            : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
   } else {
     mlir_flags->tf_mlir_enable_mlir_bridge =
         ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 0ffea77..4d2c9f4 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -1199,6 +1199,7 @@
     RecursiveCompilabilityChecker::OperationFilter filter =
         CreateOperationFilter(*registration);
     filter.require_always_compilable = true;
+    filter.allow_string_consts = false;
 
     RecursiveCompilabilityChecker checker(
         filter, DeviceType{registration->compilation_device_name});
@@ -1207,6 +1208,15 @@
       continue;
     }
 
+    if (node->type_string() == "Const") {
+      // Skip Const op with type DT_STRING, since XLA autoclustering doesn't
+      // support it.
+      const AttrValue* attr = node->attrs().Find("dtype");
+      if (attr != nullptr && attr->type() == DT_STRING) {
+        continue;
+      }
+    }
+
     if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
       VLOG(1) << "Rejecting TF operation " << node->def().op()
               << " as it is not listed in --tf_xla_ops_to_cluster.";
@@ -2035,6 +2045,7 @@
                                      "TensorScatterUpdate",
                                      "TridiagonalSolve",
                                      "TruncatedNormal",
+                                     "Unique",
                                      "UpperBound",
                                      "UnsortedSegmentMax",
                                      "UnsortedSegmentMin",
@@ -2071,6 +2082,7 @@
                                      "XlaSpmdShardToFullShape",
                                      "XlaSvd",
                                      "XlaVariadicReduce",
+                                     "XlaVariadicSort",
                                      "XlaWhile",
                                      "Zeta",
                                      "_Arg",
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 005332a..61ff6bc 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -24,6 +24,8 @@
 #include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/compiler/jit/xla_activity.pb.h"
 #include "tensorflow/compiler/jit/xla_activity_listener.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
+#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -48,11 +50,6 @@
 #include "tensorflow/core/public/version.h"
 #include "tensorflow/core/util/dump_graph.h"
 
-#if !defined(LIBTPU_ON_GCE)
-#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
-#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
-#endif
-
 namespace tensorflow {
 
 constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold;
@@ -292,13 +289,6 @@
                     GetMlirBridgeRolloutPolicy(*graph, *config) ==
                         MlirBridgeRolloutPolicy::kEnabledByUser &&
                     node_def.op() != "VarIsInitializedOp";
-#ifdef LIBTPU_ON_GCE
-    if (use_mlir) {
-      LOG(WARNING) << "MLIR is not supported in this environment.";
-    }
-    return compiler->CompileGraph(compile_options, node_def.name(),
-                                  std::move(graph), args, result);
-#else
     if (!use_mlir) {
       return compiler->CompileGraph(compile_options, node_def.name(),
                                     std::move(graph), args, result);
@@ -314,7 +304,6 @@
         *graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
         options.device_type.type_string(), compile_options.use_tuple_arg,
         *options.flib_def, debug_info, options.shape_representation_fn, result);
-#endif
   };
   return CompileImpl(options, name, args, compile_op,
                      /*compile_threshold=*/absl::nullopt,
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index dd1ddb6..c4edd86 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -38,7 +38,7 @@
 Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
   XlaDeviceFlags* flags = GetXlaDeviceFlags();
   if (!flags->tf_xla_enable_xla_devices) {
-    LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
+    VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
     return Status::OK();
   }
 
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 99ba565..209ea4a 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -96,7 +96,7 @@
     std::vector<std::unique_ptr<Device>>* devices) {
   XlaDeviceFlags* flags = GetXlaDeviceFlags();
   if (!flags->tf_xla_enable_xla_devices) {
-    LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
+    VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
     return Status::OK();
   }
 
diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc
index baca8b9..602c2d2 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator.cc
+++ b/tensorflow/compiler/jit/xla_kernel_creator.cc
@@ -115,17 +115,24 @@
           uncompilable_node_info.emplace_back(info);
         }
       }
-      string message = absl::StrCat(
+      std::string message = absl::StrCat(
           "Function invoked by the following node is not compilable: ",
           SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
-      absl::StrAppend(&message, "Uncompilable nodes:");
+      absl::StrAppend(&message, "Uncompilable operations:");
       for (const auto& node_info : uncompilable_node_info) {
-        string node_message = absl::StrCat("\n", node_info.name, ": ",
-                                           node_info.uncompilable_reason, "\n",
-                                           "\tStacktrace:\n");
-        for (const auto& stack_frame : node_info.stack_trace) {
-          absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
-                                stack_frame.name, stack_frame.function_name);
+        std::string node_message = absl::StrCat(
+            "\n", node_info.name, ": ", node_info.uncompilable_reason, "\n",
+            "The op is created at:\n");
+        if (node_info.stack_trace.back().stack_trace) {
+          AbstractStackTrace::TracePrintingOptions opts;
+          opts.show_line_contents = true;
+          opts.filter_common_prefix = true;
+          opts.drop_internal_frames = true;
+          absl::StrAppend(
+              &node_message,
+              node_info.stack_trace.back().stack_trace->ToString(opts));
+        } else {
+          absl::StrAppend(&node_message, "<Unavailable>\n");
         }
         absl::StrAppend(&message, node_message);
       }
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index f839acd..77db4eb 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -79,6 +79,7 @@
         "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
         "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
         "//tensorflow/core:lib",
         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
@@ -112,8 +113,8 @@
         "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
         "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
         "//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
-        "//tensorflow/compiler/mlir/tosa:tf_tosa_passes",
-        "//tensorflow/compiler/mlir/tosa:tfl_tosa_passes",
+        "//tensorflow/compiler/mlir/tosa:tf_passes",
+        "//tensorflow/compiler/mlir/tosa:tfl_passes",
     ],
 )
 
diff --git a/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md b/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md
index a32cfd5..a344655 100644
--- a/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md
+++ b/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md
@@ -1,7 +1,273 @@
 <!-- Autogenerated by mlir-tblgen; don't manually edit -->
+### `-tf-device-constant-sinking`: Sinks constants implicitly captured in a tf_device.cluster region.
+This pass sinks implicitly captured constants (`tf.Const` ops) used by and into
+a `tf_device.cluster` region. Performing this prior to outlining will reduce the
+number of arguments of the outlined function.
+
+For example, the following:
+
+```mlir
+func @cluster() -> tensor<i32> {
+  %const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %cluster = "tf_device.cluster"() ( {
+    %identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %identity : tensor<i32>
+  }) : () -> (tensor<i32>)
+  return %cluster : tensor<i32>
+}
+```
+
+will be transformed into:
+
+```mlir
+func @cluster() -> tensor<i32> {
+  %cluster = "tf_device.cluster"() ( {
+    %const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+    %identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %identity : tensor<i32>
+  }) : () -> (tensor<i32>)
+  return %cluster : tensor<i32>
+}
+```
+### `-tf-executor-graph-pruning`: Prunes unreachable ops in a tf_executor.graph
+This pass removes ops from a `tf_executor.graph` that are not transitively, via
+data or control dependencies, connected to the associated `tf_executor.fetch`
+op. The order of ops will be preserved. Functions named `main` with no
+`tf.entry_function` attribute will not be pruned, as such graphs/functions may
+have been imported from a V1 TensorFlow graph, where feeds/fetches/targets are
+not provided at certain stages of IR transformation (e.g. pre-placement).
+
+For example, the following:
+
+```mlir
+func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+  %graph = tf_executor.graph {
+    %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+    %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
+    %unreachable_data:2 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+    %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
+    %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
+    %unreachable_control = tf_executor.island wraps "tf.NoOp"() : () -> tensor<i32>
+    tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
+  }
+  return %graph : tensor<i32>
+}
+```
+
+will be transformed into:
+
+```mlir
+func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+  %graph = tf_executor.graph {
+    %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+    %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
+    %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
+    %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
+    tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
+  }
+  return %graph : tensor<i32>
+}
+```
+### `-tf-executor-to-functional-conversion`: Lifts tf_executor.island inner ops from a tf_executor.graph
+This pass converts tf_executor.graphs consisting of only tf_executor.islands and
+a tf_executor.fetch into a sea of nodes consisting of TensorFlow Dialect ops by
+lifting such ops out of a tf_executor.graph's tf_executor.islands. If V1 control
+flow ops are present in a tf_executor.graph, an error will be returned.
+
+For example, the following:
+
+```mlir
+func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+  %graph_results:2 = tf_executor.graph {
+    %island_0_result, %island_0_control = tf_executor.island {
+      %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+      tf_executor.yield %identity : tensor<i32>
+    }
+    %island_1_result, %island_1_control = tf_executor.island {
+      %identity_n:2 = "tf.IdentityN"(%arg1, %island_0_result) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
+      tf_executor.yield %identity_n#0
+    }
+    tf_executor.fetch %island_0_result, %island_1_result : tensor<i32>, tensor<i32>
+  }
+  return %graph_results#0, %graph_results#1 : tensor<i32>, tensor<i32>
+}
+```
+
+will be transformed into:
+
+```mlir
+func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+  %identity_n:2 = "tf.IdentityN"(%arg1, %identity) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
+  return %identity, %identity_n#0 : tensor<i32>, tensor<i32>
+}
+```
+### `-tf-mark-ops-for-outside-compilation`: Marks ops in device cluster for outside compilation if they are unsupported on device.
+This pass marks unsupported ops in a device cluster with
+`_xla_outside_compilation` attribute so the operations will run on the host
+instead of the device. Unsupported ops are ops that can not be code
+generated to run on the device for the cluster including:
+
+1. String operations on TPUs.
+2. Operations that don't have a kernel defined for the device.
+
+This pass is conservative in that it will mark all ops for outside compilation
+that can not be compiled for the device.  Exceptions for this are added for ops
+that will be rewritten or decomposed before compiling on device.
+
+
+For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp:
+
+```mlir
+func @unsupported_op() -> tensor<i32> {
+  %0 = "tf_device.cluster"() ( {
+    %1 = "tf.UnsupportedOp"() : () -> tensor<i32>
+    %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %2 : tensor<i32>
+  }) {allow_soft_placement = true, num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<i32>
+  return %0 : tensor<i32>
+}
+```
+
+will mark tf.UnsupportedOp with `_xla_outside_compilation` attribute:
+
+```mlir
+func @unsupported_op() -> tensor<i32> {
+  %0 = "tf_device.cluster"() ( {
+    %1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor<i32>
+    %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %2 : tensor<i32>
+  }) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<i32>
+  return %0 : tensor<i32>
+}
+```
 ### `-tf-shape-inference`: Simple Shape Inference on TensorFlow Dialect
 
 #### Options
 ```
 -max-iterations : Maximum shape inference iterations
 ```
+### `-tf-tpu-cluster-formation`: Forms clusters from operations assigned to the same TPU computation
+TPU computations from the frontend are composed of a `tf.TPUReplicateMetadata`
+op, a subgraph of ops (TensorFlow Dialect) each with a matching `_tpu_replicate`
+attribute relative to the associated `tf.TPUReplicateMetadata` op, and
+optionally `tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops feeding in
+inputs and outputs to and from a replicated TPU computation. The number of times
+a TPU computation is replicated is defined in the `tf.TPUReplicateMetadata` op
+(`num_replicas` attribute) and operand and result sizes of
+`tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` respectively must match,
+excluding packed tensors. It is also assumed ops of the same TPU computation do
+not have ops outside of the TPU computation that are both inputs and outputs to
+the same TPU computation.
+
+This pass takes the TPU computation subgraph, moves them into a
+`tf_device.cluster`, and copies over attributes from the associated
+`tf.TPUReplicateMetadata` op to the newly created `tf_device.cluster`. If the
+computation is replicated (`num_replicas` > 1), the `num_replicas` attribute is
+not copied over but instead the `tf_device.cluster` is further wrapped with a
+`tf_device.replicate`, and associated `tf.TPUReplicatedInput` and
+`tf.TPUReplicatedOutput` ops are replaced as the `tf_device.replicate` operands
+and results. Otherwise, the single operands and results of the associated
+`tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops are simply forwarded to
+the `tf_device.cluster`.
+
+For example, the following non replicated computation:
+
+```mlir
+func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
+  // Metadata op for cluster `cluster` with 1 replica, 1 core per replica and
+  // with topology `<topology>`.
+  "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "<topology>", device_assignment = [], padding_map = []} : () -> ()
+  %replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor<i32>) -> tensor<i32>
+  %identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
+  %replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> tensor<i32>
+  return %replicated_output : tensor<i32>
+}
+```
+
+will be transformed into:
+
+```mlir
+func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
+  %cluster = "tf_device.cluster"() ( {
+    %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %identity : tensor<i32>
+  }) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
+  return %cluster : tensor<i32>
+}
+```
+
+The following replicated computation:
+
+```mlir
+func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+  "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> ()
+  %replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
+  %replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
+  return %replicated_output#0, %replicated_output#1 : tensor<i32>, tensor<i32>
+}
+```
+
+will be transformed into:
+
+```mlir
+func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+  %replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} {
+    %cluster = "tf_device.cluster"() ( {
+      %identity = "tf.Identity"(%replicated_input) : (tensor<i32>) -> tensor<i32>
+      tf_device.return %identity : tensor<i32>
+    }) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
+    tf_device.return %cluster : tensor<i32>
+  }
+  return %replicate#0, %replicate#1 : tensor<i32>, tensor<i32>
+}
+```
+### `-tf-tpu-extract-outside-compilation`: Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.
+This pass extracts a CPU computation cluster with `_xla_outside_compilation`
+annotation, which denotes ops that should be run on CPU/host, from a TPU cluster.
+Each outside compilation cluster is moved to
+a tf_device.parallel_execute region. The TPU cluster is also moved to a
+tf_device.parallel_execute region. Communication ops between device and host are
+added to pass inputs/outputs to/from the outside compiled region.
+
+For example, the following tf_device.cluster with an op marked for `xla_outside_compilation`:
+
+```mlir
+func @outside_compilation() -> tensor<f32> {
+  %0 = "tf_device.cluster"() ( {
+    %1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
+    %2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor<f32>) -> (tensor<f32>)
+    %3 = "tf.AddV2"(%1, %2) : (tensor<f32>, tensor<f32>) -> (tensor<f32>)
+    tf_device.return %3 : tensor<f32>
+  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<f32>
+  return %0 : tensor<f32>
+}
+```
+
+will become a tf_device.parallel_execute op with a CPU/host region and
+a tf_device.cluster with communication ops to send data to/from device/host:
+
+```mlir
+func @outside_compilation() -> tensor<f32> {
+  %0 = "tf_device.parallel_execute"() ( {
+    "tf_device.launch"() ( {
+      %1 = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf.string>
+      %2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf.string>) -> tensor<f32>
+      %3 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
+      "tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor<f32>, tensor<3x!tf.string>) -> ()
+      tf_device.return
+    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
+    tf_device.return
+  },  {
+    %1 = "tf_device.cluster"() ( {
+      %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+      %3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor<f32>) -> tensor<f32>
+      %4 = "tf.AddV2"(%2, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+      tf_device.return %4 : tensor<f32>
+    }) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<f32>
+    tf_device.return %1 : tensor<f32>
+  }) : () -> tensor<f32>
+  return %0 : tensor<f32>
+}
+```
diff --git a/tensorflow/compiler/mlir/hlo/.bazelrc b/tensorflow/compiler/mlir/hlo/.bazelrc
new file mode 100644
index 0000000..840949a
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/.bazelrc
@@ -0,0 +1,15 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+build --cxxopt=-std=c++14
+build --host_cxxopt=-std=c++14
diff --git a/tensorflow/compiler/mlir/hlo/.gitignore b/tensorflow/compiler/mlir/hlo/.gitignore
index cc1696bf..53e8335 100644
--- a/tensorflow/compiler/mlir/hlo/.gitignore
+++ b/tensorflow/compiler/mlir/hlo/.gitignore
@@ -1,4 +1,4 @@
 build
 llvm-project
 llvm-build
-
+bazel-*
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index 219d391..5411c90 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -570,7 +570,7 @@
         "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:Transforms",
+        "@llvm-project//mlir:TensorDialect",
     ],
 )
 
@@ -583,6 +583,7 @@
         ":map_hlo_to_lhlo_op",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:StandardOps",
     ],
 )
@@ -654,6 +655,7 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Transforms",
     ],
@@ -740,6 +742,7 @@
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:StandardOpsTransforms",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:Transforms",
     ],
     alwayslink = 1,
@@ -809,11 +812,11 @@
     deps = [
         ":hlo",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/compiler/mlir/hlo/WORKSPACE b/tensorflow/compiler/mlir/hlo/WORKSPACE
new file mode 100644
index 0000000..563df21
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/WORKSPACE
@@ -0,0 +1,57 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Workspace for MLIR HLO."""
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+
+LLVM_COMMIT = "<LLVM_COMMIT>"
+
+LLVM_SHA256 = "<LLVM_SHA256>"
+
+LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)
+
+http_archive(
+    name = "llvm-bazel",
+    strip_prefix = "llvm-bazel-{tag}/llvm-bazel".format(tag = LLVM_BAZEL_TAG),
+    url = "https://github.com/google/llvm-bazel/archive/{tag}.tar.gz".format(tag = LLVM_BAZEL_TAG),
+)
+
+load("@llvm-bazel//:terminfo.bzl", "llvm_terminfo_disable")
+load("@llvm-bazel//:zlib.bzl", "llvm_zlib_disable")
+load("@llvm-bazel//:configure.bzl", "llvm_configure")
+
+http_archive(
+    name = "llvm-project-raw",
+    build_file_content = "#empty",
+    sha256 = LLVM_SHA256,
+    strip_prefix = "llvm-project-{commit}".format(commit = LLVM_COMMIT),
+    urls = [
+        "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
+        "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
+    ],
+)
+
+llvm_terminfo_disable(
+    name = "llvm_terminfo",
+)
+
+llvm_zlib_disable(
+    name = "llvm_zlib",
+)
+
+llvm_configure(
+    name = "llvm-project",
+    src_path = ".",
+    src_workspace = "@llvm-project-raw//:WORKSPACE",
+)
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h
index 05b2277..c1d7ffc 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h
@@ -18,12 +18,12 @@
 
 #include "llvm/ADT/StringRef.h"
 #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h
index 9bba3d3..2b1a18f 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h
@@ -21,13 +21,13 @@
 #include "llvm/ADT/StringRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index ba4749e..cbdc3b6 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -146,10 +146,9 @@
 
 // Abs supports complex to real, so element type is not guaranteed to match.
 def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
-    [NoSideEffect, SameOperandsAndResultShape],
+    [NoSideEffect, SameOperandsAndResultShape,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>],
      TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
-  let builders = [
-    OpBuilderDAG<(ins "Value":$operand)>];
 }
 
 def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
@@ -902,6 +901,7 @@
     ConvolutionAttributes.attributes);
 
   let results = (outs HLO_Tensor);
+  let hasCustomHLOConverter = 1;
 }
 
 def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp {
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
index 9b1b126..57fdfb6 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
@@ -958,6 +958,17 @@
     OptionalAttr<
           TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
 
+def BoolElementsAttr :
+    ElementsAttrBase<
+      And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
+           CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
+      "constant boolean vector/tensor attribute"> {
+  let storageType = [{ ::mlir::DenseElementsAttr }];
+  let returnType = [{ ::mlir::DenseElementsAttr }];
+
+  let convertFromStorage = "$_self";
+}
+
 def ConvolutionAttributes {
   dag attributes = (ins
     // Default value: one for each of the spatial dimension.
@@ -968,6 +979,8 @@
     OptionalAttr<I64ElementsAttr>:$lhs_dilation,
     // Default value: one for each of the spatial dimension.
     OptionalAttr<I64ElementsAttr>:$rhs_dilation,
+    // Default value: one for each of the spatial dimension.
+    OptionalAttr<BoolElementsAttr>:$window_reversal,
     ConvDimensionNumbers:$dimension_numbers,
     I64Attr:$feature_group_count,
     I64Attr:$batch_group_count,
@@ -983,6 +996,14 @@
 
     See https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
   }];
+
+  code extraClassDeclaration = [{
+    bool hasWindowReversal() {
+      auto reversal = window_reversalAttr();
+      return reversal && llvm::any_of(reversal.getBoolValues(),
+                                      [](bool v) { return v; });
+    }
+  }];
 }
 
 class BASE_HLO_CopyOp {
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h
index 64c2f8f..70247d7 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h
@@ -19,8 +19,8 @@
 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
 
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Identifier.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Types.h"
 
 // Order matters, this .inc header is not self-contained, and relies on the
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h
index 00de117..e26bf08 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h
@@ -16,8 +16,8 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/StandardTypes.h"
 
 namespace mlir {
 
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h
index 92b7f63..3214ec6 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h
@@ -25,12 +25,12 @@
 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h"
 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h
index 70f6f17..f34dccf 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h
@@ -19,8 +19,8 @@
 #define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
 
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Identifier.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Types.h"
 
 // Order matters, this .inc header is not self-contained, and relies on the
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td
index 2bf93f7..da7d179 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td
@@ -21,7 +21,17 @@
 def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
                                           LHLO_GPU_Dialect, [
    StructFieldAttr<"algorithm", I64Attr>,
-   StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
+   StructFieldAttr<"tensor_ops_enabled", BoolAttr>,
+   // The following 3 attributes describe the layout as an array of integers
+   // that list the dimensions in minor-to-major order similar to XLA's layout
+   // representation. operand_0_layout and operand_0_layout described the layout
+   // of the first 2 operands of the convolution, and result_layout describes
+   // the layout of the primary output operand of the convolution.
+   // Note: Not using names like input_layout or filter_layout as `input` may be
+   // an input operand (for ConvForward) but output for ConvBackward.
+   StructFieldAttr<"operand_0_layout", I64ArrayAttr>,
+   StructFieldAttr<"operand_1_layout", I64ArrayAttr>,
+   StructFieldAttr<"result_layout", I64ArrayAttr>]> {
    let description = "GPU Convolution backend configuration";
 }
 
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h
index 78e9c7e..7dfbfd6 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h
@@ -22,12 +22,12 @@
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h
index ac67619..ef36f41 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h
@@ -65,12 +65,17 @@
 MAP_HLO_TO_LHLO(MulOp);
 MAP_HLO_TO_LHLO(NegOp);
 MAP_HLO_TO_LHLO(NotOp);
+MAP_HLO_TO_LHLO(OrOp);
+MAP_HLO_TO_LHLO(PowOp);
 MAP_HLO_TO_LHLO(RealOp);
 MAP_HLO_TO_LHLO(ReduceOp);
 MAP_HLO_TO_LHLO(ReshapeOp);
 MAP_HLO_TO_LHLO(RemOp);
 MAP_HLO_TO_LHLO(RsqrtOp);
 MAP_HLO_TO_LHLO(SelectOp);
+MAP_HLO_TO_LHLO(ShiftLeftOp);
+MAP_HLO_TO_LHLO(ShiftRightArithmeticOp);
+MAP_HLO_TO_LHLO(ShiftRightLogicalOp);
 MAP_HLO_TO_LHLO(SignOp);
 MAP_HLO_TO_LHLO(SinOp);
 MAP_HLO_TO_LHLO(SliceOp);
@@ -78,6 +83,7 @@
 MAP_HLO_TO_LHLO(SubOp);
 MAP_HLO_TO_LHLO(TanhOp);
 MAP_HLO_TO_LHLO(TransposeOp);
+MAP_HLO_TO_LHLO(XorOp);
 
 #undef MAP_HLO_TO_LHLO
 
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
index d59dfd4..eadc32c 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
@@ -16,12 +16,16 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/iterator_range.h"
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 
 namespace mlir {
@@ -37,6 +41,7 @@
 struct LhloToScalarOp<lmhlo::AddOp> {
   using FOp = ::mlir::AddFOp;
   using IOp = ::mlir::AddIOp;
+  using COp = ::mlir::AddCFOp;
 };
 template <>
 struct LhloToScalarOp<lmhlo::CompareOp> {
@@ -62,20 +67,18 @@
 struct LhloToScalarOp<lmhlo::SubOp> {
   using FOp = ::mlir::SubFOp;
   using IOp = ::mlir::SubIOp;
-};
-
-template <typename LhloBinaryOpTy>
-struct ScalarOp {
-  using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
-  using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
+  using COp = ::mlir::SubCFOp;
 };
 
 // Alias for the map from LHLO binary op type to STD floating-point op type.
 template <typename LhloOp>
-using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
+using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
 // Alias for the map from LHLO binary op type to STD integer op type.
 template <typename LhloOp>
-using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
+using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
+// Alias for the map from LHLO binary op type to STD complex op type.
+template <typename LhloOp>
+using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
 
 template <typename... Args>
 struct MapLhloOpToStdScalarOpImpl {
@@ -143,6 +146,16 @@
   }
   return nullptr;
 }
+template <>
+inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
+                                                  ArrayRef<Type> result_types,
+                                                  ArrayRef<Value> args,
+                                                  OpBuilder* b) {
+  return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
+                                    FloatType, ScalarFOp<lmhlo::AddOp>,
+                                    ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
+      loc, result_types, args, b);
+}
 
 template <>
 inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
@@ -172,7 +185,7 @@
     StringRef comparison_direction) {
   return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
       .Case("EQ", CmpFPredicate::OEQ)
-      .Case("NE", CmpFPredicate::ONE)
+      .Case("NE", CmpFPredicate::UNE)
       .Case("GE", CmpFPredicate::OGE)
       .Case("GT", CmpFPredicate::OGT)
       .Case("LE", CmpFPredicate::OLE)
@@ -482,6 +495,15 @@
 }
 
 template <>
+inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
+                                                 ArrayRef<Type> result_types,
+                                                 ArrayRef<Value> args,
+                                                 OpBuilder* b) {
+  return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
+      loc, result_types, args, b);
+}
+
+template <>
 inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
                                                     ArrayRef<Type> result_types,
                                                     ArrayRef<Value> args,
@@ -491,6 +513,40 @@
 }
 
 template <>
+inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
+                                                  ArrayRef<Type> result_types,
+                                                  ArrayRef<Value> args,
+                                                  OpBuilder* b) {
+  lmhlo::PowOp::Adaptor adaptor(args);
+  // Floating point can use std::powf
+  auto result_type = result_types.front();
+  if (result_type.isa<::mlir::FloatType>())
+    return MapLhloOpToStdScalarOpImpl<::mlir::PowFOp>{}(loc, result_types, args,
+                                                        b);
+
+  assert(result_type.isa<::mlir::IntegerType>() &&
+         "only float and integer `pow` is supported right now");
+
+  // There is no powi, so lower to a simple product. Note that HLO does not
+  // define semantics of negative exponents.
+  Value init = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
+
+  Value lowerBound = b->create<ConstantIndexOp>(loc, 0);
+  Value upperBound =
+      b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType());
+  Value step = b->create<ConstantIndexOp>(loc, 1);
+  return b
+      ->create<scf::ForOp>(
+          loc, lowerBound, upperBound, step, llvm::makeArrayRef(init),
+          [&](OpBuilder& b, Location l, Value v, ValueRange iters) {
+            Value prod =
+                b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
+            b.create<scf::YieldOp>(l, prod);
+          })
+      .getResult(0);
+}
+
+template <>
 inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
     OpBuilder* b) {
@@ -499,6 +555,30 @@
 }
 
 template <>
+inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
+    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
+    OpBuilder* b) {
+  return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
+      loc, result_types, args, b);
+}
+
+template <>
+inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
+    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
+    OpBuilder* b) {
+  return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
+      loc, result_types, args, b);
+}
+
+template <>
+inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
+    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
+    OpBuilder* b) {
+  return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
+      loc, result_types, args, b);
+}
+
+template <>
 inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
                                                    ArrayRef<Type> result_types,
                                                    ArrayRef<Value> args,
@@ -506,14 +586,22 @@
   Type element_type = getElementTypeOrSelf(args.front().getType());
   if (auto float_type = element_type.dyn_cast<FloatType>()) {
     bool ignored;
-    APFloat one_apfloat(1.0f);
-    one_apfloat.convert(float_type.getFloatSemantics(),
-                        APFloat::rmNearestTiesToEven, &ignored);
-    Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
+    APFloat zero_apfloat(0.0f);
+    zero_apfloat.convert(float_type.getFloatSemantics(),
+                         APFloat::rmNearestTiesToEven, &ignored);
+    Value zero =
+        b->create<mlir::ConstantFloatOp>(loc, zero_apfloat, float_type);
     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
-      one = b->create<::mlir::SplatOp>(loc, vec_type, one);
+      zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
     }
-    return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
+    Value ne0_i1 =
+        b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, args[0], zero);
+    Value ne0_float = b->create<::mlir::UIToFPOp>(loc, ne0_i1, zero.getType());
+    Value copy_sign =
+        b->create<::mlir::CopySignOp>(loc, result_types, ne0_float, args[0]);
+    auto is_nan =
+        b->create<::mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[0]);
+    return b->create<::mlir::SelectOp>(loc, is_nan, args[0], copy_sign);
   } else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
     // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
     Value zero =
@@ -548,6 +636,17 @@
 }
 
 template <>
+inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
+                                                  ArrayRef<Type> result_types,
+                                                  ArrayRef<Value> args,
+                                                  OpBuilder* b) {
+  return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
+                                    FloatType, ScalarFOp<lmhlo::SubOp>,
+                                    ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
+      loc, result_types, args, b);
+}
+
+template <>
 inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
                                                    ArrayRef<Type> result_types,
                                                    ArrayRef<Value> args,
@@ -556,6 +655,15 @@
       loc, result_types, args, b);
 }
 
+template <>
+inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
+                                                  ArrayRef<Type> result_types,
+                                                  ArrayRef<Value> args,
+                                                  OpBuilder* b) {
+  return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
+      loc, result_types, args, b);
+}
+
 }  // namespace impl
 
 struct HloOpToStdScalarOp {
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
index a2066df8..1fb9ba6 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
@@ -52,6 +52,11 @@
 void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
                                MLIRContext *ctx);
 
+// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect.
+void populateDynamicHLOToLHLOConversionPattern(
+    MLIRContext *context, BufferizeTypeConverter *converter,
+    OwningRewritePatternList *patterns, bool insert_copy = true);
+
 // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
 void populateHLOToLHLOConversionPattern(MLIRContext *context,
                                         BufferizeTypeConverter *converter,
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h
index 1c57073..7059d95 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h
@@ -21,8 +21,8 @@
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Support/LLVM.h"
 
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h
index 39e0acf..5e41fd0 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h
@@ -17,7 +17,7 @@
 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
 
 #include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
 
 namespace mlir {
 namespace hlo {
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h
index 74ea9c9..602ca96 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h
@@ -18,8 +18,8 @@
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 
 namespace mlir {
@@ -83,6 +83,11 @@
 // Requires `ty` to be either FloatType or IntegerType.
 DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
 
+// Given `op_name` from LMHLO, returns the corresponding op name in MHLO.
+// Returns empty string if no such op exists.
+std::string LmhloToMhloOpName(llvm::StringRef op_name,
+                              mlir::MLIRContext* context);
+
 }  // namespace hlo
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
index 7ea42c6..9761e6a 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
@@ -19,9 +19,9 @@
 #include "mlir-hlo/utils/broadcast_utils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 
 namespace mlir {
@@ -202,7 +202,7 @@
     MLIRContext* context, Optional<Location> location, ValueRange operands,
     DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
-  Type element_type = IntegerType::get(1, context);
+  Type element_type = IntegerType::get(context, 1);
   return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
                                                     attributes, element_type,
                                                     inferedReturnShapes);
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index 6b7b235..cec1ad7 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -42,6 +42,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
@@ -51,7 +52,6 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
@@ -454,18 +454,23 @@
 // AbsOp
 //===----------------------------------------------------------------------===//
 
-void AbsOp::build(OpBuilder& builder, OperationState& result, Value operand) {
-  auto shaped_type = operand.getType().cast<ShapedType>();
-  Type new_type;
-  if (!shaped_type.getElementType().isa<ComplexType>()) {
-    new_type = operand.getType();
-  } else if (shaped_type.hasRank()) {
-    new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType());
-  } else {
-    new_type = UnrankedTensorType::get(operand.getType());
+LogicalResult AbsOp::inferReturnTypes(
+    MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
+    RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
+  auto operand_ty = (*operands.begin()).getType().cast<ShapedType>();
+  Type element_ty = operand_ty.getElementType();
+  if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
+    element_ty = complex_ty.getElementType();
   }
 
-  return AbsOp::build(builder, result, new_type, operand);
+  Type result_ty;
+  if (operand_ty.hasRank()) {
+    result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty);
+  } else {
+    result_ty = UnrankedTensorType::get(element_ty);
+  }
+  inferredReturnTypes.push_back(result_ty);
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -616,7 +621,7 @@
 static LogicalResult Verify(TupleOp op) {
   SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
                                        op.operand_type_end()};
-  auto expectedType = TupleType::get(operandTypes, op.getContext());
+  auto expectedType = TupleType::get(op.getContext(), operandTypes);
   if (op.getType() != expectedType) {
     return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
                                         op.getType(), expectedType));
@@ -1879,28 +1884,29 @@
                        llvm::ArrayRef<int64_t> shape) {
     for (int64_t i = index.size() - 1; i >= 0; --i) {
       ++index[i];
-      if (index[i] < shape[i]) return true;
+      if (index[i] < shape[i]) return;
       index[i] = 0;
     }
-    return false;
   };
 
   // Iterate over all elements of the input tensor and copy it to the correct
   // location in the output tensor.
   llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
-  do {
-    uint64_t linear_index = 0;
-    uint64_t linear_index_multiplyer = 1;
+  uint64_t num_elements = input.getNumElements();
+  for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) {
+    uint64_t result_idx = 0;
+    uint64_t idx_multiplyer = 1;
     for (int64_t i = index.size() - 1; i >= 0; --i) {
-      linear_index +=
+      result_idx +=
           (edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
            index[i] *
                (interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
-          linear_index_multiplyer;
-      linear_index_multiplyer *= return_type.getShape()[i];
+          idx_multiplyer;
+      idx_multiplyer *= return_type.getDimSize(i);
     }
-    result[linear_index] = input.getValue(index);
-  } while (next_index(index, input.getType().getShape()));
+    result[result_idx] = input.getValue(index);
+    next_index(index, input.getType().getShape());
+  }
   return DenseElementsAttr::get(return_type, result);
 }
 
@@ -1961,7 +1967,7 @@
     MLIRContext* context, Optional<Location>, ValueRange operands,
     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
   inferredReturnTypes.push_back(RankedTensorType::get(
-      /*shape=*/{}, IntegerType::get(32, IntegerType::Unsigned, context)));
+      /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
   return success();
 }
 
@@ -2332,6 +2338,12 @@
 
   auto shape = result_type.getShape();
   int64_t count = result_type.getNumElements();
+  if (count == 0) {
+    return DenseElementsAttr::get<E>(
+        op->getResult().getType().cast<ShapedType>(),
+        /*list=*/{});
+  }
+
   // Compute the striding for each dimension.
   llvm::SmallVector<int64_t, 6> sizes;
   sizes.reserve(shape.size());
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc
index 10c5c0c..572cc43 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc
@@ -31,6 +31,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
@@ -39,7 +40,6 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
index 126eda0..f4ca3a1 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
@@ -32,6 +32,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
@@ -40,7 +41,6 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
index eebdcf4..1f5cf27 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
@@ -112,6 +112,7 @@
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRPass
+  MLIRTensor
 )
 
 add_mlir_library(MhloLhloToLinalg
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
index 2144a59..a9102cbc 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
@@ -30,10 +30,10 @@
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td
index 92a02ea..a4f425e 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td
@@ -23,10 +23,18 @@
 // Unary op patterns.
 //===----------------------------------------------------------------------===//
 
+def NonComplexElementType : Type<
+  CPred<"!$_self.cast<ShapedType>().getElementType().isa<ComplexType>()">,
+  "Non complex element type">;
+
 // Expand acos to MHLO dialect as follows:
 //   acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x))  if x != -1
 //           = pi                                 if x == -1
-def : Pat<(HLOClient_AcosOp $input),
+//
+// TODO(hinsu): Support operands with complex element types separately using
+// the following formula.
+//   acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x))))
+def : Pat<(HLOClient_AcosOp NonComplexElementType:$input),
   (HLO_SelectOp
     (HLO_CompareOp
       $input,
@@ -68,7 +76,9 @@
 // Express `sinh` as
 //   sinh(x) = (e^x - e^-x) / 2                     if |x| < 1
 //           = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
-def : Pat<(HLOClient_SinhOp $input),
+// TODO(hinsu): Support operands with complex element types by always using the
+// second formula. The compare op below is not legal for complex numbers.
+def : Pat<(HLOClient_SinhOp NonComplexElementType:$input),
   (HLO_SelectOp
     (HLO_CompareOp
       (HLO_AbsOp $input),
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
index 76b91f7..59c2a22 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
@@ -24,16 +24,17 @@
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Bufferize.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -62,7 +63,7 @@
     if (shape_element.value() != ShapedType::kDynamicSize) continue;
     Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
     Value alloc_operand =
-        rewriter->create<ExtractElementOp>(loc, shape_operand, index);
+        rewriter->create<tensor::ExtractOp>(loc, shape_operand, index);
     if (!alloc_operand.getType().isIndex()) {
       alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
                                                     rewriter->getIndexType());
@@ -184,32 +185,64 @@
     // for args and outputs.
     const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
                                  static_cast<int32_t>(op->getNumResults())};
-    lhloOp.setAttr(lhloOp.getOperandSegmentSizeAttr(),
-                   rewriter.getI32VectorAttr(segments));
+    lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(),
+                    rewriter.getI32VectorAttr(segments));
 
     rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
     return success();
   }
 };
 
-struct HloToLhloDynamicBroadcastInDimOpConverter
+// TODO(pifon): Consider inserting lhlo.copy as in
+// HloToLhloDynamicBroadcastInDimOpConverter.
+struct HloToLhloDynamicReshapeConverter
+    : public BaseOpConversion<mhlo::DynamicReshapeOp> {
+ public:
+  using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
+
+  LogicalResult matchAndRewrite(
+      mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter& rewriter) const final {
+    Type result_type;
+    if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
+      result_type =
+          MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
+    } else if (auto unranked_type =
+                   op.getType().dyn_cast<UnrankedTensorType>()) {
+      result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
+    } else {
+      return failure();
+    }
+    mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
+    rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
+        op, result_type, adaptor.operand(), adaptor.output_shape());
+    return success();
+  }
+};
+
+class HloToLhloDynamicBroadcastInDimOpConverter
     : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
  public:
-  using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
+  HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter,
+                                            MLIRContext* ctx,
+                                            bool insert_copy = true)
+      : BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter, ctx),
+        insert_copy_(insert_copy) {}
 
   LogicalResult matchAndRewrite(
       mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
       ConversionPatternRewriter& rewriter) const final {
-    auto loc = op.getLoc();
-    Value resultBuffer = InsertDynamicAllocAndDealloc(
-        loc, op.getResult(), op.output_dimensions(), &rewriter);
+    Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
 
-    Value transformed_operand =
-        InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
-    rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
+    if (insert_copy_) {
+      auto loc = op.getLoc();
+      Value result_buffer = InsertDynamicAllocAndDealloc(
+          loc, op.getResult(), op.output_dimensions(), &rewriter);
 
-    rewriter.replaceOp(op, {resultBuffer});
-
+      rewriter.create<lmhlo::CopyOp>(loc, result, result_buffer);
+      result = result_buffer;
+    }
+    rewriter.replaceOp(op, {result});
     return success();
   }
 
@@ -260,7 +293,7 @@
     for (int i = 0; i < result_rank; ++i) {
       Value i_val = b->create<ConstantIndexOp>(loc, i);
       Value result_dim_size =
-          b->create<ExtractElementOp>(loc, op.output_dimensions(), i_val);
+          b->create<tensor::ExtractOp>(loc, op.output_dimensions(), i_val);
       if (!result_dim_size.getType().isIndex()) {
         result_dim_size =
             b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
@@ -307,31 +340,10 @@
         static_strides, llvm::None, sizes, strides);
     return transformed_operand;
   }
-};
 
-struct HloToLhloDynamicReshapeConverter
-    : public BaseOpConversion<mhlo::DynamicReshapeOp> {
- public:
-  using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
-
-  LogicalResult matchAndRewrite(
-      mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
-      ConversionPatternRewriter& rewriter) const final {
-    Type result_type;
-    if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
-      result_type =
-          MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
-    } else if (auto unranked_type =
-                   op.getType().dyn_cast<UnrankedTensorType>()) {
-      result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
-    } else {
-      return failure();
-    }
-    mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
-    rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
-        op, result_type, adaptor.operand(), adaptor.output_shape());
-    return success();
-  }
+  // Keep the copy semantics and allocate a buffer for the result of the memref
+  // cast.
+  bool insert_copy_;
 };
 
 struct HloToLhloDotGeneralOpConverter
@@ -428,7 +440,7 @@
       mhlo::ReturnOp op, ArrayRef<Value> operands,
       ConversionPatternRewriter& rewriter) const final {
     auto loc = op.getLoc();
-    auto& entry_block = op.getParentRegion()->front();
+    auto& entry_block = op->getParentRegion()->front();
     auto num_arguments = entry_block.getNumArguments();
     if (operands.size() > num_arguments) {
       return op.emitError(
@@ -556,6 +568,7 @@
     ConversionTarget target(context);
     target.addLegalDialect<lmhlo::LmhloDialect>();
     target.addLegalDialect<StandardOpsDialect>();
+    target.addLegalDialect<tensor::TensorDialect>();
     target.addIllegalOp<mlir::TensorLoadOp>();
     target.addIllegalOp<mlir::TensorStoreOp>();
     target.addIllegalDialect<mhlo::MhloDialect>();
@@ -593,15 +606,22 @@
 };
 }  // namespace
 
+void populateDynamicHLOToLHLOConversionPattern(
+    MLIRContext* context, BufferizeTypeConverter* converter,
+    OwningRewritePatternList* patterns, bool insert_copy) {
+  patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
+      *converter, context, insert_copy);
+  patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context);
+}
+
 void populateHLOToLHLOConversionPattern(MLIRContext* context,
                                         BufferizeTypeConverter* converter,
                                         OwningRewritePatternList* patterns) {
+  populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
   // clang-format off
   patterns->insert<
       HloToLhloCustomCallOpConverter,
       HloToLhloDotGeneralOpConverter,
-      HloToLhloDynamicBroadcastInDimOpConverter,
-      HloToLhloDynamicReshapeConverter,
       HloToLhloOpConverter<mhlo::AbsOp>,
       HloToLhloOpConverter<mhlo::AddOp>,
       HloToLhloOpConverter<mhlo::AndOp>,
@@ -629,11 +649,16 @@
       HloToLhloOpConverter<mhlo::MulOp>,
       HloToLhloOpConverter<mhlo::NegOp>,
       HloToLhloOpConverter<mhlo::NotOp>,
+      HloToLhloOpConverter<mhlo::OrOp>,
+      HloToLhloOpConverter<mhlo::PowOp>,
       HloToLhloOpConverter<mhlo::RealOp>,
       HloToLhloOpConverter<mhlo::RemOp>,
       HloToLhloOpConverter<mhlo::RsqrtOp>,
       HloToLhloOpConverter<mhlo::ReshapeOp>,
       HloToLhloOpConverter<mhlo::SelectOp>,
+      HloToLhloOpConverter<mhlo::ShiftLeftOp>,
+      HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
+      HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
       HloToLhloOpConverter<mhlo::SignOp>,
       HloToLhloOpConverter<mhlo::SinOp>,
       HloToLhloOpConverter<mhlo::SliceOp>,
@@ -641,6 +666,7 @@
       HloToLhloOpConverter<mhlo::SubOp>,
       HloToLhloOpConverter<mhlo::TanhOp>,
       HloToLhloOpConverter<mhlo::TransposeOp>,
+      HloToLhloOpConverter<mhlo::XorOp>,
       HloToLhloReduceOpConverter,
       HloToLhloReturnOpConverter,
       HloToLhloTensorLoadOpConverter,
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc
index 4591b93..3f876b8 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc
@@ -21,12 +21,13 @@
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"  // TF:llvm-project
 #include "mlir/IR/Block.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -83,7 +84,7 @@
 
   // Extract the predicate for checking branching, then branch to the true and
   // false regions appropriately.
-  auto cond_value = builder.create<mlir::ExtractElementOp>(loc, if_op.pred());
+  auto cond_value = builder.create<mlir::tensor::ExtractOp>(loc, if_op.pred());
   builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
                                      if_op.true_arg(), false_block,
                                      if_op.false_arg());
@@ -142,7 +143,7 @@
   builder.create<mlir::BranchOp>(loc, cond_block, while_op.getOperand());
 
   // Updates the inlined condition blocks by replacing the return op with an
-  // extract_element and conditional branch. This changes the block below:
+  // tensor.extract and conditional branch. This changes the block below:
   //   ^cond(%0):
   //     <inlined conditional region>
   //    "mhlo".return(%1)
@@ -150,7 +151,7 @@
   //  Into:
   //   ^cond(%0):
   //     <inlined conditional region>
-  //     %2 = extract_element %1[] : tensor<i1> // Extract the condition value.
+  //     %2 = tensor.extract %1[] : tensor<i1> // Extract the condition value.
   //     cond_br %2, ^body(%0), ^tail(%0) // Branch.
   builder.setInsertionPointToStart(cond_block);
 
@@ -166,7 +167,8 @@
     builder.setInsertionPointToEnd(new_block);
 
     auto return_value = return_op.getOperand(0);
-    auto cond_value = builder.create<mlir::ExtractElementOp>(loc, return_value);
+    auto cond_value =
+        builder.create<mlir::tensor::ExtractOp>(loc, return_value);
 
     // Get the body block arguments.
     llvm::SmallVector<Value, 4> successor_args(cond_block->args_begin(),
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index c9eeefd..1a153dd 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -25,17 +25,18 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -49,17 +50,17 @@
 }
 
 template <bool isLHLO = true>
-Value getResultValue(Operation* op) {
+Value GetResultValue(Operation* op) {
   return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
 }
 
 template <bool isLHLO = true>
-ShapedType getHloOpResultType(Operation* op) {
-  return getResultValue<isLHLO>(op).getType().template cast<ShapedType>();
+ShapedType GetHloOpResultType(Operation* op) {
+  return GetResultValue<isLHLO>(op).getType().template cast<ShapedType>();
 }
 
 template <bool isLHLO = true>
-bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
+bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
   auto verify_type = [&](Value val) -> bool {
     return (isLHLO && val.getType().isa<MemRefType>()) ||
            (!isLHLO && val.getType().isa<RankedTensorType>());
@@ -131,6 +132,7 @@
     SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
                                             common_indexing_map);
 
+    bool failed = false;
     auto linalg_op = rewriter.create<linalg::GenericOp>(
         loc, op_result_types, inputs, output_buffers,
         /*initTensors=*/ValueRange{}, indexing_maps,
@@ -141,8 +143,13 @@
           Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
               op, body_result_types,
               llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
-          nested_builder.create<linalg::YieldOp>(loc, op_result);
+          if (op_result == nullptr) {
+            failed = true;
+          } else {
+            nested_builder.create<linalg::YieldOp>(loc, op_result);
+          }
         });
+    if (failed) return failure();
     rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
     return success();
   }
@@ -243,7 +250,8 @@
     }
 
     // TODO: LHS dilation for deconvolution not supported yet.
-    if (op.lhs_dilation()) {
+    // TODO(jurahul): Window reversal is not supported yet.
+    if (op.lhs_dilation() || op.hasWindowReversal()) {
       return failure();
     }
 
@@ -292,8 +300,8 @@
   LogicalResult matchAndRewrite(
       OpTy op, ArrayRef<Value> args,
       ConversionPatternRewriter& rewriter) const final {
-    if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
-    auto result_type = getHloOpResultType<isLHLO>(op);
+    if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
+    auto result_type = GetHloOpResultType<isLHLO>(op);
 
     SmallVector<AffineMap, 2> indexing_maps =
         Derived::getIndexingMaps(op, &rewriter);
@@ -330,7 +338,7 @@
     ShapedType input_type =
         broadcast_op.operand().getType().template cast<ShapedType>();
     unsigned input_rank = input_type.getRank();
-    unsigned nloops = getHloOpResultType<isLHLO>(broadcast_op).getRank();
+    unsigned nloops = GetHloOpResultType<isLHLO>(broadcast_op).getRank();
 
     // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
     // the input's dimensions.
@@ -364,7 +372,7 @@
 
   static SmallVector<AffineMap, 2> getIndexingMaps(
       mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
-    auto result_type = getHloOpResultType<false>(broadcast_op);
+    auto result_type = GetHloOpResultType<false>(broadcast_op);
     auto operand_type =
         broadcast_op.operand().getType().template cast<ShapedType>();
     unsigned nloops = result_type.getRank();
@@ -562,7 +570,7 @@
                                 isLHLO>::DataMovementOpConverter;
   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
     auto result_type =
-        getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
+        GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
     auto nloops = result_type.getRank();
     SmallVector<AffineExpr, 2> input_exprs;
     input_exprs.resize(result_type.getRank());
@@ -586,11 +594,11 @@
   LogicalResult matchAndRewrite(
       OpTy reshape_op, ArrayRef<Value> args,
       ConversionPatternRewriter& rewriter) const final {
-    if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
+    if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
       return failure();
     ShapedType operand_type =
         reshape_op.operand().getType().template cast<ShapedType>();
-    ShapedType result_type = getHloOpResultType<isLHLO>(reshape_op);
+    ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
 
     if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
       return failure();
@@ -695,7 +703,7 @@
   LogicalResult matchAndRewrite(
       OpTy iota_op, ArrayRef<Value> args,
       ConversionPatternRewriter& rewriter) const final {
-    ShapedType result_shaped_type = getHloOpResultType<isLHLO>(iota_op);
+    ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
     if (!result_shaped_type) return failure();
 
     auto result_element_type = result_shaped_type.getElementType();
@@ -733,23 +741,37 @@
   }
 };
 
-class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
+template <typename OpTy>
+class ConstConverter : public OpConversionPattern<OpTy> {
  public:
-  using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
+  using OpConversionPattern<OpTy>::OpConversionPattern;
 
   LogicalResult matchAndRewrite(
-      lmhlo::ConstOp const_op, ArrayRef<Value> args,
+      OpTy const_op, ArrayRef<Value> /*args*/,
       ConversionPatternRewriter& rewriter) const final {
-    auto loc = const_op.getLoc();
-    auto value_attr = const_op.value().cast<DenseElementsAttr>();
+    Location loc = const_op.getLoc();
+    auto value_attr = const_op.value().template cast<DenseElementsAttr>();
     if (value_attr.getType().getRank() != 0) return failure();
-    auto std_const_op =
-        rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
-    rewriter.create<mlir::AffineStoreOp>(loc, std_const_op,
-                                         const_op.getOperand(), ValueRange());
-    rewriter.eraseOp(const_op);
+    ReplaceConstOp(loc, const_op, value_attr, rewriter);
     return success();
   }
+
+ private:
+  void ReplaceConstOp(Location loc, mhlo::ConstOp op,
+                      DenseElementsAttr value_attr,
+                      ConversionPatternRewriter& rewriter) const {
+    Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
+    rewriter.replaceOp(op, {std_tensor_const});
+  }
+  void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
+                      DenseElementsAttr value_attr,
+                      ConversionPatternRewriter& rewriter) const {
+    Value std_scalar_const =
+        rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
+    rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(),
+                                         llvm::None);
+    rewriter.eraseOp(op);
+  }
 };
 
 class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
@@ -798,7 +820,8 @@
         loc, /*resultTensorTypes=*/ArrayRef<Type>{},
         /*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(),
         /*initTensors=*/ValueRange{}, maps, types);
-    linalg_op.region().takeBody(reduce_op.body());
+    rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(),
+                                linalg_op.region().end());
     {
       OpBuilder::InsertionGuard region_guard(rewriter);
       Block* block = linalg_op.getBody();
@@ -852,7 +875,7 @@
                                 isLHLO>::DataMovementOpConverter;
   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
     auto result_type =
-        getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
+        GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
     auto nloops = result_type.getRank();
     SmallVector<AffineExpr, 2> input_exprs;
     input_exprs.reserve(nloops);
@@ -908,7 +931,7 @@
                                            OwningRewritePatternList* patterns) {
   // clang-format off
   patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
-                   ConstConverter,
+                   ConstConverter<lmhlo::ConstOp>,
                    ConvToLinalgConverter,
                    IotaConverter<lmhlo::IotaOp>,
                    LhloBroadcastInDimConverter,
@@ -927,22 +950,28 @@
                    PointwiseToLinalgConverter<lmhlo::ExpOp>,
                    PointwiseToLinalgConverter<lmhlo::FloorOp>,
                    PointwiseToLinalgConverter<lmhlo::ImagOp>,
+                   PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
                    PointwiseToLinalgConverter<lmhlo::LogOp>,
                    PointwiseToLinalgConverter<lmhlo::MaxOp>,
                    PointwiseToLinalgConverter<lmhlo::MinOp>,
                    PointwiseToLinalgConverter<lmhlo::MulOp>,
                    PointwiseToLinalgConverter<lmhlo::NegOp>,
                    PointwiseToLinalgConverter<lmhlo::NotOp>,
+                   PointwiseToLinalgConverter<lmhlo::OrOp>,
+                   PointwiseToLinalgConverter<lmhlo::PowOp>,
                    PointwiseToLinalgConverter<lmhlo::RealOp>,
                    PointwiseToLinalgConverter<lmhlo::RemOp>,
                    PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
                    PointwiseToLinalgConverter<lmhlo::SelectOp>,
+                   PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
+                   PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
+                   PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
                    PointwiseToLinalgConverter<lmhlo::SignOp>,
                    PointwiseToLinalgConverter<lmhlo::SinOp>,
                    PointwiseToLinalgConverter<lmhlo::SqrtOp>,
                    PointwiseToLinalgConverter<lmhlo::SubOp>,
                    PointwiseToLinalgConverter<lmhlo::TanhOp>,
-                   PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
+                   PointwiseToLinalgConverter<lmhlo::XorOp>,
                    ReduceConverter,
                    ReshapeOpConverter<lmhlo::ReshapeOp>,
                    ReverseConverter<lmhlo::ReverseOp>,
@@ -994,13 +1023,14 @@
 struct HloLegalizeToLinalgPass
     : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
   void getDependentDialects(DialectRegistry& registry) const override {
-    registry.insert<linalg::LinalgDialect>();
+    registry.insert<linalg::LinalgDialect, scf::SCFDialect>();
   }
 
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     ConversionTarget target(getContext());
-    target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
+    target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
+                           scf::SCFDialect>();
 
     auto func = getFunction();
     mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
@@ -1024,7 +1054,8 @@
                                           OwningRewritePatternList* patterns) {
   patterns
       ->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
-               HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
+               ConstConverter<mhlo::ConstOp>, HloBroadcastInDimConverter,
+               IotaConverter<mhlo::IotaOp, false>,
                PointwiseToLinalgConverter<mhlo::AbsOp, false>,
                PointwiseToLinalgConverter<mhlo::AddOp, false>,
                PointwiseToLinalgConverter<mhlo::AndOp, false>,
@@ -1039,21 +1070,28 @@
                PointwiseToLinalgConverter<mhlo::ExpOp, false>,
                PointwiseToLinalgConverter<mhlo::FloorOp, false>,
                PointwiseToLinalgConverter<mhlo::ImagOp, false>,
+               PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
                PointwiseToLinalgConverter<mhlo::LogOp, false>,
                PointwiseToLinalgConverter<mhlo::MaxOp, false>,
                PointwiseToLinalgConverter<mhlo::MinOp, false>,
                PointwiseToLinalgConverter<mhlo::MulOp, false>,
                PointwiseToLinalgConverter<mhlo::NegOp, false>,
                PointwiseToLinalgConverter<mhlo::NotOp, false>,
+               PointwiseToLinalgConverter<mhlo::OrOp, false>,
+               PointwiseToLinalgConverter<mhlo::PowOp, false>,
                PointwiseToLinalgConverter<mhlo::RealOp, false>,
                PointwiseToLinalgConverter<mhlo::RemOp, false>,
                PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
                PointwiseToLinalgConverter<mhlo::SelectOp, false>,
+               PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
+               PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
+               PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
+               PointwiseToLinalgConverter<mhlo::SignOp, false>,
                PointwiseToLinalgConverter<mhlo::SinOp, false>,
                PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
                PointwiseToLinalgConverter<mhlo::SubOp, false>,
                PointwiseToLinalgConverter<mhlo::TanhOp, false>,
-               PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
+               PointwiseToLinalgConverter<mhlo::XorOp, false>,
                ReshapeOpConverter<mhlo::ReshapeOp, false>,
                ReverseConverter<mhlo::ReverseOp, false>,
                TransposeConverter<mhlo::TransposeOp, false>>(context);
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc
index 63bbd44..454ec18 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc
@@ -145,7 +145,7 @@
 
     auto int_shape_type = RankedTensorType::get(
         output_type.getShape(),
-        IntegerType::get(bitwidth, rewriter.getContext()));
+        IntegerType::get(rewriter.getContext(), bitwidth));
     auto loc = op.getLoc();
     auto integer_const = rewriter.create<mlir::ConstantOp>(
         loc, DenseIntElementsAttr::get(int_shape_type, values));
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc
index 5bd27f0..72fd422 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc
@@ -19,8 +19,8 @@
 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc
index 501627e..1637eef 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc
@@ -30,11 +30,11 @@
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc
index 78d681b..4bf73ae 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc
@@ -20,7 +20,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -437,12 +437,14 @@
         loc, operand_type.getElementType(), mapped_ivs.in_bounds,
         /*withElseRegion=*/true);
 
-    OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
+    OpBuilder then_builder =
+        elem_or_init.getThenBodyBuilder(rewriter->getListener());
     Value elem = then_builder.create<mlir::LoadOp>(
         loc, reduce_window_op.operand(), mapped_ivs.ivs);
     then_builder.create<scf::YieldOp>(loc, elem);
 
-    OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
+    OpBuilder else_builder =
+        elem_or_init.getElseBodyBuilder(rewriter->getListener());
     else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
 
     return rewriter->create<scf::ReduceOp>(loc,
@@ -617,7 +619,8 @@
 
     // Case when we are inside boundaries of 'arg' and not in the pad area.
     {
-      OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder();
+      OpBuilder in_bounds_then_b =
+          if_in_bounds.getThenBodyBuilder(b->getListener());
       auto select_or_init_results = SelectOrInitialize(
           s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
       in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
@@ -625,7 +628,8 @@
 
     // Case when we are in the pad.
     {
-      OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder();
+      OpBuilder in_bounds_else_b =
+          if_in_bounds.getElseBodyBuilder(b->getListener());
       in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
     }
 
@@ -651,7 +655,7 @@
     // element in boundaries of the operand. Select function has to be computed
     // here.
     {
-      OpBuilder if_init_then_b = if_init.getThenBodyBuilder();
+      OpBuilder if_init_then_b = if_init.getThenBodyBuilder(b->getListener());
 
       auto& lhlo_select = s_and_s_op.select().front();
       Value pred =
@@ -664,14 +668,14 @@
       // Pred == true, therefore pack newly selected ivs, val and init flag back
       // to iter_args and return.
       {
-        OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder();
+        OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(b->getListener());
         if_pred_then_b.create<scf::YieldOp>(
             loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
       }
 
       // Pred == false, therefore return old iter_args.
       {
-        OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder();
+        OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(b->getListener());
         if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
       }
 
@@ -680,7 +684,7 @@
     // Init == false, i.e. only pad was visited before and this is the first
     // element in the boundaries of the operand.
     {
-      OpBuilder if_init_else_b = if_init.getElseBodyBuilder();
+      OpBuilder if_init_else_b = if_init.getElseBodyBuilder(b->getListener());
 
       if_init_else_b.create<scf::YieldOp>(
           loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc
index cf63616..ea8783f 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc
@@ -23,9 +23,9 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc
index dba3cab..907fd76 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc
@@ -18,9 +18,10 @@
 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"  // TF:llvm-project
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
@@ -119,7 +120,7 @@
   auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
   auto getAsIndex = [&](Value val) {
     auto loc = whileOp.getLoc();
-    return b.create<ExtractElementOp>(
+    return b.create<tensor::ExtractOp>(
         loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
   };
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
index 1788b28..b2fef91 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
@@ -22,10 +22,10 @@
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -42,11 +42,11 @@
                       sep fn(SqrtOp) sep fn(TanhOp)
 
 // TODO(herhut): Generate these out of op definitions.
-#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep)                           \
-  fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \
-      sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp)             \
-          sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp)              \
-              sep fn(ShiftRightLogicalOp) sep fn(SubOp)
+#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep)                            \
+  fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp)  \
+      sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \
+          sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
+              sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
 
 // TODO(herhut): Generate these out of op definitions.
 #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep)                         \
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc
index 9d07248..0639589 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc
@@ -18,9 +18,9 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Transforms/DialectConversion.h"
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc b/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc
index 71b1a4e..bdd66a1 100644
--- a/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc
@@ -21,8 +21,8 @@
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/StandardTypes.h"
 
 namespace mlir {
 namespace hlo {
diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc b/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc
index a29f0a6..f7177ec 100644
--- a/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc
@@ -18,7 +18,7 @@
 #include "mlir-hlo/utils/convert_op_folder.h"
 
 #include "mlir/IR/Attributes.h"
-#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 
 namespace mlir {
diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc b/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc
index 0bbd91e..8ff1ce3 100644
--- a/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc
@@ -132,5 +132,13 @@
   llvm_unreachable("unsupported type");
 }
 
+std::string LmhloToMhloOpName(llvm::StringRef op_name,
+                              mlir::MLIRContext *context) {
+  assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
+  std::string mhlo_op_name(op_name.drop_front(1));
+  if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
+  return "";
+}
+
 }  // namespace hlo
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
index 7f27252..8e17895 100644
--- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
@@ -327,6 +327,15 @@
   return %1 : tensor<4x1xi64>
 }
 
+// CHECK-LABEL: slice_zero_elements
+func @slice_zero_elements() -> tensor<0xi64> {
+  %0 = mhlo.constant dense<> : tensor<0xi64>
+  // CHECK: %[[CONST:.*]] = mhlo.constant dense<> : tensor<0xi64>
+  %1 = "mhlo.slice"(%0) { limit_indices = dense<[0]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<0xi64>) -> (tensor<0xi64>)
+  // CHECK: return %[[CONST]] : tensor<0xi64>
+  return %1 : tensor<0xi64>
+}
+
 // CHECK-LABEL: slice_unknown_shape
 func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> {
   // CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
@@ -1506,6 +1515,14 @@
   // CHECK-SAME: ]> : tensor<4x5xi32>
 }
 
+func @pad_fold_zero_elements() -> tensor<3xi32> {
+  %0 = mhlo.constant dense<> : tensor<0xi32>
+  %1 = mhlo.constant dense<7> : tensor<i32>
+  %2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<3> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<0xi32>, tensor<i32>) -> tensor<3xi32>
+  return %2 : tensor<3xi32>
+  // CHECK: mhlo.constant dense<7> : tensor<3xi32>
+}
+
 // CHECK-LABEL: @identity_broadcast_reshape
 func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
   %0 = "mhlo.broadcast"(%arg0) {
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir
index cb88019..8cfffb3 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir
@@ -170,24 +170,31 @@
   return %rank : index
 }
 // CHECK: %[[SHAPE:.*]] = tensor_from_elements
+
 // CHECK: %[[C0:.*]] = constant 0 : index
-// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
-// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
 // CHECK: %[[C1:.*]] = constant 1 : index
-// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
-// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
-// CHECK: %[[C2:.*]] = constant 2 : index
-// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
-// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
-// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
 // CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
 // CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
 // CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
+
+// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
+// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
+// CHECK: %[[EL1:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
+
+// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
 // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
 // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
+
+// CHECK: %[[C2:.*]] = constant 2 : index
+// CHECK: %[[EL2:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
+// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
 // CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
 // CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
+
 // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
+
+// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
+
 // CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
 // CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
 
@@ -316,6 +323,20 @@
 
 // -----
 
+// CHECK-LABEL: func @and
+func @and(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
+          %result: memref<2x2xi32>) {
+  %tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
+  %tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
+  %tensor_result = "mhlo.and"(%tensor_operand0, %tensor_operand1)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  // CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
+  tensor_store %tensor_result, %result : memref<2x2xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @ceil
 func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
   %tensor_operand = tensor_load %operand : memref<2x2xf32>
@@ -389,6 +410,20 @@
 
 // -----
 
+// CHECK-LABEL: func @or
+func @or(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
+         %result: memref<2x2xi32>) {
+  %tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
+  %tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
+  %tensor_result = "mhlo.or"(%tensor_operand0, %tensor_operand1)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  // CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
+  tensor_store %tensor_result, %result : memref<2x2xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt
 func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
   %tensor_operand = tensor_load %operand : memref<2x2xf32>
@@ -425,6 +460,48 @@
 
 // -----
 
+// CHECK-LABEL: func @shift_left
+func @shift_left(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
+                 %result: memref<2x2xi32>) {
+  %tensor_lhs = tensor_load %lhs : memref<2x2xi32>
+  %tensor_rhs = tensor_load %rhs : memref<2x2xi32>
+  %tensor_result = "mhlo.shift_left"(%tensor_lhs, %tensor_rhs)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  // CHECK: "lmhlo.shift_left"(%{{.*}}, %{{.*}})
+  tensor_store %tensor_result, %result : memref<2x2xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @shift_right_arithmetic
+func @shift_right_arithmetic(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
+                             %result: memref<2x2xi32>) {
+  %tensor_lhs = tensor_load %lhs : memref<2x2xi32>
+  %tensor_rhs = tensor_load %rhs : memref<2x2xi32>
+  %tensor_result = "mhlo.shift_right_arithmetic"(%tensor_lhs, %tensor_rhs)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  // CHECK: "lmhlo.shift_right_arithmetic"(%{{.*}}, %{{.*}})
+  tensor_store %tensor_result, %result : memref<2x2xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @shift_right_logical
+func @shift_right_logical(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
+                          %result: memref<2x2xi32>) {
+  %tensor_lhs = tensor_load %lhs : memref<2x2xi32>
+  %tensor_rhs = tensor_load %rhs : memref<2x2xi32>
+  %tensor_result = "mhlo.shift_right_logical"(%tensor_lhs, %tensor_rhs)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  // CHECK: "lmhlo.shift_right_logical"(%{{.*}}, %{{.*}})
+  tensor_store %tensor_result, %result : memref<2x2xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @tanh
 func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
   %tensor_operand = tensor_load %operand : memref<2x2xf32>
@@ -438,7 +515,8 @@
 // -----
 
 // CHECK-LABEL: func @remainder
-func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
+func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
+                %result: memref<2x2xf32>) {
   %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
   %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
   %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
@@ -450,6 +528,20 @@
 
 // -----
 
+// CHECK-LABEL: func @xor
+func @xor(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
+          %result: memref<2x2xi32>) {
+  %tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
+  %tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
+  %tensor_result = "mhlo.xor"(%tensor_operand0, %tensor_operand1)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  // CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
+  tensor_store %tensor_result, %result : memref<2x2xi32>
+  return
+}
+
+// -----
+
 // Dynamic shape binary element-wise operation.
 // CHECK-LABEL: func @add_dyn
 func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
@@ -462,9 +554,9 @@
   // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
   // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
   // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
-  // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
+  // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
   // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
-  // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
+  // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
   // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
   // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
   // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
@@ -485,9 +577,9 @@
   // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
   // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
   // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
-  // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
+  // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
   // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
-  // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
+  // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
   // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
   // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
   // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
index c4413ed..71a8b79 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
@@ -29,6 +29,18 @@
 
 // -----
 
+// CHECK-LABEL: complex_add
+func @complex_add(%lhs: tensor<2x2xcomplex<f32>>,
+                  %rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
+  // CHECK: linalg.generic
+  // CHECK: addcf
+  %0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
+      tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
+  return %0 : tensor<2x2xcomplex<f32>>
+}
+
+// -----
+
 // CHECK-LABEL: func @float_mul
 func @float_mul(%lhs: tensor<2x2xf32>,
                 %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
@@ -112,6 +124,18 @@
 
 // -----
 
+// CHECK-LABEL: complex_sub
+func @complex_sub(%lhs: tensor<2x2xcomplex<f32>>,
+                  %rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
+  // CHECK: linalg.generic
+  // CHECK: subcf
+  %0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
+      tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
+  return %0 : tensor<2x2xcomplex<f32>>
+}
+
+// -----
+
 // CHECK-LABEL: func @float_abs
 func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
   // CHECK: linalg.generic
@@ -194,6 +218,30 @@
 
 // -----
 
+// CHECK-LABEL: func @integer_or
+func @integer_or(%lhs: tensor<2x2xi32>,
+                  %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+  // CHECK: linalg.generic
+  // CHECK: or
+  %0 = "mhlo.or"(%lhs, %rhs) : (tensor<2x2xi32>,
+                                    tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @integer_xor
+func @integer_xor(%lhs: tensor<2x2xi32>,
+                  %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+  // CHECK: linalg.generic
+  // CHECK: xor
+  %0 = "mhlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>,
+                                    tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @float_cmp
 func @float_cmp(%lhs: tensor<2x2xf32>,
                 %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
@@ -208,6 +256,20 @@
 
 // -----
 
+// CHECK-LABEL: func @float_cmp_ne
+func @float_cmp_ne(%lhs: tensor<2x2xf32>,
+                %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
+  %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"}
+          : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
+  return %0 : tensor<2x2xi1>
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
+// CHECK-NEXT:   %[[RESULT:.*]] = cmpf "une", %[[LHS_IN]], %[[RHS_IN]] : f32
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i1
+
+// -----
+
 // CHECK-LABEL: func @int_cmp
 func @int_cmp(%lhs: tensor<2x2xi32>,
               %rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
@@ -630,3 +692,95 @@
 // CHECK-NEXT:   %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
 // CHECK-NEXT:   %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
 // CHECK-NEXT:   linalg.yield %[[FLOAT_CAST]] : f32
+
+// -----
+
+func @shift_left(%lhs: tensor<2x2xi32>,
+                 %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+  %result = "mhlo.shift_left"(%lhs, %rhs)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %result : tensor<2x2xi32>
+}
+// CHECK-LABEL: func @shift_left
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
+// CHECK-NEXT:   %[[RESULT:.*]] = shift_left %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
+
+// -----
+
+func @shift_right_arithmetic(%lhs: tensor<2x2xi32>,
+                             %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+  %result = "mhlo.shift_right_arithmetic"(%lhs, %rhs)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %result : tensor<2x2xi32>
+}
+// CHECK-LABEL: func @shift_right_arithmetic
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
+// CHECK-NEXT:   %[[RESULT:.*]] = shift_right_signed %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
+
+// -----
+
+func @shift_right_logical(%lhs: tensor<2x2xi32>,
+                          %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+  %result = "mhlo.shift_right_logical"(%lhs, %rhs)
+      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %result : tensor<2x2xi32>
+}
+// CHECK-LABEL: func @shift_right_logical
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
+// CHECK-NEXT:   %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
+
+// -----
+
+// CHECK-LABEL: func @constant
+func @constant() {
+  %result = "mhlo.constant"() {
+    value = dense<10> : tensor<i32>
+  } : () -> (tensor<i32>)
+  return
+}
+// CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor<i32>
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @float_pow
+func @float_pow(%lhs: tensor<2x2xf32>,
+                %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // CHECK: linalg.generic
+  // CHECK: ^{{[a-z0-9_]*}}
+  // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
+  // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
+  // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = powf %[[ARG0]], %[[ARG1]]
+  // CHECK: linalg.yield %[[RESULT]]
+  %0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xf32>,
+                                   tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @integer_pow
+func @integer_pow(%lhs: tensor<2x2xi32>,
+                %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+                    // CHECK: linalg.generic
+  // CHECK: ^{{[a-z0-9_]*}}
+  // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
+  // CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]]
+  // CHECK: %[[RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
+  // CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}}
+  // CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) {
+  //   CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]]
+  //   CHECK: scf.yield %[[ACCUM]]
+  // CHECK: linalg.yield %[[RESULT]]
+  %0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
+                                   tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir
index 274792e..8e5e18a 100644
--- a/tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir
@@ -5,7 +5,7 @@
   //CHECK:   br ^bb1(%arg0 : tensor<i64>)
   //CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
   //CHECK:   [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]])
-  //CHECK:   [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor<i1>
+  //CHECK:   [[VAL2:%.+]] = tensor.extract [[VAL1]][] : tensor<i1>
   //CHECK:   cond_br [[VAL2]], ^bb2([[VAL0]] : tensor<i64>), ^bb3([[VAL0]] : tensor<i64>)
   //CHECK: ^bb2([[VAL3:%.+]]: tensor<i64>):
   //CHECK:   [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]]
@@ -33,7 +33,7 @@
   // CHECK:   [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
   %0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
 
-  // CHECK:   [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
+  // CHECK:   [[VAL1:%.+]] = tensor.extract [[VAL0]][] : tensor<i1>
   // CHECK:   cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
   %1 = "mhlo.if"(%0, %arg0, %arg0) ( {
 
@@ -63,7 +63,7 @@
   // CHECK:   br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
   // CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
   // CHECK:   %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
-  // CHECK:   %2 = extract_element %1[] : tensor<i1>
+  // CHECK:   %2 = tensor.extract %1[] : tensor<i1>
   // CHECK:   cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
   // CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>):
   // CHECK:   br ^[[BODY_SUCC:.+]](%3 : tensor<i64>)
@@ -95,7 +95,7 @@
   // CHECK:   br ^[[COND_SUCC:.+]](%0 : tensor<i64>)
   // CHECK: ^[[COND_SUCC]](%1: tensor<i64>):
   // CHECK:   %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
-  // CHECK:   %3 = extract_element %2[] : tensor<i1>
+  // CHECK:   %3 = tensor.extract %2[] : tensor<i1>
   // CHECK:   cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
   // CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>):
   // CHECK:   br ^[[COND_ENTRY]](%4 : tensor<i64>)
@@ -118,7 +118,7 @@
 
 // CHECK-LABEL: func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
 func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
-  // CHECK:   %0 = extract_element %arg2[] : tensor<i1>
+  // CHECK:   %0 = tensor.extract %arg2[] : tensor<i1>
   // CHECK:   cond_br %0, ^[[THEN_ENTRY:.+]](%arg0 : tensor<f32>), ^[[ELSE_ENTRY:.+]](%arg1 : tensor<f32>)
   // CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>):
   // CHECK:   br ^[[THEN_SUCC:.+]](%1 : tensor<f32>)
diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir
index 9c887a7..101800d 100644
--- a/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir
@@ -30,9 +30,9 @@
 // CHECK:  %[[VAL_11:.*]] = constant dense<0> : tensor<i32>
 // CHECK:  %[[VAL_12:.*]] = constant dense<1000> : tensor<i32>
 // CHECK:  %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor<i32> to tensor<index>
-// CHECK:  %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor<index>
+// CHECK:  %[[VAL_15:.*]] = tensor.extract %[[VAL_14]][] : tensor<index>
 // CHECK:  %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor<i32> to tensor<index>
-// CHECK:  %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor<index>
+// CHECK:  %[[VAL_17:.*]] = tensor.extract %[[VAL_16]][] : tensor<index>
 // CHECK:  %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor<i32> to tensor<index>
-// CHECK:  %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor<index>
+// CHECK:  %[[VAL_19:.*]] = tensor.extract %[[VAL_18]][] : tensor<index>
 // CHECK:  scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]])
diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
index 5bfde29..716e00a 100644
--- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
@@ -4,6 +4,21 @@
 // CHECK-LABEL: func @element_wise
 func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
                    %result: memref<2x2xf32>) {
+  "lmhlo.power"(%lhs, %rhs, %result)
+      : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
+  return
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
+// CHECK-NEXT:   %[[RESULT:.*]] = powf %[[LHS_IN]], %[[RHS_IN]] : f32
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @element_wise
+func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
+                   %result: memref<2x2xf32>) {
   "lmhlo.add"(%lhs, %rhs, %result)
       : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
   return
@@ -594,8 +609,12 @@
 }
 // CHECK: linalg.generic
 // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
-// CHECK-NEXT:   %[[CST:.*]] = constant 1.000000e+00 : f32
-// CHECK-NEXT:   %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : f32
+// CHECK-NEXT:   %[[CST_0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT:   %[[NE_0:.*]] = cmpf "one", %[[OPERAND_IN]], %[[CST_0]] : f32
+// CHECK-NEXT:   %[[NE_0_FLOAT:.*]] = uitofp %[[NE_0]] : i1 to f32
+// CHECK-NEXT:   %[[SIGN:.*]] = copysign %[[NE_0_FLOAT]], %[[OPERAND_IN]] : f32
+// CHECK-NEXT:   %[[CMP:.*]] = cmpf "uno", %[[OPERAND_IN]], %[[OPERAND_IN]] : f32
+// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[OPERAND_IN]], %[[SIGN]] : f32
 // CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 
 // -----
@@ -607,8 +626,12 @@
 }
 // CHECK: linalg.generic
 // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]):
-// CHECK-NEXT:   %[[CST:.*]] = constant 1.000000e+00 : bf16
-// CHECK-NEXT:   %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16
+// CHECK-NEXT:   %[[CST_0:.*]] = constant 0.000000e+00 : bf16
+// CHECK-NEXT:   %[[NE_0:.*]] = cmpf "one", %[[OPERAND_IN]], %[[CST_0]] : bf16
+// CHECK-NEXT:   %[[NE_0_FLOAT:.*]] = uitofp %[[NE_0]] : i1 to bf16
+// CHECK-NEXT:   %[[SIGN:.*]] = copysign %[[NE_0_FLOAT]], %[[OPERAND_IN]] : bf16
+// CHECK-NEXT:   %[[CMP:.*]] = cmpf "uno", %[[OPERAND_IN]], %[[OPERAND_IN]] : bf16
+// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[OPERAND_IN]], %[[SIGN]] : bf16
 // CHECK-NEXT:   linalg.yield %[[RESULT]] : bf16
 
 // -----
diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir
index 35bf59b..82c455c 100644
--- a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir
@@ -50,8 +50,11 @@
       feature_group_count = 1,
       batch_group_count = 1,
       result_scale = 1.0,
-      backend_config = {algorithm=0, tensor_ops_enabled = true }
-    }
+      backend_config = {algorithm=0,
+                        operand_0_layout = [3,2,1,0],
+                        operand_1_layout = [3,2,1,0],
+                        result_layout = [3,2,1,0],
+                        tensor_ops_enabled = true}}
     : (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> ()
   return
 }
@@ -60,7 +63,11 @@
 func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
   %scratch = alloc() : memref<23328xui8>
   "lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
-    { backend_config = {algorithm = 1 : i64, tensor_ops_enabled = false},
+    { backend_config = {algorithm = 1 : i64,
+                        operand_0_layout = [3,2,1,0],
+                        operand_1_layout = [3,2,1,0],
+                        result_layout = [3,2,1,0],
+                        tensor_ops_enabled = false},
       batch_group_count = 1 : i64,
       dimension_numbers = {input_batch_dimension = 0 : i64,
                            input_feature_dimension = 3 : i64,
@@ -86,7 +93,11 @@
 func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
   %scratch = alloc() : memref<32xui8>
   "lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
-  { backend_config = {algorithm = 1 : i64, tensor_ops_enabled = false},
+    { backend_config = {algorithm = 1 : i64,
+                        operand_0_layout = [3,2,1,0],
+                        operand_1_layout = [3,2,1,0],
+                        result_layout = [3,2,1,0],
+                        tensor_ops_enabled = false},
     batch_group_count = 1 : i64,
     dimension_numbers = {input_batch_dimension = 0 : i64,
                          input_feature_dimension = 1 : i64,
@@ -103,7 +114,8 @@
     precision_config = [],
     result_scale = 1.000000e+00 : f64,
     rhs_dilation = dense<1> : tensor<2xi64>,
-    window_strides = dense<1> : tensor<2xi64>}
+    window_strides = dense<1> : tensor<2xi64>,
+    window_reversal = dense<true>: tensor<2xi1>}
   : (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> ()
   return
 }
@@ -113,7 +125,11 @@
   %scratch = alloc() : memref<32xui8>
   "lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
     {activation_mode = "Relu",
-     backend_config = {algorithm = 0 : i64, tensor_ops_enabled = false},
+     backend_config = {algorithm = 1 : i64,
+                       operand_0_layout = [3,2,1,0],
+                       operand_1_layout = [3,2,1,0],
+                       result_layout = [3,2,1,0],
+                       tensor_ops_enabled = false},
      batch_group_count = 1 : i64,
      dimension_numbers = {input_batch_dimension = 0 : i64,
        input_feature_dimension = 1 : i64,
@@ -140,7 +156,11 @@
   %scratch = alloc() : memref<0xui8>
   "lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
     {activation_mode = "Relu",
-     backend_config = {algorithm = 0 : i64, tensor_ops_enabled = false},
+     backend_config = {algorithm = 1 : i64,
+                       operand_0_layout = [3,2,1,0],
+                       operand_1_layout = [3,2,1,0],
+                       result_layout = [3,2,1,0],
+                       tensor_ops_enabled = false},
      batch_group_count = 1 : i64,
      dimension_numbers = {input_batch_dimension = 0 : i64,
        input_feature_dimension = 1 : i64,
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index b7f6de7..4b3ed9f 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -25,6 +25,7 @@
     packages = [
         "//learning/brain/mlir/...",
         "//tensorflow/compiler/mlir/...",
+        "//third_party/iree/...",
     ],
 )
 
@@ -353,6 +354,25 @@
 )
 
 cc_library(
+    name = "perception_ops_utils",
+    srcs = [
+        "utils/perception_ops_utils.cc",
+    ],
+    hdrs = [
+        "utils/perception_ops_utils.h",
+    ],
+    copts = ["-std=c++14"],
+    deps = [
+        ":tensorflow_lite",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
+        "//tensorflow/lite/c:common",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
+    ],
+)
+
+cc_library(
     name = "stateful_ops_utils",
     srcs = [
         "utils/stateful_ops_utils.cc",
@@ -384,6 +404,23 @@
     ],
 )
 
+tf_cc_test(
+    name = "perception_ops_utils_test",
+    size = "small",
+    srcs = ["utils/perception_ops_utils_test.cc"],
+    deps = [
+        ":perception_ops_utils",
+        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:StandardOps",
+    ],
+)
+
 cc_library(
     name = "tensorflow_lite_legalize_tf",
     srcs = [
@@ -413,6 +450,7 @@
         ":constant_utils",
         ":lstm_utils",
         ":nms_utils",
+        ":perception_ops_utils",
         ":stateful_ops_utils",
         ":tensorflow_lite",
         ":tftext_utils",
diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h b/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h
index 782714f..59bb7ae 100644
--- a/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h
+++ b/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h
@@ -15,8 +15,8 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
 
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 
 // For add/mul/div/sub and other broadcastable ops.
 class ArithmeticCountUtilHelper {
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
index fdf27cb..5dc9238 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
@@ -47,10 +47,10 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -1244,7 +1244,7 @@
 }
 
 void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
-  auto dict_attr = fn.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
+  auto dict_attr = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
   if (!dict_attr) return;
 
   llvm::SmallVector<llvm::StringRef, 2> input_names;
@@ -1481,7 +1481,7 @@
 
 Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
 Translator::CreateMetadataVector() {
-  auto dict_attr = module_.getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
+  auto dict_attr = module_->getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
   std::vector<BufferOffset<tflite::Metadata>> metadata;
   if (dict_attr) {
     for (const auto& named_attr : dict_attr) {
@@ -1559,7 +1559,7 @@
 
   // Fetch function inputs and outputs tensor names.
   auto dict_attr =
-      main_op.getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
+      main_op->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
   if (!dict_attr) return {};
 
   // Get Input and output tensor names from attribute.
@@ -1592,7 +1592,7 @@
   }
   // Exported method name.
   auto exported_name =
-      main_op.getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
+      main_op->getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
   if (exported_name.empty()) {
     main_op.emitError("Empty exported names for main Function");
     return {};
@@ -1658,7 +1658,7 @@
   int entry_func_count = 0;
   FuncOp entry_func = nullptr;
   for (auto fn : module.getOps<FuncOp>()) {
-    auto attrs = fn.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
+    auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
     if (attrs && !attrs.empty()) {
       entry_func_count++;
       entry_func = fn;
@@ -1759,13 +1759,14 @@
     std::string err;
     if (!failed_flex_ops_.empty())
       err +=
-          "Ops that can be supported by the flex runtime (enabled via setting "
-          "the -emit-select-tf-ops flag):\n" +
+          "Some ops are not supported by the native TFLite runtime, you can "
+          "enable TF kernels fallback using TF Select. See instructions: "
+          "https://www.tensorflow.org/lite/guide/ops_select" +
           failed_flex_ops_summary;
     if (!failed_custom_ops_.empty())
       err +=
-          "Ops that need custom implementation (enabled via setting the "
-          "-emit-custom-ops flag):\n" +
+          "Some ops in the model are custom ops, See instructions to implement "
+          "custom ops: https://www.tensorflow.org/lite/guide/ops_custom" +
           failed_custom_ops_summary;
 
     auto& failed_region = named_regions[first_failed_func];
@@ -1776,7 +1777,7 @@
   }
 
   std::string model_description;
-  if (auto attr = module_.getAttrOfType<StringAttr>("tfl.description")) {
+  if (auto attr = module_->getAttrOfType<StringAttr>("tfl.description")) {
     model_description = attr.getValue().str();
   } else {
     model_description = "MLIR Converted.";
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index a4339a1..5950602 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -51,12 +51,12 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
@@ -509,7 +509,7 @@
     return op.getOperation();
   }
   auto op = builder.create<tfl::ConstOp>(loc, value);
-  if (!tensor.quantization->min.empty()) {
+  if (tensor.quantization && !tensor.quantization->min.empty()) {
     if (auto stats_op =
             ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
       return stats_op;
@@ -1012,7 +1012,7 @@
       attributes.push_back(BuildTFEntryFunctionAttribute(
           subgraph, &builder, "outputs", func_outputs));
     }
-    func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
+    func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
   } else {
     func.setPrivate();
   }
@@ -1170,11 +1170,11 @@
   auto module = mlir::ModuleOp::create(base_loc);
   // We currently don't use this to make decisions, but we could
   // use it in exports or if there are breaking changes
-  module.setAttr("tfl.schema_version",
-                 builder.getI32IntegerAttr(model->version));
+  module->setAttr("tfl.schema_version",
+                  builder.getI32IntegerAttr(model->version));
   if (!model->description.empty()) {
-    module.setAttr("tfl.description",
-                   builder.getStringAttr(model->description));
+    module->setAttr("tfl.description",
+                    builder.getStringAttr(model->description));
   }
 
   for (auto e : llvm::enumerate(model->subgraphs)) {
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc
index 60fd116..df9ddaf 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc
@@ -23,7 +23,7 @@
 #include "llvm/ADT/StringSwitch.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
index bcd3243..901199e 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
@@ -23,10 +23,10 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 6f45636..9d40f4c 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -31,11 +31,11 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -1990,14 +1990,24 @@
     return nullptr;
   }
 
-  const bool is_input_unsigned = operand_element_type.isUnsigned();
+  const bool is_unsigned = operand_element_type.isUnsigned();
+  const bool involves_bool = operand_element_type.getWidth() == 1 ||
+                             result_element_type.getWidth() == 1;
   const int output_bitwidth = result_element_type.getWidth();
   // The integer cast op is the same as C integer cast. Depends on the operand
   // type's signedness, we will determine whether or not sign extension is
   // needed.
   auto cast = [&](APInt value) {
-    return is_input_unsigned ? value.zextOrTrunc(output_bitwidth)
-                             : value.sextOrTrunc(output_bitwidth);
+    if (involves_bool) {
+      // Handle boolean inputs or outputs explicitly as it doesn't have the same
+      // behavior as extension or truncation.
+      // true input should always be cast to 1 and not -1 as the sign extension
+      // would do for signed outputs. Similarly, non-zero inputs should be cast
+      // to true. Truncating even numbers to one bit will result in `false`.
+      return APInt(result_element_type.getWidth(), value != 0);
+    }
+    return is_unsigned ? value.zextOrTrunc(output_bitwidth)
+                       : value.sextOrTrunc(output_bitwidth);
   };
 
   return elements_attr.mapValues(result_element_type, cast);
@@ -2447,6 +2457,8 @@
   if (value.isa<OpaqueElementsAttr>() ||
       (value.isa<ElementsAttr>() && value.getType() != type))
     return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
+  if (ConstantOp::isBuildableWith(value, type))
+    return builder.create<ConstantOp>(loc, type, value);
   return nullptr;
 }
 
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h
index 589f18d..74fb98a 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h
@@ -22,9 +22,9 @@
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Interfaces/DerivedAttributeOpInterface.h"  // from @llvm-project
 #include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
 #include "mlir/Interfaces/LoopLikeInterface.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 4ca957c..c3930fc 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -354,9 +354,9 @@
               TFL_TCresVTEtIsSameAsOp<0, 0>>;
 
 // This is a quantization-aware version of TCopVTEtAreSameAt
-class TFL_TCopVTEtAreSameAt<int i, int j> : Or<[
+class TFL_TCopVTEtAreSameAt<int i, int j, int num=8> : Or<[
   TCopVTEtAreSameAt<[i, j]>,
-  TFL_TFOperandTypesWithSameBits<i, j, 8>,
+  TFL_TFOperandTypesWithSameBits<i, j, num>,
   And<[
     SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # j # "))",
       quant_QuantizedType.predicate>,
@@ -369,11 +369,8 @@
 // TFL op common constraints.
 //===----------------------------------------------------------------------===//
 
-// This is a constraint for most of the binary ops, e.g., add, mul, div, etc.
-// Binary ops lhs & rhs should have the same value type, and is capable to
-// compare quantization types as well.
-def BinaryOpSameElementTypeConstraint :
-  PredOpTrait<"operands have same element type",
+class OperandsSameElementTypeConstraintBase<string op> :
+  PredOpTrait<op # " operands have same element type",
     Or<[
       TCopVTEtIsSameAs<0, 1>,
       // Two operands' values are both quantized and their type have the same
@@ -386,6 +383,18 @@
               "quant::QuantizedType::castToStorageType("
                   "getElementTypeOrSelf($_op.getOperand(1)))">]>]>>;
 
+// This is a constraint for most of the binary ops, e.g., add, mul, div, etc.
+// Binary ops lhs & rhs should have the same value type, and is capable to
+// compare quantization types as well.
+def BinaryOpSameElementTypeConstraint :
+  OperandsSameElementTypeConstraintBase<"binary op">;
+
+// This is a constraint for most of the comparison ops, e.g., equal, not_equal,
+// greater, greater_equal, less, etc. Comparison ops lhs & rhs should have the
+// same value type, and is capable to compare quantization types as well.
+def ComparisonOpSameElementTypeConstraint :
+  OperandsSameElementTypeConstraintBase<"comparison op">;
+
 //===----------------------------------------------------------------------===//
 // TFL common builders.
 //===----------------------------------------------------------------------===//
@@ -1100,7 +1109,7 @@
 // Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
 def TFL_LessEqualOp : TFL_Op<"less_equal", [
     ResultsBroadcastableShape,
-    BinaryOpSameElementTypeConstraint,
+    ComparisonOpSameElementTypeConstraint,
     TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
     NoSideEffect]> {
   let summary = "Less_equal operator";
@@ -1164,6 +1173,7 @@
 def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
     TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
     ResultsBroadcastableShape,
+    ComparisonOpSameElementTypeConstraint,
     NoSideEffect]> {
   let summary = "Greater_equal operator";
 
@@ -1355,7 +1365,7 @@
 
 def TFL_NotEqualOp : TFL_Op<"not_equal", [
     TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
-    BinaryOpSameElementTypeConstraint,
+    ComparisonOpSameElementTypeConstraint,
     ResultsBroadcastableShape,
     Commutative,
     NoSideEffect,
@@ -1459,9 +1469,10 @@
 
 def TFL_EqualOp: TFL_Op<"equal", [
     Commutative,
+    NoSideEffect,
     ResultsBroadcastableShape,
     TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
-    PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
+    ComparisonOpSameElementTypeConstraint]> {
   let summary = "Equal operator";
 
   let description = [{
@@ -1668,7 +1679,7 @@
 
 def TFL_GreaterOp : TFL_Op<"greater", [
     ResultsBroadcastableShape,
-    BinaryOpSameElementTypeConstraint,
+    ComparisonOpSameElementTypeConstraint,
     TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
     NoSideEffect]> {
   let summary = "Greater operator";
@@ -1767,7 +1778,7 @@
 
 def TFL_LessOp : TFL_Op<"less", [
     ResultsBroadcastableShape,
-    BinaryOpSameElementTypeConstraint,
+    ComparisonOpSameElementTypeConstraint,
     TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
     NoSideEffect]> {
   let summary = "Less operator";
@@ -3676,7 +3687,8 @@
 
 def LstmOptionalPeepholeWeightConstraint : PredOpTrait<
   "the optional peephole weights should all be specified or none",
-  TCopVTEtAreSameAt<[9, 10, 11]>>;
+  And<[TFL_TCopVTEtAreSameAt<9, 10, 16>,
+       TFL_TCopVTEtAreSameAt<9, 11, 16>]>>;
 
 def LstmProjectionWeightBiasConstraint : PredOpTrait<
   "either projection weight must be specified or both projection weight and "
@@ -3982,6 +3994,7 @@
                        30, 31, 32, 35, 36, 37, 38]>,
     Neg<TypeIsPred<"input", F32>>]>>;
 
+// TODO(b/172517537): support quantized types
 def BidiLstmOptionalPeepholeWeightConstraint : PredOpTrait<
   "the optional peephole weights should all be specified or none",
   TCopVTEtAreSameAt<[9, 10, 11, 26, 27, 28]>>;
diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD
index caa5605..dd7f68f 100644
--- a/tensorflow/compiler/mlir/lite/python/BUILD
+++ b/tensorflow/compiler/mlir/lite/python/BUILD
@@ -22,6 +22,7 @@
         "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+        "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
         "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
         "//tensorflow/core:core_cpu_base",
         "//tensorflow/core:lib",
diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
index 1fa3f04..28cb34d 100644
--- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
@@ -21,8 +21,8 @@
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/ToolOutputFile.h"
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
@@ -51,7 +51,7 @@
   mlir::FuncOp entry_function = nullptr;
   for (auto func : module->get().getOps<mlir::FuncOp>()) {
     if (auto tf_attrs =
-            func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
+            func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
       // TODO(jaesung): There could be multiple entry functions. Let's handle
       // such cases if there are any needs for that.
       if (entry_function != nullptr) {
@@ -67,7 +67,7 @@
 
   // Get the list of input Op names from the function attribute.
   mlir::DictionaryAttr tf_attrs =
-      entry_function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
+      entry_function->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
   llvm::SmallVector<llvm::StringRef, 4> function_input_names;
   function_input_names.reserve(model_flags.input_arrays().size());
   auto input_attr = tf_attrs.get("inputs");
diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
index ecc1d2e..c5562c1 100644
--- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
+++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
@@ -30,6 +30,7 @@
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
@@ -302,6 +303,7 @@
 
   mlir::PassManager pm(module->getContext(),
                        mlir::OpPassManager::Nesting::Implicit);
+  ::tensorflow::SetCrashReproducer(pm);
 
   tensorflow::AddTFToTFLConversionPasses(pass_config, &pm, session);
   // Convert back to outlined while format for export back to flatbuffer.
diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.cc b/tensorflow/compiler/mlir/lite/quantization/device_target.cc
index 6b5c894..09dae97 100644
--- a/tensorflow/compiler/mlir/lite/quantization/device_target.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/device_target.cc
@@ -37,10 +37,10 @@
 
 DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
   f32_ = FloatType::getF32(ctx_);
-  i8_ = IntegerType::get(k8Bits, ctx_);
+  i8_ = IntegerType::get(ctx_, k8Bits);
   i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
   i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
-  i32_ = IntegerType::get(k32Bits, ctx_);
+  i32_ = IntegerType::get(ctx_, k32Bits);
   i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
   i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
   any_ = AnyQuantizedType();
@@ -131,8 +131,8 @@
   output_multipliers->push_back(quant::QuantizeMultiplier(real_multiplier));
 
   // output ranges
-  auto min = rop.getAttrOfType<FloatAttr>("min");
-  auto max = rop.getAttrOfType<FloatAttr>("max");
+  auto min = rop->getAttrOfType<FloatAttr>("min");
+  auto max = rop->getAttrOfType<FloatAttr>("max");
   output_ranges->push_back(quant::CalculateQuantizedRange(
       o_spec.getScale(), o_spec.getZeroPoint(),
       (min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
@@ -166,8 +166,8 @@
   if (!o_spec) return failure();
 
   // output ranges
-  auto min = rop.getAttrOfType<FloatAttr>("min");
-  auto max = rop.getAttrOfType<FloatAttr>("max");
+  auto min = rop->getAttrOfType<FloatAttr>("min");
+  auto max = rop->getAttrOfType<FloatAttr>("max");
   output_ranges->push_back(quant::CalculateQuantizedRange(
       o_spec.getScale(), o_spec.getZeroPoint(),
       (min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.h b/tensorflow/compiler/mlir/lite/quantization/device_target.h
index 8ed4315..936fe6b 100644
--- a/tensorflow/compiler/mlir/lite/quantization/device_target.h
+++ b/tensorflow/compiler/mlir/lite/quantization/device_target.h
@@ -29,8 +29,8 @@
 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
index 7e7d467..1c86d1a 100644
--- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
@@ -29,9 +29,9 @@
 #include "mlir/IR/AffineExpr.h"  // from @llvm-project
 #include "mlir/IR/AffineMap.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc
index 6b226fa..5c78a53 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc
@@ -49,7 +49,7 @@
                                           dq.arg());
       dq.getResult().replaceAllUsesWith(dcast);
       if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
-        dcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
+        dcast->setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
       }
       dq.erase();
     } else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
@@ -58,7 +58,7 @@
                                         TypeAttr::get(out_type));
       q.getResult().replaceAllUsesWith(qcast);
       if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
-        qcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
+        qcast->setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
       }
       q.erase();
     }
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc
index fc0e763..6078a20 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc
@@ -28,10 +28,10 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -135,7 +135,7 @@
         input_specs.push_back(TypeAttr::get(state.params));
       }
     }
-    op.setAttr("input_specs", ArrayAttr::get(input_specs, context));
+    op->setAttr("input_specs", ArrayAttr::get(input_specs, context));
 
     llvm::SmallVector<Attribute, 4> output_specs;
     auto original_output_specs = op.output_specs().getValue();
@@ -150,7 +150,7 @@
         output_specs.push_back(TypeAttr::get(state.params));
       }
     }
-    op.setAttr("output_specs", ArrayAttr::get(output_specs, context));
+    op->setAttr("output_specs", ArrayAttr::get(output_specs, context));
   });
   return success();
 }
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
index d847a7d..8330da8 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
@@ -30,10 +30,10 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
@@ -532,7 +532,7 @@
   // quantization pass. These ops can be removed without losing original
   // program accuracy.
   // TODO(fengliuai): make the attribute being part of op definition.
-  quantize.setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
+  quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
 
   // `original_result` has a use to `quantize`, so this will replace that use
   // by the result of `dequantize`. Remember to reset that use afterwards
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
index 9991d10..c8ac2c2 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
@@ -29,8 +29,8 @@
 #include "mlir/Dialect/Quant/QuantizeUtils.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
index 8e2787d..12e1dc6 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
@@ -33,10 +33,10 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
@@ -268,7 +268,17 @@
       OperationState new_state(quantized_op->getLoc(),
                                quantized_op->getName().getStringRef(), inputs,
                                output_types, quantized_op->getAttrs());
+      for (int i = 0; i < quantized_op->getNumRegions(); ++i) {
+        new_state.addRegion();
+      }
       Operation* new_op = rewriter.createOperation(new_state);
+      if (quantized_op->getNumRegions() != 0) {
+        for (auto indexed_regions :
+             llvm::enumerate(quantized_op->getRegions())) {
+          new_op->getRegion(indexed_regions.index())
+              .takeBody(indexed_regions.value());
+        }
+      }
       for (auto output : outputs_replaced) {
         output.getFirst().replaceAllUsesWith(
             new_op->getResult(output.getSecond()));
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir
index d9d4d44..ce22419 100644
--- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir
@@ -68,7 +68,7 @@
 }
 
 // CHECK-LABEL: fakeQuantWithConv2D
-func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x8x7x16xf32>) {
 ^bb0(%arg: tensor<256x32x32x3xf32>) :
   %in = constant dense<0.0> : tensor<3x3x3x16xf32>
   %min = constant dense<0.0> : tensor<f32>
@@ -76,8 +76,8 @@
   %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
   %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
   %fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
-  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
-  return %rst : tensor<256x30x30x16xf32>
+  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
+  return %rst : tensor<256x8x7x16xf32>
 
 // CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
 // CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
@@ -87,7 +87,7 @@
 }
 
 // CHECK-LABEL: perChannelFakeQuantWithConv2D
-func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x8x7x16xf32>) {
 ^bb0(%arg: tensor<256x32x32x3xf32>) :
   %in = constant dense<0.0> : tensor<3x3x3x16xf32>
   %min = constant dense<0.0> : tensor<16xf32>
@@ -95,8 +95,8 @@
   %mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
   %maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
   %fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
-  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
-  return %rst : tensor<256x30x30x16xf32>
+  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
+  return %rst : tensor<256x8x7x16xf32>
 
 // CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
 // CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
@@ -104,7 +104,7 @@
 // CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
 // CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
 // CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
-// CHECK: return %[[CONV]] : tensor<256x30x30x16xf32>
+// CHECK: return %[[CONV]] : tensor<256x8x7x16xf32>
 }
 
 // CHECK-LABEL: fakeQuantWithDepthwiseConv2D
diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
index ff92f64..89992e8 100644
--- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
@@ -687,3 +687,42 @@
   // CHECK: return %arg0 : tensor<7xf32>
 }
 
+// CHECK-LABEL: @cast_i1_to_i8
+func @cast_i1_to_i8() -> tensor<2xi8> {
+  %cst = constant dense<[false, true]> : tensor<2xi1>
+  %0 = "tfl.cast"(%cst) : (tensor<2xi1>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+
+// CHECK: %[[CST:.*]] = constant dense<[0, 1]> : tensor<2xi8>
+// CHECK:  return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_i1_to_ui8
+func @cast_i1_to_ui8() -> tensor<2xui8> {
+  %cst = constant dense<[false, true]> : tensor<2xi1>
+  %0 = "tfl.cast"(%cst) : (tensor<2xi1>) -> tensor<2xui8>
+  return %0 : tensor<2xui8>
+
+// CHECK: %[[CST:.*]] = constant dense<[0, 1]> : tensor<2xui8>
+// CHECK:  return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_i8_to_i1
+func @cast_i8_to_i1() -> tensor<4xi1> {
+  %cst = constant dense<[0, 1, 2, -1]> : tensor<4xi8>
+  %0 = "tfl.cast"(%cst) : (tensor<4xi8>) -> tensor<4xi1>
+  return %0 : tensor<4xi1>
+
+// CHECK: %[[CST:.*]] = constant dense<[false, true, true, true]> : tensor<4xi1>
+// CHECK:  return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_ui8_to_i1
+func @cast_ui8_to_i1() -> tensor<4xi1> {
+  %cst = constant dense<[0, 127, 128, 255]> : tensor<4xui8>
+  %0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi1>
+  return %0 : tensor<4xi1>
+
+// CHECK: %[[CST:.*]] = constant dense<[false, true, true, true]> : tensor<4xi1>
+// CHECK:  return %[[CST]]
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
index 93a1535..9c3543e 100644
--- a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
@@ -1,30 +1,30 @@
 // RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s
 
-func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
+func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32> {
   %cst = constant dense<[2, 2]> : tensor<2xi32>
   %cst_0 = constant dense<4> : tensor<2x2xi32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
   %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
-  %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
-  return %2 : tensor<1x128x128x8xf32>
+  %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x120x120x8xf32>
+  return %2 : tensor<1x120x120x8xf32>
 
   // CHECK-LABEL: testDilatedConv
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
-  // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
+  // CHECK-NEXT: return [[RESULT]] : tensor<1x120x120x8xf32>
 }
 
-func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
+func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32> {
   %cst = constant dense<[2, 2]> : tensor<2xi32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
   %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
-  %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
-  return %2 : tensor<1x128x128x8xf32>
+  %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x120x120x8xf32>
+  return %2 : tensor<1x120x120x8xf32>
 
   // CHECK-LABEL: testDilatedConvWithNonConstantPadAndCrops
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
-  // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
+  // CHECK-NEXT: return [[RESULT]] : tensor<1x120x120x8xf32>
 }
 
 func @testDilatedConvWithNonZeroBasePadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
@@ -47,8 +47,8 @@
   %cst_0 = constant dense<4> : tensor<2x2xi32>
   %cst_1 = constant dense<0> : tensor<2x2xi32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
-  %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
-  %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
+  %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x60x60x8xf32>
+  %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x60x60x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
   return %2 : tensor<1x128x128x8xf32>
 
   // CHECK-LABEL: testDilatedConvWithNonTrivialDilations
@@ -245,7 +245,7 @@
   %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
   %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
   %4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
-  %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
+  %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
   %6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
   return %6 : tensor<1x128x128xf32>
 
@@ -253,7 +253,7 @@
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
   // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
@@ -358,13 +358,13 @@
   // CHECK-NEXT: return [[RESULT]]
 }
 
-func @testNoDilatedConvWhenGivenInputIsNonFloatType(%arg0: tensor<1x128x128x3xi32>, %arg1: tensor<5x5x3x8xi32>) -> tensor<1x128x128x8xi32> {
+func @testNoDilatedConvWhenGivenInputIsNonFloatType(%arg0: tensor<1x128x128x3xi32>, %arg1: tensor<5x5x3x8xi32>) -> tensor<1x120x120x8xi32> {
   %cst = constant dense<[2, 2]> : tensor<2xi32>
   %cst_0 = constant dense<4> : tensor<2x2xi32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xi32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xi32>
   %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xi32>, tensor<5x5x3x8xi32>) -> tensor<4x64x64x8xi32>
-  %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xi32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xi32>
-  return %2 : tensor<1x128x128x8xi32>
+  %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xi32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x120x120x8xi32>
+  return %2 : tensor<1x120x120x8xi32>
 
   // CHECK-LABEL: testNoDilatedConvWhenGivenInputIsNonFloatType
   // CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
diff --git a/tensorflow/compiler/mlir/lite/tests/inlining.mlir b/tensorflow/compiler/mlir/lite/tests/inlining.mlir
index c494b8c..c0fa4b1 100644
--- a/tensorflow/compiler/mlir/lite/tests/inlining.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/inlining.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s -inline="disable-simplify" | FileCheck %s
+// RUN: tf-opt %s -inline='default-pipeline=''' | FileCheck %s
 
 // Inline a function that contains only tfl ops.
 func @func_with_tfl_ops(%arg0 : tensor<2xi32>) -> tensor<2xi32> {
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir
index c5bf39c..89dc807 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir
@@ -1,5 +1,5 @@
 // RUN: tf-opt --tfl-legalize-tf-while %s -o - | FileCheck %s
-// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline="disable-simplify" | FileCheck %s --check-prefix=INLINE
+// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline='default-pipeline=''' | FileCheck %s --check-prefix=INLINE
 // RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline | FileCheck %s --check-prefix=CANON
 
 func @while_main(%arg0: tensor<?x256x256xf32>) -> (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>) attributes {tf.entry_function = {inputs = "input", outputs = "Identity,Identity_1,Identity_2"}} {
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir
index 1be7db1..913e128 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_custom.mlir
@@ -2,7 +2,7 @@
 
 // CHECK: error: 'tf.MyCustomOp' op is neither a custom op nor a flex op
 // CHECK: error: failed while converting: 'main'
-// CHECK: Ops that need custom implementation (enabled via setting the -emit-custom-ops flag):
+// CHECK: Some ops in the model are custom ops, See instructions to implement
 // CHECK: tf.MyCustomOp {name = "MyCustomOp"}
 
 func @main(tensor<4xf32>) -> tensor<4xf32> {
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir
index e767dc0..8e36c52 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir
@@ -2,7 +2,7 @@
 
 // CHECK: error: 'tf.Div' op is neither a custom op nor a flex op
 // CHECK: error: failed while converting: 'main'
-// CHECK: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag):
+// CHECK: Some ops are not supported by the native TFLite runtime
 // CHECK: tf.Div {name = "div"}
 
 func @main(tensor<4xf32>) -> tensor<4xf32> {
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index 0719395..18ba99f 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -357,7 +357,7 @@
 
 func @testMulInvalidOperands(tensor<? x f32>, tensor<? x i32>) -> tensor<? x i32> {
 ^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x i32>):
-  // expected-error @+1 {{failed to verify that operands have same element type}}
+  // expected-error @+1 {{failed to verify that binary op operands have same element type}}
   %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<? x f32>, tensor<? x i32>) -> tensor<? x i32>
   return %0#0 : tensor<? x i32>
 }
@@ -366,7 +366,7 @@
 
 func @testMulInvalidQuantizedOperands(tensor<* x !quant.any<i16:f32>>, tensor<* x !quant.any<i8:f32>>) -> tensor<* x !quant.any<i16:f32>> {
 ^bb0(%arg0: tensor<* x !quant.any<i16:f32>>, %arg1: tensor<* x !quant.any<i8:f32>>):
-  // expected-error @+1 {{failed to verify that operands have same element type}}
+  // expected-error @+1 {{failed to verify that binary op operands have same element type}}
   %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<* x !quant.any<i16:f32>>, tensor<* x !quant.any<i8:f32>>) -> tensor<* x !quant.any<i16:f32>>
   return %0#0 : tensor<* x !quant.any<i16:f32>>
 }
@@ -412,7 +412,7 @@
 // -----
 
 func @testFloorDivF32(%arg0: tensor<2 x f32>, %arg1: tensor<2 x i32>) -> tensor<2 x f32> {
-  // expected-error @+1 {{failed to verify that operands have same element type}}
+  // expected-error @+1 {{failed to verify that binary op operands have same element type}}
   %0 = "tfl.floor_div"(%arg0, %arg1) : (tensor<2 x f32>, tensor<2 x i32>) -> tensor<2 x f32>
   return %0#0 : tensor<2 x f32>
 }
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
index 477efbb..6e845af 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
@@ -583,3 +583,68 @@
   return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
 }
 }
+
+// -----
+
+module {
+func @max_unpooling_2d(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2]}>} {
+  %0 = "tf.Const"() {value = dense<[4, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<0> : tensor<1x1x2x1xi32>} : () -> tensor<1x1x2x1xi32>
+  %3 = "tf.Const"() {value = dense<[1, 2, 4, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %4 = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
+  %5 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  %6 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  %7 = "tf.FloorDiv"(%arg1, %5) {device = ""} : (tensor<1x1x2x1xi32>, tensor<i32>) -> tensor<1x1x2x1xi32>
+  %8 = "tf.FloorMod"(%7, %4) {device = ""} : (tensor<1x1x2x1xi32>, tensor<i32>) -> tensor<1x1x2x1xi32>
+  %9 = "tf.FloorDiv"(%arg1, %4) {device = ""} : (tensor<1x1x2x1xi32>, tensor<i32>) -> tensor<1x1x2x1xi32>
+  %10 = "tf.Pack"(%2, %9, %8, %2) {axis = 0 : i64, device = ""} : (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) -> tensor<4x1x1x2x1xi32>
+  %11 = "tf.Reshape"(%10, %0) {device = ""} : (tensor<4x1x1x2x1xi32>, tensor<2xi32>) -> tensor<4x2xi32>
+  %12 = "tf.Transpose"(%11, %6) {device = ""} : (tensor<4x2xi32>, tensor<2xi32>) -> tensor<2x4xi32>
+  %13 = "tf.Reshape"(%arg0, %1) {device = ""} : (tensor<1x1x2x1xf32>, tensor<1xi32>) -> tensor<2xf32>
+  %14 = "tf.ScatterNd"(%12, %13, %3) {device = ""} : (tensor<2x4xi32>, tensor<2xf32>, tensor<4xi32>) -> tensor<1x2x4x1xf32>
+  %15 = "tf.Identity"(%14) {device = ""} : (tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
+  return %15 : tensor<1x2x4x1xf32>
+}
+
+// CHECK-LABEL: func @max_unpooling_2d(
+// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x1x2x1xf32>,
+// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = "MaxUnpooling2D"} {
+// CHECK-NEXT:    %[[VAL_2:.*]] = "tfl.custom"(%[[VAL_0]], %[[VAL_1]]) {custom_code = "MaxUnpooling2D", custom_option = opaque<"tfl", "0x01000000020000000200000002000000020000000000000000000000000000000000000000000000"> : tensor<40xi8>} : (tensor<1x1x2x1xf32>, tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32>
+// CHECK-NEXT:    return %[[VAL_2]] : tensor<1x2x4x1xf32>
+// CHECK-NEXT:  }
+}
+
+// -----
+
+module {
+// expected-error @+1 {{Invalid number of results from MaxUnpooling2D}}
+func private @max_unpooling_2d_invalid_results(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> (tensor<1x2x4x1xf32>, tensor<1x2x4x1xi32>) attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2]}>}
+
+// expected-error @+1 {{Invalid number of arguments to MaxUnpooling2D}}
+func private @max_unpooling_2d_invalid_args(%arg0: tensor<1x1x2x1xf32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2]}>}
+
+// expected-error @+1 {{Padding for MaxUnpooling2D must be 'SAME' or 'VALID'}}
+func private @max_unpooling_2d_wrong_padding(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "NO", pool_size = [2, 2], strides = [2, 2]}>}
+
+// expected-error @+1 {{'pool_size' attribute for MaxUnpooling2D must be set and has size of 2}}
+func private @max_unpooling_2d_wrong_filter(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2], strides = [2, 2]}>}
+
+// expected-error @+1 {{'strides' attribute for MaxUnpooling2D must be set and has size of 2}}
+func private @max_unpooling_2d_wrong_strides(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2, 2]}>}
+
+// expected-error @+1 {{'padding' attribute for MaxUnpooling2D is not set or not a string}}
+func private @max_unpooling_2d_no_padding(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {pool_size = [2, 2], strides = [2, 2]}>}
+
+// expected-error @+1 {{'pool_size' attribute for MaxUnpooling2D must be set and has size of 2}}
+func private @max_unpooling_2d_no_filter(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", strides = [2, 2]}>}
+
+// expected-error @+1 {{'strides' attribute for MaxUnpooling2D must be set and has size of 2}}
+func private @max_unpooling_2d_no_strides(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2]}>}
+
+// expected-error @+1 {{'pool_size' attribute for MaxUnpooling2D does not contain integer values}}
+func private @max_unpooling_2d_filter_wrong_type(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = ["a", "b"], strides = [2, 2]}>}
+
+  // expected-error @+1 {{'strides' attribute for MaxUnpooling2D does not contain integer values}}
+func private @max_unpooling_2d_strides_wrong_type(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = ["2", "2"]}>}
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir
index 5c06bb4..cf8712b 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir
@@ -1,13 +1,14 @@
-// RUN: tf-opt %s -tfl-prepare-quantize | FileCheck %s
+// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-signed -tfl-test-post-training-quantize | FileCheck %s
 
 // CHECK-LABEL: QuantizeLstmCellInput
 func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> {
-    %cst_1 = constant dense<1.0> : tensor<1x20xf32>
     %cst_2 = constant unit
     %cst_3 = constant dense<1.0> : tensor<20x20xf32>
     %cst_7 = constant dense<1.0> : tensor<20xf32>
     %cst_11 = constant dense<1.0> : tensor<20x28xf32>
-    %cell_input = constant dense<0.0> : tensor<1x20xf32>
+    %recurrent_input = constant dense<1.0> : tensor<1x20xf32>
+    %recurrent_stats = "quant.stats"(%recurrent_input) {layerStats = dense<[-2.0, 1.0]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
+    %cell_input = constant dense<1.0> : tensor<1x20xf32>
     %cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
     %0 = "tfl.unidirectional_sequence_lstm"(%arg0,
       %cst_11, %cst_11, %cst_11, %cst_11,
@@ -15,7 +16,7 @@
       %cst_2, %cst_2, %cst_2,
       %cst_7, %cst_7, %cst_7, %cst_7,
       %cst_2, %cst_2,
-      %cst_1, %cell_stats,
+      %recurrent_stats, %cell_stats,
       %cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
     : ( tensor<1x28x28xf32>,
         tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>,
@@ -25,17 +26,20 @@
         none, none,
         tensor<1x20xf32>, tensor<1x20xf32>,
         none, none, none, none) -> tensor<1x28x20xf32>
-    return %0 : tensor<1x28x20xf32>
+    %1 = "quant.stats"(%0) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<1x28x20xf32>) -> tensor<1x28x20xf32>
+    return %1 : tensor<1x28x20xf32>
 // CHECK: %[[none:.*]] = constant unit
-// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32>
+// CHECK: %[[cell_input:.*]] = constant dense<1.000000e+00> : tensor<1x20xf32>
 // CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>
 // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x20xf32>
 // Checks if input 19 is correctly passed from a dequantize op.
 // CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]])
 }
 
-// CHECK-LABEL: QuantizeIntermediates
-func @QuantizeIntermediates(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} {
+// CHECK-LABEL: QuantizeWithoutNorm
+func @QuantizeWithoutNorm(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} {
+  %none = constant unit
+  %input = "quant.stats"(%arg0) {layerStats = dense<[-1.2, 1.5]> : tensor<2xf32>} : (tensor<1x5xf32>) -> tensor<1x5xf32>
   %0 = "tfl.pseudo_const"() {value = dense<[[1.31760073, -0.78338623, 0.287265539, -0.383972764, -0.00321021513], [0.104248755, 1.07823908, 0.138089031, 0.76123321, -1.4124943]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
   %1 = "tfl.pseudo_const"() {value = dense<[[2.32939887, -0.623641372, -0.0191893689, 0.326861918, 0.734137893], [0.499284297, 1.25277913, 0.60228157, -1.39478016, 0.115529917]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
   %2 = "tfl.pseudo_const"() {value = dense<[[0.839470446, 0.564852297, -0.80136007, -0.0372898243, 0.57127893], [-5.516230e-01, -1.082380e+00, 1.41860521, -0.92541927, -1.13971734]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
@@ -53,19 +57,100 @@
   %14 = "tfl.pseudo_const"() {value = dense<[-1.90073407, -0.286088765]> : tensor<2xf32>} : () -> tensor<2xf32>
   %15 = "tfl.pseudo_const"() {value = dense<[[0.580187321, -1.72028887], [1.48392391, 0.859561979], [0.316514879, 0.81852132], [0.0933789983, 0.58165586]]> : tensor<4x2xf32>} : () -> tensor<4x2xf32>
   %16 = "tfl.pseudo_const"() {value = dense<[-0.0432887711, -0.431485623, -0.307492912, -0.882515907]> : tensor<4xf32>} : () -> tensor<4xf32>
-  %17 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
-  %18 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
+  %recurrent_input = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
+  %recurrent_stats = "quant.stats"(%recurrent_input) {layerStats = dense<[-2.0, 1.0]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
+  %cell_input = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
+  %cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
   %19 = "tfl.pseudo_const"() {value = dense<[0.928654432, -0.393729329]> : tensor<2xf32>} : () -> tensor<2xf32>
   %20 = "tfl.pseudo_const"() {value = dense<[-0.76004064, -0.892570137]> : tensor<2xf32>} : () -> tensor<2xf32>
   %21 = "tfl.pseudo_const"() {value = dense<[-0.330534697, -1.68513882]> : tensor<2xf32>} : () -> tensor<2xf32>
   %22 = "tfl.pseudo_const"() {value = dense<[-0.896740913, -0.382640809]> : tensor<2xf32>} : () -> tensor<2xf32>
-  %23 = "tfl.unidirectional_sequence_lstm"(%arg0,
+  %23 = "tfl.unidirectional_sequence_lstm"(%input,
     %0, %1, %2, %3,
     %4, %5, %6, %7,
     %8, %9, %10,
     %11, %12, %13, %14,
     %15, %16,
-    %17, %18,
+    %recurrent_stats, %cell_stats,
+    %none, %none, %none, %none) {cell_clip = 5.000000e+01 : f32,
+      effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>,
+      fused_activation_function = "TANH",
+      proj_clip = 0.000000e+00 : f32, time_major = false} : (
+        tensor<1x5xf32>,
+        tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>,
+        tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>,
+        tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
+        tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
+        tensor<4x2xf32>, tensor<4xf32>,
+        tensor<1x4xf32>, tensor<1x2xf32>,
+        none, none, none, none) -> tensor<*xf32>
+  %24 = "quant.stats"(%23) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32>
+  return %24 : tensor<*xf32>
+
+// CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x5x!quant.uniform<i8:f32, 0.010588235481112611:-15>>) -> tensor<1x5xf32>
+// CHECK-DAG: %[[input_1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011122002376346137>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011122002376346137>>
+// CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.018341723389512912>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.018341723389512912>>
+// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011170119751156785>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011170119751156785>>
+// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.017216451524749515>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.017216451524749515>>
+// CHECK-DAG: %[[input_5:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.013025231248750461>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.013025231248750461>>
+// CHECK-DAG: %[[input_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.019049501794529713>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.019049501794529713>>
+// CHECK-DAG: %[[input_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.010094007169167826>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.010094007169167826>>
+// CHECK-DAG: %[[input_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.018637238525030179>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.018637238525030179>>
+// CHECK-DAG: %[[input_9:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.3700684138124656E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.3700684138124656E-5>>
+// CHECK-DAG: %[[input_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.0780334190922573E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.0780334190922573E-5>>
+// CHECK-DAG: %[[input_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.8612512878442185E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.8612512878442185E-5>>
+// CHECK-DAG: %[[input_12:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 1.1776238018224695E-4>>) -> tensor<2x!quant.uniform<i32:f32, 1.1776238018224695E-4>>
+// CHECK-DAG: %[[input_13:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 1.9420648637759368E-4>>) -> tensor<2x!quant.uniform<i32:f32, 1.9420648637759368E-4>>
+// CHECK-DAG: %[[input_14:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 1.1827185827747504E-4>>) -> tensor<2x!quant.uniform<i32:f32, 1.1827185827747504E-4>>
+// CHECK-DAG: %[[input_15:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 1.8229184289320815E-4>>) -> tensor<2x!quant.uniform<i32:f32, 1.8229184289320815E-4>>
+// CHECK-DAG: %[[input_16:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x2x!quant.uniform<i8<-127:127>:f32, 0.013545581674951268>>) -> tensor<4x2x!quant.uniform<i8<-127:127>:f32, 0.013545581674951268>>
+// CHECK-DAG: %[[input_17:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x!quant.uniform<i32:f32, 5.3119928137063791E-5>>) -> tensor<4x!quant.uniform<i32:f32, 5.3119928137063791E-5>>
+// CHECK-DAG: %[[input_18:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform<i8:f32, 0.015686274509803921:-1>>) -> tensor<1x4xf32>
+// CHECK-DAG: %[[input_19:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x2x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x2xf32>
+
+// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]], %[[input_5]], %[[input_6]], %[[input_7]], %[[input_8]],
+// CHECK-SAME: %[[input_9]], %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]]
+// CHECK-SAME: effective_hidden_scale_intermediate = tensor<!quant.uniform<i8:f32, 0.0039215686274509803:-1>>
+
+// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
+}
+
+
+// CHECK-LABEL: QuantizeUnidirectionalLstmFull
+func @QuantizeUnidirectionalLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} {
+  %input = "quant.stats"(%arg0) {layerStats = dense<[-1.2, 1.5]> : tensor<2xf32>} : (tensor<1x5xf32>) -> tensor<1x5xf32>
+  %0 = "tfl.pseudo_const"() {value = dense<[[1.31760073, -0.78338623, 0.287265539, -0.383972764, -0.00321021513], [0.104248755, 1.07823908, 0.138089031, 0.76123321, -1.4124943]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %1 = "tfl.pseudo_const"() {value = dense<[[2.32939887, -0.623641372, -0.0191893689, 0.326861918, 0.734137893], [0.499284297, 1.25277913, 0.60228157, -1.39478016, 0.115529917]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %2 = "tfl.pseudo_const"() {value = dense<[[0.839470446, 0.564852297, -0.80136007, -0.0372898243, 0.57127893], [-5.516230e-01, -1.082380e+00, 1.41860521, -0.92541927, -1.13971734]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %3 = "tfl.pseudo_const"() {value = dense<[[-0.440826088, -0.0863231644, -0.707756281, -0.695703208, -1.87899077], [0.16942361, 0.206325337, 1.09067786, -2.18648934, 0.273400396]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %4 = "tfl.pseudo_const"() {value = dense<[[-1.65420437, 0.19633314, 0.828249216, -0.546153665], [-1.49073172, 1.6467551, 0.904948651, 1.1367631]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %5 = "tfl.pseudo_const"() {value = dense<[[-0.435141891, -0.940576493, 1.30446923, -1.02953017], [0.684501767, 0.363370508, -2.29151702, 2.41928673]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %6 = "tfl.pseudo_const"() {value = dense<[[0.270476967, 0.00706229592, 0.489950746, 1.05166924], [1.28193891, 0.273171216, 0.484176666, 1.11504579]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %7 = "tfl.pseudo_const"() {value = dense<[[-2.36692929, -3.483900e-01, 0.322934568, -1.56939185], [-5.623850e-01, -0.083735466, 1.73820043, 0.218063414]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %8 = "tfl.pseudo_const"() {value = dense<[1.43194032, -0.553496838]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %9 = "tfl.pseudo_const"() {value = dense<[-1.66391921, 1.14934266]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %10 = "tfl.pseudo_const"() {value = dense<[-1.59288621, 0.904723584]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %11 = "tfl.pseudo_const"() {value = dense<[-0.323118627, 1.77580559]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %12 = "tfl.pseudo_const"() {value = dense<[-1.0347594, -1.09994471]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %13 = "tfl.pseudo_const"() {value = dense<[-2.03072214, -1.63648951]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %14 = "tfl.pseudo_const"() {value = dense<[-1.90073407, -0.286088765]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %15 = "tfl.pseudo_const"() {value = dense<[[0.580187321, -1.72028887], [1.48392391, 0.859561979], [0.316514879, 0.81852132], [0.0933789983, 0.58165586]]> : tensor<4x2xf32>} : () -> tensor<4x2xf32>
+  %16 = "tfl.pseudo_const"() {value = dense<[-0.0432887711, -0.431485623, -0.307492912, -0.882515907]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %recurrent_input = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
+  %recurrent_stats = "quant.stats"(%recurrent_input) {layerStats = dense<[-2.0, 1.0]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
+  %cell_input = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
+  %cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
+  %19 = "tfl.pseudo_const"() {value = dense<[0.928654432, -0.393729329]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %20 = "tfl.pseudo_const"() {value = dense<[-0.76004064, -0.892570137]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %21 = "tfl.pseudo_const"() {value = dense<[-0.330534697, -1.68513882]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %22 = "tfl.pseudo_const"() {value = dense<[-0.896740913, -0.382640809]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %23 = "tfl.unidirectional_sequence_lstm"(%input,
+    %0, %1, %2, %3,
+    %4, %5, %6, %7,
+    %8, %9, %10,
+    %11, %12, %13, %14,
+    %15, %16,
+    %recurrent_stats, %cell_stats,
     %19, %20, %21, %22) {cell_clip = 5.000000e+01 : f32,
       effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>,
       fused_activation_function = "TANH",
@@ -82,10 +167,135 @@
         tensor<4x2xf32>, tensor<4xf32>,
         tensor<1x4xf32>, tensor<1x2xf32>,
         tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<*xf32>
-  return %23 : tensor<*xf32>
-// CHECK: effective_hidden_scale_intermediate = tensor<!quant.uniform<u8:f32, 0.0039215686274509803:128>>
-// CHECK: input_to_cell_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 1.2207403790398877E-4>>
-// CHECK: input_to_forget_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 4.8829615161595508E-4>>
-// CHECK: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>>
-// CHECK: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>>,
+  %24 = "quant.stats"(%23) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32>
+  return %24 : tensor<*xf32>
+
+// CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x5x!quant.uniform<i8:f32, 0.010588235481112611:-15>>) -> tensor<1x5xf32>
+// CHECK-DAG: %[[input_1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011122002376346137>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011122002376346137>>
+// CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.018341723389512912>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.018341723389512912>>
+// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011170119751156785>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011170119751156785>>
+// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.017216451524749515>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.017216451524749515>>
+// CHECK-DAG: %[[input_5:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.013025231248750461>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.013025231248750461>>
+// CHECK-DAG: %[[input_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.019049501794529713>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.019049501794529713>>
+// CHECK-DAG: %[[input_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.010094007169167826>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.010094007169167826>>
+// CHECK-DAG: %[[input_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.018637238525030179>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.018637238525030179>>
+// CHECK-DAG: %[[input_9:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.3700684138124656E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.3700684138124656E-5>>
+// CHECK-DAG: %[[input_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.0780334190922573E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.0780334190922573E-5>>
+// CHECK-DAG: %[[input_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.8612512878442185E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.8612512878442185E-5>>
+// CHECK-DAG: %[[input_12:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 2.7676903410132078E-8>>) -> tensor<2x!quant.uniform<i32:f32, 2.7676903410132078E-8>>
+// CHECK-DAG: %[[input_13:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 2.6601474818224132E-8>>) -> tensor<2x!quant.uniform<i32:f32, 2.6601474818224132E-8>>
+// CHECK-DAG: %[[input_14:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 5.0222583101003261E-8>>) -> tensor<2x!quant.uniform<i32:f32, 5.0222583101003261E-8>>
+// CHECK-DAG: %[[input_15:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 2.6725777405118232E-8>>) -> tensor<2x!quant.uniform<i32:f32, 2.6725777405118232E-8>>
+// CHECK-DAG: %[[input_16:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x2x!quant.uniform<i8<-127:127>:f32, 0.013545581674951268>>) -> tensor<4x2x!quant.uniform<i8<-127:127>:f32, 0.013545581674951268>>
+// CHECK-DAG: %[[input_17:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x!quant.uniform<i32:f32, 5.3119928137063791E-5>>) -> tensor<4x!quant.uniform<i32:f32, 5.3119928137063791E-5>>
+// CHECK-DAG: %[[input_18:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform<i8:f32, 0.015686274509803921:-1>>) -> tensor<1x4xf32>
+// CHECK-DAG: %[[input_19:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x2x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x2xf32>
+// CHECK-DAG: %[[input_20:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.8341149091975248E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.8341149091975248E-5>>
+// CHECK-DAG: %[[input_21:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.7239910213861512E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.7239910213861512E-5>>
+// CHECK-DAG: %[[input_22:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.1427925095427339E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.1427925095427339E-5>>
+// CHECK-DAG: %[[input_23:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.736719606284107E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.736719606284107E-5>>
+
+// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]], %[[input_5]], %[[input_6]], %[[input_7]], %[[input_8]],
+// CHECK-SAME: %[[input_9]], %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]],
+// CHECK-SAME: %[[input_20]], %[[input_21]], %[[input_22]], %[[input_23]])
+// CHECK-SAME: effective_hidden_scale_intermediate = tensor<!quant.uniform<i8:f32, 0.0039215686274509803:-1>>
+// CHECK-SAME: input_to_cell_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 1.2207403790398877E-4>>
+// CHECK-SAME: input_to_forget_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 4.8829615161595508E-4>>
+// CHECK-SAME: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>>
+// CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>>
+
+// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
 }
+
+// CHECK-LABEL: QuantizeLstmFull
+func @QuantizeLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} {
+  %input = "quant.stats"(%arg0) {layerStats = dense<[-1.2, 1.5]> : tensor<2xf32>} : (tensor<1x5xf32>) -> tensor<1x5xf32>
+  %0 = "tfl.pseudo_const"() {value = dense<[[1.31760073, -0.78338623, 0.287265539, -0.383972764, -0.00321021513], [0.104248755, 1.07823908, 0.138089031, 0.76123321, -1.4124943]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %1 = "tfl.pseudo_const"() {value = dense<[[2.32939887, -0.623641372, -0.0191893689, 0.326861918, 0.734137893], [0.499284297, 1.25277913, 0.60228157, -1.39478016, 0.115529917]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %2 = "tfl.pseudo_const"() {value = dense<[[0.839470446, 0.564852297, -0.80136007, -0.0372898243, 0.57127893], [-5.516230e-01, -1.082380e+00, 1.41860521, -0.92541927, -1.13971734]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %3 = "tfl.pseudo_const"() {value = dense<[[-0.440826088, -0.0863231644, -0.707756281, -0.695703208, -1.87899077], [0.16942361, 0.206325337, 1.09067786, -2.18648934, 0.273400396]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
+  %4 = "tfl.pseudo_const"() {value = dense<[[-1.65420437, 0.19633314, 0.828249216, -0.546153665], [-1.49073172, 1.6467551, 0.904948651, 1.1367631]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %5 = "tfl.pseudo_const"() {value = dense<[[-0.435141891, -0.940576493, 1.30446923, -1.02953017], [0.684501767, 0.363370508, -2.29151702, 2.41928673]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %6 = "tfl.pseudo_const"() {value = dense<[[0.270476967, 0.00706229592, 0.489950746, 1.05166924], [1.28193891, 0.273171216, 0.484176666, 1.11504579]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %7 = "tfl.pseudo_const"() {value = dense<[[-2.36692929, -3.483900e-01, 0.322934568, -1.56939185], [-5.623850e-01, -0.083735466, 1.73820043, 0.218063414]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
+  %8 = "tfl.pseudo_const"() {value = dense<[1.43194032, -0.553496838]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %9 = "tfl.pseudo_const"() {value = dense<[-1.66391921, 1.14934266]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %10 = "tfl.pseudo_const"() {value = dense<[-1.59288621, 0.904723584]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %11 = "tfl.pseudo_const"() {value = dense<[-0.323118627, 1.77580559]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %12 = "tfl.pseudo_const"() {value = dense<[-1.0347594, -1.09994471]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %13 = "tfl.pseudo_const"() {value = dense<[-2.03072214, -1.63648951]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %14 = "tfl.pseudo_const"() {value = dense<[-1.90073407, -0.286088765]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %15 = "tfl.pseudo_const"() {value = dense<[[0.580187321, -1.72028887], [1.48392391, 0.859561979], [0.316514879, 0.81852132], [0.0933789983, 0.58165586]]> : tensor<4x2xf32>} : () -> tensor<4x2xf32>
+  %16 = "tfl.pseudo_const"() {value = dense<[-0.0432887711, -0.431485623, -0.307492912, -0.882515907]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %recurrent_input = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
+  %recurrent_stats = "quant.stats"(%recurrent_input) {layerStats = dense<[-2.0, 1.0]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
+  %cell_input = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
+  %cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
+  %19 = "tfl.pseudo_const"() {value = dense<[0.928654432, -0.393729329]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %20 = "tfl.pseudo_const"() {value = dense<[-0.76004064, -0.892570137]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %21 = "tfl.pseudo_const"() {value = dense<[-0.330534697, -1.68513882]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %22 = "tfl.pseudo_const"() {value = dense<[-0.896740913, -0.382640809]> : tensor<2xf32>} : () -> tensor<2xf32>
+  %23 = "tfl.lstm"(%input,
+    %0, %1, %2, %3,
+    %4, %5, %6, %7,
+    %8, %9, %10,
+    %11, %12, %13, %14,
+    %15, %16,
+    %recurrent_stats, %cell_stats,
+    %19, %20, %21, %22) ({}) {
+      cell_clip = 5.000000e+01 : f32,
+      effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>,
+      fused_activation_function = "TANH",
+      input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>,
+      input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>,
+      input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01:3.200000e+01>>>,
+      input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>,
+      proj_clip = 0.000000e+00 : f32,time_major = false} : (
+        tensor<1x5xf32>,
+        tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>,
+        tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>,
+        tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
+        tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
+        tensor<4x2xf32>, tensor<4xf32>,
+        tensor<1x4xf32>, tensor<1x2xf32>,
+        tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<*xf32>
+  %24 = "quant.stats"(%23) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32>
+  return %24 : tensor<*xf32>
+
+// CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x5x!quant.uniform<i8:f32, 0.010588235481112611:-15>>) -> tensor<1x5xf32>
+// CHECK-DAG: %[[input_1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011122002376346137>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011122002376346137>>
+// CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.018341723389512912>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.018341723389512912>>
+// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011170119751156785>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011170119751156785>>
+// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.017216451524749515>>) -> tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.017216451524749515>>
+// CHECK-DAG: %[[input_5:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.013025231248750461>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.013025231248750461>>
+// CHECK-DAG: %[[input_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.019049501794529713>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.019049501794529713>>
+// CHECK-DAG: %[[input_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.010094007169167826>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.010094007169167826>>
+// CHECK-DAG: %[[input_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.018637238525030179>>) -> tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.018637238525030179>>
+// CHECK-DAG: %[[input_9:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.3700684138124656E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.3700684138124656E-5>>
+// CHECK-DAG: %[[input_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.0780334190922573E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.0780334190922573E-5>>
+// CHECK-DAG: %[[input_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.8612512878442185E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.8612512878442185E-5>>
+// CHECK-DAG: %[[input_12:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 2.7676903410132078E-8>>) -> tensor<2x!quant.uniform<i32:f32, 2.7676903410132078E-8>>
+// CHECK-DAG: %[[input_13:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 2.6601474818224132E-8>>) -> tensor<2x!quant.uniform<i32:f32, 2.6601474818224132E-8>>
+// CHECK-DAG: %[[input_14:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 5.0222583101003261E-8>>) -> tensor<2x!quant.uniform<i32:f32, 5.0222583101003261E-8>>
+// CHECK-DAG: %[[input_15:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 2.6725777405118232E-8>>) -> tensor<2x!quant.uniform<i32:f32, 2.6725777405118232E-8>>
+// CHECK-DAG: %[[input_16:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x2x!quant.uniform<i8<-127:127>:f32, 0.013545581674951268>>) -> tensor<4x2x!quant.uniform<i8<-127:127>:f32, 0.013545581674951268>>
+// CHECK-DAG: %[[input_17:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x!quant.uniform<i32:f32, 5.3119928137063791E-5>>) -> tensor<4x!quant.uniform<i32:f32, 5.3119928137063791E-5>>
+// CHECK-DAG: %[[input_18:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform<i8:f32, 0.015686274509803921:-1>>) -> tensor<1x4xf32>
+// CHECK-DAG: %[[input_19:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x2x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x2xf32>
+// CHECK-DAG: %[[input_20:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.8341149091975248E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.8341149091975248E-5>>
+// CHECK-DAG: %[[input_21:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.7239910213861512E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.7239910213861512E-5>>
+// CHECK-DAG: %[[input_22:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.1427925095427339E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.1427925095427339E-5>>
+// CHECK-DAG: %[[input_23:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.736719606284107E-5>>) -> tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.736719606284107E-5>>
+
+// CHECK: %[[lstm:.*]] = "tfl.lstm"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]], %[[input_5]], %[[input_6]], %[[input_7]], %[[input_8]],
+// CHECK-SAME: %[[input_9]], %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]],
+// CHECK-SAME: %[[input_20]], %[[input_21]], %[[input_22]], %[[input_23]])
+// CHECK-NEXT: effective_hidden_scale_intermediate = tensor<!quant.uniform<i8:f32, 0.0039215686274509803:-1>>
+// CHECK-SAME: input_to_cell_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 1.2207403790398877E-4>>
+// CHECK-SAME: input_to_forget_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 4.8829615161595508E-4>>
+// CHECK-SAME: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>>
+// CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>>
+
+// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
+}
+
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
index 03c320c..c134192 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
@@ -588,12 +588,12 @@
 
   return %6 : tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
 
-// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
-// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]])
-// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<32x!quant.uniform<i32:f32, 1.000000e+00>>)
 // CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
 // CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]])
 // CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) : (tensor<32x!quant.uniform<i32:f32, 2.000000e+00>>)
+// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
+// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]])
+// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<32x!quant.uniform<i32:f32, 1.000000e+00>>)
 // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
 // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq_0]])
 }
@@ -672,14 +672,15 @@
   %7 = "tfl.minimum"(%3, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
   return %4, %5, %6, %7 : tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>
 
-// CHECK: %[[cst1:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<32xf32>
-// CHECK: %[[cst2:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<32xf32>
-// CHECK: %[[cst3:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 3.000000e+00>>) -> tensor<32xf32>
-// CHECK: %[[cst4:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 4.000000e+00>>) -> tensor<32xf32>
-// CHECK: %[[output1:.*]] = "tfl.minimum"(%{{.*}}, %[[cst1]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
-// CHECK: %[[output2:.*]] = "tfl.minimum"(%{{.*}}, %[[cst2]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
-// CHECK: %[[output3:.*]] = "tfl.minimum"(%{{.*}}, %[[cst3]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
-// CHECK: %[[output4:.*]] = "tfl.minimum"(%{{.*}}, %[[cst4]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+// CHECK-DAG: %[[cst1:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<32xf32>
+// CHECK-DAG: %[[cst2:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<32xf32>
+// CHECK-DAG: %[[cst3:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 3.000000e+00>>) -> tensor<32xf32>
+// CHECK-DAG: %[[cst4:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 4.000000e+00>>) -> tensor<32xf32>
+// CHECK-NOT: BLOCK_DAG
+// CHECK-DAG: "tfl.minimum"(%{{.*}}, %[[cst1]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+// CHECK-DAG: "tfl.minimum"(%{{.*}}, %[[cst2]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+// CHECK-DAG: "tfl.minimum"(%{{.*}}, %[[cst3]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+// CHECK-DAG: "tfl.minimum"(%{{.*}}, %[[cst4]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
 }
 
 // Make sure quantization parameters are scanned from weight, but not from bias.
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf-with-allowing-bf16-type-legalization.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf-with-allowing-bf16-type-legalization.mlir
index 6b67bb8..8beb319 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf-with-allowing-bf16-type-legalization.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf-with-allowing-bf16-type-legalization.mlir
@@ -3,9 +3,9 @@
 module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
 
 // CHECK-LABEL: conv_2d_bf16
-func @conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16> {
-  %0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16>
-  return %0 : tensor<256x30x30x16xbf16>
+func @conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x16xbf16>) -> tensor<256x8x7x16xbf16> {
+  %0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x16xbf16>) -> tensor<256x8x7x16xbf16>
+  return %0 : tensor<256x8x7x16xbf16>
   // CHECK: "tfl.conv_2d"
 }
 
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index e54b596..f1b4678 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -3,29 +3,29 @@
 
 module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
 
-func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<256x3x32x32xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x16x30x30xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
+func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<256x3x32x32xf32>) -> (tensor<256x8x7x16xf32>, tensor<256x16x32x32xf32>, tensor<256x8x6x16xf32>, tensor<256x32x32x16xf32>, tensor<256x32x32x16xf32>) {
 ^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<256x3x32x32xf32>) :
    // OK
-   %0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+   %0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
    // Unsupported data format
-   %1 = "tf.Conv2D"(%arg2, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x3x32x32xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32>
+   %1 = "tf.Conv2D"(%arg2, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x3x32x32xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x32x32xf32>
    // OK
-   %2 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC",                           padding = "VALID", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+   %2 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC",                           padding = "VALID", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x6x16xf32>
    // Unsupported padding
-   %3 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "EXPLICIT", strides = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+   %3 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "EXPLICIT", strides = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32>
    // Unsupported strides
-   %4 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [2, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+   %4 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [2, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32>
 
-  return %0, %1, %2, %3, %4 : tensor<256x30x30x16xf32>, tensor<256x16x30x30xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
+  return %0, %1, %2, %3, %4 : tensor<256x8x7x16xf32>, tensor<256x16x32x32xf32>, tensor<256x8x6x16xf32>, tensor<256x32x32x16xf32>, tensor<256x32x32x16xf32>
 
 // CHECK-LABEL: conv
 // CHECK:  %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
 // CHECK:  %[[CONSTANT0:.*]] = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
 // CHECK:  %0 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
-// CHECK:  %1 = "tfl.conv_2d"(%arg0, %0, %[[CONSTANT]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
+// CHECK:  %1 = "tfl.conv_2d"(%arg0, %0, %[[CONSTANT]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x8x7x16xf32>
 // CHECK:  %2 = "tf.Conv2D"
 // CHECK:  %3 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
-// CHECK:  %4 = "tfl.conv_2d"(%arg0, %3, %[[CONSTANT]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
+// CHECK:  %4 = "tfl.conv_2d"(%arg0, %3, %[[CONSTANT]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x8x6x16xf32>
 // CHECK:  %5 = "tf.Conv2D"
 // CHECK:  %6 = "tf.Conv2D"
 }
@@ -54,9 +54,9 @@
 // CHECK:  %5 = "tf.DepthwiseConv2dNative"
 }
 
-func @Conv2dNCHW(%arg0: tensor<256x3x32x32xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32> {
-  %0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x3x32x32xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32>
-  return %0 : tensor<256x16x30x30xf32>
+func @Conv2dNCHW(%arg0: tensor<256x3x32x32xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x16x32x32xf32> {
+  %0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x3x32x32xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x32x32xf32>
+  return %0 : tensor<256x16x32x32xf32>
 
   // LAYOUT-LABEL: Conv2dNCHW
   // LAYOUT: "tfl.conv_2d"
@@ -272,7 +272,7 @@
 }
 
 // CHECK-LABEL: fakeQuantWithConv2D
-func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x8x7x16xf32>) {
 ^bb0(%arg: tensor<256x32x32x3xf32>) :
   %in = constant dense<0.0> : tensor<3x3x3x16xf32>
   %min = constant dense<0.0> : tensor<f32>
@@ -280,8 +280,8 @@
   %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
   %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
   %fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
-  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
-  return %rst : tensor<256x30x30x16xf32>
+  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
+  return %rst : tensor<256x8x7x16xf32>
 
 // CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
 // CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<16x3x3x3xf32>
@@ -292,7 +292,7 @@
 }
 
 // CHECK-LABEL: perChannelFakeQuantWithConv2D
-func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x8x7x16xf32>) {
 ^bb0(%arg: tensor<256x32x32x3xf32>) :
   %in = constant dense<0.0> : tensor<3x3x3x16xf32>
   %min = constant dense<0.0> : tensor<16xf32>
@@ -300,8 +300,8 @@
   %mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
   %maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
   %fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
-  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
-  return %rst : tensor<256x30x30x16xf32>
+  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
+  return %rst : tensor<256x8x7x16xf32>
 
 // CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
 // CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<16x3x3x3xf32>
@@ -309,7 +309,7 @@
 // CHECK-SAME: {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00}>>
 // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
 // CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
-// CHECK: return %[[CONV]] : tensor<256x30x30x16xf32>
+// CHECK: return %[[CONV]] : tensor<256x8x7x16xf32>
 }
 
 // CHECK-LABEL: fakeQuantWithDepthwiseConv2D
@@ -740,9 +740,9 @@
 }
 
 // CHECK-LABEL: conv_2d_bf16
-func @conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16> {
-  %0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16>
-  return %0 : tensor<256x30x30x16xbf16>
+func @conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x16xbf16>) -> tensor<256x8x7x16xbf16> {
+  %0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x16xbf16>) -> tensor<256x8x7x16xbf16>
+  return %0 : tensor<256x8x7x16xbf16>
   // CHECK: "tf.Conv2D"
 }
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
index 451eb61..ad5dbc6 100644
--- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
@@ -17,8 +17,8 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "absl/memory/memory.h"
diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc
index 700f062..1dc06ca 100644
--- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc
@@ -20,7 +20,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
@@ -244,8 +244,8 @@
   tflite::optimize::sparsity::FormatConverter<T> format_converter(
       shape, traversal_order, format, b_size, b_map);
   format_converter.DenseToSparse(dense_buffer);
-  auto metadata = format_converter.GetDimMetadata();
-  auto compressed_data = format_converter.GetData();
+  const auto& metadata = format_converter.GetDimMetadata();
+  const auto& compressed_data = format_converter.GetData();
   const int dim_size = metadata.size() / 2;
   std::vector<Attribute> dim_metadata(traversal_order.size());
   for (int i = 0; i < dim_size; i++) {
diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
index 2cd1152..88fcbf0 100644
--- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
+++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
@@ -22,9 +22,9 @@
 
 #include "llvm/Support/Casting.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
@@ -82,12 +82,12 @@
 LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
     Conv2dOpTy op, PatternRewriter& rewriter) const {
   // Make sure Conv2D has 'VALID' padding.
-  if (op.template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
+  if (op->template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
     return failure();
   }
   // Make sure dilations are all ones if set.
   const ArrayAttr& dilations =
-      op.template getAttrOfType<ArrayAttr>("dilations");
+      op->template getAttrOfType<ArrayAttr>("dilations");
   if (dilations && !TFIntListIsAllOnes(dilations)) {
     return failure();
   }
@@ -233,14 +233,14 @@
     for (auto it1 = paddings.begin(), it2 = crops.begin();
          it1 != paddings.end() && it2 != crops.end(); it1++, it2++) {
       if ((*it1).getInt() != (*it2).getInt()) {
-        op.setAttr("padding", rewriter.getStringAttr("SAME"));
+        op->setAttr("padding", rewriter.getStringAttr("SAME"));
         break;
       }
     }
   }
 
   // Set dilations
-  op.setAttr("dilations", dilations_attr.getValue());
+  op->setAttr("dilations", dilations_attr.getValue());
 
   if (expand_op) {
     // If there is `expand_op`, we need to rewire the inputs to bypass the
diff --git a/tensorflow/compiler/mlir/lite/transforms/insert_call_once_op.cc b/tensorflow/compiler/mlir/lite/transforms/insert_call_once_op.cc
index 6f41398..5c018ad 100644
--- a/tensorflow/compiler/mlir/lite/transforms/insert_call_once_op.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/insert_call_once_op.cc
@@ -51,7 +51,7 @@
 
     for (auto func : module.getOps<FuncOp>()) {
       auto dict_attr =
-          func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
+          func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
       if (!dict_attr) continue;
 
       OpBuilder builder(func.getContext());
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index febbd7d..802c84a 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -33,11 +33,11 @@
 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -212,7 +212,7 @@
   }
 
   *attr_i32 = IntegerAttr::get(
-      IntegerType::get(/*width=*/32, attr.getContext()), value);
+      IntegerType::get(attr.getContext(), /*width=*/32), value);
   return success();
 }
 
@@ -466,12 +466,11 @@
     attributes.push_back(
         rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
 
-    auto lstm_op = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
+    Value lstm_result = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
         op->getLoc(), result_types, inputs, attributes);
 
     // Rewire the output.
-    op->getResult(2).replaceAllUsesWith(lstm_op.getResult());
-    rewriter.eraseOp(op);
+    rewriter.replaceOp(op, {nullptr, nullptr, lstm_result});
     return success();
   }
 };
@@ -525,12 +524,11 @@
     attributes.push_back(
         rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
 
-    auto rnn_op = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
+    Value rnn_result = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
         op->getLoc(), result_types, inputs, attributes);
 
     // Rewire the output.
-    op->getResult(1).replaceAllUsesWith(rnn_op.getResult());
-    rewriter.eraseOp(op);
+    rewriter.replaceOp(op, {nullptr, rnn_result});
 
     return success();
   }
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index 3f35b06..43f9834 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -36,12 +36,12 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Block.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
@@ -547,8 +547,8 @@
     Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
                                size_diff.getType(), size.getType()};
     Type branch_result_type[] = {result_type};
-    auto func_type = FunctionType::get(branch_args_type, branch_result_type,
-                                       rewriter.getContext());
+    auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
+                                       branch_result_type);
 
     // Constructs `then_branch`, which is executed when `if_cond` evaluates to
     // true.
@@ -565,7 +565,7 @@
     // Inserts the two blocks' names into the symbol table held by the module.
     // Using SymbolTable will ensure that the inserted symbol names are
     // unique.
-    SymbolTable manager(op.getParentOfType<ModuleOp>());
+    SymbolTable manager(op->getParentOfType<ModuleOp>());
     manager.insert(then_branch_op);
     manager.insert(else_branch_op);
 
@@ -775,8 +775,8 @@
     // Change `func`'s argument type to `unranked_argument_types`. If it
     // return types contain a `DT_VARIANT`, change it to the unranked type
     // derived from the corresponding argument.
-    func.setType(FunctionType::get(updated_argument_types, updated_result_types,
-                                   op.getContext()));
+    func.setType(FunctionType::get(op.getContext(), updated_argument_types,
+                                   updated_result_types));
 
     // Change the argument type for the first block.
     llvm::for_each(func.getArguments(), [&](BlockArgument &arg) {
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index 704cae9..a053ff0 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -37,8 +37,8 @@
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
@@ -243,7 +243,7 @@
   return mlir::DenseElementsAttr::get(
       RankedTensorType::get(
           {static_cast<int>(shape.size())},
-          mlir::IntegerType::get(32, output_val.getContext())),
+          mlir::IntegerType::get(output_val.getContext(), 32)),
       llvm::makeArrayRef(shape));
 }
 
@@ -650,7 +650,7 @@
     ShapedType filter_type = filter_cst.getType();
 
     if (llvm::isa<AddOp, SubOp>(binary_op)) {
-      auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
+      auto padding = fc_op->template getAttrOfType<StringAttr>("padding");
       if (padding && padding.getValue() != "VALID") return failure();
 
       // The fusion of add/sub is actually applying the following
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
index 665fae3..ce2ce2a 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
@@ -20,8 +20,8 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -50,7 +50,7 @@
   if (llvm::makeArrayRef(return_types) == func_type.getResults()) return;
 
   auto updated_type =
-      FunctionType::get(func_type.getInputs(), return_types, func.getContext());
+      FunctionType::get(func.getContext(), func_type.getInputs(), return_types);
   func.setType(updated_type);
 }
 
@@ -79,7 +79,7 @@
     // and therefore one terminator op. So, that function return type can be
     // updated if operands' shapes change after inlining. Without this
     // restriction, it would require tensor cast ops.
-    FuncOp parent_op = op.getParentOfType<FuncOp>();
+    FuncOp parent_op = op->getParentOfType<FuncOp>();
     if (!llvm::hasSingleElement(parent_op)) return failure();
 
     // Find the then and else branch functions.
diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
index 424bd85..be8b096 100644
--- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
@@ -130,7 +130,7 @@
                                 PatternRewriter& rewriter) const override {
     auto input_op = op.input().getDefiningOp();
     if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) {
-      if (!q.getAttr(mlir::quant::kVolatileOpAttrName)) return failure();
+      if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure();
 
       op.replaceAllUsesWith(q.input());
       return success();
@@ -171,6 +171,7 @@
   auto* ctx = func.getContext();
   TFL::populateWithGenerated(ctx, patterns);
   patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
+  patterns.insert<PruneUnusedLstm<TFL::LSTMOp>>(ctx);
   patterns.insert<PruneUnusedLstm<TFL::UnidirectionalSequenceLSTMOp>>(ctx);
   applyPatternsAndFoldGreedily(func, std::move(patterns));
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
index eee0378..6450bd4 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
@@ -27,11 +27,11 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/Visitors.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
@@ -42,6 +42,7 @@
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
+#include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h"
 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@@ -61,6 +62,7 @@
 constexpr char kTFTextAPIPrefix[] = "tftext:";
 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
+constexpr char kCustomMaxUnpooling[] = "addons:MaxUnpooling2D";
 
 using mlir::TF::FuncAttr;
 
@@ -70,8 +72,8 @@
   explicit ConvertEmbeddedLookupFunc(FuncOp func) : func_(func) {}
 
   void RewriteFunc() {
-    func_.setAttr(kTFImplements,
-                  StringAttr::get("embedding_lookup", func_.getContext()));
+    func_->setAttr(kTFImplements,
+                   StringAttr::get("embedding_lookup", func_.getContext()));
     Value lookup = func_.getArgument(1);
     Value value = func_.getArgument(0);
     auto output_type = func_.getType().getResult(0);
@@ -294,6 +296,12 @@
         failed(convert_ssd_postprocess.RewriteFunc())) {
       return signalPassFailure();
     }
+  } else if (api_name == kCustomMaxUnpooling) {
+    ConvertMaxUnpoolingFunc max_unpooling(func, attr);
+    if (failed(max_unpooling.VerifySignature()) ||
+        failed(max_unpooling.RewriteFunc())) {
+      return signalPassFailure();
+    }
   }
 }
 
@@ -326,20 +334,21 @@
     // 2) tf._implements, with proto attributes.
     // 3) tf.api_implements.
     // We need to handle them separately.
-    auto tf_implements_attr_str = func.getAttrOfType<StringAttr>(kTFImplements);
+    auto tf_implements_attr_str =
+        func->getAttrOfType<StringAttr>(kTFImplements);
     if (tf_implements_attr_str) {
       ConvertTFImplements(func, tf_implements_attr_str);
       continue;
     }
 
-    auto tf_implements_attr = func.getAttrOfType<FuncAttr>(kTFImplements);
+    auto tf_implements_attr = func->getAttrOfType<FuncAttr>(kTFImplements);
     if (tf_implements_attr) {
       ConvertTFImplementsWithAttributes(func, tf_implements_attr);
       continue;
     }
 
     auto tf_api_implements_attr =
-        func.getAttrOfType<StringAttr>(kTFAPIImplements);
+        func->getAttrOfType<StringAttr>(kTFAPIImplements);
     if (tf_api_implements_attr) {
       // TODO(b/147536816): Keras lstm should set up the correct attributes.
       ConvertTFAPIImplements(func, tf_api_implements_attr, module);
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
index be46423..e4f2eaa 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
@@ -30,10 +30,10 @@
 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
@@ -321,9 +321,10 @@
 using PrepareQuantStats =
     quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
 
-using PrepareLstmQuantStats =
-    TFL::ConvertLstmStatsToQDQs<TFL::UnidirectionalSequenceLSTMOp,
-                                quant::QuantizeCastOp, quant::DequantizeCastOp>;
+using PrepareLstmQuantStats = TFL::ConvertLstmStatsToQDQs<TFL::LSTMOp>;
+
+using PrepareUnidirectionalLstmQuantStats =
+    TFL::ConvertLstmStatsToQDQs<TFL::UnidirectionalSequenceLSTMOp>;
 
 void PrepareQuantizePass::runOnFunction() {
   FuncOp func = getFunction();
@@ -341,9 +342,6 @@
     }
   }
 
-  // During the legalization, unsigned quantized type is used, so we have to
-  // convert all of them to signed.
-  OwningRewritePatternList patterns;
   bool is_signed = quant_specs_.IsSignedInferenceType();
   int bit_width = quant_specs_.GetQuantizationTypeWidth();
   // When this is true, the quantizer will try its best to extract the
@@ -357,18 +355,32 @@
   bool infer_tensor_range =
       (quant_specs_.post_training_quantization || eager_quantize) &&
       !quant_specs_.disable_infer_tensor_range;
+
+  // LSTM's restrict_scale requirement should be handled before converting stats
+  // to Q-DQ ops. The pattern is applied for non-PTQ case to make op ordering
+  // consistent. Otherwise some FileCheck tests would fail.
+  OwningRewritePatternList patterns_1;
+  if (quant_specs_.post_training_quantization) {
+    patterns_1.insert<PrepareLstmQuantStats>(ctx, quant_specs_);
+    patterns_1.insert<PrepareUnidirectionalLstmQuantStats>(ctx, quant_specs_);
+  }
+  applyPatternsAndFoldGreedily(func, std::move(patterns_1));
+
+  // During the legalization, unsigned quantized type is used, so we have to
+  // convert all of them to signed.
+  OwningRewritePatternList patterns_2;
   if (is_signed) {
-    patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
+    patterns_2.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(
+        ctx);
     // Convert quant stats to int8 quantization parameters.
     // Currently, only activation stats are imported, so narrow_range = false.
-    patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx);
+    patterns_2.insert<PrepareQuantStats>(bit_width, false, true, ctx);
   } else {
     // Convert quant stats to uint8 quantization parameters.
     // Currently, only activation stats are imported, so narrow_range = false.
-    patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
+    patterns_2.insert<PrepareQuantStats>(bit_width, false, false, ctx);
   }
-  patterns.insert<PrepareLstmQuantStats>(ctx, quant_specs_);
-  applyPatternsAndFoldGreedily(func, std::move(patterns));
+  applyPatternsAndFoldGreedily(func, std::move(patterns_2));
 
   SanityCheckAndAdjustment(func);
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h
index 4538a36..aec8869 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h
@@ -23,20 +23,22 @@
 #include <string>
 #include <vector>
 
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/MathExtras.h"
 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
+#include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
-#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/lite/schema/schema_generated.h"
 #include "tensorflow/lite/tools/optimize/operator_property.h"
@@ -52,10 +54,14 @@
   return std::pow(2, std::ceil(std::log2(value)));
 }
 
+constexpr double power_of_two_scale = 32768.0;
+
 namespace operator_property = ::tflite::optimize::operator_property;
+using Q = quant::QuantizeCastOp;
+using DQ = quant::DequantizeCastOp;
 
 // Quantize recurrent input of LSTM with 16 bits.
-template <typename SourceOp, typename Q, typename DQ>
+template <typename SourceOp>
 struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
  public:
   ConvertLstmStatsToQDQs(MLIRContext* context,
@@ -85,14 +91,95 @@
     lstm_variant.use_layer_norm =
         !op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
 
-    auto lstm_property = operator_property::GetOperatorProperty(lstm_variant);
+    const auto lstm_property =
+        operator_property::GetOperatorProperty(lstm_variant);
 
-    // Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
-    const std::vector<std::string> intermediate_attributes = {
-        "input_to_input_intermediate", "input_to_forget_intermediate",
-        "input_to_cell_intermediate", "input_to_output_intermediate",
-        "effective_hidden_scale_intermediate"};
+    // Use same scale for input and output specified in restrict_scale.
+    for (const std::vector<int>& tensors : lstm_property.restrict_scale) {
+      if (tensors.empty()) {
+        continue;
+      }
+      if (tensors.size() != 2) {
+        op.emitError(
+            "Unexpected restricted_scale from operator property."
+            " Should only have a pair of indices.");
+        return failure();
+      }
+      if (failed(processRestrictScale(op, tensors[0], tensors[1], rewriter))) {
+        return failure();
+      }
+    }
 
+    if (failed(processIntermediates(op, lstm_variant, lstm_property)) ||
+        failed(processInputs(op, lstm_variant, lstm_property, rewriter))) {
+      return failure();
+    }
+
+    return success();
+  }
+
+ private:
+  QuantizationSpecs quant_specs;
+  // Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+  const std::vector<std::string> intermediate_attributes = {
+      "input_to_input_intermediate", "input_to_forget_intermediate",
+      "input_to_cell_intermediate", "input_to_output_intermediate",
+      "effective_hidden_scale_intermediate"};
+
+  QuantizedType getIntermediateType(SourceOp op, int intermediate_index) const {
+    TypeAttr attr = op->template getAttrOfType<TypeAttr>(
+        intermediate_attributes[intermediate_index]);
+    if (!attr) {
+      return nullptr;
+    }
+    return QuantizedType::getQuantizedElementType(attr.getValue());
+  }
+
+  LogicalResult getDerivedScale(
+      SourceOp op, int input_index,
+      const operator_property::TensorProperty& tensor_property,
+      double& scale) const {
+    scale = 1.0;
+    for (int tensor_index : tensor_property.derived_scale.input_tensors) {
+      auto dequantize_op = llvm::dyn_cast_or_null<DQ>(
+          op.getOperand(tensor_index).getDefiningOp());
+
+      if (!dequantize_op) {
+        return failure();  // Wait for other scales to be calculated.
+      }
+      auto quant_type = QuantizedType::getQuantizedElementType(
+          dequantize_op.getOperand().getType());
+      if (!quant_type ||
+          !quant_type.template isa<quant::UniformQuantizedType>()) {
+        dequantize_op.emitError("Expected UniformQuantizedType.");
+        return failure();
+      }
+      scale *= quant_type.template dyn_cast<quant::UniformQuantizedType>()
+                   .getScale();
+    }
+    for (int tensor_index :
+         tensor_property.derived_scale.intermediate_tensors) {
+      auto quant_type = getIntermediateType(op, tensor_index);
+      if (!quant_type ||
+          !quant_type.template isa<quant::UniformQuantizedType>()) {
+        op.emitError() << "While processing derived scale for input "
+                       << input_index << ": "
+                       << intermediate_attributes[tensor_index]
+                       << " is not quantized.";
+        return failure();
+      }
+      scale *= quant_type.template dyn_cast<quant::UniformQuantizedType>()
+                   .getScale();
+    }
+    for (float factor : tensor_property.derived_scale.factors) {
+      scale *= factor;
+    }
+    return success();
+  }
+
+  LogicalResult processIntermediates(
+      SourceOp op, const operator_property::OpVariant& lstm_variant,
+      const operator_property::OperatorProperty& lstm_property) const {
     for (auto& enumerated_intermediates : lstm_property.intermediates) {
       int index = enumerated_intermediates.first;
       auto& tensor_property = enumerated_intermediates.second;
@@ -100,31 +187,23 @@
       if (!lstm_variant.use_layer_norm && index != 4) {
         continue;
       }
-      // intermediate tensor 4 is only used with projection.
-      if (!lstm_variant.use_projection && index == 4) {
-        continue;
-      }
-      TypeAttr attr =
-          op.template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
 
-      if (!attr) {
-        op.emitError()
-            << op.getOperationName()
-            << " requires quantization values for intermediate tensor "
-            << intermediate_attributes[index];
-        return failure();
-      }
-      auto quantized_type =
-          QuantizedType::getQuantizedElementType(attr.getValue());
-      if (!quantized_type) {
+      TypeAttr attr =
+          op->template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
+      auto quant_type = getIntermediateType(op, index);
+      if (!quant_type) {
+        // intermediate tensor 4 is optional, unless the LSTM uses projection.
+        if (index == 4 && !lstm_variant.use_projection) {
+          return success();
+        }
         op.emitError() << intermediate_attributes[index]
                        << " is not quantized.";
         return failure();
       }
       auto calibrated_type =
-          quantized_type.dyn_cast<quant::CalibratedQuantizedType>();
+          quant_type.template dyn_cast<quant::CalibratedQuantizedType>();
       if (!calibrated_type) {
-        int num_storage_bits = quantized_type.getStorageTypeIntegralWidth();
+        int num_storage_bits = quant_type.getStorageTypeIntegralWidth();
         if (tensor_property.number_of_bits != num_storage_bits) {
           op.emitError() << intermediate_attributes[index]
                          << " is expected to be quantized with "
@@ -154,34 +233,165 @@
         return failure();
       }
 
-      op.setAttr(intermediate_attributes[index],
-                 TypeAttr::get(qtype.castFromExpressedType(
-                     qtype.castToExpressedType(attr.getValue()))));
+      op->setAttr(intermediate_attributes[index],
+                  TypeAttr::get(qtype.castFromExpressedType(
+                      qtype.castToExpressedType(attr.getValue()))));
+    }
+    return success();
+  }
+
+  LogicalResult processInputs(
+      SourceOp op, const operator_property::OpVariant& lstm_variant,
+      const operator_property::OperatorProperty& lstm_property,
+      PatternRewriter& rewriter) const {
+    for (auto& enumerated_inputs : lstm_property.inputs) {
+      int index = enumerated_inputs.first;
+      auto& tensor_property = enumerated_inputs.second;
+
+      Value input = op.getOperand(index);
+
+      if (input.getDefiningOp() == nullptr) continue;
+
+      // TODO(b/172517537): make this work with non-PTQ case.
+      if (llvm::isa<ConstantOp, TFL::ConstOp>(input.getDefiningOp())) {
+        if (failed(processConstantOp(op, input.getDefiningOp(), index,
+                                     tensor_property, rewriter))) {
+          return failure();
+        }
+      } else {
+        if (auto stats_op =
+                llvm::dyn_cast<quant::StatisticsOp>(input.getDefiningOp())) {
+          if (failed(replaceStatsOp(op, stats_op, index, tensor_property,
+                                    rewriter))) {
+            return failure();
+          }
+          // Continue if StatisticsOp is already converted to Q-DQ pair.
+        } else if (!llvm::isa<DQ>(input.getDefiningOp())) {
+          // TODO(b/172517537): make this work with non-PTQ case.
+          op.emitError() << "Input " << index
+                         << " should be from DequantizeCast "
+                            "or Statistics op.";
+          input.getDefiningOp()->emitError();
+          return failure();
+        }
+      }
+    }
+    return success();
+  }
+
+  LogicalResult processConstantOp(
+      SourceOp op, Operation* const_op, int input_index,
+      const operator_property::TensorProperty& tensor_property,
+      PatternRewriter& rewriter) const {
+    // Non-float tensors are neither weights nor require quantization.
+    auto type = const_op->getResult(0).getType().dyn_cast<ShapedType>();
+    if (!type || !type.getElementType().isa<FloatType>()) return success();
+
+    DenseFPElementsAttr attr;
+    if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) {
+      const_op->emitError("Not a constant op.");
+      return failure();
     }
 
-    quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
-        op.input_cell_state().getDefiningOp());
-    // Recurrent input is be used within an LSTM, and thus should have one use.
-    if (!stats_op || !stats_op.getResult().hasOneUse()) {
+    UniformQuantizedType quant_type = nullptr;
+    const int64_t storage_min = llvm::minIntN(tensor_property.number_of_bits);
+    const int64_t storage_max = llvm::maxIntN(tensor_property.number_of_bits);
+    const IntegerType storage_type =
+        rewriter.getIntegerType(tensor_property.number_of_bits);
+    const Type expressed_type =
+        getElementTypeOrSelf(const_op->getResult(0).getType());
+
+    if (tensor_property.use_derived_scale) {
+      // Biases use derived scale from other tensors.
+      // input 12~15: gate biases, input 17: projection bias
+      if (tensor_property.number_of_bits != 32) {
+        op.emitError() << "Derived scale is only supported for 32-bit "
+                       << "quantization. Got " << tensor_property.number_of_bits
+                       << " bits in input index " << input_index;
+        return failure();
+      }
+      double scale;
+      if (failed(getDerivedScale(op, input_index, tensor_property, scale))) {
+        return failure();
+      }
+      quant_type = UniformQuantizedType::getChecked(
+          quant::QuantizationFlags::Signed, storage_type, expressed_type, scale,
+          /*zeroPoint=*/0, storage_min, storage_max, const_op->getLoc());
+    } else {
+      // For weights, use quantization scale directly inferred from the
+      // values.
+      //
+      // input 1~4: input to gate weights
+      // input 5~8: recurrent to gate weights
+      // input 9~11: peephole weights, input 16: projection weight
+      // input 20~23: normalization weights
+      quant_type =
+          quant::GetUniformQuantizedTypeForWeight(
+              attr, /*symmetric=*/true,
+              /*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true,
+              /*narrow_range=*/true)
+              .template dyn_cast<quant::UniformQuantizedType>();
+    }
+
+    if (!quant_type) {
+      const_op->emitError("Failed to get quantized type");
+      return failure();
+    }
+
+    // TODO(b/172517537): duplicate the constant when the bias is shared.
+    Type cast_type =
+        quant_type.castFromExpressedType(const_op->getResult(0).getType());
+    rewriter.setInsertionPointAfter(const_op);
+    auto q = rewriter.create<Q>(const_op->getLoc(), cast_type,
+                                const_op->getResult(0));
+    auto dq = rewriter.create<DQ>(const_op->getLoc(), cast_type, q);
+    op.setOperand(input_index, dq.getResult());
+    return success();
+  }
+
+  LogicalResult replaceStatsOp(
+      SourceOp op, quant::StatisticsOp stats_op, int input_index,
+      const operator_property::TensorProperty& tensor_property,
+      PatternRewriter& rewriter) const {
+    if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) {
+      // TODO(b/172517537): check if other tensors should go through this
+      // check too.
+      op.emitError() << "Input tensor [" << input_index
+                     << "] is a state tensor, but has more than one use.";
       return failure();
     }
     auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
-    if (!stats) {
+    if (!stats || stats.getNumElements() != 2) {
+      stats_op.emitError("Stats should have 2 values.");
       return failure();
     }
+    quant::QuantizedType quant_type;
+    double min = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
+    double max = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
+    Type expressed = getElementTypeOrSelf(stats_op.getType());
 
-    double max = std::max(
-        std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}))),
-        std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
-    double bound = power_of_two_bound(max);
-    Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
-    // Set flags to 1 for signed type.
-    quant::QuantizedType quant_type = UniformQuantizedType::getChecked(
-        quant::QuantizationFlags::Signed,
-        IntegerType::get(16, expressed.getContext()), expressed,
-        /*scale=*/bound / 32768.0, /*zeroPoint=*/0, llvm::minIntN(16),
-        llvm::maxIntN(16), op.getLoc());
+    if (tensor_property.extend_to_power_of_two) {
+      if (tensor_property.number_of_bits != 16) {
+        op.emitError(
+            "extended power of 2 scale is only supported for 16-bit"
+            " quantization.");
+        return failure();
+      }
 
+      double bound = power_of_two_bound(std::max(std::abs(min), std::abs(max)));
+      // Set flags to 1 for signed type.
+      quant_type = UniformQuantizedType::getChecked(
+          quant::QuantizationFlags::Signed,
+          rewriter.getIntegerType(tensor_property.number_of_bits), expressed,
+          /*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits),
+          /*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits),
+          llvm::maxIntN(tensor_property.number_of_bits), op.getLoc());
+    } else {
+      quant_type = quant::fakeQuantAttrsToType(
+          op.getLoc(), tensor_property.number_of_bits, min, max,
+          /*narrowRange=*/false, expressed,
+          /*isSigned=*/true);
+    }
     rewriter.setInsertionPointAfter(stats_op);
     Type result_type = quant_type.castFromExpressedType(stats_op.getType());
     auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
@@ -189,8 +399,60 @@
     return success();
   }
 
- private:
-  QuantizationSpecs quant_specs;
+  // For LSTM's recurrent input activation and output, they are quantized with
+  // the collective range of both tensors, because theoretically the input
+  // activation value for the very first inference is not reflected in the
+  // output and the input activation is not captured.
+  LogicalResult processRestrictScale(SourceOp op, int input_index,
+                                     int output_index,
+                                     PatternRewriter& rewriter) const {
+    assert(output_index == 0);
+    if (!op.getResult().hasOneUse()) {
+      op.emitError()
+          << "output " << output_index
+          << " should have only one use, which should be quant.stats.";
+      return failure();
+    }
+
+    llvm::SmallVector<quant::StatisticsOp, 2> stats_ops = {
+        llvm::dyn_cast_or_null<quant::StatisticsOp>(
+            op.getOperand(input_index).getDefiningOp()),
+        llvm::dyn_cast_or_null<quant::StatisticsOp>(
+            *op.getResult().getUsers().begin()),
+    };
+
+    if (!stats_ops[0] || !stats_ops[1]) {
+      return failure();  // Already converted to Q-DQ pair.
+    }
+
+    llvm::SmallVector<llvm::APFloat, 4> min_max_values;
+
+    for (auto& stats_op : stats_ops) {
+      auto values = stats_op.layerStats()
+                        .dyn_cast<DenseFPElementsAttr>()
+                        .getValues<llvm::APFloat>();
+      min_max_values.insert(min_max_values.end(), values.begin(), values.end());
+    }
+
+    // min and max values of two stats are already the same.
+    if (min_max_values[0] == min_max_values[2] &&
+        min_max_values[1] == min_max_values[3]) {
+      return success();
+    }
+
+    mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
+        mlir::RankedTensorType::get({2}, rewriter.getF32Type()),
+        {llvm::minimum(min_max_values[0], min_max_values[2]),
+         llvm::maximum(min_max_values[1], min_max_values[3])});
+    mlir::ElementsAttr axis_stats;
+    mlir::IntegerAttr axis;
+    for (auto& stats_op : stats_ops) {
+      rewriter.setInsertionPointAfter(stats_op);
+      rewriter.replaceOpWithNewOp<quant::StatisticsOp>(
+          stats_op, stats_op.arg(), layer_stats, axis_stats, axis);
+    }
+    return success();
+  }
 };
 
 }  // namespace TFL
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 10a4a69..02affa9 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -45,8 +45,8 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -959,7 +959,8 @@
     if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
 
     {
-      epsilon = fused_batch_norm_op.getAttrOfType<::mlir::FloatAttr>("epsilon");
+      epsilon =
+          fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>("epsilon");
       if (!epsilon)
         epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f);
 
@@ -974,7 +975,7 @@
     }
     {
       exponential_avg_factor =
-          fused_batch_norm_op.getAttrOfType<::mlir::FloatAttr>(
+          fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>(
               "exponential_avg_factor");
       if (!exponential_avg_factor)
         exponential_avg_factor =
@@ -982,12 +983,12 @@
     }
     {
       data_format =
-          fused_batch_norm_op.getAttrOfType<::mlir::StringAttr>("data_format");
+          fused_batch_norm_op->getAttrOfType<::mlir::StringAttr>("data_format");
       if (!data_format) data_format = rewriter.getStringAttr("NHWC");
     }
     {
       is_training =
-          fused_batch_norm_op.getAttrOfType<::mlir::BoolAttr>("is_training");
+          fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
       if (!is_training) is_training = rewriter.getBoolAttr(true);
 
       if (!((!is_training.getValue()))) {
diff --git a/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc b/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc
index 40cca52..c7fa7d4 100644
--- a/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc
@@ -18,8 +18,8 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
diff --git a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc
index af9f21a..a58b7a3 100644
--- a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc
@@ -23,11 +23,11 @@
 #include "mlir/IR/Block.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
index cfa2efe..83d4ac3 100644
--- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
@@ -19,11 +19,11 @@
 #include "llvm/Support/CommandLine.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
@@ -134,12 +134,12 @@
                                  bool passthru_extra_args) {
     FunctionType type;
     if (passthru_extra_args) {
-      type = FunctionType::get(types, types, &getContext());
+      type = FunctionType::get(&getContext(), types, types);
     } else {
       SmallVector<Type, 4> result_types;
       auto operands = region.front().getTerminator()->getOperandTypes();
       result_types.append(operands.begin(), operands.end());
-      type = FunctionType::get(types, result_types, &getContext());
+      type = FunctionType::get(&getContext(), types, result_types);
     }
 
     auto outlined_func = builder.create<FuncOp>(while_op.getLoc(), name, type);
diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc
index 2085517..41a7cd1 100644
--- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 
 namespace mlir {
 namespace TFL {
diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.h b/tensorflow/compiler/mlir/lite/utils/constant_utils.h
index 5c34802..0434cf7 100644
--- a/tensorflow/compiler/mlir/lite/utils/constant_utils.h
+++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.h
@@ -17,10 +17,10 @@
 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
 
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace mlir {
diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc
index 56aac8f..489b0f3 100644
--- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc
+++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc
@@ -16,7 +16,7 @@
 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
 
 #include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
index 1a5e740..090551f 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
@@ -25,12 +25,12 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
@@ -382,8 +382,8 @@
   auto input_types = fused_func_op_.getType().getInputs();
   auto output_type = mlir::RankedTensorType::get(
       output_shape, input_.getType().cast<RankedTensorType>().getElementType());
-  fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
-                                                 fused_func_op_.getContext()));
+  fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(),
+                                                 input_types, output_type));
 }
 
 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
@@ -438,7 +438,7 @@
 }
 
 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::InitializeFromFuncAttributes() {
-  auto attr = fused_func_op_.getAttrOfType<StringAttr>(kTFImplements);
+  auto attr = fused_func_op_->getAttrOfType<StringAttr>(kTFImplements);
   if (!attr) {
     return fused_func_op_.emitError()
            << "Invalid function attribute, expected " << kTFImplements
@@ -639,7 +639,7 @@
 
   // TFL lstm only supports time-majored inputs, so if it's not time-majored,
   // we will transpose the inputs and outputs.
-  auto time_major_attr = func_op.getAttrOfType<BoolAttr>("tf.time_major");
+  auto time_major_attr = func_op->getAttrOfType<BoolAttr>("tf.time_major");
   if (time_major_attr == nullptr) return failure();
 
   bool time_majored = time_major_attr.getValue();
@@ -654,7 +654,7 @@
 
   // Handle go_backwards:
   // LSTM in Keras semantic will reverse the input sequence if it's go_backwards
-  auto go_backwards_attr = func_op.getAttrOfType<BoolAttr>("tf.go_backwards");
+  auto go_backwards_attr = func_op->getAttrOfType<BoolAttr>("tf.go_backwards");
 
   if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) {
     int time_dim = time_majored ? 0 : 1;
@@ -820,8 +820,8 @@
   }
 
   // Update function signatures.
-  func_op.setType(mlir::FunctionType::get(func_op.getType().getInputs(),
-                                          output_types, func_op.getContext()));
+  func_op.setType(mlir::FunctionType::get(
+      func_op.getContext(), func_op.getType().getInputs(), output_types));
 
   builder->create<mlir::ReturnOp>(func_op.getLoc(), outputs);
   return success();
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h
index 449c473..6fc0119 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h
@@ -22,8 +22,8 @@
 #include "llvm/ADT/StringRef.h"
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
index 9eb767c..28e3b07 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
@@ -28,9 +28,9 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
@@ -81,7 +81,7 @@
   mlir::StringAttr attr_values =
       builder->getStringAttr(llvm::join(attributes, ","));
 
-  func.setAttr(kTFImplements, attr_values);
+  func->setAttr(kTFImplements, attr_values);
   return func;
 }
 
@@ -126,7 +126,7 @@
 
   // verify transpose
   EXPECT_EQ(
-      fused_lstm_func_.getAttrOfType<StringAttr>(kTFImplements).getValue(),
+      fused_lstm_func_->getAttrOfType<StringAttr>(kTFImplements).getValue(),
       convert.GetCompositeOpName());
   EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5);
   EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
@@ -199,9 +199,9 @@
 
   llvm::SmallVector<std::string, 2> attributes{kLstmCellSimple,
                                                kCoupleInputForgetGates};
-  EXPECT_EQ(
-      fused_lstm_func_cifg_.getAttrOfType<StringAttr>(kTFImplements).getValue(),
-      llvm::join(attributes, ","));
+  EXPECT_EQ(fused_lstm_func_cifg_->getAttrOfType<StringAttr>(kTFImplements)
+                .getValue(),
+            llvm::join(attributes, ","));
 
   auto it = fused_lstm_func_cifg_.getBody().back().rbegin();
   EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName());
@@ -224,7 +224,7 @@
   fused_ln_lstm_func_.dump();
 
   EXPECT_EQ(
-      fused_ln_lstm_func_.getAttrOfType<StringAttr>(kTFImplements).getValue(),
+      fused_ln_lstm_func_->getAttrOfType<StringAttr>(kTFImplements).getValue(),
       convert.GetCompositeOpName());
   EXPECT_EQ(fused_ln_lstm_func_.getNumArguments(), 5);
   EXPECT_EQ(fused_ln_lstm_func_.getType().getNumResults(), 1);
diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc
index 5435450..277e0a6 100644
--- a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc
@@ -40,8 +40,8 @@
 }  // namespace
 
 void ConvertNMSPaddedFunc::RewriteFunc() {
-  func_.setAttr(kTFImplements,
-                StringAttr::get(kTfNMSPadded, func_.getContext()));
+  func_->setAttr(kTFImplements,
+                 StringAttr::get(kTfNMSPadded, func_.getContext()));
   Value boxes = func_.getArgument(0);
   Value scores = func_.getArgument(1);
   Value max_output_size = func_.getArgument(2);
@@ -85,8 +85,8 @@
 LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
   func_.eraseBody();
   func_.addEntryBlock();
-  func_.setAttr(kTFImplements,
-                StringAttr::get(kCustomSSDPostprocessing, func_.getContext()));
+  func_->setAttr(kTFImplements,
+                 StringAttr::get(kCustomSSDPostprocessing, func_.getContext()));
 
   OpBuilder builder(func_.getBody());
   std::string custom_option_buffer;
diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc
new file mode 100644
index 0000000..eba71da
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc
@@ -0,0 +1,147 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h"
+
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/OpDefinition.h"  // from @llvm-project
+#include "mlir/IR/Types.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/lite/c/builtin_op_data.h"
+
+namespace mlir {
+namespace TFL {
+
+namespace {
+
+constexpr char kTFImplements[] = "tf._implements";
+constexpr char kMaxUnpooling[] = "MaxUnpooling2D";
+
+inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
+                                       const std::string& content) {
+  ShapedType type = RankedTensorType::get(
+      {static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
+  return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
+                                 type,
+                                 StringRef(content.data(), content.size()));
+}
+
+inline LogicalResult GetIntegerArraySafe(
+    FuncOp* func, const DictionaryAttr& attrs, const std::string& attr_name,
+    llvm::SmallVectorImpl<int32_t>* results, int N) {
+  ArrayAttr array_attr = attrs.get(attr_name).dyn_cast_or_null<ArrayAttr>();
+  if (array_attr == nullptr || array_attr.size() != N) {
+    return func->emitError()
+           << "'" << attr_name << "' attribute for " << kMaxUnpooling
+           << " must be set and has size of " << N;
+  }
+  results->reserve(N);
+
+  for (Attribute integer_attr : array_attr.getValue()) {
+    IntegerAttr value = integer_attr.dyn_cast<IntegerAttr>();
+    if (!value) {
+      return func->emitError()
+             << "'" << attr_name << "' attribute for " << kMaxUnpooling
+             << " does not contain integer values";
+    }
+    results->push_back(value.getInt());
+  }
+  return success();
+}
+
+}  // namespace
+
+LogicalResult ConvertMaxUnpoolingFunc::RewriteFunc() {
+  func_.eraseBody();
+  func_.addEntryBlock();
+  func_->setAttr(kTFImplements,
+                 StringAttr::get(kMaxUnpooling, func_.getContext()));
+
+  OpBuilder builder(func_.getBody());
+  std::string custom_option_buffer;
+  if (failed(CreateCustomOptions(custom_option_buffer))) {
+    return failure();
+  }
+  auto op = builder.create<CustomOp>(
+      func_.getLoc(), func_.getType().getResults(), func_.getArguments(),
+      kMaxUnpooling, CustomOption(&builder, custom_option_buffer));
+  builder.create<ReturnOp>(func_.getLoc(), op.getResults());
+
+  return success();
+}
+
+LogicalResult ConvertMaxUnpoolingFunc::VerifySignature() {
+  // Verify high-level function signature.
+  if (func_.getNumArguments() != 2) {
+    return func_.emitError()
+           << "Invalid number of arguments to " << kMaxUnpooling << ": "
+           << func_.getNumArguments();
+  }
+  if (func_.getType().getNumResults() != 1) {
+    return func_.emitError()
+           << "Invalid number of results from " << kMaxUnpooling << ": "
+           << func_.getType().getNumResults();
+  }
+  return success();
+}
+
+LogicalResult ConvertMaxUnpoolingFunc::CreateCustomOptions(
+    std::string& custom_option_buffer) {
+  auto attrs = attr_.GetAttrs();
+  TfLitePoolParams pool_params;
+
+  llvm::SmallVector<int32_t, 2> pool_size;
+  if (failed(GetIntegerArraySafe(&func_, attrs, "pool_size", &pool_size, 2))) {
+    return failure();
+  }
+  pool_params.filter_height = pool_size[0];
+  pool_params.filter_width = pool_size[1];
+
+  // Retrieve strides.
+  llvm::SmallVector<int32_t, 2> strides;
+  if (failed(GetIntegerArraySafe(&func_, attrs, "strides", &strides, 2))) {
+    return failure();
+  }
+  pool_params.stride_height = strides[0];
+  pool_params.stride_width = strides[1];
+
+  // Retrieves padding.
+  auto padding = attrs.get("padding").dyn_cast_or_null<StringAttr>();
+  if (!padding) {
+    return func_.emitError() << "'padding' attribute for " << kMaxUnpooling
+                             << " is not set or not a string";
+  }
+  if (padding.getValue().equals("VALID")) {
+    pool_params.padding = kTfLitePaddingValid;
+  } else if (padding.getValue().equals("SAME")) {
+    pool_params.padding = kTfLitePaddingSame;
+  } else {
+    return func_.emitError()
+           << "Padding for " << kMaxUnpooling << " must be 'SAME' or 'VALID'";
+  }
+
+  pool_params.activation = kTfLiteActNone;
+  pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
+
+  custom_option_buffer.assign(reinterpret_cast<char*>(&pool_params),
+                              sizeof(TfLitePoolParams));
+  return success();
+}
+
+}  // namespace TFL
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h
new file mode 100644
index 0000000..e82c779
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h
@@ -0,0 +1,47 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_PERCEPTION_OPS_UTILS_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_PERCEPTION_OPS_UTILS_H_
+
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
+
+namespace mlir {
+namespace TFL {
+
+// Fuse MaxUnpooling2D ops annotated by tf.function to a TFLite custom op.
+class ConvertMaxUnpoolingFunc {
+ public:
+  explicit ConvertMaxUnpoolingFunc(FuncOp func, mlir::TF::FuncAttr attr)
+      : func_(func), attr_(attr) {}
+
+  LogicalResult RewriteFunc();
+
+  LogicalResult VerifySignature();
+
+ private:
+  LogicalResult CreateCustomOptions(std::string& custom_option_buffer);
+
+  FuncOp func_;
+  mlir::TF::FuncAttr attr_;
+};
+
+}  // end namespace TFL
+}  // end namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_PERCEPTION_OPS_UTILS_H_
diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc
new file mode 100644
index 0000000..19a2b81
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc
@@ -0,0 +1,196 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h"
+
+#include <memory>
+#include <vector>
+
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace mlir {
+namespace TFL {
+namespace {
+
+template <int NInput, int NOutput>
+FuncOp createMaxUnpoolingFunc(
+    mlir::Builder* builder, const SmallVector<mlir::Type, NInput>& input_types,
+    const SmallVector<mlir::Type, NOutput>& output_types) {
+  auto func_type = builder->getFunctionType(input_types, output_types);
+  auto func =
+      FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"),
+                                        builder->getContext()),
+                     "fused_func", func_type, {});
+
+  func.addEntryBlock();
+  mlir::StringAttr attr_value = builder->getStringAttr("MaxUnpooling2D");
+  func.setAttr("tf._implements", attr_value);
+  return func;
+}
+
+FuncOp createMaxUnpoolingFunc(mlir::Builder* builder,
+                              const SmallVector<int64_t, 4>& input_shape,
+                              const SmallVector<int64_t, 4>& output_shape) {
+  auto input_type = RankedTensorType::get(input_shape, builder->getF32Type());
+  auto indices_type = RankedTensorType::get(input_shape, builder->getI64Type());
+  auto output_type = RankedTensorType::get(output_shape, builder->getF32Type());
+  SmallVector<mlir::Type, 2> input_types{input_type, indices_type};
+  SmallVector<mlir::Type, 1> output_types{output_type};
+  return createMaxUnpoolingFunc<2, 1>(builder, input_types, output_types);
+}
+
+template <int N>
+ArrayAttr createInt32Array(mlir::Builder* builder, mlir::MLIRContext* context,
+                           const SmallVector<int32_t, N>& values) {
+  SmallVector<Attribute, N> ret;
+  for (int32_t value : values) {
+    ret.push_back(builder->getI32IntegerAttr(value));
+  }
+  return ArrayAttr::get(ret, context);
+}
+
+template <int N>
+ArrayAttr createInt64Array(mlir::Builder* builder, mlir::MLIRContext* context,
+                           const SmallVector<int64_t, N>& values) {
+  SmallVector<Attribute, N> ret;
+  for (int64_t value : values) {
+    ret.push_back(builder->getI64IntegerAttr(value));
+  }
+  return ArrayAttr::get(ret, context);
+}
+
+mlir::TF::FuncAttr createMaxUnpoolingAttr(mlir::MLIRContext* context,
+                                          const std::string& padding,
+                                          const ArrayAttr& pool_size,
+                                          const ArrayAttr& strides) {
+  SmallVector<::mlir::NamedAttribute, 3> fields;
+
+  auto padding_id = ::mlir::Identifier::get("padding", context);
+  fields.emplace_back(padding_id, StringAttr::get(padding, context));
+
+  auto pool_size_id = ::mlir::Identifier::get("pool_size", context);
+  fields.emplace_back(pool_size_id, pool_size);
+
+  auto strides_id = ::mlir::Identifier::get("strides", context);
+  fields.emplace_back(strides_id, strides);
+
+  DictionaryAttr dict = DictionaryAttr::get(fields, context);
+  return TF::FuncAttr::get(context, "MaxUnpooling2D", dict);
+}
+
+}  // namespace
+
+class PerceptionUtilsTest : public ::testing::Test {
+ protected:
+  PerceptionUtilsTest() {}
+
+  void SetUp() override {
+    context_ = std::make_unique<mlir::MLIRContext>();
+    context_->loadDialect<mlir::StandardOpsDialect, mlir::TF::TensorFlowDialect,
+                          TensorFlowLiteDialect>();
+    builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
+
+    fused_max_unpooling_func_ =
+        createMaxUnpoolingFunc(builder_.get(), {2, 4, 4, 2}, {2, 2, 2, 2});
+
+    func_attr_ = createMaxUnpoolingAttr(
+        context_.get(), "SAME",
+        createInt32Array<2>(builder_.get(), context_.get(), {2, 2}),
+        createInt32Array<2>(builder_.get(), context_.get(), {2, 2}));
+  }
+
+  void TearDown() override {
+    fused_max_unpooling_func_.erase();
+    builder_.reset();
+  }
+
+  FuncOp fused_max_unpooling_func_;
+  mlir::TF::FuncAttr func_attr_;
+  std::unique_ptr<mlir::MLIRContext> context_;
+  std::unique_ptr<mlir::Builder> builder_;
+};
+
+TEST_F(PerceptionUtilsTest, VerifySignatureValid) {
+  mlir::TFL::ConvertMaxUnpoolingFunc convert(fused_max_unpooling_func_,
+                                             func_attr_);
+
+  EXPECT_FALSE(failed(convert.VerifySignature()));
+}
+
+TEST_F(PerceptionUtilsTest, VerifySignatureInvalid) {
+  auto input_type = RankedTensorType::get({1, 2, 2, 1}, builder_->getF32Type());
+  auto output_type =
+      RankedTensorType::get({1, 2, 1, 1}, builder_->getF32Type());
+  SmallVector<mlir::Type, 1> input_types{input_type};
+  SmallVector<mlir::Type, 1> output_types{output_type};
+
+  auto max_unpooling_func =
+      createMaxUnpoolingFunc<1, 1>(builder_.get(), input_types, output_types);
+  mlir::TFL::ConvertMaxUnpoolingFunc convert(max_unpooling_func, func_attr_);
+
+  EXPECT_TRUE(failed(convert.VerifySignature()));
+  max_unpooling_func->erase();
+}
+
+TEST_F(PerceptionUtilsTest, RewriteValid) {
+  mlir::TFL::ConvertMaxUnpoolingFunc convert(fused_max_unpooling_func_,
+                                             func_attr_);
+
+  EXPECT_FALSE(failed(convert.RewriteFunc()));
+}
+
+TEST_F(PerceptionUtilsTest, RewriteWrongPadding) {
+  auto func_attr = createMaxUnpoolingAttr(
+      context_.get(), "INVALID",
+      createInt32Array<2>(builder_.get(), context_.get(), {2, 2}),
+      createInt32Array<2>(builder_.get(), context_.get(), {2, 2}));
+  mlir::TFL::ConvertMaxUnpoolingFunc convert(fused_max_unpooling_func_,
+                                             func_attr);
+
+  EXPECT_TRUE(failed(convert.RewriteFunc()));
+}
+
+TEST_F(PerceptionUtilsTest, RewriteWrongFilter) {
+  auto func_attr = createMaxUnpoolingAttr(
+      context_.get(), "VALID",
+      createInt32Array<2>(builder_.get(), context_.get(), {2, 2, 2}),
+      createInt32Array<2>(builder_.get(), context_.get(), {2, 2}));
+  mlir::TFL::ConvertMaxUnpoolingFunc convert(fused_max_unpooling_func_,
+                                             func_attr);
+
+  EXPECT_TRUE(failed(convert.RewriteFunc()));
+}
+
+TEST_F(PerceptionUtilsTest, RewriteWrongStrides) {
+  auto func_attr = createMaxUnpoolingAttr(
+      context_.get(), "VALID",
+      createInt32Array<2>(builder_.get(), context_.get(), {2, 2}),
+      createInt32Array<2>(builder_.get(), context_.get(), {2, 2, 0}));
+  mlir::TFL::ConvertMaxUnpoolingFunc convert(fused_max_unpooling_func_,
+                                             func_attr);
+
+  EXPECT_TRUE(failed(convert.RewriteFunc()));
+}
+
+}  // namespace TFL
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc
index f714453..b4306de 100644
--- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc
@@ -26,13 +26,13 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
@@ -133,7 +133,7 @@
                                          FuncAttr attr) {
   func.eraseBody();
   func.addEntryBlock();
-  func.setAttr(kTFImplements, attr);
+  func->setAttr(kTFImplements, attr);
   OpBuilder builder(func.getBody());
   std::string empty_option_buffer;
   auto op = builder.create<CustomOp>(
@@ -256,7 +256,7 @@
 LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) {
   func.eraseBody();
   func.addEntryBlock();
-  func.setAttr(kTFImplements, attr);
+  func->setAttr(kTFImplements, attr);
   OpBuilder builder(func.getBody());
   std::string custom_option_buffer;
   if (failed(CreateNgramsCustomOption(func, attr.GetAttrs(),
@@ -336,7 +336,7 @@
   // See more details in tensorflow_models/sequence_projection/sgnn/sgnn.py
   func.eraseBody();
   func.addEntryBlock();
-  func.setAttr(kTFImplements, attr);
+  func->setAttr(kTFImplements, attr);
   OpBuilder builder(func.getBody());
   std::string custom_option_buffer;
   if (failed(CreateSgnnProjectionCustomOption(func, attr.GetAttrs(),
diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.h b/tensorflow/compiler/mlir/lite/utils/tftext_utils.h
index 82938d5..60a954d 100644
--- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.h
+++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.h
@@ -22,8 +22,8 @@
 #include "llvm/ADT/StringRef.h"
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h
index d7a56fa..f73a75c 100644
--- a/tensorflow/compiler/mlir/lite/utils/validators.h
+++ b/tensorflow/compiler/mlir/lite/utils/validators.h
@@ -20,7 +20,7 @@
 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
 
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 
 namespace mlir {
 namespace TFL {
diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
index be2dc20..7ff50a0 100644
--- a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
+++ b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 
@@ -33,7 +33,7 @@
       .def("getF64", &mlir::FloatType::getF64);
 
   py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
-      .def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
+      .def("get", py::overload_cast<mlir::MLIRContext*, unsigned>(
                       &mlir::IntegerType::get));
 
   py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 95d916c..8038f50 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -639,6 +639,7 @@
         ":tensorflow_tfrt_ops_inc_gen",
         ":tensorflow_traits",
         ":tensorflow_types",
+        ":tf_pass_inc_gen",
         ":tf_saved_model_inc_gen",
         "//tensorflow/compiler/mlir/lite:validators",
         "//tensorflow/core:framework",
@@ -867,8 +868,10 @@
         "transforms/collection_ops_util.cc",
         "transforms/constant_op_device_assignment.cc",
         "transforms/contraction_fusion.cc",
+        "transforms/cross_host_transfer.cc",
         "transforms/decompose_resource_ops_pass.cc",
         "transforms/device_index_selector.cc",
+        "transforms/drop_while_shape_invariant.cc",
         "transforms/einsum.cc",
         "transforms/executor_island_coarsening.cc",
         "transforms/executor_tpuv1_inline_tpu_island.cc",
@@ -889,6 +892,7 @@
         "transforms/layout_optimization.cc",
         "transforms/mark_ops_for_outside_compilation.cc",
         "transforms/materialize_mlir_passthrough_op.cc",
+        "transforms/merge_control_flow.cc",
         "transforms/optimize.cc",
         "transforms/outside_compiled_to_host_launch.cc",
         "transforms/parallel_execute_to_islands.cc",
@@ -1007,6 +1011,7 @@
         "@llvm-project//mlir:Rewrite",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:TransformUtils",
         "@llvm-project//mlir:Transforms",
     ],
@@ -1104,7 +1109,6 @@
         ":mlir_roundtrip_flags",
         ":tensorflow",
         ":tensorflow_attributes",
-        ":tensorflow_passes",
         ":tensorflow_types",
         ":tf_saved_model_passes",
         ":translate_utils",
@@ -1464,7 +1468,6 @@
         ":tensorflow_traits",
         ":tensorflow_types",
         "//tensorflow/c:tf_status",
-        "//tensorflow/c/eager:c_api",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/stream_executor",
@@ -1669,15 +1672,16 @@
     "//tensorflow/compiler/tf2xla:common",
     "//tensorflow/compiler/tf2xla:xla_helpers",
     "//tensorflow/compiler/tf2xla:xla_argument",
-    "//tensorflow/compiler/xla/client:xla_computation",
-    "//tensorflow/core/common_runtime:core_cpu_internal",
-    "//tensorflow/core/platform:logging",
-    "//tensorflow/core:framework",
-    "//tensorflow/core:protos_all_cc",
-    "//tensorflow/stream_executor/lib",
     "//tensorflow/compiler/xla:shape_util",
     "//tensorflow/compiler/xla:xla_data_proto_cc",
+    "//tensorflow/compiler/xla/client:xla_computation",
     "//tensorflow/compiler/xla/service:hlo",
+    "//tensorflow/core:framework",
+    "//tensorflow/core:protos_all_cc",
+    "//tensorflow/core/common_runtime:core_cpu_internal",
+    "//tensorflow/core/platform:logging",
+    "//tensorflow/core/tpu:tpu_defs",
+    "//tensorflow/stream_executor/lib",
 ]
 
 # Prefer to link 'compile_mlir_util' library that also links necessary
@@ -1740,6 +1744,7 @@
         ":translate_cl_options",
         "//tensorflow/compiler/mlir:string_container_utils",
         "//tensorflow/compiler/mlir/xla:translate_cl_options",
+        "//tensorflow/compiler/mlir/xla:type_to_shape",
         "//tensorflow/compiler/tf2xla:xla_argument",
         "//tensorflow/compiler/tf2xla:xla_helpers",
         "//tensorflow/compiler/xla/service:hlo",
@@ -1753,7 +1758,6 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
-        "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Translation",
@@ -1976,6 +1980,7 @@
     hdrs = ["utils/bridge_logger.h"],
     deps = [
         ":dump_mlir_util",
+        "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc
index 4f0a258..7b0f402 100644
--- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc
+++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc
@@ -30,8 +30,8 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Block.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/IR/Visitors.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
index 71ab528..de7861c 100644
--- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
+++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
@@ -31,10 +31,10 @@
 #include "mlir/IR/Block.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
index 82c23d2..e4d3dc6 100644
--- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
+++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
@@ -23,11 +23,11 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
@@ -527,7 +527,7 @@
   // In case of failure, the `diag_handler` converts MLIR errors emitted to
   // the MLIRContext into a tensorflow::Status.
   StatusScopedDiagnosticHandler diag_handler(func_.getContext());
-  LogicalResult result = pm.run(func_.getParentOfType<ModuleOp>());
+  LogicalResult result = pm.run(func_->getParentOfType<ModuleOp>());
   (void)result;
   TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus());
 
@@ -668,7 +668,7 @@
 
   auto arg_types = body.getArgumentTypes();
   auto result_types = body.getTerminator()->getOperandTypes();
-  func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
+  func_.setType(FunctionType::get(func_.getContext(), arg_types, result_types));
   *f = new MlirFunction(std::move(context_), std::move(module_), func_);
   return Status::OK();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h
index c00b974..d93cb3c 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h
@@ -20,8 +20,8 @@
 
 #include "llvm/ADT/StringRef.h"
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 
 namespace mlir {
 namespace TF {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
index ca7f7a7..9498420 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
@@ -22,18 +22,20 @@
 #include <utility>
 
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/SMLoc.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
@@ -512,22 +514,13 @@
 
 void BuildReplicateOp(
     Builder* builder, OperationState* state, int n,
-    const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
-        devices,
+    llvm::Optional<DictionaryAttr> devices,
     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
     ValueRange packed_inputs, TypeRange replica_output_types) {
   DCHECK_GE(n, 2);
   state->addAttribute("n", builder->getI32IntegerAttr(n));
 
-  llvm::SmallVector<mlir::NamedAttribute, 1> device_list;
-  device_list.reserve(devices.size());
-  for (auto alias_and_devices : devices) {
-    NamedAttribute device_name_attr = builder->getNamedAttr(
-        alias_and_devices.getFirst(),
-        builder->getStrArrayAttr(alias_and_devices.getSecond()));
-    device_list.emplace_back(device_name_attr);
-  }
-  state->addAttribute("devices", builder->getDictionaryAttr(device_list));
+  if (devices.hasValue()) state->addAttribute("devices", devices.getValue());
 
   Region* region = state->addRegion();
   region->push_back(new Block);
@@ -567,6 +560,28 @@
         devices,
     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
     ValueRange packed_inputs, TypeRange replica_output_types) {
+  llvm::Optional<DictionaryAttr> devices_attr;
+  if (!devices.empty()) {
+    llvm::SmallVector<mlir::NamedAttribute, 1> device_list;
+    device_list.reserve(devices.size());
+    for (auto alias_and_devices : devices) {
+      NamedAttribute device_name_attr = builder.getNamedAttr(
+          alias_and_devices.getFirst(),
+          builder.getStrArrayAttr(alias_and_devices.getSecond()));
+      device_list.emplace_back(device_name_attr);
+    }
+    devices_attr.emplace(builder.getDictionaryAttr(device_list));
+  }
+
+  BuildReplicateOp(&builder, &state, n, devices_attr, replicated_inputs,
+                   packed_inputs, replica_output_types);
+}
+
+void ReplicateOp::build(
+    OpBuilder& builder, OperationState& state, int n,
+    llvm::Optional<DictionaryAttr> devices,
+    llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
+    ValueRange packed_inputs, TypeRange replica_output_types) {
   BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
                    packed_inputs, replica_output_types);
 }
@@ -630,6 +645,10 @@
   return getOperand(operand_index);
 }
 
+// Checks if a tf_device.replicate wraps a single operation and the single
+// operation results are perfectly forwarded to the replicate return.
+bool ReplicateOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
+
 //===----------------------------------------------------------------------===//
 // Canonicalization patterns
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td
index e9c2a05..92d3815 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td
@@ -289,6 +289,7 @@
     bool IsPackedBlockArgument(BlockArgument block_arg);
     unsigned GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg, unsigned replica);
     Value GetReplicaOperandForBlockArgument(BlockArgument block_arg, unsigned replica);
+    bool WrapsSingleOp();
   }];
 
   let builders = [
@@ -296,6 +297,9 @@
       "const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&":$devices,
       "llvm::ArrayRef<std::pair<ValueRange, Type>>":$replicated_inputs,
       "ValueRange":$packed_inputs, "TypeRange":$replica_output_types)>,
+    OpBuilderDAG<(ins "int":$n, "llvm::Optional<DictionaryAttr>":$devices,
+      "llvm::ArrayRef<std::pair<ValueRange, Type>>":$replicated_inputs,
+      "ValueRange":$packed_inputs, "TypeRange":$replica_output_types)>,
   ];
 
   let parser = [{ return Parse$cppClass(&parser, &result); }];
@@ -379,4 +383,41 @@
   }];
 }
 
+def TfDevice_SendOp : TfDevice_Op<"send", []> {
+  let summary = "Send a value to a host.";
+
+  let description = [{
+    Send the value to the given host with the given rendezvous key.
+  }];
+
+  let arguments = (ins
+    AnyType:$value,
+    StrAttr:$key,
+    StrAttr:$dst_host
+  );
+
+  let results = (outs);
+
+  let assemblyFormat = [{$value $key $dst_host attr-dict `:` type($value)}];
+}
+
+def TfDevice_ReceiveOp : TfDevice_Op<"receive", []> {
+  let summary = "Rceive a value from a host.";
+
+  let description = [{
+    Receive a value from the given host with the given rendezvous key.
+  }];
+
+  let arguments = (ins
+    StrAttr:$key,
+    StrAttr:$src_host
+  );
+
+  let results = (outs
+    AnyType:$result
+  );
+
+  let assemblyFormat = [{$key $src_host attr-dict `:` type($result)}];
+}
+
 #endif // TF_DEVICE_DIALECT
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
index ff0bfb5..e24c83e 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
@@ -31,13 +31,13 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -163,7 +163,7 @@
 namespace {
 
 LogicalResult Verify(GraphOp graph) {
-  auto *executorDialect = graph.getDialect();
+  auto *executorDialect = graph->getDialect();
 
   if (graph.GetBody().empty())
     return graph.emitOpError() << "expects a non-empty body";
@@ -461,7 +461,7 @@
 namespace {
 
 LogicalResult Verify(SwitchNOp switchn) {
-  IntegerAttr num_outs = switchn.getAttrOfType<IntegerAttr>("num_outs");
+  IntegerAttr num_outs = switchn->getAttrOfType<IntegerAttr>("num_outs");
   if (!num_outs)
     return switchn.emitOpError() << "expects a `num_outs` integer attribute";
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h
index 2bc1355..4354736 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h
@@ -24,10 +24,10 @@
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 
 namespace mlir {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index b0cc6f9..752cf0c 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -870,6 +870,10 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 
   let hasCanonicalizer = 1;
+
+  let verifier = [{
+    return Verify(*this);
+  }];
 }
 
 def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
@@ -1907,7 +1911,7 @@
   TF_DerivedOperandTypeAttr Tperm = TF_DerivedOperandTypeAttr<1>;
 }
 
-def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> {
+def TF_Conv2DOp : TF_Op<"Conv2D", [DeclareOpInterfaceMethods<InferTypeOpInterface>, NoSideEffect, TF_LayoutSensitiveInterface]> {
   let summary = [{
 Computes a 2-D convolution given 4-D `input` and `filter` tensors.
   }];
@@ -1964,6 +1968,10 @@
     SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
     StringRef GetOptimalLayout(const RuntimeDevices& devices);
     LogicalResult UpdateDataFormat(StringRef data_format);
+    // InferTypeOpInterface:
+    static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r) {
+      return ArraysAreCastCompatible(l, r);
+    }
   }];
 }
 
@@ -2037,7 +2045,7 @@
   }];
 }
 
-def TF_Conv3DOp : TF_Op<"Conv3D", [NoSideEffect]> {
+def TF_Conv3DOp : TF_Op<"Conv3D", [DeclareOpInterfaceMethods<InferTypeOpInterface>, NoSideEffect]> {
   let summary = [{
 Computes a 3-D convolution given 5-D `input` and `filter` tensors.
   }];
@@ -2069,6 +2077,14 @@
   let verifier = [{
     return Verify(*this);
   }];
+
+  let extraClassDeclaration = [{
+    // InferTypeOpInterface:
+    static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r) {
+      return ArraysAreCastCompatible(l, r);
+    }
+  }];
+
 }
 
 def TF_Conv3DBackpropFilterV2Op : TF_Op<"Conv3DBackpropFilterV2", [NoSideEffect]> {
@@ -3392,7 +3408,24 @@
   let hasFolder = 1;
 }
 
-def TF_EnqueueTPUEmbeddingRaggedTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingRaggedTensorBatch", [SameVariadicOperandSize]> {
+def TF_EnqueueTPUEmbeddingIntegerBatchOp : TF_Op<"EnqueueTPUEmbeddingIntegerBatch", [TF_TPUEmbeddingSideEffect]> {
+  let summary = [{
+An op that enqueues a list of input batch tensors to TPUEmbedding.
+  }];
+
+  let arguments = (ins
+    Variadic<TF_Int32Tensor>:$batch,
+    TF_StrTensor:$mode_override,
+
+    DefaultValuedAttr<I64Attr, "-1">:$device_ordinal
+  );
+
+  let results = (outs);
+
+  TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
+}
+
+def TF_EnqueueTPUEmbeddingRaggedTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingRaggedTensorBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
   let summary = "Eases the porting of code that uses tf.nn.embedding_lookup().";
 
   let description = [{
@@ -3426,7 +3459,42 @@
   TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
 }
 
-def TF_EnqueueTPUEmbeddingSparseTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseTensorBatch", [SameVariadicOperandSize]> {
+def TF_EnqueueTPUEmbeddingSparseBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
+  let summary = [{
+An op that enqueues TPUEmbedding input indices from a SparseTensor.
+  }];
+
+  let description = [{
+This Op eases the porting of code that uses embedding_lookup_sparse(),
+although some Python preprocessing of the SparseTensor arguments to
+embedding_lookup_sparse() is required to produce the arguments to this Op,
+since only a single EnqueueTPUEmbeddingSparseBatch Op is allowed per training
+step.
+
+The tensors at corresponding positions in the three input lists
+must have the same shape, i.e. rank 1 with dim_size() equal to the total
+number of lookups into the table described by the corresponding table_id.
+  }];
+
+  let arguments = (ins
+    Variadic<TF_I32OrI64Tensor>:$sample_indices,
+    Variadic<TF_I32OrI64Tensor>:$embedding_indices,
+    Variadic<TF_F32OrF64Tensor>:$aggregation_weights,
+    TF_StrTensor:$mode_override,
+
+    DefaultValuedAttr<I64Attr, "-1">:$device_ordinal,
+    DefaultValuedAttr<StrArrayAttr, "{}">:$combiners
+  );
+
+  let results = (outs);
+
+  TF_DerivedOperandTypeAttr T1 = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr T2 = TF_DerivedOperandTypeAttr<1>;
+  TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
+  TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
+}
+
+def TF_EnqueueTPUEmbeddingSparseTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseTensorBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
   let summary = [{
 Eases the porting of code that uses tf.nn.embedding_lookup_sparse().
   }];
@@ -5024,6 +5092,26 @@
   TF_DerivedOperandTypeAttr Tkey = TF_DerivedOperandTypeAttr<1>;
 }
 
+def TF_InplaceAddOp : TF_Op<"InplaceAdd", [AllTypesMatch<["x", "y"]>, NoSideEffect]> {
+  let summary = "Adds v into specified rows of x.";
+
+  let description = [{
+Computes y = x; y[i, :] += v; return y.
+  }];
+
+  let arguments = (ins
+    TF_Tensor:$x,
+    TF_Int32Tensor:$i,
+    TF_Tensor:$v
+  );
+
+  let results = (outs
+    TF_Tensor:$y
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
   let summary = "Updates specified rows 'i' with values 'v'.";
 
@@ -5370,6 +5458,37 @@
   );
 }
 
+def TF_KthOrderStatisticOp : TF_Op<"KthOrderStatistic", [NoSideEffect]> {
+  let summary = "Computes the Kth order statistic of a data set. The current";
+
+  let description = [{
+implementation uses a binary search requiring exactly 32 passes over
+the input data. The running time is linear with respect to input
+size. The median-of-medians algorithm is probably faster, but is
+difficult to implement efficiently in XLA. The implementation imposes
+a total ordering on floats. The ordering is consistent with the usual
+partial order.  Positive NaNs are greater than positive
+infinity. Negative NaNs are less than negative infinity. NaNs with
+distinct payloads are treated as distinct. Subnormal numbers are
+preserved (not flushed to zero). Positive infinity is greater than all
+numbers. Negative infinity is less than all numbers. Positive is
+greater than negative zero. There are less than k values greater than
+the kth order statistic. There are at least k values greater than or
+equal to the Kth order statistic. The semantics are not the same as
+top_k_unique.
+  }];
+
+  let arguments = (ins
+    TF_Float32Tensor:$input,
+
+    I64Attr:$k
+  );
+
+  let results = (outs
+    TF_Float32Tensor:$output
+  );
+}
+
 def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> {
   let summary = "L2 Loss.";
 
@@ -6501,6 +6620,27 @@
   let results = (outs);
 }
 
+def TF_MakeUniqueOp : TF_Op<"MakeUnique", [NoSideEffect]> {
+  let summary = [{
+Make all elements in the non-Batch dimension unique, but \"close\" to
+  }];
+
+  let description = [{
+their initial value. Never returns a sub-normal number. Never returns
+zero. The sign of each input element is always identical to the sign
+of the corresponding output element. Behavior for infinite elements is
+undefined. Behavior for subnormal elements is undefined.
+  }];
+
+  let arguments = (ins
+    TF_Float32Tensor:$input
+  );
+
+  let results = (outs
+    TF_Float32Tensor:$output
+  );
+}
+
 def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
   let summary = [{
 Multiply the matrix "a" by the matrix "b".
@@ -8164,7 +8304,7 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
-def TF_NegOp : TF_Op<"Neg", [Involution, NoSideEffect, SameOperandsAndResultType]> {
+def TF_NegOp : TF_Op<"Neg", [Involution, NoSideEffect, SameOperandsAndResultType, TF_CwiseUnary]> {
   let summary = "Computes numerical negative value element-wise.";
 
   let description = [{
@@ -9607,7 +9747,7 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
-def TF_RecvTPUEmbeddingActivationsOp : TF_Op<"RecvTPUEmbeddingActivations", []> {
+def TF_RecvTPUEmbeddingActivationsOp : TF_Op<"RecvTPUEmbeddingActivations", [TF_TPUEmbeddingSideEffect]> {
   let summary = "An op that receives embedding activations on the TPU.";
 
   let description = [{
@@ -9630,6 +9770,48 @@
   TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>;
 }
 
+def TF_ReduceJoinOp : TF_Op<"ReduceJoin", [NoSideEffect]> {
+  let summary = "Joins a string Tensor across the given dimensions.";
+
+  let description = [{
+Computes the string join across dimensions in the given string Tensor of shape
+`[\\(d_0, d_1, ..., d_{n-1}\\)]`.  Returns a new Tensor created by joining the input
+strings with the given separator (default: empty string).  Negative indices are
+counted backwards from the end, with `-1` being equivalent to `n - 1`.  If
+indices are not specified, joins across all dimensions beginning from `n - 1`
+through `0`.
+
+For example:
+
+```python
+# tensor `a` is [["a", "b"], ["c", "d"]]
+tf.reduce_join(a, 0) ==> ["ac", "bd"]
+tf.reduce_join(a, 1) ==> ["ab", "cd"]
+tf.reduce_join(a, -2) = tf.reduce_join(a, 0) ==> ["ac", "bd"]
+tf.reduce_join(a, -1) = tf.reduce_join(a, 1) ==> ["ab", "cd"]
+tf.reduce_join(a, 0, keep_dims=True) ==> [["ac", "bd"]]
+tf.reduce_join(a, 1, keep_dims=True) ==> [["ab"], ["cd"]]
+tf.reduce_join(a, 0, separator=".") ==> ["a.c", "b.d"]
+tf.reduce_join(a, [0, 1]) ==> "acbd"
+tf.reduce_join(a, [1, 0]) ==> "abcd"
+tf.reduce_join(a, []) ==> [["a", "b"], ["c", "d"]]
+tf.reduce_join(a) = tf.reduce_join(a, [1, 0]) ==> "abcd"
+```
+  }];
+
+  let arguments = (ins
+    TF_StrTensor:$inputs,
+    TF_Int32Tensor:$reduction_indices,
+
+    DefaultValuedAttr<BoolAttr, "false">:$keep_dims,
+    StrAttr:$separator
+  );
+
+  let results = (outs
+    TF_StrTensor:$output
+  );
+}
+
 def TF_ReluOp : TF_Op<"Relu", [Idempotent, NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
   let summary = "Computes rectified linear: `max(features, 0)`.";
 
@@ -11569,6 +11751,29 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_RiscAddOp : TF_Op<"RiscAdd", [Commutative, NoSideEffect]> {
+  let summary = "Returns x + y element-wise.";
+
+  let description = [{
+*NOTE*: `RiscAdd` does not supports broadcasting.
+
+Given two input tensors, the `tf.risc_add` operation computes the sum for every element in the tensor.
+
+Both input and output have a range `(-inf, inf)`.
+  }];
+
+  let arguments = (ins
+    TF_FloatTensor:$x,
+    TF_FloatTensor:$y
+  );
+
+  let results = (outs
+    TF_FloatTensor:$z
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> {
   let summary = "Rolls the elements of a tensor along an axis.";
 
@@ -15230,6 +15435,36 @@
   let hasFolder = 1;
 }
 
+def TF_TopKUniqueOp : TF_Op<"TopKUnique", [NoSideEffect]> {
+  let summary = "Returns the TopK unique values in the array in sorted order.";
+
+  let description = [{
+The running time is proportional to the product of K and the input
+size. Sorting the whole array is more efficient for sufficiently large
+values of K. The median-of-medians algorithm is probably faster, but
+difficult to implement efficiently in XLA. If there are fewer than K
+unique numbers (not NANs), the results are padded with negative
+infinity. NaNs are never returned. Subnormal numbers are flushed to
+zero. If an element appears at multiple indices, the highest index is
+returned. If a TopK element never appears in the input due to padding
+values, the indices are padded with negative one. If a padding value
+appears in the input and padding is needed, the highest index of the
+padding value will be returned. The semantics are not the same as
+kth_order_statistic.
+  }];
+
+  let arguments = (ins
+    TF_Float32Tensor:$input,
+
+    I64Attr:$k
+  );
+
+  let results = (outs
+    TF_Float32Tensor:$topk,
+    TF_Int32Tensor:$topk_indices
+  );
+}
+
 def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
   let summary = [{
 Finds values and indices of the `k` largest elements for the last dimension.
@@ -15265,6 +15500,29 @@
   let verifier = [{ return Verify(*this); }];
 }
 
+def TF_TopKWithUniqueOp : TF_Op<"TopKWithUnique", [NoSideEffect]> {
+  let summary = "Returns the TopK values in the array in sorted order.";
+
+  let description = [{
+This is a combination of MakeUnique and TopKUnique. The returned top-K will
+have its lower bits replaced by iota, thus it will be close to the original
+value but not exactly the same. The running time is proportional to the product
+of K and the input size. NaNs are never returned. Subnormal numbers are flushed
+to zero.
+  }];
+
+  let arguments = (ins
+    TF_Float32Tensor:$input,
+
+    I64Attr:$k
+  );
+
+  let results = (outs
+    TF_Float32Tensor:$topk,
+    TF_Int32Tensor:$topk_indices
+  );
+}
+
 def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
   let summary = "Shuffle dimensions of x according to a permutation.";
 
@@ -15476,6 +15734,8 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 
   let verifier = [{ return Verify(*this); }];
+
+  let hasCanonicalizer = 1;
 }
 
 def TF_UnsortedSegmentMaxOp : TF_Op<"UnsortedSegmentMax", [NoSideEffect]> {
@@ -16344,7 +16604,7 @@
   TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>;
 }
 
-def TF_XlaSetDynamicDimensionSizeOp : TF_Op<"XlaSetDynamicDimensionSize", [NoSideEffect]> {
+def TF_XlaSetDynamicDimensionSizeOp : TF_Op<"XlaSetDynamicDimensionSize", [DeclareOpInterfaceMethods<InferTypeOpInterface>, NoSideEffect, TF_NoConstantFold]> {
   let summary = "Make a static dimension into a xla bounded dynamic dimension.";
 
   let description = [{
@@ -16623,7 +16883,7 @@
   TF_DerivedResultSizeAttr N = TF_DerivedResultSizeAttr<0>;
 }
 
-def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> {
+def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", [TF_TPUEmbeddingSideEffect]> {
   let summary = "An op that receives embeddng activations on the TPU.";
 
   let description = [{
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
index a72d591..d4c05e8 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
@@ -65,6 +65,12 @@
 def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait<
   "TF::SameOperandsAndResultElementTypeResolveRef">;
 
+// Op has the same operand and result types after resolving reference types
+// (i.e., after converting reference types to their corresponding TensorFlow or
+// standard types).
+def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait<
+  "TF::SameOperandsAndResultTypeResolveRef">;
+
 // Layout agnostic operations do not depend on the operands data layout (data
 // format), as an example all element wise operations are layout agnostic.
 def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
@@ -331,8 +337,7 @@
 def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
 
 def TF_Float : AnyTypeOf<
-  [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16,
-   TF_Float16Ref, TF_Float32Ref, TF_Float64Ref, TF_Bfloat16Ref],
+  [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16],
   "floating-point">;
 
 // Tensor types
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index f673484..08ec87e 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -42,6 +42,7 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
@@ -51,7 +52,6 @@
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
index e5da007..ebb68d7 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
@@ -23,10 +23,10 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
 #include "mlir/Interfaces/ControlFlowInterfaces.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index c3214d9..4615064 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -878,9 +878,27 @@
   }];
 }
 
+def TF_EnqueueTPUEmbeddingBatchOp : TF_Op<"EnqueueTPUEmbeddingBatch", [TF_TPUEmbeddingSideEffect]> {
+  let summary = [{
+An op that enqueues a list of input batch tensors to TPUEmbedding.
+  }];
+
+  let arguments = (ins
+    Variadic<TF_StrTensor>:$batch,
+    TF_StrTensor:$mode_override,
+
+    DefaultValuedAttr<I64Attr, "-1">:$device_ordinal,
+    DefaultValuedAttr<StrArrayAttr, "{}">:$combiners
+  );
+
+  let results = (outs);
+
+  TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
+}
+
 // Multiple variadic operands with different sizes are not supported by the
 // dialect generator, so we manually added the op.
-def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments]> {
+def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> {
   let summary = "Performs gradient updates of embedding tables.";
 
   let description = [{
@@ -911,7 +929,7 @@
 
 // Multiple variadic operands with different sizes are not supported by the
 // dialect generator, so we manually added the op.
-def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments]> {
+def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> {
   let summary = "Performs gradient updates of embedding tables.";
 
   let description = [{
@@ -2024,5 +2042,48 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_TPUPartitionedInputOp : TF_Op<"TPUPartitionedInput", [NoSideEffect]> {
+  let summary = [{
+An op that groups a list of partitioned inputs together. This op
+  }];
+
+  let arguments = (ins
+    Variadic<TF_Tensor>:$inputs,
+
+    DefaultValuedAttr<I64Attr, "0">:$partition_dim,
+    OptionalAttr<StrAttr>:$_XlaSharding
+  );
+
+  let results = (outs
+    TF_Tensor:$output
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
+}
+
+def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [NoSideEffect]> {
+  let summary = [{
+An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned
+  }];
+
+  let description = [{
+outputs outside the XLA computation.
+  }];
+
+  let arguments = (ins
+    TF_Tensor:$inputs,
+
+    DefaultValuedAttr<I64Attr, "0">:$partition_dim,
+    OptionalAttr<StrAttr>:$_XlaSharding
+  );
+
+  let results = (outs
+    Variadic<TF_Tensor>:$output
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>;
+}
 
 #endif // TF_OPS
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index 8271e30..9b8c186 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -44,6 +44,7 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
@@ -53,7 +54,6 @@
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
@@ -67,7 +67,9 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
+#include "tensorflow/core/framework/kernel_shape_util.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
 
 namespace mlir {
@@ -157,19 +159,13 @@
 }
 
 //===----------------------------------------------------------------------===//
-// BatchMatMulOp
+// BatchMatMulV2Op & BatchMatMulOp
 //===----------------------------------------------------------------------===//
 
-void BatchMatMulOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<BatchMatMulToMatMul>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// BatchMatMulV2Op
-//===----------------------------------------------------------------------===//
-
-static LogicalResult Verify(BatchMatMulV2Op op) {
+template <typename OpT,
+          typename std::enable_if<llvm::is_one_of<
+              OpT, BatchMatMulOp, BatchMatMulV2Op>::value>::type * = nullptr>
+static LogicalResult Verify(OpT op) {
   if (!HasRankAtLeast(op.x(), 2)) {
     return op.emitOpError("requires lhs operand to have rank at least two");
   }
@@ -185,17 +181,34 @@
   ArrayRef<int64_t> x_shape = x_ty.getShape();
   ArrayRef<int64_t> y_shape = y_ty.getShape();
 
-  // Check broadcast compatibility if both input shapes are known.
+  llvm::SmallVector<int64_t, 4> result_batch_shape;
+  llvm::ArrayRef<int64_t> x_batches = x_shape.drop_back(2);
+  llvm::ArrayRef<int64_t> y_batches = y_shape.drop_back(2);
+
+  // Check compatibility of batch dimensions if both input shapes are known.
+  // BatchMatMul should have exactly the same batch dimensions and
+  // BatchMatMulV2 should have broadcastable batch dimensions.
   //
   // The last two dimensions are non-batch dimensions that don't need to
   // participate in batch dimension compatibility check.
-
-  llvm::SmallVector<int64_t, 4> result_batch_shape;
-  if (!OpTrait::util::getBroadcastedShape(
-          x_shape.drop_back(2), y_shape.drop_back(2), result_batch_shape))
-    return op.emitOpError()
-           << "found incompatible broadcast batch dimensions for lhs shape "
-           << x_ty << " and rhs shape " << y_ty;
+  if (std::is_same<OpT, BatchMatMulOp>()) {
+    for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) {
+      int64_t x_dim = std::get<0>(dim_pairs);
+      int64_t y_dim = std::get<1>(dim_pairs);
+      if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) &&
+          x_dim != y_dim) {
+        return op.emitOpError()
+               << "found mismatching batch dimensions for lhs shape " << x_ty
+               << " and rhs shape " << y_ty;
+      }
+    }
+  } else {
+    if (!OpTrait::util::getBroadcastedShape(x_batches, y_batches,
+                                            result_batch_shape))
+      return op.emitOpError()
+             << "found incompatible broadcast batch dimensions for lhs shape "
+             << x_ty << " and rhs shape " << y_ty;
+  }
 
   RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output());
   if (!output_ty) return success();
@@ -245,6 +258,11 @@
   return success();
 }
 
+void BatchMatMulOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<BatchMatMulToV2>(context);
+}
+
 void BatchMatMulV2Op::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   results.insert<BatchMatMulV2ToMatMul>(context);
@@ -594,6 +612,12 @@
                                             ArrayRef<int64_t> s1_shape,
                                             SmallVectorImpl<int64_t> &r0,
                                             SmallVectorImpl<int64_t> &r1) {
+  r0.clear();
+  r1.clear();
+
+  // No broadcasting is required if both the shapes are equal.
+  if (s0_shape == s1_shape) return;
+
   for (int i = bcasted_shape.size(); i > 0; --i) {
     int idx = bcasted_shape.size() - i;
     int s0_idx = i > s0_shape.size() ? -1 : s0_shape.size() - i;
@@ -609,6 +633,15 @@
         r0.push_back(idx);
       else
         r1.push_back(idx);
+    } else if (s0_shape[s0_idx] == 1) {
+      // This op is used to compute the gradient dimensions requiring reduction
+      // to match the input dimensions. In case both the dimensions are one,
+      // reducing the dimension has no effect. We choose to reduce such
+      // dimensions to match the TensorFlow kernel behavior. However, note that
+      // the TF behavior in this case is inconsistent with the case with the
+      // same shapes.
+      r0.push_back(idx);
+      r1.push_back(idx);
     }
   }
 }
@@ -634,12 +667,14 @@
   GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
                                          r1);
 
-  RankedTensorType r0_ty = GetRankedTensorTypeForOperand(op.r0());
-  RankedTensorType r1_ty = GetRankedTensorTypeForOperand(op.r1());
-  if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getShape()[0] != r0.size())
+  // Verify that output types are of rank one and matches the computed result
+  // shape.
+  auto r0_ty = op.r0().getType().dyn_cast<RankedTensorType>();
+  auto r1_ty = op.r1().getType().dyn_cast<RankedTensorType>();
+  if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getDimSize(0) != r0.size())
     return op.emitOpError() << "requires dimension 0 size of 'r0' to be "
                             << r0.size() << " but got " << r0_ty.getShape()[0];
-  if (r1_ty && r1_ty.hasStaticShape() && r1_ty.getShape()[0] != r1.size())
+  if (r1_ty && r1_ty.hasStaticShape() && r1_ty.getDimSize(0) != r1.size())
     return op.emitOpError() << "requires dimension 0 size of 'r1' to be "
                             << r1.size() << " but got " << r1_ty.getShape()[0];
 
@@ -1275,7 +1310,7 @@
   results.reserve(shapes.size());
   SmallVector<int32_t, 4> cumulative_sum(num_dims, 0);
   RankedTensorType offset_type =
-      RankedTensorType::get({num_dims}, IntegerType::get(32, getContext()));
+      RankedTensorType::get({num_dims}, IntegerType::get(getContext(), 32));
   for (DenseIntElementsAttr shape : shapes) {
     results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum));
     cumulative_sum[concat_dim] += shape.getValue<int32_t>(concat_dim);
@@ -1350,80 +1385,38 @@
 // Conv2DOp and Conv3DOp
 //===----------------------------------------------------------------------===//
 
-template <typename OpT>
-static LogicalResult VerifyConvOpAttributes(OpT op, int num_dims) {
-  if (!IsOfRankOrUnranked(op.getResult(), num_dims))
-    return op.emitOpError()
-           << "requires result to be " << num_dims << "D tensor";
-
+static LogicalResult VerifyConvOpAttributes(
+    int num_dims, ArrayRef<Attribute> strides, ArrayRef<Attribute> dilations,
+    llvm::Optional<mlir::Location> location) {
+  int64_t strides_size = strides.size();
+  if (strides_size != num_dims)
+    return emitOptionalError(
+        location, "requires strides attribute length to be ", num_dims);
   auto is_not_positive = [](Attribute val) {
     return val.cast<IntegerAttr>().getValue().getSExtValue() <= 0;
   };
+  if (llvm::any_of(strides, is_not_positive))
+    return emitOptionalError(location, "requires positive strides");
 
-  int64_t strides_size = op.strides().size();
-  if (strides_size != num_dims)
-    return op.emitOpError() << "requires strides attribute length to be "
-                            << num_dims << "; actual length " << strides_size;
-  if (llvm::any_of(op.strides().getValue(), is_not_positive))
-    return op.emitOpError("requires positive strides");
-
-  int64_t dilations_size = op.strides().size();
-  if (op.dilations().size() != num_dims)
-    return op.emitOpError() << "requires dilations attribute length to be "
-                            << num_dims << "; actual length " << dilations_size;
-  if (llvm::any_of(op.dilations().getValue(), is_not_positive))
-    return op.emitOpError("requires positive dilations");
+  int64_t dilations_size = dilations.size();
+  if (dilations_size != num_dims)
+    return emitOptionalError(
+        location, "requires dilations attribute length to be ", num_dims);
+  if (llvm::any_of(dilations, is_not_positive))
+    return emitOptionalError(location, "requires positive dilations");
 
   return success();
 }
 
 // Verifies that,
-// * Ranks of operands and result are valid
 // * Number of input channels is divisible by the number of filter input
 //   channels
-// * Length of explicit_paddings attribute is valid and has non negative
-//   elements
-// * strides and dilations attributes have positive elements
 template <typename OpT, typename std::enable_if<llvm::is_one_of<
                             OpT, Conv2DOp, Conv3DOp>::value>::type * = nullptr>
 static LogicalResult Verify(OpT op) {
   int num_spatial_dims = std::is_same<OpT, Conv2DOp>() ? 2 : 3;
   int num_dims = 2 + num_spatial_dims;
 
-  if (!IsOfRankOrUnranked(op.input(), num_dims) ||
-      !IsOfRankOrUnranked(op.filter(), num_dims))
-    return op.emitOpError()
-           << "requires operands to be " << num_dims << "D tensor";
-
-  // EXPLICIT padding mode and the associated attribute is limited to Conv2D.
-  // So, fetch attribute by string instead of the op.explicit_paddings()
-  // attribute getter.
-  if (op.padding() == "EXPLICIT") {
-    auto paddings = op.template getAttrOfType<ArrayAttr>("explicit_paddings");
-    if (!paddings)
-      return op.emitOpError() << "requires attribute 'explicit_paddings' with "
-                                 "'EXPLICIT' padding mode";
-
-    int64_t paddings_size = paddings.size();
-    int64_t expected_size = 2 * num_dims;
-
-    if (paddings_size != expected_size)
-      return op.emitOpError()
-             << "requires explicit_paddings attribute length to be "
-             << expected_size << "; actual length " << paddings_size;
-
-    auto is_negative = [](Attribute val) {
-      return val.cast<IntegerAttr>().getValue().getSExtValue() < 0;
-    };
-    if (llvm::any_of(paddings.getValue(), is_negative))
-      return op.emitOpError("requires non negative explicit paddings");
-  }
-
-  LogicalResult verify_result = VerifyConvOpAttributes(op, num_dims);
-  if (failed(verify_result)) {
-    return verify_result;
-  }
-
   int64_t input_channels = -1;
   if (auto ty = op.input().getType().template dyn_cast<RankedTensorType>()) {
     absl::string_view data_format(op.data_format().data(),
@@ -1460,13 +1453,143 @@
   if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
 
   // Update convolution attributes.
-  setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
-  setAttr("strides", ShuffleArrayAttr(strides(), perm));
-  setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2));
+  (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
+  (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
+  (*this)->setAttr("explicit_paddings",
+                   ShuffleArrayAttr(explicit_paddings(), perm, 2));
 
   return success();
 }
 
+// Verifies the inferred return type of the given operation.
+template <typename OpT,
+          typename std::enable_if<llvm::is_one_of<
+              OpT, Conv2DOpAdaptor, Conv3DOpAdaptor>::value>::type * = nullptr>
+static LogicalResult inferConvReturnTypes(
+    OpT op, llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes,
+    llvm::Optional<mlir::Location> location,
+    ArrayRef<Attribute> explicit_padding) {
+  const int64_t num_spatial_dims = std::is_same<OpT, Conv2DOpAdaptor>() ? 2 : 3;
+  const int64_t num_dims = 2 + num_spatial_dims;
+  const Value input = op.input();
+  const Value filter = op.filter();
+  const TensorType input_ty = input.getType().template cast<TensorType>();
+  const TensorType filter_ty = filter.getType().template cast<TensorType>();
+  const StringRef paddings = op.padding().getValue();
+
+  ArrayRef<Attribute> strides = op.strides().getValue();
+  StringRef data_format = op.data_format().getValue();
+  ArrayRef<Attribute> dilations = op.dilations().getValue();
+
+  tensorflow::TensorFormat format;
+  auto data_format_is_valid = FormatFromString(data_format.str(), &format);
+  if (!data_format_is_valid) {
+    return emitOptionalError(location, "Invalid data format provided");
+  }
+  tensorflow::Padding padding;
+  auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
+  if (!padding_is_valid.ok()) {
+    return emitOptionalError(location, "Invalid padding format provided");
+  }
+  auto get_int = [](Attribute attr) {
+    return attr.template cast<IntegerAttr>().getInt();
+  };
+
+  // Necessary sanity checks.
+  // Verifies that,
+  // * Ranks of operands and result are valid
+  // * Length of explicit_paddings attribute is valid and has non negative
+  //   elements
+  // * strides and dilations attributes have positive elements
+  if (!IsOfRankOrUnranked(input, num_dims) ||
+      !IsOfRankOrUnranked(filter, num_dims))
+    return emitOptionalError(location, "requires operands to be ", num_dims,
+                             "D tensor");
+
+  if (padding == tensorflow::Padding::EXPLICIT) {
+    if (explicit_padding.size() == 0) {
+      return emitOptionalError(location,
+                               "requires attribute 'explicit_paddings' with "
+                               "'EXPLICIT' padding mode");
+    }
+    if (explicit_padding.size() != num_dims * 2) {
+      return emitOptionalError(
+          location, "requires explicit_paddings attribute length to be ",
+          num_dims * 2);
+    }
+    auto is_negative = [](Attribute val) {
+      return val.cast<IntegerAttr>().getValue().getSExtValue() < 0;
+    };
+    if (llvm::any_of(explicit_padding, is_negative))
+      return emitOptionalError(location,
+                               "requires non negative explicit paddings");
+  }
+
+  if (failed(VerifyConvOpAttributes(num_dims, strides, dilations, location))) {
+    return failure();
+  }
+
+  // For operands having dynamic shape.
+  SmallVector<int64_t, 4> return_shape(num_dims, ShapedType::kDynamicSize);
+  if (!input_ty.hasStaticShape() || !filter_ty.hasStaticShape()) {
+    inferredReturnTypes.assign(
+        {RankedTensorType::get(return_shape, input_ty.getElementType())});
+    return success();
+  }
+
+  // Checks the size of each of the output dimension.
+  for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
+    const int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
+    int64_t stride = get_int(strides[dim]);
+    tensorflow::int64 expected_output_size;
+    tensorflow::int64 pad_low;
+    tensorflow::int64 pad_high;
+    // Retrieve padding, if defined explicitly.
+    if (padding == tensorflow::Padding::EXPLICIT) {
+      pad_low = get_int(explicit_padding[2 * dim]);
+      pad_high = get_int(explicit_padding[2 * dim + 1]);
+    }
+    // Calculate the expected_output_size.
+    tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
+        input_ty.getDimSize(dim), filter_ty.getDimSize(i),
+        get_int(dilations[dim]), stride, padding, &expected_output_size,
+        &pad_low, &pad_high);
+    // Return failure if expected_output_size could not be calculated.
+    if (!status.ok()) return failure();
+    return_shape[dim] = expected_output_size;
+  }
+
+  // The remaining dimensions can be obtained using utilities from
+  // tensorflow/core/util/tensor_format.h.
+  return_shape[GetTensorBatchDimIndex(num_dims, format)] =
+      input_ty.getShape()[GetTensorBatchDimIndex(num_dims, format)];
+  return_shape[GetTensorFeatureDimIndex(num_dims, format)] =
+      filter_ty.getShape()[GetFilterTensorOutputChannelsDimIndex(
+          num_dims, tensorflow::FORMAT_HWIO)];
+
+  inferredReturnTypes.assign(
+      {RankedTensorType::get(return_shape, input_ty.getElementType())});
+  return success();
+}
+
+LogicalResult Conv2DOp::inferReturnTypes(
+    mlir::MLIRContext *context, llvm::Optional<mlir::Location> location,
+    mlir::ValueRange operands, mlir::DictionaryAttr attributes,
+    mlir::RegionRange regions,
+    llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
+  Conv2DOpAdaptor op(operands, attributes);
+  ArrayRef<Attribute> explicit_padding;
+  ArrayAttr explicit_pad =
+      attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
+  if (!explicit_pad) {
+    explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
+  }
+  explicit_padding = explicit_pad.getValue();
+
+  return inferConvReturnTypes(op, inferredReturnTypes, location,
+                              explicit_padding);
+}
+
 StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) {
   // Keep current data format if no GPUs are available or if explicit placement
   // does not allow to use GPU for this operation.
@@ -1534,9 +1657,10 @@
   if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
 
   // Update convolution attributes.
-  setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
-  setAttr("strides", ShuffleArrayAttr(strides(), perm));
-  setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2));
+  (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
+  (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
+  (*this)->setAttr("explicit_paddings",
+                   ShuffleArrayAttr(explicit_paddings(), perm, 2));
 
   // Permute filter sizes operand.
   OpBuilder builder(getOperation());
@@ -1580,8 +1704,15 @@
       !IsOfRankOrUnranked(op.filter(), num_dims))
     return op.emitOpError()
            << "requires operands to be " << num_dims << "D tensor";
+  if (!IsOfRankOrUnranked(op.getResult(), num_dims))
+    return op.emitOpError()
+           << "requires result to be " << num_dims << "D tensor";
 
-  LogicalResult verify_result = VerifyConvOpAttributes(op, num_dims);
+  llvm::Optional<mlir::Location> location = op.getLoc();
+  ArrayRef<Attribute> strides = op.strides().getValue();
+  ArrayRef<Attribute> dilations = op.dilations().getValue();
+  LogicalResult verify_result =
+      VerifyConvOpAttributes(num_dims, strides, dilations, location);
   if (failed(verify_result)) {
     return verify_result;
   }
@@ -1599,9 +1730,10 @@
   if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
 
   // Update convolution attributes.
-  setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
-  setAttr("strides", ShuffleArrayAttr(strides(), perm));
-  setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2));
+  (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
+  (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
+  (*this)->setAttr("explicit_paddings",
+                   ShuffleArrayAttr(explicit_paddings(), perm, 2));
 
   // Permute input sizes operand.
   OpBuilder builder(getOperation());
@@ -1634,6 +1766,28 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Conv3DOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult Conv3DOp::inferReturnTypes(
+    mlir::MLIRContext *context, llvm::Optional<mlir::Location> location,
+    mlir::ValueRange operands, mlir::DictionaryAttr attributes,
+    mlir::RegionRange regions,
+    llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
+  Conv3DOpAdaptor op(operands, attributes);
+  ArrayRef<Attribute> explicit_padding;
+  ArrayAttr explicit_pad =
+      attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
+  if (!explicit_pad) {
+    explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
+  }
+  explicit_padding = explicit_pad.getValue();
+
+  return inferConvReturnTypes(op, inferredReturnTypes, location,
+                              explicit_padding);
+}
+
+//===----------------------------------------------------------------------===//
 // DataFormatVecPermuteOp
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h
index fe788ac..90cd1c2 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h
@@ -20,10 +20,10 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
 #include "mlir/Interfaces/ControlFlowInterfaces.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc
index 72ca50b..dddb9bc 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc
@@ -370,7 +370,7 @@
   if (perm.empty()) return failure();
 
   // Update data format attribute.
-  op->setAttr("data_format", StringAttr::get(data_format, op->getContext()));
+  (*op)->setAttr("data_format", StringAttr::get(data_format, op->getContext()));
 
   // Update types for all layout sensitive results.
   auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation());
@@ -388,7 +388,7 @@
 LogicalResult FoldOperandsPermutation(
     ArrayRef<int64_t> permutation, Op *op,
     ArrayRef<std::pair<StringRef, ArrayAttr>> shuffle_attrs = {}) {
-  MLIRContext *context = op->template getParentOfType<ModuleOp>().getContext();
+  MLIRContext *context = (*op)->template getParentOfType<ModuleOp>().getContext();
 
   // We only support NHWC <-> NCHW permutations.
   static constexpr std::array<int64_t, 4> kNchwToNhwc = {0, 2, 3, 1};
@@ -421,12 +421,12 @@
       GetDataFormatPermutation(op->data_format(), target_data_format);
   if (reverse_permutation.empty()) return failure();
 
-  op->setAttr("data_format", StringAttr::get(target_data_format, context));
+  (*op)->setAttr("data_format", StringAttr::get(target_data_format, context));
 
   for (auto pair : shuffle_attrs) {
     StringRef attr_name = pair.first;
     ArrayAttr attr_value = pair.second;
-    op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
+    (*op)->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
   }
 
   auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation());
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index 95858b3..8b717ec 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -45,6 +45,7 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
@@ -54,7 +55,6 @@
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
@@ -400,8 +400,8 @@
 
 template <class OpClass>
 static LogicalResult VerifyPartitionedCall(OpClass op) {
-  auto module = op.template getParentOfType<ModuleOp>();
-  SymbolRefAttr func = op.getAttr("f").template cast<SymbolRefAttr>();
+  auto module = op->template getParentOfType<ModuleOp>();
+  SymbolRefAttr func = op->getAttr("f").template cast<SymbolRefAttr>();
 
   auto function =
       dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
@@ -532,7 +532,11 @@
   auto ranked_type = type.dyn_cast<RankedTensorType>();
   if (!ranked_type) return {};
 
-  auto output_type = getType().cast<ShapedType>();
+  // DenseIntElementsAttr::get requires the output type be ranked with static
+  // shape.
+  auto output_type = getType().dyn_cast<RankedTensorType>();
+  if (!output_type || !output_type.hasStaticShape()) return {};
+
   int32_t rank = ranked_type.getRank();
   return DenseIntElementsAttr::get(output_type, rank);
 }
@@ -894,7 +898,7 @@
     dimensions.push_back(APInt(out_width, shape[i]));
 
   auto result_type = RankedTensorType::get(
-      {rank}, IntegerType::get(out_width, input_ty.getContext()));
+      {rank}, IntegerType::get(input_ty.getContext(), out_width));
   return DenseElementsAttr::get(result_type, dimensions);
 }
 
@@ -2375,7 +2379,7 @@
     // If the types don't match then only fold if all the operands are in the TF
     // dialect.
     for (auto user : op.getOperation()->getUsers())
-      if (user->getDialect() != op.getDialect()) return {};
+      if (user->getDialect() != op->getDialect()) return {};
   }
 
   return op.x();
@@ -2514,6 +2518,74 @@
   return success();
 }
 
+namespace {
+
+// Hoist coefficient-wise unary operation out of the Unpack op:
+//
+//   %unpacked:N = "tf.Unpack"(%0)
+//   %neg0 = "tf.Neg"(%unpacked#0)
+//   %neg1 = "tf.Neg"(%unpacked#1)
+//   ...
+//   %negN-1 = "tf.Neg"(%unpacked:N-1)
+//
+// Rewrite it to:
+//
+//   %neg = "tf.Neg"(%0)
+//   %unpacked:N = "tf.Unpack"(%neg)
+class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern<UnpackOp> {
+ public:
+  explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context)
+      : OpRewritePattern<UnpackOp>(context) {}
+  LogicalResult matchAndRewrite(UnpackOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite(
+    UnpackOp op, PatternRewriter &rewriter) const {
+  auto loc = op.getLoc();
+
+  // First unpack user must be coeff-wise unary operation.
+  Operation *first_user = *op->getUsers().begin();
+  if (!first_user->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
+
+  // All unpack users must be defined by the op of same kind.
+  bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) {
+    return user->getName() == first_user->getName();
+  });
+  if (!users_same_op) return failure();
+
+  // Pass unpack operand to unary operation.
+  OperationState new_unary_op_state(loc, first_user->getName().getStringRef(),
+                                    op.getOperand(), op.getOperand().getType(),
+                                    ArrayRef<NamedAttribute>());
+  Operation *new_unary_op = rewriter.createOperation(new_unary_op_state);
+
+  // Unpack results after applying unary operation.
+  auto unpack_unary_op = rewriter.create<UnpackOp>(
+      loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis());
+
+  // Bypass all users of the original unpack operation and use `unpack_unary_op`
+  // results instead.
+  for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) {
+    OpResult old_result = std::get<0>(pair);  // result of original Unpack
+    OpResult new_result = std::get<1>(pair);  // result of transformed Unpack
+    for (Operation *user : llvm::make_early_inc_range(old_result.getUsers()))
+      rewriter.replaceOp(user, ValueRange(new_result));
+  }
+
+  // Erase original unpack operation.
+  rewriter.eraseOp(op.getOperation());
+
+  return success();
+}
+
+}  // namespace
+
+void UnpackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                           MLIRContext *context) {
+  results.insert<HoistCwiseUnaryOutOfUnpack>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Unsorted segment reduction ops
 //===----------------------------------------------------------------------===//
@@ -2897,6 +2969,18 @@
   results.insert<XdivyWithSqrtDivisor>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// XlaSetDynamicDimensionSizeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.assign({operands.front().getType()});
+  return success();
+}
+
 }  // namespace TF
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h
index 353c1c6..eef1b6c 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h
@@ -20,10 +20,10 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
 #include "mlir/Interfaces/ControlFlowInterfaces.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc
index e5162b0..70282b5 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc
@@ -42,6 +42,7 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
@@ -51,7 +52,6 @@
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h
index 01a93e1..62caa9c 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h
@@ -20,9 +20,9 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
 #include "mlir/Interfaces/DerivedAttributeOpInterface.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc
index 55cb512..3edcbf5 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc
@@ -25,10 +25,10 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -79,7 +79,7 @@
 
 static LogicalResult Verify(SessionInitializerOp session_initializer) {
   mlir::SymbolTable symbol_table(
-      session_initializer.getParentOfType<ModuleOp>());
+      session_initializer->getParentOfType<ModuleOp>());
 
   for (auto sym_ref : session_initializer.initializers()) {
     auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
@@ -327,7 +327,7 @@
 
 LogicalResult VerifyExportedFunc(FuncOp func) {
   bool reached_bound_inputs = false;
-  auto module = func.getParentOfType<ModuleOp>();
+  auto module = func->getParentOfType<ModuleOp>();
   for (int i = 0, e = func.getNumArguments(); i < e; i++) {
     if (func.getArgAttr(i, "tf_saved_model.bound_input")) {
       reached_bound_inputs = true;
@@ -342,7 +342,7 @@
       continue;
     }
     if (func.getArgAttr(i, "tf.resource_name")) {
-      if (module.getAttr("tf_saved_model.under_construction")) continue;
+      if (module->getAttr("tf_saved_model.under_construction")) continue;
       return func.emitError() << "'tf.resource_name' attribute is not allowed "
                                  "unless it is being under construction";
     }
@@ -355,7 +355,7 @@
     if (auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
             i, "tf_saved_model.bound_input")) {
       if (!unique_bound_inputs.insert(attr.getValue()).second) {
-        if (module.getAttr("tf_saved_model.under_construction")) continue;
+        if (module->getAttr("tf_saved_model.under_construction")) continue;
         return func.emitError()
                << "duplicate 'tf_saved_model.bound_input' binding";
       }
@@ -431,7 +431,7 @@
 }
 
 bool HasTfSavedModelSemantics(ModuleOp module) {
-  return module.getAttr("tf_saved_model.semantics") != nullptr;
+  return module->getAttr("tf_saved_model.semantics") != nullptr;
 }
 
 Operation *LookupBoundInput(FuncOp func, int arg_index,
@@ -455,7 +455,7 @@
 
   LogicalResult matchAndRewrite(SessionInitializerOp op,
                                 PatternRewriter &rewriter) const override {
-    SymbolTable symbol_table(op.getParentOfType<ModuleOp>());
+    SymbolTable symbol_table(op->getParentOfType<ModuleOp>());
 
     SmallVector<FuncOp, 2> to_remove;
     SmallVector<mlir::Attribute, 2> to_keep;
@@ -483,7 +483,7 @@
     if (to_keep.empty())
       rewriter.eraseOp(op);
     else
-      op.setAttr("initializers", rewriter.getArrayAttr(to_keep));
+      op->setAttr("initializers", rewriter.getArrayAttr(to_keep));
 
     return success();
   }
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h
index b90bf2d..98d2a49 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h
@@ -19,14 +19,13 @@
 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_STRUCTS_H_
 
 #include "llvm/ADT/StringMap.h"
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
-#include "tensorflow/core/util/device_name_utils.h"
-
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h.inc"
+#include "tensorflow/core/util/device_name_utils.h"
 
 namespace mlir {
 namespace TF {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h
index aef3c53..db76bd5 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h
@@ -18,8 +18,8 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
 
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -66,6 +66,39 @@
   }
 };
 
+namespace detail {
+inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef(
+    Operation* op) {
+  Type element_type;
+  if (op->getNumResults() > 0) {
+    element_type =
+        mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
+  } else if (op->getNumOperands() > 0) {
+    element_type =
+        mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
+  } else {
+    // Nothing to check.
+    return success();
+  }
+  // Verify that all result element types are compatible to `element_type`.
+  for (const auto& result_type : op->getResultTypes()) {
+    if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != element_type) {
+      return op->emitOpError(
+          "requires compatible element types for all operands and results");
+    }
+  }
+  // Verify that all operand element types are compatible to `element_type`.
+  for (const auto& operand_type : op->getOperandTypes()) {
+    if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) !=
+        element_type) {
+      return op->emitOpError(
+          "requires compatible element types for all operands and results");
+    }
+  }
+  return success();
+}
+}  // namespace detail
+
 // Verifies that op has the same operand and result element types (or type
 // itself, if scalar) after resolving reference types (i.e., after converting
 // reference types to their corresponding TensorFlow or standard types).
@@ -75,34 +108,20 @@
                        SameOperandsAndResultElementTypeResolveRef> {
  public:
   static LogicalResult verifyTrait(Operation* op) {
-    Type element_type;
-    if (op->getNumResults() > 0) {
-      element_type =
-          mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
-    } else if (op->getNumOperands() > 0) {
-      element_type =
-          mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
-    } else {
-      // Nothing to check.
-      return success();
-    }
-    // Verify that all result element types are compatible to `element_type`.
-    for (const auto& result_type : op->getResultTypes()) {
-      if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) !=
-          element_type) {
-        return op->emitOpError(
-            "requires compatible element types for all operands and results");
-      }
-    }
-    // Verify that all operand element types are compatible to `element_type`.
-    for (const auto& operand_type : op->getOperandTypes()) {
-      if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) !=
-          element_type) {
-        return op->emitOpError(
-            "requires compatible element types for all operands and results");
-      }
-    }
-    return success();
+    return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
+  }
+};
+
+// Verifies that op has the same operand and result types after resolving
+// reference types (i.e., after converting reference types to their
+// corresponding TensorFlow or standard types).
+template <typename ConcreteType>
+class SameOperandsAndResultTypeResolveRef
+    : public TraitBase<ConcreteType, SameOperandsAndResultTypeResolveRef> {
+ public:
+  static LogicalResult verifyTrait(Operation* op) {
+    if (failed(impl::verifySameOperandsAndResultShape(op))) return failure();
+    return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
   }
 };
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc
index 86369b9..0b21b86 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc
@@ -17,8 +17,8 @@
 
 #include "llvm/Support/ErrorHandling.h"
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 
 namespace {
@@ -155,19 +155,19 @@
   if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
   if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
   if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
-  if (isa<BoolRefType>()) return mlir::IntegerType::get(1, ctx);
-  if (isa<Int8RefType>()) return mlir::IntegerType::get(8, ctx);
-  if (isa<Int16RefType>()) return mlir::IntegerType::get(16, ctx);
-  if (isa<Int32RefType>()) return mlir::IntegerType::get(32, ctx);
-  if (isa<Int64RefType>()) return mlir::IntegerType::get(64, ctx);
+  if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
+  if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
+  if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
+  if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
+  if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
   if (isa<Uint8RefType>())
-    return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
+    return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
   if (isa<Uint16RefType>())
-    return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
+    return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
   if (isa<Uint32RefType>())
-    return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
+    return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
   if (isa<Uint64RefType>())
-    return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
+    return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
   if (isa<Complex64RefType>())
     return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
   if (isa<Complex128RefType>())
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h
index 1d3ca0c..52021a2 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h
@@ -18,10 +18,10 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
 
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc
index 6a6a757..69a1bf0 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc
@@ -15,7 +15,7 @@
 
 #include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h"
 
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
index f06b226..ffa3394 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
@@ -291,7 +291,8 @@
     %source:3 = tf_executor.NextIteration.Source : tensor<*xi32>
     %island:2 = tf_executor.island {
       %const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
-      %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
+      %print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>)
+
       tf_executor.yield %const : tensor<*xi32>
     }
     tf_executor.NextIteration.Sink[%source#1] %island#0 : tensor<*xi32>
@@ -306,7 +307,7 @@
   tf_executor.graph {
     %island:2 = tf_executor.island {
       %const = "tf.Const"() {value = dense<1> : tensor<i1>} : () -> tensor<*xi1>
-      %print = "tf.Print"(%const) : (tensor<*xi1>) -> (tensor<*xi1>)
+      %print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi1>) -> (tensor<*xi1>)
       tf_executor.yield %const : tensor<*xi1>
     }
     %loop_cond:2 = tf_executor.LoopCond %island#0 : tensor<*xi1>
@@ -321,7 +322,7 @@
   tf_executor.graph {
     %island:2 = tf_executor.island {
       %const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
-      %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
+      %print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>)
       tf_executor.yield %const : tensor<*xi32>
     }
     %enter:2 = tf_executor.Enter %island#0 frame "some/frame" : tensor<*xi32>
@@ -336,7 +337,7 @@
   tf_executor.graph {
     %island:2 = tf_executor.island {
       %const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
-      %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
+      %print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>)
       tf_executor.yield %const : tensor<*xi32>
     }
     %switchn:4 = tf_executor._SwitchN %island#0, %arg1 of 3: tensor<*xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index fa04ac0..e344d32 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -16,6 +16,20 @@
   return
 }
 
+// CHECK-LABEL: testBatchMatMulToV2
+func @testBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>) -> tensor<2x3x7xf32> {
+  // CHECK: tf.BatchMatMulV2
+  %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32>
+  return %0: tensor<2x3x7xf32>
+}
+
+// CHECK-LABEL: testDynamicBatchMatMulToV2
+func @testDynamicBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<?x5x7xf32>) -> tensor<2x3x7xf32> {
+  // CHECK: tf.BatchMatMul
+  %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3x5xf32>, tensor<?x5x7xf32>) -> tensor<2x3x7xf32>
+  return %0: tensor<2x3x7xf32>
+}
+
 // CHECK-LABEL: testBatchMatMulToMatMul
 func @testBatchMatMulToMatMul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> tensor<2x2xf32> {
   %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32>
@@ -890,6 +904,20 @@
   return %0 : tensor<i32>
 }
 
+// CHECK-LABEL: testRankOfRankedTensorUnrankedOutput
+func @testRankOfRankedTensorUnrankedOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<*xi32> {
+  // Regression test to make sure we don't crash in this case.
+  %0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<*xi32>
+  return %0 : tensor<*xi32>
+}
+
+// CHECK-LABEL: testRankOfRankedTensorDynamicShapeOutput
+func @testRankOfRankedTensorDynamicShapeOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<?xi32> {
+  // Regression test to make sure we don't crash in this case.
+  %0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+
 // CHECK-LABEL: @foldFill
 func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
   %0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
@@ -1313,3 +1341,16 @@
   return
 }
 
+// CHECK-LABEL: testUnpackAndCwiseUnary
+func @testUnpackAndCwiseUnary(%arg0: tensor<?x2xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+
+  // CHECK: %[[NEG:.*]] = "tf.Neg"(%arg0)
+  // CHECK: %[[UNPACK:.*]]:2 = "tf.Unpack"(%[[NEG]])
+  %unpacked:2 = "tf.Unpack"(%arg0) {axis = 1 : i64, device = ""}
+                : (tensor<?x2xf32>) -> (tensor<?xf32>, tensor<?xf32>)
+  %0 = "tf.Neg"(%unpacked#0): (tensor<?xf32>) -> tensor<?xf32>
+  %1 = "tf.Neg"(%unpacked#1): (tensor<?xf32>) -> tensor<?xf32>
+
+  // CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1
+  return %0, %1 : tensor<?xf32>, tensor<?xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/add.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/add.mlir
index 84e3f52..feb0f42 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/add.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/add.mlir
@@ -1,5 +1,7 @@
 // RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-return-tuple | FileCheck %s
 // RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck -check-prefix=TUPLE-ARGS %s
+// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: | FileCheck -check-prefix=NO_RET_TUPLE %s
+// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=: | FileCheck -check-prefix=NO_RET_TUPLE %s
 
 module attributes {tf.versions = {producer = 179 : i32}} {
   func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
@@ -36,3 +38,16 @@
 // TUPLE-ARGS-NEXT:  // XlaInputShape (f32[], f32[])
 // TUPLE-ARGS-NEXT:  // XlaOutputShape (f32[])
 // TUPLE-ARGS-NEXT:  // XlaOutputDescription type=float shape=()
+
+
+// NO_RET_TUPLE-LABEL: HloModule main{{[.0-9]*}}
+// NO_RET_TUPLE:       ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] {
+// NO_RET_TUPLE-NEXT:    %[[ARG0]] = f32[] parameter(0)
+// NO_RET_TUPLE-NEXT:    %[[ARG1]] = f32[] parameter(1)
+// NO_RET_TUPLE-NEXT:    ROOT [[ADD:%.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
+
+// NO_RET_TUPLE:       // InputMapping {0, 1}
+// NO_RET_TUPLE-NEXT:  // XlaInputShape f32[]
+// NO_RET_TUPLE-NEXT:  // XlaInputShape f32[]
+// NO_RET_TUPLE-NEXT:  // XlaOutputShape (f32[])
+// NO_RET_TUPLE-NEXT:  // XlaOutputDescription type=float shape=()
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding-hook.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding-hook.mlir
index c745fbc..37608b8 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding-hook.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding-hook.mlir
@@ -1,4 +1,6 @@
 // RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck %s
+// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: | FileCheck -check-prefix=NO_TUPLES %s
+// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=: | FileCheck -check-prefix=NO_TUPLES %s
 
 module attributes {tf.versions = {producer = 179 : i32}} {
   func @main() -> (tensor<0xi32>, tensor<0xi32>) {
@@ -14,3 +16,9 @@
 // CHECK:         [[CONSTANT:%.*]] = s32[0]{0} constant({})
 // CHECK:         ROOT %tuple.{{[0-9]+}} = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} [[CONSTANT]], s32[0]{0} [[CONSTANT]])
 // CHECK:       }
+
+// NO_TUPLES-LABEL: HloModule main{{.[0-9+]}}
+// NO_TUPLES:       ENTRY %main.{{[0-9+]}} () -> (s32[0], s32[0]) {
+// NO_TUPLES:         [[CONSTANT:%.*]] = s32[0]{0} constant({})
+// NO_TUPLES:         ROOT %tuple.{{[0-9]+}} = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} [[CONSTANT]], s32[0]{0} [[CONSTANT]])
+// NO_TUPLES:       }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference-after-legalization.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference-after-legalization.mlir
index 55bdea5..5e9b652 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference-after-legalization.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference-after-legalization.mlir
@@ -1,4 +1,5 @@
 // RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=8,16,16,64:64 -emit-use-tuple-args -emit-return-tuple | FileCheck %s
+// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=8,16,16,64:64 | FileCheck %s
 
 module attributes {tf.versions = {producer = 179 : i32}} {
   func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference.mlir
index f9eca51..16a11af 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference.mlir
@@ -1,4 +1,6 @@
 // RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,17:17,19 -emit-use-tuple-args -emit-return-tuple | FileCheck %s
+// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,17:17,19 | FileCheck -check-prefix=NO_TUPLES %s
+// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=10,17:17,19 | FileCheck -check-prefix=NO_TUPLES %s
 
 module attributes {tf.versions = {producer = 179 : i32}} {
   func @main(%arg0: tensor<*xf32>, %arg1: tensor<?x19xf32>) -> tensor<?x19xf32> {
@@ -9,3 +11,6 @@
 
 // CHECK-LABEL: HloModule main
 // CHECK:       (arg_tuple.{{[0-9]+}}: (f32[10,17], f32[17,19])) -> (f32[10,19])
+
+// NO_TUPLES-LABEL: HloModule main{{.[0-9]*}}
+// NO_TUPLES:       ({{.+}}: f32[10,17], {{.+}}: f32[17,19]) -> f32[10,19]
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index 528d26c..b5455bb 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -300,19 +300,6 @@
 // CHECK-NEXT: return %[[CST]], %[[CST1]]
 }
 
-// Tests ops that have non-local device assignment but with local device with
-// same type (CPU) are correctly evaluated.
-// CHECK-LABEL: func @testRemoteDevice() -> tensor<2x2xi32>
-func @testRemoteDevice() -> tensor<2x2xi32> {
-^bb0:
-  %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
-  %1 = constant dense<1> : tensor<2xi32>
-  %2 = "tf.Add"(%0, %1) {device = "/job:remote_worker/replica:123/task:456/CPU:0", name = "add"} : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
-  // CHECK:         [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[\[}}1, 2], {{\[}}3, 4]]> : tensor<2x2xi32>
-  // CHECK-NEXT:    return [[cst]] : tensor<2x2xi32>
-  return %2: tensor<2x2xi32>
-}
-
 // Tests ops that variable shapes are correctly evaluated on static types.
 // CHECK-LABEL: func @testVariableShape
 func @testVariableShape(%arg0: tensor<!tf.resource<tensor<2x4xf32>>>) -> tensor<2xi32> {
@@ -512,17 +499,93 @@
   return %2 : tensor<8xf32>
 }
 
-// CHECK-LABEL: func @testBroadcastGradientArgs
-func @testBroadcastGradientArgs() -> (tensor<1xi32>, tensor<0xi32>) {
+// CHECK-LABEL: func @testBroadcastGradientArgsSameShape
+func @testBroadcastGradientArgsSameShape() -> (tensor<0xi32>, tensor<0xi32>) {
+  %s0 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  %s1 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) {} : (tensor<2xi32>, tensor<2xi32>) -> (tensor<0xi32>, tensor<0xi32>)
+
+  // CHECK-DAG: %[[R:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK-NOT: tf.BroadcastGradientArgs
+  // CHECK: return %[[R]], %[[R]]
+
+  return %r0, %r1 : tensor<0xi32>, tensor<0xi32>
+}
+
+// CHECK-LABEL: func @testBroadcastGradientArgs1
+func @testBroadcastGradientArgs1() -> (tensor<1xi32>, tensor<0xi32>) {
   %s0 = "tf.Const"() {value = dense<[4]> : tensor<1xi32>} : () -> tensor<1xi32>
   %s1 = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
   %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) {} : (tensor<1xi32>, tensor<2xi32>) -> (tensor<1xi32>, tensor<0xi32>)
+  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK-NOT: tf.BroadcastGradientArgs
+  // CHECK: return %[[R0]], %[[R1]]
+
+  return %r0, %r1 : tensor<1xi32>, tensor<0xi32>
+}
+
+// CHECK-LABEL: func @testBroadcastGradientArgs2
+func @testBroadcastGradientArgs2() -> (tensor<1xi32>, tensor<3xi32>) {
+  %s2 = "tf.Const"() {value = dense<[501, 1, 32, 1280]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %s3 = "tf.Const"() {value = dense<[  1, 1,  1, 1280]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %r2, %r3 = "tf.BroadcastGradientArgs"(%s2, %s3) {} : (tensor<4xi32>, tensor<4xi32>) -> (tensor<1xi32>, tensor<3xi32>)
+  // CHECK-DAG: %[[R2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-DAG: %[[R3:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK-NOT: tf.BroadcastGradientArgs
+  // CHECK: return %[[R2]], %[[R3]]
+
+  return %r2, %r3 : tensor<1xi32>, tensor<3xi32>
+}
+
+// CHECK-LABEL: func @testBroadcastGradientArgs3
+func @testBroadcastGradientArgs3() -> (tensor<3xi32>, tensor<3xi32>) {
+  %s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
+  %s5 = "tf.Const"() {value = dense<[1, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>)
+  // CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK-NOT: tf.BroadcastGradientArgs
+  // CHECK: return %[[R0]], %[[R0]]
+
+  return %r4, %r5 : tensor<3xi32>, tensor<3xi32>
+}
+
+// CHECK-LABEL: func @testBroadcastGradientArgs4
+func @testBroadcastGradientArgs4() -> (tensor<2xi32>, tensor<3xi32>) {
+  %s4 = "tf.Const"() {value = dense<[1, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %s5 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
+  %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<3xi32>, tensor<0xi32>) -> (tensor<2xi32>, tensor<3xi32>)
+  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK-NOT: tf.BroadcastGradientArgs
+  // CHECK: return %[[R0]], %[[R1]]
+
+  return %r4, %r5 : tensor<2xi32>, tensor<3xi32>
+}
+
+// CHECK-LABEL: func @testBroadcastGradientArgs5
+func @testBroadcastGradientArgs5() -> (tensor<1xi32>, tensor<1xi32>) {
+  %s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
+  %s5 = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32>
+  %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>)
+  // CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NOT: tf.BroadcastGradientArgs
+  // CHECK: return %[[R0]], %[[R0]]
+
+  return %r4, %r5 : tensor<1xi32>, tensor<1xi32>
+}
+
+// CHECK-LABEL: func @testBroadcastGradientArgs6
+func @testBroadcastGradientArgs6() -> (tensor<1xi32>, tensor<0xi32>) {
+  %s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
+  %s5 = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> tensor<1xi32>
+  %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<0xi32>)
   // CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // CHECK: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
   // CHECK-NOT: tf.BroadcastGradientArgs
-  // CEHCK: return [[R0]], [[R1]]
+  // CHECK: return %[[R0]], %[[R1]]
 
-  return %r0, %r1 : tensor<1xi32>, tensor<0xi32>
+  return %r4, %r5 : tensor<1xi32>, tensor<0xi32>
 }
 
 // CHECK-LABEL: func @testBroadcastGradientArgsHigherRank
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cross_host_transfer.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cross_host_transfer.mlir
new file mode 100644
index 0000000..dd1437b
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/cross_host_transfer.mlir
@@ -0,0 +1,67 @@
+// RUN: tf-opt --tf-cross-host-transfer %s | FileCheck %s
+
+// CHECK-LABEL: func @test_merge_send
+func @test_merge_send() {
+  // CHECK-NEXT: %[[RESULT_0:.*]] = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor<f32>}
+  %0 = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
+
+  // CHECK-NEXT: tf_device.send %[[RESULT_0]] "key-0" "/job:worker/replica:0/task:1" {device = "/job:worker/replica:0/task:0/device:CPU:0"}
+  // CHECK-NEXT: %[[RESULT_1:.*]] = tf_device.receive "key-0" "/job:worker/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"}
+  // CHECK-NEXT: %[[RESULT_2:.*]] = "tf.Sqrt"(%[[RESULT_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
+  %1 = "tf.Sqrt"(%0) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor<f32>) -> tensor<f32>
+
+  // CHECK-NEXT: %[[RESULT_3:.*]] = "tf.Sqrt"(%[[RESULT_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
+  %2 = "tf.Sqrt"(%0) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor<f32>) -> tensor<f32>
+  return
+}
+
+// CHECK-LABEL: func @test_multiple_send
+func @test_multiple_send() -> tensor<f32> {
+  // CHECK-NEXT: %[[RESULT_0:.*]] = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor<f32>}
+  %0 = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
+
+  // CHECK-NEXT: tf_device.send %[[RESULT_0]] "key-1" "/job:worker/replica:0/task:1" {device = "/job:worker/replica:0/task:0/device:CPU:0"}
+  // CHECK-NEXT: %[[RESULT_1:.*]] = tf_device.receive "key-1" "/job:worker/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"}
+  // CHECK-NEXT: %[[RESULT_2:.*]] = "tf.Sqrt"(%[[RESULT_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
+  %1 = "tf.Sqrt"(%0) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor<f32>) -> tensor<f32>
+
+  // CHECK-NEXT: tf_device.send %[[RESULT_2]] "key-2" "/job:localhost/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"}
+  // CHECK-NEXT: %[[RESULT_3:.*]] = tf_device.receive "key-2" "/job:worker/replica:0/task:1" {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
+  // CHECK-NEXT: %[[RESULT_4:.*]] = "tf.Identity"(%[[RESULT_3]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
+  %2 = "tf.Identity"(%1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<f32>) -> tensor<f32>
+
+  // CHECK-NEXT: return %[[RESULT_4]] : tensor<f32>
+  return %2 : tensor<f32>
+}
+
+// CHECK: func @test_send_func_arg(%[[ARG_0:.*]]: tensor<f32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}) {
+func @test_send_func_arg(%arg0: tensor<f32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}) {
+  // CHECK-NEXT: tf_device.send %[[ARG_0]] "key-3" "/job:localhost/replica:0/task:0" {device = "/job:worker/replica:0/task:0/device:CPU:0"}
+  // CHECK-NEXT: %[[RESULT_0:.*]] = tf_device.receive "key-3" "/job:worker/replica:0/task:0" {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
+  // CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Identity"(%[[RESULT_0]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
+  %0 = "tf.Identity"(%arg0) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<f32>) -> tensor<f32>
+
+  return
+}
+
+// CHECK: func @test_not_send_while_loop_arg(%[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<*xf32>, %[[ARG_2:.*]]: tensor<i32>) {
+func @test_not_send_while_loop_arg(%arg0: tensor<i32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) {
+  // CHECK-NEXT: %[[RESULT_0:.*]]:2 = "tf.WhileRegion"(%[[ARG_0]], %[[ARG_1]]) ( {
+  %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( {
+  // CHECK-NEXT: bb0(%[[ARG_3:.*]]: tensor<i32>, %[[ARG_4:.*]]: tensor<*xf32>)
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
+    // CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Identity"(%[[ARG_3]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
+    %2 = "tf.Identity"(%arg3) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor<i32>) -> tensor<i32>
+    // CHECK-NEXT: tf_device.send %[[RESULT_1]] "key-4" "/job:localhost/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"}
+    // CHECK-NEXT: %[[RESULT_2:.*]] = tf_device.receive "key-4" "/job:worker/replica:0/task:1" {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
+    // CHECK-NEXT: %[[RESULT_3:.*]] = "tf.NotEqual"(%[[ARG_2]], %[[RESULT_2]])
+    %3 = "tf.NotEqual"(%arg2, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "tf.Yield"(%3) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
+    %cst = constant dense<1> : tensor<i32>
+    %1 = "tf.Sub"(%arg3, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    "tf.Yield"(%1, %arg4) : (tensor<i32>, tensor<*xf32>) -> ()
+  }) {is_stateless = true} : (tensor<i32>, tensor<*xf32>) -> (tensor<i32>, tensor<*xf32>)
+  return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
index ec9a787..920f535 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
@@ -532,9 +532,9 @@
   // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
   // CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "var"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
   // CHECK-DAG: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "accum"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
-  // CHECK-DAG: %[[GRAD_SQRT:.*]] = "tf.Sqrt"(%[[GRAD]]) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-DAG: %[[GRAD_SQ:.*]] = "tf.Square"(%[[GRAD]]) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-DAG: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
-  // CHECK-DAG: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQRT]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-DAG: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQ]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-DAG: %[[RSQRT_ACCUM:.*]] = "tf.Rsqrt"(%[[ACCUM_NEW]]) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-DAG: %[[ADAGRAD_LR:.*]] = "tf.Mul"(%[[LR]], %[[RSQRT_ACCUM]]) : (tensor<f32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-DAG: %[[DELTA:.*]] = "tf.Mul"(%[[GRAD]], %[[ADAGRAD_LR]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/drop_while_shape_invariant.mlir b/tensorflow/compiler/mlir/tensorflow/tests/drop_while_shape_invariant.mlir
new file mode 100644
index 0000000..b20776c
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/drop_while_shape_invariant.mlir
@@ -0,0 +1,29 @@
+// RUN: tf-opt %s -tf-drop-while-shape-invariant | FileCheck %s
+
+// CHECK-LABEL: while_shape_invariant
+// CHECK-NOT: shape_invariant
+func @while_shape_invariant(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
+  %0 = "tf.While"(%arg0) {cond = @while_cond, body = @while_body, is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
+
+  %1 = "tf.WhileRegion"(%arg0) ( {
+  ^cond(%carg0: tensor<*xf32>):
+    %2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    "tf.Yield"(%2) : (tensor<i1>) -> ()
+  }, {
+  ^body(%barg0: tensor<*xf32>):
+    %2 = "tf.SomeOp"(%barg0) : (tensor<*xf32>) -> tensor<*xf32>
+    "tf.Yield"(%2) : (tensor<*xf32>) -> ()
+  }) {is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
+
+  return %0, %1 : tensor<*xf32>, tensor<*xf32>
+}
+
+func @while_cond(%arg0: tensor<*xf32>) -> tensor<i1> {
+  %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+func @while_body(%arg0: tensor<*xf32>) -> (tensor<*xf32>) {
+  %0 = "tf.SomeOp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
index 9806e79..ad70631 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
@@ -11,7 +11,7 @@
   } : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
 
 // CHECK:   [[TOBOOL:%.+]] = "tf.ToBool"(%arg0) : (tensor<i1>) -> tensor<i1>
-// CHECK:   [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor<i1>
+// CHECK:   [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor<i1>
 // CHECK:   cond_br [[PRED]], ^bb1, ^bb2
 // CHECK: ^bb1:
 // CHECK:   [[THEN:%.+]] = call @testIf1Then(%arg1, %arg2)
@@ -36,7 +36,7 @@
   } : (tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
 
 // CHECK:   [[TOBOOL:%.+]] = "tf.ToBool"(%arg0) : (tensor<i1>) -> tensor<i1>
-// CHECK:   [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor<i1>
+// CHECK:   [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor<i1>
 // CHECK:   cond_br [[PRED]], ^bb1, ^bb2
 // CHECK: ^bb1:
 // CHECK:   [[THEN:%.+]]:3 = call @testIf3Then(%arg1)
@@ -65,7 +65,7 @@
   } : (tensor<i1>, tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant<tensor<f32>>>
   return %0: tensor<!tf.variant<tensor<f32>>>
 // CHECK:   [[TOBOOL:%.+]] = "tf.ToBool"(%arg0) : (tensor<i1>) -> tensor<i1>
-// CHECK:   [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor<i1>
+// CHECK:   [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor<i1>
 // CHECK:   cond_br [[PRED]], ^bb1, ^bb2
 // CHECK: ^bb1:
 // CHECK:   [[CAST0:%.+]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant>
@@ -93,7 +93,7 @@
 ^bb0(%arg0: tensor<4xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>):
 
   // CHECK: [[TOBOOL:%.+]] = "tf.ToBool"(%arg0) : (tensor<4xi1>) -> tensor<i1>
-  // CHECK: [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor<i1>
+  // CHECK: [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor<i1>
   %1 = "tf.If"(%arg0, %arg1, %arg2) {
     then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
   } : (tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
@@ -118,7 +118,7 @@
 // CHECK: ^bb1([[CONDARG0:%.+]]: tensor<*xf32>, [[CONDARG1:%.+]]: tensor<*xf32>):
 // CHECK:   [[CONTINUE:%.+]] = call @testWhile2Cond(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<i1>
 // CHECK:   [[TOBOOL:%.+]] = "tf.ToBool"([[CONTINUE]]) : (tensor<i1>) -> tensor<i1>
-// CHECK:   [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor<i1>
+// CHECK:   [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor<i1>
 // CHECK:   cond_br [[PRED]], ^bb2([[CONDARG0]], [[CONDARG1]] : tensor<*xf32>, tensor<*xf32>), ^bb3([[CONDARG0]], [[CONDARG1]] : tensor<*xf32>, tensor<*xf32>)
 // CHECK: ^bb2([[BODYARG0:%.+]]: tensor<*xf32>, [[BODYARG1:%.+]]: tensor<*xf32>):
 // CHECK:   [[BODYRETS:%.+]]:2 = call @testWhile2Body([[BODYARG0]], [[BODYARG1]]) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
@@ -142,7 +142,7 @@
 // CHECK: ^bb1:
 // CHECK:   [[CONTINUE:%.+]] = call @testWhile0Cond() : () -> tensor<i1>
 // CHECK:   [[TOBOOL:%.+]] = "tf.ToBool"([[CONTINUE]]) : (tensor<i1>) -> tensor<i1>
-// CHECK:   [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor<i1>
+// CHECK:   [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor<i1>
 // CHECK:   cond_br [[PRED]], ^bb2, ^bb3
 // CHECK: ^bb2:
 // CHECK:   call @testWhile0Body() : () -> ()
@@ -166,7 +166,7 @@
 // CHECK:  ^bb1([[CONDARG0:%.+]]: tensor<*xf32>, [[CONDARG1:%.+]]: tensor<*xf32>):
 // CHECK:    [[CONTINUE:%.+]] = call @testWhile2Cond([[CONDARG0]], [[CONDARG1]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<i1>
 // CHECK:    [[TOBOOL:%.+]] = "tf.ToBool"([[CONTINUE]]) : (tensor<i1>) -> tensor<i1>
-// CHECK:    [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor<i1>
+// CHECK:    [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor<i1>
 // CHECK:    cond_br [[PRED]], ^bb2([[CONDARG0]], [[CONDARG1]] : tensor<*xf32>, tensor<*xf32>), ^bb3([[CONDARG0]], [[CONDARG1]] : tensor<*xf32>, tensor<*xf32>)
 // CHECK:  ^bb2([[BODYARG0:%.+]]: tensor<*xf32>, [[BODYARG1:%.+]]: tensor<*xf32>):
 // CHECK:    [[BODYRETS:%.+]]:2 = call @testWhile2Body([[BODYARG0]], [[BODYARG1]]) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
@@ -206,7 +206,7 @@
 // CHECK: ^bb1([[CONDARG0:%.+]]: tensor<!tf.variant>):        // 2 preds: ^bb0, ^bb2
 // CHECK:   [[CONTINUE:%.+]] = call @testWhileCond([[CONDARG0]]) : (tensor<!tf.variant>) -> tensor<i1>
 // CHECK:   [[TOBOOL:%.+]] = "tf.ToBool"([[CONTINUE]]) : (tensor<i1>) -> tensor<i1>
-// CHECK:   [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor<i1>
+// CHECK:   [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor<i1>
 // CHECK:   [[CASTCONDARG0:%.+]] = "tf.Cast"([[CONDARG0]]) {Truncate = false} : (tensor<!tf.variant>) -> tensor<!tf.variant<tensor<1x?xf32>>>
 // CHECK:   cond_br [[PRED]], ^bb2([[CASTCONDARG0]] : tensor<!tf.variant<tensor<1x?xf32>>>), ^bb3([[CASTCONDARG0]] : tensor<!tf.variant<tensor<1x?xf32>>>)
 // CHECK: ^bb2([[BODYARG0:%.+]]: tensor<!tf.variant<tensor<1x?xf32>>>):       // pred: ^bb1
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-retval-attrs.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-retval-attrs.pbtxt
index e32136a..0ac7f46 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-retval-attrs.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-retval-attrs.pbtxt
@@ -19,7 +19,6 @@
 node {
   name: "arg1"
   op: "_Arg"
-  device: "CPU:1"
   attr {
     key: "T"
     value {
@@ -111,7 +110,6 @@
   name: "ret2"
   op: "_Retval"
   input: "arg2"
-  device: "CPU:0"
   attr {
     key: "T"
     value {
@@ -152,6 +150,6 @@
 # arg/result attributes, at the right index.
 
 # CHECK:      func @main
-# CHECK-SAME: ({{%.*}}: tensor<*xf32>, {{%.*}}: tensor<*xi32> {tf._arg1_attr0 = "_arg1_attr0_value", tf._arg1_attr1 = 8.000000e+00 : f32, tf.device = "CPU:1"}, {{%.*}}: tensor<*xi1>)
-# CHECK-SAME: -> (tensor<*xf32> {tf._ret0_attr0 = 8 : i64, tf._ret0_attr1 = false}, tensor<*xi32>, tensor<*xi1> {tf._ret2_attr0 = !tf.variant, tf._ret2_attr1 = #tf.shape<128x1024>, tf.device = "CPU:0"})
+# CHECK-SAME: ({{%.*}}: tensor<*xf32>, {{%.*}}: tensor<*xi32> {tf._arg1_attr0 = "_arg1_attr0_value", tf._arg1_attr1 = 8.000000e+00 : f32}, {{%.*}}: tensor<*xi1>)
+# CHECK-SAME: -> (tensor<*xf32> {tf._ret0_attr0 = 8 : i64, tf._ret0_attr1 = false}, tensor<*xi32>, tensor<*xi1> {tf._ret2_attr0 = !tf.variant, tf._ret2_attr1 = #tf.shape<128x1024>})
 # CHECK-SAME: attributes {tf.entry_function = {control_outputs = "", inputs = "arg0,arg1,arg2", outputs = "ret0,ret1,ret2"}}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/device-arg-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/device-arg-retval-attr.pbtxt
similarity index 91%
rename from tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/device-arg-attr.pbtxt
rename to tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/device-arg-retval-attr.pbtxt
index bc15182..68608e3 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/device-arg-attr.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/device-arg-retval-attr.pbtxt
@@ -1,9 +1,9 @@
 # RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -o - | FileCheck %s
 
-# Verify arg devices are added as arg attributes.
+# Verify arg and ret devices are added as arg and ret attributes.
 
 # CHECK-LABEL: func @main
-# CHECK-SAME:  (%[[ARG_0:[a-z0-9]+]]: tensor<*xf32> {tf.device = "/CPU:0"}, %[[ARG_1:[a-z0-9]+]]: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<*xi32>)
+# CHECK-SAME:  (%[[ARG_0:[a-z0-9]+]]: tensor<*xf32> {tf.device = "/CPU:0"}, %[[ARG_1:[a-z0-9]+]]: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<*xi32> {tf.device = "/CPU:1"})
 
 node {
   name: "args_0"
@@ -95,6 +95,7 @@
   name: "rets_1"
   op: "_Retval"
   input: "identity:1"
+  device: "/CPU:1"
   attr {
     key: "T"
     value {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir
index 67b4691..1405ac6 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s -inline="disable-simplify" | FileCheck %s
+// RUN: tf-opt %s -inline='default-pipeline=''' | FileCheck %s
 
 // Test that simple TF operations can be inlined.
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir
index dc9b5d5..e7a857c 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir
@@ -5,7 +5,7 @@
 } {
 
 // CHECK-LABEL: func @transposeConv2D_3x3_f16
-func @transposeConv2D_3x3_f16(%input: tensor<1x28x28x64xf16>, %filter: tensor<3x3x64x64xf16>) -> tensor<1x28x28x64xf16> {
+func @transposeConv2D_3x3_f16(%input: tensor<1x28x28x64xf16>, %filter: tensor<3x3x64x64xf16>) -> tensor<1x26x26x64xf16> {
   // cuDNN prefers NCHW data format for spatial convolutions in f16 before
   // compute capability 7.0 (NVIDIA Tensor Cores).
 
@@ -17,9 +17,9 @@
          padding = "VALID",
          strides = [1, 1, 1, 1]
        } : (tensor<1x28x28x64xf16>, tensor<3x3x64x64xf16>)
-        -> tensor<1x28x28x64xf16>
+        -> tensor<1x26x26x64xf16>
 
-  return %0 : tensor<1x28x28x64xf16>
+  return %0 : tensor<1x26x26x64xf16>
 }
 
 // CHECK-LABEL: func @transposeConv2DBackpropFilter_f16
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir
index 6173fa3..605fcdd 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir
@@ -5,7 +5,7 @@
 } {
 
 // CHECK-LABEL: func @transposeConv2D_3x3_f32
-func @transposeConv2D_3x3_f32(%input: tensor<1x28x28x64xf32>, %filter: tensor<3x3x64x64xf32>) -> tensor<1x28x28x64xf32> {
+func @transposeConv2D_3x3_f32(%input: tensor<1x28x28x64xf32>, %filter: tensor<3x3x64x64xf32>) -> tensor<1x26x26x64xf32> {
   // cuDNN prefers NCHW data format for spatial convolutions.
   // CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1)
   // CHECK-SAME: data_format = "NCHW"
@@ -15,9 +15,9 @@
          padding = "VALID",
          strides = [1, 1, 1, 1]
        } : (tensor<1x28x28x64xf32>, tensor<3x3x64x64xf32>)
-        -> tensor<1x28x28x64xf32>
+        -> tensor<1x26x26x64xf32>
 
-  return %0 : tensor<1x28x28x64xf32>
+  return %0 : tensor<1x26x26x64xf32>
 }
 
 // CHECK-LABEL: func @transposeConv2D_1x1_f32
@@ -48,7 +48,7 @@
 }
 
 // CHECK-LABEL: func @transposeConv2D_3x3_f16
-func @transposeConv2D_3x3_f16(%input: tensor<1x64x28x28xf16>, %filter: tensor<3x3x64x64xf16>) -> tensor<1x64x28x28xf16> {
+func @transposeConv2D_3x3_f16(%input: tensor<1x64x28x28xf16>, %filter: tensor<3x3x64x64xf16>) -> tensor<1x64x26x26xf16> {
   // To use Tensor Cores for f16 data type, input must be in NHWC data format.
   // CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1)
   // CHECK-SAME: data_format = "NHWC"
@@ -58,9 +58,9 @@
          padding = "VALID",
          strides = [1, 1, 1, 1]
        } : (tensor<1x64x28x28xf16>, tensor<3x3x64x64xf16>)
-        -> tensor<1x64x28x28xf16>
+        -> tensor<1x64x26x26xf16>
 
-  return %0 : tensor<1x64x28x28xf16>
+  return %0 : tensor<1x64x26x26xf16>
 }
 
 // CHECK-LABEL: func @transposeConv2DBackpropFilter_f32
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir
index 9bb05a7..f64fff2 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir
@@ -5,7 +5,7 @@
 // that changing convolution data layout will update all the attributes.
 
 // CHECK-LABEL: func @transposeConv2D
-func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> {
+func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x7x7x8xf32> {
 
   // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
   // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
@@ -16,7 +16,7 @@
   // CHECK-SAME: explicit_paddings = [1, 2, 7, 8, 3, 4, 5, 6]
   // CHECK-SAME: padding = "EXPLICIT"
   // CHECK-SAME: strides = [5, 8, 6, 7]
-  // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32>
+  // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x7x7xf32>
 
   // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
@@ -29,13 +29,13 @@
          explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
          padding = "EXPLICIT",
          strides = [5, 6, 7, 8]
-       } : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
+       } : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x7x7x8xf32>
 
-  return %0 : tensor<1x32x32x8xf32>
+  return %0 : tensor<1x7x7x8xf32>
 }
 
 // CHECK-LABEL: func @transposeConv2DWithDefaultAttr
-func @transposeConv2DWithDefaultAttr(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<*xf32>
+func @transposeConv2DWithDefaultAttr(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<?x?x?x?xf32>
 {
 
   // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
@@ -47,7 +47,7 @@
   // CHECK-SAME: explicit_paddings = [1, 2, 7, 8, 3, 4, 5, 6]
   // CHECK-SAME: padding = "EXPLICIT"
   // CHECK-SAME: strides = [5, 8, 6, 7]
-  // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<*xf32>
+  // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<?x?x?x?xf32>
 
   // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
@@ -61,9 +61,9 @@
          explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
          padding = "EXPLICIT",
          strides = [5, 6, 7, 8]
-       } : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<*xf32>
+       } : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<?x?x?x?xf32>
 
-  return %0 : tensor<*xf32>
+  return %0 : tensor<?x?x?x?xf32>
 }
 
 // CHECK-LABEL: func @transposeConv2DBackpropFilter
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
index 342804f..a0e09de 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
@@ -5,7 +5,7 @@
 // layout will update all the attributes.
 
 // CHECK-LABEL: func @transposeConv2D
-func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> {
+func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x7x6xf32> {
 
   // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
   // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
@@ -16,7 +16,7 @@
   // CHECK-SAME: explicit_paddings = [1, 2, 5, 6, 7, 8, 3, 4]
   // CHECK-SAME: padding = "EXPLICIT"
   // CHECK-SAME: strides = [5, 7, 8, 6]
-  // CHECK-SAME: (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
+  // CHECK-SAME: (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x7x6x8xf32>
 
   // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
@@ -29,9 +29,9 @@
          explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
          padding = "EXPLICIT",
          strides = [5, 6, 7, 8]
-       } : (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32>
+       } : (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x7x6xf32>
 
-  return %0 : tensor<1x8x32x32xf32>
+  return %0 : tensor<1x8x7x6xf32>
 }
 
 // CHECK-LABEL: func @transposeFusedBatchNormV3
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nchw.mlir
index ae3592b..ba46de7 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nchw.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nchw.mlir
@@ -1,7 +1,7 @@
 // RUN: tf-opt %s -tf-layout-optimization=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always
 
 // CHECK-LABEL: func @transposeConv2D
-func @transposeConv2D(%arg0: tensor<1x3x32x32xf32>, %arg1: tensor<1x1x3x8xf32>) -> tensor<1x3x32x32xf32> {
+func @transposeConv2D(%arg0: tensor<1x3x32x32xf32>, %arg1: tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> {
 
   // Convert input: NCHW -> NHWC
   %0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
@@ -14,20 +14,20 @@
       padding = "SAME",
       strides = [1, 1, 1, 1],
       dilations = [1, 1, 1, 1]
-    } : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x3xf32>
+    } : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
 
   // Convert result back: NHWC -> NCHW
   %3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
-  %4 = "tf.Transpose"(%2, %3) : (tensor<1x32x32x3xf32>, tensor<4xi32>) -> tensor<1x3x32x32xf32>
+  %4 = "tf.Transpose"(%2, %3) : (tensor<1x32x32x8xf32>, tensor<4xi32>) -> tensor<1x8x32x32xf32>
 
   // Check that Conv2D computed in NCHW format, and all redundant transpose
   // operations removed from the function.
 
   // CHECK: %[[CONV:[0-9]*]] = "tf.Conv2D"(%arg0, %arg1)
   // CHECK-SAME: data_format = "NCHW"
-  // CHECK-SAME: -> tensor<1x3x32x32xf32>
+  // CHECK-SAME: -> tensor<1x8x32x32xf32>
 
   // CHECK: return %[[CONV]]
 
-  return %4 : tensor<1x3x32x32xf32>
+  return %4 : tensor<1x8x32x32xf32>
 }
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
index fbcf228..44c0910 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
@@ -1610,16 +1610,16 @@
 
 // CHECK-LABEL:   func @convert_conv2d_valid_padding(
 // CHECK-SAME:                                       %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
-// CHECK-SAME:                                       %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
-// CHECK:           return %[[VAL_2]] : tensor<1x8x8x16xf32>
+// CHECK-SAME:                                       %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> {
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32>
+// CHECK:           return %[[VAL_2]] : tensor<1x6x6x16xf32>
 // CHECK:         }
-func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> {
   %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
        {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
        feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
-       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
-  return %0 : tensor<1x8x8x16xf32>
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32>
+  return %0 : tensor<1x6x6x16xf32>
 }
 
 // CHECK-LABEL:   func @convert_reduce_to_sum(
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
index 15662ed..d0fc9b1 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
@@ -908,3 +908,41 @@
   %resize = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x?x?x1xi32>
   return %resize: tensor<1x?x?x1xi32>
 }
+
+// CHECK-LABEL: func @xdivy
+// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
+func @xdivy(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK:  %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+  // CHECK:  %[[MUL:.*]] = "tf.Div"(%[[X]], %[[Y]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  // CHECK:  %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[MUL]]) : (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
+  %0 = "tf.Xdivy"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  // CHECK: return %[[RESULT]]
+  return %0 : tensor<*xf32>
+}
+
+// CHECK-LABEL: func @xlog1py
+// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
+func @xlog1py(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK:  %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+  // CHECK:  %[[LOG:.*]] = "tf.Log1p"(%[[Y]]) : (tensor<*xf32>) -> tensor<*xf32>
+  // CHECK:  %[[MUL:.*]] = "tf.Mul"(%[[X]], %[[LOG]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  // CHECK:  %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[MUL]]) : (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
+  %0 = "tf.Xlog1py"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  // CHECK: return %[[RESULT]]
+  return %0 : tensor<*xf32>
+}
+
+// CHECK-LABEL: func @xlogy
+// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
+func @xlogy(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK:  %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+  // CHECK:  %[[LOG:.*]] = "tf.Log"(%[[Y]]) : (tensor<*xf32>) -> tensor<*xf32>
+  // CHECK:  %[[MUL:.*]] = "tf.Mul"(%[[X]], %[[LOG]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  // CHECK:  %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[MUL]]) : (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
+  %0 = "tf.Xlogy"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  // CHECK: return %[[RESULT]]
+  return %0 : tensor<*xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir b/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir
new file mode 100644
index 0000000..5307829
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir
@@ -0,0 +1,398 @@
+// RUN: tf-opt %s -tf-merge-control-flow | FileCheck %s
+
+// Check that IfRegions with different predicates are not merged.
+
+// CHECK-LABEL: func @different_predicate_no_merge
+func @different_predicate_no_merge() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        "tf.IfRegion"
+  // CHECK:        "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    %1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
+    "tf.IfRegion"(%0) ( {
+      %2 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) {is_stateless = true} : (tensor<i1>) -> ()
+    "tf.IfRegion"(%1) ( {
+      %2 = "tf.B"() : () -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) {is_stateless = true} : (tensor<i1>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+
+// Check that IfRegions with same predicates but different block are not merged.
+
+// CHECK-LABEL: func @different_block_no_merge
+func @different_block_no_merge() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        "tf.IfRegion"
+  // CHECK:        "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    %1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
+    %3 = "tf.A"() : () -> (tensor<?xf32>)
+    %4 = "tf.B"() : () -> (tensor<i32>)
+    "tf.WhileRegion"(%4, %3) ({
+    ^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
+      "tf.IfRegion"(%0) ( {
+        %2 = "tf.A"() : () -> (tensor<f32>)
+        "tf.Yield"() : () -> ()
+        }, {
+        "tf.Yield"() : () -> ()
+       }) {is_stateless = true} : (tensor<i1>) -> ()
+       "tf.Yield"(%1) : (tensor<i1>) -> ()
+    }, {
+    ^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
+      "tf.IfRegion"(%0) ( {
+        %2 = "tf.B"() : () -> (tensor<f32>)
+        "tf.Yield"() : () -> ()
+        }, {
+        "tf.Yield"() : () -> ()
+       }) {is_stateless = true} : (tensor<i1>) -> ()
+      "tf.Yield"(%arg1, %arg2) : (tensor<i32>, tensor<?xf32>) -> ()
+    }) {is_stateless = false} : (tensor<i32>, tensor<?xf32>) -> (tensor<i32>, tensor<?xf32>)
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+
+// Check that IfRegions with same predicates and no returns are merged.
+
+// CHECK-LABEL: func @same_predicate_no_returns_merged
+func @same_predicate_no_returns_merged() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        "tf.IfRegion"
+  // CHECK-NOT:    "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    "tf.IfRegion"(%0) ( {
+      %2 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) {is_stateless = true} : (tensor<i1>) -> ()
+    "tf.IfRegion"(%0) ( {
+      %2 = "tf.B"() : () -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) {is_stateless = true} : (tensor<i1>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+
+// Check that IfRegions with same predicate intermediate data dependency are not merged.
+
+// CHECK-LABEL: func @same_predicate_intermediate_dependency_no_merge
+func @same_predicate_intermediate_dependency_no_merge() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        "tf.IfRegion"
+  // CHECK:        "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    %1 = "tf.IfRegion"(%0) ( {
+      %2 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"(%2) : (tensor<f32>) -> ()
+      }, {
+      %2 = "tf.C"() : () -> (tensor<f32>)
+      "tf.Yield"(%2) : (tensor<f32>) -> ()
+     }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+    %3 = "tf.D"(%1) : (tensor<f32>) -> (tensor<f32>)
+    %4 = "tf.E"(%3) : (tensor<f32>) -> (tensor<f32>)
+    "tf.IfRegion"(%0) ( {
+      %5 = "tf.B"(%4) : (tensor<f32>) -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) {is_stateless = true} : (tensor<i1>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+
+// Check that IfRegions with same predicate intermediate side effect dependency are not merged.
+
+// CHECK-LABEL: func @same_predicate_side_effect_dependency_no_merge
+func @same_predicate_side_effect_dependency_no_merge() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        "tf.IfRegion"
+  // CHECK:        "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    %1 = "tf.IfRegion"(%0) ( {
+      %2 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"(%2) : (tensor<f32>) -> ()
+      }, {
+      %2 = "tf.C"() : () -> (tensor<f32>)
+      "tf.Yield"(%2) : (tensor<f32>) -> ()
+     }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+    "tf.D"(%1) : (tensor<f32>) -> ()
+    "tf.IfRegion"(%0) ( {
+      %4 = "tf.B"(%1) : (tensor<f32>) -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) {is_stateless = false} : (tensor<i1>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+
+// Check that merged IfRegions correctly set is_stateless attribute.
+
+// CHECK-LABEL: func @same_predicate_stateless_merge
+func @same_predicate_stateless_merge() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        "tf.IfRegion"
+  // CHECK:        is_stateless = false
+  // CHECK-NOT:    "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    %1 = "tf.IfRegion"(%0) ( {
+      %2 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"(%2) : (tensor<f32>) -> ()
+      }, {
+      %2 = "tf.C"() : () -> (tensor<f32>)
+      "tf.Yield"(%2) : (tensor<f32>) -> ()
+     }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+    "tf.IfRegion"(%0) ( {
+      %4 = "tf.B"() : () -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) {is_stateless = false} : (tensor<i1>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+
+// Check that IfRegions with same predicates and returns are merged.
+
+// CHECK-LABEL: func @same_predicate_returns_merged
+func @same_predicate_returns_merged() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        %[[IF_OUTPUT:[0-9]*]]:2 = "tf.IfRegion"
+  // CHECK:          %[[A_OUTPUT:[0-9]*]] = "tf.A"
+  // CHECK-NEXT:     %[[B_OUTPUT:[0-9]*]] = "tf.B"
+  // CHECK-NEXT:     "tf.Yield"(%[[A_OUTPUT]], %[[B_OUTPUT]])
+  // CHECK:          %[[C_OUTPUT:[0-9]*]] = "tf.C"
+  // CHECK-NEXT:     %[[D_OUTPUT:[0-9]*]] = "tf.D"
+  // CHECK-NEXT:     "tf.Yield"(%[[C_OUTPUT]], %[[D_OUTPUT]])
+  // CHECK-NOT:    "tf.IfRegion"
+  // CHECK         "tf.E"(%[[IF_OUTPUT]]#0, %[[IF_OUTPUT]]#1)
+  // CHECK-NOT:    "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    %1 = "tf.IfRegion"(%0) ( {
+      %3 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
+      }, {
+      %3 = "tf.C"() : () -> (tensor<f32>)
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> (tensor<f32>)
+    %2 = "tf.IfRegion"(%0) ( {
+      %3 = "tf.B"() : () -> (tensor<i32>)
+      "tf.Yield"(%3) : (tensor<i32>) -> ()
+      }, {
+      %3 = "tf.D"() : () -> (tensor<i32>)
+      "tf.Yield"(%3) : (tensor<i32>) -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> (tensor<i32>)
+    "tf.E"(%1, %2) : (tensor<f32>, tensor<i32>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+// Check that IfRegions with same predicates and unused returns.
+
+// CHECK-LABEL: func @same_predicate_returns_unused
+func @same_predicate_returns_unused() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        %[[IF_OUTPUT:[0-9]*]] = "tf.IfRegion"
+  // CHECK:          %[[A_OUTPUT:[0-9]*]] = "tf.A"
+  // CHECK-NEXT:     %[[B_OUTPUT:[0-9]*]] = "tf.B"
+  // CHECK-NEXT:     "tf.Yield"(%[[B_OUTPUT]])
+  // CHECK:          %[[C_OUTPUT:[0-9]*]] = "tf.C"
+  // CHECK-NEXT:     %[[D_OUTPUT:[0-9]*]] = "tf.D"
+  // CHECK-NEXT:     "tf.Yield"(%[[D_OUTPUT]])
+  // CHECK-NOT:    "tf.IfRegion"
+  // CHECK         "tf.E"(%[[IF_OUTPUT]])
+  // CHECK-NOT:    "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    %1 = "tf.IfRegion"(%0) ( {
+      %3 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
+      }, {
+      %3 = "tf.C"() : () -> (tensor<f32>)
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> (tensor<f32>)
+    %2 = "tf.IfRegion"(%0) ( {
+      %3 = "tf.B"() : () -> (tensor<i32>)
+      "tf.Yield"(%3) : (tensor<i32>) -> ()
+      }, {
+      %3 = "tf.D"() : () -> (tensor<i32>)
+      "tf.Yield"(%3) : (tensor<i32>) -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> (tensor<i32>)
+    "tf.E"(%2) : (tensor<i32>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+
+// CHECK-LABEL: func @same_predicate_dependency
+func @same_predicate_dependency() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        %[[IF_OUTPUT:[0-9]*]] = "tf.IfRegion"
+  // CHECK:          %[[A_OUTPUT:[0-9]*]] = "tf.A"
+  // CHECK-NEXT:     %[[B_OUTPUT:[0-9]*]] = "tf.B"
+  // CHECK-NEXT:     "tf.Yield"(%[[B_OUTPUT]])
+  // CHECK:          %[[C_OUTPUT:[0-9]*]] = "tf.C"
+  // CHECK-NEXT:     %[[D_OUTPUT:[0-9]*]] = "tf.D"
+  // CHECK-NEXT:     "tf.Yield"(%[[D_OUTPUT]])
+  // CHECK-NOT:    "tf.IfRegion"
+  // CHECK         "tf.E"(%[[IF_OUTPUT]])
+  // CHECK-NOT:    "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    %1 = "tf.IfRegion"(%0) ( {
+      %3 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
+      }, {
+      %3 = "tf.C"() : () -> (tensor<f32>)
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> (tensor<f32>)
+    %2 = "tf.IfRegion"(%0) ( {
+      %3 = "tf.B"(%1) : (tensor<f32>) -> (tensor<i32>)
+      "tf.Yield"(%3) : (tensor<i32>) -> ()
+      }, {
+      %3 = "tf.D"(%1) : (tensor<f32>) -> (tensor<i32>)
+      "tf.Yield"(%3) : (tensor<i32>) -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> (tensor<i32>)
+    "tf.E"(%2) : (tensor<i32>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+
+// Checks that results from first IfRegion are moved after merged IfRegion op as needed.
+
+// CHECK-LABEL: func @same_predicate_results_moved
+func @same_predicate_results_moved(%arg0: tensor<!tf.resource<tensor<f32>>>) {
+  // CHECK:      tf_device.cluster
+  // CHECK:        %[[IF_OUTPUT:[0-9]*]]:2 = "tf.IfRegion"
+  // CHECK:          %[[A_OUTPUT:[0-9]*]] = "tf.A"
+  // CHECK-NEXT:     %[[B_OUTPUT:[0-9]*]] = "tf.B"
+  // CHECK-NEXT:     "tf.Yield"(%[[A_OUTPUT]], %[[B_OUTPUT]])
+  // CHECK:          %[[C_OUTPUT:[0-9]*]] = "tf.C"
+  // CHECK-NEXT:     %[[D_OUTPUT:[0-9]*]] = "tf.D"
+  // CHECK-NEXT:     "tf.Yield"(%[[C_OUTPUT]], %[[D_OUTPUT]])
+  // CHECK-NOT:    "tf.IfRegion"
+  // CHECK         "tf.AssignVariableOp(arg0, %[[IF_OUTPUT#0]])
+  // CHECK         "tf.E"(%[[IF_OUTPUT#1]])
+  // CHECK-NEXT    "tf.F"(%[[IF_OUTPUT#1]])
+  // CHECK-NOT:    "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    %1 = "tf.IfRegion"(%0) ( {
+      %3 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
+      }, {
+      %3 = "tf.C"() : () -> (tensor<f32>)
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> (tensor<f32>)
+    "tf.AssignVariableOp"(%arg0, %1) : (tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
+    %4 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
+    %5 = "tf.IfRegion"(%0) ( {
+      %3 = "tf.B"(%4) : (tensor<f32>) -> (tensor<i32>)
+      "tf.Yield"(%3) : (tensor<i32>) -> ()
+      }, {
+      %3 = "tf.D"(%4) : (tensor<f32>) -> (tensor<i32>)
+      "tf.Yield"(%3) : (tensor<i32>) -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> (tensor<i32>)
+    %6 = "tf.E"(%5) : (tensor<i32>) -> (tensor<f32>)
+    "tf.F"(%1, %6) : (tensor<f32>, tensor<f32>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+
+// Check that 3 IfRegions with same predicates and no intermediate dependencies are merged.
+
+// CHECK-LABEL: func @same_predicate_3_ifregions
+func @same_predicate_3_ifregions() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        "tf.IfRegion"
+  // CHECK-NOT:    "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    "tf.IfRegion"(%0) ( {
+      %2 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> ()
+    "tf.IfRegion"(%0) ( {
+      %2 = "tf.B"() : () -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> ()
+    "tf.IfRegion"(%0) ( {
+      %2 = "tf.C"() : () -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
+
+// Check that 3 IfRegions with same predicates where 2nd and 3rd IfRegions
+// can be merged but not 1st IfRegion.
+
+// CHECK-LABEL: func @same_predicate_3_ifregions_only_merge2
+func @same_predicate_3_ifregions_only_merge2() {
+  // CHECK:      tf_device.cluster
+  // CHECK:        "tf.IfRegion"
+  // CHECK:          "tf.A"
+  // CHECK:        "tf.D"
+  // CHECK-NEXT    "tf.IfRegion"
+  // CHECK:          "tf.E"
+  // CHECK-NEXT:     "tf.G"
+  // CHECK-NOT:    "tf.IfRegion"
+  "tf_device.cluster"() ( {
+    %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+    %1 = "tf.IfRegion"(%0) ( {
+      %2 = "tf.A"() : () -> (tensor<f32>)
+      "tf.Yield"(%2) : (tensor<f32>) -> ()
+      }, {
+      %2 = "tf.C"() : () -> (tensor<f32>)
+      "tf.Yield"(%2) : (tensor<f32>) -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> (tensor<f32>)
+    %3 = "tf.D"(%1) : (tensor<f32>) -> (tensor<f32>)
+    "tf.IfRegion"(%0) ( {
+      %4 = "tf.E"(%3) : (tensor<f32>) -> (tensor<f32>)
+      "tf.Yield"(%4) : (tensor<f32>) -> ()
+      }, {
+      %4 = "tf.F"() : () -> (tensor<f32>)
+      "tf.Yield"(%4) : (tensor<f32>) -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> (tensor<f32>)
+    "tf.IfRegion"(%0) ( {
+      %5 = "tf.G"(%3) : (tensor<f32>) -> (tensor<f32>)
+      "tf.Yield"() : () -> ()
+      }, {
+      "tf.Yield"() : () -> ()
+     }) { is_stateless = true } : (tensor<i1>) -> ()
+    tf_device.return
+  }) {cluster_attr = "cluster_attr"} : () -> ()
+  return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir
index fea105a..baafed0 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir
@@ -1,37 +1,37 @@
 // RUN: tf-opt -tf-optimize %s -o %t && FileCheck %s < %t
 
 // CHECK-LABEL: convbiasaddmul
-func @convbiasaddmul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> {
+func @convbiasaddmul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x8x7x16xf32> {
   %filter = constant dense<2.0> : tensor<3x3x3x16xf32>
   %bias = constant dense<3.0> : tensor<16xf32>
   %value = constant dense<4.0> : tensor<16xf32>
-  %0 = "tf.Conv2D"(%arg, %filter) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
-  %1 = "tf.BiasAdd"(%0, %bias) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"}: (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
-  %2 = "tf.Mul"(%1, %value) {T = "tfdtype$DT_FLOAT"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
-  return %2 : tensor<256x30x30x16xf32>
+  %0 = "tf.Conv2D"(%arg, %filter) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
+  %1 = "tf.BiasAdd"(%0, %bias) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"}: (tensor<256x8x7x16xf32>, tensor<16xf32>) -> tensor<256x8x7x16xf32>
+  %2 = "tf.Mul"(%1, %value) {T = "tfdtype$DT_FLOAT"} : (tensor<256x8x7x16xf32>, tensor<16xf32>) -> tensor<256x8x7x16xf32>
+  return %2 : tensor<256x8x7x16xf32>
 
 // CHECK-NEXT: %[[cst:.*]] = "tf.Const{{.*}} dense<8.000000e+00> : tensor<3x3x3x16xf32>
 // CHECK-NEXT: %[[cst_0:.*]] = "tf.Const{{.*}} dense<1.200000e+01> : tensor<16xf32>
 // CHECK-NEXT: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]])
 // CHECK-NEXT: %[[bias:.*]] = "tf.AddV2"(%[[conv]], %[[cst_0]])
-// CHECK-NEXT: return %[[bias]] : tensor<256x30x30x16xf32>
+// CHECK-NEXT: return %[[bias]] : tensor<256x8x7x16xf32>
 }
 
 // CHECK-LABEL: convaddv2mul
-func @convaddv2mul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> {
+func @convaddv2mul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x8x7x16xf32> {
   %filter = constant dense<2.0> : tensor<3x3x3x16xf32>
   %bias = constant dense<3.0> : tensor<16xf32>
   %value = constant dense<4.0> : tensor<16xf32>
-  %0 = "tf.Conv2D"(%arg, %filter) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
-  %1 = "tf.AddV2"(%0, %bias) {T = "tfdtype$DT_FLOAT"}: (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
-  %2 = "tf.Mul"(%1, %value) {T = "tfdtype$DT_FLOAT"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
-  return %2 : tensor<256x30x30x16xf32>
+  %0 = "tf.Conv2D"(%arg, %filter) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
+  %1 = "tf.AddV2"(%0, %bias) {T = "tfdtype$DT_FLOAT"}: (tensor<256x8x7x16xf32>, tensor<16xf32>) -> tensor<256x8x7x16xf32>
+  %2 = "tf.Mul"(%1, %value) {T = "tfdtype$DT_FLOAT"} : (tensor<256x8x7x16xf32>, tensor<16xf32>) -> tensor<256x8x7x16xf32>
+  return %2 : tensor<256x8x7x16xf32>
 
 // CHECK-NEXT: %[[cst:.*]] = "tf.Const{{.*}} dense<8.000000e+00> : tensor<3x3x3x16xf32>
 // CHECK-NEXT: %[[cst_0:.*]] = "tf.Const{{.*}} dense<1.200000e+01> : tensor<16xf32>
 // CHECK-NEXT: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]])
 // CHECK-NEXT: %[[add:.*]] = "tf.AddV2"(%[[conv]], %[[cst_0]])
-// CHECK-NEXT: return %[[add]] : tensor<256x30x30x16xf32>
+// CHECK-NEXT: return %[[add]] : tensor<256x8x7x16xf32>
 }
 
 // CHECK-LABEL: fold_cast_fft_to_rfft
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_inlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_inlining.mlir
index af8e720..2154712 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/resource_inlining.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_inlining.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt -tf-shape-inference -inline="disable-simplify" %s | FileCheck %s --dump-input=always
+// RUN: tf-opt -tf-shape-inference -inline='default-pipeline=''' %s | FileCheck %s --dump-input=always
 // RUN: tf-opt -tf-standard-pipeline=enable-inliner %s | FileCheck %s --dump-input=always
 
 // Tests function with argument has no resource subtype but caller operand has a
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index 8c156e3..fddb45f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -551,49 +551,49 @@
 // -----
 
 // CHECK-LABEL: func @testValidConv2D
-func @testValidConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
-  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
-  return %0 : tensor<256x30x30x16xf32>
+func @testValidConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> {
+  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32>
+  return %0 : tensor<256x32x32x16xf32>
 }
 
 // -----
 
 // CHECK-LABEL: func @testValidDynamicConv2D
-func @testValidDynamicConv2D(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
-  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-  return %0 : tensor<*xf32>
+func @testValidDynamicConv2D(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<*xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
 }
 
 // -----
 
 // CHECK-LABEL: func @testValidConv3D
-func @testValidConv3D(%arg0: tensor<256x32x32x32x3xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x30x30x30x16xf32> {
-  %0 = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<256x32x32x32x3xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x30x30x30x16xf32>
-  return %0 : tensor<256x30x30x30x16xf32>
+func @testValidConv3D(%arg0: tensor<256x32x32x32x3xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x32x32x32x16xf32> {
+  %0 = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<256x32x32x32x3xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x32x32x32x16xf32>
+  return %0 : tensor<256x32x32x32x16xf32>
 }
 
 // -----
 
-func @testConv2D(%arg0: tensor<256x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
+func @testConv2D(%arg0: tensor<256x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> {
   // expected-error @+1 {{requires operands to be 4D tensor}}
-  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
-  return %0 : tensor<256x30x30x16xf32>
+  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32>
+  return %0 : tensor<256x32x32x16xf32>
 }
 
 // -----
 
-func @testConv3D(%arg0: tensor<256x32x32x32x3xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
-  // expected-error @+1 {{requires result to be 5D tensor}}
-  %0 = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<256x32x32x32x3xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
-  return %0 : tensor<256x30x30x16xf32>
+func @testConv3D(%arg0: tensor<256x32x32x32x3xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x32x32x16xf32> {
+  // expected-error @+1 {{'tf.Conv3D' op inferred type incompatible with return type of operation}}
+  %0 = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<256x32x32x32x3xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x32x32x16xf32>
+  return %0 : tensor<256x32x32x16xf32>
 }
 
 // -----
 
-func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x2x16xf32>) -> tensor<256x30x30x16xf32> {
+func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x2x16xf32>) -> tensor<256x32x32x16xf32> {
   // expected-error @+1 {{requires the number of input channels to be divisible by the number of filter input channels; found 3 and 2, respectively}}
-  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x2x16xf32>) -> tensor<256x30x30x16xf32>
-  return %0 : tensor<256x30x30x16xf32>
+  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x2x16xf32>) -> tensor<256x32x32x16xf32>
+  return %0 : tensor<256x32x32x16xf32>
 }
 
 // -----
@@ -607,7 +607,7 @@
 // -----
 
 func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
-  // expected-error @+1 {{requires explicit_paddings attribute length to be 8; actual length 4}}
+  // expected-error @+1 {{requires explicit_paddings attribute length to be 8}}
   %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "EXPLICIT", strides = [1, 1, 1, 1], explicit_paddings = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
   return %0 : tensor<256x30x30x16xf32>
 }
@@ -639,6 +639,38 @@
 // -----
 
 func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
+  // expected-error @+1 {{'tf.Conv2D' op inferred type incompatible with return type of operation}}
+  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 2, 3, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+  return %0 : tensor<256x30x30x16xf32>
+}
+
+// -----
+
+func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x16x30x16xf32> {
+  // expected-error @+1 {{'tf.Conv2D' op inferred type incompatible with return type of operation}}
+  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 2, 3, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x30x16xf32>
+  return %0 : tensor<256x16x30x16xf32>
+}
+
+// -----
+
+func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> {
+  // expected-error @+1 {{'tf.Conv2D' op inferred type incompatible with return type of operation}}
+  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "EXPLICIT", dilations = [1, 2, 3, 4], explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8], strides = [5, 6, 7, 8]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32>
+  return %0 : tensor<256x32x32x16xf32>
+}
+
+// -----
+
+func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> {
+  // expected-error @+1 {{'tf.Conv2D' op inferred type incompatible with return type of operation}}
+  %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32>
+  return %0 : tensor<256x32x32x16xf32>
+}
+
+// -----
+
+func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
   // expected-error @+1 {{requires dilations attribute length to be 4}}
   %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1], dilations = [1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
   return %0 : tensor<256x30x30x16xf32>
@@ -3530,6 +3562,22 @@
 
 // -----
 
+// Legal BatchMatMul op.
+func @testBatchMatMul(%lhs: tensor<2x?x2x?x3x5xf32>, %rhs: tensor<2x2x?x?x5x7xf32>) {
+  %0 = "tf.BatchMatMul"(%lhs, %rhs) : (tensor<2x?x2x?x3x5xf32>, tensor<2x2x?x?x5x7xf32>) -> tensor<2x?x?x?x3x7xf32>
+  return
+}
+
+// -----
+
+// Mismatching batch dimensions.
+func @testBatchMatMul(%lhs: tensor<1x3x5xf32>, %rhs: tensor<2x5x7xf32>) {
+  // expected-error @+1 {{found mismatching batch dimensions for lhs shape 'tensor<1x3x5xf32>' and rhs shape 'tensor<2x5x7xf32>'}}
+  %0 = "tf.BatchMatMul"(%lhs, %rhs) : (tensor<1x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32>
+}
+
+// -----
+
 func @testBatchMatMulV2(%lhs: tensor<f32>, %rhs: tensor<10x10xf32>) {
   // expected-error @+1 {{requires lhs operand to have rank at least two}}
   %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<f32>, tensor<10x10xf32>) -> tensor<10x10xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_optimize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_optimize.mlir
index 9d877f9..f6e8e41 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_optimize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_optimize.mlir
@@ -1,34 +1,34 @@
 // RUN: tf-opt %s -tf-optimize | FileCheck %s
 
 // CHECK-LABEL: @fuseMulIntoConv2d
-func @fuseMulIntoConv2d(%arg0: tensor<1x112x112x3xf32>) -> tensor<1x112x112x2xf32> {
+func @fuseMulIntoConv2d(%arg0: tensor<1x112x112x3xf32>) -> tensor<1x28x23x2xf32> {
   %cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
   %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
-  %0 = "tf.Conv2D"(%arg0, %cst0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<1x112x112x3xf32>, tensor<1x3x3x2xf32>) -> tensor<1x112x112x2xf32>
-  %1 = "tf.Mul"(%0, %cst2) : (tensor<1x112x112x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
+  %0 = "tf.Conv2D"(%arg0, %cst0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<1x112x112x3xf32>, tensor<1x3x3x2xf32>) -> tensor<1x28x23x2xf32>
+  %1 = "tf.Mul"(%0, %cst2) : (tensor<1x28x23x2xf32>, tensor<2xf32>) -> tensor<1x28x23x2xf32>
 
-  return %1 : tensor<1x112x112x2xf32>
+  return %1 : tensor<1x28x23x2xf32>
   // CHECK: %[[CST:.*]] = "tf.Const{{.*}} dense<
   // CHECK-SAME: [1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00], [5.000000e+00, 1.200000e+01]
   // CHECK-SAME: [7.000000e+00, 1.600000e+01], [9.000000e+00, 2.000000e+01], [1.100000e+01, 2.400000e+01]
   // CHECK-SAME: [1.300000e+01, 2.800000e+01], [1.500000e+01, 3.200000e+01], [1.700000e+01, 3.600000e+01]
   // CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CST]]) {data_format = "NHWC", dilations = [1, 2, 3, 1], explicit_paddings = [], padding = "SAME", strides = [1, 4, 5, 1], use_cudnn_on_gpu = true}
-  // CHECK: return %[[CONV]] : tensor<1x112x112x2xf32>
+  // CHECK: return %[[CONV]] : tensor<1x28x23x2xf32>
 }
 
 // CHECK-LABEL: @notfuseMulIntoConv2d
 // filter and multiply are not broadcastable
-func @notfuseMulIntoConv2d(%arg0: tensor<1x112x112x3xf32>) -> tensor<1x112x112x2xf32> {
+func @notfuseMulIntoConv2d(%arg0: tensor<1x112x112x3xf32>) -> tensor<1x28x23x2xf32> {
   %cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
-  %cst2 = constant dense<3.0> : tensor<112x2xf32>
-  %0 = "tf.Conv2D"(%arg0, %cst0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<1x112x112x3xf32>, tensor<1x3x3x2xf32>) -> tensor<1x112x112x2xf32>
-  %1 = "tf.Mul"(%0, %cst2) : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
+  %cst2 = constant dense<3.0> : tensor<23x2xf32>
+  %0 = "tf.Conv2D"(%arg0, %cst0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<1x112x112x3xf32>, tensor<1x3x3x2xf32>) -> tensor<1x28x23x2xf32>
+  %1 = "tf.Mul"(%0, %cst2) : (tensor<1x28x23x2xf32>, tensor<23x2xf32>) -> tensor<1x28x23x2xf32>
 
-  return %1 : tensor<1x112x112x2xf32>
-  // CHECK: %cst_0 = constant dense<3.000000e+00> : tensor<112x2xf32>
+  return %1 : tensor<1x28x23x2xf32>
+  // CHECK: %cst_0 = constant dense<3.000000e+00> : tensor<23x2xf32>
   // CHECK: %0 = "tf.Conv2D"(%arg0, %cst) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]}
-  // CHECK: %1 = "tf.Mul"(%0, %cst_0) : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
-  // CHECK: return %1 : tensor<1x112x112x2xf32>
+  // CHECK: %1 = "tf.Mul"(%0, %cst_0) : (tensor<1x28x23x2xf32>, tensor<23x2xf32>) -> tensor<1x28x23x2xf32>
+  // CHECK: return %1 : tensor<1x28x23x2xf32>
 }
 
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
index de61800..47f6b88 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
@@ -79,7 +79,7 @@
     if FLAGS.save_model_path:
       save_model_path = FLAGS.save_model_path
     else:
-      save_model_path = tempfile.mktemp(suffix='.saved_model')
+      save_model_path = tempfile.mkdtemp(suffix='.saved_model')
     save_options = tf.saved_model.SaveOptions(save_debug_info=show_debug_info)
     tf.saved_model.save(
         create_module_fn(), save_model_path, options=save_options)
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/debug_info.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/debug_info.py
index 9356163..70b5e28 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/debug_info.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/debug_info.py
@@ -35,7 +35,7 @@
     # Basic check that the debug info file is being correctly saved and loaded.
     #
     # CHECK: "tf.AddV2"{{.*}}loc(#[[LOC:.*]])
-    # CHECK: #[[LOC]] = loc({{.*}}callsite("{{[^"]*}}/debug_info.py":{{[0-9]+}}:{{[0-9]+}}
+    # CHECK: #[[LOC]] = loc({{.*}}callsite("{{[^"]*}}/debug_info.py{{.*}}":{{[0-9]+}}:{{[0-9]+}}
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
index 10777cf..2924cb5 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
@@ -1343,4 +1343,60 @@
 
     return %1 : tensor<?xi32>
   }
+
+  // Verifies that ops in between outside compile ops and depending on results
+  // from the host are moved after the host compute op so that dominance is not
+  // violated. tf.C op in this case.
+  // CHECK-LABEL: func @device_op_dominance
+  func @device_op_dominance() -> () {
+    // CHECK: tf._XlaRecvAtHost
+    // CHECK: tf.B
+    // CHECK: tf.D
+    // CHECK: tf._XlaSendFromHost
+
+    // CHECK: tf.A
+    // CHECK: tf._XlaHostComputeMlir
+    // CHECK: tf.C
+    // CHECK: tf.E
+
+    "tf_device.cluster"() ( {
+      %0 = "tf.A"() : () -> (tensor<i32>)
+      %1 = "tf.B"() {_xla_outside_compilation = "cluster0"} : () -> (tensor<i32>)
+      "tf.C"(%1) : (tensor<i32>) -> ()
+      "tf.D"(%1, %0) {_xla_outside_compilation = "cluster0"} : (tensor<i32>, tensor<i32>) -> ()
+      "tf.E"(%0, %1) : (tensor<i32>, tensor<i32>) -> ()
+      tf_device.return
+    }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> ()
+    return
+  }
+
+  // Verifies that ops indirectly depending on results from the host are also
+  // moved after the host compute op. tf.E op in this case.
+
+  // CHECK-LABEL: func @device_op_dominance_with_indirect_dependency
+  func @device_op_dominance_with_indirect_dependency() -> () {
+    // CHECK: tf._XlaRecvAtHost
+    // CHECK: tf.B
+    // CHECK: tf.F
+    // CHECK: tf._XlaSendFromHost
+
+    // CHECK: tf.A
+    // CHECK: tf.D
+    // CHECK: tf._XlaHostComputeMlir
+    // CHECK: tf.C
+    // CHECK: tf.E
+    // CHECK: tf.G
+
+    "tf_device.cluster"() ( {
+      %0 = "tf.A"() : () -> (tensor<i32>)
+      %1 = "tf.B"() {_xla_outside_compilation = "cluster0"} : () -> (tensor<i32>)
+      %2 = "tf.C"(%1) : (tensor<i32>) -> (tensor<i32>)
+      %3 = "tf.D"() : () -> (tensor<i32>)
+      "tf.E"(%2, %3) : (tensor<i32>, tensor<i32>) -> ()
+      "tf.F"(%1, %0) {_xla_outside_compilation = "cluster0"} : (tensor<i32>, tensor<i32>) -> ()
+      "tf.G"(%0, %1) : (tensor<i32>, tensor<i32>) -> ()
+      tf_device.return
+    }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> ()
+    return
+  }
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir
index 7cf5f19..ae553f1 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir
@@ -219,23 +219,3 @@
   // CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
   // CHECK: return %[[v0]] : tensor<4x6xf32>
 }
-
-// -----
-
-func @batchMatMulVectorLhsInputMatchFailure(%arg0: tensor<10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> {
-  %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
-  return %0 : tensor<10x20xf32>
-
-  // CHECK-LABEL: batchMatMulVectorLhs
-  // CHECK: %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
-}
-
-// -----
-
-func @batchMatMulVectorRhsInputMatchFailure(%arg0: tensor<10x20xf32>, %arg1: tensor<10xf32>) -> tensor<10x20xf32> {
-  %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32>
-  return %0 : tensor<10x20xf32>
-
-  // CHECK-LABEL: batchMatMulVectorRhs
-  // CHECK: %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32>
-}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc
index 50df1f1..ec51b2f 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc
@@ -58,10 +58,10 @@
   ModuleOp m = getOperation();
   OpBuilder builder(m.getContext());
   m.walk([&](tf_device::ClusterFuncOp cluster_func) {
-    auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
+    auto replicate = cluster_func->getParentOfType<tf_device::ReplicateOp>();
     if (!replicate) return;
     auto mirrored_variable_indices_attr =
-        replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
+        replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
     llvm::SmallDenseSet<int64_t, 8> mirrored_replicate_args;
     if (mirrored_variable_indices_attr) {
       for (const auto& mirrored_index : mirrored_variable_indices_attr) {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc
index 2a60706..99da136 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc
@@ -26,9 +26,9 @@
 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
index a5d18d6..c48a972 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
@@ -110,7 +110,7 @@
   pm.addPass(CreateTPUOutsideCompilationClusterPass());
   pm.addPass(CreateTPUExtractOutsideCompilationPass());
 
-  pm.addNestedPass<FuncOp>(tf_executor::CreateTFExecutorConstantSinkingPass());
+  pm.addNestedPass<FuncOp>(TFDevice::CreateClusterConstantSinkingPass());
   pm.addPass(TF::CreateResourceDeviceInferencePass());
   pm.addPass(TFDevice::CreateClusterOutliningPass());
   pm.addPass(CreateTPUDynamicPaddingMapperPass());
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td
index 39340dc..dcd8931 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td
@@ -55,6 +55,16 @@
 // BatchMatMul op patterns.
 //===----------------------------------------------------------------------===//
 
+// Static shaped operands in a legal BatchMatMul op will have matching batch
+// dimensions and can be upgraded to the BatchMatMulV2 op. Canonicalizing
+// dynamically shaped operands is not correct as that will execute ops that
+// have non matching batch dimensions but are broadcastable which should fail
+// with V1.
+def BatchMatMulToV2 :
+  Pat<(TF_BatchMatMulOp AnyStaticShapeTensor:$x, AnyStaticShapeTensor:$y,
+                        $adj_x, $adj_y),
+      (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y)>;
+
 def BatchMatMulToMatMul : Pat<(TF_BatchMatMulOp $x, $y, $adj_x, $adj_y),
                               (TF_MatMulOp $x, $y, $adj_x, $adj_y),
                               [(IsRank2Tensor $x), (IsRank2Tensor $y)]>;
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
index d8bcdee..854058e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
@@ -58,8 +58,8 @@
   operand_types.reserve(live_ins.size());
   for (Value v : live_ins) operand_types.emplace_back(v.getType());
 
-  auto func_type = FunctionType::get(operand_types, cluster_op.getResultTypes(),
-                                     builder->getContext());
+  auto func_type =
+      builder->getFunctionType(operand_types, cluster_op.getResultTypes());
 
   // TODO(lyandy): Define better name for outlined function. Potentially some
   // name can be added during cluster formation.
@@ -108,8 +108,8 @@
 
   FuncOp outlined_func =
       BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
-  cluster_op.setAttr(builder->getIdentifier(kFuncAttr),
-                     builder->getSymbolRefAttr(outlined_func.getName()));
+  cluster_op->setAttr(builder->getIdentifier(kFuncAttr),
+                      builder->getSymbolRefAttr(outlined_func.getName()));
 
   builder->setInsertionPoint(cluster_op);
   auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc
index b45d981..6a786ea 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc
@@ -193,7 +193,7 @@
     std::replace(func_name.begin(), func_name.end(), '/', '_');
 
     FunctionType func_type =
-        FunctionType::get(input_types, result_types, context);
+        FunctionType::get(context, input_types, result_types);
     Location loc = metadata.ops.front()->getLoc();
     FuncOp func_op = FuncOp::create(loc, func_name, func_type);
     // Sets the device attribute for every input and every result of the
@@ -208,7 +208,7 @@
           StringAttr::get(metadata.result_devices[i], context));
     }
 
-    func_op.setAttr(kHostAttr, StringAttr::get(host, context));
+    func_op->setAttr(kHostAttr, StringAttr::get(host, context));
     func_op.setPublic();
     Block *block = func_op.addEntryBlock();
 
@@ -291,7 +291,7 @@
   void runOnFunction() override {
     MLIRContext *context = &getContext();
     FuncOp func_op = getOperation();
-    ModuleOp module_op = func_op.getParentOfType<mlir::ModuleOp>();
+    ModuleOp module_op = func_op->getParentOfType<mlir::ModuleOp>();
 
     llvm::Optional<llvm::StringMap<FunctionMetadata>> metadatas =
         GetFunctionMetadatas(func_op);
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc
index 23ab4ff..1da5e36 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc
@@ -25,9 +25,9 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -53,7 +53,7 @@
   values.reserve(rank);
   for (int i = 0; i < rank; ++i) values.push_back(APInt(bitwidth, r1[i]));
   auto result_type = RankedTensorType::get(
-      {rank}, IntegerType::get(bitwidth, builder.getContext()));
+      {rank}, IntegerType::get(builder.getContext(), bitwidth));
   return builder.create<TF::ConstOp>(
       loc, DenseElementsAttr::get(result_type, values));
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
index 31cfc5e..833d35c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
@@ -20,7 +20,6 @@
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
-#include "tensorflow/c/eager/c_api.h"
 #include "tensorflow/c/tf_status.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
@@ -106,6 +105,10 @@
     // The TFE_Context is created without an accompanying delete due to current
     // lifetime. This does not result in memory leaks reported (see totw/110).
     TFE_ContextOptions* opts = TFE_NewContextOptions();
+    // Input tensors are placed on the host CPU so use the explicit device
+    // policy to fail if no CPU kernels are available for the op.
+    TFE_ContextOptionsSetDevicePlacementPolicy(opts,
+                                               TFE_DEVICE_PLACEMENT_EXPLICIT);
     auto ctx = TFE_NewContext(opts, status);
     TFE_DeleteContextOptions(opts);
     TF_DeleteStatus(status);
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_op_device_assignment.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_op_device_assignment.cc
index 2d7f412..d584a6a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_op_device_assignment.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_op_device_assignment.cc
@@ -42,7 +42,7 @@
 
   module.walk([&](TF::ConstOp op) {
     // Keep the ConstOp if the op already have the device attribute.
-    if (StringAttr device_attr = op.getAttrOfType<StringAttr>(kDeviceAttr)) {
+    if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
       return WalkResult::advance();
     }
     OpBuilder builder(op);
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc
new file mode 100644
index 0000000..b1c63b8
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc
@@ -0,0 +1,160 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This pass inserts tf_device.send and tf_device.receive ops to make sure any
+// argument of any op is on the same host of the op itself.
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Types.h"  // from @llvm-project
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Transforms/Passes.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace mlir {
+namespace TF {
+
+namespace {
+
+using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
+
+constexpr const char *kOpDeviceAttr = "device";
+constexpr const char *kArgDeviceAttr = "tf.device";
+// TODO(b/175480458): Do not assign default host once every op in the TF
+// dialect has the device attribute.
+constexpr const char *kDefaultHost = "/job:localhost/replica:0/task:0";
+constexpr const char *kCPUDevice = "/device:CPU:0";
+
+// Return the job/replica/task from the device name as the host address. If no
+// job/replica/task is specified, return /job:localhost/replica:0/task:0 as the
+// default host address.
+std::string GetHost(const std::string &device) {
+  DeviceNameUtils::ParsedName parsed_name;
+  DeviceNameUtils::ParseFullName(device, &parsed_name);
+  parsed_name.has_id = false;
+  parsed_name.has_type = false;
+
+  auto host = DeviceNameUtils::ParsedNameToString(parsed_name);
+  if (host.empty()) return kDefaultHost;
+
+  return host;
+}
+
+struct CrossHostTransferPass
+    : public PassWrapper<CrossHostTransferPass, OperationPass<ModuleOp>> {
+  void runOnOperation() override;
+
+ private:
+  // The key_count represents the total number of send/recv pairs generated
+  // before this method call. And the key_count should be incremented based
+  // on the send/recv pairs newly generated by this method call.
+  void runOnFunction(FuncOp func_op, int &key_count);
+};
+
+void CrossHostTransferPass::runOnOperation() {
+  ModuleOp module = getOperation();
+  int key_count = 0;
+
+  module.walk([&](FuncOp func_op) { runOnFunction(func_op, key_count); });
+}
+
+void CrossHostTransferPass::runOnFunction(FuncOp func_op, int &key_count) {
+  // This map is used to avoid transferring the same value to the same host
+  // multiple times.
+  llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
+      transferred_value_by_value_and_host;
+
+  func_op.getBody().walk([&](Operation *op) {
+    if (op->isKnownTerminator()) return WalkResult::advance();
+
+    OpBuilder builder(op);
+    // Get the host address of the op.
+    std::string op_device = "";
+    if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kOpDeviceAttr)) {
+      op_device = device_attr.getValue().str();
+    }
+    std::string dst_host = GetHost(op_device);
+
+    for (mlir::Value arg : op->getOperands()) {
+      // Get the host address of the argument.
+      std::string arg_device = "";
+      if (BlockArgument block_arg = arg.dyn_cast<BlockArgument>()) {
+        // Do not send this argument if it is not a function's argument. This
+        // can happen when the argument is a while loop's argument.
+        if (block_arg.getParentRegion() != &func_op.getRegion()) continue;
+
+        if (StringAttr device_attr = func_op.getArgAttrOfType<StringAttr>(
+                block_arg.getArgNumber(), kArgDeviceAttr)) {
+          arg_device = device_attr.getValue().str();
+        }
+      } else {
+        Operation *defining_op = arg.getDefiningOp();
+        if (StringAttr device_attr =
+                defining_op->getAttrOfType<StringAttr>(kOpDeviceAttr)) {
+          arg_device = device_attr.getValue().str();
+        }
+      }
+      std::string src_host = GetHost(arg_device);
+
+      if (src_host == dst_host) continue;
+
+      // Re-use the transferred argument if the argument has already been
+      // transferred to the given host.
+      llvm::StringMap<mlir::Value> &transferred_value_by_host =
+          transferred_value_by_value_and_host[arg];
+      auto iter = transferred_value_by_host.find(dst_host);
+      if (iter != transferred_value_by_host.end()) {
+        op->replaceUsesOfWith(arg, iter->second);
+        continue;
+      }
+
+      // Create tf_device.send and tf_device.receive ops to send the argument to
+      // the same host of the operation.
+      std::string key = "key-" + std::to_string(key_count);
+      key_count++;
+
+      auto send_op =
+          builder.create<tf_device::SendOp>(op->getLoc(), arg, key, dst_host);
+      send_op->setAttr(kOpDeviceAttr,
+                       builder.getStringAttr(src_host + kCPUDevice));
+
+      auto receive_op = builder.create<tf_device::ReceiveOp>(
+          op->getLoc(), arg.getType(), key, src_host);
+      receive_op->setAttr(kOpDeviceAttr,
+                          builder.getStringAttr(dst_host + kCPUDevice));
+
+      transferred_value_by_host[dst_host] = receive_op.getResult();
+      op->replaceUsesOfWith(arg, receive_op.getResult());
+    }
+    return WalkResult::advance();
+  });
+}
+
+}  // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass() {
+  return std::make_unique<CrossHostTransferPass>();
+}
+
+static PassRegistration<CrossHostTransferPass> pass(
+    "tf-cross-host-transfer",
+    "This pass inserts tf_device.send and tf_device.receive ops to make sure "
+    "any argument of any op is on the same host of the op itself.");
+
+}  // namespace TF
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc
index d309c6d..09fac6e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc
@@ -18,8 +18,8 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc
index 28a5c58..7701d96 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc
@@ -15,7 +15,7 @@
 
 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
 
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td
index 7d13d60..ba53f22 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td
@@ -458,7 +458,7 @@
      (TF_ConstOp:$zero (GetScalarOfType<0> $grad)),
      (TF_AddV2Op:$accum_new
        (CreateTFReadVariableOp $src_op, $grad, $accum_resource),
-       (TF_SqrtOp $grad)),
+       (TF_SquareOp $grad)),
      (TF_MulOp:$adagrad_lr $lr, (TF_RsqrtOp $accum_new)),
      (TF_SubOp:$prox_var
        (CreateTFReadVariableOp $src_op, $grad, $var_resource),
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc
new file mode 100644
index 0000000..cd55dbc
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc
@@ -0,0 +1,53 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+
+namespace mlir {
+namespace TF {
+
+namespace {
+
+constexpr char kShapeInvariantAttr[] = "shape_invariant";
+
+// Drop `shape_invariant` attribute from tf.While and tf.WhileRegion op. This
+// would allow shape inference pass to further refine operand/result shapes of
+// these ops. This is only safe to do when compiling to XLA.
+class DropWhileShapeInvariantPass
+    : public PassWrapper<DropWhileShapeInvariantPass, FunctionPass> {
+  void runOnFunction() override;
+};
+
+void DropWhileShapeInvariantPass::runOnFunction() {
+  getFunction().walk([](Operation* op) {
+    if (llvm::isa<WhileOp, WhileRegionOp>(op))
+      op->removeAttr(kShapeInvariantAttr);
+  });
+}
+
+static PassRegistration<DropWhileShapeInvariantPass> pass(
+    "tf-drop-while-shape-invariant",
+    "Drop `shape_invariant` attrbute from While/WhileRegion ops.");
+
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> CreateDropWhileShapeInvariantPass() {
+  return std::make_unique<DropWhileShapeInvariantPass>();
+}
+
+}  // namespace TF
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
index 35e9e90..6277b1d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
@@ -33,9 +33,9 @@
 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h
index 490fe1e..65e0528 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h
@@ -24,11 +24,11 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Casting.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
index 2f612ac..801cbbb 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
@@ -66,7 +66,7 @@
 // that is closest to the island in the graph. If no candidate can be found or
 // the op found is not an island, an empty optional is returned.
 llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
-  Operation* graph_op = island.getParentOp();
+  Operation* graph_op = island->getParentOp();
   Operation* candidate = nullptr;
 
   // Check island control operands.
@@ -95,7 +95,7 @@
 // an op, that is closest to the island in the graph. If no candidate can be
 // found or the op found is not an island, an empty optional is returned.
 llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
-  Operation* graph_op = island.getParentOp();
+  Operation* graph_op = island->getParentOp();
   Operation* candidate = nullptr;
 
   // Check island control results.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc
index 278e283..e47f2fd 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc
@@ -68,9 +68,9 @@
     return signalPassFailure();
   }
   ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc());
-  outlined_module.setAttrs(getOperation().getAttrs());
-  outlined_module.setAttr(SymbolTable::getSymbolAttrName(),
-                          StringAttr::get(kNestedModule, ctx));
+  outlined_module->setAttrs(getOperation().getAttrs());
+  outlined_module->setAttr(SymbolTable::getSymbolAttrName(),
+                           StringAttr::get(kNestedModule, ctx));
   symbol_table.insert(outlined_module);
   SymbolTable outlined_symbol_table(outlined_module);
 
@@ -78,7 +78,7 @@
   // in a new module to run the V1 bridge there.
   SmallVector<IslandOp, 8> islands_to_outline;
   getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) {
-    auto island_op = cast<IslandOp>(replicate_op.getParentOp());
+    auto island_op = cast<IslandOp>(replicate_op->getParentOp());
     if (!island_op || island_op.WrapsSingleOp()) return;
     islands_to_outline.push_back(island_op);
   });
@@ -100,7 +100,7 @@
     for (Value operand : island_op.GetYield().getOperands())
       func_result_types.push_back(operand.getType());
     FunctionType func_type =
-        FunctionType::get(func_operand_types, func_result_types, ctx);
+        FunctionType::get(ctx, func_operand_types, func_result_types);
 
     // Create the outlined function
     SmallString<32> name = kOutlinedFuncPrefix;
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
index ce949ef..b71c9dd 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
@@ -21,9 +21,9 @@
 #include "llvm/Support/Casting.h"
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
index cc24c98..f00c00d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
@@ -34,11 +34,11 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Block.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc
index a5d7661..6adce66 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc
@@ -17,6 +17,7 @@
 // TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
 
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
@@ -42,7 +43,7 @@
 // control flow op into an i1 value.
 static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
   auto zero_d = builder->create<ToBoolOp>(loc, value);
-  auto scalar = builder->create<ExtractElementOp>(loc, zero_d);
+  auto scalar = builder->create<tensor::ExtractOp>(loc, zero_d);
   return scalar.getResult();
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc
index 0e0e874..1bf12a1 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc
@@ -22,8 +22,8 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/IR/Verifier.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
index 26c0126..e2db7a4 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
@@ -27,6 +27,7 @@
 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
 
 namespace mlir {
 namespace tf_executor {
@@ -40,7 +41,7 @@
 // "tf.entry_function" attribute defined.
 bool CanPruneGraph(FuncOp func) {
   return func.getName() != "main" ||
-         func.getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
+         func->getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
 }
 
 // Visits an op's operand if it is an output of an Operation in the same
@@ -120,7 +121,8 @@
 namespace {
 
 // This transformation pass prunes a TF graph eliminating dead-nodes.
-struct GraphPruning : public PassWrapper<GraphPruning, FunctionPass> {
+struct GraphPruningPass
+    : public TF::ExecutorGraphPruningPassBase<GraphPruningPass> {
   void runOnFunction() override {
     if (!CanPruneGraph(getFunction())) return;
     getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
@@ -130,12 +132,8 @@
 }  // namespace
 
 std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass() {
-  return std::make_unique<GraphPruning>();
+  return std::make_unique<GraphPruningPass>();
 }
 
-static PassRegistration<GraphPruning> pass(
-    "tf-executor-graph-pruning",
-    "Prune unreachable nodes in a TensorFlow Graph.");
-
 }  // namespace tf_executor
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc
index 3d89454..f715890 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc
@@ -124,7 +124,7 @@
 
   // Get runtime devices information from the closest parent module.
   RuntimeDevices devices;
-  if (failed(::tensorflow::GetDevicesFromOp(func.getParentOfType<ModuleOp>(),
+  if (failed(::tensorflow::GetDevicesFromOp(func->getParentOfType<ModuleOp>(),
                                             &devices)))
     return signalPassFailure();
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
index 9a13c01..d7c8506 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
@@ -31,12 +31,12 @@
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc
index ea34696..12bc3b3 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc
@@ -213,9 +213,9 @@
     }
 
     // Update the function type.
-    func.setType(mlir::FunctionType::get(func.getArgumentTypes(),
-                                         func.getType().getResults(),
-                                         module.getContext()));
+    func.setType(mlir::FunctionType::get(module.getContext(),
+                                         func.getArgumentTypes(),
+                                         func.getType().getResults()));
   }
   return success();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
index e48e155..0f70647 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
@@ -20,11 +20,11 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeRange.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
index f667080..1dbae15 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
@@ -298,3 +298,19 @@
       (TF_TensorScatterAddOp
        (TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
        $indices, $updates)>;
+
+//===----------------------------------------------------------------------===//
+// Xdivy, Xlog1p and Xlogy op patterns.
+//===----------------------------------------------------------------------===//
+
+class BinaryXopyPat<dag From, dag To>
+  : Pat<From,
+        (TF_SelectV2Op (TF_EqualOp $x,
+                                   (TF_ConstOp:$zero (GetScalarOfType<0> $x)),
+                       /*incompatible_shape_error*/ConstBoolAttrTrue),
+           $zero, To)>;
+
+foreach fromToPair = [[(TF_XdivyOp $x, $y), (TF_DivOp $x, $y)],
+                      [(TF_Xlog1pyOp $x, $y), (TF_MulOp $x, (TF_Log1pOp $y))],
+                      [(TF_XlogyOp $x, $y), (TF_MulOp $x, (TF_LogOp $y))]] in
+  def : BinaryXopyPat<fromToPair[0], fromToPair[1]>;
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
index ddb4303..f9a85b7 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
@@ -28,6 +28,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
 #include "tensorflow/core/lib/monitoring/gauge.h"
 
@@ -44,13 +45,9 @@
         "/tensorflow/core/use_auto_outside_compilation",
         "Tracks if auto outside compilation is enabled");
 
-// This pass marks unsupported ops in a device cluster with
-// `_xla_outside_compilation` attribute so the operations will run on the host
-// instead of the device.  Unsupported ops are ops that can not be code
-// generated to run on the device for the cluster.
 struct MarkOpsForOutsideCompilation
-    : public PassWrapper<MarkOpsForOutsideCompilation,
-                         OperationPass<ModuleOp>> {
+    : public TF::MarkOpsForOutsideCompilationPassBase<
+          MarkOpsForOutsideCompilation> {
   void runOnOperation() override;
 };
 
@@ -201,6 +198,9 @@
   int outside_compiled_cluster_counter = 0;
   block->walk([&](Operation* op) {
     if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
+      VLOG(3) << "Cloud TPU: Op " << op->getName().getStringRef().str()
+              << " isn't compilable, adding outside_compilation attr. "
+                 "This op will automatically be placed on CPU.";
       op->setAttr(
           kXlaOutsideCompilationAttr,
           StringAttr::get(
@@ -264,7 +264,7 @@
     // Only if `allow_soft_placement` attribute is true should we mark ops
     // for outside compilation.
     auto soft_placement_attr =
-        cluster.getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
+        cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
     if (!(soft_placement_attr && soft_placement_attr.getValue())) {
       return WalkResult::advance();
     }
@@ -281,7 +281,7 @@
     // Only if `allow_soft_placement` attribute is true should we unmark ops
     // for outside compilation.
     auto soft_placement_attr =
-        cluster.getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
+        cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
     if (!(soft_placement_attr && soft_placement_attr.getValue())) {
       return;
     }
@@ -296,9 +296,5 @@
   return std::make_unique<MarkOpsForOutsideCompilation>();
 }
 
-static PassRegistration<MarkOpsForOutsideCompilation> pass(
-    "tf-mark-ops-for-outside-compilation",
-    "Marks unsupported ops a device cluster for outside compilation.");
-
 }  // namespace TFDevice
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc
new file mode 100644
index 0000000..48cf6b2
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc
@@ -0,0 +1,315 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <queue>
+#include <string>
+#include <utility>
+
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/IR/TypeUtilities.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+
+namespace mlir {
+namespace TFDevice {
+
+namespace {
+
+// This pass merges IfRegion ops together if they have the same predicate and it
+// is safe to do so (there are no intermediate dependencies, they are in the
+// same block, etc).
+//
+// A simple example:
+//    "tf.IfRegion"(%0) ( {
+//      %2 = "tf.A"() : () -> (tensor<f32>)
+//      "tf.Yield"() : () -> ()
+//      }, {
+//      "tf.Yield"() : () -> ()
+//     }) { is_stateless = true } : (tensor<i1>) -> ()
+//    "tf.IfRegion"(%0) ( {
+//      %2 = "tf.B"() : () -> (tensor<f32>)
+//      "tf.Yield"() : () -> ()
+//      }, {
+//      "tf.Yield"() : () -> ()
+//     }) { is_stateless = true } : (tensor<i1>) -> ()
+// Would become:
+//    "tf.IfRegion"(%0) ( {
+//      %2 = "tf.A"() : () -> (tensor<f32>)
+//      %3 = "tf.B"() : () -> (tensor<f32>)
+//      "tf.Yield"() : () -> ()
+//      }, {
+//      "tf.Yield"() : () -> ()
+//     }) { is_stateless = true } : (tensor<i1>) -> ()
+
+struct MergeControlFlow : public TF::PerFunctionAggregateAnalysisConsumerPass<
+                              MergeControlFlow, TF::SideEffectAnalysis> {
+  void runOnFunction(FuncOp func,
+                     const TF::SideEffectAnalysis::Info& side_effect_analysis);
+};
+
+// Returns whether it is safe to merge `source` IfRegion into `destination`
+// IfRegion. `source` must come after `destination`.
+bool SafeToMerge(TF::IfRegionOp source, TF::IfRegionOp destination,
+                 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
+  // IfRegion ops must be in the same block.
+  if (source.getOperation()->getBlock() !=
+      destination.getOperation()->getBlock())
+    return false;
+  assert(destination.getOperation()->isBeforeInBlock(source.getOperation()));
+
+  llvm::SmallSetVector<Operation*, 4> source_ops;
+  source_ops.insert(source);
+  for (Operation& op : source.then_branch().front()) {
+    source_ops.insert(&op);
+  }
+  for (Operation& op : source.else_branch().front()) {
+    source_ops.insert(&op);
+  }
+
+  // If there is an intermediate data or side effect dependency between the
+  // ops in destination and the ops in the source, it's not safe to merge
+  // them.
+  llvm::SmallSetVector<Operation*, 4> op_stack;
+  for (auto* user : destination.getOperation()->getUsers()) {
+    if (!source_ops.contains(user)) op_stack.insert(user);
+  }
+  for (Operation& op : destination.then_branch().front()) {
+    for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
+      if (!source_ops.contains(successor)) op_stack.insert(successor);
+    }
+  }
+  for (Operation& op : destination.else_branch().front()) {
+    for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
+      if (!source_ops.contains(successor)) op_stack.insert(successor);
+    }
+  }
+
+  bool safe_to_merge = true;
+
+  while (!op_stack.empty()) {
+    auto* next_op = op_stack.pop_back_val();
+    for (auto* user : next_op->getUsers()) {
+      if (source_ops.contains(user)) {
+        safe_to_merge = false;
+        break;
+      } else {
+        op_stack.insert(user);
+      }
+    }
+    for (auto* successor :
+         side_effect_analysis.DirectControlSuccessors(next_op)) {
+      if (source_ops.contains(successor)) {
+        safe_to_merge = false;
+        break;
+      } else {
+        op_stack.insert(successor);
+      }
+    }
+    if (!safe_to_merge) break;
+  }
+  return safe_to_merge;
+}
+
+// Checks whether a return indice should be kep for `first_if_op` by checking
+// for results in `second_if_op`.
+llvm::SmallVector<int, 4> GetReturnIndicesToKeep(TF::IfRegionOp first_if_op,
+                                                 TF::IfRegionOp second_if_op) {
+  llvm::SmallVector<int, 4> return_indices_to_keep;
+  for (auto& index_and_value : llvm::enumerate(first_if_op.getResults())) {
+    if (!llvm::all_of(index_and_value.value().getUsers(), [&](Operation* op) {
+          return second_if_op->isProperAncestor(op);
+        })) {
+      return_indices_to_keep.push_back(index_and_value.index());
+    }
+  }
+  return return_indices_to_keep;
+}
+
+// Move the body excluding the terminators of else and then regions from
+// 'source' to 'destination'.
+void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) {
+  Block& destination_then_block = destination.then_branch().front();
+  auto& source_then_body = source.then_branch().front().getOperations();
+  destination_then_block.getOperations().splice(
+      destination_then_block.without_terminator().end(), source_then_body,
+      source_then_body.begin(), std::prev(source_then_body.end()));
+
+  Block& destination_else_block = destination.else_branch().front();
+  auto& source_else_body = source.else_branch().front().getOperations();
+  destination_else_block.getOperations().splice(
+      destination_else_block.without_terminator().end(), source_else_body,
+      source_else_body.begin(), std::prev(source_else_body.end()));
+}
+
+// Move all ops that depends on the results from `result_op` after `after_op`.
+void MoveResultsAfter(Operation* result_op, Operation* after_op) {
+  std::queue<Operation*> queue;
+  for (Operation* user : result_op->getUsers()) {
+    queue.push(user);
+  }
+  while (!queue.empty()) {
+    auto* op = queue.front();
+    queue.pop();
+    for (Operation* user : op->getUsers()) queue.push(user);
+    if (op->isBeforeInBlock(after_op)) op->moveAfter(after_op);
+    after_op = op;
+  }
+}
+
+TF::IfRegionOp CreateMergedIf(ArrayRef<int> source_return_indices_to_keep,
+                              ArrayRef<int> destination_return_indices_to_keep,
+                              TF::IfRegionOp source,
+                              TF::IfRegionOp destination) {
+  llvm::SmallVector<Type, 4> merged_return_types;
+  for (int i : destination_return_indices_to_keep)
+    merged_return_types.push_back(destination.getResult(i).getType());
+  for (int i : source_return_indices_to_keep)
+    merged_return_types.push_back(source.getResult(i).getType());
+
+  OpBuilder builder(destination);
+  // Create new IfRegion with correct merged results.
+  builder.setInsertionPoint(source.getOperation());
+
+  auto new_if_op = builder.create<TF::IfRegionOp>(
+      destination.getLoc(), merged_return_types, destination.cond(),
+      destination.is_stateless() && source.is_stateless());
+  new_if_op.then_branch().push_back(new Block);
+  new_if_op.else_branch().push_back(new Block);
+  // Replace internal usages of merged if ops.
+  for (OpResult result : destination.getResults()) {
+    replaceAllUsesInRegionWith(
+        result,
+        destination.then_branch().front().getTerminator()->getOperand(
+            result.getResultNumber()),
+        source.then_branch());
+    replaceAllUsesInRegionWith(
+        result,
+        destination.else_branch().front().getTerminator()->getOperand(
+            result.getResultNumber()),
+        source.else_branch());
+  }
+
+  MoveResultsAfter(destination.getOperation(), new_if_op.getOperation());
+
+  // Replace external usages of merged if ops.
+  int new_return_index = 0;
+  for (int i : destination_return_indices_to_keep) {
+    destination.getResult(i).replaceAllUsesWith(
+        new_if_op.getResult(new_return_index++));
+  }
+  for (int i : source_return_indices_to_keep) {
+    source.getResult(i).replaceAllUsesWith(
+        new_if_op.getResult(new_return_index++));
+  }
+
+  // Create the Yield ops for both branches with merged results.
+  llvm::SmallVector<Value, 4> merged_then_yield_values;
+  for (int i : destination_return_indices_to_keep)
+    merged_then_yield_values.push_back(
+        destination.then_branch().front().getTerminator()->getOperand(i));
+  for (int i : source_return_indices_to_keep)
+    merged_then_yield_values.push_back(
+        source.then_branch().front().getTerminator()->getOperand(i));
+  builder.setInsertionPointToEnd(&new_if_op.then_branch().front());
+  builder.create<TF::YieldOp>(
+      destination.then_branch().front().getTerminator()->getLoc(),
+      /*operands=*/merged_then_yield_values);
+
+  llvm::SmallVector<Value, 4> merged_else_yield_values;
+  for (int i : destination_return_indices_to_keep)
+    merged_else_yield_values.push_back(
+        destination.else_branch().front().getTerminator()->getOperand(i));
+  for (int i : source_return_indices_to_keep)
+    merged_else_yield_values.push_back(
+        source.else_branch().front().getTerminator()->getOperand(i));
+  builder.setInsertionPointToEnd(&new_if_op.else_branch().front());
+  builder.create<TF::YieldOp>(
+      destination.else_branch().front().getTerminator()->getLoc(),
+      /*operands=*/merged_else_yield_values);
+
+  // Merge the two branch regions from both IfRegionOps into new IfRegionOp.
+  MoveBranches(/*source=*/destination, /*destination=*/new_if_op);
+  destination.erase();
+  MoveBranches(/*source=*/source, /*destination=*/new_if_op);
+  source.erase();
+  return new_if_op;
+}
+
+// Groups if regions by common predicate and attemps to merge them.
+void OptimizeIfRegions(
+    Block* block, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
+  // Determine IfRegions with the same predicate.
+  llvm::SmallDenseMap<Value, llvm::SmallVector<TF::IfRegionOp, 8>, 8>
+      grouped_if_ops;
+  block->walk([&](TF::IfRegionOp if_op) {
+    auto it = grouped_if_ops.try_emplace(if_op.cond());
+    it.first->getSecond().push_back(if_op);
+  });
+
+  for (auto& entry : grouped_if_ops) {
+    auto& if_ops = entry.second;
+    for (auto it = if_ops.begin(); it != if_ops.end(); ++it) {
+      TF::IfRegionOp first_if_op = *it;
+      for (auto it2 = std::next(it); it2 != if_ops.end(); ++it2) {
+        TF::IfRegionOp second_if_op = *it2;
+        if (!SafeToMerge(second_if_op, first_if_op, side_effect_analysis))
+          break;
+
+        // For both check if there are uses outside of IfRegion, keep these as
+        // part of the return and replace the internal uses.
+        auto first_return_indices_to_keep =
+            GetReturnIndicesToKeep(first_if_op, second_if_op);
+        auto second_return_indices_to_keep =
+            GetReturnIndicesToKeep(second_if_op, first_if_op);
+
+        auto new_if_op = CreateMergedIf(second_return_indices_to_keep,
+                                        first_return_indices_to_keep,
+                                        second_if_op, first_if_op);
+
+        if_ops.erase(it2--);
+        first_if_op = new_if_op;
+      }
+    }
+  }
+}
+
+void MergeControlFlow::runOnFunction(
+    FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
+  auto result = func.walk([&](tf_device::ClusterOp cluster) {
+    OptimizeIfRegions(&cluster.GetBody(), side_effect_analysis);
+    return WalkResult::advance();
+  });
+
+  if (result.wasInterrupted()) return signalPassFailure();
+}
+
+}  // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass() {
+  return std::make_unique<MergeControlFlow>();
+}
+
+static PassRegistration<MergeControlFlow> pass(
+    "tf-merge-control-flow", "Merges control flow with a common predicate.");
+}  // namespace TFDevice
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc
index 540527c..e0e0c45 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc
@@ -23,8 +23,8 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc
index 86eea50..43f7fbb 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc
@@ -165,7 +165,7 @@
       unused_execute_controls.push_back(execute.control());
 
   if (!unused_execute_controls.empty()) {
-    auto graph_op = island_op.getParentOfType<tf_executor::GraphOp>();
+    auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
     tf_executor::FetchOp fetch = graph_op.GetFetch();
     auto fetches = llvm::to_vector<8>(fetch.getOperands());
     fetches.append(unused_execute_controls.begin(),
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index e7a6e0e..ca89c7d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -39,6 +39,10 @@
 CreateExecutorDialectToFunctionalConversionPass();
 
 namespace TF {
+// Creates a pass that drops `shape_invariant` attribute from While/WhileRegion
+// ops.
+std::unique_ptr<OperationPass<FuncOp>> CreateDropWhileShapeInvariantPass();
+
 // Transforms functional control flow operations in the TensorFlow dialect to
 // MLIR Control Flow Graph (CFG) form.
 std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();
@@ -195,6 +199,10 @@
 // assignment of the result.
 std::unique_ptr<FunctionPass> CreateClusterTFOpsByHostPass();
 
+// Creates a pass to insert tf_device.send and tf_device.receive ops to make
+// sure any argument of any op is on the same host of the op itself.
+std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass();
+
 // Creates a pass that adds the device attribute to every tf.Const op based on
 // the device attribute of the operations that read its result. If the result of
 // a tf.Const op is read by operations placed on multiple devices, then the pass
@@ -229,11 +237,6 @@
 
 // Creates a pass to prune tf_executor.graph from dead nodes.
 std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass();
-
-// Sink `tf.Const` operations in the LaunchOp region using them. This is
-// performed in order to limit the number of values implicitly captured in this
-// region before outlining.
-std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass();
 }  // namespace tf_executor
 
 namespace TFDevice {
@@ -241,6 +244,11 @@
 // same device.
 std::unique_ptr<OperationPass<FuncOp>> CreateClusterFormationPass();
 
+// Sinks `tf.Const` operations in the ClusterOp region using them. This is
+// performed in order to limit the number of values implicitly captured in this
+// region before outlining.
+std::unique_ptr<OperationPass<FuncOp>> CreateClusterConstantSinkingPass();
+
 // Creates a pass that outlines regions of tf_device.launch operations.
 std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass();
 
@@ -282,6 +290,9 @@
 std::unique_ptr<OperationPass<ModuleOp>>
 CreateMarkOpsForOutsideCompilationPass();
 
+// Creates a pass that merges control flow with similar predicates.
+std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass();
+
 // Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
 // attribute to each TensorFlow dialect op in the body based on the `device`
 // attribute on the `tf_device.launch`.
@@ -394,6 +405,9 @@
 
 }  // namespace TFTPU
 
+#define GEN_PASS_REGISTRATION
+#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
+
 }  // namespace mlir
 
 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
index ca45645..0a23912 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
@@ -59,7 +59,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
@@ -180,8 +180,8 @@
   }
 
   if (!var_handle_shared_names->empty())
-    function.setType(FunctionType::get(func_arg_types, func_type.getResults(),
-                                       function.getContext()));
+    function.setType(FunctionType::get(function.getContext(), func_arg_types,
+                                       func_type.getResults()));
 
   return success();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc
index 4c65162..55f4168 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc
@@ -22,7 +22,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
@@ -138,7 +138,8 @@
     ShapedType shaped_type =
         variable_v2_op.getResult().getType().cast<ShapedType>();
     TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
-    StringAttr device_attr = variable_v2_op.getAttrOfType<StringAttr>("device");
+    StringAttr device_attr =
+        variable_v2_op->getAttrOfType<StringAttr>("device");
     if (!device_attr) device_attr = builder.getStringAttr("");
     StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
     if (variable_name.empty()) {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
index f800450..0c5aa72 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
@@ -121,7 +121,7 @@
   if (extern_values_passthrough)
     for (auto input : extern_values) return_types.push_back(input.getType());
 
-  auto type = FunctionType::get(input_types, return_types, region.getContext());
+  auto type = FunctionType::get(region.getContext(), input_types, return_types);
 
   // Create new function and extract region body into the function.
   auto outlined_func = builder.create<FuncOp>(loc, name, type);
@@ -210,8 +210,8 @@
 bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
   if (first.getNumOperands() != second.getNumOperands()) return false;
 
-  Region& first_region = *first.getParentRegion();
-  Region& second_region = *second.getParentRegion();
+  Region& first_region = *first->getParentRegion();
+  Region& second_region = *second->getParentRegion();
 
   for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
     // Get the defining Op, skipping over casts.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
index c051c50..86a80fd 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
@@ -121,7 +121,7 @@
     // Map aliased devices to explicit devices based on replica.
     if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
       if (auto device_by_replica = devices.getValue().get(launch.device()))
-        launch.setAttr(
+        launch->setAttr(
             kDeviceAttr,
             device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());
 
@@ -316,7 +316,7 @@
   });
 
   for (tf_executor::IslandOp island_op : replicate_op_islands) {
-    auto graph_op = island_op.getParentOfType<tf_executor::GraphOp>();
+    auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
     auto replicate_op =
         cast<tf_device::ReplicateOp>(island_op.GetBody().front());
     if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc
index f58e0bf..a700deb 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc
@@ -192,7 +192,7 @@
             if (auto device = result->DeviceForResource(output)) {
               LLVM_DEBUG(llvm::dbgs()
                          << " Setting device = " << *device << "\n");
-              identity.setAttr(kDeviceAttr, builder.getStringAttr(*device));
+              identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
             }
           }
         } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
index 46fe6ae..9dd6671 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
@@ -33,10 +33,10 @@
 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/Region.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
@@ -785,9 +785,9 @@
     }
   }
   func_op.eraseArguments(indices_to_erase);
-  func_op.setType(FunctionType::get(
-      new_types, llvm::to_vector<4>(return_op->getOperandTypes()),
-      func_op.getContext()));
+  func_op.setType(
+      FunctionType::get(func_op.getContext(), new_types,
+                        llvm::to_vector<4>(return_op->getOperandTypes())));
 }
 
 // Lifts reads/writes of resource arguments from func_op and changes its
@@ -841,10 +841,9 @@
     assign_variable_op.erase();
   }
 
-  func_op.setType(
-      FunctionType::get(func_op.front().getArgumentTypes(),
-                        func_op.front().getTerminator()->getOperandTypes(),
-                        func_op.getContext()));
+  func_op.setType(FunctionType::get(
+      func_op.getContext(), func_op.front().getArgumentTypes(),
+      func_op.front().getTerminator()->getOperandTypes()));
 
   return success();
 }
@@ -1106,7 +1105,7 @@
 
   // Clone the callee before making changes.
   SmallString<64> name_base = callee.getName();
-  auto module = callee.getParentOfType<ModuleOp>();
+  auto module = callee->getParentOfType<ModuleOp>();
   name_base += "_resource_lifted";
   auto name = name_base;
   callee = callee.clone();
@@ -1153,9 +1152,9 @@
   auto new_return =
       builder.create<ReturnOp>(old_return->getLoc(), old_and_new_retvals);
   old_return->erase();
-  callee.setType(FunctionType::get(
-      callee.getType().getInputs(),
-      llvm::to_vector<4>(new_return.getOperandTypes()), callee.getContext()));
+  callee.setType(
+      FunctionType::get(callee.getContext(), callee.getType().getInputs(),
+                        llvm::to_vector<4>(new_return.getOperandTypes())));
   return success();
 }
 
@@ -1180,7 +1179,7 @@
   auto new_call = builder.create<CallOpType>(
       call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
       new_operands, call_op.getAttrs());
-  new_call.setAttr(
+  new_call->setAttr(
       "f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
   AddLoadsStoresOutsideControlFlowOp(
       new_call, lifting_info.arg_data_type_and_updated_output_index);
@@ -1376,7 +1375,7 @@
   llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
       lifted_partitioned_call_callees;
   if (failed(HoistForControlFlow(
-          &function.front(), cast<ModuleOp>(function.getParentOp()),
+          &function.front(), cast<ModuleOp>(function->getParentOp()),
           /*vars_initialized=*/false, &lifted_partitioned_call_callees)))
     return failure();
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
index a3754d0..6a149d3 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
@@ -117,7 +117,7 @@
 // multiple uses or unknown uses (for external functions). The cloned function
 // will be marked as private.
 FuncOp CloneFunctionIfNeeded(FuncOp func) {
-  ModuleOp module = func.getParentOfType<ModuleOp>();
+  ModuleOp module = func->getParentOfType<ModuleOp>();
   auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
   if (func_uses.hasValue() && llvm::hasSingleElement(func_uses.getValue()))
     return func;
@@ -184,9 +184,9 @@
   // Patch up function types (with less number of return values and potentially
   // less number of arguments)
   for (FuncOp func : cloned_branches) {
-    func.setType(FunctionType::get(
-        func.front().getArgumentTypes(),
-        func.front().getTerminator()->getOperandTypes(), func.getContext()));
+    func.setType(
+        FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
+                          func.front().getTerminator()->getOperandTypes()));
   }
 
   EliminateUnusedResults(op);
@@ -232,9 +232,9 @@
 
   // Patch up branch function types.
   for (FuncOp func : {cloned_cond, cloned_body}) {
-    func.setType(FunctionType::get(
-        func.front().getArgumentTypes(),
-        func.front().getTerminator()->getOperandTypes(), func.getContext()));
+    func.setType(
+        FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
+                          func.front().getTerminator()->getOperandTypes()));
   }
   EliminateUnusedResults(op, &can_eliminate);
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc
index 5ea341a..c2b8a07 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc
@@ -15,7 +15,7 @@
 
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@@ -94,8 +94,8 @@
 
     auto new_send_op = AddOperandAndRewriteAs<_SendTPUEmbeddingGradientsOp>(
         send_op, dedup_op, &builder);
-    new_send_op.setAttr(new_send_op.getOperandSegmentSizeAttr(),
-                        operand_size_attr);
+    new_send_op->setAttr(new_send_op.getOperandSegmentSizeAttr(),
+                         operand_size_attr);
   }
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 96c2927..f198ff3 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -34,11 +34,11 @@
 #include "mlir/IR/Block.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
@@ -247,7 +247,7 @@
       continue;
     }
     if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
-      Operation* parent = yield.getParentOp();
+      Operation* parent = yield->getParentOp();
       if (!CanInferTensorListElementType(
               parent->getResult(use.getOperandNumber()), initial_element_shape,
               potential_element_type))
@@ -619,7 +619,7 @@
 ArrayRef<FuncOp> ShapeInference::GetCallers(FuncOp fn) {
   auto pair = callers_of_func_.try_emplace(fn);
   if (pair.second) {
-    ModuleOp module = fn.getParentOfType<ModuleOp>();
+    ModuleOp module = fn->getParentOfType<ModuleOp>();
     auto uses = mlir::SymbolTable::getSymbolUses(fn.getOperation(), module);
     if (uses) {
       pair.first->second.reserve(pair.first->second.size());
@@ -1150,8 +1150,8 @@
     }
 
     FunctionType func_type = func.getType();
-    func.setType(FunctionType::get(input_types, func_type.getResults(),
-                                   func.getContext()));
+    func.setType(FunctionType::get(func.getContext(), input_types,
+                                   func_type.getResults()));
 
     auto res =
         PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration);
@@ -1493,8 +1493,8 @@
   }
 
   DCOMMENT("Updating function type");
-  func.setType(FunctionType::get(
-      func.getArgumentTypes(), return_op.getOperandTypes(), func.getContext()));
+  func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(),
+                                 return_op.getOperandTypes()));
 
   if (changed) EnqueueCallers(func);
 }
@@ -1611,8 +1611,8 @@
     return failure();
 
   context.InferShapeForFunctionReturnType(func);
-  func.setType(FunctionType::get(new_arg_types, func.getType().getResults(),
-                                 func.getContext()));
+  func.setType(FunctionType::get(func.getContext(), new_arg_types,
+                                 func.getType().getResults()));
 
   return success();
 }
@@ -1627,8 +1627,10 @@
     return success();
   }
   int64_t producer = producer_or.ValueOrDie();
+  // TODO(jpienaar): Clean up propagate_caller_callee_constants if it is no
+  // longer needed.
   ShapeInference context(producer, module.getContext(),
-                         /*propagate_caller_callee_constants=*/true);
+                         /*propagate_caller_callee_constants=*/false);
   if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
     context.enqueue(main);
   for (auto func : module.getOps<FuncOp>()) context.enqueue(func);
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc
index 36f62a7..9d77164 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -36,9 +36,6 @@
   }
 };
 
-PassRegistration<ShapeInference> pass(
-    "tf-shape-inference", "Simple Shape Inference on TensorFlow Dialect");
-
 }  // namespace
 
 std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass() {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc
index e62df78..9065003 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc
@@ -25,21 +25,21 @@
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
-#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
 
 #define DEBUG_TYPE "tf-executor-sink-constant"
 
 namespace mlir {
-namespace tf_executor {
+namespace TFDevice {
 
 namespace {
 using ::mlir::TF::ConstOp;
 
-class ExecutorConstantSinking
-    : public mlir::PassWrapper<ExecutorConstantSinking, FunctionPass> {
+class ClusterConstantSinkingPass
+    : public TF::ClusterConstantSinkingPassBase<ClusterConstantSinkingPass> {
   void runOnFunction() override {
     getFunction().walk([](tf_device::ClusterOp cluster) {
       LLVM_DEBUG(llvm::dbgs() << "Visit " << *cluster.getOperation() << "\n");
@@ -82,16 +82,11 @@
   }
 };
 
-static mlir::PassRegistration<ExecutorConstantSinking> pass(
-    "tf-device-constant-sinking",
-    "Sink constants implicitly captured in a tf_device.cluster region. This "
-    "reduces the number of arguments when outlining later.");
-
 }  // anonymous namespace
 
-std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass() {
-  return std::make_unique<ExecutorConstantSinking>();
+std::unique_ptr<OperationPass<FuncOp>> CreateClusterConstantSinkingPass() {
+  return std::make_unique<ClusterConstantSinkingPass>();
 }
 
-}  // namespace tf_executor
+}  // namespace TFDevice
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc
index a36df74..33f9d34 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc
@@ -27,9 +27,9 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
@@ -137,9 +137,9 @@
   if (handle_new_size_vars) {
     handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
   }
-  func.setType(FunctionType::get(
-      new_input_types, func.front().getTerminator()->getOperandTypes(),
-      func.getContext()));
+  func.setType(
+      FunctionType::get(func.getContext(), new_input_types,
+                        func.front().getTerminator()->getOperandTypes()));
 }
 
 // Contains cached information for decomposed callee functions for (stateful)
@@ -307,7 +307,7 @@
     auto new_call = builder.create<CallOp>(
         call.getLoc(), info.decomposed_callee.getType().getResults(),
         new_operands, call.getAttrs());
-    new_call.setAttr(
+    new_call->setAttr(
         "f", builder.getSymbolRefAttr(
                  const_cast<FuncOp&>(info.decomposed_callee).getName()));
     for (int64_t i = 0; i < call.getNumResults(); ++i) {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc
index 27096e7..ab8c86d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc
@@ -27,9 +27,9 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
@@ -460,10 +460,9 @@
 void UpdateFuncType(FuncOp func) {
   llvm::SmallVector<Type, 8> arg_types;
   for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
-  func.setType(FunctionType::get(
-      arg_types,
-      llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
-      func.getContext()));
+  func.setType(
+      FunctionType::get(func.getContext(), arg_types,
+                        func.front().getTerminator()->getOperandTypes()));
 }
 
 // Finds the accessed gradient sources for each tensor array argument.
@@ -752,7 +751,7 @@
     auto new_call = builder.create<CallOp>(
         call.getLoc(), info.decomposed_callee.getType().getResults(),
         new_operands, call.getAttrs());
-    new_call.setAttr(
+    new_call->setAttr(
         "f", builder.getSymbolRefAttr(
                  const_cast<FuncOp&>(info.decomposed_callee).getName()));
     for (const auto& entry : info.ret_forward_input) {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc
index 437fb27..20ac207 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc
@@ -20,8 +20,8 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/Passes.h"
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Pass/PassOptions.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@@ -60,7 +60,7 @@
         arg_device = attr;
       }
 
-      StringAttr op_device = op.getAttrOfType<StringAttr>(kDeviceAttr);
+      StringAttr op_device = op->getAttrOfType<StringAttr>(kDeviceAttr);
       if (!op_device) op_device = empty_string;
       // Skip the folding logic if the argument's device is different from the
       // operation's device.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc
index ec6c899..d2a4659 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc
@@ -23,7 +23,7 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -71,9 +71,9 @@
 void UpdateFuncType(FuncOp func) {
   llvm::SmallVector<Type, 8> arg_types;
   for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
-  func.setType(FunctionType::get(
-      arg_types, func.front().getTerminator()->getOperandTypes(),
-      func.getContext()));
+  func.setType(
+      FunctionType::get(func.getContext(), arg_types,
+                        func.front().getTerminator()->getOperandTypes()));
 }
 
 // Holds the size value of a tensor list and whether the size is statically
@@ -457,7 +457,7 @@
     auto new_call = builder.create<CallOp>(
         call.getLoc(), info.decomposed_callee.getType().getResults(),
         new_operands, call.getAttrs());
-    new_call.setAttr(
+    new_call->setAttr(
         "f", builder.getSymbolRefAttr(
                  const_cast<FuncOp&>(info.decomposed_callee).getName()));
     for (const auto& entry : info.buffer_ret_to_size_ret) {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc
index f2321df..10d0bcc 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc
@@ -15,7 +15,7 @@
 
 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h"
 
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
index 4a8076d..e09a7a0 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
@@ -21,10 +21,324 @@
   let summary = "Simple Shape Inference on TensorFlow Dialect";
   // TODO(jpienaar): Write `description`.
 
-  let constructor = "CreateTFShapeInferencePass()";
+  let constructor = "TF::CreateTFShapeInferencePass()";
 
   let options = [
     Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10",
            "Maximum shape inference iterations">
   ];
 }
+
+def ExecutorGraphPruningPass : FunctionPass<"tf-executor-graph-pruning"> {
+  let summary = "Prunes unreachable ops in a tf_executor.graph";
+
+  let description = [{
+This pass removes ops from a `tf_executor.graph` that are not transitively, via
+data or control dependencies, connected to the associated `tf_executor.fetch`
+op. The order of ops will be preserved. Functions named `main` with no
+`tf.entry_function` attribute will not be pruned, as such graphs/functions may
+have been imported from a V1 TensorFlow graph, where feeds/fetches/targets are
+not provided at certain stages of IR transformation (e.g. pre-placement).
+
+For example, the following:
+
+```mlir
+func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+  %graph = tf_executor.graph {
+    %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+    %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
+    %unreachable_data:2 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+    %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
+    %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
+    %unreachable_control = tf_executor.island wraps "tf.NoOp"() : () -> tensor<i32>
+    tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
+  }
+  return %graph : tensor<i32>
+}
+```
+
+will be transformed into:
+
+```mlir
+func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+  %graph = tf_executor.graph {
+    %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+    %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
+    %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
+    %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
+    tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
+  }
+  return %graph : tensor<i32>
+}
+```
+  }];
+
+  let constructor = "tf_executor::CreateTFExecutorGraphPruningPass()";
+}
+
+def ExecutorDialectToFunctionalPass : FunctionPass<"tf-executor-to-functional-conversion"> {
+  let summary = "Lifts tf_executor.island inner ops from a tf_executor.graph";
+
+  let description = [{
+This pass converts tf_executor.graphs consisting of only tf_executor.islands and
+a tf_executor.fetch into a sea of nodes consisting of TensorFlow Dialect ops by
+lifting such ops out of a tf_executor.graph's tf_executor.islands. If V1 control
+flow ops are present in a tf_executor.graph, an error will be returned.
+
+For example, the following:
+
+```mlir
+func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+  %graph_results:2 = tf_executor.graph {
+    %island_0_result, %island_0_control = tf_executor.island {
+      %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+      tf_executor.yield %identity : tensor<i32>
+    }
+    %island_1_result, %island_1_control = tf_executor.island {
+      %identity_n:2 = "tf.IdentityN"(%arg1, %island_0_result) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
+      tf_executor.yield %identity_n#0
+    }
+    tf_executor.fetch %island_0_result, %island_1_result : tensor<i32>, tensor<i32>
+  }
+  return %graph_results#0, %graph_results#1 : tensor<i32>, tensor<i32>
+}
+```
+
+will be transformed into:
+
+```mlir
+func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+  %identity_n:2 = "tf.IdentityN"(%arg1, %identity) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
+  return %identity, %identity_n#0 : tensor<i32>, tensor<i32>
+}
+```
+  }];
+
+  let constructor = "CreateExecutorDialectToFunctionalConversionPass()";
+}
+
+def TPUClusterFormationPass : Pass<"tf-tpu-cluster-formation", "ModuleOp"> {
+  let summary = "Forms clusters from operations assigned to the same TPU computation";
+
+  let description = [{
+TPU computations from the frontend are composed of a `tf.TPUReplicateMetadata`
+op, a subgraph of ops (TensorFlow Dialect) each with a matching `_tpu_replicate`
+attribute relative to the associated `tf.TPUReplicateMetadata` op, and
+optionally `tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops feeding in
+inputs and outputs to and from a replicated TPU computation. The number of times
+a TPU computation is replicated is defined in the `tf.TPUReplicateMetadata` op
+(`num_replicas` attribute) and operand and result sizes of
+`tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` respectively must match,
+excluding packed tensors. It is also assumed ops of the same TPU computation do
+not have ops outside of the TPU computation that are both inputs and outputs to
+the same TPU computation.
+
+This pass takes the TPU computation subgraph, moves them into a
+`tf_device.cluster`, and copies over attributes from the associated
+`tf.TPUReplicateMetadata` op to the newly created `tf_device.cluster`. If the
+computation is replicated (`num_replicas` > 1), the `num_replicas` attribute is
+not copied over but instead the `tf_device.cluster` is further wrapped with a
+`tf_device.replicate`, and associated `tf.TPUReplicatedInput` and
+`tf.TPUReplicatedOutput` ops are replaced as the `tf_device.replicate` operands
+and results. Otherwise, the single operands and results of the associated
+`tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops are simply forwarded to
+the `tf_device.cluster`.
+
+For example, the following non replicated computation:
+
+```mlir
+func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
+  // Metadata op for cluster `cluster` with 1 replica, 1 core per replica and
+  // with topology `<topology>`.
+  "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "<topology>", device_assignment = [], padding_map = []} : () -> ()
+  %replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor<i32>) -> tensor<i32>
+  %identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
+  %replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> tensor<i32>
+  return %replicated_output : tensor<i32>
+}
+```
+
+will be transformed into:
+
+```mlir
+func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
+  %cluster = "tf_device.cluster"() ( {
+    %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %identity : tensor<i32>
+  }) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
+  return %cluster : tensor<i32>
+}
+```
+
+The following replicated computation:
+
+```mlir
+func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+  "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> ()
+  %replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
+  %replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
+  return %replicated_output#0, %replicated_output#1 : tensor<i32>, tensor<i32>
+}
+```
+
+will be transformed into:
+
+```mlir
+func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+  %replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} {
+    %cluster = "tf_device.cluster"() ( {
+      %identity = "tf.Identity"(%replicated_input) : (tensor<i32>) -> tensor<i32>
+      tf_device.return %identity : tensor<i32>
+    }) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
+    tf_device.return %cluster : tensor<i32>
+  }
+  return %replicate#0, %replicate#1 : tensor<i32>, tensor<i32>
+}
+```
+  }];
+
+  let constructor = "TFTPU::CreateTPUClusterFormationPass()";
+}
+
+def ClusterConstantSinkingPass : FunctionPass<"tf-device-constant-sinking"> {
+  let summary = "Sinks constants implicitly captured in a tf_device.cluster region.";
+
+  let description = [{
+This pass sinks implicitly captured constants (`tf.Const` ops) used by and into
+a `tf_device.cluster` region. Performing this prior to outlining will reduce the
+number of arguments of the outlined function.
+
+For example, the following:
+
+```mlir
+func @cluster() -> tensor<i32> {
+  %const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %cluster = "tf_device.cluster"() ( {
+    %identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %identity : tensor<i32>
+  }) : () -> (tensor<i32>)
+  return %cluster : tensor<i32>
+}
+```
+
+will be transformed into:
+
+```mlir
+func @cluster() -> tensor<i32> {
+  %cluster = "tf_device.cluster"() ( {
+    %const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+    %identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %identity : tensor<i32>
+  }) : () -> (tensor<i32>)
+  return %cluster : tensor<i32>
+}
+```
+  }];
+
+  let constructor = "TFDevice::CreateClusterConstantSinkingPass()";
+}
+
+def TPUExtractOutsideCompilationPass : Pass<"tf-tpu-extract-outside-compilation", "ModuleOp"> {
+  let summary = "Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.";
+
+  let description = [{
+This pass extracts a CPU computation cluster with `_xla_outside_compilation`
+annotation, which denotes ops that should be run on CPU/host, from a TPU cluster.
+Each outside compilation cluster is moved to
+a tf_device.parallel_execute region. The TPU cluster is also moved to a
+tf_device.parallel_execute region. Communication ops between device and host are
+added to pass inputs/outputs to/from the outside compiled region.
+
+For example, the following tf_device.cluster with an op marked for `xla_outside_compilation`:
+
+```mlir
+func @outside_compilation() -> tensor<f32> {
+  %0 = "tf_device.cluster"() ( {
+    %1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
+    %2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor<f32>) -> (tensor<f32>)
+    %3 = "tf.AddV2"(%1, %2) : (tensor<f32>, tensor<f32>) -> (tensor<f32>)
+    tf_device.return %3 : tensor<f32>
+  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<f32>
+  return %0 : tensor<f32>
+}
+```
+
+will become a tf_device.parallel_execute op with a CPU/host region and
+a tf_device.cluster with communication ops to send data to/from device/host:
+
+```mlir
+func @outside_compilation() -> tensor<f32> {
+  %0 = "tf_device.parallel_execute"() ( {
+    "tf_device.launch"() ( {
+      %1 = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf.string>
+      %2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf.string>) -> tensor<f32>
+      %3 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
+      "tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor<f32>, tensor<3x!tf.string>) -> ()
+      tf_device.return
+    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
+    tf_device.return
+  },  {
+    %1 = "tf_device.cluster"() ( {
+      %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+      %3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor<f32>) -> tensor<f32>
+      %4 = "tf.AddV2"(%2, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+      tf_device.return %4 : tensor<f32>
+    }) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<f32>
+    tf_device.return %1 : tensor<f32>
+  }) : () -> tensor<f32>
+  return %0 : tensor<f32>
+}
+```
+  }];
+
+  let constructor = "TFTPU::CreateTPUExtractOutsideCompilationPass()";
+}
+
+def MarkOpsForOutsideCompilationPass : Pass<"tf-mark-ops-for-outside-compilation", "ModuleOp"> {
+  let summary = "Marks ops in device cluster for outside compilation if they are unsupported on device.";
+
+  let description = [{
+This pass marks unsupported ops in a device cluster with
+`_xla_outside_compilation` attribute so the operations will run on the host
+instead of the device. Unsupported ops are ops that can not be code
+generated to run on the device for the cluster including:
+
+1. String operations on TPUs.
+2. Operations that don't have a kernel defined for the device.
+
+This pass is conservative in that it will mark all ops for outside compilation
+that can not be compiled for the device.  Exceptions for this are added for ops
+that will be rewritten or decomposed before compiling on device.
+
+
+For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp:
+
+```mlir
+func @unsupported_op() -> tensor<i32> {
+  %0 = "tf_device.cluster"() ( {
+    %1 = "tf.UnsupportedOp"() : () -> tensor<i32>
+    %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %2 : tensor<i32>
+  }) {allow_soft_placement = true, num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<i32>
+  return %0 : tensor<i32>
+}
+```
+
+will mark tf.UnsupportedOp with `_xla_outside_compilation` attribute:
+
+```mlir
+func @unsupported_op() -> tensor<i32> {
+  %0 = "tf_device.cluster"() ( {
+    %1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor<i32>
+    %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %2 : tensor<i32>
+  }) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<i32>
+  return %0 : tensor<i32>
+}
+```
+  }];
+
+  let constructor = "TFDevice::CreateMarkOpsForOutsideCompilationPass()";
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
index 46bc094..45c5ab0 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
@@ -13,15 +13,6 @@
 limitations under the License.
 ==============================================================================*/
 
-// This transformation pass takes ops with the same `_tpu_replicate` attribute
-// in a block and clusters them together under a `tf_device.cluster`.
-// Associated TPUReplicateMetadata ops are removed and its attributes are copied
-// over to the associated `tf_device.cluster`. If a cluster should be
-// replicated, the associated `tf_device::LaunchOp` will be wrapped further with
-// a `tf_device.replicate`. This pass also assumes ops of the same cluster do
-// not have ops outside of the cluster that are both operands and results of the
-// cluster. Note, this currently does not handle side effecting ops yet.
-
 #include <algorithm>
 #include <iterator>
 #include <memory>
@@ -51,6 +42,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
 
 namespace mlir {
 namespace TFTPU {
@@ -77,16 +69,13 @@
 // Mapping for `_tpu_replicate` attribute to ops of a cluster.
 using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, ClusterOps, 8>;
 
-struct TPUClusterFormation
-    : public TF::PerFunctionAggregateAnalysisConsumerPass<
-          TPUClusterFormation, TF::ResourceAliasAnalysis> {
+struct TPUClusterFormationPass
+    : public TF::TPUClusterFormationPassBase<TPUClusterFormationPass> {
   void getDependentDialects(DialectRegistry& registry) const override {
     registry.insert<tf_device::TensorFlowDeviceDialect>();
   }
 
-  void runOnFunction(
-      FuncOp func,
-      const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
+  void runOnOperation() override;
 };
 
 // Creates a mapping from the TPUReplicateMetadata ops `_tpu_replicate`
@@ -210,8 +199,8 @@
 // cluster may be interleaved with other ops in the cluster. Resource id's are
 // also captured, to keep track of resource usage before, in, or after the
 // cluster.
-// TODO(lyandy): Extend this to handle all side effecting ops while handling
-// transitive data dependencies.
+// TODO(b/175701589): Extend this to handle all side effecting ops while
+// handling transitive data dependencies.
 llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
     Block* block, const ClusterOps& cluster_ops,
     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
@@ -419,12 +408,12 @@
       llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
       replicated_inputs, packed_inputs, cluster.getResultTypes());
   if (has_replicated_input_index)
-    replicate_op.setAttr(kReplicatedInputIndicesAttr,
-                         builder.getI64ArrayAttr(replicated_input_indices));
+    replicate_op->setAttr(kReplicatedInputIndicesAttr,
+                          builder.getI64ArrayAttr(replicated_input_indices));
 
   if (!mirrored_variable_indices.empty())
-    replicate_op.setAttr(kMirroredVariableIndicesAttr,
-                         builder.getI64ArrayAttr(mirrored_variable_indices));
+    replicate_op->setAttr(kMirroredVariableIndicesAttr,
+                          builder.getI64ArrayAttr(mirrored_variable_indices));
 
   // Replace replicated cluster results with replicate op results.
   for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
@@ -550,7 +539,7 @@
       return failure();
 
     // Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
-    cluster.setAttrs(cluster_metadata->second);
+    cluster->setAttrs(cluster_metadata->second);
     // Exclude `num_replicas` as cluster should be replicated if necessary.
     cluster.removeAttr(kNumReplicasAttr);
   }
@@ -558,16 +547,14 @@
   return success();
 }
 
-void TPUClusterFormation::runOnFunction(
+LogicalResult FormClustersInFunction(
     FuncOp func,
     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
-  if (!llvm::hasSingleElement(func)) {
-    func.emitOpError("Expecting a single block function");
-    return signalPassFailure();
-  }
+  if (!llvm::hasSingleElement(func))
+    return func.emitOpError("Expecting a single block function");
 
   if (failed(FormClustersInBlock(&func.front(), resource_alias_analysis)))
-    return signalPassFailure();
+    return failure();
 
   // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
   auto remove_result = func.walk([&](Operation* op) {
@@ -593,17 +580,22 @@
     return WalkResult::advance();
   });
 
-  if (remove_result.wasInterrupted()) return signalPassFailure();
+  return failure(remove_result.wasInterrupted());
+}
+
+void TPUClusterFormationPass::runOnOperation() {
+  auto& resource_alias_analysis = getAnalysis<TF::ResourceAliasAnalysis>();
+  for (auto func : getOperation().getOps<FuncOp>())
+    if (!func.isExternal() &&
+        failed(FormClustersInFunction(
+            func, resource_alias_analysis.GetAnalysisForFunc(func))))
+      return signalPassFailure();
 }
 }  // anonymous namespace
 
 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass() {
-  return std::make_unique<TPUClusterFormation>();
+  return std::make_unique<TPUClusterFormationPass>();
 }
 
-static PassRegistration<TPUClusterFormation> pass(
-    "tf-tpu-cluster-formation",
-    "Form clusters from operations assigned to the same TPU cluster");
-
 }  // namespace TFTPU
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc
index 1ec9eba..24d5186 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc
@@ -116,7 +116,7 @@
 
   OpBuilder builder(&getContext());
   for (auto execute_launch : execute_launches) {
-    auto replicate = execute_launch.getParentOfType<tf_device::ReplicateOp>();
+    auto replicate = execute_launch->getParentOfType<tf_device::ReplicateOp>();
     if (!replicate) continue;
 
     ColocateCompositeResourceOpsInReplicate(replicate, &builder);
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_compile_op_replication_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_compile_op_replication_pass.cc
index e69cb28..f808d30 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_compile_op_replication_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_compile_op_replication_pass.cc
@@ -76,8 +76,8 @@
               builder.create<TF::TPUCompileSucceededAssertOp>(
                   new_compile_op->getLoc(),
                   new_compile_op->getResult(kStatusResultIndex));
-          new_assert_op.setAttr(kDeviceAttr,
-                                new_compile_op->getAttr(kDeviceAttr));
+          new_assert_op->setAttr(kDeviceAttr,
+                                 new_compile_op->getAttr(kDeviceAttr));
         }
         // Updates the operand to use the result of the newly created
         // tf._TPUCompileMlir op.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc
index b87faca..bcca0c6 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc
@@ -198,7 +198,7 @@
       if (auto sink =
               llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
         auto source = sink.GetSource();
-        source.setAttr(kDeviceAttr, new_device_attr);
+        source->setAttr(kDeviceAttr, new_device_attr);
         PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
                                    value_to_device);
         updated_next_iteration = true;
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc
index d483a85..39c867c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc
@@ -21,11 +21,11 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
@@ -109,7 +109,7 @@
   };
 
   // Check all generator aliases (ops or function argument) are on CPU.
-  FuncOp func = iterator_op.getParentOfType<FuncOp>();
+  FuncOp func = iterator_op->getParentOfType<FuncOp>();
   return llvm::all_of(aliases, [&](Value alias) {
     // Ignore non-generator aliases.
     if (!is_generator(alias)) return true;
@@ -172,7 +172,7 @@
   builder.setInsertionPoint(execute_launch);
   auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
                                               get_layout, input, &builder);
-  copy_with_layout.setAttr(kDeviceAttr, execute_launch.deviceAttr());
+  copy_with_layout->setAttr(kDeviceAttr, execute_launch.deviceAttr());
   execute.setOperand(execute_arg_index, copy_with_layout);
 }
 
@@ -206,8 +206,8 @@
                            .getValue()
                            .get(execute_launch.getDevice())
                            .cast<ArrayAttr>();
-    copy_with_layout.setAttr(kDeviceAttr,
-                             device_list.getValue()[entry.index()]);
+    copy_with_layout->setAttr(kDeviceAttr,
+                              device_list.getValue()[entry.index()]);
 
     replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
                          copy_with_layout);
@@ -230,7 +230,7 @@
 
   bool metadata_updated = false;
   auto maybe_replicate =
-      execute_launches.front().getParentOfType<tf_device::ReplicateOp>();
+      execute_launches.front()->getParentOfType<tf_device::ReplicateOp>();
 
   for (auto execute_and_input_mapping :
        llvm::zip(execute_launches, input_mappings)) {
@@ -274,8 +274,8 @@
   }
 
   if (metadata_updated)
-    compile.setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
-                                                compile.getContext()));
+    compile->setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
+                                                 compile.getContext()));
 }
 
 void TPUDynamicLayoutPass::runOnFunction(
@@ -284,7 +284,7 @@
   func.walk([&](TF::_TPUCompileMlirOp compile) {
     // Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
     auto compile_launch =
-        llvm::dyn_cast<tf_device::LaunchOp>(compile.getParentOp());
+        llvm::dyn_cast<tf_device::LaunchOp>(compile->getParentOp());
     if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
 
     llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
@@ -295,7 +295,7 @@
       auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
       if (!execute) return;
       auto execute_launch =
-          llvm::dyn_cast<tf_device::LaunchOp>(execute.getParentOp());
+          llvm::dyn_cast<tf_device::LaunchOp>(execute->getParentOp());
       if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
       execute_launches.push_back(execute_launch);
     }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc
index 9a0f6af..4ae900f 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc
@@ -94,7 +94,7 @@
         .str();
   };
 
-  Attribute padding_map_attr = cluster_func.getAttr(kPaddingMapAttr);
+  Attribute padding_map_attr = cluster_func->getAttr(kPaddingMapAttr);
   if (!padding_map_attr) return success();
 
   auto padding_map = padding_map_attr.dyn_cast<ArrayAttr>();
@@ -180,7 +180,7 @@
 
 LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
                                         SymbolTable* symbol_table) {
-  auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
+  auto replicate = cluster_func->getParentOfType<tf_device::ReplicateOp>();
   // LaunchFunc is not replicated, there will be no padding.
   if (!replicate) return success();
 
@@ -188,7 +188,7 @@
   if (!func) return success();
 
   auto replicated_input_indices_attr =
-      replicate.getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
+      replicate->getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
   if (!replicated_input_indices_attr) return success();
 
   llvm::SmallDenseMap<int32_t, int32_t> remapped_indices =
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
index 98b5f39..af21ee4 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
@@ -131,7 +131,7 @@
     const TF::SideEffectAnalysis& side_effect_analysis,
     tf_device::ClusterOp cluster) {
   const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
-      cluster.getParentOfType<FuncOp>());
+      cluster->getParentOfType<FuncOp>());
   Region* cluster_region = &cluster.body();
   llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
 
@@ -227,7 +227,7 @@
     llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
     llvm::SmallVectorImpl<Value>* cluster_results) {
   const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
-      cluster.getParentOfType<FuncOp>());
+      cluster->getParentOfType<FuncOp>());
   Region* cluster_region = &cluster.body();
   llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
   Operation* terminator = cluster.GetBody().getTerminator();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
index 0d69120..2a4fb73 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
@@ -25,10 +25,10 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeRange.h"  // from @llvm-project
 #include "mlir/IR/Visitors.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
@@ -38,6 +38,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
 
@@ -53,39 +54,9 @@
 using OutsideClusterMap =
     llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<Operation*, 8>, 8>;
 
-// This pass extracts a CPU computation cluster with `_xla_outside_compilation`
-// annotation from a TPU cluster. Each outside compilation cluster is moved to
-// a parallel_execute region. The TPU cluster is also moved to a
-// parallel_execute region. Communication ops between device and host are
-// added to pass inputs/outputs to/from the outside compiled region.
-//
-// A simple example:
-//   "tf_device.cluster"() ( {
-//     "tf.A"()
-//     "tf.B"() {_xla_outside_compilation = "cluster1"}
-//     "tf.C"()
-//     tf_device.return
-//   }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []}
-//
-// Would become the following ops (unimportant attribute, type are omitted):
-//   "tf_device.parallel_execute"() ( {
-//     "tf_device.launch"() ( {
-//       "tf.B()
-//       tf_device.return
-//     })
-//     tf_device.return
-//   }, {
-//     "tf_device.cluster"( {
-//       "tf.A"()
-//       "tf.C"()
-//       tf_device.return
-//     })
-//    tf_device.return
-//  })
-
 struct TPUExtractOutsideCompilation
-    : public PassWrapper<TPUExtractOutsideCompilation,
-                         OperationPass<ModuleOp>> {
+    : public TF::TPUExtractOutsideCompilationPassBase<
+          TPUExtractOutsideCompilation> {
   void runOnOperation() override;
 };
 
@@ -559,32 +530,46 @@
   return external_outputs;
 }
 
-// Sets the insertion point on `builder` for HostCompute op.  Sets insertion
-// point to the first op in `cluster_ops` that has one of `external_inputs`
-// as an operand.  If there are no external_inputs, set insertion point to first
-// cluster_op.
-void SetHostComputeInsertion(
-    OpBuilder& builder, llvm::ArrayRef<Operation*> cluster_ops,
-    const llvm::SmallSetVector<Value, 4>& external_inputs) {
-  if (external_inputs.empty()) builder.setInsertionPoint(cluster_ops.front());
-  for (const auto& cluster_op : cluster_ops) {
-    for (Value v : cluster_op->getOperands()) {
-      if (external_inputs.count(v)) {
-        builder.setInsertionPoint(cluster_op);
-        return;
-      }
-    }
+// Move all the ops that are in-between the cluster ops and depend on any op in
+// the cluster to after last op in the cluster. This also includes ops that
+// indirectly depend on the results so that the IR is legal.
+void MoveDependentOpsAfter(llvm::ArrayRef<Operation*> cluster_ops) {
+  llvm::SmallPtrSet<Operation*, 8> outside_ops(cluster_ops.begin(),
+                                               cluster_ops.end());
+
+  // Collect all ops between first and last op in the cluster that may need to
+  // be moved after the cluster.
+  llvm::SmallVector<Operation*, 8> ops;
+  Operation* first_op = *cluster_ops.begin();
+  Operation* last_op = *cluster_ops.rbegin();
+  for (Operation& op :
+       llvm::make_range(first_op->getIterator(), last_op->getIterator())) {
+    if (!outside_ops.contains(&op)) ops.push_back(&op);
   }
 
-  // If no operand usage can be found, this means that external input is
-  // implicitly captured inputs for ops inside internal regions of one of the
-  // `cluster_ops`. In that case, set the insertion point to the last op of the
-  // `cluster_ops` in the IR.
-  builder.setInsertionPoint(cluster_ops.back());
+  Operation* move_position = last_op;
+  for (Operation* op : ops) {
+    bool is_dependent = false;
+    for (Value operand : op->getOperands()) {
+      if (outside_ops.contains(operand.getDefiningOp())) {
+        is_dependent = true;
+        break;
+      }
+    }
+    // Op doesn't depend on any of the cluster ops' results.
+    if (!is_dependent) continue;
+
+    // Note that results of this op are never used as operands by any of the ops
+    // in this cluster. That would create an circular dependency between host
+    // and device which is avoided by the cluster assignment pass.
+    op->moveAfter(move_position);
+    move_position = op;
+    outside_ops.insert(op);
+  }
 }
 
-// Creates the HostCompute with `inputs` and `outputs`
-// using `communication_key`.
+// Creates the HostCompute with `inputs` and `outputs` using
+// `communication_key`.
 TF::_XlaHostComputeMlirOp CreateHostCompute(
     OpBuilder& builder, tf_device::ClusterOp tpu_cluster,
     llvm::ArrayRef<Operation*> cluster_ops,
@@ -594,7 +579,10 @@
   llvm::SmallVector<Type, 4> device_output_types;
   for (const auto& output : outputs)
     device_output_types.push_back(output.getType());
-  SetHostComputeInsertion(builder, cluster_ops, inputs);
+
+  MoveDependentOpsAfter(cluster_ops);
+  builder.setInsertionPointAfter(*cluster_ops.rbegin());
+
   auto host_compute = builder.create<TF::_XlaHostComputeMlirOp>(
       tpu_cluster.getLoc(), device_output_types, inputs.getArrayRef(),
       builder.getStringAttr(args_communication_key),
@@ -738,7 +726,7 @@
     // If there is no replication/data parallelism, it is assumed the device
     // ordinal is always 0 (e.g. /device:TPU:0). In that case, a constant 0
     // attribute can be used instead for _XlaSendFromHost/_XlaRecvAtHost ops.
-    if (tpu_cluster.getParentOfType<tf_device::ReplicateOp>()) {
+    if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
       auto device_ordinal_op =
           builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
               host_launch_op.getLoc(),
@@ -876,9 +864,5 @@
   return std::make_unique<TPUExtractOutsideCompilation>();
 }
 
-static PassRegistration<TPUExtractOutsideCompilation> pass(
-    "tf-tpu-extract-outside-compilation",
-    "Extracts TPU outside compilation to separate parallel_execute.");
-
 }  // namespace TFTPU
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc
index 1adc443..9fc0a53 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc
@@ -127,7 +127,7 @@
   VariableAccessesForTPUExecute infos;
   Attribute device_attr = execute_launch.deviceAttr();
   if (check_device && !device_attr) return infos;
-  auto func = execute_launch.getParentOfType<mlir::FuncOp>();
+  auto func = execute_launch->getParentOfType<mlir::FuncOp>();
 
   // Track the first read op found, which is used later to check if there are
   // assign ops between it and the TPUExecute op. We will exclude reads before
@@ -137,7 +137,7 @@
   Operation* first_read = nullptr;
   Operation& execute = execute_launch.GetBody().front();
   auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
-      execute_launch.getParentOp());
+      execute_launch->getParentOp());
   Operation* execute_parent =
       parallel_execute ? parallel_execute.getOperation() : execute_launch;
   // Find inputs that are variable reads.
@@ -148,7 +148,7 @@
         operand.value().get().getDefiningOp());
     if (!read_op) continue;
     if (check_same_region &&
-        read_op.getParentRegion() != execute_parent->getParentRegion())
+        read_op->getParentRegion() != execute_parent->getParentRegion())
       continue;
 
     auto resource = read_op.resource();
@@ -240,7 +240,7 @@
   auto execute_outputs =
       parallel_execute
           ? parallel_execute.GetRegionOutputs(
-                execute_launch.getParentRegion()->getRegionNumber())
+                execute_launch->getParentRegion()->getRegionNumber())
           : execute_launch.getResults();
   for (auto execute_output : llvm::enumerate(execute_outputs)) {
     // TODO(lyandy): Handle updates to resource writes by remapping to parent
@@ -340,7 +340,7 @@
   llvm::SmallVector<Type, 8> output_types;
   const int parallel_execute_num_results = parallel_execute_op->getNumResults();
   output_types.reserve(parallel_execute_num_results);
-  Region* execute_region = merged_execute_launch.getParentRegion();
+  Region* execute_region = merged_execute_launch->getParentRegion();
   const int region_index = execute_region->getRegionNumber();
   const int num_results_before_region =
       AppendTypes(&output_types, parallel_execute, 0, region_index);
@@ -547,7 +547,7 @@
       merged_execute_launch.GetBody().getTerminator());
 
   if (auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
-          execute_launch.getParentOp()))
+          execute_launch->getParentOp()))
     ReplaceParallelExecute(parallel_execute, execute_launch,
                            merged_execute_launch, infos, builder);
   else
@@ -591,11 +591,11 @@
   for (auto execute_launch : execute_launches) {
     OpBuilder builder(&getContext());
     const bool parent_is_replicate =
-        llvm::isa<tf_device::ReplicateOp>(execute_launch.getParentOp()) ||
+        llvm::isa<tf_device::ReplicateOp>(execute_launch->getParentOp()) ||
         (llvm::isa<tf_device::ParallelExecuteOp>(
-             execute_launch.getParentOp()) &&
+             execute_launch->getParentOp()) &&
          llvm::isa<tf_device::ReplicateOp>(
-             execute_launch.getParentOp()->getParentOp()));
+             execute_launch->getParentOp()->getParentOp()));
 
     // If this is inside a tf_device::ReplicateOp, the variables are guaranteed
     // to be on the same device as the TPUExecute op. Skip device checking in
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc
index 7f0acbd..558b877 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc
@@ -124,49 +124,44 @@
   }
 
  private:
+  // TODO(hinsu): Consider using GraphCycles data structure available in xla
+  // directory to avoid potentially full traversal for each new op and cluster
+  // pair.
   // Checks if it is safe for `op` to be merged into this cluster.
   bool IsSafeToAdd(Operation* op,
                    const TF::SideEffectAnalysis::Info& side_effect_analysis) {
-    // If the op is not marked for outside compilation it doesn't belong in a
-    // cluster.
-    if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
-      return false;
-
     if (host_cluster_ops_.empty()) return true;
 
     // If there is an intermediate data or side effect dependency between the op
     // and ops in the cluster, it's not safe to add.
-    llvm::SmallSetVector<Operation*, 4> op_stack;
-    for (auto* user : op->getUsers()) {
-      if (!host_cluster_ops_.contains(user)) op_stack.insert(user);
-    }
-    for (auto* successor : side_effect_analysis.DirectControlSuccessors(op)) {
-      if (!host_cluster_ops_.contains(successor)) op_stack.insert(successor);
-    }
-    bool safe_to_add = true;
-    while (!op_stack.empty()) {
-      auto* next_op = op_stack.pop_back_val();
-      for (auto* user : next_op->getUsers()) {
-        if (host_cluster_ops_.contains(user)) {
-          safe_to_add = false;
-          break;
-        } else {
-          op_stack.insert(user);
-        }
-      }
-      for (auto* successor :
-           side_effect_analysis.DirectControlSuccessors(next_op)) {
-        if (host_cluster_ops_.contains(successor)) {
-          safe_to_add = false;
-          break;
-        } else {
-          op_stack.insert(successor);
-        }
-      }
-      if (!safe_to_add) break;
+    std::vector<Operation*> dependencies;
+
+    // Materialize data dependencies as the llvm::concat doesn't support
+    // non-materialized iteration.
+    auto data_deps = llvm::to_vector<4>(op->getUsers());
+    llvm::SmallVector<Operation*, 4> control_deps =
+        side_effect_analysis.DirectControlSuccessors(op);
+    for (auto* dep : llvm::concat<Operation*>(data_deps, control_deps)) {
+      if (!host_cluster_ops_.contains(dep)) dependencies.push_back(dep);
     }
 
-    return safe_to_add;
+    llvm::SmallPtrSet<Operation*, 4> visited;
+    while (!dependencies.empty()) {
+      Operation* next_op = dependencies.back();
+      dependencies.pop_back();
+      if (visited.count(next_op)) continue;
+      visited.insert(next_op);
+
+      auto data_deps = llvm::to_vector<4>(next_op->getUsers());
+      llvm::SmallVector<Operation*, 4> control_deps =
+          side_effect_analysis.DirectControlSuccessors(next_op);
+      for (auto* dep : llvm::concat<Operation*>(data_deps, control_deps)) {
+        if (host_cluster_ops_.contains(dep)) return false;
+        dependencies.push_back(dep);
+      }
+    }
+
+    return true;
   }
 
   // `host_cluster_op_` stores a set of ops that will be grouped and computed
@@ -183,14 +178,15 @@
   int cluster_counter = 0;
 
   func.walk([&](tf_device::ClusterOp tpu_cluster) {
-    llvm::SmallVector<Operation*, 4> tpu_cluster_ops;
-    tpu_cluster_ops.reserve(tpu_cluster.getBody()->getOperations().size());
-
-    tpu_cluster.walk([&](Operation* op) { tpu_cluster_ops.emplace_back(op); });
+    llvm::SmallVector<Operation*, 4> outside_ops;
+    tpu_cluster.walk([&](Operation* op) {
+      if (op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
+        outside_ops.emplace_back(op);
+    });
 
     // In order to cluster ops feeding results to the same operation, traverse
     // the ops in reverse order.
-    for (Operation* op : llvm::reverse(tpu_cluster_ops)) {
+    for (Operation* op : llvm::reverse(outside_ops)) {
       // Try to add the op to existing clusters.
       bool added = false;
       for (auto& cluster : clusters)
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc
index 812cafa..cb69d77 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc
@@ -118,8 +118,8 @@
     for (Value read_operand : read_operands)
       block.addArgument(read_operand.getType());
 
-    func.setType(FunctionType::get(block.getArgumentTypes(),
-                                   func.getCallableResults(), &getContext()));
+    func.setType(FunctionType::get(&getContext(), block.getArgumentTypes(),
+                                   func.getCallableResults()));
     cluster_func.erase();
   }
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
index 18fb3c6..ffcc552 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
@@ -28,8 +28,8 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
@@ -106,15 +106,15 @@
 
 LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
                                           std::string* serialized_func_module) {
-  ModuleOp module = entry_func.getParentOfType<ModuleOp>();
+  ModuleOp module = entry_func->getParentOfType<ModuleOp>();
   SymbolTable entry_module_table(module);
   llvm::SmallVector<FuncOp, 4> referenced({entry_func});
 
   // Create a new module to hold func and all referenced functions.
   OwningModuleRef module_for_func =
       ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
-  auto parent_module = entry_func.getParentOfType<ModuleOp>();
-  auto versions_attr = parent_module.getAttr(kVersionsAttr);
+  auto parent_module = entry_func->getParentOfType<ModuleOp>();
+  auto versions_attr = parent_module->getAttr(kVersionsAttr);
   if (!versions_attr)
     return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
 
@@ -165,7 +165,7 @@
     tf_device::ClusterFuncOp op,
     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
   auto step_marker_location =
-      op.getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
+      op->getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
   if (!step_marker_location)
     return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
 
@@ -190,7 +190,7 @@
 LogicalResult SetMetadataProtoPaddingMap(
     tf_device::ClusterFuncOp op,
     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
-  auto padding_map = op.getAttrOfType<ArrayAttr>(kPaddingMapAttr);
+  auto padding_map = op->getAttrOfType<ArrayAttr>(kPaddingMapAttr);
   if (!padding_map)
     return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr));
 
@@ -234,7 +234,7 @@
     tf_device::ClusterFuncOp op,
     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
   auto input_shardings =
-      op.getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
+      op->getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
   if (!input_shardings)
     return op.emitOpError(
         CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
@@ -289,7 +289,7 @@
     tf_device::ClusterFuncOp op,
     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
   auto output_shardings =
-      op.getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
+      op->getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
   if (!output_shardings)
     return op.emitOpError(
         CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
@@ -329,7 +329,7 @@
   if (xla_device_assignment.hasValue())
     *metadata->mutable_device_assignment() =
         std::move(xla_device_assignment.getValue());
-  auto use_spmd_attr = op.getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
+  auto use_spmd_attr = op->getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
   if (!use_spmd_attr)
     return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
   metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
@@ -400,7 +400,7 @@
   }
 
   FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
-  FuncOp func = cluster_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
+  FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
       func_attr.getValue());
 
   std::string txt_module;
@@ -457,7 +457,7 @@
         tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
   }
 
-  replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
+  replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
 }
 
 // Creates a `tf.TPUExecute` op that executes TPU program.
@@ -637,16 +637,16 @@
     OpBuilder* builder) {
   // Skip non-tpu device cluster_func.
   auto replicate_attr =
-      cluster_func.getAttrOfType<StringAttr>("_tpu_replicate");
+      cluster_func->getAttrOfType<StringAttr>("_tpu_replicate");
   if (!replicate_attr) return success();
 
   // Collect `num_replicas` and `num_cores_per_replica` attributes.
   int num_replicas = 1;
   tf_device::ReplicateOp replicate =
-      cluster_func.getParentOfType<tf_device::ReplicateOp>();
+      cluster_func->getParentOfType<tf_device::ReplicateOp>();
   if (replicate) num_replicas = replicate.n();
 
-  auto num_cores_per_replica_attr = cluster_func.getAttrOfType<IntegerAttr>(
+  auto num_cores_per_replica_attr = cluster_func->getAttrOfType<IntegerAttr>(
       tensorflow::kNumCoresPerReplicaAttr);
   if (!num_cores_per_replica_attr)
     return cluster_func.emitOpError(
@@ -655,12 +655,12 @@
   int num_cores_per_replica = num_cores_per_replica_attr.getInt();
 
   auto topology_attr =
-      cluster_func.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
+      cluster_func->getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
   if (!topology_attr)
     return cluster_func.emitOpError(
         CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
 
-  auto device_assignment_attr = cluster_func.getAttrOfType<mlir::ArrayAttr>(
+  auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
       tensorflow::kDeviceAssignmentAttr);
   if (!device_assignment_attr)
     return cluster_func.emitOpError(
@@ -692,11 +692,11 @@
 
   // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
   // parallel_execute region if it exists.
-  if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func.getParentOp())) {
+  if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func->getParentOp())) {
     // Currently, outside compilation and model parallelism are not supported
     // together.
     assert(num_cores_per_replica == 1);
-    builder->setInsertionPoint(cluster_func.getParentOp());
+    builder->setInsertionPoint(cluster_func->getParentOp());
   }
 
   Operation* compile_op = BuildCompileOp(
@@ -711,7 +711,7 @@
   // and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
   // structured lowering.
   if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
-          cluster_func.getParentOp())) {
+          cluster_func->getParentOp())) {
     parallel_op.walk([&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
       key_op.replaceAllUsesWith(compile_op->getResult(1));
       key_op.erase();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc
index 76c22ef..5853a8a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc
@@ -130,8 +130,8 @@
     }
   }
 
-  cluster_func_op.setAttr(tensorflow::kInputShardingAttr,
-                          builder->getStrArrayAttr(sharding_for_args));
+  cluster_func_op->setAttr(tensorflow::kInputShardingAttr,
+                           builder->getStrArrayAttr(sharding_for_args));
 }
 
 // Finds XlaSharding op connected to a result value. XlaSharding op may be
@@ -202,8 +202,8 @@
     }
   }
 
-  cluster_func.setAttr(tensorflow::kOutputShardingAttr,
-                       builder->getStrArrayAttr(sharding_for_rets));
+  cluster_func->setAttr(tensorflow::kOutputShardingAttr,
+                        builder->getStrArrayAttr(sharding_for_rets));
 }
 
 // Extracts input/output sharding configuration of `cluster_func` by parsing
@@ -211,7 +211,7 @@
 void IdentifyXlaShardingForTPUComputation(
     Builder* builder, tf_device::ClusterFuncOp cluster_func) {
   // Look up function definition from module.
-  FuncOp func = cluster_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
+  FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
       cluster_func.func());
 
   // By default inputs/outputs have maximal sharding and are assigned to logical
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc
index 1c1bc14..ccc31d2 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc
@@ -25,12 +25,12 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
@@ -117,7 +117,7 @@
 void UpdateFuncType(FuncOp func) {
   auto arg_types = func.front().getArgumentTypes();
   auto result_types = func.front().getTerminator()->getOperandTypes();
-  func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
+  func.setType(FunctionType::get(func.getContext(), arg_types, result_types));
 }
 
 void HandleFuncOp(Operation* op) {
@@ -196,11 +196,11 @@
   MLIRContext* context = conv2d.getContext();
   SmallVector<int64_t, 4> values = {1, 1, 1, 1};
   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
-    return IntegerAttr::get(IntegerType::get(64, context), v);
+    return IntegerAttr::get(IntegerType::get(context, 64), v);
   });
   // TODO(b/157276506): change type of strides to DenseElementsAttr
   auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
-  conv2d.setAttr("strides", strides);
+  conv2d->setAttr("strides", strides);
 }
 
 // Transforms input shape for the first convolution.
@@ -351,7 +351,7 @@
   MLIRContext* context = backprop.getContext();
   SmallVector<int64_t, 4> values = {1, 1, 1, 1};
   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
-    return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
+    return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
   });
   auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
 
@@ -483,7 +483,7 @@
 void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
                    unsigned arg_num) {
   auto maybe_replicate =
-      llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func.getParentOp());
+      llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func->getParentOp());
 
   llvm::SmallVector<int64_t, 8> transform_input_indices;
   for (auto input : llvm::enumerate(cluster_func.operands())) {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc
index 6cd9f76..4328677 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc
@@ -19,8 +19,8 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Block.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/IR/Visitors.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
@@ -128,8 +128,8 @@
     auto outside_compilation_attr =
         embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
     if (outside_compilation_attr)
-      enqueue_mode.setAttr(kXlaOutsideCompilationAttr,
-                           outside_compilation_attr);
+      enqueue_mode->setAttr(kXlaOutsideCompilationAttr,
+                            outside_compilation_attr);
 
     mode_enqueue_operand.set(enqueue_mode);
   }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
index 5bdee80..bc777bd 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
@@ -151,7 +151,7 @@
 
   llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
   auto mirrored_variable_indices_attr =
-      replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
+      replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
   if (!mirrored_variable_indices_attr) return mapping;
 
   // Finds the mapping from a replicate argument to an execute operand.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc
index 330e768..e4f70a0 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc
@@ -27,9 +27,9 @@
 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
index e9cea13..7ee3f42 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
@@ -306,7 +306,7 @@
       if (auto other_island_op =
               llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
         (*new_control_inputs)[other_island_op].push_back(sink_island_control);
-      } else if (owner->getDialect() == island_op.getDialect() &&
+      } else if (owner->getDialect() == island_op->getDialect() &&
                  !llvm::isa<tf_executor::GraphOp, tf_executor::YieldOp,
                             tf_executor::NextIterationSourceOp>(owner)) {
         (*new_control_inputs)[owner].push_back(sink_island_control);
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
index bac78c6..36db48a 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
@@ -458,7 +458,7 @@
   llvm::SmallVector<llvm::StringRef, 2> input_names;
   llvm::SmallVector<llvm::StringRef, 2> output_names;
   auto dict_attr =
-      function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
+      function->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
   if (dict_attr) {
     TF_RET_CHECK(dict_attr.get("inputs").isa<mlir::StringAttr>())
         << "inputs missing in entry function attribute";
@@ -474,7 +474,7 @@
 
   // Extract version info.
   VersionDef versions;
-  auto module = function.getParentOfType<mlir::ModuleOp>();
+  auto module = function->getParentOfType<mlir::ModuleOp>();
   if (mlir::succeeded(ExtractTfVersions(module, &versions))) {
     graph->set_versions(versions);
   }
@@ -547,7 +547,7 @@
 
   auto convert_called_function = [&](llvm::StringRef name) {
     auto func =
-        function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
+        function->getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
             name);
     if (func != nullptr) {
       TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
@@ -648,9 +648,9 @@
   // and populates the GradientDef.
   auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
   if (auto attr =
-          function.getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) {
+          function->getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) {
     auto grad_func =
-        function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
+        function->getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
             attr.getValue());
     TF_RETURN_IF_ERROR(
         ConvertLibFunction(configs, tf_dialect, grad_func, flib));
@@ -661,7 +661,7 @@
   }
 
   auto stateful_string = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
-  if (auto attr = function.getAttrOfType<mlir::UnitAttr>(stateful_string)) {
+  if (auto attr = function->getAttrOfType<mlir::UnitAttr>(stateful_string)) {
     func_def.mutable_signature()->set_is_stateful(true);
   }
 
@@ -670,7 +670,7 @@
   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
       grad_string.data(), stateful_string.data()};
   llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs(
-      function.getDialectAttrs());
+      function->getDialectAttrs());
   TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, attrs_to_ignore,
                                        /*remove_ref_type=*/false,
                                        func_def.mutable_attr()));
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index 4d33308..56c2c04 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -50,12 +50,12 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Verifier.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
@@ -190,8 +190,13 @@
       restrict_functionalization_to_tpu_nodes
           ? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
           : NodeFilter{};
-  return FunctionalizeControlFlow(graph, flib_def, node_filter,
-                                  /*include_functions=*/true);
+  TF_RETURN_WITH_CONTEXT_IF_ERROR(
+      FunctionalizeControlFlow(graph, flib_def, node_filter,
+                               /*include_functions=*/true),
+      "Failed to functionalize Control Flow V1 ops. Consider using Control "
+      "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/"
+      "compat/v1/enable_control_flow_v2.");
+  return Status::OK();
 }
 
 // Stateful helper class to import a TensorFlow model into an MLIR Module.
@@ -1449,9 +1454,9 @@
       all_equal = false;
     }
     if (!all_equal) {
-      function.setType(mlir::FunctionType::get(func_type.getInputs(),
-                                               graph.getResultTypes(),
-                                               function.getContext()));
+      function.setType(mlir::FunctionType::get(function.getContext(),
+                                               func_type.getInputs(),
+                                               graph.getResultTypes()));
     }
   }
 
@@ -2901,8 +2906,8 @@
       }
       new_input_types.push_back(arg.getType());
     }
-    func.setType(mlir::FunctionType::get(
-        new_input_types, func.getType().getResults(), module.getContext()));
+    func.setType(mlir::FunctionType::get(module.getContext(), new_input_types,
+                                         func.getType().getResults()));
   }
 }
 
@@ -3064,7 +3069,7 @@
             /*executor_type=*/builder.getStringAttr(""));
         body_builder.create<mlir::ReturnOp>(func.getLoc(), call.getResults());
       }
-      func.setAttr(
+      func->setAttr(
           "tf_saved_model.exported_names",
           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
       const SavedConcreteFunction& concrete_function =
@@ -3162,7 +3167,7 @@
           value_attr,
           /*type=*/mlir::TypeAttr::get(type),
           /*is_mutable=*/builder.getUnitAttr());
-      op.setAttr(
+      op->setAttr(
           "tf_saved_model.exported_names",
           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
     } else if (object.kind_case() == SavedObject::kConstant) {
@@ -3182,13 +3187,13 @@
           value_attr,
           /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()),
           /*is_mutable=*/nullptr);
-      op.setAttr(
+      op->setAttr(
           "tf_saved_model.exported_names",
           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
     }
   }
   AdjustBoundInputArgTypes(module);
-  module.setAttr("tf_saved_model.semantics", builder.getUnitAttr());
+  module->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
   SortSavedModelModule(module);
   MarkSavedModelFunctionVisibility(module);
   return Status::OK();
@@ -3448,7 +3453,7 @@
 
   // Set the exported name of init function to an reserved name for
   // tf_saved_model.
-  init_func_op.setAttr(
+  init_func_op->setAttr(
       "tf_saved_model.exported_names",
       builder.getStrArrayAttr({absl::StrCat(
           "__tf_saved_model_session_initializer_", target_node_name)}));
@@ -3508,8 +3513,8 @@
       << sig_def_key << ".";
 
   // Use unique SignatureDef key as exported name.
-  func_op.setAttr("tf_saved_model.exported_names",
-                  builder.getStrArrayAttr({sig_def_key}));
+  func_op->setAttr("tf_saved_model.exported_names",
+                   builder.getStrArrayAttr({sig_def_key}));
 
   // Transfer input and output parameter names to index_path attributes.
   for (auto input_and_idx : llvm::enumerate(inputs)) {
@@ -3623,7 +3628,7 @@
   builder.create<mlir::tf_saved_model::SessionInitializerOp>(
       module_->getLoc(), builder.getArrayAttr(init_sym_refs));
 
-  module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
+  (*module_)->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
 
   SortSavedModelModule(*module_);
   MarkSavedModelFunctionVisibility(*module_);
@@ -3653,7 +3658,8 @@
                             context, upgrade_legacy, /*import_restore=*/false));
 
     mlir::OpBuilder builder(module->getContext());
-    module->setAttr("tf_saved_model.under_construction", builder.getUnitAttr());
+    (*module)->setAttr("tf_saved_model.under_construction",
+                       builder.getUnitAttr());
     TF_RETURN_IF_ERROR(LiftVariables(bundle, *module));
     module->removeAttr("tf_saved_model.under_construction");
 
@@ -3733,6 +3739,7 @@
     mlir::MLIRContext* context) {
   tensorflow::GraphDebugInfo dummy_debug_info;
   tensorflow::GraphImportConfig specs;
+  specs.enable_shape_inference = false;
   specs.graph_as_function = true;
   for (const auto* control_ret_node : fbody->control_ret_nodes)
     specs.control_outputs.push_back(control_ret_node->name());
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_executor_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_executor_to_functional.cc
index 67c9447..8b5e0ec 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_executor_to_functional.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_executor_to_functional.cc
@@ -24,38 +24,15 @@
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
 
 namespace mlir {
 
 namespace {
-// This pass lifts tf_executor.island inner ops from a tf_executor.graph that
-// contains only tf_executor.island ops.
-//
-// e.g.
-//   func @my_fn(%arg0, %arg1) -> (...) {
-//     %graph_results:2 = tf_executor.graph {
-//       %island_0_result, %island_0_control = tf_executor.island {
-//         %a = tf.opA(%arg0)
-//         tf_executor.yield %a
-//       }
-//       %island_1_result, %island_1_control = tf_executor.island {
-//         %b = tf.opB(%arg1, %island_0_result)
-//         tf_executor.yield %b
-//       }
-//       tf_executor.fetch %island_0_result, %island_1_result
-//     }
-//     return %graph_results#0, %graph_results#1
-//   }
-//
-// will be transformed into:
-//   func @my_fn(%arg0, %arg1) -> (...) {
-//     %a = tf.opA(%arg0)
-//     %b = tf.opB(%arg1, %a)
-//     return %a, %b
-//   }
 
 struct ExecutorDialectToFunctionalConversion
-    : public PassWrapper<ExecutorDialectToFunctionalConversion, FunctionPass> {
+    : public TF::ExecutorDialectToFunctionalPassBase<
+          ExecutorDialectToFunctionalConversion> {
   void runOnFunction() override;
 };
 
@@ -109,7 +86,3 @@
 
 }  // namespace mlir
 
-static mlir::PassRegistration<mlir::ExecutorDialectToFunctionalConversion> pass(
-    "tf-executor-to-functional-conversion",
-    "Transform from the TF executor dialect (tf_executor.graph containing only "
-    "tf_executor.island ops) to func op.");
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
index dd19327..496bf83 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
@@ -19,10 +19,10 @@
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Parser.h"  // from @llvm-project
 #include "tensorflow/cc/saved_model/bundle_v2.h"
 #include "tensorflow/cc/saved_model/reader.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc
index d7b5110..3565f58 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc
@@ -15,6 +15,9 @@
 
 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
 
+#include <atomic>
+
+#include "absl/strings/str_split.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/IR/Operation.h"  // from @llvm-project
@@ -23,17 +26,30 @@
 
 namespace tensorflow {
 
+// Counter is used as a prefix for filenames.
+static std::atomic<int> log_counter(0);
+
 BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope,
                                        bool print_after_only_on_change)
     : mlir::PassManager::IRPrinterConfig(print_module_scope,
-                                         print_after_only_on_change) {}
+                                         print_after_only_on_change) {
+  const char* log_pass_patterns = getenv("MLIR_BRIDGE_LOG_PASS_PATTERNS");
+  if (log_pass_patterns) {
+    log_pass_patterns_ =
+        absl::StrSplit(log_pass_patterns, ',', absl::SkipWhitespace());
+  }
+}
 
-// Logs op to file with name of format `mlir_bridge-pass_name-file_suffix.mlir`.
+// Logs op to file with name of format
+// `<log_counter>_mlir_bridge_<pass_name>_<file_suffix>.mlir`.
 inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
                        mlir::Pass* pass, mlir::Operation* op,
                        llvm::StringRef file_suffix) {
-  std::string name =
-      llvm::formatv("mlir_bridge_{0}_{1}", pass->getName(), file_suffix).str();
+  std::string pass_name = pass->getName().str();
+
+  // Add 4-digit counter as prefix so the order of the passes is obvious.
+  std::string name = llvm::formatv("{0,0+4}_mlir_bridge_{1}_{2}", log_counter++,
+                                   pass_name, file_suffix);
 
   std::unique_ptr<llvm::raw_ostream> os;
   std::string filepath;
@@ -44,13 +60,30 @@
 void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass,
                                               mlir::Operation* operation,
                                               PrintCallbackFn print_callback) {
-  Log(print_callback, pass, operation, "before");
+  if (should_print(pass)) Log(print_callback, pass, operation, "before");
 }
 
 void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass,
                                              mlir::Operation* operation,
                                              PrintCallbackFn print_callback) {
-  Log(print_callback, pass, operation, "after");
+  if (should_print(pass)) Log(print_callback, pass, operation, "after");
+}
+
+bool BridgeLoggerConfig::should_print(mlir::Pass* pass) {
+  if (log_pass_patterns_.empty()) return true;
+
+  std::string pass_name = pass->getName().str();
+  for (const auto& pattern : log_pass_patterns_) {
+    if (pass_name.find(pattern) != std::string::npos) {
+      // pattern matches pass
+      return true;
+    }
+  }
+  // no pattern matches pass
+  VLOG(2) << "Not logging pass " << pass_name
+          << " because it does not match any pattern in "
+             "MLIR_BRIDGE_LOG_PASS_PATTERNS";
+  return false;
 }
 
 void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) {
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h
index eaf3a7c..c7cd22b 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h
@@ -23,7 +23,11 @@
 namespace tensorflow {
 
 // Logger for logging/dumping MLIR modules before and after passes in bridge
-// targeting TPUs.
+// targeting TPUs. The passes being logged can be restricted via environment
+// variable `MLIR_BRIDGE_LOG_PASS_PATTERNS` which is interpreted as a comma-
+// separated list of strings, and only passes whose name contains any of those
+// strings as a substring are logged (no regex support). If
+// `MLIR_BRIDGE_LOG_PASS_PATTERNS` is not defined, then all passes are logged.
 class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
  public:
   explicit BridgeLoggerConfig(bool print_module_scope = false,
@@ -42,6 +46,14 @@
   // with the stream to dump into.
   void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
                            PrintCallbackFn print_callback) override;
+
+ private:
+  bool should_print(mlir::Pass *pass);
+
+  // Only print passes that match any of these patterns. A pass matches a
+  // pattern if its name contains the pattern as a substring. If
+  // `log_pass_patterns_` is empty, print all passes.
+  std::vector<std::string> log_pass_patterns_;
 };
 
 // Logger for logging/dumping pass pipeline timings after completion.
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index e20aa64..6b09455 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -27,11 +27,11 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
@@ -60,6 +60,7 @@
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
 
 namespace tensorflow {
 namespace {
@@ -210,7 +211,20 @@
   std::iota(input_mapping->begin(), input_mapping->end(), 0);
 }
 
-// Refine MLIR types based on new shape information.
+static void RegisterDialects(mlir::DialectRegistry& registry) {
+  mlir::RegisterAllTensorFlowDialects(registry);
+  mlir::mhlo::registerAllMhloDialects(registry);
+}
+
+// Checks if functions can be inlined after TF -> HLO legalization. Currently
+// TPU's are supported, to follow the behavior of inlining functions via the
+// Graph based bridge in the TPUCompile op kernel.
+bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
+  return device_type == DEVICE_TPU_XLA_JIT;
+}
+
+}  //  namespace
+
 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
                     mlir::ModuleOp module) {
   auto producer_or = GetTfGraphProducerVersion(module);
@@ -261,19 +275,18 @@
   return Status::OK();
 }
 
-static void RegisterDialects(mlir::DialectRegistry& registry) {
-  mlir::RegisterAllTensorFlowDialects(registry);
-  mlir::mhlo::registerAllMhloDialects(registry);
-}
-
-}  //  namespace
-
 void CreateConvertMlirToXlaHloPipeline(
     mlir::OpPassManager& pm, llvm::StringRef device_type,
     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
         custom_legalization_passes) {
   pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
+  pm.addNestedPass<mlir::FuncOp>(mlir::TF::CreateDropWhileShapeInvariantPass());
   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+  // The SCCP pass performs constant propagation across the IR, which, for
+  // example, propagates constant arguments into callee functions.
+  pm.addPass(mlir::createSCCPPass());
+  // Guarantee all functions have one use, which enables shape inference.
+  pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
   // Run shape inference pass before tensorlist decomposition to get buffer
   // shape of uninitialized TensorLists.
   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
@@ -284,8 +297,6 @@
       mlir::TFDevice::CreateDecomposeResourceOpsPass());
   pm.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
   pm.addPass(mlir::createSymbolDCEPass());
-  // Guarantee all functions have one use, which enables shape inference.
-  pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
   // TODO(b/171426148): We cannot completely remove region to functional control
   // flow conversion from this pipeline yet as it causes some unit tests to
@@ -295,9 +306,6 @@
   // with a tuple argument which break the assumption of resource lifting
   // inside PromoteResourcesToArgs.
   pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
-  // The SCCP pass performs constant propagation across the IR, which, for
-  // example, propagates constant arguments into callee functions.
-  pm.addPass(mlir::createSCCPPass());
 
   pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
       /*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
@@ -319,19 +327,19 @@
   pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
       /*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
       /*tf2xla_fallback_device_type=*/device_type));
+
+  if (CanInlineFunctionsPostLegalization(device_type))
+    pm.addPass(mlir::createInlinerPass());
+
   // In order to export to XLA, we must sink constants to control flow regions,
   // since XLA uses functional control flow.
   pm.addNestedPass<mlir::FuncOp>(
       mlir::mhlo::createSinkConstantsToControlFlowPass());
 }
 
-Status ConvertMLIRToXlaComputation(
-    mlir::ModuleOp module_op, llvm::StringRef device_type,
-    xla::XlaComputation* xla_computation, bool use_tuple_args,
-    bool return_tuple,
-    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
-    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
-        custom_legalization_passes) {
+Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
+                     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+                         custom_legalization_passes) {
   mlir::PassManager tf2xla(module_op.getContext());
   applyTensorflowAndCLOptions(tf2xla);
   CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
@@ -358,6 +366,32 @@
   if (VLOG_IS_ON(1))
     tensorflow::DumpMlirOpToFile("mlir_compile_legalize_hlo", module_op);
 
+  return Status::OK();
+}
+
+Status BuildHloFromTfInner(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
+                           llvm::ArrayRef<xla::XlaOp> xla_params,
+                           std::vector<xla::XlaOp>& returns,
+                           llvm::StringRef device_type,
+                           llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+                               custom_legalization_passes) {
+  TF_RETURN_IF_ERROR(
+      LegalizeToHlo(module_op, device_type, custom_legalization_passes));
+
+  mlir::Block& block = module_op.lookupSymbol<mlir::FuncOp>("main").front();
+  return mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns);
+}
+
+Status ConvertMLIRToXlaComputation(
+    mlir::ModuleOp module_op, llvm::StringRef device_type,
+    xla::XlaComputation* xla_computation, bool use_tuple_args,
+    bool return_tuple,
+    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+        custom_legalization_passes) {
+  TF_RETURN_IF_ERROR(
+      LegalizeToHlo(module_op, device_type, custom_legalization_passes));
+
   xla::HloProto hlo_proto;
   TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
                                                use_tuple_args, return_tuple,
@@ -366,14 +400,9 @@
   return Status::OK();
 }
 
-Status CompileMlirToXlaHlo(
+Status CompileMlirSetup(
     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
-    llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
-    bool use_resource_updates_for_aliases,
-    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
-    XlaCompilationResult* compilation_result,
-    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
-        custom_legalization_passes) {
+    XlaHelpers::ShapeRepresentationFn* shape_representation_fn) {
   if (VLOG_IS_ON(1))
     tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
 
@@ -383,16 +412,39 @@
   if (VLOG_IS_ON(1))
     tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op);
 
-  if (!shape_representation_fn)
-    shape_representation_fn = IdentityShapeRepresentationFn();
+  if (!*shape_representation_fn)
+    *shape_representation_fn = IdentityShapeRepresentationFn();
+
+  return Status::OK();
+}
+
+Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
+                      llvm::ArrayRef<xla::XlaOp> xla_params,
+                      std::vector<xla::XlaOp>& returns,
+                      llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
+                      llvm::StringRef device_type,
+                      llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+                          custom_legalization_passes) {
+  XlaHelpers::ShapeRepresentationFn shape_representation_fn;
+  TF_RETURN_IF_ERROR(
+      CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
 
   // Convert MLIR module to XLA HLO proto contained in XlaComputation.
-  compilation_result->computation = std::make_shared<xla::XlaComputation>();
-  TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
-      module_op, device_type, compilation_result->computation.get(),
-      use_tuple_args, use_return_tuple, shape_representation_fn,
-      custom_legalization_passes));
+  TF_RETURN_IF_ERROR(BuildHloFromTfInner(module_op, builder, xla_params,
+                                         returns, device_type,
+                                         custom_legalization_passes));
 
+  if (VLOG_IS_ON(1))
+    tensorflow::DumpMlirOpToFile("mlir_compile_after", module_op);
+
+  return Status::OK();
+}
+
+Status PopulateResultIOInfo(
+    mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
+    bool use_tuple_args, bool use_resource_updates_for_aliases,
+    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    XlaCompilationResult* compilation_result) {
   // Construct mapping from XlaComputation's arg to input edges of execute
   // node.
   GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
@@ -414,6 +466,29 @@
   return Status::OK();
 }
 
+Status CompileMlirToXlaHlo(
+    mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
+    llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
+    bool use_resource_updates_for_aliases,
+    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    XlaCompilationResult* compilation_result,
+    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+        custom_legalization_passes) {
+  TF_RETURN_IF_ERROR(
+      CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
+
+  // Convert MLIR module to XLA HLO proto contained in XlaComputation.
+  compilation_result->computation = std::make_shared<xla::XlaComputation>();
+  TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
+      module_op, device_type, compilation_result->computation.get(),
+      use_tuple_args, use_return_tuple, shape_representation_fn,
+      custom_legalization_passes));
+
+  return PopulateResultIOInfo(module_op, arg_shapes, use_tuple_args,
+                              use_resource_updates_for_aliases,
+                              shape_representation_fn, compilation_result);
+}
+
 Status CompileSerializedMlirToXlaHlo(
     llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
     llvm::StringRef device_type, bool use_tuple_args,
@@ -494,9 +569,9 @@
     for (mlir::BlockArgument& arg : main_fn.getArguments())
       updated_argument_types.push_back(arg.getType());
 
-    main_fn.setType(mlir::FunctionType::get(updated_argument_types,
-                                            main_fn.getType().getResults(),
-                                            main_fn.getContext()));
+    main_fn.setType(mlir::FunctionType::get(main_fn.getContext(),
+                                            updated_argument_types,
+                                            main_fn.getType().getResults()));
   }
 
   for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
@@ -504,18 +579,13 @@
   return params;
 }
 
-Status CompileGraphToXlaHlo(
+Status CompileGraphSetup(
     mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
-    llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
-    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
-    XlaCompilationResult* compilation_result,
-    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
-        custom_legalization_passes) {
-  TF_ASSIGN_OR_RETURN(std::vector<int> remaining_params,
-                      RewriteWithArgs(module_op, args));
-  llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
-  arg_shapes.reserve(remaining_params.size());
-  for (unsigned idx : remaining_params) {
+    std::vector<int>* remaining_params,
+    llvm::SmallVector<TensorOrResourceShape, 4>& arg_shapes) {
+  TF_ASSIGN_OR_RETURN(*remaining_params, RewriteWithArgs(module_op, args));
+  arg_shapes.reserve(remaining_params->size());
+  for (unsigned idx : *remaining_params) {
     const auto& arg = args[idx];
     TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
                         GetTensorShapeFromXlaArgument(arg));
@@ -527,10 +597,39 @@
   applyTensorflowAndCLOptions(pm);
   mlir::TF::StandardPipelineOptions tf_options;
   mlir::TF::CreateTFStandardPipeline(pm, tf_options);
-  {
-    mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
-    if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
-  }
+
+  mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
+  if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
+
+  return Status::OK();
+}
+
+Status BuildHloFromModule(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
+                          llvm::ArrayRef<xla::XlaOp> xla_params,
+                          std::vector<xla::XlaOp>& returns,
+                          llvm::ArrayRef<XlaArgument> args,
+                          llvm::StringRef device_type,
+                          llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+                              custom_legalization_passes) {
+  std::vector<int> remaining_params;
+  llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
+  TF_RETURN_IF_ERROR(
+      CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
+  return BuildHloFromTf(module_op, builder, xla_params, returns, arg_shapes,
+                        device_type, custom_legalization_passes);
+}
+
+Status CompileGraphToXlaHlo(
+    mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
+    llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
+    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    XlaCompilationResult* compilation_result,
+    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+        custom_legalization_passes) {
+  std::vector<int> remaining_params;
+  llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
+  TF_RETURN_IF_ERROR(
+      CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
 
   auto status = CompileMlirToXlaHlo(
       module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
@@ -540,6 +639,49 @@
   return status;
 }
 
+Status GraphToModule(const Graph& graph,
+                     llvm::ArrayRef<std::string> control_rets,
+                     const FunctionLibraryDefinition& flib_def,
+                     const GraphDebugInfo& debug_info,
+                     mlir::MLIRContext* context,
+                     mlir::OwningModuleRef* module) {
+  RegisterDialects(context->getDialectRegistry());
+  GraphImportConfig config;
+  config.graph_as_function = true;
+  config.control_outputs = control_rets;
+  // Disable shape inference during import as some TensorFlow op fails during
+  // shape inference with dynamic shaped operands. This in turn causes the
+  // import to fail. Shape inference during import is going to be removed and
+  // the shape inference pass is run early in the pass pipeline, shape inference
+  // during import is not necessary.
+  config.enable_shape_inference = false;
+  auto module_or =
+      ConvertGraphToMlir(graph, debug_info, flib_def, config, context);
+  if (!module_or.ok()) return module_or.status();
+
+  *module = std::move(module_or.ValueOrDie());
+
+  return Status::OK();
+}
+
+Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
+                         llvm::ArrayRef<xla::XlaOp> xla_params,
+                         std::vector<xla::XlaOp>& returns,
+                         llvm::ArrayRef<XlaArgument> args,
+                         llvm::ArrayRef<std::string> control_rets,
+                         llvm::StringRef device_type,
+                         const FunctionLibraryDefinition& flib_def,
+                         const GraphDebugInfo& debug_info,
+                         llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+                             custom_legalization_passes) {
+  mlir::MLIRContext context;
+  mlir::OwningModuleRef module;
+  TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
+                                   &context, &module));
+  return BuildHloFromModule(module.get(), builder, xla_params, returns, args,
+                            device_type, custom_legalization_passes);
+}
+
 Status CompileGraphToXlaHlo(
     const Graph& graph, llvm::ArrayRef<XlaArgument> args,
     llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
@@ -550,22 +692,10 @@
     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
         custom_legalization_passes) {
   mlir::MLIRContext context;
-  RegisterDialects(context.getDialectRegistry());
-  GraphImportConfig config;
-  config.graph_as_function = true;
-  config.control_outputs = control_rets;
-  // Disable shape inference during import as some TensorFlow op fails during
-  // shape inference with dynamic shaped operands. This in turn causes the
-  // import to fail. Shape inference during import is going to be removed and
-  // the shape inference pass is run early in the pass pipeline, shape inference
-  // during import is not necessary.
-  config.enable_shape_inference = false;
-  auto module_or =
-      ConvertGraphToMlir(graph, debug_info, flib_def, config, &context);
-  if (!module_or.ok()) return module_or.status();
-
-  mlir::ModuleOp module_op = module_or.ValueOrDie().get();
-  return CompileGraphToXlaHlo(module_op, args, device_type, use_tuple_args,
+  mlir::OwningModuleRef module;
+  TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
+                                   &context, &module));
+  return CompileGraphToXlaHlo(module.get(), args, device_type, use_tuple_args,
                               /*use_return_tuple=*/true,
                               shape_representation_fn, compilation_result,
                               custom_legalization_passes);
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
index b02720b..64d48ee 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
@@ -81,8 +81,33 @@
   bool is_resource = false;
 };
 
+// Refine MLIR types based on new shape information.
+Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
+                    mlir::ModuleOp module);
+
+// Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level
+// inputs to module_op that have already been added to the XlaBuilder. returns
+// are the returned XlaOps.
+Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
+                      llvm::ArrayRef<xla::XlaOp> xla_params,
+                      std::vector<xla::XlaOp>& returns,
+                      llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
+                      llvm::StringRef device_type,
+                      llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+                          custom_legalization_passes);
+
+// Apply shape, description, and resource information to inputs and outputs
+// in the XlaCompilationResult. This should be called after
+// compilation_result->computation was set.
+Status PopulateResultIOInfo(
+    mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
+    bool use_tuple_args, bool use_resource_updates_for_aliases,
+    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    XlaCompilationResult* compilation_result);
+
 // Compiles a MLIR module into XLA HLO, generates all accompanying metadata and
 // stores them in CompilationResult.
+// TODO(hinsu): Migrate options to separate struct.
 Status CompileMlirToXlaHlo(
     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
     llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
@@ -127,6 +152,21 @@
     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
         custom_legalization_passes = {});
 
+// Compiles a Graph from TF to HLO and adds the resulting HLO to the
+// XlaBuilder. This function adds HLO to a larger HLO computation, so
+// HLO-level inputs are supplied, and HLO-level outputs are produced.
+// xla_params is the HLO-level inputs and returns is the HLO-level outputs.
+Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
+                         llvm::ArrayRef<xla::XlaOp> xla_params,
+                         std::vector<xla::XlaOp>& returns,
+                         llvm::ArrayRef<XlaArgument> args,
+                         llvm::ArrayRef<std::string> control_rets,
+                         llvm::StringRef device_type,
+                         const FunctionLibraryDefinition& flib_def,
+                         const GraphDebugInfo& debug_info,
+                         llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+                             custom_legalization_passes = {});
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
index d1f4086..8ecf62d 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
@@ -26,7 +26,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
index 6266a5e..578bbab 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
@@ -20,9 +20,9 @@
 
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
@@ -136,30 +136,30 @@
       {1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context)));
 
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<int8>(
-      {1, -1}, DT_INT8, mlir::IntegerType::get(8, &context)));
+      {1, -1}, DT_INT8, mlir::IntegerType::get(&context, 8)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<int16>(
-      {1, -1}, DT_INT16, mlir::IntegerType::get(16, &context)));
+      {1, -1}, DT_INT16, mlir::IntegerType::get(&context, 16)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<int32>(
-      {1, -1}, DT_INT32, mlir::IntegerType::get(32, &context)));
+      {1, -1}, DT_INT32, mlir::IntegerType::get(&context, 32)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<int64>(
-      {1, -1}, DT_INT64, mlir::IntegerType::get(64, &context)));
+      {1, -1}, DT_INT64, mlir::IntegerType::get(&context, 64)));
 
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint8>(
       {1, 2}, DT_UINT8,
       mlir::IntegerType::get(
-          8, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+          &context, 8, mlir::IntegerType::SignednessSemantics::Unsigned)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint16>(
       {1, 2}, DT_UINT16,
       mlir::IntegerType::get(
-          16, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+          &context, 16, mlir::IntegerType::SignednessSemantics::Unsigned)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint32>(
       {1, 2}, DT_UINT32,
       mlir::IntegerType::get(
-          32, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+          &context, 32, mlir::IntegerType::SignednessSemantics::Unsigned)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint64>(
       {1, 2}, DT_UINT64,
       mlir::IntegerType::get(
-          64, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+          &context, 64, mlir::IntegerType::SignednessSemantics::Unsigned)));
 
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<float>>(
       {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64,
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc
index 0d035e8..3456901 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc
@@ -17,7 +17,7 @@
 
 #include "absl/strings/str_cat.h"
 #include "llvm/Support/Casting.h"
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc
index 07f6b12..ee44207 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc
@@ -17,8 +17,8 @@
 
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc
index cca6981..0159281 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc
@@ -55,26 +55,6 @@
   return false;
 }
 
-// Update node_def's device attribute (if any) to use a local device, that is
-// /job:localhost/replica:0/task:0/{DEVICE_TYPE}:{DEVICE_ID}.
-// This is because EvaluateOperation only has access to local devices but the
-// given node may carry a device assignment to a remote device. In that case,
-// evaluation would fail even if we have a device of same type locally. By
-// altering device assignment to a local one, we could successfully evaluate in
-// that case.
-void ForceUseLocalhostDevice(NodeDef* node_def) {
-  DeviceNameUtils::ParsedName parsed_name;
-
-  if (!DeviceNameUtils::ParseFullName(node_def->device(), &parsed_name)) return;
-
-  if (parsed_name.has_job) parsed_name.job = "localhost";
-  if (parsed_name.has_replica) parsed_name.replica = 0;
-  if (parsed_name.has_task) parsed_name.task = 0;
-
-  *node_def->mutable_device() =
-      DeviceNameUtils::ParsedNameToString(parsed_name);
-}
-
 mlir::LogicalResult EvaluateOperation(
     mlir::Operation* inst, llvm::ArrayRef<mlir::ElementsAttr> operands,
     TFE_Context* context, llvm::SmallVectorImpl<mlir::Attribute>* results) {
@@ -104,12 +84,16 @@
   RETURN_FAILURE_IF_ERROR(node_def_or.status());
   const auto& node_def = node_def_or.ValueOrDie();
 
-  ForceUseLocalhostDevice(node_def.get());
-
   TFE_Op* op = TFE_NewOp(context, node_def->op().c_str(), status);
   RETURN_FAILURE_IF_ERROR(status);
   auto clean_op = MakeCleanup([op] { TFE_DeleteOp(op); });
-  TFE_OpSetDevice(op, node_def->device().c_str(), status);
+
+  // Explicitly set device to Host CPU instead of the device present in device
+  // attribute of the MLIR op. The assigned device might be remote, not
+  // available during compilation or compilation only device for on demand
+  // execution which may create a recursion if used for constant folding.
+  constexpr char kHostCpu[] = "/job:localhost/replica:0/task:0/CPU:0";
+  TFE_OpSetDevice(op, kHostCpu, status);
   RETURN_FAILURE_IF_ERROR(status);
   for (const auto& attr : node_def->attr()) {
     SetOpAttrValueScalar(context, op, attr.second, attr.first.c_str(), status);
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h
index 4130e72..e3e14af 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h
@@ -25,8 +25,10 @@
 namespace tensorflow {
 
 // Attempts to evaluates an MLIR Operation in TensorFlow eager mode with the
-// specified operands. If successful, this fills in the results vector. If not,
-// results vector is unspecified.
+// specified operands. The op is always executed on the local host CPU
+// irrespective of the device attribute of the given op. If there is a CPU
+// kernel registered for the op and is executed successfully, this fills in the
+// results vector.  If not, results vector is unspecified.
 //
 mlir::LogicalResult EvaluateOperation(
     mlir::Operation* inst, llvm::ArrayRef<mlir::ElementsAttr> operands,
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
index fb73abb..f6e1a9d 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
@@ -27,11 +27,11 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc
index d82d61e..6b97aba 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc
@@ -27,12 +27,12 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/Region.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Interfaces/DerivedAttributeOpInterface.h"  // from @llvm-project
 #include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc
index f6cf5a4..ac6bc63 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc
@@ -40,6 +40,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
+#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
 #include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
 #include "tensorflow/compiler/tf2xla/xla_argument.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -233,8 +234,64 @@
 
 }  // anonymous namespace
 
-static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
-    mlir::ModuleOp module_op, llvm::raw_ostream& output) {
+// Test BuildHloFromTf. BuildHloFromTf only performs part of the conversion, so
+// to make this test comparable to other compile tests, the test implements
+// the remaining parts of the conversion.
+Status CompileMlirToXlaHloViaBuilder(
+    mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
+    llvm::StringRef device_type, XlaCompilationResult* compilation_result,
+    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+        custom_legalization_passes) {
+  // This call to RefineShapes is redundant with the call in BuildHloFromTf.
+  // It's here so xla::Parameters that are created form block.getArguments will
+  // have the proper shapes.
+  TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
+
+  mlir::FuncOp main = module_op.lookupSymbol<mlir::FuncOp>("main");
+  mlir::Block& block = main.getRegion().front();
+  xla::XlaBuilder builder("main");
+
+  // Create xla_params.
+  std::vector<xla::XlaOp> xla_params;
+  for (mlir::BlockArgument& arg : block.getArguments()) {
+    auto num = arg.getArgNumber();
+    xla::Shape shape = xla::TypeToShape(arg.getType());
+    xla::XlaOp argop =
+        xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
+    xla_params.push_back(argop);
+  }
+
+  std::vector<xla::XlaOp> returns(1);
+  TF_RETURN_IF_ERROR(BuildHloFromTf(module_op, builder, xla_params, returns,
+                                    arg_shapes, device_type,
+                                    custom_legalization_passes));
+
+  xla::XlaOp return_value;
+  if (returns.size() == 1)
+    return_value = returns[0];
+  else
+    return_value = xla::Tuple(&builder, returns);
+
+  TF_ASSIGN_OR_RETURN(
+      xla::XlaComputation computation,
+      return_value.valid() ? builder.Build(return_value) : builder.Build());
+  auto hlo_module = computation.proto();
+  xla::HloProto hlo_proto;
+  hlo_proto.mutable_hlo_module()->Swap(&hlo_module);
+
+  compilation_result->computation = std::make_shared<xla::XlaComputation>();
+  xla::XlaComputation* xla_computation = compilation_result->computation.get();
+  *xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
+
+  XlaHelpers::ShapeRepresentationFn shape_representation_fn =
+      IdentityShapeRepresentationFn();
+  return PopulateResultIOInfo(module_op, arg_shapes, /*use_tuple_args=*/false,
+                              /*use_resource_updates_for_aliases=*/false,
+                              shape_representation_fn, compilation_result);
+}
+
+static mlir::LogicalResult MlirTfToHloTextTranslateFunctionImpl(
+    mlir::ModuleOp module_op, llvm::raw_ostream& output, bool via_builder) {
   if (!module_op) return mlir::failure();
 
   llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
@@ -245,12 +302,21 @@
     return mlir::failure();
   }
 
+  auto device_type = "XLA_CPU_JIT";
+  llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
+      custom_legalization_passes{};
   XlaCompilationResult compilation_result;
-  auto compilation_status = CompileMlirToXlaHlo(
-      module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
-      emit_return_tuple, /*use_resource_updates_for_aliases=*/true,
-      IdentityShapeRepresentationFn(), &compilation_result,
-      /*custom_legalization_passes=*/{});
+  auto compilation_status =
+      via_builder
+          ? CompileMlirToXlaHloViaBuilder(module_op, arg_shapes, device_type,
+                                          &compilation_result,
+                                          custom_legalization_passes)
+          : CompileMlirToXlaHlo(module_op, arg_shapes, device_type,
+                                emit_use_tuple_arg, emit_return_tuple,
+                                /*use_resource_updates_for_aliases=*/true,
+                                IdentityShapeRepresentationFn(),
+                                &compilation_result,
+                                custom_legalization_passes);
   if (!compilation_status.ok()) {
     LOG(ERROR) << "TF/XLA compilation failed: "
                << compilation_status.ToString();
@@ -326,12 +392,27 @@
   return mlir::success();
 }
 
+static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
+    mlir::ModuleOp module_op, llvm::raw_ostream& output) {
+  return MlirTfToHloTextTranslateFunctionImpl(module_op, output, false);
+}
+
+static mlir::LogicalResult MlirTfToHloTextViaBuilderTranslateFunction(
+    mlir::ModuleOp module_op, llvm::raw_ostream& output) {
+  return MlirTfToHloTextTranslateFunctionImpl(module_op, output, true);
+}
+
 }  // namespace tensorflow
 
 static mlir::TranslateFromMLIRRegistration MlirTfToHloTextTranslate(
     "mlir-tf-to-hlo-text", tensorflow::MlirTfToHloTextTranslateFunction,
     tensorflow::RegisterMlirInputDialects);
 
+static mlir::TranslateFromMLIRRegistration MlirTfToHloTextViaBuilderTranslate(
+    "mlir-tf-to-hlo-text-via-builder",
+    tensorflow::MlirTfToHloTextViaBuilderTranslateFunction,
+    tensorflow::RegisterMlirInputDialects);
+
 static mlir::TranslateFromMLIRRegistration MlirTfGraphToHloTextTranslate(
     "mlir-tf-graph-to-hlo-text",
     tensorflow::MlirTfGraphToHloTextTranslateFunction,
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc
index bf7c82b..5ab7857 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc
@@ -489,13 +489,13 @@
 mlir::LogicalResult GetHostDeviceOutsideComputation(
     mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster,
     std::string* host_device) {
-  auto replicate = cluster.getParentOfType<mlir::tf_device::ReplicateOp>();
+  auto replicate = cluster->getParentOfType<mlir::tf_device::ReplicateOp>();
   if (replicate) {
     *host_device = tensorflow::kTPUReplicatedHost;
     return mlir::success();
   }
 
-  auto num_cores_per_replica_attr = cluster.getAttrOfType<mlir::IntegerAttr>(
+  auto num_cores_per_replica_attr = cluster->getAttrOfType<mlir::IntegerAttr>(
       tensorflow::kNumCoresPerReplicaAttr);
   if (!num_cores_per_replica_attr)
     return cluster.emitOpError(
@@ -506,12 +506,12 @@
         "outside compilation is not supported with model parallelism.");
 
   auto topology_attr =
-      cluster.getAttrOfType<mlir::StringAttr>(tensorflow::kTopologyAttr);
+      cluster->getAttrOfType<mlir::StringAttr>(tensorflow::kTopologyAttr);
   if (!topology_attr)
     return cluster.emitOpError("cluster op missing `topology` attribute");
 
-  auto device_assignment_attr =
-      cluster.getAttrOfType<mlir::ArrayAttr>(tensorflow::kDeviceAssignmentAttr);
+  auto device_assignment_attr = cluster->getAttrOfType<mlir::ArrayAttr>(
+      tensorflow::kDeviceAssignmentAttr);
   if (!device_assignment_attr)
     return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
                                              tensorflow::kDeviceAssignmentAttr)
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
index 78547c8..63afd0a 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
@@ -650,10 +650,10 @@
   llvm::SmallVector<mlir::Type, 8> result_types;
   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
       mlir::UnknownLoc::get(&context), result_types);
-  cluster.setAttr(kNumCoresPerReplicaAttr,
-                  builder.getIntegerAttr(builder.getIntegerType(64), 5));
-  cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
-  cluster.setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
+  cluster->setAttr(kNumCoresPerReplicaAttr,
+                   builder.getIntegerAttr(builder.getIntegerType(64), 5));
+  cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
+  cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
 
   mlir::TF::RuntimeDevices runtime_devices;
   std::string host_device;
@@ -671,9 +671,9 @@
   llvm::SmallVector<mlir::Type, 8> result_types;
   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
       mlir::UnknownLoc::get(&context), result_types);
-  cluster.setAttr(kNumCoresPerReplicaAttr,
-                  builder.getIntegerAttr(builder.getIntegerType(64), 1));
-  cluster.setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
+  cluster->setAttr(kNumCoresPerReplicaAttr,
+                   builder.getIntegerAttr(builder.getIntegerType(64), 1));
+  cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
 
   mlir::TF::RuntimeDevices runtime_devices;
   std::string host_device;
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc
index f32485e..075d33a 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc
@@ -30,15 +30,16 @@
       "bad_consumers",
       b.getI32ArrayAttr(llvm::ArrayRef<int32_t>(
           versions.bad_consumers().begin(), versions.bad_consumers().end())));
-  module.setAttr("tf.versions",
-                 b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>(
-                     {producer, min_consumer, bad_consumers})));
+  module->setAttr("tf.versions",
+                  b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>(
+                      {producer, min_consumer, bad_consumers})));
 }
 
 mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module,
                                       VersionDef* versions) {
   versions->Clear();
-  auto version_attr = module.getAttrOfType<mlir::DictionaryAttr>("tf.versions");
+  auto version_attr =
+      module->getAttrOfType<mlir::DictionaryAttr>("tf.versions");
   if (!version_attr) return mlir::failure();
 
   auto producer =
@@ -66,7 +67,7 @@
 
 ::stream_executor::port::StatusOr<int64_t> GetTfGraphProducerVersion(
     mlir::ModuleOp module) {
-  auto versions = module.getAttrOfType<::mlir::DictionaryAttr>("tf.versions");
+  auto versions = module->getAttrOfType<::mlir::DictionaryAttr>("tf.versions");
   if (!versions) {
     return errors::Internal(
         "Missing 'tf.versions' attribute on the module, abort.\n");
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc
index 8d429ec..d58f029 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc
@@ -15,7 +15,7 @@
 
 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
 
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 
 namespace mlir {
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc
index a3f8e83..46a88da 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc
@@ -26,8 +26,8 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
@@ -92,8 +92,9 @@
   llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
   *split_op = builder->create<mlir::TF::SplitOp>(
       location, output_types, split_dimension_op.output(), src_input);
-  split_op->setAttr(kNumSplitAttr, builder->getIntegerAttr(
-                                       builder->getIntegerType(32), num_split));
+  (*split_op)->setAttr(
+      kNumSplitAttr,
+      builder->getIntegerAttr(builder->getIntegerType(32), num_split));
   return mlir::success();
 }
 
diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
index ee02419..70c00ef 100644
--- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
+++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
@@ -22,6 +22,7 @@
 #include "tensorflow/compiler/mlir/init_mlir.h"
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
 #include "tensorflow/core/platform/init_main.h"
 
@@ -29,6 +30,7 @@
   tensorflow::InitMlir y(&argc, &argv);
 
   mlir::registerAllPasses();
+  mlir::registerTensorFlowPasses();
   mlir::mhlo::registerAllMhloPasses();
   mlir::lmhlo::registerAllLmhloPasses();
 
diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
index bc52e3a..4e2c5f6 100644
--- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
+++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
@@ -23,12 +23,11 @@
 #define TENSORFLOW_COMPILER_MLIR_TFJS_IR_TFJS_OPS_H_
 
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
-
 #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_dialect.h.inc"
 
 #define GET_OP_CLASSES
diff --git a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc
index 04811ff..353e961 100644
--- a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc
@@ -20,9 +20,9 @@
 
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD
index 49c5091..2dfa85d 100644
--- a/tensorflow/compiler/mlir/tfr/BUILD
+++ b/tensorflow/compiler/mlir/tfr/BUILD
@@ -104,7 +104,9 @@
         "utils/utils.h",
     ],
     deps = [
+        ":tfr",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Support",
     ],
 )
diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD
index 9fc30ba..037536c 100644
--- a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD
+++ b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD
@@ -68,26 +68,22 @@
     tags = [
         "no_cuda_asan",  # Not needed, and there were issues with timeouts.
         "no_oss",  # Avoid downloading mnist data set in oss.
-        "notap",  # The test is too long to run as part of llvm
-        # presubmits (b/173661843).
-        "notsan",  # Not needed, and there were issues with timeouts.
         "nomultivm",  # Not needed. Save some resources and test time.
-
-        # TODO(b/172367622) Re-enable TPU test after issues with TPU are
-        # resolved.
-        "notpu",
+        "notap",  # The test is too long to run as part of llvm presubmits (b/173661843).
+        "notsan",  # Not needed, and there were issues with timeouts.
     ],
+
+    # TODO(b/175056184): Re-enable xla_enable_strict_auto_jit once the issues
+    # with GPU and the MLIR bridge are worked out.
+    xla_enable_strict_auto_jit = False,
     deps = [
         ":mnist_train",
-        "@absl_py//absl/testing:parameterized",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:extra_py_tests_deps",
+        "//tensorflow/python:is_mlir_bridge_test_true",
         "//tensorflow/python/distribute:combinations",
         "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/distribute:test_util",
-
-        # TODO(b/172367622) Switch to MLIR bridge after issues with MLIR bridge
-        # are resolved.
-        # "//tensorflow/python:is_mlir_bridge_test_true",
+        "@absl_py//absl/testing:parameterized",
     ],
 )
diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc
index f7f39ab..f3e8780 100644
--- a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc
+++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc
@@ -38,12 +38,17 @@
   if (!IsEnabled()) return Status::OK();
   // This can be the default cpu device.
   if (orig_op->Device() != kVariantDeviceNull) return Status::OK();
+  if (orig_op->is_function()) return Status::OK();
+
   // TODO(fengliuai): We need a better condition to skip the rewrite. Currently,
   // The rewrite is enabled for all the tf ops and it is a no-op if the tf op
-  // isn't a composite op. "VarHandleOp" is explicitly skipped here because its
-  // roundtrip fails due to some unknown reasons.
-  if (orig_op->is_function()) return Status::OK();
-  if (absl::StartsWith(orig_op->op_name(), "VarHandleOp")) return Status::OK();
+  // isn't a composite op. The following ops are explicitly skipped here because
+  // their "no-op" expansion is known to cause problems in some cases.
+  static const char* kOpsToSkip[] = {"IdentityOp", "NoOp", "OptionalHasValue",
+                                     "OptionalGetValue", "VarHandleOp"};
+  for (const char* skip : kOpsToSkip) {
+    if (absl::StartsWith(orig_op->op_name(), skip)) return Status::OK();
+  }
 
   tf_core_op_expansion_node_counter->GetCell()->IncrementBy(1);
 
diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
index a3b0193..a5edc9f 100644
--- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
+++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
@@ -28,10 +28,10 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Verifier.h"  // from @llvm-project
 #include "mlir/Parser.h"  // from @llvm-project
@@ -150,7 +150,7 @@
   mlir::Location loc = mlir::UnknownLoc::get(context);
   mlir::ModuleOp module = mlir::ModuleOp::create(loc);
   mlir::FunctionType func_type =
-      mlir::FunctionType::get(input_tys, output_tys, context);
+      mlir::FunctionType::get(context, input_tys, output_tys);
   llvm::StringRef func_name_str(func_name.data(), func_name.size());
   auto func = mlir::FuncOp::create(loc, func_name_str, func_type, {});
   module.push_back(func);
diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc
index 8736e9f..d451bea 100644
--- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc
+++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc
@@ -19,9 +19,9 @@
 
 #include "absl/types/span.h"
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/common_shape_fns.h"
@@ -82,6 +82,7 @@
   tfr.return %res : !tfr.tensor
 }
 
+tfr.func @tf__my_add_n_(!tfr.tensor_list<N,T>, i64 {tfr.name="N"}) -> !tfr.tensor attributes{N,T}
 tfr.func @tf__risc_add_dummy_(!tfr.tensor<T>, !tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
 )";
 
diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc
index 6933875..be01511 100644
--- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc
+++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc
@@ -31,6 +31,7 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
 #include "mlir/IR/FunctionImplementation.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
@@ -38,7 +39,6 @@
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
@@ -121,6 +121,13 @@
   addInterfaces<TFRInlinerInterface>();
 }
 
+Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value,
+                                           Type type, Location loc) {
+  if (ConstantOp::isBuildableWith(value, type))
+    return builder.create<ConstantOp>(loc, type, value);
+  return nullptr;
+}
+
 bool TFRType::classof(Type type) {
   return llvm::isa<TFRDialect>(type.getDialect());
 }
@@ -171,7 +178,7 @@
   // and returns. Also, collect the names of all the attribute arguments as the
   // defined list. Later on, the used attribute names will be verified to be in
   // the defined list.
-  llvm::SmallVector<StringAttr, 4> used_attrs;
+  llvm::SmallVector<StringAttr, 4> input_used_attrs, output_used_attrs;
 
   // While scanning the arguments, record the start/end indices of each argument
   // type, so the order can be verified as well.
@@ -179,7 +186,6 @@
   // at the end?
   int first_tensor = -1, last_tensor = -1, first_tensor_list = -1,
       last_tensor_list = -1, first_attr = -1;
-
   for (auto arg : llvm::enumerate(func.getType().getInputs())) {
     Type arg_type = arg.value();
 
@@ -189,7 +195,7 @@
       }
       last_tensor = arg.index();
       auto used = tensor.getAttrKeys();
-      used_attrs.append(used.begin(), used.end());
+      input_used_attrs.append(used.begin(), used.end());
       continue;
     }
 
@@ -199,7 +205,7 @@
       }
       last_tensor_list = arg.index();
       auto used = tensor_list.getAttrKeys();
-      used_attrs.append(used.begin(), used.end());
+      input_used_attrs.append(used.begin(), used.end());
       continue;
     }
 
@@ -222,46 +228,62 @@
     return failure();
   }
 
+  // Collect all the undefined attributes used in the inputs.
+  llvm::SmallVector<StringAttr, 4> undefined_attrs;
+  for (auto attr : input_used_attrs) {
+    if (!func->getAttr(attr.getValue())) {
+      undefined_attrs.push_back(attr);
+    }
+  }
+
   // Verify the argument order: tensors, tensor list, attributes; and also
   // verify there is at most one tensor list argument.
-  if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
+  if (first_attr != -1 &&
+      (first_attr < last_tensor_list || first_attr < last_tensor)) {
     func.emitError(
-        "tfr.tensor argument should be before tfr.tensor_list argument.");
+        "tfr.tensor/tfr.tensor_list argument should be before non tensor "
+        "arguments.");
     return failure();
   }
-  if (first_attr != -1 && first_attr < last_tensor_list) {
-    func.emitError(
-        "tfr.tensor_list argument should be before non tensor arguments.");
-    return failure();
-  }
-  if (first_tensor_list != last_tensor_list) {
-    func.emitError("More than one tfr.tensor_list argument isn't allowed.");
-    return failure();
+  // The order between tensor arguments and tensor list arguments and the number
+  // of tensor list arguments are verified only when they couldn't be determined
+  // by the attributes.
+  if (!undefined_attrs.empty()) {
+    if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
+      func.emitError(
+          "tfr.tensor argument should be before tfr.tensor_list argument.");
+      return failure();
+    }
+    if (first_tensor_list != last_tensor_list) {
+      func.emitError("More than one tfr.tensor_list argument isn't allowed.");
+      return failure();
+    }
   }
 
   // Verify the result order: tensor, tensor list, and also verify at most one
   // tensor list result.
-  bool seen_tensor_list = false;
+  int undefined_input_attrs_number = undefined_attrs.size();
+  bool seen_tensor_list = false, has_tensor_list_order_error = false,
+       has_multiple_tensor_lists_error = false;
   for (auto result_type : func.getType().getResults()) {
     if (auto tensor = result_type.dyn_cast<TFRTensorType>()) {
       if (seen_tensor_list) {
-        func.emitError(
-            "tfr.tensor result should be before tfr.tensor_list result.");
-        return failure();
+        has_tensor_list_order_error = true;
+      } else {
+        auto used = tensor.getAttrKeys();
+        output_used_attrs.append(used.begin(), used.end());
       }
-      auto used = tensor.getAttrKeys();
-      used_attrs.append(used.begin(), used.end());
       continue;
     }
 
     if (auto tensor_list = result_type.dyn_cast<TFRTensorListType>()) {
       if (seen_tensor_list) {
-        func.emitError("More than one tfr.tensor_list result isn't allowed.");
-        return failure();
+        has_multiple_tensor_lists_error = true;
+      } else {
+        seen_tensor_list = true;
+        auto used = tensor_list.getAttrKeys();
+        output_used_attrs.append(used.begin(), used.end());
       }
-      seen_tensor_list = true;
-      auto used = tensor_list.getAttrKeys();
-      used_attrs.append(used.begin(), used.end());
       continue;
     }
 
@@ -271,13 +293,28 @@
     return failure();
   }
 
-  // Verify that all the used attributes are in the attribute arguments.
-  llvm::SmallVector<StringAttr, 4> undefined_attrs;
-  for (auto attr : used_attrs) {
-    if (!func.getAttr(attr.getValue())) {
+  // Collect all the undefined attributes used in the outputs.
+  for (auto attr : output_used_attrs) {
+    if (!func->getAttr(attr.getValue())) {
       undefined_attrs.push_back(attr);
     }
   }
+
+  // Verify there are no tensor/tensor list order error and multiple tensor
+  // list arguments error.
+  if (undefined_input_attrs_number != undefined_attrs.size()) {
+    if (has_tensor_list_order_error) {
+      func.emitError(
+          "tfr.tensor result should be before tfr.tensor_list result.");
+      return failure();
+    } else if (has_multiple_tensor_lists_error) {
+      func.emitError("More than one tfr.tensor_list result isn't allowed.");
+      return failure();
+    }
+  }
+
+  // TODO(fengliuai): We might want to refine this constraint because the
+  // tensor element type can be derived.
   if (!undefined_attrs.empty()) {
     llvm::SmallVector<std::string, 4> attr_names(undefined_attrs.size());
     std::transform(undefined_attrs.begin(), undefined_attrs.end(),
@@ -437,6 +474,23 @@
   }
 };
 
+struct RemoveRedundantGetLength : public OpRewritePattern<GetLengthOp> {
+  using OpRewritePattern<GetLengthOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GetLengthOp gl_op,
+                                PatternRewriter &rewriter) const override {
+    auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
+        gl_op.tensor_list().getDefiningOp());
+    if (!preceding_build_list) {
+      return failure();
+    }
+    int64_t num_tensors = preceding_build_list.getNumOperands();
+    rewriter.replaceOpWithNewOp<ConstantOp>(gl_op,
+                                            rewriter.getIndexAttr(num_tensors));
+    return success();
+  }
+};
+
 struct BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
   using OpRewritePattern<BuildListOp>::OpRewritePattern;
 
@@ -477,6 +531,11 @@
   results.insert<RemoveRedundantGetElement>(context);
 }
 
+void GetLengthOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
+  results.insert<RemoveRedundantGetLength>(context);
+}
+
 void BuildListOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
   results.insert<BuildConstantListAsAttr>(context);
diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h
index 5145f22..6732198 100644
--- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h
+++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h
@@ -16,12 +16,13 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_
 #define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_
 
+#include "llvm/ADT/StringSet.h"
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
 #include "mlir/IR/FunctionSupport.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
 #include "mlir/Interfaces/ControlFlowInterfaces.h"  // from @llvm-project
@@ -39,6 +40,9 @@
 
   static StringRef getDialectNamespace() { return "tfr"; }
 
+  Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+                                 Location loc) override;
+
   // Parse a type registered to this dialect.
   Type parseType(DialectAsmParser &parser) const override;
 
diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td
index 4c2ecc0..9d1e7fb 100644
--- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td
+++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td
@@ -93,7 +93,8 @@
 // all allowed build list input types
 def TFR_allowedBuiltListType : Type<Or<[
     TFR_TensorType.predicate,
-    TF_ElementType.predicate]>, "single tfr.tensor or tensor element type">;
+    TF_ElementType.predicate,
+    TFR_AttrType.predicate]>, "single tfr.tensor or tensor element type">;
 
 // all allowed build list result types
 def TFR_allowedListResultType : Type<Or<[
@@ -349,6 +350,30 @@
   let hasCanonicalizer = 1;
 }
 
+def TFR_GetLengthOp : TFR_Op<"get_length", [NoSideEffect]> {
+  let description = [{
+    The `get_length` operation returns the number of tensors for a
+    tfr.tensor_list.
+
+    Example:
+
+    ```mlir
+    %2 = tfr.get_length(%1) : tfr.tensor -> index
+    %2 = tfr.get_length %1 -> index
+    ```
+  }];
+
+  let arguments = (ins TFR_TensorListType:$tensor_list);
+
+  let results = (outs Index:$out);
+
+  let hasCanonicalizer = 1;
+
+  let assemblyFormat = [{
+    $tensor_list attr-dict `->` type($out)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Function related classes
 //===----------------------------------------------------------------------===//
@@ -411,6 +436,22 @@
     // Hooks for the input/output type enumeration in FunctionLike .
     unsigned getNumFuncArguments() { return getType().getNumInputs(); }
     unsigned getNumFuncResults() { return getType().getNumResults(); }
+
+    // Get the names of all defined attributes, including both derived and
+    // non-derived ones.
+    llvm::StringSet<> getDefinedAttributeNames() {
+      llvm::StringSet<> all_attrs;
+      for (auto& attr : getAttrs()) {
+        all_attrs.insert(attr.first.strref());
+      }
+      for (const auto& operand : llvm::enumerate(getType().getInputs())) {
+        if (auto attr_name = getArgAttrOfType<StringAttr>(
+            operand.index(), kAttrArgumentNameAttr)) {
+          all_attrs.insert(attr_name.getValue());
+        }
+      }
+      return all_attrs;
+    }
   }];
 
   let verifier = [{ return Verify(*this); }];
diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_types.h b/tensorflow/compiler/mlir/tfr/ir/tfr_types.h
index 4bda8f3..b27d56e 100644
--- a/tensorflow/compiler/mlir/tfr/ir/tfr_types.h
+++ b/tensorflow/compiler/mlir/tfr/ir/tfr_types.h
@@ -17,10 +17,10 @@
 #define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_
 
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeSupport.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 
diff --git a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc
index d399a10..3240a3a 100644
--- a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc
+++ b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc
@@ -151,9 +151,21 @@
 
 }  // namespace
 
-void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results,
-                                            MLIRContext *context) {
-  results.insert<UnrollSCFForOp, SimplifySCFIfOp>(context);
+void populateCanonicalizationPatterns(FuncOp func,
+                                      OwningRewritePatternList &patterns) {
+  MLIRContext *context = func.getContext();
+  mlir::Dialect *tf = context->getLoadedDialect<mlir::TF::TensorFlowDialect>();
+  // Load all official canonicalization patterns. Here we skip the
+  // canonicalization of the ops in the tf dialect, because they couldn't
+  // propagate the attributes correctly. These optimization will be played by
+  // bridge.
+  func->walk([&](Operation *op) {
+    if (op->getDialect() != tf) {
+      op->getAbstractOperation()->getCanonicalizationPatterns(patterns,
+                                                              context);
+    }
+  });
+  patterns.insert<UnrollSCFForOp, SimplifySCFIfOp>(context);
 }
 
 }  // namespace TFR
diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc
index f51e460..c532bc1 100644
--- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc
+++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc
@@ -35,8 +35,8 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/IR/Visitors.h"  // from @llvm-project
@@ -99,22 +99,19 @@
 };
 
 void DecomposeTFOpsPass::ApplyCanonicalization() {
+  FuncOp func = getFunction();
   OwningRewritePatternList patterns;
 
-  auto* context = &getContext();
-  for (auto* op : context->getRegisteredOperations()) {
-    op->getCanonicalizationPatterns(patterns, context);
-  }
-  populateSCFOpsCanonicalizationPatterns(patterns, context);
+  populateCanonicalizationPatterns(func, patterns);
 
-  applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+  applyPatternsAndFoldGreedily(func, std::move(patterns));
 }
 
 LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
   FuncOp func = getFunction();
   SymbolTable table(external_tfr_module.hasValue()
                         ? *external_tfr_module
-                        : func.getParentOfType<ModuleOp>());
+                        : func->getParentOfType<ModuleOp>());
   OpBuilder builder(func);
   bool changed = false;
   func.walk([&table, &builder, &changed](Operation* op) {
@@ -122,7 +119,7 @@
     // either will be constant folded or lowered by the rules defined in the
     // bridge.
     if (op->isRegistered()) {
-      return;
+      return WalkResult::advance();
     }
 
     // Find out the compose function
@@ -130,7 +127,17 @@
     auto compose_func = table.lookup<TFRFuncOp>(compose_func_name);
     if (!compose_func || compose_func.isExternal()) {
       // There are no decomposition methods defined for this op, skip.
-      return;
+      return WalkResult::advance();
+    }
+
+    // Make sure all the attributes are valid. An attribute is valid when it is
+    // in the signature or it is allowed explicitly.
+    auto compose_func_signature =
+        table.lookup<TFRFuncOp>(compose_func_name + "_");
+    if (!compose_func_signature) compose_func_signature = compose_func;
+    auto defined_attrs = compose_func_signature.getDefinedAttributeNames();
+    if (failed(ValidateAttrs(op, defined_attrs))) {
+      return WalkResult::interrupt();
     }
 
     tensorflow::IncreaseOpExpansionExecuteCounterByOne(
@@ -215,8 +222,15 @@
           op->getLoc(), std::get<0>(res).getType(), std::get<1>(res));
       std::get<0>(res).replaceAllUsesWith(casted.out());
     }
+
+    // Copy all the unregisted attributes to the new op.
+    if (failed(CopyAllowedUnregisteredAttrs(op, new_op, defined_attrs))) {
+      return WalkResult::interrupt();
+    }
+
     op->erase();
     changed |= true;
+    return WalkResult::advance();
   });
 
   // If `changed` is false, it is considered as a failure, so the recursive
@@ -230,13 +244,22 @@
   FuncOp func = getFunction();
   SymbolTable table(external_tfr_module.hasValue()
                         ? *external_tfr_module
-                        : func.getParentOfType<ModuleOp>());
+                        : func->getParentOfType<ModuleOp>());
 
   // The inliner only inlines the TFR call op.
   bool changed = false;
   auto walk_result = func.walk([&](CallOp call_op) {
     auto callee = table.lookup<TFRFuncOp>(call_op.callee());
     if (!callee || callee.isExternal()) return WalkResult::advance();
+
+    // Record the boundary of the inlined operations. The inlined operation will
+    // be inserted between these two operations.
+    Operation* inlined_point = call_op.getOperation();
+    Operation* after_inlined_point =
+        &*std::next(Block::iterator(call_op.getOperation()));
+
+    // Use the inliner to replace all the uses of the call_op by its
+    // composition.
     if (failed(inlineCall(inliner,
                           cast<CallOpInterface>(call_op.getOperation()),
                           cast<CallableOpInterface>(callee.getOperation()),
@@ -246,6 +269,13 @@
       // This call will be raised to TF ops.
       return WalkResult::interrupt();
     }
+
+    // Propagate all the attributes to the inlined operations, which are defined
+    // by the two boundary operations.
+    PropagateAttrsToOperations(call_op, Block::iterator(inlined_point),
+                               Block::iterator(after_inlined_point));
+
+    // Remove the call_op to finish the op expansion.
     call_op.erase();
     changed |= true;
     return WalkResult::advance();
diff --git a/tensorflow/compiler/mlir/tfr/passes/passes.h b/tensorflow/compiler/mlir/tfr/passes/passes.h
index bb227e9..8914cba 100644
--- a/tensorflow/compiler/mlir/tfr/passes/passes.h
+++ b/tensorflow/compiler/mlir/tfr/passes/passes.h
@@ -25,8 +25,10 @@
 namespace mlir {
 namespace TFR {
 
-void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results,
-                                            MLIRContext *context);
+// Scans the func op and adds all the canonicalization patterns of the ops
+// except the tf ops, inside the function.
+void populateCanonicalizationPatterns(FuncOp func,
+                                      OwningRewritePatternList &patterns);
 
 // Decompose ops.
 std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeTFOpsPass(
diff --git a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc
index b3e1983..d3780a4 100644
--- a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc
+++ b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc
@@ -35,11 +35,11 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
@@ -377,6 +377,10 @@
       new_results.push_back(list_op.out());
     }
   }
+
+  // Copy all the allowed attributes to the new op.
+  if (failed(CopyNonSymbolRefAttrs(call_op, new_op))) return failure();
+
   rewriter.replaceOp(call_op, new_results);
   return success();
 }
@@ -446,13 +450,12 @@
   MLIRContext* ctx = &getContext();
   SymbolTable table(external_tfr_module.hasValue()
                         ? *external_tfr_module
-                        : func.getParentOfType<ModuleOp>());
+                        : func->getParentOfType<ModuleOp>());
 
   OwningRewritePatternList patterns;
   patterns.insert<RewriteTFRCallOp>(ctx, table, materialize_derived_attrs);
-  for (auto* op : ctx->getRegisteredOperations()) {
-    op->getCanonicalizationPatterns(patterns, ctx);
-  }
+
+  populateCanonicalizationPatterns(func, patterns);
 
   applyPatternsAndFoldGreedily(func, std::move(patterns));
 }
diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_gen.py b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py
index 09323b7..a5e275a 100644
--- a/tensorflow/compiler/mlir/tfr/python/tfr_gen.py
+++ b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py
@@ -26,7 +26,6 @@
 import os
 import re
 import types
-from typing import List, Tuple
 import gast as ast
 
 from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr
@@ -43,11 +42,14 @@
 from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
 from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
 from tensorflow.python.autograph.pyct.static_analysis import type_inference
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import load_library
 from tensorflow.python.framework import op_def_registry
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import tf_inspect
 
+# TODO(mdan): Use class definitions so that we can mix these with Python types.
+
 
 class TFRTypes(enum.Enum):
   """All the supported types.
@@ -315,6 +317,13 @@
     float: TFRTypes.F32,
 }
 
+_TF_DTYPE_TO_TFR = {
+    'bool': TFRTypes.I1,
+    'int64': TFRTypes.I64,
+    'int32': TFRTypes.I32,
+    'float32': TFRTypes.F32,
+}
+
 _AG_FIXED_RETURN_TYPE = {
     'for_stmt': type(None),
     'if_stmt': type(None),
@@ -379,6 +388,9 @@
     if getattr(value, '__name__', None) == 'tensorflow.raw_ops':
       return {types.ModuleType}
     if hasattr(value, '__module__'):
+      if isinstance(value, dtypes.DType):
+        return {TFRTypes.ATTR}
+
       # All the imported operations, which are not autograph built-ins, are
       # considered to be TF raw ops.
       # TODO(fengliuai): refine the condition so we only match TensorFlow
@@ -410,7 +422,7 @@
 
         iterated_type = args[0]
         assert iterated_type & {
-            TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, List[int]
+            TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, TFRTypes.ATTR
         }, (
             iterated_type)
         self._for_loop_target_types[body_fn_name] = iterated_type
@@ -443,7 +455,7 @@
     elif f_type == (TFRTypes.PY_BUILTIN_FUNC,):
       assert name.is_simple()
       if name == QN('range'):
-        return {List[int]}, None
+        return {TFRTypes.ATTR}, None
 
       if name == QN('len'):
         return {TFRTypes.INDEX}, None
@@ -459,7 +471,7 @@
       if f_name_str in self._for_loop_target_types:
         # See autograph/converters/control_flow.py - the function has a single
         # argument, the iterate before any expansion.
-        assert self._for_loop_target_types[f_name_str] & {List[int]}
+        assert self._for_loop_target_types[f_name_str] & {TFRTypes.ATTR}
         # Assume all loops are TF loops. Then the iterates are autoboxed into
         # Tensors.
         return {TFRTypes.INDEX}
@@ -488,7 +500,7 @@
 
     raise ValueError('Argument is not defined in OpDef: ' + str(name))
 
-  def res_subscript(self, ns, types_ns, node_or_slice, value, slice_):
+  def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
     assert len(value) == 1
     value, = tuple(value)
     if value == TFRTypes.TF_TENSOR_SHAPE_LIST:
@@ -503,10 +515,40 @@
     # TODO(fengliuai): make sure left and right are compatible
     return {TFRTypes.I1}
 
+  def res_unop(self, ns, types_ns, node, opnd):
+    return opnd
+
   def res_binop(self, ns, types_ns, node, left, right):
     # TODO(fengliuai): make sure left and right are compatible
     return left
 
+  def _coerce_to_more_specific_type(self, elt_types):
+    # TODO(mdan): This needs some type theory study.
+    if TFRTypes.INDEX in elt_types:
+      # Constants collapse to indices.
+      elt_types.discard(TFRTypes.I64)
+    if TFRTypes.TENSOR in elt_types:
+      # Constants collapse to tensors.
+      elt_types.discard(TFRTypes.I64)
+      # Indices collapse to tensors.
+      elt_types.discard(TFRTypes.INDEX)
+    return elt_types
+
+  def res_list_literal(self, ns, elt_types):
+    all_elt_types = set()
+    for t in elt_types:
+      all_elt_types |= t
+
+    if len(all_elt_types) != 1:
+      all_elt_types = self._coerce_to_more_specific_type(all_elt_types)
+
+    if len(all_elt_types) != 1:
+      raise ValueError('ambiguous list element types: {}'.format(elt_types))
+
+    if TFRTypes.TENSOR in all_elt_types:
+      return {TFRTypes.TENSOR_LIST}
+    return {TFRTypes.ATTR}
+
 
 class SymbolTable(object):
   """Symbol Table for python code."""
@@ -599,22 +641,6 @@
           node, types_))
 
     type_, = types_
-    # TODO(fengliuai): Tuple is added here to make return tuple work.
-    if type_ is list or type_ is Tuple:
-      # TODO(fengliuai): Seems like we need to move the followed list handling
-      # to the type inference and we shouldn't just put 'list' there. Otherwise
-      # we couldn't find out the right type for the Name node.
-      if not isinstance(node, ast.List):
-        return default
-      all_types = [
-          anno.getanno(elt, anno.Static.TYPES, None) for elt in node.elts
-      ]
-      if (TFRTypes.TENSOR,) in all_types:
-        # For the elt which is not tfr.tensor, tfr.constant_tensor needs to be
-        # use to cast it to a tfr.tensor.
-        return TFRTypes.TENSOR_LIST
-      else:
-        return TFRTypes.ATTR
 
     if default is not None and type_ != default:
       print('WARN: type annotation {}({}) does not match {}({})'.format(
@@ -643,6 +669,15 @@
     else:
       return value, ty
 
+  def _i64_to_index(self, value, ty):
+    if ty == TFRTypes.I64:
+      casted = self._ssa_name('casted')
+      self._emit_with_loc('\n{} = index_cast {} : i64 to index'.format(
+          casted, value))
+      return casted, TFRTypes.INDEX
+    else:
+      return value, ty
+
   def _value_to_tensor(self, value, ty, node):
     value, ty = self._index_to_I64(value, ty)
     cst_tensor = self._ssa_name('cst')
@@ -680,6 +715,13 @@
         # This branch is used when it is inside tensorflow
         return (node.attr, TFRTypes.TF_RAW_OP)
 
+      if node_type == TFRTypes.ATTR:
+        attr = self._ssa_name('attr')
+        tfr_type = _TF_DTYPE_TO_TFR.get(node.attr)
+        self._emit_with_loc(
+            '\n{} = tfr.constant {} -> !tfr.attr'.format(attr, tfr_type), node)
+        return (attr, TFRTypes.ATTR)
+
       value, _ = self.visit(node.value)
       tensor_type = self._get_inferred_type(node.value, None)
       # TODO(fengliuai): use node_type once it
@@ -695,7 +737,6 @@
     if isinstance(node.value, ast.Attribute):
       if isinstance(node.value.value, ast.Name):
         if node.value.value.id == 'tf' and node.value.attr == 'raw_ops':
-          # This branch is used when it is outside tensorflow
           return (node.attr, TFRTypes.TF_RAW_OP)
 
       value, ty = self.visit(node.value)
@@ -717,13 +758,24 @@
       raise NotImplementedError('Assignment target type not recognized.')
 
     if isinstance(values, list):
+      if isinstance(node.value, ast.Call):
+        expected = tuple(t for n, t in values)
+        if len(values) == 1:
+          expected = expected[0]
+      elif isinstance(node.value, ast.Tuple):
+        expected = tuple(t for n, t in values)
+      else:
+        raise ValueError('unknown assignment target node', node.value)
+      ty = self._get_inferred_type(node.value, expected)
+
       if len(targets) == len(values):
-        for key, value in zip(targets, values):
-          ssa_value, ty_ = value
-          ty = self._get_inferred_type(node.value, ty_)
-          self.symbol_table.insert_symbol(key, ssa_value, ty)
+        # TODO(mdan): This should already be a tuple.
+        ty_ = (ty,) if len(values) == 1 else ty
+        for key, value, t in zip(targets, values, ty_):
+          ssa_value, _ = value
+          self.symbol_table.insert_symbol(key, ssa_value, t)
       elif len(values) == 1:
-        n, ty = values[0]
+        n, _ = values[0]
         assert ty == TFRTypes.TENSOR_LIST
         # assign a tensor_list to multiple variables
         for idx, key in enumerate(targets):
@@ -738,10 +790,11 @@
           self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR)
       elif len(targets) == 1:
         ssa_names = [n for n, _ in values]
-        tys = [t for _, t in values]
-        self.symbol_table.insert_symbol(targets[0], ssa_names, tys)
-    else:
-      self.symbol_table.insert_symbol(targets[0], values[0], values[1])
+        self.symbol_table.insert_symbol(targets[0], ssa_names, ty)
+      return
+
+    ty = self._get_inferred_type(node.value, values[1])
+    self.symbol_table.insert_symbol(targets[0], values[0], ty)
 
   def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty):
     assert lhs_ty, rhs_ty
@@ -786,7 +839,7 @@
 
   def visit_Call(self, node):
     func_name, func_type = self.visit(node.func)
-    _ = self._get_inferred_type(node.func, func_type)
+    func_type = self._get_inferred_type(node.func, func_type)
     if func_type == TFRTypes.AG_BUILTIN_FUNC:
       if func_name == 'if_stmt':
         cond, _ = self.visit(node.args[0])
@@ -828,15 +881,19 @@
       if func_name == 'len':
         arg, ty = self.visit(node.args[0])
         ty = self._get_inferred_type(node.args[0], ty)
-        assert ty == TFRTypes.TF_TENSOR_SHAPE_LIST, ty
-        len_value = self._ssa_name('len')
-        self._emit_with_loc(
-            '\n{} = shape.rank {} : !shape.shape -> !shape.size'.format(
-                len_value, arg), node)
-        size_value = self._ssa_name('len_size')
-        self._emit_with_loc(
-            '\n{} = shape.size_to_index {} : !shape.size'.format(
-                size_value, len_value), node)
+        if ty == TFRTypes.TF_TENSOR_SHAPE_LIST:
+          len_value = self._ssa_name('len')
+          self._emit_with_loc(
+              '\n{} = shape.rank {} : !shape.shape -> !shape.size'.format(
+                  len_value, arg), node)
+          size_value = self._ssa_name('len_size')
+          self._emit_with_loc(
+              '\n{} = shape.size_to_index {} : !shape.size'.format(
+                  size_value, len_value), node)
+        elif ty == TFRTypes.TENSOR_LIST:
+          size_value = self._ssa_name('len')
+          self._emit_with_loc(
+              '\n{} = tfr.get_length {} -> index'.format(size_value, arg), node)
         return (size_value, TFRTypes.INDEX)
 
     raise NotImplementedError('call operator not recognized: {} {}'.format(
@@ -845,7 +902,7 @@
   def visit_Compare(self, node):
     lhs, lhs_ty = self.visit(node.left)
     for op, right in zip(node.ops, node.comparators):
-      rhs, _ = self.visit(right)
+      rhs, rhs_ty = self.visit(right)
       if isinstance(op, ast.Eq):
         pred = 'eq'
       elif isinstance(op, ast.Lt):
@@ -870,6 +927,10 @@
           code = 'cmpi'
         elif lhs_ty == TFRTypes.F32:
           code = 'cmpf'
+        elif lhs_ty == TFRTypes.INDEX:
+          code = 'cmpi'
+          # TODO(fengliuai): the reverse type inference should solve the issue.
+          rhs, _ = self._i64_to_index(rhs, rhs_ty)
         else:
           raise NotImplementedError('Compare operand type not recognized')
         self._emit_with_loc(
@@ -1268,6 +1329,7 @@
     tys = []
     for elt in node.elts:
       val, ty = self.visit(elt)
+      ty = self._get_inferred_type(elt, ty)
       if ty in _attribute_types and out_type == TFRTypes.TENSOR_LIST:
         # This list is a tensor list, then cast all the input values to tensors.
         val, ty = self._value_to_tensor(val, ty, node)
diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py b/tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py
index 704f2e5..b68fac9 100644
--- a/tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py
+++ b/tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py
@@ -28,6 +28,7 @@
 from tensorflow.compiler.mlir.tfr.python import composite
 from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module as tfr_gen
 from tensorflow.compiler.mlir.tfr.resources import gen_test_ops as test_ops
+from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import gen_array_ops as array_ops
 from tensorflow.python.ops import gen_math_ops as math_ops
 from tensorflow.python.platform import test
@@ -126,6 +127,15 @@
   return x_sum
 
 
+@composite.Composite('TestInputNOp')
+def _tfr_control_flow_tensor_list_size(ins):
+  n = len(ins)
+  if n == 0:
+    return array_ops.Const(value=[[0, 1], [2, 3]], dtype=dtypes.int64)
+  else:
+    return math_ops.AddN(ins)
+
+
 #--- test fn for tf ops ---
 
 
@@ -403,6 +413,10 @@
       CHECK-NEXT:   %{{.*}} = constant true
       CHECK-NEXT:   tfr.return %[[for_stmt]] : !tfr.tensor
       CHECK-NEXT: }
+
+      CHECK-LABEL: tfr.func @tf__test_input_n_op(%ins: !tfr.tensor_list) -> (!tfr.tensor) {
+      CHECK: %[[attr:.*]] = tfr.constant i64 -> !tfr.attr loc("tfr_gen_test.py":134:57)
+      CHECK: %Const = tfr.call @tf__const(%{{.*}}, %[[attr]]) : (!tfr.attr, !tfr.attr) -> (!tfr.tensor)
     """
     self._check_code(mlir_code, mlir_code_exp)
 
diff --git a/tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir b/tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir
index f67d24c..55e1f2c 100644
--- a/tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir
+++ b/tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir
@@ -28,6 +28,8 @@
   tfr.return %res : !tfr.tensor
 }
 
+tfr.func @tf__my_add_n_(!tfr.tensor_list<N,T>, i64 {tfr.name="N"}) -> !tfr.tensor attributes {N,T}
+
 // Translated from tf.compose Python function.
 tfr.func @tf__my_biased_dense(%input: !tfr.tensor, %weight: !tfr.tensor,
                               %bias: !tfr.tensor,
@@ -55,6 +57,9 @@
   tfr.return %res : !tfr.tensor
 }
 
+tfr.func @tf__my_biased_dense_(!tfr.tensor<T>, !tfr.tensor<T>, !tfr.tensor<T>,
+    !tfr.attr{tfr.name="act", tfr.default=""}) -> !tfr.tensor attributes {T}
+
 // This is a wong decomposition and used to verify that tf.Elu isn't decomposed
 // since its kernel has been registered.
 tfr.func @tf__elu_(%input: !tfr.tensor) -> !tfr.tensor {
diff --git a/tensorflow/compiler/mlir/tfr/tests/decompose.mlir b/tensorflow/compiler/mlir/tfr/tests/decompose.mlir
index 97f12c9..83361a0 100644
--- a/tensorflow/compiler/mlir/tfr/tests/decompose.mlir
+++ b/tensorflow/compiler/mlir/tfr/tests/decompose.mlir
@@ -82,3 +82,43 @@
 // CHECK-NEXT: return %[[back]] : tensor<f32>
 }
 
+// CHECK-LABEL: attribute_propagate_direct
+func @attribute_propagate_direct(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> {
+  %0 = "tf.Intermediate"(%arg0) {_tpu_replicate, device="hello"} : (tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string>
+  return %0 : tensor<1x2x3x4x!tf.string>
+
+// CHECK-NEXT: %[[casted:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
+// CHECK-NEXT: %[[id:.*]] = tfr.call @tf__risc(%[[casted]]) {_tpu_replicate, device = "hello"}
+// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id]]) : (!tfr.tensor) -> tensor<1x2x3x4x!tf.string>
+// CHECK-NEXT: return %[[back]]
+}
+
+// CHECK-LABEL: attribute_propagate
+func @attribute_propagate(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
+  %0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) {A=0:index, _tpu_replicate, device="hello"} : (tensor<1x2x3x4x!tf.string>, tensor<f32>, tensor<f32>) -> (tensor<1x2x3x4x!tf.string>, tensor<f32>)
+  return %0#1 : tensor<f32>
+
+// CHECK-NEXT: %[[in0:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
+// CHECK-NEXT: %[[in1:.*]] = "tfr.cast"(%arg1) : (tensor<f32>) -> !tfr.tensor
+// CHECK-NEXT: %[[id0:.*]] = tfr.call @tf__risc(%[[in0]]) {_tpu_replicate, device = "hello"}
+// CHECK-NEXT: %[[id1:.*]] = tfr.call @tf__risc(%[[in1]]) {_tpu_replicate, device = "hello"}
+// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id1]]) : (!tfr.tensor) -> tensor<f32>
+// CHECK-NEXT: return %[[back]] : tensor<f32>
+}
+
+// CHECK-LABEL: no_tf_canonicalization
+func @no_tf_canonicalization(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
+  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
+  return %0: tensor<8x3xf32>
+
+// CHECK:   "tf.Select"
+}
+
+// CHECK-LABEL: denied_attribute
+func @denied_attribute(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
+  // expected-error@+1 {{Denied unregistered attribute was found: denied_attr}}
+  %0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) {A=0:index, denied_attr} : (tensor<1x2x3x4x!tf.string>, tensor<f32>, tensor<f32>) -> (tensor<1x2x3x4x!tf.string>, tensor<f32>)
+  return %0#1 : tensor<f32>
+
+// CHECK-NEXT:   "tf.FusedN"(%arg0, %arg1, %arg2) {A = 0 : index, denied_attr}
+}
diff --git a/tensorflow/compiler/mlir/tfr/tests/ops.mlir b/tensorflow/compiler/mlir/tfr/tests/ops.mlir
index 0b35bb2..0440338 100644
--- a/tensorflow/compiler/mlir/tfr/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/tfr/tests/ops.mlir
@@ -260,6 +260,35 @@
 
 // -----
 
+// CHECK-LABEL: build_high_dim_const_list
+// CANON-LABEL: build_high_dim_const_list
+func @build_high_dim_const_list() -> !tfr.attr {
+  %0 = "std.constant"() {value = 42 : i32} : () -> i32
+  %1 = "std.constant"() {value = 41 : i32} : () -> i32
+  %2 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr
+  %3 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr
+  %4 = "tfr.build_list"(%2, %3) : (!tfr.attr, !tfr.attr) -> !tfr.attr
+  return %4 : !tfr.attr
+
+// CANON-NEXT: %[[c:.*]] = tfr.constant {{\[}}[42 : i32, 41 : i32], [42 : i32, 41 : i32]] -> !tfr.attr
+// CANON-NEXT: return %[[c]] : !tfr.attr
+}
+
+// -----
+
+// CHECK-LABEL: get_length
+// CANON-LABEL: get_length
+func @get_length(%arg0: !tfr.tensor<A>, %arg1: !tfr.tensor<B>) -> index {
+  %0 = "tfr.build_list"(%arg0, %arg1) : (!tfr.tensor<A>, !tfr.tensor<B>) -> !tfr.tensor_list
+  %1 = "tfr.get_length"(%0) : (!tfr.tensor_list) -> index
+  return %1 : index
+
+// CANON-NEXT: %[[c:.*]] = constant 2 : index
+// CANON-NEXT: return %[[c]] : index
+}
+
+// -----
+
 // CHECK-LABEL: tfr.func
 tfr.func @External(%arg0: !tfr.tensor<A>,
               %arg1: !tfr.tensor_list<C>,
@@ -315,7 +344,7 @@
 
 // -----
 
-// expected-error@+1 {{tfr.tensor_list argument should be before non tensor arguments}}
+// expected-error@+1 {{tfr.tensor/tfr.tensor_list argument should be before non tensor arguments}}
 tfr.func @Foo_invalid_arg_order(%arg0: !tfr.tensor<A>,
               %arg2: i32 {tfr.name = "A"},
               %arg1: !tfr.tensor_list<A>,
@@ -326,14 +355,25 @@
 
 // -----
 
+tfr.func @Foo_valid_arg_order0(
+              %arg1: !tfr.tensor_list,
+              %arg0: !tfr.tensor<T>,
+              %arg2: i32 {tfr.name = "A"},
+              %arg3: vector<1xi32> {tfr.name = "C"}) ->
+    (!tfr.tensor, !tfr.tensor_list) attributes {T}{
+  tfr.return %arg0, %arg1 : !tfr.tensor<T>, !tfr.tensor_list
+}
+
+// -----
+
 // expected-error@+1 {{tfr.tensor argument should be before tfr.tensor_list argument.}}
 tfr.func @Foo_invalid_arg_order0(
               %arg1: !tfr.tensor_list,
-              %arg0: !tfr.tensor,
+              %arg0: !tfr.tensor<T>,
               %arg2: i32 {tfr.name = "A"},
               %arg3: vector<1xi32> {tfr.name = "C"}) ->
     (!tfr.tensor, !tfr.tensor_list) {
-  tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list
+  tfr.return %arg0, %arg1 : !tfr.tensor<T>, !tfr.tensor_list
 }
 
 // -----
diff --git a/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir b/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir
index 41d0ee6..a54cc8c 100644
--- a/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir
+++ b/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir
@@ -74,3 +74,16 @@
 // CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[tfcast]]) {shape = #tf.shape<>} : (tensor<*xi32>) -> tensor<i32>
 // CHECK: return %[[es]] : tensor<i32>
 }
+
+// CHECK-LABEL: attribute_propagate
+func @attribute_propagate(%arg0: tensor<f32>) -> tensor<i32> {
+  %0 = "tfr.cast"(%arg0) : (tensor<f32>) -> !tfr.tensor
+  %t = tfr.constant i32 -> !tfr.attr
+  %concat = tfr.call @tf__risc_cast(%0, %t) {device = "hello", _tpu_replicate} : (!tfr.tensor, !tfr.attr) -> !tfr.tensor
+  %4 = "tfr.cast"(%concat) : (!tfr.tensor) -> tensor<i32>
+  return %4 : tensor<i32>
+
+// CHECK: %[[tfcast:.*]] = "tf.RiscCast"(%arg0) {K = i32, _tpu_replicate, device = "hello"} : (tensor<f32>) -> tensor<*xi32>
+// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[tfcast]]) {shape = #tf.shape<>} : (tensor<*xi32>) -> tensor<i32>
+// CHECK: return %[[es]] : tensor<i32>
+}
diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.cc b/tensorflow/compiler/mlir/tfr/utils/utils.cc
index 6c08b68..2dec560 100644
--- a/tensorflow/compiler/mlir/tfr/utils/utils.cc
+++ b/tensorflow/compiler/mlir/tfr/utils/utils.cc
@@ -15,11 +15,58 @@
 
 #include "tensorflow/compiler/mlir/tfr/utils/utils.h"
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/StringRef.h"
-#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "llvm/ADT/StringSet.h"
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
 
 namespace mlir {
 namespace TFR {
+namespace {
+
+// TODO(b/174692018): Use the official allowlist of the unregistered attrs.
+const llvm::StringSet<>& GetAllowedAttributes() {
+  static auto* const ops = new llvm::StringSet<>({"device", "_tpu_replicate"});
+  return *ops;
+}
+
+void CollectAllowedAttrs(CallOp src, NamedAttrList* attrs) {
+  for (auto& attr : src->getAttrs()) {
+    if (GetAllowedAttributes().contains(attr.first.strref())) {
+      attrs->append(attr);
+    }
+  }
+}
+
+// Adds `attrs` to all the operations between `begin` and `end` in the same
+// block. Does not include `end`.
+void AddAttributesInSameBlock(Block::iterator begin, Block::iterator end,
+                              const NamedAttrList& attrs) {
+  for (Block::iterator it = begin; it != end; ++it) {
+    for (auto& attr : attrs) {
+      it->setAttr(attr.first, attr.second);
+    }
+  }
+}
+
+// Adds `attrs` to all the operations between `begin` and `end`. Does not
+// include `end`. The operations might be across multiple  blocks.
+void AddAttributes(Block::iterator begin, Block::iterator end,
+                   const NamedAttrList& attrs) {
+  if (begin->getBlock() == end->getBlock()) {
+    AddAttributesInSameBlock(begin, end, attrs);
+  } else {
+    Region::iterator begin_block = Region::iterator(begin->getBlock());
+    Region::iterator end_block = Region::iterator(end->getBlock());
+    AddAttributesInSameBlock(begin, begin_block->end(), attrs);
+    for (Region::iterator it = ++begin_block; it != end_block; ++it) {
+      AddAttributesInSameBlock(it->begin(), it->end(), attrs);
+    }
+  }
+}
+
+}  // namespace
 
 std::string GetComposeFuncName(StringRef tf_op_name) {
   std::string compose_func_name;
@@ -74,5 +121,59 @@
   return tf_op_name;
 }
 
+LogicalResult ValidateAttrs(Operation* src, const StringSet<>& registered) {
+  for (auto& attr : src->getAttrs()) {
+    StringRef attr_name = attr.first.strref();
+    if (!registered.contains(attr_name) &&
+        !GetAllowedAttributes().contains(attr_name)) {
+      src->emitError("Denied unregistered attribute was found: " + attr_name);
+      return failure();
+    }
+  }
+  return success();
+}
+
+LogicalResult CopyAllowedUnregisteredAttrs(Operation* src, CallOp dst,
+                                           const StringSet<>& registered) {
+  for (auto& attr : src->getAttrs()) {
+    StringRef attr_name = attr.first.strref();
+    // Skip the registered attribute.
+    if (registered.contains(attr_name)) continue;
+
+    // Unregistered attribute.
+    if (GetAllowedAttributes().contains(attr_name)) {
+      dst->setAttr(attr.first, attr.second);
+    } else {
+      src->emitError("Denied unregistered attribute was found: " + attr_name);
+      return failure();
+    }
+  }
+  return success();
+}
+
+LogicalResult CopyNonSymbolRefAttrs(CallOp src, Operation* dst) {
+  NamedAttrList attrs;
+  CollectAllowedAttrs(src, &attrs);
+
+  for (auto& attr : attrs) {
+    dst->setAttr(attr.first, attr.second);
+  }
+
+  return success();
+}
+
+void PropagateAttrsToOperations(CallOp src, Block::iterator begin,
+                                Block::iterator end) {
+  // Find all the attributes in the call op. These attributes are not in the
+  // op definition, so needs to be propagated to all the target ops.
+  NamedAttrList attrs;
+  CollectAllowedAttrs(src, &attrs);
+
+  // Add all the attributes to the operations in the range.
+  if (!attrs.empty()) {
+    AddAttributes(begin, end, attrs);
+  }
+}
+
 }  // namespace TFR
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.h b/tensorflow/compiler/mlir/tfr/utils/utils.h
index 26c7250..f910981 100644
--- a/tensorflow/compiler/mlir/tfr/utils/utils.h
+++ b/tensorflow/compiler/mlir/tfr/utils/utils.h
@@ -16,9 +16,12 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_
 #define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_
 
-#include <string>
-
+#include "mlir/IR/Block.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
 
 namespace mlir {
 namespace TFR {
@@ -36,6 +39,25 @@
 //   tf__concat_v2 => tf.ConcatV2
 std::string GetTFOpName(StringRef compose_func_name);
 
+// Validate the attributes of 'src' is either contained in the registered
+// attribute sets or in the allowed list.
+LogicalResult ValidateAttrs(Operation* src, const StringSet<>& registered);
+
+// Copies all the allowed attributes in 'src' to 'dst'. The copy failed if the
+// 'dst' has the attribute. Return a failure if there are any attributes are not
+// allowed and also unregistered.
+LogicalResult CopyAllowedUnregisteredAttrs(Operation* src, CallOp dst,
+                                           const StringSet<>& registered);
+
+// Copies all the allowed attributes in 'src' to 'dst'. FlatSymbolRefAttr is
+// excluded.
+LogicalResult CopyNonSymbolRefAttrs(CallOp src, Operation* dst);
+
+// Propagates all the attributes in 'src' to the operations between 'begin' and
+// 'end'. Operation 'end' is excluded.
+void PropagateAttrsToOperations(CallOp src, Block::iterator begin,
+                                Block::iterator end);
+
 }  // namespace TFR
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
index 4cfb216..210a4ee 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
@@ -12,6 +12,7 @@
     "@local_config_rocm//rocm:build_defs.bzl",
     "if_rocm_is_configured",
 )
+load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available", "if_llvm_system_z_available")
 
 package(
     default_visibility = [":friends"],
@@ -55,8 +56,6 @@
         "//tensorflow/compiler/xla/service/gpu:stream_executor_util",
         "//tensorflow/compiler/xla/service/gpu:target_constants",
         "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
-        "//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering",
-        "//tensorflow/compiler/xla/service/mlir_gpu:passes",
         "//tensorflow/core:lib",
         "//tensorflow/core/platform:cuda_libdevice_path",
         "@llvm-project//llvm:Support",
@@ -87,29 +86,6 @@
 )
 
 tf_cc_binary(
-    name = "tf_to_gpu_binary",
-    srcs = [
-        "crash_handler.h",
-        "tf_to_gpu_binary.cc",
-    ],
-    visibility = [
-        "//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary:__pkg__",
-        "//tensorflow/core/kernels/mlir_generated:__pkg__",
-    ],
-    deps = [
-        ":kernel_creator",
-        "//tensorflow/compiler/mlir:init_mlir",
-        "//tensorflow/compiler/mlir/tensorflow",
-        "//tensorflow/core:lib",
-        "//tensorflow/core/platform",
-        "//tensorflow/stream_executor/lib",
-        "@com_google_absl//absl/strings",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:Pass",
-    ],
-)
-
-tf_cc_binary(
     name = "tf_to_kernel",
     srcs = ["tf_to_kernel.cc"],
     visibility = [
@@ -125,8 +101,10 @@
         "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Analysis",
+        "@llvm-project//llvm:ARMCodeGen",  # fixdeps: keep
         "@llvm-project//llvm:CodeGen",
         "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:PowerPCCodeGen",  # fixdeps: keep
         "@llvm-project//llvm:Support",
         "@llvm-project//llvm:Target",
         "@llvm-project//llvm:X86CodeGen",  # fixdeps: keep
@@ -134,7 +112,11 @@
         "@llvm-project//mlir:ExecutionEngineUtils",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:TargetLLVMIR",
-    ],
+    ] + if_llvm_system_z_available([
+        "@llvm-project//llvm:SystemZCodeGen",  # fixdeps: keep
+    ]) + if_llvm_aarch64_available([
+        "@llvm-project//llvm:AArch64CodeGen",  # fixdeps: keep
+    ]),
 )
 
 tf_cc_binary(
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h
index 30cefaa..c8f8439 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h
@@ -19,11 +19,11 @@
 #define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_
 
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_status.h.inc"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
index 302c7e9..91d1c7c 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
@@ -34,6 +34,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"  // from @llvm-project
 #include "mlir/Dialect/Linalg/Passes.h"  // from @llvm-project
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"  // from @llvm-project
 #include "mlir/Dialect/SCF/Passes.h"  // from @llvm-project
 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
 #include "mlir/Dialect/SCF/Transforms.h"  // from @llvm-project
@@ -55,8 +56,6 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/path.h"
@@ -89,67 +88,53 @@
 };
 }  // end anonymous namespace
 
-Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
-                    llvm::ArrayRef<uint32_t> tile_sizes,
+Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
                     llvm::ArrayRef<uint32_t> unroll_factors,
                     bool embed_memref_prints) {
   mlir::PassManager pm(module.getContext());
   applyTensorflowAndCLOptions(pm);
 
-  if (gpu_binary_only) {
-    pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
-        /*allow_partial_conversion=*/false, /*legalize_chlo=*/true));
-    pm.addNestedPass<mlir::FuncOp>(
-        mlir::kernel_gen::transforms::CreateMaterializeBroadcastsPass());
-    pm.addNestedPass<mlir::FuncOp>(
-        mlir::kernel_gen::transforms::CreateUnfuseBatchNormPass());
-  } else {
-    pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
-        /*allow_partial_conversion=*/false, /*legalize_chlo=*/false));
-    pm.addNestedPass<mlir::FuncOp>(mlir::createTransformUnrankedHloPass());
-    pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createChloLegalizeToHloPass());
-    pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
-  }
+  pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
+      /*allow_partial_conversion=*/false, /*legalize_chlo=*/false));
+  pm.addNestedPass<mlir::FuncOp>(mlir::createTransformUnrankedHloPass());
+  pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createChloLegalizeToHloPass());
+  pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
 
-  // Partial bufferization: Transforms inparticular HLO operation to their
-  // corresponding LHLO operations and converts the function signature. Leaves
-  // shape operations untouched.
-  pm.addPass(mlir::kernel_gen::transforms::CreateHloBufferizePass());
-  // Run CSE to ensure that loads and stores to the same location get recognized
-  // as such.
-  pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
-  // Forward stores to buffers to loads.
-  pm.addNestedPass<mlir::FuncOp>(xla::mlir_gpu::createStoreForwardingPass());
-
-  // Clean up the IR for further processing.
+  // Transform HLO operations to LinAlg.
+  pm.addNestedPass<mlir::FuncOp>(::mlir::mhlo::createLegalizeHloToLinalgPass());
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
   // We have to anticipate later unrolling in tiling to make sure that we get
   // the requested tiling after unrolling. Compute the new tiling here if
   // needed.
-  llvm::SmallVector<unsigned, 4> tiling_for_unrolling;
-  llvm::SmallVector<int64_t, 4> as_int64;
+  llvm::SmallVector<int64_t, 4> tiling_for_unrolling, inner_tile;
   tiling_for_unrolling.reserve(tile_sizes.size());
   for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
     tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
-    as_int64.push_back(std::get<1>(pair));
+    inner_tile.push_back(std::get<1>(pair));
   }
   tiling_for_unrolling.append(
       tile_sizes.drop_front(unroll_factors.size()).begin(), tile_sizes.end());
-  // Transform LHLO operations to LinAlg.
-  pm.addNestedPass<mlir::FuncOp>(
-      ::mlir::lmhlo::createLegalizeLhloToLinalgPass());
-  if (!gpu_binary_only) {
-    // Find candidates for buffer reuse. This is only successful if buffer size
-    // equality can be determined based on `linalg.generic` operations.
-    pm.addNestedPass<mlir::FuncOp>(
-        mlir::kernel_gen::transforms::CreateBufferReusePass());
-  }
   // Fuse linalg operations.
-  pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass(
-      /*use_parallel_loops=*/true, tiling_for_unrolling));
-  // Transform the Linalg operations inside of the loop nest into parallel
-  // loops.
+  pm.addNestedPass<mlir::FuncOp>(mlir::createLinalgFusionOfTensorOpsPass());
+
+  // Partial bufferization: Transforms inparticular HLO and Linalg operations to
+  // their corresponding LHLO operations and converts the function signature.
+  // Leaves shape operations untouched.
+  //
+  // TODO(pifon): Rename the pass to CreateHloLinalgBufferizePass or bufferize
+  // in 2 steps: first Linalg, then Hlo. That would need refactoring of
+  // BufferizeTypeConverter.
+  pm.addPass(mlir::kernel_gen::transforms::CreateHloBufferizePass());
+  pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
+  pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
+  // Find candidates for buffer reuse. This is only successful if buffer size
+  // equality can be determined based on `linalg.generic` operations.
+  pm.addNestedPass<mlir::FuncOp>(
+      mlir::kernel_gen::transforms::CreateBufferReusePass());
+  pm.addNestedPass<mlir::FuncOp>(
+      mlir::createLinalgTilingToParallelLoopsPass((tiling_for_unrolling)));
+  // Transform the Linalg ops inside of the loop nest into parallel loops.
   pm.addNestedPass<mlir::FuncOp>(
       ::mlir::createConvertLinalgToParallelLoopsPass());
   // Canonicalize the code to simplify index computations. This is needed so
@@ -158,24 +143,20 @@
   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
   // Fuse the inner-most loops.
   pm.addNestedPass<mlir::FuncOp>(
-      xla::mlir_gpu::createFuseInnerParallelLoopsPass());
+      mlir::kernel_gen::transforms::CreateFuseInnerParallelLoopsPass());
   // Run CSE to ensure that loads and stores to the same subview get
   // recognized as such.
   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
-  // Forward stores to buffers to loads.
-  pm.addNestedPass<mlir::FuncOp>(xla::mlir_gpu::createStoreForwardingPass());
-  // Remove now unused temporary buffers.
-  pm.addNestedPass<mlir::FuncOp>(
-      xla::mlir_gpu::createDeadTempBufferRemovalPass());
   if (!unroll_factors.empty()) {
     pm.addNestedPass<mlir::FuncOp>(
-        ::mlir::createParallelLoopTilingPass(as_int64));
+        ::mlir::createParallelLoopTilingPass(inner_tile));
   }
   // Some basic cleanup.
   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
   // Greedily map the remaining loop to GPU hardware dimensions.
-  pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass());
+  pm.addNestedPass<::mlir::FuncOp>(
+      mlir::kernel_gen::transforms::CreateMapParallelLoopsPass());
 
   // Now lower the shape computations, bufferize all remaining ops and insert
   // deallocs.
@@ -195,15 +176,13 @@
       std::make_unique<RemoveUnusedTensorToMemrefOperations>());
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
-  if (!gpu_binary_only) {
-    // Before inserting more allocs, map the ones we already have to the
-    // tf runtime. That ensures that all allocations for the actual computation
-    // end up on the device, whereas allocations for shape computation and host
-    // side things remain on the host.
-    // Longer term, this should be handled by proper device placement.
-    pm.addPass(mlir::kernel_gen::tf_framework::
-                   CreateEmbedTFFrameworkFunctionAndAllocPass());
-  }
+  // Before inserting more allocs, map the ones we already have to the
+  // tf runtime. That ensures that all allocations for the actual computation
+  // end up on the device, whereas allocations for shape computation and host
+  // side things remain on the host.
+  // Longer term, this should be handled by proper device placement.
+  pm.addPass(mlir::kernel_gen::tf_framework::
+                 CreateEmbedTFFrameworkFunctionAndAllocPass());
   pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass());
   pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass(64));
   // TODO(herhut): Depends on https://bugs.llvm.org/show_bug.cgi?id=48385.
@@ -230,11 +209,6 @@
   // Take launches to launches with kernels.
   pm.addPass(::mlir::createGpuKernelOutliningPass());
 
-  if (gpu_binary_only) {
-    // Make kernel signature deterministic so that we can call it externally.
-    pm.addNestedPass<::mlir::FuncOp>(
-        xla::mlir_gpu::createRewriteKernelSignaturePass());
-  }
   pm.addPass(::mlir::createLowerAffinePass());
   // Constraints are removed as late as possible and before lowering to CFG.
   pm.addNestedPass<::mlir::FuncOp>(::mlir::createConvertShapeConstraintsPass());
@@ -253,6 +227,33 @@
   return Status::OK();
 }
 
+Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module) {
+#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA)
+  return InternalError(
+      "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
+      " Did you specify either --config=rocm or --config=cuda ?");
+#endif
+  mlir::PassManager pm(module.getContext());
+  // We cannot verify as the signature of the kernel is rewritten.
+  // pm.enableVerifier(false);
+  tensorflow::applyTensorflowAndCLOptions(pm);
+  auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
+  kernelPm.addPass(::mlir::createLowerToCFGPass());
+#if TENSORFLOW_USE_ROCM
+  kernelPm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToRocdlPass());
+#elif GOOGLE_CUDA
+  kernelPm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToNvvmPass());
+#endif
+  // Remove all location information to prevent a debug build.
+  pm.addPass(::mlir::createStripDebugInfoPass());
+
+  if (failed(pm.run(module))) {
+    return InternalError("Lowering to low-level device IR failed.");
+  }
+
+  return Status::OK();
+}
+
 Status AmendKernelLLVMIRWithStaticKnowledge(mlir::ModuleOp module) {
   mlir::PassManager pm(module.getContext());
   applyTensorflowAndCLOptions(pm);
@@ -270,7 +271,8 @@
 Status GenerateDeviceCode(mlir::ModuleOp module,
                           llvm::StringRef gpu_binary_attr_name,
                           llvm::ArrayRef<std::string> architectures,
-                          bool generate_fatbin, bool print_ptx) {
+                          bool generate_fatbin, bool print_ptx,
+                          bool enable_ftz) {
   mlir::PassManager pm(module.getContext());
   applyTensorflowAndCLOptions(pm);
 
@@ -278,7 +280,8 @@
   // Remove debug information to ensure we do not create debug PTX.
   kernel_pm.addPass(mlir::createStripDebugInfoPass());
   kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
-      gpu_binary_attr_name, architectures, generate_fatbin, print_ptx));
+      gpu_binary_attr_name, architectures, generate_fatbin, print_ptx,
+      enable_ftz));
 
   return failed(pm.run(module))
              ? InternalError("Generating device code failed.")
@@ -302,35 +305,23 @@
 }  // namespace
 
 StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
-    mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
+    mlir::MLIRContext& context, llvm::StringRef tf_code,
     llvm::ArrayRef<std::string> architectures,
     llvm::ArrayRef<uint32_t> tile_sizes,
     llvm::ArrayRef<uint32_t> unroll_factors, bool embed_memref_prints,
-    bool generate_fatbin, bool print_ptx) {
+    bool generate_fatbin, bool print_ptx, bool enable_ftz) {
   auto& registry = context.getDialectRegistry();
   mlir::RegisterAllTensorFlowDialects(registry);
   registry.insert<mlir::chlo::HloClientDialect, mlir::mhlo::MhloDialect>();
   mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
-  TF_RETURN_IF_ERROR(LowerTFtoGPU(module.get(), gpu_binary_only, tile_sizes,
-                                  unroll_factors, embed_memref_prints));
-#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA)
-  return InternalError(
-      "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
-      " Did you specify either --config=rocm or --config=cuda ?");
-#endif
-
-#if TENSORFLOW_USE_ROCM
-  TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToROCDL(module.get()));
-#elif GOOGLE_CUDA
-  TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
-#endif
+  TF_RETURN_IF_ERROR(LowerTFtoGPU(module.get(), tile_sizes, unroll_factors,
+                                  embed_memref_prints));
+  TF_RETURN_IF_ERROR(LowerKernelBodiesToLowLevelIr(module.get()));
   TF_RETURN_IF_ERROR(AmendKernelLLVMIRWithStaticKnowledge(module.get()));
   TF_RETURN_IF_ERROR(GenerateDeviceCode(module.get(), kGpuBinaryAttrName,
                                         architectures, generate_fatbin,
-                                        print_ptx));
-  if (!gpu_binary_only) {
-    TF_RETURN_IF_ERROR(LowerHostSideToFinalForm(module.get()));
-  }
+                                        print_ptx, enable_ftz));
+  TF_RETURN_IF_ERROR(LowerHostSideToFinalForm(module.get()));
   return module;
 }
 
@@ -340,7 +331,7 @@
     return InternalError("There should be exactly one GPU Module");
   }
   mlir::gpu::GPUModuleOp gpu_mod = *gpu_modules.begin();
-  auto blob = gpu_mod.getAttrOfType<mlir::StringAttr>(kGpuBinaryAttrName);
+  auto blob = gpu_mod->getAttrOfType<mlir::StringAttr>(kGpuBinaryAttrName);
   if (blob == nullptr) {
     return InternalError("No binary blob found in the module");
   }
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h
index 33be8ae..8216656 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h
@@ -33,16 +33,14 @@
 namespace tensorflow {
 namespace kernel_gen {
 
-// Converts TF code to LLVM/NVVM. If `gpu_binary_only` is true, then the
-// conversion stops after gpu_binary blob is generated. If `gpu_binary_only` is
-// false, lowers the host side to LLVM Dialect.
+// Converts TF code to LLVM/NVVM. Lowers the host side to LLVM Dialect.
 xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
-    mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
+    mlir::MLIRContext& context, llvm::StringRef tf_code,
     llvm::ArrayRef<std::string> architectures = {"sm_75"},
     llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
     llvm::ArrayRef<uint32_t> unroll_factors = {},
     bool embed_memref_prints = false, bool generate_fatbin = true,
-    bool print_ptx = false);
+    bool print_ptx = false, bool enable_ftz = false);
 
 // Extracts gpu_binary from the converted module.
 xla::StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module);
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
index e7211aa..a5286fd 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
@@ -2,14 +2,14 @@
 // RUN: kernel-gen-opt %s --func-bufferize --final-bufferize --promote-buffers-to-stack | FileCheck %s  --check-prefixes=CHECK,ALLOCA
 
 
-// CHECK-LABEL: @extract_element
+// CHECK-LABEL: @tensor.extract
 // CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> f32
-func @extract_element(%arg : tensor<?xf32>) -> f32 {
+func @tensor.extract(%arg : tensor<?xf32>) -> f32 {
   // CHECK: %[[C0:.*]] = constant 0 : index
   // CHECK: %[[RESULT:.*]] = load %[[ARG]][%[[C0]]]
   // CHECK: return %[[RESULT]]
   %c0 = constant 0 : index
-  %result = extract_element %arg[%c0] : tensor<?xf32>
+  %result = tensor.extract %arg[%c0] : tensor<?xf32>
   return %result : f32
 }
 
@@ -30,7 +30,7 @@
   %c = constant 2.3 : f32
   %tfe = tensor_from_elements %a, %b, %c : tensor<3xf32>
   %c0 = constant 0 : index
-  %result = extract_element %tfe[%c0] : tensor<3xf32>
+  %result = tensor.extract %tfe[%c0] : tensor<3xf32>
   return %result : f32
 }
 
@@ -54,7 +54,7 @@
     yield %elem : index
   } : tensor<?xindex>
   %c0 = constant 0 : index
-  %result = extract_element %tfe[%c0] : tensor<?xindex>
+  %result = tensor.extract %tfe[%c0] : tensor<?xindex>
   return %result : index
 }
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir
index e5d124b..dce998b 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir
@@ -1,4 +1,8 @@
-// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --hlo-bufferize --canonicalize --shape-to-descriptors --canonicalize --final-bufferize | FileCheck %s
+// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | \
+// RUN: mlir-hlo-opt --transform-unranked-hlo --hlo-legalize-to-linalg  | \
+// RUN: kernel-gen-opt -allow-unregistered-dialect --hlo-bufferize \
+// RUN: --canonicalize --shape-to-descriptors --canonicalize --final-bufferize \
+// RUN: | FileCheck %s
 
 // Test whether all shape computations required for isinf can be lowered to
 // the standard dialect, scf and descriptors.
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir
index 7cd4841..26cf440 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir
@@ -1,4 +1,8 @@
-// RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --hlo-bufferize --shape-to-descriptors --canonicalize --final-bufferize | FileCheck %s
+// RUN: tf-opt %s --xla-legalize-tf | \
+// RUN: mlir-hlo-opt --transform-unranked-hlo --hlo-legalize-to-linalg  | \
+// RUN: kernel-gen-opt -allow-unregistered-dialect --hlo-bufferize \
+// RUN: --canonicalize --shape-to-descriptors --canonicalize --final-bufferize \
+// RUN: | FileCheck %s
 
 // Test whether all shape computations required for tanh can be lowered to
 // the standard dialect, scf and descriptors. We check for a sparse pattern here,
@@ -13,7 +17,7 @@
   // CHECK: scf.for
   // CHECK-NOT: tensor_from_elements
   // CHECK: memref_reshape
-  // CHECK: lmhlo.tanh
+  // CHECK: linalg.generic
   // CHECK: memref_reshape
   %0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir
index 4a2b2da..13b37f3 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir
@@ -1,5 +1,5 @@
 // RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | \
-// RUN: mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | \
+// RUN: mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo --hlo-legalize-to-linalg | \
 // RUN: kernel-gen-opt --hlo-bufferize --shape-to-descriptors --canonicalize --final-bufferize
 
 func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> {
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/BUILD
deleted file mode 100644
index 6aef5c0..0000000
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/BUILD
+++ /dev/null
@@ -1,17 +0,0 @@
-load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
-
-package(licenses = ["notice"])
-
-glob_lit_tests(
-    data = [
-        "//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_gpu_binary",
-        "@llvm-project//mlir:run_lit.sh",
-    ],
-    default_tags = [
-        # We need access to the CUDA SDK.
-        "gpu",
-        "no_rocm",
-    ],
-    driver = "//tensorflow/compiler/mlir:run_lit.sh",
-    test_file_exts = ["mlir"],
-)
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir
deleted file mode 100644
index 5177309..0000000
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir
+++ /dev/null
@@ -1,6 +0,0 @@
-// RUN: tf_to_gpu_binary --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70
-func @abs(%arg0: tensor<?xf16>) -> tensor<?xf16> attributes {tf_entry} {
-  %0 = "tf.Abs"(%arg0) { }
-    : (tensor<?xf16>) -> tensor<?xf16>
-  return %0 : tensor<?xf16>
-}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir
deleted file mode 100644
index bb50580..0000000
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir
+++ /dev/null
@@ -1,6 +0,0 @@
-// RUN: tf_to_gpu_binary --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70
-func @ceil(%arg0: tensor<?xf64>) -> tensor<?xf64> attributes {tf_entry} {
-  %0 = "tf.Ceil"(%arg0) { }
-    : (tensor<?xf64>) -> tensor<?xf64>
-  return %0 : tensor<?xf64>
-}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir
deleted file mode 100644
index fa88fc7..0000000
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir
+++ /dev/null
@@ -1,5 +0,0 @@
-// RUN: tf_to_gpu_binary --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70
-func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> attributes {tf_entry} {
-  %0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
-  return %0 : tensor<?xf32>
-}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc
deleted file mode 100644
index 6f1de7d..0000000
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc
+++ /dev/null
@@ -1,96 +0,0 @@
-// Copyright 2020 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- tf_to_gpu_binary.cc --------------------------------------*- C++ -*-===//
-//
-// This file implements the entry point to compile a tf op to a gpu binary
-//
-//===----------------------------------------------------------------------===//
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "llvm/Support/CommandLine.h"
-#include "mlir/Pass/PassManager.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/init_mlir.h"
-#include "tensorflow/compiler/mlir/tools/kernel_gen/crash_handler.h"
-#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/stream_executor/lib/statusor.h"
-
-namespace tensorflow {
-namespace kernel_gen {
-namespace {
-
-xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
-                std::string architecture, llvm::ArrayRef<uint32_t> tile_sizes,
-                llvm::ArrayRef<uint32_t> unroll_factors) {
-  // Read TF code.
-  std::string tf_code;
-  TF_RETURN_IF_ERROR(
-      ReadFileToString(Env::Default(), input_file.str(), &tf_code));
-  // Compile.
-  mlir::MLIRContext context;
-  TF_ASSIGN_OR_RETURN(
-      mlir::OwningModuleRef module,
-      GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true,
-                              architecture, tile_sizes, unroll_factors,
-                              /*embed_memref_prints=*/false,
-                              /*generate_fatbin=*/false));
-  // Extract gpu_binary.
-  TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module));
-
-  // Write gpu_binary blob.
-  TF_RETURN_IF_ERROR(
-      WriteStringToFile(Env::Default(), output_file.str(), gpu_binary));
-  return xla::Status::OK();
-}
-
-}  // namespace
-}  // namespace kernel_gen
-}  // namespace tensorflow
-
-int main(int argc, char** argv) {
-  tensorflow::kernel_gen::SetCrashReportMessage();
-  llvm::cl::opt<std::string> input_file("input", llvm::cl::desc("input file"),
-                                        llvm::cl::value_desc("filename"),
-                                        llvm::cl::init("foo.mlir"));
-  llvm::cl::opt<std::string> output_file(
-      "output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
-      llvm::cl::init("foo.bin"));
-  llvm::cl::opt<std::string> architecture(
-      "arch", llvm::cl::desc("target architecture (e.g. sm_50)"),
-      llvm::cl::init("sm_50"));
-  llvm::cl::list<uint32_t> tile_sizes(
-      "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore,
-      llvm::cl::CommaSeparated);
-  llvm::cl::list<uint32_t> unroll_factors(
-      "unroll_factors",
-      llvm::cl::desc("factors to unroll by, separated by commas"),
-      llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated);
-
-  tensorflow::InitMlir y(&argc, &argv);
-  mlir::registerPassManagerCLOptions();
-  llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
-
-  auto status = tensorflow::kernel_gen::Run(
-      input_file, output_file, architecture, tile_sizes, unroll_factors);
-  if (!status.ok()) {
-    LOG(ERROR) << status;
-    return 1;
-  }
-  return 0;
-}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc
index a62a413..823e143 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc
@@ -106,7 +106,7 @@
                 llvm::ArrayRef<std::string> architectures,
                 llvm::ArrayRef<uint32_t> tile_sizes,
                 llvm::ArrayRef<uint32_t> unroll_factors,
-                bool embed_memref_prints, bool print_ptx) {
+                bool embed_memref_prints, bool print_ptx, bool enable_ftz) {
   // Read TF code.
   std::string tf_code;
   TF_RETURN_IF_ERROR(
@@ -115,10 +115,9 @@
   mlir::MLIRContext context;
   TF_ASSIGN_OR_RETURN(
       mlir::OwningModuleRef module,
-      GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false,
-                              architectures, tile_sizes, unroll_factors,
-                              embed_memref_prints, /*generate_fatbin=*/true,
-                              /*print_ptx=*/print_ptx));
+      GenerateKernelForTfCode(context, tf_code, architectures, tile_sizes,
+                              unroll_factors, embed_memref_prints,
+                              /*generate_fatbin=*/true, print_ptx, enable_ftz));
   // Get binary.
   TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
 
@@ -147,6 +146,11 @@
       "print-ptx",
       llvm::cl::desc("Print generated PTX code per target architecture."),
       llvm::cl::init(false));
+  llvm::cl::opt<bool> enable_ftz(
+      "enable_ftz",
+      llvm::cl::desc(
+          "Enable the denormal flush to zero mode when generating code."),
+      llvm::cl::init(false));
   llvm::cl::list<std::string> architectures(
       "arch", llvm::cl::desc("target architectures (e.g. sm_70 or compute_75)"),
       llvm::cl::OneOrMore, llvm::cl::CommaSeparated);
@@ -166,7 +170,7 @@
 
   auto status = tensorflow::kernel_gen::Run(
       input_file, output_file, architectures, tile_sizes, unroll_factors,
-      embed_memref_prints, print_ptx);
+      embed_memref_prints, print_ptx, enable_ftz);
   if (!status.ok()) {
     LOG(ERROR) << status;
     return 1;
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
index 4fa6f02..9c09755 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
@@ -76,61 +76,67 @@
         "bufferize_pass.cc",
         "embed_memref_prints.cc",
         "embed_tf_framework_pass.cc",
+        "fuse_inner_parallel_loops_pass.cc",
         "gpu_kernel_to_blob_pass.cc",
-        "materialize_broadcasts_pass.cc",
+        "kernel_lowering_passes.cc",
+        "map_parallel_loops_to_gpu.cc",
         "parallel_loops_to_sequential.cc",
         "same_shape_propagation.cc",
         "shape_to_descriptors_pass.cc",
         "tensorflow_abi_knowledge_propagation.cc",
         "tf_kernel_to_llvm_pass.cc",
-        "unfuse_batch_norm_pass.cc",
     ],
     hdrs = ["passes.h"],
     copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]),
     deps = [
-        "@llvm-project//mlir:Affine",
-        "//tensorflow/compiler/mlir/hlo:materialize_broadcasts",  # buildcleaner: keep
-        "//tensorflow/compiler/mlir/hlo:unfuse_batch_norm",  # buildcleaner: keep
-        "//tensorflow/compiler/xla/service:hlo_module_config",
-        "//tensorflow/compiler/xla:debug_options_flags",
-        "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla:status",
-        "//tensorflow/compiler/xla/service/gpu:target_constants",
-        "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
-        "//tensorflow/core/platform:cuda_libdevice_path",
-        "//tensorflow/core:lib",
         ":bufferize",
         ":embed_tf_framework",
         ":kernel_gen_passes_inc_gen",
         ":tf_framework_legalize_to_llvm",
         "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:TransformUtils",
+        "@llvm-project//mlir:Affine",
+        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:GPUToGPURuntimeTransforms",
+        "@llvm-project//mlir:GPUToNVVMTransforms",
+        "@llvm-project//mlir:GPUToROCDLTransforms",
+        "@llvm-project//mlir:GPUTransforms",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:LinalgTransforms",
         "@llvm-project//mlir:LLVMDialect",
         "@llvm-project//mlir:LLVMTransforms",
-        "@llvm-project//mlir:LinalgTransforms",
-        "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:NVVMDialect",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:ROCDLDialect",
         "@llvm-project//mlir:SCFDialect",
-        "@llvm-project//mlir:Shape",
-        "@llvm-project//mlir:Analysis",
-        "@llvm-project//mlir:TargetNVVMIR",
-        "@llvm-project//mlir:TargetROCDLIR",
-        "@llvm-project//mlir:ShapeToStandard",
         "@llvm-project//mlir:SCFToStandard",
+        "@llvm-project//mlir:Shape",
+        "@llvm-project//mlir:ShapeToStandard",
         "@llvm-project//mlir:ShapeTransforms",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:StandardOpsTransforms",
-        "@llvm-project//mlir:AllPassesAndDialects",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TargetNVVMIR",
+        "@llvm-project//mlir:TargetROCDLIR",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:TensorTransforms",
         "@llvm-project//mlir:Transforms",
-        "@llvm-project//llvm:TransformUtils",
         "//tensorflow/compiler/mlir/hlo",
         "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
         "//tensorflow/compiler/mlir/hlo:lhlo",
-        "//tensorflow/compiler/xla/service/gpu:stream_executor_util",
         "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
+        "//tensorflow/compiler/xla:debug_options_flags",
+        "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
+        "//tensorflow/compiler/xla/service/gpu:stream_executor_util",
+        "//tensorflow/compiler/xla/service/gpu:target_constants",
+        "//tensorflow/compiler/xla/service:hlo_module_config",
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/platform:cuda_libdevice_path",
     ] + if_cuda_is_configured([
         "//tensorflow/stream_executor/gpu:asm_compiler",
     ]) + if_rocm_is_configured([
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc
index 26184fa..59fb75e 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc
@@ -26,8 +26,8 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/AffineMap.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
@@ -324,7 +324,7 @@
 
 struct BufferReusePass : public BufferReusePassBase<BufferReusePass> {
   void runOnFunction() override {
-    if (!getFunction().getAttrOfType<UnitAttr>(
+    if (!getFunction()->getAttrOfType<UnitAttr>(
             tf_framework::TFFrameworkDialect::kTFEntryAttrName))
       return;
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc
index 3a92192..8ff10c9 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc
@@ -21,7 +21,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
 
@@ -40,15 +40,25 @@
     // We only need to bufferize tensor constants.
     Location loc = op.getLoc();
     auto result_type = op.getType().dyn_cast<RankedTensorType>();
-    if (!result_type || !result_type.hasStaticShape() ||
-        result_type.getRank() != 1)
+    int64_t result_rank = result_type.getRank();
+    if (!result_type || !result_type.hasStaticShape() || result_rank > 1)
       return failure();
 
-    auto memref_type = MemRefType::get({result_type.getNumElements()},
-                                       result_type.getElementType());
+    auto memref_type =
+        MemRefType::get(result_type.getShape(), result_type.getElementType());
+    auto elements_attr = op.value().cast<DenseElementsAttr>();
+
+    if (result_rank == 0) {
+      Value buffer = rewriter.create<AllocOp>(loc, memref_type);
+      Value constant =
+          rewriter.create<ConstantOp>(loc, elements_attr.getValue({}));
+      rewriter.create<StoreOp>(loc, constant, buffer);
+      rewriter.replaceOp(op, {buffer});
+      return success();
+    }
+
     Value buffer = rewriter.create<AllocaOp>(loc, memref_type);
 
-    auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>();
     bool all_same_elems = elements_attr.isSplat();
     Value value;
     if (all_same_elems)
@@ -92,8 +102,8 @@
 void populateExtraStdBufferizePattern(MLIRContext *context,
                                       BufferizeTypeConverter *converter,
                                       OwningRewritePatternList *patterns) {
-  patterns->insert<BufferizeConstantOp, BufferizeDimOp,
-                   BufferizeRankOp>(*converter, context);
+  patterns->insert<BufferizeConstantOp, BufferizeDimOp, BufferizeRankOp>(
+      *converter, context);
 }
 
 }  // namespace transforms
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
index 2985e6b3..ba2e78b 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
@@ -19,7 +19,10 @@
 #include <memory>
 
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"  // from @llvm-project
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"  // from @llvm-project
 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
 #include "mlir/Dialect/SCF/Transforms.h"  // from @llvm-project
 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
@@ -27,11 +30,13 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"  // from @llvm-project
+#include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
+#include "mlir/Dialect/Tensor/Transforms/Passes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Visitors.h"  // from @llvm-project
 #include "mlir/Transforms/Bufferize.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
@@ -51,6 +56,50 @@
 #define GEN_PASS_CLASSES
 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
 
+/// A helper type converter class that automatically populates the relevant
+/// materializations and type conversions for bufferization.
+
+static Value materializeTensorLoad(OpBuilder& builder, TensorType type,
+                                   ValueRange inputs, Location loc) {
+  assert(inputs.size() == 1);
+  assert(inputs[0].getType().isa<BaseMemRefType>());
+  return builder.create<TensorLoadOp>(loc, type, inputs[0]);
+}
+
+// TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
+class CustomBufferizeTypeConverter : public BufferizeTypeConverter {
+ public:
+  CustomBufferizeTypeConverter() {
+    // Keep all types unchanged.
+    addConversion([](Type type) { return type; });
+    // Convert RankedTensorType to MemRefType.
+    addConversion([](RankedTensorType type) -> Type {
+      return MemRefType::get(type.getShape(), type.getElementType());
+    });
+    // Convert UnrankedTensorType to UnrankedMemRefType.
+    addConversion([](UnrankedTensorType type) -> Type {
+      return UnrankedMemRefType::get(type.getElementType(), 0);
+    });
+    addArgumentMaterialization(materializeTensorLoad);
+    addSourceMaterialization(materializeTensorLoad);
+    addTargetMaterialization([](OpBuilder& builder, BaseMemRefType type,
+                                ValueRange inputs, Location loc) -> Value {
+      assert(inputs.size() == 1);
+      // Target materialization is invoked if the new operand type does not
+      // match the expected type. A special case is when the new operand type is
+      // a memref with a specified layout, i.e. non-empty affine map.
+      // TODO(pifon) : Change how target materialization is invoked in dialect
+      // conversion.
+      if (auto memref_type = inputs[0].getType().dyn_cast<MemRefType>()) {
+        assert(!memref_type.getAffineMaps().empty());
+        return inputs[0];
+      }
+      assert(inputs[0].getType().isa<TensorType>());
+      return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
+    });
+  }
+};
+
 struct HloBufferizePass : public HloBufferizePassBase<HloBufferizePass> {
   // TODO(b/173201243): Move to tablegen.
   void getDependentDialects(DialectRegistry& registry) const override {
@@ -64,11 +113,14 @@
     ConversionTarget target(context);
     target.addLegalDialect<lmhlo::LmhloDialect>();
     target.addLegalDialect<StandardOpsDialect>();
+    target.addLegalDialect<lmhlo::LmhloDialect, StandardOpsDialect>();
+    target.addLegalDialect<tensor::TensorDialect>();
     target.addIllegalDialect<mhlo::MhloDialect>();
 
-    BufferizeTypeConverter converter;
+    CustomBufferizeTypeConverter converter;
     // Configure bufferize pattern for functions and lhlo.
-    mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
+    mhlo::populateDynamicHLOToLHLOConversionPattern(
+        &context, &converter, &patterns, /*insert_copy=*/false);
     populateFuncOpTypeConversionPattern(patterns, &context, converter);
     populateCallOpTypeConversionPattern(patterns, &context, converter);
     populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
@@ -76,6 +128,7 @@
 
     // Configure legality and structural patterns.
     populateBufferizeMaterializationLegality(target);
+    linalg::populateLinalgBufferizePatterns(&context, converter, patterns);
     populateShapeStructuralTypeConversionsAndLegality(&context, converter,
                                                       patterns, target);
     scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
@@ -87,8 +140,9 @@
       return converter.isLegal(inputs) && converter.isLegal(results) &&
              converter.isLegal(&op.getBody());
     });
-    target.addDynamicallyLegalOp<CallOp, ReturnOp>(
-        [&converter](Operation* op) { return converter.isLegal(op); });
+    auto isLegalOp = [&](Operation* op) { return converter.isLegal(op); };
+    target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOp);
+    target.addDynamicallyLegalOp<CallOp, ReturnOp>(isLegalOp);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -100,7 +154,8 @@
   // TODO(b/173201243): Move to tablegen.
   void getDependentDialects(DialectRegistry& registry) const override {
     registry.insert<AffineDialect, scf::SCFDialect, shape::ShapeDialect,
-                    tf_framework::TFFrameworkDialect, lmhlo::LmhloDialect>();
+                    tensor::TensorDialect, tf_framework::TFFrameworkDialect,
+                    lmhlo::LmhloDialect>();
   }
 
  public:
@@ -108,20 +163,16 @@
     auto& context = getContext();
     ConversionTarget target(context);
     target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
+                           tensor::TensorDialect,
                            tf_framework::TFFrameworkDialect, AffineDialect,
-                           shape::ShapeDialect, lmhlo::LmhloDialect>();
+                           shape::ShapeDialect, lmhlo::LmhloDialect,
+                           linalg::LinalgDialect>();
     target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
 
     target.addIllegalDialect<mhlo::MhloDialect>();
-    target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
+    target.addIllegalOp<DynamicTensorFromElementsOp, tensor::ExtractOp,
                         TensorFromElementsOp, TensorCastOp, TensorLoadOp,
                         TensorToMemrefOp>();
-    // Certain operations are no longer legal on tensors but otherwise are.
-    target.addDynamicallyLegalOp<ConstantOp, SelectOp>([&](Operation* op) {
-      return llvm::none_of(op->getResultTypes(),
-                           [](Type t) { return t.isa<TensorType>(); });
-    });
-
     BufferizeTypeConverter converter;
     auto typesAreLegal = [&converter](Operation* op) {
       return converter.isLegal(op->getOperandTypes()) &&
@@ -131,6 +182,7 @@
         typesAreLegal);
 
     OwningRewritePatternList patterns;
+    populateTensorBufferizePatterns(&context, converter, patterns);
     populateStdBufferizePatterns(&context, converter, patterns);
     populateEliminateBufferizeMaterializationsPatterns(&context, converter,
                                                        patterns);
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc
index a295c05..4b85962 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc
@@ -19,8 +19,8 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/AffineMap.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
@@ -46,10 +46,10 @@
   if (!callee_func) {
     OpBuilder::InsertionGuard insertGuard(*b);
 
-    auto module = caller_func.getParentOfType<ModuleOp>();
+    auto module = caller_func->getParentOfType<ModuleOp>();
     b->setInsertionPointToStart(module.getBody());
-    auto func_type = FunctionType::get(arg.getType(), /*results=*/llvm::None,
-                                       b->getContext());
+    auto func_type = FunctionType::get(b->getContext(), arg.getType(),
+                                       /*results=*/llvm::None);
     callee_func = b->create<FuncOp>(module.getLoc(), func_name, func_type);
     callee_func.setPrivate();
   }
@@ -106,7 +106,7 @@
     : public EmbedMemRefPrintsPassBase<EmbedMemRefPrintsPass> {
   void runOnFunction() override {
     FuncOp func = getFunction();
-    if (!func.getAttrOfType<UnitAttr>(TFFrameworkDialect::kTFEntryAttrName))
+    if (!func->getAttrOfType<UnitAttr>(TFFrameworkDialect::kTFEntryAttrName))
       return;
 
     Liveness liveness(func);
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc
index db9599e..bf195bc 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc
@@ -15,7 +15,7 @@
 
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeRange.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
@@ -63,7 +63,7 @@
   LogicalResult matchAndRewrite(
       AllocOp alloc, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
-    auto func = alloc.getParentOfType<FuncOp>();
+    auto func = alloc->getParentOfType<FuncOp>();
     if (func.getNumArguments() == 0) {
       return failure();
     }
@@ -76,10 +76,10 @@
     if (!alloc.symbolOperands().empty()) {
       return failure();
     }
-    auto reuse_input_candidates = alloc.getAttrOfType<ArrayAttr>(
+    auto reuse_input_candidates = alloc->getAttrOfType<ArrayAttr>(
         TFAllocOp::kReuseInputCandidatesAttrName);
     auto reuse_output_index =
-        alloc.getAttrOfType<IntegerAttr>(TFAllocOp::kReuseOutputAttrName);
+        alloc->getAttrOfType<IntegerAttr>(TFAllocOp::kReuseOutputAttrName);
     rewriter.replaceOpWithNewOp<TFAllocOp>(alloc, alloc.getType(), ctx,
                                            operands, reuse_input_candidates,
                                            reuse_output_index);
@@ -96,7 +96,7 @@
   LogicalResult matchAndRewrite(
       DeallocOp dealloc, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
-    FuncOp func = dealloc.getParentOfType<FuncOp>();
+    auto func = dealloc->getParentOfType<FuncOp>();
     if (func.getNumArguments() == 0) {
       return failure();
     }
@@ -125,7 +125,7 @@
   LogicalResult matchAndRewrite(
       AssertOp op, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
-    FuncOp func = op.getParentOfType<FuncOp>();
+    auto func = op->getParentOfType<FuncOp>();
     if (func.getNumArguments() == 0) {
       return failure();
     }
@@ -134,8 +134,7 @@
       return failure();
     }
     Location loc = op.getLoc();
-    AssertOp::Adaptor transformed(operands,
-                                  op.getOperation()->getAttrDictionary());
+    AssertOp::Adaptor transformed(operands, op->getAttrDictionary());
 
     // Split the block to insert CondBr.
     OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/fuse_inner_parallel_loops_pass.cc
similarity index 65%
copy from tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc
copy to tensorflow/compiler/mlir/tools/kernel_gen/transforms/fuse_inner_parallel_loops_pass.cc
index 5c347f4..d9bb794 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/fuse_inner_parallel_loops_pass.cc
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
+#include "mlir/Dialect/SCF/Transforms.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
 
 namespace mlir {
@@ -25,19 +25,19 @@
 #define GEN_PASS_CLASSES
 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
 
-struct UnfuseBatchNormPass
-    : public UnfuseBatchNormPassBase<UnfuseBatchNormPass> {
+struct FuseInnerParallelLoopsPass
+    : FuseInnerParallelLoopsPassBase<FuseInnerParallelLoopsPass> {
   void runOnFunction() override {
-    mlir::OwningRewritePatternList patterns;
-    mlir::mhlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
-    mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    getFunction().walk([](mlir::scf::ParallelOp op) {
+      mlir::scf::naivelyFuseParallelOps(op.region());
+    });
   }
 };
 
 }  // namespace
 
-std::unique_ptr<mlir::FunctionPass> CreateUnfuseBatchNormPass() {
-  return std::make_unique<UnfuseBatchNormPass>();
+std::unique_ptr<mlir::FunctionPass> CreateFuseInnerParallelLoopsPass() {
+  return std::make_unique<FuseInnerParallelLoopsPass>();
 }
 
 }  // namespace transforms
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc
index adeb14e..dc1365a 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc
@@ -53,13 +53,14 @@
  public:
   GpuKernelToBlobPass(mlir::StringRef blob_annotation,
                       llvm::ArrayRef<std::string> architectures,
-                      bool generate_fatbin, bool print_ptx) {
+                      bool generate_fatbin, bool print_ptx, bool enable_ftz) {
     if (!blob_annotation.empty()) {
       blob_annotation_ = blob_annotation.str();
     }
     architectures_ = architectures;
     generate_fatbin_ = generate_fatbin;
     print_ptx_ = print_ptx;
+    enable_ftz_ = enable_ftz;
   }
 
   void runOnOperation() override {
@@ -68,8 +69,8 @@
     if (blob_or.ok()) {
       const auto& blob = blob_or.ValueOrDie();
       std::string blob_string(blob.begin(), blob.end());
-      gpu_module.setAttr(blob_annotation_,
-                         mlir::StringAttr::get(blob_string, &getContext()));
+      gpu_module->setAttr(blob_annotation_,
+                          mlir::StringAttr::get(blob_string, &getContext()));
       return;
     }
     // Forward the error by attaching the message to the gpu module.
@@ -99,7 +100,9 @@
     llvmModule->setModuleIdentifier("acme");
 
     xla::HloModuleConfig config;
-    config.set_debug_options(xla::GetDebugOptionsFromFlags());
+    xla::DebugOptions options = xla::GetDebugOptionsFromFlags();
+    options.set_xla_gpu_ftz(enable_ftz_);
+    config.set_debug_options(options);
 
     using AmdGpuHsaco = std::vector<tensorflow::uint8>;
     std::vector<tensorflow::se::HsacoImage> images;
@@ -148,7 +151,9 @@
     llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout);
 
     xla::HloModuleConfig config;
-    config.set_debug_options(xla::GetDebugOptionsFromFlags());
+    xla::DebugOptions options = xla::GetDebugOptionsFromFlags();
+    options.set_xla_gpu_ftz(enable_ftz_);
+    config.set_debug_options(options);
 
     auto enable_fusion = [](llvm::TargetMachine* target) {
       target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
@@ -241,15 +246,16 @@
     return InternalError(
         "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice");
   }
+  bool enable_ftz_;
 };
 
 }  // namespace
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
     mlir::StringRef blob_annotation, ArrayRef<std::string> architectures,
-    bool generate_fatbin, bool print_ptx) {
-  return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architectures,
-                                               generate_fatbin, print_ptx);
+    bool generate_fatbin, bool print_ptx, bool enable_ftz) {
+  return std::make_unique<GpuKernelToBlobPass>(
+      blob_annotation, architectures, generate_fatbin, print_ptx, enable_ftz);
 }
 
 }  // namespace transforms
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_lowering_passes.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_lowering_passes.cc
new file mode 100644
index 0000000..1c82e97
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_lowering_passes.cc
@@ -0,0 +1,98 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"  // from @llvm-project
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"  // from @llvm-project
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"  // from @llvm-project
+#include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"  // from @llvm-project
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"  // from @llvm-project
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"  // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
+
+namespace mlir {
+namespace kernel_gen {
+namespace transforms {
+
+using gpu::GPUModuleOp;
+
+namespace {
+
+#define GEN_PASS_CLASSES
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
+
+/// A pass that does the final lowering to NVVM. It collects all the patterns
+/// that are currently required, currently mixing std, linalg and gpu.
+class GpuKernelToNVVMPass
+    : public GpuKernelToNVVMPassBase<GpuKernelToNVVMPass> {
+  void getDependentDialects(mlir::DialectRegistry& registry) const override {
+    registry.insert<mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
+  }
+
+ public:
+  void runOnOperation() override {
+    GPUModuleOp m = getOperation();
+
+    OwningRewritePatternList patterns;
+    LLVMTypeConverter converter(m.getContext());
+    populateStdToLLVMConversionPatterns(converter, patterns);
+    populateGpuToNVVMConversionPatterns(converter, patterns);
+    ConversionTarget target(getContext());
+    configureGpuToNVVMConversionLegality(target);
+    if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+/// A pass that does the final lowering to ROCDL. It collects all the patterns
+/// that are currently required, currently mixing std, linalg and gpu.
+class GpuKernelToROCDLPass
+    : public GpuKernelToNVVMPassBase<GpuKernelToROCDLPass> {
+  void getDependentDialects(mlir::DialectRegistry& registry) const override {
+    registry.insert<mlir::ROCDL::ROCDLDialect, mlir::LLVM::LLVMDialect>();
+  }
+
+ public:
+  void runOnOperation() override {
+    gpu::GPUModuleOp m = getOperation();
+
+    OwningRewritePatternList patterns;
+    LLVMTypeConverter converter(m.getContext());
+    populateStdToLLVMConversionPatterns(converter, patterns);
+    populateGpuToROCDLConversionPatterns(converter, patterns);
+    ConversionTarget target(getContext());
+    configureGpuToROCDLConversionLegality(target);
+    if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<GPUModuleOp> > CreateGpuKernelToNvvmPass() {
+  return std::make_unique<GpuKernelToNVVMPass>();
+}
+
+std::unique_ptr<OperationPass<GPUModuleOp> > CreateGpuKernelToRocdlPass() {
+  return std::make_unique<GpuKernelToROCDLPass>();
+}
+
+}  // namespace transforms
+}  // namespace kernel_gen
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/map_parallel_loops_to_gpu.cc
similarity index 65%
rename from tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc
rename to tensorflow/compiler/mlir/tools/kernel_gen/transforms/map_parallel_loops_to_gpu.cc
index 5c347f4..296b333 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/map_parallel_loops_to_gpu.cc
@@ -13,8 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "mlir/Dialect/GPU/ParallelLoopMapper.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
 
 namespace mlir {
@@ -25,19 +24,16 @@
 #define GEN_PASS_CLASSES
 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
 
-struct UnfuseBatchNormPass
-    : public UnfuseBatchNormPassBase<UnfuseBatchNormPass> {
+struct MapParallelLoopsPass : MapParallelLoopsPassBase<MapParallelLoopsPass> {
   void runOnFunction() override {
-    mlir::OwningRewritePatternList patterns;
-    mlir::mhlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
-    mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    mlir::greedilyMapParallelSCFToGPU(getFunction().getBody());
   }
 };
 
 }  // namespace
 
-std::unique_ptr<mlir::FunctionPass> CreateUnfuseBatchNormPass() {
-  return std::make_unique<UnfuseBatchNormPass>();
+std::unique_ptr<mlir::FunctionPass> CreateMapParallelLoopsPass() {
+  return std::make_unique<MapParallelLoopsPass>();
 }
 
 }  // namespace transforms
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc
deleted file mode 100644
index e0c21f0..0000000
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc
+++ /dev/null
@@ -1,61 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
-#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
-#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
-
-namespace mlir {
-namespace kernel_gen {
-namespace transforms {
-namespace {
-
-#define GEN_PASS_CLASSES
-#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
-
-struct MaterializeBroadcastsPass
-    : public MaterializeBroadcastsPassBase<MaterializeBroadcastsPass> {
-  void runOnFunction() override {
-    mlir::ConversionTarget conversionTarget(getContext());
-    mlir::OwningRewritePatternList conversionPatterns;
-
-    // Consider the mhlo dialect legal for tests.
-    conversionTarget.addLegalDialect<mlir::mhlo::MhloDialect>();
-    // The conversion uses helpers from the Standard dialect.
-    conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
-
-    mlir::mhlo::SetupMaterializeBroadcastsLegality(&getContext(),
-                                                   &conversionTarget);
-    mlir::mhlo::PopulateMaterializeBroadcastsPatterns(&getContext(),
-                                                      &conversionPatterns);
-
-    if (failed(applyPartialConversion(getFunction(), conversionTarget,
-                                      std::move(conversionPatterns)))) {
-      return signalPassFailure();
-    }
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<mlir::FunctionPass> CreateMaterializeBroadcastsPass() {
-  return std::make_unique<MaterializeBroadcastsPass>();
-}
-
-}  // namespace transforms
-}  // namespace kernel_gen
-}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
index f5169a1..d1283cb 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
@@ -52,7 +52,7 @@
 
 // Pass to tranform shape computations in shape dialect to standard and scf
 // using memref descriptors.
-std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass();
+std::unique_ptr<OperationPass<ModuleOp>> CreateShapeToDescriptorsPass();
 
 // Pass to tranform hlo-level computations on values to their corresponding
 // parts on buffers.
@@ -62,9 +62,6 @@
 // buffers.
 std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass();
 
-// Pass to materialize broadcasts.
-std::unique_ptr<FunctionPass> CreateMaterializeBroadcastsPass();
-
 // Pass to convert scf::ParallelOp to scf::ForOp.
 std::unique_ptr<FunctionPass> CreateParallelLoopsToSequential();
 
@@ -72,10 +69,7 @@
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
     mlir::StringRef blob_annotation = {},
     ArrayRef<std::string> architectures = {}, bool generate_fatbin = true,
-    bool print_ptx = false);
-
-// Pass to unfuse batch norm.
-std::unique_ptr<FunctionPass> CreateUnfuseBatchNormPass();
+    bool print_ptx = false, bool enable_ftz = false);
 
 // Pass to propagate tensorflow runtime ABI knowledge across kernel boundaries.
 std::unique_ptr<FunctionPass> CreatePropagateTfAbiKnowledgeToKernels();
@@ -86,6 +80,22 @@
 // Pass to print content of memrefs.
 std::unique_ptr<FunctionPass> CreateEmbedMemRefPrintsPass();
 
+/// Greedily maps loops to GPU hardware dimensions.
+std::unique_ptr<mlir::FunctionPass> CreateMapParallelLoopsPass();
+
+/// We need to direct fusion to the inner loops. This cannot be done with
+/// a passmanager alone ATM, as nested pass managers require operations to
+/// be closed from above.
+std::unique_ptr<mlir::FunctionPass> CreateFuseInnerParallelLoopsPass();
+
+/// Pass that transforms gpu modules in standard dialect to NNVM.
+std::unique_ptr<OperationPass<mlir::gpu::GPUModuleOp>>
+CreateGpuKernelToNvvmPass();
+
+/// Pass that transforms gpu modules in standard dialect to ROCDL.
+std::unique_ptr<OperationPass<mlir::gpu::GPUModuleOp>>
+CreateGpuKernelToRocdlPass();
+
 }  // namespace transforms
 
 #define GEN_PASS_REGISTRATION
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
index 2ec9bb3..5b34645 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
@@ -61,14 +61,14 @@
   let constructor = "transforms::CreateFinalBufferizePass()";
 }
 
-def MaterializeBroadcastsPass : FunctionPass<"materialize-broadcast"> {
-  let summary = "Pass to materialize broadcasts";
-  let constructor = "transforms::CreateMaterializeBroadcastsPass()";
+def GpuKernelToNVVMPass : Pass<"gpu-kernel-to-nvvm", "gpu::GPUModuleOp"> {
+  let summary = "Pass to transform a gpu module to nvvm.";
+  let constructor = "transforms::CreateGpuKernelToNvvmPass()";
 }
 
-def UnfuseBatchNormPass : FunctionPass<"unfuse-batch-norm"> {
-  let summary = "Pass to unfuse batch norm";
-  let constructor = "transforms::CreateUnfuseBatchNormPass()";
+def GpuKernelToROCDLPass : Pass<"gpu-kernel-to-rocdl", "gpu::GPUModuleOp"> {
+  let summary = "Pass to transform a gpu module to rocdl.";
+  let constructor = "transforms::CreateGpuKernelToRocdlPass()";
 }
 
 def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> {
@@ -108,4 +108,24 @@
   let constructor = "transforms::CreateEmbedMemRefPrintsPass()";
 }
 
+def MapParallelLoopsPass
+    : FunctionPass<"map-parallel-loops-to-gpu"> {
+  let summary = "Greedily maps loops to GPU hardware dimensions.";
+  let constructor = "transforms::CreateMapParallelLoopsPass()";
+  let description = [{
+    Greedily maps loops to GPU hardware dimensions.
+  }];
+}
+
+def FuseInnerParallelLoopsPass
+    : FunctionPass<"fuse-inner-parallel-loops"> {
+  let summary = "Limited pass to forward stores to loads.";
+  let constructor = "transforms::CreateFuseInnerParallelLoopsPass()";
+  let description = [{
+    Directs parallel loop fusion to the inner loops. This cannot be done with
+    a passmanager alone ATM, as nested pass managers require operations to
+    be closed from above.
+  }];
+}
+
 #endif // TF_KERNEL_GEN_PASSES
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc
index 0166141..06665bf 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc
@@ -29,8 +29,8 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/AsmState.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
@@ -321,7 +321,7 @@
     knowledge.build(getFunction());
 
     getFunction().walk([&](gpu::LaunchFuncOp launch) {
-      auto module = launch.getParentOfType<ModuleOp>();
+      auto module = launch->getParentOfType<ModuleOp>();
       auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
 
       if (!kernel || kernel.isExternal()) return;
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
index d4a9baf..7743b03 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
 #include "mlir/Dialect/Shape/Transforms/Passes.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
@@ -49,6 +50,7 @@
     target.addIllegalDialect<shape::ShapeDialect>();
     target.addLegalDialect<scf::SCFDialect>();
     target.addLegalDialect<StandardOpsDialect>();
+    target.addLegalDialect<tensor::TensorDialect>();
     // Don't mark the primary Cstr/Assuming ops as illegal, so they can be
     // lowered at a later time to assertions.
     target.addLegalOp<shape::AssumingOp, shape::AssumingYieldOp,
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc
index 7345321..077ba03 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc
@@ -25,8 +25,8 @@
 #include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
@@ -48,7 +48,7 @@
     llvm::SmallVector<Value, 4> worklist;
     // We currently only handle entry functions and do not propagate across
     // functions.
-    if (function.getAttrOfType<mlir::UnitAttr>(
+    if (function->getAttrOfType<mlir::UnitAttr>(
             tf_framework::TFFrameworkDialect::kTFEntryAttrName)) {
       // For all operands of this function, we know they are aligned. Also, by
       // construction of kernel generator, we know that there is no offset and
@@ -81,7 +81,7 @@
 
     // Now look at launches and make use of the knowledge we have.
     function.walk([&](gpu::LaunchFuncOp launch) {
-      auto module = launch.getParentOfType<ModuleOp>();
+      auto module = launch->getParentOfType<ModuleOp>();
       auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
 
       if (!kernel || kernel.isExternal()) return;
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc
index 03b6636..4fa06af 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc
@@ -19,8 +19,8 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
@@ -260,7 +260,7 @@
                                        op.getOperation()->getAttrDictionary());
 
     Location loc = op.getLoc();
-    auto module = op.getParentOfType<ModuleOp>();
+    auto module = op->getParentOfType<ModuleOp>();
     Value message_constant = GenerateErrorMessageConstant(
         loc, module, transformed.msg().getValue(), rewriter);
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc
index 80fee1f..034de3f 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc
@@ -152,7 +152,7 @@
   assert(kernel_module && "expected a kernel module");
 
   auto binary_attr =
-      kernel_module.getAttrOfType<StringAttr>(gpu_binary_annotation_);
+      kernel_module->getAttrOfType<StringAttr>(gpu_binary_annotation_);
   if (!binary_attr) {
     kernel_module.emitOpError()
         << "missing " << gpu_binary_annotation_ << " attribute";
@@ -180,13 +180,13 @@
       LLVM::createGlobalString(loc, rewriter, kernel_name_global_name,
                                kernel_name_buffer, LLVM::Linkage::Internal);
 
-  auto adaptor = gpu::LaunchFuncOpAdaptor(
-      operands, launch_op.getOperation()->getAttrDictionary());
+  auto adaptor =
+      gpu::LaunchFuncOpAdaptor(operands, launch_op->getAttrDictionary());
 
   // The TensorFlow OpKernelContext is the first argument of the surrounding
   // LLVMFunc.
   Value context_arg =
-      launch_op.getParentOfType<LLVM::LLVMFuncOp>().getArgument(0);
+      launch_op->getParentOfType<LLVM::LLVMFuncOp>().getArgument(0);
   auto kernel_params = generateParamsArray(launch_op, operands, rewriter);
 
   auto function = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
@@ -208,7 +208,7 @@
             llvm_pointer_pointer_type_, /* void **kernel_params */
         });
     rewriter.setInsertionPointToStart(
-        launch_op.getParentOfType<ModuleOp>().getBody());
+        launch_op->getParentOfType<ModuleOp>().getBody());
     function = rewriter.create<LLVM::LLVMFuncOp>(
         loc, kTfWrapperLibaryLaunchHelperName, function_type);
   }
diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD
index 8082228..9032a40 100644
--- a/tensorflow/compiler/mlir/tosa/BUILD
+++ b/tensorflow/compiler/mlir/tosa/BUILD
@@ -6,6 +6,7 @@
 load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
 load("//third_party/mlir:tblgen.bzl", "gentbl")
 
+# TODO: Tighten visibility once targets are at the right granularity.
 package(
     default_visibility = [":internal"],
     licenses = ["notice"],  # Apache 2.0
@@ -25,29 +26,20 @@
         ":internal",
     ],
     packages = [
+        "//third_party/iree/...",
     ],
 )
 
-config_setting(
-    name = "enable-build",
-    values = {"define": "build-tosa=true"},
-    visibility = ["//visibility:public"],
-)
-
 filegroup(
     name = "tosa_ops_td_files",
     srcs = [
-        "@llvm-project//mlir:TdFiles",
+        "@llvm-project//mlir:TosaDialectTdFiles",
     ],
-    # TODO: Switch to pruned list of TD files once build file changes land.
-    # srcs = [
-    #     "@llvm-project//mlir:TosaDialectTdFiles",
-    # ],
     compatible_with = get_compatible_with_cloud(),
 )
 
 gentbl(
-    name = "tosa_pass_inc_gen",
+    name = "tosa_passes_inc_gen",
     compatible_with = get_compatible_with_cloud(),
     tbl_outs = [
         (
@@ -62,6 +54,40 @@
     ],
 )
 
+cc_library(
+    name = "passes_header",
+    hdrs = [
+        "transforms/passes.h",
+        "transforms/passes.h.inc",
+    ],
+    compatible_with = get_compatible_with_cloud(),
+    deps = ["@llvm-project//mlir:Pass"],
+)
+
+cc_library(
+    name = "legalize_common",
+    srcs = [
+        "transforms/legalize_common.cc",
+        "transforms/legalize_utils.cc",
+    ],
+    hdrs = [
+        "transforms/legalize_common.h",
+        "transforms/legalize_utils.h",
+    ],
+    compatible_with = get_compatible_with_cloud(),
+    deps = [
+        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/kernels:conv_grad_shape_utils",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:QuantOps",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TosaDialect",
+    ],
+    alwayslink = 1,
+)
+
 gentbl(
     name = "tosa_legalize_tf_inc_gen",
     compatible_with = get_compatible_with_cloud(),
@@ -80,6 +106,36 @@
     ],
 )
 
+cc_library(
+    name = "tf_passes",
+    srcs = [
+        "tf_passes.cc",
+        "transforms/fuse_bias_tf.cc",
+        "transforms/legalize_tf.cc",
+        "transforms/tf_legalize_patterns.inc",
+    ],
+    hdrs = [
+        "tf_passes.h",
+        "transforms/passes.h",
+    ],
+    compatible_with = get_compatible_with_cloud(),
+    visibility = [":friends"],
+    deps = [
+        ":legalize_common",
+        ":passes_header",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:QuantOps",
+        "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TosaDialect",
+        "@llvm-project//mlir:Transforms",
+    ],
+    alwayslink = 1,
+)
+
 gentbl(
     name = "tosa_legalize_tfl_inc_gen",
     compatible_with = get_compatible_with_cloud(),
@@ -99,233 +155,31 @@
 )
 
 cc_library(
-    name = "tosa_legalize_tf",
+    name = "tfl_passes",
     srcs = [
-        "transforms/legalize_tf.cc",
-        "transforms/tf_legalize_patterns.inc",
-    ],
-    hdrs = [
-        "transforms/legalize_common.h",
-        "transforms/legalize_utils.h",
-        "transforms/passes.h",
-        "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
-    ],
-    compatible_with = get_compatible_with_cloud(),
-    deps = [
-        ":tosa_legalize_tf_inc_gen",
-        ":tosa_pass_inc_gen",
-        "//tensorflow/compiler/mlir/tensorflow",
-        "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
-        "//tensorflow/compiler/mlir/tensorflow:tensorflow_all_ops_inc_gen",
-        "//tensorflow/compiler/mlir/tensorflow:translate_lib",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core/kernels:conv_grad_shape_utils",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/memory",
-        "@flatbuffers",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:Analysis",
-        "@llvm-project//mlir:Dialect",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Parser",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:QuantOps",
-        "@llvm-project//mlir:StandardOps",
-        "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:TosaDialect",
-        "@llvm-project//mlir:TransformUtils",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
-    name = "tosa_legalize_tfl",
-    srcs = [
+        "tfl_passes.cc",
+        "transforms/convert_tfl_uint8.cc",
         "transforms/legalize_tfl.cc",
         "transforms/tfl_legalize_patterns.inc",
     ],
     hdrs = [
-        "transforms/legalize_common.h",
-        "transforms/legalize_utils.h",
+        "tfl_passes.h",
         "transforms/passes.h",
-        "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
-        "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
     ],
     compatible_with = get_compatible_with_cloud(),
+    visibility = [":friends"],
     deps = [
-        ":tosa_legalize_tfl_inc_gen",
-        ":tosa_pass_inc_gen",
+        ":legalize_common",
+        ":passes_header",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite",
-        "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_inc_gen",
-        "//tensorflow/compiler/mlir/lite:validators",
-        "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
-        "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
-        "//tensorflow/compiler/mlir/tensorflow:translate_lib",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core/kernels:conv_grad_shape_utils",
-        "//tensorflow/lite/schema:schema_fbs",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/memory",
-        "@flatbuffers",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:Analysis",
-        "@llvm-project//mlir:Dialect",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Parser",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:QuantOps",
-        "@llvm-project//mlir:StandardOps",
-        "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:TosaDialect",
-        "@llvm-project//mlir:TransformUtils",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
-    name = "tosa_legalize_common",
-    srcs = [
-        "transforms/legalize_common.cc",
-        "transforms/legalize_utils.cc",
-        "transforms/tf_legalize_patterns.inc",
-    ],
-    hdrs = [
-        "transforms/legalize_common.h",
-        "transforms/legalize_utils.h",
-        "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
-    ],
-    compatible_with = get_compatible_with_cloud(),
-    deps = [
-        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
-        "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_inc_gen",
-        "//tensorflow/compiler/mlir/lite:validators",
-        "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
-        "//tensorflow/compiler/mlir/tensorflow:tensorflow_all_ops_inc_gen",
-        "//tensorflow/compiler/mlir/tensorflow:translate_lib",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core/kernels:conv_grad_shape_utils",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/memory",
-        "@flatbuffers",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:Analysis",
-        "@llvm-project//mlir:Dialect",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Parser",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:QuantOps",
-        "@llvm-project//mlir:StandardOps",
-        "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:TosaDialect",
-        "@llvm-project//mlir:TransformUtils",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
-    name = "tosa_fuse_bias_tf",
-    srcs = [
-        "transforms/fuse_bias_tf.cc",
-    ],
-    hdrs = [
-        "transforms/passes.h",
-        "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
-    ],
-    compatible_with = get_compatible_with_cloud(),
-    deps = [
-        ":tosa_legalize_common",
-        ":tosa_pass_inc_gen",
-        "//tensorflow/compiler/mlir/tensorflow",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:StandardOps",
-        "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:TosaDialect",
-        "@llvm-project//mlir:TransformUtils",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
-    name = "tosa_convert_tfl_uint8",
-    srcs = [
-        "transforms/convert_tfl_uint8.cc",
-    ],
-    hdrs = [
-        "transforms/passes.h",
-        "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
-    ],
-    compatible_with = get_compatible_with_cloud(),
-    deps = [
-        ":tosa_legalize_common",
-        ":tosa_pass_inc_gen",
-        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:QuantOps",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TosaDialect",
-        "@llvm-project//mlir:TransformUtils",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
-    name = "tosa_pipelines",
-    srcs = [
-        "tosa_passpipes.cc",
-    ],
-    hdrs = [
-        "tosa_passpipes.h",
-        "transforms/passes.h",
-        "transforms/register_passes.h",
-    ],
-    compatible_with = get_compatible_with_cloud(),
-    deps = [
-        ":tosa_pass_inc_gen",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:TosaDialect",
-        "@llvm-project//mlir:TransformUtils",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
-    name = "tf_tosa_passes",
-    srcs = [
-        "tf_tosa_pipeline.cc",
-    ],
-    hdrs = [
-    ],
-    compatible_with = get_compatible_with_cloud(),
-    deps = [
-        ":tosa_fuse_bias_tf",
-        ":tosa_legalize_common",
-        ":tosa_legalize_tf",
-        ":tosa_pipelines",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
-    name = "tfl_tosa_passes",
-    srcs = [
-        "tfl_tosa_pipeline.cc",
-    ],
-    hdrs = [
-    ],
-    compatible_with = get_compatible_with_cloud(),
-    deps = [
-        ":tosa_convert_tfl_uint8",
-        ":tosa_legalize_common",
-        ":tosa_legalize_tfl",
-        ":tosa_pipelines",
+        "@llvm-project//mlir:Transforms",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/compiler/mlir/tosa/tf_passes.cc b/tensorflow/compiler/mlir/tosa/tf_passes.cc
new file mode 100644
index 0000000..fadf7e5
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tf_passes.cc
@@ -0,0 +1,64 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tosa/tf_passes.h"
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"  // from @llvm-project
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "mlir/Transforms/Passes.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
+
+namespace mlir {
+namespace tosa {
+
+void createTFtoTOSALegalizationPipeline(
+    OpPassManager& pm, const TOSATFLegalizationPipelineOptions& opts) {
+  //----------------------------------------------------------------------------
+  // Prepare TFL module for conversion
+  //----------------------------------------------------------------------------
+  // Inline all functions into main and then delete the functions themselves.
+  pm.addPass(mlir::createInlinerPass());
+
+  // Now that there is only one function, run some MLIR passes on it.
+  pm.addPass(mlir::createCanonicalizerPass());
+  pm.addPass(mlir::createCSEPass());
+
+  pm.addPass(mlir::createLoopFusionPass());
+  pm.addPass(mlir::createMemRefDataFlowOptPass());
+
+  //----------------------------------------------------------------------------
+  // Perform main conversion.
+  // Now that there is only one function, run some MLIR passes on it.
+  //----------------------------------------------------------------------------
+  pm.addPass(mlir::tosa::createFuseBiasTFPass());
+  pm.addPass(mlir::tosa::createLegalizeTFPass());
+
+  //----------------------------------------------------------------------------
+  // Post conversion cleanup.
+  //----------------------------------------------------------------------------
+  pm.addPass(mlir::tosa::createTosaMakeBroadcastablePass());
+  // Inline the call/return basic blocks within TOSA control flow ops.
+  pm.addPass(mlir::createInlinerPass());
+  // Clean up with DCE.
+  pm.addPass(mlir::createSymbolDCEPass());
+}
+
+static mlir::PassPipelineRegistration<TOSATFLegalizationPipelineOptions>
+    tf_tosa_pipeline("tf-to-tosa-pipeline",
+                     "TensorFlow to TOSA legalization pipeline",
+                     createTFtoTOSALegalizationPipeline);
+
+}  // namespace tosa
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/tosa_passpipes.h b/tensorflow/compiler/mlir/tosa/tf_passes.h
similarity index 65%
rename from tensorflow/compiler/mlir/tosa/tosa_passpipes.h
rename to tensorflow/compiler/mlir/tosa/tf_passes.h
index eee7e63..18d11cd 100644
--- a/tensorflow/compiler/mlir/tosa/tosa_passpipes.h
+++ b/tensorflow/compiler/mlir/tosa/tf_passes.h
@@ -16,28 +16,20 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H
 #define TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H
 
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/PassManager.h"
-#include "llvm/ADT/Optional.h"
-#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Pass/PassOptions.h"  // from @llvm-project
 
 namespace mlir {
-
 namespace tosa {
 
-void addPreOptMlirPasses(mlir::OpPassManager& pm);
+struct TOSATFLegalizationPipelineOptions
+    : public PassPipelineOptions<TOSATFLegalizationPipelineOptions> {};
 
-void addPostOptMlirPasses(mlir::OpPassManager& pm);
-
+// Legalizes TF dialect(s) to Tosa.
 void createTFtoTOSALegalizationPipeline(
-    OpPassManager& pm, const TOSALegalizationPipelineOptions& opts);
-
-void createTFLtoTOSALegalizationPipeline(
-    OpPassManager& pm, const TOSALegalizationPipelineOptions& opts);
+    OpPassManager& pm, const TOSATFLegalizationPipelineOptions& opts);
 
 }  // namespace tosa
-
 }  // namespace mlir
 
 #endif  // TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H
diff --git a/tensorflow/compiler/mlir/tosa/tf_tosa_pipeline.cc b/tensorflow/compiler/mlir/tosa/tf_tosa_pipeline.cc
deleted file mode 100644
index e8d1aa7..0000000
--- a/tensorflow/compiler/mlir/tosa/tf_tosa_pipeline.cc
+++ /dev/null
@@ -1,29 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h"
-
-namespace mlir {
-
-namespace tosa {
-
-static mlir::PassPipelineRegistration<TOSALegalizationPipelineOptions>
-    tf_tosa_pipeline("tf-to-tosa-pipeline",
-                     "TensorFlow to TOSA legalization pipeline",
-                     createTFtoTOSALegalizationPipeline);
-
-}  // namespace tosa
-
-}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/tfl_passes.cc b/tensorflow/compiler/mlir/tosa/tfl_passes.cc
new file mode 100644
index 0000000..25d9041
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tfl_passes.cc
@@ -0,0 +1,62 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tosa/tfl_passes.h"
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"  // from @llvm-project
+#include "mlir/Transforms/Passes.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
+
+namespace mlir {
+namespace tosa {
+
+void createTFLtoTOSALegalizationPipeline(
+    OpPassManager& pm, const TOSATFLLegalizationPipelineOptions& opts) {
+  //----------------------------------------------------------------------------
+  // Prepare TFL module for conversion
+  //----------------------------------------------------------------------------
+  // Inline all functions into main and then delete the functions themselves.
+  pm.addPass(mlir::createInlinerPass());
+
+  // Now that there is only one function, run some MLIR passes on it.
+  pm.addPass(mlir::createCanonicalizerPass());
+  pm.addPass(mlir::createCSEPass());
+
+  pm.addPass(mlir::createLoopFusionPass());
+  pm.addPass(mlir::createMemRefDataFlowOptPass());
+
+  //----------------------------------------------------------------------------
+  // Perform main conversion.
+  //----------------------------------------------------------------------------
+  pm.addPass(mlir::tosa::createConvertTFLUint8Pass());
+  pm.addPass(mlir::tosa::createLegalizeTFLPass());
+
+  //----------------------------------------------------------------------------
+  // Post conversion cleanup.
+  //----------------------------------------------------------------------------
+  pm.addPass(mlir::tosa::createTosaMakeBroadcastablePass());
+  // Inline the call/return basic blocks within TOSA control flow ops.
+  pm.addPass(mlir::createInlinerPass());
+  // Clean up with DCE.
+  pm.addPass(mlir::createSymbolDCEPass());
+}
+
+static mlir::PassPipelineRegistration<TOSATFLLegalizationPipelineOptions>
+    tfl_tosa_pipeline("tfl-to-tosa-pipeline",
+                      "TensorFlow Lite to TOSA legalization pipeline",
+                      createTFLtoTOSALegalizationPipeline);
+
+}  // namespace tosa
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/tfl_passes.h b/tensorflow/compiler/mlir/tosa/tfl_passes.h
new file mode 100644
index 0000000..255418a
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tfl_passes.h
@@ -0,0 +1,35 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_
+#define TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_
+
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Pass/PassOptions.h"  // from @llvm-project
+
+namespace mlir {
+namespace tosa {
+
+struct TOSATFLLegalizationPipelineOptions
+    : public PassPipelineOptions<TOSATFLLegalizationPipelineOptions> {};
+
+// Legalizes TFL (TensorFlow lite) dialect(s) to Tosa.
+void createTFLtoTOSALegalizationPipeline(
+    OpPassManager& pm, const TOSATFLLegalizationPipelineOptions& opts);
+
+}  // namespace tosa
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_
diff --git a/tensorflow/compiler/mlir/tosa/tfl_tosa_pipeline.cc b/tensorflow/compiler/mlir/tosa/tfl_tosa_pipeline.cc
deleted file mode 100644
index 8552a68..0000000
--- a/tensorflow/compiler/mlir/tosa/tfl_tosa_pipeline.cc
+++ /dev/null
@@ -1,29 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h"
-
-namespace mlir {
-
-namespace tosa {
-
-static mlir::PassPipelineRegistration<TOSALegalizationPipelineOptions>
-    tfl_tosa_pipeline("tfl-to-tosa-pipeline",
-                      "TensorFlow Lite to TOSA legalization pipeline",
-                      createTFLtoTOSALegalizationPipeline);
-
-}  // namespace tosa
-
-}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/tosa_passpipes.cc b/tensorflow/compiler/mlir/tosa/tosa_passpipes.cc
deleted file mode 100644
index 1bad415..0000000
--- a/tensorflow/compiler/mlir/tosa/tosa_passpipes.cc
+++ /dev/null
@@ -1,74 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h"
-
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/Passes.h"
-#include "llvm/ADT/Optional.h"
-#include "llvm/ADT/STLExtras.h"
-
-namespace mlir {
-
-namespace tosa {
-
-void addPreOptMlirPasses(mlir::OpPassManager& pm) {
-  // Inline all functions into main and then delete the functions themselves.
-  pm.addPass(mlir::createInlinerPass());
-
-  // Now that there is only one function, run some MLIR passes on it.
-  pm.addPass(mlir::createCanonicalizerPass());
-  pm.addPass(mlir::createCSEPass());
-
-  pm.addPass(mlir::createLoopFusionPass());
-  pm.addPass(mlir::createMemRefDataFlowOptPass());
-}
-
-void addPostOptMlirPasses(mlir::OpPassManager& pm) {
-  pm.addPass(mlir::tosa::createTosaMakeBroadcastablePass());
-  // Inline the call/return basic blocks within TOSA control flow ops.
-  pm.addPass(mlir::createInlinerPass());
-  // Clean up with DCE.
-  pm.addPass(mlir::createSymbolDCEPass());
-}
-
-void createTFtoTOSALegalizationPipeline(
-    OpPassManager& pm, const TOSALegalizationPipelineOptions& opts) {
-  addPreOptMlirPasses(pm);
-
-  pm.addPass(mlir::tosa::createFuseBiasTFPass());
-  pm.addPass(mlir::tosa::createLegalizeTFPass());
-
-  addPostOptMlirPasses(pm);
-}
-
-void createTFLtoTOSALegalizationPipeline(
-    OpPassManager& pm, const TOSALegalizationPipelineOptions& opts) {
-  addPreOptMlirPasses(pm);
-
-  pm.addPass(mlir::tosa::createConvertTFLUint8Pass());
-  pm.addPass(mlir::tosa::createLegalizeTFLPass());
-
-  addPostOptMlirPasses(pm);
-}
-
-}  // namespace tosa
-
-}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc
index 8a0e36d..08ee3c2 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc
@@ -29,25 +29,14 @@
 #include <iterator>
 #include <numeric>
 
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
diff --git a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc
index 9c17b0a0..058ba48 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc
@@ -21,41 +21,16 @@
 #include <iterator>
 #include <numeric>
 
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
 
 #define PASS_NAME "tosa-fuse-bias-tf"
 #define DEBUG_TYPE PASS_NAME
 
-// TODO: remove macro when replacing common function return types with
-// llvm::Optional<> Helper macros for checking the return value of a common
-// legalization function that returns a single tensor.
-// Packs the result in a list.
-#define TOSA_REPLACE_LOWERED_OP(REWRITER, OP, LOWERED_OP)   \
-  if (LOWERED_OP) {                                         \
-    REWRITER.replaceOp((OP), {(LOWERED_OP)->getResults()}); \
-    return success();                                       \
-  } else {                                                  \
-    return failure();                                       \
-  }
-
 namespace mlir {
 
 namespace tosa {
@@ -118,13 +93,17 @@
   // Bias tensor that feeds into tosa.conv2d must be rank 1
   if (bias_shape.size() != 1) return failure();
 
-  auto lowered_op = convertTFConv2DCommon(
+  auto result = convertTFConv2DCommon(
       rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(),
       bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(),
       tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(),
       tf_conv2d_op.data_format());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 void FuseBiasTF::runOnFunction() {
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
index fc041dd..9f987ca 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
@@ -14,12 +14,12 @@
 ==============================================================================*/
 
 // This file contains legalizations common to mapping both TensorFlow and
-// TensorFlow Lite to TOSA.
+// TensorFlow Lite to TOSA. It operates generically on ops and does not have
+// a hard reference on either dialect.
 //
-// Conversion functions return nullptr on a lowerization failure or a
-// lowered operator on success.  Callers must check and return a
-// LogicalResult failure on nullptr.  Helper macros are provided in
-// legalize_common.h to canonicalize this handling.
+// Conversion functions return llvm::None on a legalization failure or a
+// legalized value on success.  Callers must check for presence of an
+// llvm::Optional value after each call.
 
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
 
@@ -29,20 +29,21 @@
 #include <iterator>
 #include <numeric>
 
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
+#include "mlir/IR/Matchers.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
 
-// TODO for further work:
-// * It is better to return an llvm::Optional instead of an Operation*. It
-//   enables generic handling of some of the cases a bit better where
-//   we are doing different things with the ops.
-
 namespace mlir {
 namespace tosa {
 
 // Lowers the Pack operator to TOSA.
-Operation* convertPackOp(PatternRewriter& rewriter, Operation* op,
-                         Value result_value, SmallVector<Value, 8>& inputs,
-                         int32_t axis) {
+llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
+                                    Value result_value,
+                                    SmallVector<Value, 8>& inputs,
+                                    int32_t axis) {
   //////////////////////////////////////////////////
   // Operator: output = Pack([values], axis) or output = Stack([values], axis)
   // Lowering:
@@ -67,30 +68,32 @@
   // Sanity check 1: make sure all input tensors have the same shape
   // if input[0] has shape [A, B, C], input[1] to input[N-1] should also have
   // shape[A, B, C]
-  auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType result_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
 
   // Check for ranked tensor type.
   if (!result_type) {
     op->emitOpError("PackOp: result type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
 
   // Valid axis in TF is [-rank(input), rank(input))
   // Valid axis in TOSA is [0, rank(input))
   // Plus rank(input) once if axis is negative.
-  auto input_type = op->getOperand(0).getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      op->getOperand(0).getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("PackOp: input type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto input_rank = input_type.getShape().size();
+  int32_t input_rank = input_type.getShape().size();
   if (axis < 0) axis += input_rank;
 
   input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("Input 0 type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
   ArrayRef<int64_t> input0_tensor_shape = input_type.getShape();
   int input_tensor_rank = input0_tensor_shape.size();
@@ -101,17 +104,17 @@
       op->emitOpError(llvm::formatv(
           "reduce axis {} is not in valid range [-rank(input), rank(input))",
           i));
-      return nullptr;
+      return llvm::None;
     }
     ArrayRef<int64_t> next_tensor_shape = input_type.getShape();
     if (next_tensor_shape.size() != input_tensor_rank) {
       op->emitOpError("PackOp: input tensor rank mismatch.");
-      return nullptr;
+      return llvm::None;
     }
     for (int d = 0; d < input0_tensor_shape.size(); d++) {
       if (input0_tensor_shape[d] != next_tensor_shape[d]) {
         op->emitOpError("PackOp: input tensor shape mismatch.");
-        return nullptr;
+        return llvm::None;
       }
     }
   }
@@ -120,7 +123,7 @@
   // performing concat.
   if (input_tensor_rank == 0) {
     SmallVector<int64_t, 8> reshape_rank1_size1_shape{1};
-    auto reshape_rank1_size1_type =
+    RankedTensorType reshape_rank1_size1_type =
         RankedTensorType::get(ArrayRef<int64_t>(reshape_rank1_size1_shape),
                               result_type.getElementType());
     ArrayAttr shape_rank1_size1_attr =
@@ -141,7 +144,7 @@
 
   if (axis > (input_tensor_rank + 1)) {
     op->emitOpError("PackOp: axis out of valid range.");
-    return nullptr;
+    return llvm::None;
   }
 
   // Sanity check 2: if input shape is [A, B, C], output shape should be [N,
@@ -151,12 +154,12 @@
                                             result_type.getShape().end());
   if (output_shape_vals.size() != (input_tensor_rank + 1)) {
     op->emitOpError("PackOp: output tensor rank mismatch.");
-    return nullptr;
+    return llvm::None;
   }
   // 2.b check output rank 0 is N
   if (output_shape_vals[axis] != inputs.size()) {
     op->emitOpError("PackOp: output tensor shape mismatch.");
-    return nullptr;
+    return llvm::None;
   }
   // Most of the cases when PackOp.axis() is within [0, rank(input) - 1].
   // We can directly concatenate along that axis and perform the reshape.
@@ -212,7 +215,7 @@
   }
 
   concat_output_shape[concat_axis] = orig_input_dim_on_axis * 2;
-  auto concat_type = RankedTensorType::get(
+  RankedTensorType concat_type = RankedTensorType::get(
       ArrayRef<int64_t>(concat_output_shape), result_type.getElementType());
   auto a1_concat_op = rewriter.create<tosa::ConcatOp>(
       op->getLoc(), concat_type, inputs[0], inputs[1], concat_axis_attr);
@@ -228,42 +231,39 @@
                                                    inputs[i], concat_axis_attr);
   }
 
-  Operation* lowered_op = nullptr;
   // Doesn't need reshape or transpose if input tensor is rank 0, since inputs
   // are reshaped beforehand.
-  if (input_tensor_rank == 0) {
-    lowered_op = a1_concat_op;
-  } else {
-    // Reshape [N * A, B, C] to [N, A, B, C].
-    auto reshape_output_type = RankedTensorType::get(
-        ArrayRef<int64_t>(reshape_output_shape), result_type.getElementType());
+  if (input_tensor_rank == 0) return a1_concat_op.getResult();
 
-    auto a2_reshape_op =
-        rewriter.create<tosa::ReshapeOp>(op->getLoc(), reshape_output_type,
-                                         a1_concat_op.getResult(), shape_attr);
+  // Reshape [N * A, B, C] to [N, A, B, C].
+  RankedTensorType reshape_output_type = RankedTensorType::get(
+      ArrayRef<int64_t>(reshape_output_shape), result_type.getElementType());
 
-    // If axis is equal to input tensor rank, then we need extra transpose
-    // [N, A, B, C] to [A, B, C, N]
-    if (axis == input_tensor_rank) {
-      auto a3_transpose_perm =
-          get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, perm);
-      auto a3_transpose_op = rewriter.create<tosa::TransposeOp>(
-          op->getLoc(), result_type, a2_reshape_op.getResult(),
-          a3_transpose_perm);
-      lowered_op = a3_transpose_op;
-    } else {
-      lowered_op = a2_reshape_op;
-    }
+  auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
+      op->getLoc(), reshape_output_type, a1_concat_op.getResult(), shape_attr);
+
+  // If axis is equal to input tensor rank, then we need extra transpose
+  // [N, A, B, C] to [A, B, C, N]
+  if (axis == input_tensor_rank) {
+    Value a3_transpose_perm =
+        get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, perm);
+
+    return rewriter
+        .create<tosa::TransposeOp>(op->getLoc(), result_type,
+                                   a2_reshape_op.getResult(), a3_transpose_perm)
+        .getResult();
   }
 
-  return lowered_op;
+  return a2_reshape_op.getResult();
 }
 
 // Lowers the Unpack operator to TOSA
-Operation* convertUnpackOp(PatternRewriter& rewriter, Operation* op,
-                           Value input_value, int32_t axis) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<ValueRange> convertUnpackOp(PatternRewriter& rewriter,
+                                           Operation* op, Value input_value,
+                                           int32_t axis) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
   auto input_shape = input_type.getShape();
   int64_t input_rank = input_shape.size();
@@ -290,7 +290,7 @@
       perm_vec.push_back(i);
     }
 
-    auto a1_transpose_perm =
+    Value a1_transpose_perm =
         get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, perm_vec);
 
     for (int i = 0; i < input_rank; i++) {
@@ -310,9 +310,9 @@
   }
 
   // Step 2: slice [N, A, B, C] into N [A, B, C].
-  auto transposed_input_type =
+  RankedTensorType transposed_input_type =
       transposed_input_value.getType().dyn_cast<RankedTensorType>();
-  if (!transposed_input_type) return nullptr;
+  if (!transposed_input_type) return llvm::None;
 
   auto transposed_input_shape = transposed_input_type.getShape();
   int64_t transposed_input_rank = transposed_input_shape.size();
@@ -354,71 +354,73 @@
 
   // Combine the sequence of tosa.slice() ops into a list
   // using the IdentityN operator.
-  return rewriter.create<tosa::IdentityNOp>(
-      op->getLoc(), ArrayRef<Type>(outs_type_vec), results_vec);
+  return rewriter
+      .create<tosa::IdentityNOp>(op->getLoc(), ArrayRef<Type>(outs_type_vec),
+                                 results_vec)
+      .getResults();
 }
 
 // Lowers the Select operator to TOSA.
-Operation* convertSelectOp(PatternRewriter& rewriter, Operation* op,
-                           Value result_value, Value condition_value,
-                           Value x_value, Value y_value) {
-  auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
-  auto condition_type = condition_value.getType().dyn_cast<RankedTensorType>();
-  auto x_type = x_value.getType().dyn_cast<RankedTensorType>();
-  auto y_type = y_value.getType().dyn_cast<RankedTensorType>();
-
-  Operation* result_op = nullptr;
+llvm::Optional<Value> convertSelectOp(PatternRewriter& rewriter, Operation* op,
+                                      Value result_value, Value condition_value,
+                                      Value x_value, Value y_value) {
+  RankedTensorType result_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType condition_type =
+      condition_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType x_type = x_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType y_type = y_value.getType().dyn_cast<RankedTensorType>();
 
   if (!result_type || !condition_type || !x_type || !y_type) {
     op->emitOpError("Select: failed ranked tensor type check");
-    return nullptr;
+    return llvm::None;
   }
 
   // First check whether we need to reshape the condition to match
   // the same rank as the then/else clauses.
   if (result_type.getRank() == condition_type.getRank()) {
     // Nothing to reshape.
-    result_op = rewriter.create<tosa::SelectOp>(
-        op->getLoc(), result_type, condition_value, x_value, y_value);
-  } else {
-    // Need to reshape the condition.
-    SmallVector<int64_t, 8> new_cond_dims;
-    for (int i = 0; i < (result_type.getRank() - condition_type.getRank());
-         i++) {
-      new_cond_dims.push_back(1);
-    }
-    for (int i = 0; i < condition_type.getRank(); i++) {
-      new_cond_dims.push_back(condition_type.getShape()[i]);
-    }
-
-    auto reshape_op = rewriter.create<tosa::ReshapeOp>(
-        op->getLoc(),
-        RankedTensorType::get(ArrayRef<int64_t>(new_cond_dims),
-                              condition_type.getElementType()),
-        condition_value, rewriter.getI64ArrayAttr(new_cond_dims));
-
-    auto new_select = rewriter.create<tosa::SelectOp>(
-        op->getLoc(), result_type, reshape_op, x_value, y_value);
-    result_op = new_select;
+    return rewriter
+        .create<tosa::SelectOp>(op->getLoc(), result_type, condition_value,
+                                x_value, y_value)
+        .getResult();
   }
 
-  return result_op;
+  // Need to reshape the condition.
+  SmallVector<int64_t, 8> new_cond_dims(
+      result_type.getRank() - condition_type.getRank(), 1);
+
+  for (int i = 0; i < condition_type.getRank(); i++) {
+    new_cond_dims.push_back(condition_type.getShape()[i]);
+  }
+
+  auto reshape_op = rewriter.create<tosa::ReshapeOp>(
+      op->getLoc(),
+      RankedTensorType::get(ArrayRef<int64_t>(new_cond_dims),
+                            condition_type.getElementType()),
+      condition_value, rewriter.getI64ArrayAttr(new_cond_dims));
+
+  return rewriter
+      .create<tosa::SelectOp>(op->getLoc(), result_type, reshape_op, x_value,
+                              y_value)
+      .getResult();
 }
 
 // Lowers the ZerosLike operator to TOSA by creating a constant
 // of the desired type and shape.
-Operation* convertZerosLikeOp(PatternRewriter& rewriter, Operation* op,
-                              Value result, Value input) {
-  auto result_type = result.getType().dyn_cast<RankedTensorType>();
+llvm::Optional<Value> convertZerosLikeOp(PatternRewriter& rewriter,
+                                         Operation* op, Value result,
+                                         Value input) {
+  RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
   if (!result_type) {
     op->emitOpError("Zeroslike: result not ranked tensor type");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto input_type = input.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("Zeroslike: input not ranked tensor type");
-    return nullptr;
+    return llvm::None;
   }
 
   auto input_shape = input_type.getShape();
@@ -427,20 +429,26 @@
       RankedTensorType::get(input_shape, input_type.getElementType());
   Attribute zero_attr = rewriter.getZeroAttr(zero_type);
 
-  return rewriter.create<tosa::ConstOp>(op->getLoc(), zero_type,
-                                        zero_attr.cast<ElementsAttr>());
+  return rewriter
+      .create<tosa::ConstOp>(op->getLoc(), zero_type,
+                             zero_attr.cast<ElementsAttr>())
+      .getResult();
 }
 
 // Lowers the Mul operator to TOSA.  For quantized types, this requires
 // inserting rescale operators before and after the operation.
-Operation* convertMultiplyOp(PatternRewriter& rewriter, Operation* op,
-                             Value output_val, Value input_lhs_val,
-                             Value input_rhs_val) {
-  auto input_lhs_type = input_lhs_val.getType().dyn_cast<RankedTensorType>();
-  auto input_rhs_type = input_rhs_val.getType().dyn_cast<RankedTensorType>();
-  auto output_type = output_val.getType().dyn_cast<RankedTensorType>();
+llvm::Optional<Value> convertMultiplyOp(PatternRewriter& rewriter,
+                                        Operation* op, Value output_val,
+                                        Value input_lhs_val,
+                                        Value input_rhs_val) {
+  RankedTensorType input_lhs_type =
+      input_lhs_val.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_rhs_type =
+      input_rhs_val.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
+      output_val.getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
-  if (!input_lhs_type || !input_rhs_type || !output_type) return nullptr;
+  if (!input_lhs_type || !input_rhs_type || !output_type) return llvm::None;
 
   bool input_lhs_is_qtype =
       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
@@ -454,12 +462,12 @@
     op->emitOpError(
         "ConvertMultiplyOp: input/output tensor should "
         "be all quantized or all floating-point");
-    return nullptr;
+    return llvm::None;
   }
 
   Value output;
   if (output_is_qtype) {
-    auto rescale_type =
+    RankedTensorType rescale_type =
         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
     auto input_lhs_qtype = input_lhs_type.getElementType()
                                .cast<mlir::quant::UniformQuantizedType>();
@@ -467,99 +475,103 @@
                                .cast<mlir::quant::UniformQuantizedType>();
     auto output_qtype =
         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
-
     double in_lhs_scale = input_lhs_qtype.getScale();
     double in_rhs_scale = input_rhs_qtype.getScale();
     double output_scale = output_qtype.getScale();
 
     double output_rescale_scale = in_lhs_scale * in_rhs_scale / output_scale;
 
-    auto op1_rescale_lhs = buildRescaleToInt32(
+    Value op1_rescale_lhs = buildRescaleToInt32(
         rewriter, op, input_lhs_val, 1.0f, input_lhs_qtype.getZeroPoint());
-    auto op2_rescale_rhs = buildRescaleToInt32(
+    Value op2_rescale_rhs = buildRescaleToInt32(
         rewriter, op, input_rhs_val, 1.0f, input_rhs_qtype.getZeroPoint());
     auto op3_mul_op1_op2 = rewriter.create<tosa::MulOp>(
         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs, 0);
-    auto op4_rescale_op3 = buildRescaleFromInt32(
+    return buildRescaleFromInt32(
         rewriter, op, output_type, op3_mul_op1_op2.getResult(),
         output_rescale_scale, output_qtype.getZeroPoint());
-    output = op4_rescale_op3;
-  } else {
-    auto op1_mul_in = rewriter.create<tosa::MulOp>(
-        op->getLoc(), output_type, input_lhs_val, input_rhs_val, 0);
-
-    output = op1_mul_in.getResult();
   }
 
-  return output.getDefiningOp();
+  return rewriter
+      .create<tosa::MulOp>(op->getLoc(), output_type, input_lhs_val,
+                           input_rhs_val, 0)
+      .getResult();
 }
 
 // Lowers the SquaredDifference operator to TOSA.
-Operation* convertSquaredDifferenceOp(PatternRewriter& rewriter, Operation* op,
-                                      Value result, Value x, Value y) {
+llvm::Optional<Value> convertSquaredDifferenceOp(PatternRewriter& rewriter,
+                                                 Operation* op, Value result,
+                                                 Value x, Value y) {
   // Squared-difference is (x-y)*(x-y).
   // This lowering calculates the difference and multiplies.
-  auto result_type = result.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
   if (!result_type) {
     op->emitOpError("SquaredDifference: result not ranked tensor type");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto x_type = x.getType().dyn_cast<RankedTensorType>();
-  auto y_type = y.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType x_type = x.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType y_type = y.getType().dyn_cast<RankedTensorType>();
   if (!x_type || !y_type) {
     op->emitOpError("SquaredDifference: inputs not ranked tensor type");
-    return nullptr;
+    return llvm::None;
   }
 
   auto sub_op = rewriter.create<tosa::SubOp>(op->getLoc(), result_type, x, y);
-  return rewriter.create<tosa::MulOp>(
-      op->getLoc(), result_type, sub_op.getResult(), sub_op.getResult(), 0);
+  return rewriter
+      .create<tosa::MulOp>(op->getLoc(), result_type, sub_op.getResult(),
+                           sub_op.getResult(), 0)
+      .getResult();
 }
 
 // Lowers the Round operator to TOSA.
-Operation* convertRoundOp(PatternRewriter& rewriter, Operation* op,
-                          Value result, Value input) {
+llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op,
+                                     Value result, Value input) {
   // Implements banker's rounding by calculating floor(input + 0.5).
-  auto result_type = result.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
   if (!result_type) {
     op->emitOpError("Round: result not ranked tensor type");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto input_type = input.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("Round: input not ranked tensor type");
-    return nullptr;
+    return llvm::None;
   }
 
   auto add_op = rewriter.create<tosa::AddOp>(
       op->getLoc(), result_type, input,
       getTosaConstTensorSingleF32(rewriter, op, 0.5));
-  return rewriter.create<tosa::FloorOp>(op->getLoc(), result_type,
-                                        add_op.getResult());
+
+  return rewriter
+      .create<tosa::FloorOp>(op->getLoc(), result_type, add_op.getResult())
+      .getResult();
 }
 
 // Lowers ConcatV2 to TOSA.
-Operation* convertConcatV2Op(PatternRewriter& rewriter, Operation* op,
-                             Value result_value, SmallVector<Value, 8>& values,
-                             int32_t axis) {
+llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
+                                        Operation* op, Value result_value,
+                                        SmallVector<Value, 8>& values,
+                                        int32_t axis) {
   // ConcatV2 becomes a series of TOSA Concat operators that take pairs of
   // tensors as arguments.   Rank-0 tensors are reshaped to Rank-1,
   // shape (1,) tensors.
-  auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType result_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
   if (!result_type) {
     op->emitOpError("ConcatV2Op: result type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
   // Valid axis in TF is [-rank(input), rank(input)).
   // Valid axis in TOSA is [0, rank(input)).
   // Plus rank(input) once if axis is negative.
-  auto input_type = op->getOperand(0).getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      op->getOperand(0).getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("ConcatV2Op: input type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
   auto input_rank = input_type.getShape().size();
@@ -571,13 +583,13 @@
   if (!values[0].getType().dyn_cast<RankedTensorType>() ||
       !values[1].getType().dyn_cast<RankedTensorType>()) {
     op->emitOpError("ConcatV2Op: value type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
   Value lhs_val = values[0];
   Value rhs_val = values[1];
-  auto lhs_type = lhs_val.getType().cast<RankedTensorType>();
-  auto rhs_type = rhs_val.getType().cast<RankedTensorType>();
+  RankedTensorType lhs_type = lhs_val.getType().cast<RankedTensorType>();
+  RankedTensorType rhs_type = rhs_val.getType().cast<RankedTensorType>();
   ArrayRef<int64_t> lhs_tensor_shape = lhs_type.getShape();
   ArrayRef<int64_t> rhs_tensor_shape = rhs_type.getShape();
   int input_tensor_rank = lhs_tensor_shape.size();
@@ -593,10 +605,10 @@
   if (input_tensor_rank == 0) {
     if (axis != 0) {
       op->emitOpError("ConcatV2Op: axis invalid.");
-      return nullptr;
+      return llvm::None;
     }
     SmallVector<int64_t, 8> reshape_rank1_size1_shape{1};
-    auto reshape_rank1_size1_type =
+    RankedTensorType reshape_rank1_size1_type =
         RankedTensorType::get(ArrayRef<int64_t>(reshape_rank1_size1_shape),
                               result_type.getElementType());
     ArrayAttr shape_rank1_size1_attr =
@@ -611,7 +623,7 @@
   } else {
     if (axis < 0 || axis >= input_tensor_rank) {
       op->emitOpError("ConcatV2Op: axis invalid.");
-      return nullptr;
+      return llvm::None;
     }
     for (int i = 0; i < input_tensor_rank; i++) {
       concat_result_shape.push_back(lhs_tensor_shape[i]);
@@ -619,7 +631,7 @@
     concat_result_shape[axis] = lhs_tensor_shape[axis] + rhs_tensor_shape[axis];
   }
 
-  auto concat_type = RankedTensorType::get(
+  RankedTensorType concat_type = RankedTensorType::get(
       ArrayRef<int64_t>(concat_result_shape), result_type.getElementType());
 
   mlir::quant::UniformQuantizedType lhs_quant_type =
@@ -651,22 +663,22 @@
 
     // Rescale input if scale is not equal to output tensor scale.
     if (lhs_scale != result_scale) {
-      auto rescale_type =
+      RankedTensorType rescale_type =
           RankedTensorType::get(lhs_type.getShape(), result_quant_type);
 
-      auto rescale_op = buildRescale(rewriter, op, rescale_type, lhs_val,
-                                     lhs_scale / result_scale, lhs_zeropoint,
-                                     result_zeropoint);
+      Value rescale_op = buildRescale(rewriter, op, rescale_type, lhs_val,
+                                      lhs_scale / result_scale, lhs_zeropoint,
+                                      result_zeropoint);
 
       lhs_val = rescale_op;
     }
     if (rhs_scale != result_scale) {
-      auto rescale_type =
+      RankedTensorType rescale_type =
           RankedTensorType::get(rhs_type.getShape(), result_quant_type);
 
-      auto rescale_op = buildRescale(rewriter, op, rescale_type, rhs_val,
-                                     rhs_scale / result_scale, rhs_zeropoint,
-                                     result_zeropoint);
+      Value rescale_op = buildRescale(rewriter, op, rescale_type, rhs_val,
+                                      rhs_scale / result_scale, rhs_zeropoint,
+                                      result_zeropoint);
 
       rhs_val = rescale_op;
     }
@@ -696,12 +708,12 @@
       rhs_zeropoint = rhs_quant_type.getZeroPoint();
 
       if (rhs_scale != result_scale) {
-        auto rescale_type =
+        RankedTensorType rescale_type =
             RankedTensorType::get(rhs_type.getShape(), result_quant_type);
 
-        auto rescale_op = buildRescale(rewriter, op, rescale_type, rhs_val,
-                                       rhs_scale / result_scale, rhs_zeropoint,
-                                       result_zeropoint);
+        Value rescale_op = buildRescale(rewriter, op, rescale_type, rhs_val,
+                                        rhs_scale / result_scale, rhs_zeropoint,
+                                        result_zeropoint);
 
         rhs_val = rescale_op;
       }
@@ -712,14 +724,15 @@
         rewriter.getI64IntegerAttr(axis));
   }
 
-  return concat_op;
+  return concat_op.getResult();
 }
 
 // Lowers SpaceToBatchND to TOSA.
-Operation* convertSpaceToBatchNDOp(PatternRewriter& rewriter, Operation* op,
-                                   Value result_value, Value input_value,
-                                   Value block_shape_value,
-                                   Value paddings_value) {
+llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
+                                              Operation* op, Value result_value,
+                                              Value input_value,
+                                              Value block_shape_value,
+                                              Value paddings_value) {
   /////////////////////////////////////////////////
   // Operator: output = SpaceToBatchND(input, block_shape, paddings)
   // Lowering:
@@ -773,28 +786,31 @@
   //  shape=a3_shape)
   //
 
-  auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  auto block_shape_type =
+  RankedTensorType result_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType block_shape_type =
       block_shape_value.getType().dyn_cast<RankedTensorType>();
-  auto paddings_type = paddings_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType paddings_type =
+      paddings_value.getType().dyn_cast<RankedTensorType>();
 
   // Not a ranked tensor output.
   if (!result_type) {
     op->emitOpError("SpaceToBatchND: result type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
   if (!input_type) {
     op->emitOpError("SpaceToBatchND: input type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
   if (!block_shape_type) {
     op->emitOpError("SpaceToBatchND: block shape type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
   if (!paddings_type) {
     op->emitOpError("SpaceToBatchND: paddings type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
 
   // Follow implementation in
@@ -818,10 +834,10 @@
   ElementsAttr paddings_elems;
 
   if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems)))
-    return nullptr;
+    return llvm::None;
 
   if (!matchPattern(paddings_value, m_Constant(&paddings_elems)))
-    return nullptr;
+    return llvm::None;
 
   SmallVector<int32_t, 2> a0_pad_const(2 * (input_rank));
   SmallVector<int64_t, 2> padded_shape(input_rank);
@@ -863,7 +879,7 @@
     padded_shape[i + block_rank + 1] = input_shape[i + block_rank + 1];
   }
 
-  auto a0_pad_const_attr_type =
+  RankedTensorType a0_pad_const_attr_type =
       RankedTensorType::get({(input_rank), 2}, rewriter.getIntegerType(32));
 
   // Create a const op to generate the tensor type for the input padding array
@@ -936,7 +952,7 @@
     a3_transpose_shape[i] = a2_shape[a3_perm[i]];
   }
 
-  auto a3_transpose_const =
+  Value a3_transpose_const =
       get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, a3_perm);
 
   auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
@@ -972,15 +988,19 @@
     a4_reshape_shape[1 + block_rank + i] = input_shape[1 + block_rank + i];
   }
 
-  return rewriter.create<tosa::ReshapeOp>(
-      op->getLoc(), result_type, a3_transpose_a2_op.getResult(),
-      rewriter.getI64ArrayAttr(a4_reshape_shape));
+  return rewriter
+      .create<tosa::ReshapeOp>(op->getLoc(), result_type,
+                               a3_transpose_a2_op.getResult(),
+                               rewriter.getI64ArrayAttr(a4_reshape_shape))
+      .getResult();
 }
 
 // Lowers BatchToSpaceND to TOSA.
-Operation* convertBatchToSpaceNDOp(PatternRewriter& rewriter, Operation* op,
-                                   Value result_value, Value input_value,
-                                   Value block_shape_value, Value crops_value) {
+llvm::Optional<Value> convertBatchToSpaceNDOp(PatternRewriter& rewriter,
+                                              Operation* op, Value result_value,
+                                              Value input_value,
+                                              Value block_shape_value,
+                                              Value crops_value) {
   /////////////////////////////////////////////////
   // Operator: output = BatchToSpaceND(input, block_shape, clips)
   // Lowering:
@@ -1027,27 +1047,30 @@
   // a4_slice_a3 = tosa.slice(input=a3_reshape_a2, start=a4_start,
   // size=a4_size)
 
-  auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  auto block_shape_type =
+  RankedTensorType result_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType block_shape_type =
       block_shape_value.getType().dyn_cast<RankedTensorType>();
-  auto crops_type = crops_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType crops_type =
+      crops_value.getType().dyn_cast<RankedTensorType>();
 
   if (!result_type) {
     op->emitOpError("BatchToSpaceND: result type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
   if (!input_type) {
     op->emitOpError("BatchToSpaceND: input type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
   if (!block_shape_type) {
     op->emitOpError("BatchToSpaceND: block shape type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
   if (!crops_type) {
     op->emitOpError("BatchToSpaceND: crops type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
 
   // Another 4-step process
@@ -1062,12 +1085,12 @@
 
   if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems))) {
     op->emitOpError("BatchToSpaceND: block_shape not a constant");
-    return nullptr;
+    return llvm::None;
   }
 
   if (!matchPattern(crops_value, m_Constant(&crops_elems))) {
     op->emitOpError("BatchToSpaceND: crops not a constant");
-    return nullptr;
+    return llvm::None;
   }
 
   SmallVector<int64_t, 4> block_shape(block_rank);
@@ -1148,7 +1171,7 @@
     a2_transpose_shape[i] = a1_shape[a2_perm[i]];
   }
 
-  auto a2_transpose_perm =
+  Value a2_transpose_perm =
       get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, a2_perm);
   auto a2_transpose_a1_op = rewriter.create<tosa::TransposeOp>(
       op->getLoc(),
@@ -1201,36 +1224,40 @@
     }
   }
 
-  return rewriter.create<tosa::SliceOp>(
-      op->getLoc(),
-      RankedTensorType::get(ArrayRef<int64_t>(a4_size_vals),
-                            result_type.getElementType()),
-      a3_reshape_a2.getResult(), rewriter.getI64ArrayAttr(a4_begin_vals),
-      rewriter.getI64ArrayAttr(a4_size_vals));
+  return rewriter
+      .create<tosa::SliceOp>(
+          op->getLoc(),
+          RankedTensorType::get(ArrayRef<int64_t>(a4_size_vals),
+                                result_type.getElementType()),
+          a3_reshape_a2.getResult(), rewriter.getI64ArrayAttr(a4_begin_vals),
+          rewriter.getI64ArrayAttr(a4_size_vals))
+      .getResult();
 }
 
 // Lowers ExpandDims to TOSA.
-Operation* convertExpandDimsOp(PatternRewriter& rewriter, Operation* op,
-                               Value result_value, Value input_value,
-                               Value dim_value) {
+llvm::Optional<Value> convertExpandDimsOp(PatternRewriter& rewriter,
+                                          Operation* op, Value result_value,
+                                          Value input_value, Value dim_value) {
   // Lowers to a reshape op with 1's inserted in the appropriate dimensions.
-  auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) {
     op->emitOpError("ExpandDims: output type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("ExpandDims: input type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
 
   auto input_shape = input_type.getShape();
 
   ElementsAttr dim_elem;
-  if (!matchPattern(dim_value, m_Constant(&dim_elem))) return nullptr;
+  if (!matchPattern(dim_value, m_Constant(&dim_elem))) return llvm::None;
 
   assert(dim_elem.getType().getRank() == 0 && "expected scalar tensor");
   int32_t dim = dim_elem.getValue<IntegerAttr>({}).getInt();
@@ -1253,27 +1280,31 @@
 
   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
 
-  return rewriter.create<tosa::ReshapeOp>(op->getLoc(), output_type,
-                                          input_value, shape_attr);
+  return rewriter
+      .create<tosa::ReshapeOp>(op->getLoc(), output_type, input_value,
+                               shape_attr)
+      .getResult();
 }
 
 // Lowers Squeeze to TOSA.
-Operation* convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
-                            Value result_value, Value input_value,
-                            SmallVector<int32_t, 8>& squeeze_dims) {
+llvm::Optional<Value> convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
+                                       Value result_value, Value input_value,
+                                       SmallVector<int32_t, 8>& squeeze_dims) {
   // Lowers to a reshape op where dimensions in squeeze_dims with size=1
   // are removed.
-  auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) {
     op->emitOpError("Squeeze: output type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("Squeeze: input type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
 
   auto input_shape = input_type.getShape();
@@ -1309,13 +1340,15 @@
 
   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
 
-  return rewriter.create<tosa::ReshapeOp>(op->getLoc(), output_type,
-                                          input_value, shape_attr);
+  return rewriter
+      .create<tosa::ReshapeOp>(op->getLoc(), output_type, input_value,
+                               shape_attr)
+      .getResult();
 }
 
 // Lowers ELU to a sequence of TOSA ops.
-Operation* convertEluOp(PatternRewriter& rewriter, Operation* op,
-                        Value result_value, Value features_value) {
+llvm::Optional<Value> convertEluOp(PatternRewriter& rewriter, Operation* op,
+                                   Value result_value, Value features_value) {
   // Lowers Elu using the following formula:
   // elu(x) = x < 0 ? (exp(x) - 1) : x
   // one = const({1});
@@ -1326,24 +1359,22 @@
   // a2 = sub(a1, one_bcast)
   // a3 = ge(x, zero_bcast)
   // a4 = select(a3, x, a2)
-  auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) {
     op->emitOpError("Elu: output type not ranked tensor");
-    return nullptr;
+    return llvm::None;
   }
 
   int32_t input_rank = output_type.getShape().size();
-  SmallVector<int64_t, 4> bcast_shape;
-  for (int i = 0; i < input_rank; i++) {
-    bcast_shape.push_back(1);
-  }
+  SmallVector<int64_t, 4> bcast_shape(input_rank, 1);
 
   // Can't directly create size=1, rank=rank(input) tensor because
   // it will be optimized out.  Instead, create rank0 tensor and reshape later.
-  auto one_const_op = getTosaConstTensorSingleF32(rewriter, op, 1.0);
+  Value one_const_op = getTosaConstTensorSingleF32(rewriter, op, 1.0);
 
-  auto zero_const_op = getTosaConstTensorSingleF32(rewriter, op, 0.0);
+  Value zero_const_op = getTosaConstTensorSingleF32(rewriter, op, 0.0);
 
   auto a1_exp_in_op =
       rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, features_value);
@@ -1356,14 +1387,16 @@
       RankedTensorType::get(output_type.getShape(), rewriter.getIntegerType(1)),
       features_value, zero_const_op);
 
-  return rewriter.create<tosa::SelectOp>(
-      op->getLoc(), output_type, a3_ge_in_zero_op.getResult(), features_value,
-      a2_sub_a1_one_op.getResult());
+  return rewriter
+      .create<tosa::SelectOp>(op->getLoc(), output_type,
+                              a3_ge_in_zero_op.getResult(), features_value,
+                              a2_sub_a1_one_op.getResult())
+      .getResult();
 }
 
 // Lowers Softmax to a sequence of TOSA ops.
-Operation* convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
-                            Value result_value, Value logits_value) {
+llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
+                                       Value result_value, Value logits_value) {
   // softmax = exp(logits) / reduce_sum(exp(logits), -1)
   //
   // or equivalently multiply exp(-max(logits)) to both numerator and
@@ -1375,13 +1408,15 @@
   // We'll use first version for direct fp lowering, and second version for
   // quantized lowering since second one we can restrict input to exp() be
   // negative, and thus LUT can always be within [0.0, 1.0].
-  auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
-  auto input_type = logits_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      logits_value.getType().dyn_cast<RankedTensorType>();
 
   // Not a ranked tensor input/output
   if (!output_type || !input_type) {
     op->emitOpError("Softmax: input and result not ranked tensors");
-    return nullptr;
+    return llvm::None;
   }
 
   // reduce_sum on last dimension
@@ -1403,17 +1438,17 @@
     auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
         true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
         -32768, 32767);
-    auto int16_logits_type =
+    RankedTensorType int16_logits_type =
         RankedTensorType::get(logits_shape, int16_element_qtype);
-    auto int32_logits_type =
+    RankedTensorType int32_logits_type =
         RankedTensorType::get(logits_shape, rewriter.getIntegerType(32));
-    auto int16_rsum_type =
+    RankedTensorType int16_rsum_type =
         RankedTensorType::get(rsum_shape, int16_element_qtype);
-    auto int32_rsum_type =
+    RankedTensorType int32_rsum_type =
         RankedTensorType::get(rsum_shape, rewriter.getIntegerType(32));
 
     // Step 1. get x - max(x)
-    auto op1_rescale_in =
+    Value op1_rescale_in =
         buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
                      in_quant_type.getZeroPoint(), 0);
 
@@ -1434,10 +1469,10 @@
       return std::lround(32768.0 * v);
     };
 
-    auto exp_table_const = getTosa1DConstTensorTable(rewriter, op, exp_func);
+    Value exp_table_const = getTosa1DConstTensorTable(rewriter, op, exp_func);
 
     // Step 2. rescale input
-    auto op4_rescale_op3 = buildRescale(
+    Value op4_rescale_op3 = buildRescale(
         rewriter, op, int16_logits_type, op3_sub_op1_op2.getResult(),
         in_quant_type.getScale() * 128.0 / exp_sample_grain, 0, 0);
 
@@ -1500,7 +1535,7 @@
       return std::lround(32768.0 * v);
     };
 
-    auto one_over_one_plus_x_table_const =
+    Value one_over_one_plus_x_table_const =
         getTosa1DConstTensorTable(rewriter, op, one_over_one_plus_x_func);
 
     auto op14_table_op13 = rewriter.create<tosa::TableOp>(
@@ -1508,11 +1543,11 @@
         one_over_one_plus_x_table_const);
 
     // Rescale sum(exp(x)) from 0.23 back to 0.16
-    auto op15_rescale_op14 = buildRescale(rewriter, op, int32_rsum_type,
-                                          op14_table_op13, 1.0 / 128.0, 0, 0);
+    Value op15_rescale_op14 = buildRescale(rewriter, op, int32_rsum_type,
+                                           op14_table_op13, 1.0 / 128.0, 0, 0);
 
     // Rescale exp(x) from 0.23 back to 0.16
-    auto op16_rescale_op5 =
+    Value op16_rescale_op5 =
         buildRescale(rewriter, op, int32_logits_type, op5_table_op4.getResult(),
                      1.0 / 128.0, 0, 0);
 
@@ -1529,12 +1564,10 @@
 
     // Step 7. output scaling, extra 1.0 / 256.0 since we keep extra 8 bits
     // in op9_sub_op8
-    auto op19_rescale_op18 = buildRescale(
-        rewriter, op, output_type, op18_rshift_op17_op9.getResult(),
-        1.0 / (out_quant_type.getScale() * 256.0), 0,
-        out_quant_type.getZeroPoint());
-
-    return op19_rescale_op18.getDefiningOp();
+    return buildRescale(rewriter, op, output_type,
+                        op18_rshift_op17_op9.getResult(),
+                        1.0 / (out_quant_type.getScale() * 256.0), 0,
+                        out_quant_type.getZeroPoint());
 
   } else {
     SmallVector<int64_t, 4> rsum_shape_v(input_type.getShape().begin(),
@@ -1550,7 +1583,7 @@
     // op4 = mul(op1, op3)
     auto op1_exp_in =
         rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
-    auto rsum_type =
+    RankedTensorType rsum_type =
         RankedTensorType::get(rsum_shape, output_type.getElementType());
 
     // Keep dims so we don't need to reshape later
@@ -1560,15 +1593,17 @@
     auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
         op->getLoc(), rsum_type, op2_reducesum_op1.getResult());
 
-    return rewriter.create<tosa::MulOp>(op->getLoc(), output_type,
-                                        op1_exp_in.getResult(),
-                                        op3_reciprocal_op2.getResult(), 0);
+    return rewriter
+        .create<tosa::MulOp>(op->getLoc(), output_type, op1_exp_in.getResult(),
+                             op3_reciprocal_op2.getResult(), 0)
+        .getResult();
   }
 }
 
 // Lowers LogSoftmax to a sequence of TOSA ops.
-Operation* convertLogSoftmaxOp(PatternRewriter& rewriter, Operation* op,
-                               Value result_value, Value logits_value) {
+llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter,
+                                          Operation* op, Value result_value,
+                                          Value logits_value) {
   // log_softmax = log(exp(logits) / reduce_sum(exp(logits), -1))
   // op1 = exp(logits)
   // op2 = reduce_sum(op1, -1)
@@ -1576,17 +1611,19 @@
   // op4 = mul(op1, op3)
   // op5 = log(op4)
 
-  auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) {
     op->emitOpError("LogSoftmax: output type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto input_type = op->getOperand(0).getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      op->getOperand(0).getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("LogSoftmax: input type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
   mlir::quant::UniformQuantizedType in_quant_type =
@@ -1597,7 +1634,7 @@
           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
   if (in_quant_type || out_quant_type) {
     op->emitOpError("Quantized log_softmax lowering not implemented yet");
-    return nullptr;
+    return llvm::None;
   }
 
   auto op1_exp_in =
@@ -1608,8 +1645,8 @@
   SmallVector<int64_t, 4> rsum_shape(output_type.getShape().begin(),
                                      output_type.getShape().end());
   rsum_shape[input_rank - 1] = 1;
-  auto rsum_type = RankedTensorType::get(ArrayRef<int64_t>(rsum_shape),
-                                         output_type.getElementType());
+  RankedTensorType rsum_type = RankedTensorType::get(
+      ArrayRef<int64_t>(rsum_shape), output_type.getElementType());
   // Keep dims so we don't need to reshape later
   auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
       op->getLoc(), rsum_type, op1_exp_in.getResult(),
@@ -1621,15 +1658,18 @@
       op->getLoc(), output_type, op1_exp_in.getResult(),
       op3_reciprocal_op2.getResult(), 0);
 
-  return rewriter.create<tosa::LogOp>(op->getLoc(), output_type,
-                                      op4_mul_op1_op3.getResult());
+  return rewriter
+      .create<tosa::LogOp>(op->getLoc(), output_type,
+                           op4_mul_op1_op3.getResult())
+      .getResult();
 }
 
 // Lowers SpaceToDepth to a sequence of TOSA ops.  Supports NHWC.
-Operation* convertSpaceToDepthOp(PatternRewriter& rewriter, Operation* op,
-                                 Value result_value, Value input_value,
-                                 IntegerAttr block_size_attr,
-                                 StringAttr data_format) {
+llvm::Optional<Value> convertSpaceToDepthOp(PatternRewriter& rewriter,
+                                            Operation* op, Value result_value,
+                                            Value input_value,
+                                            IntegerAttr block_size_attr,
+                                            StringAttr data_format) {
   // NHWC lowering version:
   // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1]//b, b, orig_shape[2]//b,
   // b, orig_shape[3]])
@@ -1637,30 +1677,32 @@
   // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1]//b, orig_shape[2]//b,
   // orig_shape[3]*b*b])
   // return a4
-  auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
 
   // Not a ranked tensor output.
   if (!output_type) {
     op->emitOpError("SpaceToDepth: output type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("SpaceToDepth: input type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
   if (input_type.getRank() != 4) {
     op->emitOpError("SpaceToDepth: input rank not 4.");
-    return nullptr;
+    return llvm::None;
   }
 
   auto input_shape = input_type.getShape();
 
   if (!block_size_attr) {  // This is a required parameter
     op->emitOpError("SpaceToDepth: block size attribute not set.");
-    return nullptr;
+    return llvm::None;
   }
 
   SmallVector<int64_t, 2> block_size;
@@ -1670,7 +1712,7 @@
 
   if (data_format.getValue().str() != "NHWC") {
     op->emitOpError("SpaceToDepth: data format not NHWC.");
-    return nullptr;
+    return llvm::None;
   }
 
   assert(block_size[0] * block_size[1] != 0);
@@ -1683,13 +1725,13 @@
   a_reshape_dims.push_back(block_size[1]);
   a_reshape_dims.push_back(input_shape[3]);
 
-  auto a_reshape_output_type = RankedTensorType::get(
+  RankedTensorType a_reshape_output_type = RankedTensorType::get(
       ArrayRef<int64_t>(a_reshape_dims), output_type.getElementType());
   auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
       op->getLoc(), a_reshape_output_type, input_value,
       rewriter.getI64ArrayAttr(a_reshape_dims));
 
-  auto a3_transpose_perm = get1DConstTensor<tosa::ConstOp, int32_t>(
+  Value a3_transpose_perm = get1DConstTensor<tosa::ConstOp, int32_t>(
       rewriter, op, {0, 1, 3, 2, 4, 5});
 
   auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
@@ -1702,18 +1744,21 @@
   a3_reshape_dims.push_back(input_shape[2] / block_size[1]);
   a3_reshape_dims.push_back(input_shape[3] * block_size[0] * block_size[1]);
 
-  auto a3_reshape_output_type = RankedTensorType::get(
+  RankedTensorType a3_reshape_output_type = RankedTensorType::get(
       ArrayRef<int64_t>(a3_reshape_dims), output_type.getElementType());
-  return rewriter.create<tosa::ReshapeOp>(
-      op->getLoc(), a3_reshape_output_type, a3_transpose_a2_op.getResult(),
-      rewriter.getI64ArrayAttr(a3_reshape_dims));
+  return rewriter
+      .create<tosa::ReshapeOp>(op->getLoc(), a3_reshape_output_type,
+                               a3_transpose_a2_op.getResult(),
+                               rewriter.getI64ArrayAttr(a3_reshape_dims))
+      .getResult();
 }
 
 // Lowers DepthToSpace to a sequence of TOSA ops.  Supports NHWC.
-Operation* convertDepthToSpaceOp(PatternRewriter& rewriter, Operation* op,
-                                 Value result_value, Value input_value,
-                                 IntegerAttr block_size_attr,
-                                 StringAttr data_format) {
+llvm::Optional<Value> convertDepthToSpaceOp(PatternRewriter& rewriter,
+                                            Operation* op, Value result_value,
+                                            Value input_value,
+                                            IntegerAttr block_size_attr,
+                                            StringAttr data_format) {
   // NHWC version
   // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1], orig_shape[2], b, b,
   // orig_shape[3] // (b*b)])
@@ -1722,26 +1767,28 @@
   // orig_shape[3] // (b*b)])
   // return a4
 
-  auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
 
   // Not a ranked tensor output
   if (!output_type) {
     op->emitOpError("DepthToSpace: output type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("DepthToSpace: input type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
-  if (input_type.getRank() != 4) return nullptr;
+  if (input_type.getRank() != 4) return llvm::None;
   auto input_shape = input_type.getShape();
 
   if (!block_size_attr) {  // This is a required parameter
     op->emitOpError("DepthToSpace: block size attribute not set.");
-    return nullptr;
+    return llvm::None;
   }
 
   SmallVector<int64_t, 2> block_size;
@@ -1750,7 +1797,7 @@
   if (!data_format) data_format = rewriter.getStringAttr("NHWC");
   if (data_format.getValue().str() != "NHWC") {
     op->emitOpError("DepthToSpace: data format not NHWC.");
-    return nullptr;
+    return llvm::None;
   }
 
   assert(block_size[0] * block_size[1] != 0);
@@ -1763,13 +1810,13 @@
   a_reshape_dims.push_back(block_size[1]);
   a_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
 
-  auto a_reshape_output_type = RankedTensorType::get(
+  RankedTensorType a_reshape_output_type = RankedTensorType::get(
       ArrayRef<int64_t>(a_reshape_dims), output_type.getElementType());
   auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
       op->getLoc(), a_reshape_output_type, input_value,
       rewriter.getI64ArrayAttr(a_reshape_dims));
 
-  auto a3_transpose_perm = get1DConstTensor<tosa::ConstOp, int32_t>(
+  Value a3_transpose_perm = get1DConstTensor<tosa::ConstOp, int32_t>(
       rewriter, op, {0, 1, 3, 2, 4, 5});
 
   auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
@@ -1782,31 +1829,36 @@
   a3_reshape_dims.push_back(input_shape[2] * block_size[1]);
   a3_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
 
-  auto a3_reshape_output_type = RankedTensorType::get(
+  RankedTensorType a3_reshape_output_type = RankedTensorType::get(
       ArrayRef<int64_t>(a3_reshape_dims), output_type.getElementType());
-  return rewriter.create<tosa::ReshapeOp>(
-      op->getLoc(), a3_reshape_output_type, a3_transpose_a2_op.getResult(),
-      rewriter.getI64ArrayAttr(a3_reshape_dims));
+  return rewriter
+      .create<tosa::ReshapeOp>(op->getLoc(), a3_reshape_output_type,
+                               a3_transpose_a2_op.getResult(),
+                               rewriter.getI64ArrayAttr(a3_reshape_dims))
+      .getResult();
 }
 
 // Lowers Split to a sequence of TOSA ops.
-Operation* convertSplitOp(PatternRewriter& rewriter, Operation* op,
-                          Value result_value, Value input_value,
-                          int32_t num_split, int32_t axis) {
+llvm::Optional<ValueRange> convertSplitOp(PatternRewriter& rewriter,
+                                          Operation* op, Value result_value,
+                                          Value input_value, int32_t num_split,
+                                          int32_t axis) {
   // This lowering creates num_split slice ops and ties them together
   // with IdentityN to get from an array of Operations to a single Operation
   // with a list of result tensors.
-  auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType result_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!result_type) {
     op->emitOpError("Split: output type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("Split: input type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
   auto input_shape = input_type.getShape();
@@ -1854,28 +1906,34 @@
 
   // Combine the sequence of tosa.slice() ops into a list
   // using the IdentityN operator
-  return rewriter.create<tosa::IdentityNOp>(
-      op->getLoc(), ArrayRef<Type>(outs_type_vec), results_vec);
+  return rewriter
+      .create<tosa::IdentityNOp>(op->getLoc(), ArrayRef<Type>(outs_type_vec),
+                                 results_vec)
+      .getResults();
 }
 
 // Lowers SplitV to a sequence of TOSA ops.
-Operation* convertSplitVOp(PatternRewriter& rewriter, Operation* op,
-                           Value result_value, Value input_value,
-                           SmallVector<int32_t, 4>& size_split, int32_t axis) {
+llvm::Optional<ValueRange> convertSplitVOp(PatternRewriter& rewriter,
+                                           Operation* op, Value result_value,
+                                           Value input_value,
+                                           SmallVector<int32_t, 4>& size_split,
+                                           int32_t axis) {
   // This lowering creates num_split slice ops and ties them together
   // with IdentityN to get from an array of Operations to a single Operation
   // with a list of result tensors.
-  auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType result_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!result_type) {
     op->emitOpError("SplitV: output type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     op->emitOpError("SplitV: input type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
   auto input_shape = input_type.getShape();
@@ -1931,18 +1989,18 @@
 
   // Combine the sequence of tosa.slice() ops into a list
   // using the IdentityN operator
-  return rewriter.create<tosa::IdentityNOp>(
-      op->getLoc(), ArrayRef<Type>(outs_type_vec), results_vec);
+  return rewriter
+      .create<tosa::IdentityNOp>(op->getLoc(), ArrayRef<Type>(outs_type_vec),
+                                 results_vec)
+      .getResults();
 }
 
 // Lowers StridedSlice to a sequence of TOSA ops.
-Operation* convertStridedSliceOp(PatternRewriter& rewriter, Operation* op,
-                                 Value result_value, Value input_value,
-                                 Value begin_value, Value end_value,
-                                 Value strides_value, int32_t begin_mask,
-                                 int32_t end_mask, int32_t ellipsis_mask,
-                                 int32_t new_axis_mask,
-                                 int32_t shrink_axis_mask) {
+llvm::Optional<Value> convertStridedSliceOp(
+    PatternRewriter& rewriter, Operation* op, Value result_value,
+    Value input_value, Value begin_value, Value end_value, Value strides_value,
+    int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
+    int32_t new_axis_mask, int32_t shrink_axis_mask) {
   // The mask arguments are bitmasks where bit [i] applies to
   // dimension [i] of the input tensor.
   //
@@ -1972,17 +2030,19 @@
   // to insert tosa.Reverse operators for this.
   assert(ellipsis_mask == 0);
 
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType result_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
 
   if (!result_type) {
     op->emitOpError("StridedSlice: output type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
   if (!input_type) {
     op->emitOpError("StridedSlice: input type not ranked tensor.");
-    return nullptr;
+    return llvm::None;
   }
 
   int32_t input_rank = input_type.getRank();
@@ -1993,15 +2053,15 @@
 
   if (getVectorFromValue32(begin_value, begin) != input_rank) {
     op->emitOpError("StridedSlice: begin doesn't match input_rank.");
-    return nullptr;
+    return llvm::None;
   }
   if (getVectorFromValue32(end_value, end) != input_rank) {
     op->emitOpError("StridedSlice: end doesn't match input_rank.");
-    return nullptr;
+    return llvm::None;
   }
   if (getVectorFromValue32(strides_value, strides) != input_rank) {
     op->emitOpError("StridedSlice: strides doesn't match input_rank.");
-    return nullptr;
+    return llvm::None;
   }
 
   SmallVector<int64_t, 2> a1_begin(input_rank), a1_size(input_rank);
@@ -2073,17 +2133,19 @@
       rewriter.getI64ArrayAttr(a3_size));
 
   // Step 4: reshape the now-strided tensor
-  return rewriter.create<tosa::ReshapeOp>(
-      op->getLoc(),
-      RankedTensorType::get(ArrayRef<int64_t>(a4_shape),
-                            input_type.getElementType()),
-      a3_slice_op.getResult(), rewriter.getI64ArrayAttr(a4_shape));
+  return rewriter
+      .create<tosa::ReshapeOp>(
+          op->getLoc(),
+          RankedTensorType::get(ArrayRef<int64_t>(a4_shape),
+                                input_type.getElementType()),
+          a3_slice_op.getResult(), rewriter.getI64ArrayAttr(a4_shape))
+      .getResult();
 }
 
 // Lowers FloorDiv to a sequence of TOSA operators.
-Operation* convertFloorDivOp(PatternRewriter& rewriter, Operation* op,
-                             Value result_value, Value lhs_value,
-                             Value rhs_value) {
+llvm::Optional<Value> convertFloorDivOp(PatternRewriter& rewriter,
+                                        Operation* op, Value result_value,
+                                        Value lhs_value, Value rhs_value) {
   // FloorDiv lowering:
   // floor(1/rhs * lhs)
   //
@@ -2091,23 +2153,26 @@
   // a2 = mul(lhs, a1);
   // a3 = floor(a2);
   // return a3;
-  auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
-  if (!output_type) return nullptr;
+  if (!output_type) return llvm::None;
 
   auto a1_reciprocal_rhs_op =
       rewriter.create<tosa::ReciprocalOp>(op->getLoc(), output_type, rhs_value);
   auto a2_mul_lhs_a1_op =
       rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
                                    a1_reciprocal_rhs_op.getResult(), 0);
-  return rewriter.create<tosa::FloorOp>(op->getLoc(), output_type,
-                                        a2_mul_lhs_a1_op.getResult());
+  return rewriter
+      .create<tosa::FloorOp>(op->getLoc(), output_type,
+                             a2_mul_lhs_a1_op.getResult())
+      .getResult();
 }
 
 // Lowers FloorMod to a sequence of TOSA operators.
-Operation* convertFloorModOp(PatternRewriter& rewriter, Operation* op,
-                             Value result_value, Value lhs_value,
-                             Value rhs_value) {
+llvm::Optional<Value> convertFloorModOp(PatternRewriter& rewriter,
+                                        Operation* op, Value result_value,
+                                        Value lhs_value, Value rhs_value) {
   // FloorMod lowering:
   // (1/rhs * lhs) - floor(1/rhs * lhs)
   // a1 = reciprocal(rhs);
@@ -2116,9 +2181,10 @@
   // a4 = sub(a2, a3);
   // return a4;
 
-  auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
-  if (!output_type) return nullptr;
+  if (!output_type) return llvm::None;
 
   auto a1_reciprocal_rhs_op =
       rewriter.create<tosa::ReciprocalOp>(op->getLoc(), output_type, rhs_value);
@@ -2127,17 +2193,20 @@
                                    a1_reciprocal_rhs_op.getResult(), 0);
   auto a3_floor_a2_op = rewriter.create<tosa::FloorOp>(
       op->getLoc(), output_type, a2_mul_lhs_a1_op.getResult());
-  return rewriter.create<tosa::SubOp>(op->getLoc(), output_type,
-                                      a2_mul_lhs_a1_op.getResult(),
-                                      a3_floor_a2_op.getResult());
+  return rewriter
+      .create<tosa::SubOp>(op->getLoc(), output_type,
+                           a2_mul_lhs_a1_op.getResult(),
+                           a3_floor_a2_op.getResult())
+      .getResult();
 }
 
 // Lowers FusedActivation to a sequence of TOSA ops.
-Operation* convertFusedActivation(PatternRewriter& rewriter, Operation* op,
-                                  Value input_value,
-                                  StringAttr fused_activation_fn) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertFusedActivation(PatternRewriter& rewriter,
+                                             Operation* op, Value input_value,
+                                             StringAttr fused_activation_fn) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
   bool input_is_qtype =
       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
@@ -2149,16 +2218,18 @@
     if (fused_activation_fn.getValue() == "TANH") {
       // TODO: implement with TABLE
       op->emitWarning("Quantized TANH lowering TBD!");
-      return nullptr;
+      return llvm::None;
     } else {
-      auto rescale_type = RankedTensorType::get(input_type.getShape(),
-                                                rewriter.getIntegerType(32));
+      RankedTensorType rescale_type = RankedTensorType::get(
+          input_type.getShape(), rewriter.getIntegerType(32));
 
-      auto op1_rescale_in = buildRescaleToInt32(rewriter, op, input_value, 1.0f,
-                                                input_qtype.getZeroPoint());
+      Value op1_rescale_in = buildRescaleToInt32(
+          rewriter, op, input_value, 1.0f, input_qtype.getZeroPoint());
 
       Value op2_relu_op1;
-      if (fused_activation_fn.getValue() == "RELU") {
+      if (fused_activation_fn.getValue() == "NONE") {
+        return input_value;
+      } else if (fused_activation_fn.getValue() == "RELU") {
         auto relu_op = rewriter.create<tosa::ReluNOp>(
             op->getLoc(), rescale_type, op1_rescale_in,
             rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
@@ -2191,52 +2262,58 @@
 
         op2_relu_op1 = relu_op.getResult();
       } else {
-        return nullptr;
+        return llvm::None;
       }
 
-      auto op3_rescale_op2 =
-          buildRescaleFromInt32(rewriter, op, input_type, op2_relu_op1, 1.0f,
-                                input_qtype.getZeroPoint());
-
-      return op3_rescale_op2.getDefiningOp();
+      return buildRescaleFromInt32(rewriter, op, input_type, op2_relu_op1, 1.0f,
+                                   input_qtype.getZeroPoint());
     }
   } else {
-    if (fused_activation_fn.getValue() == "RELU") {
-      return rewriter.create<tosa::ReluNOp>(
-          op->getLoc(), input_type, input_value,
-          rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
-          rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
+    if (fused_activation_fn.getValue() == "NONE") {
+      return input_value;
+    } else if (fused_activation_fn.getValue() == "RELU") {
+      return rewriter
+          .create<tosa::ReluNOp>(
+              op->getLoc(), input_type, input_value,
+              rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
+              rewriter.getF32FloatAttr(std::numeric_limits<float>::max()))
+          .getResult();
     } else if (fused_activation_fn.getValue() == "RELU6") {
-      return rewriter.create<tosa::ReluNOp>(
-          op->getLoc(), input_type, input_value, rewriter.getI64IntegerAttr(6),
-          rewriter.getF32FloatAttr(6.0));
+      return rewriter
+          .create<tosa::ReluNOp>(op->getLoc(), input_type, input_value,
+                                 rewriter.getI64IntegerAttr(6),
+                                 rewriter.getF32FloatAttr(6.0))
+          .getResult();
     } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
-      return rewriter.create<tosa::ClampOp>(
-          op->getLoc(), input_type, input_value, rewriter.getI64IntegerAttr(-1),
-          rewriter.getI64IntegerAttr(1), rewriter.getF32FloatAttr(-1.0),
-          rewriter.getF32FloatAttr(1.0));
+      return rewriter
+          .create<tosa::ClampOp>(
+              op->getLoc(), input_type, input_value,
+              rewriter.getI64IntegerAttr(-1), rewriter.getI64IntegerAttr(1),
+              rewriter.getF32FloatAttr(-1.0), rewriter.getF32FloatAttr(1.0))
+          .getResult();
     } else if (fused_activation_fn.getValue() == "TANH") {
-      return rewriter.create<tosa::TanhOp>(op->getLoc(), input_type,
-                                           input_value);
+      return rewriter
+          .create<tosa::TanhOp>(op->getLoc(), input_type, input_value)
+          .getResult();
     } else {
       // Unsupported activation type. Bail out.
-      return nullptr;
+      return llvm::None;
     }
   }
 
-  return nullptr;
+  return llvm::None;
 }
 
 // Common function for lowering reduce operations to TOSA ops.
 template <typename T>
-Value convertReduceOpCommon(PatternRewriter& rewriter, Operation* op,
-                            RankedTensorType output_type, Value input_value,
-                            ElementsAttr axes_elems, bool keep_dims,
-                            Type reduce_element_type, bool is_quantized,
-                            double input_scale, int64_t input_zp,
-                            double output_scale, int64_t output_zp) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertReduceOpCommon(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims,
+    Type reduce_element_type, bool is_quantized, double input_scale,
+    int64_t input_zp, double output_scale, int64_t output_zp) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
   ArrayRef<int64_t> input_shape = input_type.getShape();
   ArrayRef<int64_t> output_shape = output_type.getShape();
@@ -2262,7 +2339,7 @@
       auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
 
       shape_vec[axis_val] = 1;
-      auto reduce_type = RankedTensorType::get(
+      RankedTensorType reduce_type = RankedTensorType::get(
           llvm::makeArrayRef<int64_t>(shape_vec), reduce_element_type);
 
       auto reduce_op =
@@ -2272,7 +2349,7 @@
     }
 
     if (is_quantized) {
-      auto output_rescale_type = RankedTensorType::get(
+      RankedTensorType output_rescale_type = RankedTensorType::get(
           llvm::makeArrayRef<int64_t>(shape_vec), output_type.getElementType());
       val = buildRescaleFromInt32(rewriter, op, output_rescale_type, val,
                                   output_scale, output_zp);
@@ -2291,67 +2368,64 @@
 }
 
 // Lowers ReduceAll to a sequence of TOSA ops.
-Operation* convertReduceAllOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              ElementsAttr axes_elems, bool keep_dims) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertReduceAllOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
-  Value val = convertReduceOpCommon<tosa::ReduceAllOp>(
+  return convertReduceOpCommon<tosa::ReduceAllOp>(
       rewriter, op, output_type, input_value, axes_elems, keep_dims,
       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
-
-  return val.getDefiningOp();
 }
 
 // Lowers ReduceAny to a sequence of TOSA ops.
-Operation* convertReduceAnyOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              ElementsAttr axes_elems, bool keep_dims) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertReduceAnyOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
-  Value val = convertReduceOpCommon<tosa::ReduceAnyOp>(
+  return convertReduceOpCommon<tosa::ReduceAnyOp>(
       rewriter, op, output_type, input_value, axes_elems, keep_dims,
       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
-
-  return val.getDefiningOp();
 }
 
 // Lowers ReduceMin to a sequence of TOSA ops.
-Operation* convertReduceMinOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              ElementsAttr axes_elems, bool keep_dims) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertReduceMinOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
-  Value val = convertReduceOpCommon<tosa::ReduceMinOp>(
+  return convertReduceOpCommon<tosa::ReduceMinOp>(
       rewriter, op, output_type, input_value, axes_elems, keep_dims,
       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
-
-  return val.getDefiningOp();
 }
 
 // Lowers ReduceMax to a sequence of TOSA ops.
-Operation* convertReduceMaxOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              ElementsAttr axes_elems, bool keep_dims) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertReduceMaxOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
-  Value val = convertReduceOpCommon<tosa::ReduceMaxOp>(
+  return convertReduceOpCommon<tosa::ReduceMaxOp>(
       rewriter, op, output_type, input_value, axes_elems, keep_dims,
       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
-
-  return val.getDefiningOp();
 }
 
 // Lowers ReduceProd to a sequence of TOSA ops.
-Operation* convertReduceProdOp(PatternRewriter& rewriter, Operation* op,
-                               RankedTensorType output_type, Value input_value,
-                               ElementsAttr axes_elems, bool keep_dims) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertReduceProdOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
   bool input_is_qtype =
       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
@@ -2362,22 +2436,21 @@
     op->emitOpError(
         "ConvertReduceProdOp: input/output tensor should "
         "be all floating-point.");
-    return nullptr;
+    return llvm::None;
   }
 
-  Value val = convertReduceOpCommon<tosa::ReduceProdOp>(
+  return convertReduceOpCommon<tosa::ReduceProdOp>(
       rewriter, op, output_type, input_value, axes_elems, keep_dims,
       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
-
-  return val.getDefiningOp();
 }
 
 // Lowers ReduceSum to a sequence of TOSA ops.
-Operation* convertReduceSumOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              ElementsAttr axes_elems, bool keep_dims) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertReduceSumOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
   bool input_is_qtype =
       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
@@ -2388,7 +2461,7 @@
     op->emitOpError(
         "ConvertReduceSumOp: input/output tensor should "
         "be all quantized or all floating-point.");
-    return nullptr;
+    return llvm::None;
   }
 
   double input_scale = 1.0f;
@@ -2415,24 +2488,23 @@
     reduce_element_type = rewriter.getI32Type();
   }
 
-  Value val = convertReduceOpCommon<tosa::ReduceSumOp>(
+  return convertReduceOpCommon<tosa::ReduceSumOp>(
       rewriter, op, output_type, input_value, axes_elems, keep_dims,
       reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
       output_zp);
-
-  return val.getDefiningOp();
 }
 
 // Lowers ReduceMean to a sequence of TOSA ops.
-Operation* convertReduceMeanOp(PatternRewriter& rewriter, Operation* op,
-                               RankedTensorType output_type, Value input_value,
-                               ElementsAttr axes_elems, bool keep_dims) {
+llvm::Optional<Value> convertReduceMeanOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims) {
   // reduce_mean is lowered as followed:
   // op1 = reduce_sum(input)
   // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
 
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
   bool input_is_qtype =
       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
@@ -2443,7 +2515,7 @@
     op->emitOpError(
         "ConvertReduceSumOp: input/output tensor should "
         "be all quantized or all floating-point.");
-    return nullptr;
+    return llvm::None;
   }
 
   // Only supports float type mean() if it's non-quantized
@@ -2451,7 +2523,7 @@
     op->emitWarning(
         "Failed convertReduceMean: input unquantized type but output element "
         "not FloatType!");
-    return nullptr;
+    return llvm::None;
   }
 
   int64_t input_rank = input_type.getRank();
@@ -2487,27 +2559,31 @@
     reduce_element_type = rewriter.getI32Type();
   }
 
-  Value val = convertReduceOpCommon<tosa::ReduceSumOp>(
+  auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
       rewriter, op, output_type, input_value, axes_elems, keep_dims,
       reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
       output_zp);
 
+  if (!val.hasValue()) return llvm::None;
+
   if (!input_is_qtype) {
     Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
-    auto mul_op = rewriter.create<tosa::MulOp>(op->getLoc(), output_type, val,
-                                               div_const, 0);
-    val = mul_op.getResult();
+    return rewriter
+        .create<tosa::MulOp>(op->getLoc(), output_type, val.getValue(),
+                             div_const, 0)
+        .getResult();
   }
 
-  return val.getDefiningOp();
+  return val;
 }
 
 // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
-Operation* convertResizeOp(PatternRewriter& rewriter, Operation* op,
-                           RankedTensorType output_type, Value input_value,
-                           StringRef mode) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertResizeOp(PatternRewriter& rewriter, Operation* op,
+                                      RankedTensorType output_type,
+                                      Value input_value, StringRef mode) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
   auto input_shape = input_type.getShape();
   auto output_shape = output_type.getShape();
@@ -2526,13 +2602,13 @@
     op->emitOpError(
         "ConvertResizeOp: input/output tensor should "
         "be all quantized or all floating-point.");
-    return nullptr;
+    return llvm::None;
   }
 
   if (!input_is_qtype) {
     // TODO: support float type
     op->emitOpError("ConvertResizeOp: floating-point type not supported yet ");
-    return nullptr;
+    return llvm::None;
   }
 
   int32_t shift = 11;  // Set default shift to maximum allowed
@@ -2559,17 +2635,22 @@
   IntegerAttr shift_attr = rewriter.getI32IntegerAttr(shift);
   StringAttr resize_mode = rewriter.getStringAttr(mode.str());
 
-  return rewriter.create<tosa::ResizeOp>(op->getLoc(), output_type, input_value,
-                                         output_size, stride, offset,
-                                         shift_attr, resize_mode);
+  return rewriter
+      .create<tosa::ResizeOp>(op->getLoc(), output_type, input_value,
+                              output_size, stride, offset, shift_attr,
+                              resize_mode)
+      .getResult();
 }
 
 // Lowers Quantize to a sequence of TOSA quantization ops.
-Operation* convertQuantizeOp(PatternRewriter& rewriter, Operation* op,
-                             RankedTensorType output_type, Value input_value,
-                             double scale, int64_t zeropoint) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertQuantizeOp(PatternRewriter& rewriter,
+                                        Operation* op,
+                                        RankedTensorType output_type,
+                                        Value input_value, double scale,
+                                        int64_t zeropoint) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
   auto output_shape = output_type.getShape();
   auto output_element_type = output_type.getElementType();
@@ -2578,10 +2659,10 @@
   if (!output_element_type.isa<mlir::quant::QuantizedType>()) {
     op->emitWarning(
         "Lowering quantizeOp but output element type not quantized!");
-    return nullptr;
+    return llvm::None;
   }
 
-  auto output_fp_type =
+  RankedTensorType output_fp_type =
       RankedTensorType::get(output_shape, rewriter.getF32Type());
 
   Value zp_val =
@@ -2602,22 +2683,23 @@
   auto op3_cast_op2 = rewriter.create<tosa::CastOp>(
       op->getLoc(), output_int32_type, op2_add_op1.getResult());
 
-  auto op4_rescale_op3 = buildRescale(rewriter, op, output_type,
-                                      op3_cast_op2.getResult(), 1.0, 0, 0);
-
-  return op4_rescale_op3.getDefiningOp();
+  return buildRescale(rewriter, op, output_type, op3_cast_op2.getResult(), 1.0,
+                      0, 0);
 }
 
 // Lowers Dequantize to a sequence of TOSA dequantization ops.
-Operation* convertDequantizeOp(PatternRewriter& rewriter, Operation* op,
-                               RankedTensorType output_type, Value input_value,
-                               double scale, int64_t zeropoint) {
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+llvm::Optional<Value> convertDequantizeOp(PatternRewriter& rewriter,
+                                          Operation* op,
+                                          RankedTensorType output_type,
+                                          Value input_value, double scale,
+                                          int64_t zeropoint) {
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
   // input element type could only be quantized integer
   if (!input_type.getElementType().isa<mlir::quant::QuantizedType>())
-    return nullptr;
+    return llvm::None;
 
   auto output_shape = output_type.getShape();
 
@@ -2629,7 +2711,7 @@
 
   // TOSA doesn't support CAST AINT8 -> FLOAT, need to RESCALE to INT32
   // followed by a CAST
-  auto op1_rescale_in =
+  Value op1_rescale_in =
       buildRescale(rewriter, op, output_int32_type, input_value, 1.0, 0, 0);
 
   auto op2_cast_op1 =
@@ -2638,27 +2720,33 @@
   auto op3_sub_op2 = rewriter.create<tosa::SubOp>(
       op->getLoc(), output_type, op2_cast_op1.getResult(), zp_val);
 
-  return rewriter.create<tosa::MulOp>(
-      op->getLoc(), output_type, op3_sub_op2.getResult(),
-      getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)), 0);
+  return rewriter
+      .create<tosa::MulOp>(
+          op->getLoc(), output_type, op3_sub_op2.getResult(),
+          getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)),
+          0)
+      .getResult();
 }
 
 // Lowers FakeQuant to a sequence of TOSA quantization ops.
-Operation* convertFakeQuantOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              double min, double max, int64_t num_bits,
-                              bool narrow_range) {
+llvm::Optional<Value> convertFakeQuantOp(PatternRewriter& rewriter,
+                                         Operation* op,
+                                         RankedTensorType output_type,
+                                         Value input_value, double min,
+                                         double max, int64_t num_bits,
+                                         bool narrow_range) {
   // FakeQuant is lowered as follow:
   // op1 = quantize(input)
   // op2 = dequantize(op1)
 
-  auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
-  if (!input_type) return nullptr;
+  RankedTensorType input_type =
+      input_value.getType().dyn_cast<RankedTensorType>();
+  if (!input_type) return llvm::None;
 
   // quantized as INT<num_bits>, where num_bits can only be 8, 16
   if (num_bits != 8 && num_bits != 16) {
     op->emitWarning("FakeQuantOp lowering handles only 8 and 16 for num_bits!");
-    return nullptr;
+    return llvm::None;
   }
 
   auto output_shape = output_type.getShape();
@@ -2672,30 +2760,33 @@
   auto int_element_qtype = mlir::quant::UniformQuantizedType::get(
       true, rewriter.getIntegerType(num_bits), rewriter.getF32Type(), 1.0f, 0,
       qmin, qmax);
-  auto output_int_type = RankedTensorType::get(output_shape, int_element_qtype);
+  RankedTensorType output_int_type =
+      RankedTensorType::get(output_shape, int_element_qtype);
 
   double scale = (max - min) / static_cast<double>(qmax - qmin);
   int64_t zeropoint = std::llround((-min) / scale + static_cast<double>(qmin));
 
   // Quantize: round(x / scale + zeropoint)
-  auto quantized_op = convertQuantizeOp(rewriter, op, output_int_type,
-                                        input_value, 1.0 / scale, zeropoint);
+  auto quantized_val = convertQuantizeOp(rewriter, op, output_int_type,
+                                         input_value, 1.0 / scale, zeropoint);
+
+  if (!quantized_val.hasValue()) return llvm::None;
 
   // Dequantize: ((float)x - zeropoint) * scale
   return convertDequantizeOp(rewriter, op, output_type,
-                             quantized_op->getResult(0), scale, zeropoint);
+                             quantized_val.getValue(), scale, zeropoint);
 }
 
-Operation* convertTFConv2DCommon(
+llvm::Optional<Value> convertTFConv2DCommon(
     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
     Value input, Value filter, Value bias, ArrayAttr strides_attr,
     ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
     StringRef padding_ref, StringRef data_format_ref) {
-  auto input_type = input.getType().dyn_cast<RankedTensorType>();
-  auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
+  RankedTensorType filter_type = filter.getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
-  if (!input_type) return nullptr;
-  if (!filter_type) return nullptr;
+  if (!input_type) return llvm::None;
+  if (!filter_type) return llvm::None;
 
   // Transpose [H, W, I, O] to [O, H, W, I]
   auto filter_shape = filter_type.getShape();
@@ -2704,7 +2795,7 @@
   a1_transpose_dims.push_back(filter_shape[0]);
   a1_transpose_dims.push_back(filter_shape[1]);
   a1_transpose_dims.push_back(filter_shape[2]);
-  auto a1_filter_transpose_perm =
+  Value a1_filter_transpose_perm =
       get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {3, 0, 1, 2});
   auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
       op->getLoc(),
@@ -2715,7 +2806,7 @@
   // Only support NHWC now.
   if (data_format_ref.str() != "NHWC") {
     op->emitWarning("convertTDConv2DCommon only supports NHWC!");
-    return nullptr;
+    return llvm::None;
   }
 
   ArrayAttr stride;
@@ -2745,12 +2836,12 @@
     tensorflow::Padding tf_pad;
     if (!GetPaddingFromString(padding_ref.str(), &tf_pad).ok()) {
       op->emitWarning("Could not get padding data from padding string term!");
-      return nullptr;
+      return llvm::None;
     }
 
     tensorflow::TensorFormat data_format_tf;
     if (!FormatFromString(data_format_ref.str(), &data_format_tf))
-      return nullptr;
+      return llvm::None;
 
     if (tf_pad == tensorflow::Padding::EXPLICIT) {
       pad = getPaddingValuesFromExplicitPadAttr(explicit_padding_attr,
@@ -2760,13 +2851,15 @@
                                        0,  // tensorflow::FORMAT_HWIO
                                        input_type, filter_type, stride,
                                        dilation, rewriter, pad))
-        return nullptr;
+        return llvm::None;
     }
   }
 
-  return rewriter.create<tosa::Conv2DOp>(op->getLoc(), output_type, input,
-                                         a1_filter_transpose_op.getResult(),
-                                         bias, pad, stride, dilation);
+  return rewriter
+      .create<tosa::Conv2DOp>(op->getLoc(), output_type, input,
+                              a1_filter_transpose_op.getResult(), bias, pad,
+                              stride, dilation)
+      .getResult();
 }
 
 };  // namespace tosa
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h
index af95277..d5ef518 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h
@@ -4,7 +4,7 @@
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
 
-   http://www.apache.org/licenses/LICENSE-2.0
+    http://www.apache.org/licenses/LICENSE-2.0
 
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
@@ -16,221 +16,222 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H
 #define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H
 
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+
 // This file contains legalizations common to mapping both TensorFlow and
 // TensorFlow Lite to TOSA.
 //
-// Conversion functions return nullptr on a lowerization failure or a lowered
-// operator on success.   Callers must check and return a LogicalResult failure
-// on nullptr.
+// Conversion functions return None on a failure or result value on success.
+// Callers must check and return a LogicalResult failure on nullptr.
 //
 // For these functions, the framework-specific operands/attributes/defaults
 // are already extracted and placed in a common form for lowering.
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/Optional.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/StringSwitch.h"
-#include "llvm/Support/FormatVariadic.h"
 
 namespace mlir {
 namespace tosa {
 
 // Lowers the Pack operator to TOSA.
-Operation* convertPackOp(PatternRewriter& rewriter, Operation* op,
-                         Value result_value, SmallVector<Value, 8>& inputs,
-                         int32_t axis);
+llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
+                                    Value result_value,
+                                    SmallVector<Value, 8>& inputs,
+                                    int32_t axis);
 
 // Lowers the Unpack operator to TOSA.
-Operation* convertUnpackOp(PatternRewriter& rewriter, Operation* op,
-                           Value input_value, int32_t axis);
+llvm::Optional<ValueRange> convertUnpackOp(PatternRewriter& rewriter,
+                                           Operation* op, Value input_value,
+                                           int32_t axis);
 
 // Lowers the Select operator to TOSA.
-Operation* convertSelectOp(PatternRewriter& rewriter, Operation* op,
-                           Value result_value, Value condition_value,
-                           Value x_value, Value y_value);
+llvm::Optional<Value> convertSelectOp(PatternRewriter& rewriter, Operation* op,
+                                      Value result_value, Value condition_value,
+                                      Value x_value, Value y_value);
 
 // Lowers the ZerosLike operator to TOSA by creating a constant
 // of the desired type and shape.
-Operation* convertZerosLikeOp(PatternRewriter& rewriter, Operation* op,
-                              Value result, Value input);
+llvm::Optional<Value> convertZerosLikeOp(PatternRewriter& rewriter,
+                                         Operation* op, Value result,
+                                         Value input);
 
 // Lowers the Mul operator to TOSA.  For quantized types, this requires
 // inserting rescale operators before and after the operation.
-Operation* convertMultiplyOp(PatternRewriter& rewriter, Operation* op,
-                             Value output_val, Value input_lhs_val,
-                             Value input_rhs_val);
+llvm::Optional<Value> convertMultiplyOp(PatternRewriter& rewriter,
+                                        Operation* op, Value output_val,
+                                        Value input_lhs_val,
+                                        Value input_rhs_val);
 
 // Lowers the SquaredDifference operator to TOSA.
-Operation* convertSquaredDifferenceOp(PatternRewriter& rewriter, Operation* op,
-                                      Value result, Value x, Value y);
+llvm::Optional<Value> convertSquaredDifferenceOp(PatternRewriter& rewriter,
+                                                 Operation* op, Value result,
+                                                 Value x, Value y);
 
 // Lowers the Round operator to TOSA.
-Operation* convertRoundOp(PatternRewriter& rewriter, Operation* op,
-                          Value result, Value input);
+llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op,
+                                     Value result, Value input);
 
 // Lowers ConcatV2 to TOSA.
-Operation* convertConcatV2Op(PatternRewriter& rewriter, Operation* op,
-                             Value result_value, SmallVector<Value, 8>& values,
-                             int32_t axis);
+llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
+                                        Operation* op, Value result_value,
+                                        SmallVector<Value, 8>& values,
+                                        int32_t axis);
 
 // Lowers SpaceToBatchND to TOSA.
-Operation* convertSpaceToBatchNDOp(PatternRewriter& rewriter, Operation* op,
-                                   Value result_value, Value input_value,
-                                   Value block_shape_value,
-                                   Value paddings_value);
+llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
+                                              Operation* op, Value result_value,
+                                              Value input_value,
+                                              Value block_shape_value,
+                                              Value paddings_value);
 
 // Lowers BatchToSpaceND to TOSA.
-Operation* convertBatchToSpaceNDOp(PatternRewriter& rewriter, Operation* op,
-                                   Value result_value, Value input_value,
-                                   Value block_shape_value, Value crops_value);
+llvm::Optional<Value> convertBatchToSpaceNDOp(PatternRewriter& rewriter,
+                                              Operation* op, Value result_value,
+                                              Value input_value,
+                                              Value block_shape_value,
+                                              Value crops_value);
 
 // Lowers ExpandDims to TOSA.
-Operation* convertExpandDimsOp(PatternRewriter& rewriter, Operation* op,
-                               Value result_value, Value input_value,
-                               Value dim_value);
+llvm::Optional<Value> convertExpandDimsOp(PatternRewriter& rewriter,
+                                          Operation* op, Value result_value,
+                                          Value input_value, Value dim_value);
 
 // Lowers Squeeze to TOSA.
-Operation* convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
-                            Value result_value, Value input_value,
-                            SmallVector<int32_t, 8>& squeeze_dims);
+llvm::Optional<Value> convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
+                                       Value result_value, Value input_value,
+                                       SmallVector<int32_t, 8>& squeeze_dims);
 
 // Lowers ELU to a sequence of TOSA ops.
-Operation* convertEluOp(PatternRewriter& rewriter, Operation* op,
-                        Value result_value, Value features_value);
+llvm::Optional<Value> convertEluOp(PatternRewriter& rewriter, Operation* op,
+                                   Value result_value, Value features_value);
 
 // Lowers Softmax to a sequence of TOSA ops.
-Operation* convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
-                            Value result_value, Value logits_value);
+llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
+                                       Value result_value, Value logits_value);
 
 // Lowers LogSoftmax to a sequence of TOSA ops.
-Operation* convertLogSoftmaxOp(PatternRewriter& rewriter, Operation* op,
-                               Value result_value, Value logits_value);
+llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter,
+                                          Operation* op, Value result_value,
+                                          Value logits_value);
 
 // Lowers SpaceToDepth to a sequence of TOSA ops.  Supports NHWC.
-Operation* convertSpaceToDepthOp(PatternRewriter& rewriter, Operation* op,
-                                 Value result_value, Value input_value,
-                                 IntegerAttr block_size_attr,
-                                 StringAttr data_format);
+llvm::Optional<Value> convertSpaceToDepthOp(PatternRewriter& rewriter,
+                                            Operation* op, Value result_value,
+                                            Value input_value,
+                                            IntegerAttr block_size_attr,
+                                            StringAttr data_format);
 
 // Lowers DepthToSpace to a sequence of TOSA ops.  Supports NHWC.
-Operation* convertDepthToSpaceOp(PatternRewriter& rewriter, Operation* op,
-                                 Value result_value, Value input_value,
-                                 IntegerAttr block_size_attr,
-                                 StringAttr data_format);
+llvm::Optional<Value> convertDepthToSpaceOp(PatternRewriter& rewriter,
+                                            Operation* op, Value result_value,
+                                            Value input_value,
+                                            IntegerAttr block_size_attr,
+                                            StringAttr data_format);
 
 // Lowers Split to a sequence of TOSA ops.
-Operation* convertSplitOp(PatternRewriter& rewriter, Operation* op,
-                          Value result_value, Value input_value,
-                          int32_t num_split, int32_t axis);
+llvm::Optional<ValueRange> convertSplitOp(PatternRewriter& rewriter,
+                                          Operation* op, Value result_value,
+                                          Value input_value, int32_t num_split,
+                                          int32_t axis);
 
 // Lowers SplitV to a sequence of TOSA ops.
-Operation* convertSplitVOp(PatternRewriter& rewriter, Operation* op,
-                           Value result_value, Value input_value,
-                           SmallVector<int32_t, 4>& size_split, int32_t axis);
+llvm::Optional<ValueRange> convertSplitVOp(PatternRewriter& rewriter,
+                                           Operation* op, Value result_value,
+                                           Value input_value,
+                                           SmallVector<int32_t, 4>& size_split,
+                                           int32_t axis);
 
 // Lowers StridedSlice to a sequence of TOSA ops.
-Operation* convertStridedSliceOp(PatternRewriter& rewriter, Operation* op,
-                                 Value result_value, Value input_value,
-                                 Value begin_value, Value end_value,
-                                 Value strides_value, int32_t begin_mask,
-                                 int32_t end_mask, int32_t ellipsis_mask,
-                                 int32_t new_axis_mask,
-                                 int32_t shrink_axis_mask);
+llvm::Optional<Value> convertStridedSliceOp(
+    PatternRewriter& rewriter, Operation* op, Value result_value,
+    Value input_value, Value begin_value, Value end_value, Value strides_value,
+    int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
+    int32_t new_axis_mask, int32_t shrink_axis_mask);
 
 // Lowers FloorDiv to a sequence of TOSA operators.
-Operation* convertFloorDivOp(PatternRewriter& rewriter, Operation* op,
-                             Value result_value, Value lhs_value,
-                             Value rhs_value);
+llvm::Optional<Value> convertFloorDivOp(PatternRewriter& rewriter,
+                                        Operation* op, Value result_value,
+                                        Value lhs_value, Value rhs_value);
 
 // Lowers FloorMod to a sequence of TOSA operators.
-Operation* convertFloorModOp(PatternRewriter& rewriter, Operation* op,
-                             Value result_value, Value lhs_value,
-                             Value rhs_value);
+llvm::Optional<Value> convertFloorModOp(PatternRewriter& rewriter,
+                                        Operation* op, Value result_value,
+                                        Value lhs_value, Value rhs_value);
 
 // Lowers FusedActivation to a sequence of TOSA ops.
-Operation* convertFusedActivation(PatternRewriter& rewriter, Operation* op,
-                                  Value input_value,
-                                  StringAttr fused_activation_fn);
+llvm::Optional<Value> convertFusedActivation(PatternRewriter& rewriter,
+                                             Operation* op, Value input_value,
+                                             StringAttr fused_activation_fn);
 
 // Helper function for implementing quantized divide by power-of-two in TOSA
 // ops.
-Operation* convertRoundingDivideByPOT(PatternRewriter& rewriter, Operation* op,
-                                      Value input_value, Value rshift_value);
+llvm::Optional<Value> convertRoundingDivideByPOT(PatternRewriter& rewriter,
+                                                 Operation* op,
+                                                 Value input_value,
+                                                 Value rshift_value);
 
 // Lowers ReduceAll to a sequence of TOSA ops.
-Operation* convertReduceAllOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              ElementsAttr axes_elems, bool keep_dims);
+llvm::Optional<Value> convertReduceAllOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims);
 
 // Lowers ReduceAny to a sequence of TOSA ops.
-Operation* convertReduceAnyOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              ElementsAttr axes_elems, bool keep_dims);
+llvm::Optional<Value> convertReduceAnyOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims);
 
 // Lowers ReduceMin to a sequence of TOSA ops.
-Operation* convertReduceMinOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              ElementsAttr axes_elems, bool keep_dims);
+llvm::Optional<Value> convertReduceMinOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims);
 
 // Lowers ReduceMax to a sequence of TOSA ops.
-Operation* convertReduceMaxOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              ElementsAttr axes_elems, bool keep_dims);
+llvm::Optional<Value> convertReduceMaxOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims);
 
 // Lowers ReduceProd to a sequence of TOSA ops.
-Operation* convertReduceProdOp(PatternRewriter& rewriter, Operation* op,
-                               RankedTensorType output_type, Value input_value,
-                               ElementsAttr axes_elems, bool keep_dims);
+llvm::Optional<Value> convertReduceProdOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims);
 
 // Lowers ReduceSum to a sequence of TOSA ops.
-Operation* convertReduceSumOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              ElementsAttr axes_elems, bool keep_dims);
+llvm::Optional<Value> convertReduceSumOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims);
 
 // Lowers ReduceMean to a sequence of TOSA ops.
-Operation* convertReduceMeanOp(PatternRewriter& rewriter, Operation* op,
-                               RankedTensorType output_type, Value input_value,
-                               ElementsAttr axes_elems, bool keep_dims);
+llvm::Optional<Value> convertReduceMeanOp(
+    PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+    Value input_value, ElementsAttr axes_elems, bool keep_dims);
 
 // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
-Operation* convertResizeOp(PatternRewriter& rewriter, Operation* op,
-                           RankedTensorType output_type, Value input_value,
-                           StringRef mode);
+llvm::Optional<Value> convertResizeOp(PatternRewriter& rewriter, Operation* op,
+                                      RankedTensorType output_type,
+                                      Value input_value, StringRef mode);
 
 // Lowers Quantize to a sequence of TOSA quantization ops.
-Operation* convertQuantizeOp(PatternRewriter& rewriter, Operation* op,
-                             RankedTensorType output_type, Value input_value,
-                             double scale, int64_t zeropoint);
+llvm::Optional<Value> convertQuantizeOp(PatternRewriter& rewriter,
+                                        Operation* op,
+                                        RankedTensorType output_type,
+                                        Value input_value, double scale,
+                                        int64_t zeropoint);
 
 // Lowers Dequantize to a sequence of TOSA dequantization ops.
-Operation* convertDequantizeOp(PatternRewriter& rewriter, Operation* op,
-                               RankedTensorType output_type, Value input_value,
-                               double scale, int64_t zeropoint);
+llvm::Optional<Value> convertDequantizeOp(PatternRewriter& rewriter,
+                                          Operation* op,
+                                          RankedTensorType output_type,
+                                          Value input_value, double scale,
+                                          int64_t zeropoint);
 
 // Lowers FakeQuant to a sequence of TOSA quantization ops.
-Operation* convertFakeQuantOp(PatternRewriter& rewriter, Operation* op,
-                              RankedTensorType output_type, Value input_value,
-                              double min, double max, int64_t num_bits,
-                              bool narrow_range);
-Operation* convertTFConv2DCommon(
+llvm::Optional<Value> convertFakeQuantOp(PatternRewriter& rewriter,
+                                         Operation* op,
+                                         RankedTensorType output_type,
+                                         Value input_value, double min,
+                                         double max, int64_t num_bits,
+                                         bool narrow_range);
+
+// Lowers TensorFlow Conv2D to a sequence of TOSA quantization ops.
+llvm::Optional<Value> convertTFConv2DCommon(
     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
     Value input, Value filter, Value bias, ArrayAttr strides_attr,
     ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
index e2cb347..1219e14 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
@@ -21,30 +21,9 @@
 #include <iterator>
 #include <numeric>
 
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/Optional.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/StringSwitch.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
@@ -70,10 +49,10 @@
 
 #define DECL_CONVERT_OP(tf_op)                                               \
   struct ConvertTF##tf_op##Op : public RewritePattern {                      \
-    explicit ConvertTF##tf_op##Op(MLIRContext *context)                      \
+    explicit ConvertTF##tf_op##Op(MLIRContext* context)                      \
         : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {}   \
-    LogicalResult matchAndRewrite(Operation *op,                             \
-                                  PatternRewriter &rewriter) const override; \
+    LogicalResult matchAndRewrite(Operation* op,                             \
+                                  PatternRewriter& rewriter) const override; \
   }
 
 // All the explcitly implemented complex lowerings.
@@ -155,34 +134,11 @@
 DECL_CONVERT_OP(FakeQuantWithMinMaxVars);
 #undef DECL_CONVERT_OP
 
-// TODO: remove macro when replacing common function return types with
-// llvm::Optional<> Helper macros for checking the return value of a common
-// legalization function that returns a single tensor.
-// Packs the result in a list.
-#define TOSA_REPLACE_LOWERED_OP(REWRITER, OP, LOWERED_OP)   \
-  if (LOWERED_OP) {                                         \
-    REWRITER.replaceOp((OP), {(LOWERED_OP)->getResults()}); \
-    return success();                                       \
-  } else {                                                  \
-    return failure();                                       \
-  }
-
-// TODO: remove macro when replacing common function return types with
-// llvm::Optional<> Helper macros for checking the return value of a common
-// legalization function that returns a tensor list.
-#define TOSA_REPLACE_LOWERED_OP_LIST(REWRITER, OP, LOWERED_OP) \
-  if (LOWERED_OP) {                                            \
-    REWRITER.replaceOp((OP), (LOWERED_OP)->getResults());      \
-    return success();                                          \
-  } else {                                                     \
-    return failure();                                          \
-  }
-
 LogicalResult ConvertTFReluOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_relu_op = cast<TF::ReluOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_relu_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -201,10 +157,10 @@
 }
 
 LogicalResult ConvertTFRelu6Op::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_relu6_op = cast<TF::Relu6Op>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_relu6_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -222,10 +178,10 @@
 }
 
 LogicalResult ConvertTFEqualOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_equal_op = cast<TF::EqualOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -236,10 +192,10 @@
 }
 
 LogicalResult ConvertTFNotEqualOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_not_equal_op = cast<TF::NotEqualOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_not_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -256,10 +212,10 @@
 }
 
 LogicalResult ConvertTFGreaterOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_greater_op = cast<TF::GreaterOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_greater_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -270,10 +226,10 @@
 }
 
 LogicalResult ConvertTFGreaterEqualOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_greater_equal_op = cast<TF::GreaterEqualOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_greater_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -283,11 +239,11 @@
   return success();
 }
 
-LogicalResult ConvertTFAddOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFAddOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_add_op = cast<TF::AddOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_add_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -298,10 +254,10 @@
 }
 
 LogicalResult ConvertTFAddV2Op::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_addv2_op = cast<TF::AddV2Op>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_addv2_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -313,10 +269,10 @@
 
 // AddN is commutative
 LogicalResult ConvertTFAddNOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_addn_op = cast<TF::AddNOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_addn_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -337,11 +293,11 @@
   return success();
 }
 
-LogicalResult ConvertTFSubOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFSubOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_sub_op = cast<TF::SubOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_sub_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -351,51 +307,65 @@
   return success();
 }
 
-LogicalResult ConvertTFMulOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFMulOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_mul_op = cast<TF::MulOp>(op);
 
-  auto lowered_op = convertMultiplyOp(rewriter, op, tf_mul_op.getResult(),
-                                      tf_mul_op.x(), tf_mul_op.y());
+  llvm::Optional<Value> result = convertMultiplyOp(
+      rewriter, op, tf_mul_op.getResult(), tf_mul_op.x(), tf_mul_op.y());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+  return success();
 }
 
 LogicalResult ConvertTFSquareOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_square_op = cast<TF::SquareOp>(op);
 
-  auto lowered_op = convertMultiplyOp(rewriter, op, tf_square_op.getResult(),
-                                      tf_square_op.x(), tf_square_op.x());
+  llvm::Optional<Value> result =
+      convertMultiplyOp(rewriter, op, tf_square_op.getResult(),
+                        tf_square_op.x(), tf_square_op.x());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+  return success();
 }
 
 LogicalResult ConvertTFSquaredDifferenceOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_squared_op = cast<TF::SquaredDifferenceOp>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertSquaredDifferenceOp(rewriter, op, tf_squared_op.getResult(),
                                  tf_squared_op.x(), tf_squared_op.y());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+  return success();
 }
 
 LogicalResult ConvertTFRoundOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_round_op = cast<TF::RoundOp>(op);
 
-  auto input_type = tf_round_op.x().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      tf_round_op.x().getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     return op->emitOpError("Round: input not ranked tensor type");
   }
 
   if (input_type.getElementType().isa<FloatType>()) {
-    auto lowered_op =
+    llvm::Optional<Value> result =
         convertRoundOp(rewriter, op, tf_round_op.getResult(), tf_round_op.x());
 
-    TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+    if (!result) return failure();
+
+    rewriter.replaceOp(op, {result.getValue()});
+    return success();
 
   } else {
     tf_round_op.replaceAllUsesWith(tf_round_op.x());
@@ -404,37 +374,47 @@
 }
 
 LogicalResult ConvertTFFloorDivOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_floordiv_op = cast<TF::FloorDivOp>(op);
 
-  auto lowered_op = convertFloorDivOp(rewriter, op, tf_floordiv_op.getResult(),
-                                      tf_floordiv_op.x(), tf_floordiv_op.y());
+  llvm::Optional<Value> result =
+      convertFloorDivOp(rewriter, op, tf_floordiv_op.getResult(),
+                        tf_floordiv_op.x(), tf_floordiv_op.y());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFFloorModOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_floormod_op = cast<TF::FloorModOp>(op);
 
-  auto lowered_op = convertFloorModOp(rewriter, op, tf_floormod_op.getResult(),
-                                      tf_floormod_op.x(), tf_floormod_op.y());
+  llvm::Optional<Value> result =
+      convertFloorModOp(rewriter, op, tf_floormod_op.getResult(),
+                        tf_floormod_op.x(), tf_floormod_op.y());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFAssertOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   op->dropAllReferences();
   op->erase();
   return success();
 }
 
 LogicalResult ConvertTFMaximumOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_maximum_op = cast<TF::MaximumOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_maximum_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -445,10 +425,10 @@
 }
 
 LogicalResult ConvertTFMinimumOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_minimum_op = cast<TF::MinimumOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_minimum_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -459,11 +439,12 @@
 }
 
 LogicalResult ConvertTFRealDivOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_div_op = cast<TF::RealDivOp>(op);
 
-  auto y_type = tf_div_op.y().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType y_type =
+      tf_div_op.y().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tf_div_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type || !y_type) return failure();
@@ -479,11 +460,12 @@
 }
 
 LogicalResult ConvertTFArgMaxOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_argmax_op = cast<TF::ArgMaxOp>(op);
 
-  auto input_type = tf_argmax_op.input().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType input_type =
+      tf_argmax_op.input().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tf_argmax_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type || !input_type) return failure();
@@ -509,12 +491,12 @@
   return success();
 }
 LogicalResult ConvertTFAvgPoolOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_avgpool_op = cast<TF::AvgPoolOp>(op);
 
-  auto input_type =
+  RankedTensorType input_type =
       tf_avgpool_op.value().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tf_avgpool_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_type || !output_type) return failure();
@@ -557,12 +539,12 @@
 
     SmallVector<int64_t, 2> i64array;
 
-    for (auto &elem : tf_avgpool_op.ksize()) {
+    for (auto& elem : tf_avgpool_op.ksize()) {
       int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
       i64array.emplace_back(value);
     }
 
-    auto filter_type = RankedTensorType::get(
+    RankedTensorType filter_type = RankedTensorType::get(
         llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
 
     if (!getPaddingValuesFromPadType(
@@ -579,12 +561,12 @@
 }
 
 LogicalResult ConvertTFMaxPoolOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_maxpool_op = cast<TF::MaxPoolOp>(op);
 
-  auto input_type =
+  RankedTensorType input_type =
       tf_maxpool_op.input().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tf_maxpool_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_type || !output_type) return failure();
@@ -627,12 +609,12 @@
 
     SmallVector<int64_t, 4> i64array;
 
-    for (auto &elem : tf_maxpool_op.ksize()) {
+    for (auto& elem : tf_maxpool_op.ksize()) {
       int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
       i64array.emplace_back(value);
     }
 
-    auto filter_type = RankedTensorType::get(
+    RankedTensorType filter_type = RankedTensorType::get(
         llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
 
     if (!getPaddingValuesFromPadType(
@@ -649,7 +631,7 @@
 }
 
 LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_concatv2_op = cast<TF::ConcatV2Op>(op);
   SmallVector<Value, 8> values(tf_concatv2_op.values());
 
@@ -659,17 +641,21 @@
 
   int32_t axis = axis_elems.getValue<IntegerAttr>({}).getInt();
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertConcatV2Op(rewriter, op, tf_concatv2_op.getResult(), values, axis);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFReshapeOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_reshape_op = cast<TF::ReshapeOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -688,15 +674,17 @@
 }
 
 LogicalResult ConvertTFRankOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_rank_op = cast<TF::RankOp>(op);
 
-  auto input_type = tf_rank_op.input().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      tf_rank_op.input().getType().dyn_cast<RankedTensorType>();
   if (!input_type) return failure();
 
   int32_t rank = input_type.getRank();
 
-  auto rank_type = RankedTensorType::get({1}, rewriter.getIntegerType(32));
+  RankedTensorType rank_type =
+      RankedTensorType::get({1}, rewriter.getIntegerType(32));
   auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
   auto rank_const =
       rewriter.create<tosa::ConstOp>(op->getLoc(), rank_type, rank_attr);
@@ -707,15 +695,16 @@
 }
 
 LogicalResult ConvertTFShapeOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_shape_op = cast<TF::ShapeOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
 
-  auto input_type = tf_shape_op.input().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      tf_shape_op.input().getType().dyn_cast<RankedTensorType>();
   if (!input_type) return failure();
 
   auto input_shape = input_type.getShape();
@@ -725,7 +714,7 @@
     shape_arr.emplace_back(input_shape[i]);
   }
 
-  auto shape_type = RankedTensorType::get(
+  RankedTensorType shape_type = RankedTensorType::get(
       {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
   auto shape_attr = DenseElementsAttr::get(
       shape_type, llvm::makeArrayRef<int32_t>(shape_arr));
@@ -738,38 +727,47 @@
 }
 
 LogicalResult ConvertTFExpandDimsOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_expanddims_op = cast<TF::ExpandDimsOp>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertExpandDimsOp(rewriter, op, tf_expanddims_op.getResult(),
                           tf_expanddims_op.input(), tf_expanddims_op.dim());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFSqueezeOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_squeeze_op = cast<TF::SqueezeOp>(op);
 
   // Copy squeeze_dims into int32_t array
   auto squeeze_dims_attr = tf_squeeze_op.squeeze_dimsAttr();
   SmallVector<int32_t, 8> squeeze_dims;
-  for (auto &squeeze_dim : squeeze_dims_attr) {
+  for (auto& squeeze_dim : squeeze_dims_attr) {
     squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
   }
 
-  auto lowered_op = convertSqueezeOp(rewriter, op, tf_squeeze_op.getResult(),
-                                     tf_squeeze_op.input(), squeeze_dims);
+  llvm::Optional<Value> result =
+      convertSqueezeOp(rewriter, op, tf_squeeze_op.getResult(),
+                       tf_squeeze_op.input(), squeeze_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFFillOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_fill_op = cast<TF::FillOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -788,8 +786,8 @@
   if (!matchPattern(tf_fill_op.value(), m_Constant(&value_elem)))
     return failure();
 
-  auto fill_type = RankedTensorType::get(ArrayRef<int64_t>(dims_vals),
-                                         value_elem.getType().getElementType());
+  RankedTensorType fill_type = RankedTensorType::get(
+      ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
   DenseElementsAttr fill_attr;
 
   // Convert to a compatible zero type
@@ -814,40 +812,44 @@
 }
 
 LogicalResult ConvertTFConv2DOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_conv2d_op = cast<TF::Conv2DOp>(op);
 
-  auto filter_type =
+  RankedTensorType filter_type =
       tf_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tf_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
 
   // Set up a zero attr for subsequent pattern replacement if required
   auto bias_dim = filter_type.getShape().back();
-  auto bias_type =
+  RankedTensorType bias_type =
       RankedTensorType::get({bias_dim}, filter_type.getElementType());
   auto bias_attr = rewriter.getZeroAttr(bias_type);
   auto bias = rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type,
                                              bias_attr.cast<ElementsAttr>());
 
-  auto lowered_op = convertTFConv2DCommon(
+  llvm::Optional<Value> result = convertTFConv2DCommon(
       rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(),
       bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(),
       tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(),
       tf_conv2d_op.data_format());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_dwconv2d_op = cast<TF::DepthwiseConv2dNativeOp>(op);
 
-  auto input_type =
+  RankedTensorType input_type =
       tf_dwconv2d_op.input().getType().dyn_cast<RankedTensorType>();
-  auto filter_type =
+  RankedTensorType filter_type =
       tf_dwconv2d_op.filter().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tf_dwconv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_type) return failure();
@@ -909,7 +911,7 @@
 
   auto filter_shape = filter_type.getShape();
   auto bias_dim = filter_shape[2] * filter_shape[3];
-  auto bias_type =
+  RankedTensorType bias_type =
       RankedTensorType::get({bias_dim}, filter_type.getElementType());
   auto bias_attr = rewriter.getZeroAttr(bias_type);
   auto bias = rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type,
@@ -922,13 +924,14 @@
 }
 
 LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_conv_op = cast<TF::Conv2DBackpropInputOp>(op);
 
-  auto input_type =
+  RankedTensorType input_type =
       tf_conv_op.out_backprop().getType().dyn_cast<RankedTensorType>();
-  auto filter_type = tf_conv_op.filter().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType filter_type =
+      tf_conv_op.filter().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tf_conv_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_type) return failure();
@@ -942,7 +945,7 @@
   a1_transpose_dims.push_back(filter_shape[0]);
   a1_transpose_dims.push_back(filter_shape[1]);
   a1_transpose_dims.push_back(filter_shape[3]);
-  auto a1_filter_transpose_perm =
+  Value a1_filter_transpose_perm =
       get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {2, 0, 1, 3});
   auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
       op->getLoc(),
@@ -1024,11 +1027,11 @@
   return success();
 }
 
-LogicalResult ConvertTFAllOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFAllOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_all_op = cast<TF::AllOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_all_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1040,17 +1043,21 @@
   auto keep_dims_attr = tf_all_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceAllOp(
+  llvm::Optional<Value> result = convertReduceAllOp(
       rewriter, op, output_type, tf_all_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
-LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_any_op = cast<TF::AnyOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_any_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1062,17 +1069,21 @@
   auto keep_dims_attr = tf_any_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceAnyOp(
+  llvm::Optional<Value> result = convertReduceAnyOp(
       rewriter, op, output_type, tf_any_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
-LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_max_op = cast<TF::MaxOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_max_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1084,17 +1095,21 @@
   auto keep_dims_attr = tf_max_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceMaxOp(
+  llvm::Optional<Value> result = convertReduceMaxOp(
       rewriter, op, output_type, tf_max_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
-LogicalResult ConvertTFMinOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFMinOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_min_op = cast<TF::MinOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_min_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1106,17 +1121,21 @@
   auto keep_dims_attr = tf_min_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceMinOp(
+  llvm::Optional<Value> result = convertReduceMinOp(
       rewriter, op, output_type, tf_min_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFMeanOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_mean_op = cast<TF::MeanOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1128,17 +1147,21 @@
   auto keep_dims_attr = tf_mean_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceMeanOp(
+  llvm::Optional<Value> result = convertReduceMeanOp(
       rewriter, op, output_type, tf_mean_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFProdOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_prod_op = cast<TF::ProdOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1150,17 +1173,21 @@
   auto keep_dims_attr = tf_prod_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceProdOp(
+  llvm::Optional<Value> result = convertReduceProdOp(
       rewriter, op, output_type, tf_prod_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
-LogicalResult ConvertTFSumOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFSumOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_sum_op = cast<TF::SumOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1172,47 +1199,63 @@
   auto keep_dims_attr = tf_sum_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceSumOp(
+  llvm::Optional<Value> result = convertReduceSumOp(
       rewriter, op, output_type, tf_sum_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
-LogicalResult ConvertTFEluOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFEluOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_elu_op = cast<TF::EluOp>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertEluOp(rewriter, op, tf_elu_op.getResult(), tf_elu_op.features());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFSoftmaxOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_softmax_op = cast<TF::SoftmaxOp>(op);
 
-  auto lowered_op = convertSoftmaxOp(rewriter, op, tf_softmax_op.getResult(),
-                                     tf_softmax_op.logits());
+  llvm::Optional<Value> result = convertSoftmaxOp(
+      rewriter, op, tf_softmax_op.getResult(), tf_softmax_op.logits());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLogSoftmaxOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_logsoftmax_op = cast<TF::LogSoftmaxOp>(op);
 
-  auto lowered_op = convertLogSoftmaxOp(
+  llvm::Optional<Value> result = convertLogSoftmaxOp(
       rewriter, op, tf_logsoftmax_op.getResult(), tf_logsoftmax_op.logits());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFFusedBatchNormOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_batchnorm_op = cast<TF::FusedBatchNormOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1237,9 +1280,9 @@
   // op5 = mul(op4, bscale)
   // op6 = add(op5, boffset)
 
-  auto mean_type =
+  RankedTensorType mean_type =
       tf_batchnorm_op.mean().getType().dyn_cast<RankedTensorType>();
-  auto variance_type =
+  RankedTensorType variance_type =
       tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
   if (!variance_type || !mean_type) return failure();
 
@@ -1257,7 +1300,7 @@
     variance_val = tf_batchnorm_op.variance();
   }
 
-  auto epsilon_type =
+  RankedTensorType epsilon_type =
       RankedTensorType::get({1}, variance_type.getElementType());
   auto epsilon_attr =
       DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
@@ -1292,10 +1335,10 @@
 }
 
 LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_batchnorm_op = cast<TF::FusedBatchNormV3Op>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1313,7 +1356,7 @@
       op->getLoc(), tf_batchnorm_op.getResult(0).getType(), tf_batchnorm_op.x(),
       tf_batchnorm_op.mean());
 
-  auto variance_type =
+  RankedTensorType variance_type =
       tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
   if (!variance_type) return failure();
 
@@ -1349,10 +1392,10 @@
 }
 
 LogicalResult ConvertTFBiasAddOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_biasadd_op = cast<TF::BiasAddOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_biasadd_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1365,10 +1408,10 @@
 }
 
 LogicalResult ConvertTFSliceOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_slice_op = cast<TF::SliceOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1404,10 +1447,10 @@
 }
 
 LogicalResult ConvertTFTileOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_tile_op = cast<TF::TileOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1428,10 +1471,10 @@
 }
 
 LogicalResult ConvertTFTransposeOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_transpose_op = cast<TF::TransposeOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_transpose_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) {
@@ -1445,7 +1488,7 @@
 }
 
 LogicalResult ConvertTFPackOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_pack_op = cast<TF::PackOp>(op);
 
   SmallVector<Value, 8> inputs(tf_pack_op.values());
@@ -1460,14 +1503,18 @@
   }
   int32_t axis_i32 = axis_attr.getInt();
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertPackOp(rewriter, op, tf_pack_op.getResult(), inputs, axis_i32);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_unpack_op = cast<TF::UnpackOp>(op);
 
   IntegerAttr axis_attr;
@@ -1478,15 +1525,19 @@
   }
   int32_t axis_i32 = axis_attr.getInt();
 
-  auto lowered_op =
+  llvm::Optional<ValueRange> results =
       convertUnpackOp(rewriter, op, tf_unpack_op.value(), axis_i32);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!results) return failure();
+
+  rewriter.replaceOp(op, results.getValue());
+
+  return success();
 }
 
 // Splits in num_split parts along split_dim
 LogicalResult ConvertTFSplitOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_split_op = cast<TF::SplitOp>(op);
 
   // Get the number of splits
@@ -1502,15 +1553,20 @@
     axis = axisAttrElems.getValue<IntegerAttr>({}).getInt();
   }
 
-  auto lowered_op = convertSplitOp(rewriter, op, tf_split_op.getResult(0),
-                                   tf_split_op.value(), num_split, axis);
+  llvm::Optional<ValueRange> results =
+      convertSplitOp(rewriter, op, tf_split_op.getResult(0),
+                     tf_split_op.value(), num_split, axis);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!results) return failure();
+
+  rewriter.replaceOp(op, results.getValue());
+
+  return success();
 }
 
 // TFSplitV op splits based on a vector of sizes
 LogicalResult ConvertTFSplitVOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_splitv_op = cast<TF::SplitVOp>(op);
 
   // Get the size_splits array
@@ -1533,17 +1589,22 @@
 
   int32_t axis = axisAttrElems.getValue<IntegerAttr>(0).getInt();
 
-  auto lowered_op = convertSplitVOp(rewriter, op, tf_splitv_op.getResult(0),
-                                    tf_splitv_op.value(), size_split, axis);
+  llvm::Optional<ValueRange> results =
+      convertSplitVOp(rewriter, op, tf_splitv_op.getResult(0),
+                      tf_splitv_op.value(), size_split, axis);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!results) return failure();
+
+  rewriter.replaceOp(op, results.getValue());
+
+  return success();
 }
 
 LogicalResult ConvertTFLessOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_less_op = cast<TF::LessOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_less_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1560,10 +1621,10 @@
 }
 
 LogicalResult ConvertTFLessEqualOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_less_equal_op = cast<TF::LessEqualOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_less_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1579,11 +1640,11 @@
   return success();
 }
 
-LogicalResult ConvertTFPadOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFPadOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_pad_op = cast<TF::PadOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_pad_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1596,42 +1657,52 @@
 }
 
 LogicalResult ConvertTFResizeBilinearOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_resize_op = cast<TF::ResizeBilinearOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
 
-  auto lowered_op = convertResizeOp(
+  llvm::Optional<Value> result = convertResizeOp(
       rewriter, op, output_type, tf_resize_op.images(), StringRef("BILINEAR"));
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFResizeNearestNeighborOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_resize_op = cast<TF::ResizeNearestNeighborOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
 
-  auto lowered_op = convertResizeOp(
+  llvm::Optional<Value> result = convertResizeOp(
       rewriter, op, output_type, tf_resize_op.images(), StringRef("NEAREST"));
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFMatMulOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_matmul_op = cast<TF::MatMulOp>(op);
 
-  auto a_type = tf_matmul_op.a().getType().dyn_cast<RankedTensorType>();
-  auto b_type = tf_matmul_op.b().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType a_type =
+      tf_matmul_op.a().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType b_type =
+      tf_matmul_op.b().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tf_matmul_op.getResult().getType().dyn_cast<RankedTensorType>();
 
   if (!(a_type && b_type && output_type)) {
@@ -1648,10 +1719,10 @@
 }
 
 LogicalResult ConvertTFGatherOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_gather_op = cast<TF::GatherOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1667,10 +1738,10 @@
 }
 
 LogicalResult ConvertTFGatherV2Op::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_gather_op = cast<TF::GatherV2Op>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1697,86 +1768,115 @@
 }
 
 LogicalResult ConvertTFSelectV2Op::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_sel_op = cast<TF::SelectV2Op>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertSelectOp(rewriter, op, tf_sel_op.getResult(),
                       tf_sel_op.condition(), tf_sel_op.t(), tf_sel_op.e());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFSpaceToDepthOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_s2d_op = cast<TF::SpaceToDepthOp>(op);
 
-  auto lowered_op = convertSpaceToDepthOp(
+  llvm::Optional<Value> result = convertSpaceToDepthOp(
       rewriter, op, tf_s2d_op.getResult(), tf_s2d_op.input(),
       tf_s2d_op.block_sizeAttr(), tf_s2d_op.data_formatAttr());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFDepthToSpaceOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_d2s_op = cast<TF::DepthToSpaceOp>(op);
 
-  auto lowered_op = convertDepthToSpaceOp(
+  llvm::Optional<Value> result = convertDepthToSpaceOp(
       rewriter, op, tf_d2s_op.getResult(), tf_d2s_op.input(),
       tf_d2s_op.block_sizeAttr(), tf_d2s_op.data_formatAttr());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFSpaceToBatchNDOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_s2b_op = cast<TF::SpaceToBatchNDOp>(op);
 
-  auto lowered_op = convertSpaceToBatchNDOp(
+  llvm::Optional<Value> result = convertSpaceToBatchNDOp(
       rewriter, op, tf_s2b_op.getResult(), tf_s2b_op.input(),
       tf_s2b_op.block_shape(), tf_s2b_op.paddings());
+  if (!result) return failure();
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFBatchToSpaceNDOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_b2s_op = cast<TF::BatchToSpaceNDOp>(op);
 
-  auto lowered_op = convertBatchToSpaceNDOp(
+  llvm::Optional<Value> result = convertBatchToSpaceNDOp(
       rewriter, op, tf_b2s_op.getResult(), tf_b2s_op.input(),
       tf_b2s_op.block_shape(), tf_b2s_op.crops());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_ss_op = cast<TF::StridedSliceOp>(op);
 
-  auto lowered_op = convertStridedSliceOp(
+  llvm::Optional<Value> result = convertStridedSliceOp(
       rewriter, op, tf_ss_op.getResult(), tf_ss_op.input(), tf_ss_op.begin(),
       tf_ss_op.end(), tf_ss_op.strides(), tf_ss_op.begin_maskAttr().getInt(),
       tf_ss_op.end_maskAttr().getInt(), tf_ss_op.ellipsis_maskAttr().getInt(),
       tf_ss_op.new_axis_maskAttr().getInt(),
       tf_ss_op.shrink_axis_maskAttr().getInt());
-  TOSA_REPLACE_LOWERED_OP_LIST(rewriter, op, lowered_op);
+
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFZerosLikeOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_zeroslike_op = cast<TF::ZerosLikeOp>(op);
 
-  auto lowered_op = convertZerosLikeOp(
+  llvm::Optional<Value> result = convertZerosLikeOp(
       rewriter, op, tf_zeroslike_op.getResult(), tf_zeroslike_op.x());
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFSigmoidOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_sigmoid_op = cast<TF::SigmoidOp>(op);
-  auto output_type =
+  RankedTensorType output_type =
       tf_sigmoid_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1787,9 +1887,9 @@
 }
 
 LogicalResult ConvertTFTanhOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_tanh_op = cast<TF::TanhOp>(op);
-  auto output_type =
+  RankedTensorType output_type =
       tf_tanh_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1799,9 +1899,9 @@
 }
 
 LogicalResult ConvertTFLeakyReluOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_leakyrelu_op = cast<TF::LeakyReluOp>(op);
-  auto output_type =
+  RankedTensorType output_type =
       tf_leakyrelu_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1810,10 +1910,10 @@
   return failure();
 }
 
-LogicalResult ConvertTFNegOp::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ConvertTFNegOp::matchAndRewrite(Operation* op,
+                                              PatternRewriter& rewriter) const {
   auto tf_neg_op = cast<TF::NegOp>(op);
-  auto output_type =
+  RankedTensorType output_type =
       tf_neg_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1823,9 +1923,9 @@
 }
 
 LogicalResult ConvertTFStopGradientOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_stopgrad_op = cast<TF::StopGradientOp>(op);
-  auto output_type =
+  RankedTensorType output_type =
       tf_stopgrad_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1836,11 +1936,11 @@
 }
 
 LogicalResult ConvertTFReverseV2Op::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_reverse_op = cast<TF::ReverseV2Op>(op);
-  auto input_type =
+  RankedTensorType input_type =
       tf_reverse_op.tensor().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tf_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!input_type || !output_type) return failure();
 
@@ -1872,29 +1972,33 @@
 }
 
 LogicalResult ConvertTFFakeQuantWithMinMaxArgsOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxArgsOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_fakequant_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertFakeQuantOp(rewriter, op, output_type, tf_fakequant_op.inputs(),
                          tf_fakequant_op.minAttr().getValueAsDouble(),
                          tf_fakequant_op.maxAttr().getValueAsDouble(),
                          tf_fakequant_op.num_bitsAttr().getInt(),
                          tf_fakequant_op.narrow_rangeAttr().getValue());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFFakeQuantWithMinMaxVarsOp::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    Operation* op, PatternRewriter& rewriter) const {
   auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxVarsOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tf_fakequant_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1913,17 +2017,21 @@
   int64_t min_val = min_elems.getValue<IntegerAttr>(0).getInt();
   int64_t max_val = max_elems.getValue<IntegerAttr>(0).getInt();
 
-  auto lowered_op = convertFakeQuantOp(
+  llvm::Optional<Value> result = convertFakeQuantOp(
       rewriter, op, output_type, tf_fakequant_op.inputs(), min_val, max_val,
       tf_fakequant_op.num_bitsAttr().getInt(),
       tf_fakequant_op.narrow_rangeAttr().getValue());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 void LegalizeTF::runOnFunction() {
   OwningRewritePatternList patterns;
-  auto *ctx = &getContext();
+  auto* ctx = &getContext();
   auto func = getFunction();
 
   // Add the generated patterns to the list.
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
index a5fe18f..4e51bd7 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
@@ -23,32 +23,9 @@
 #include <numeric>
 #include <unordered_set>
 
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/Optional.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringSwitch.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
@@ -163,35 +140,13 @@
 DECL_CONVERT_OP(QConst);
 #undef DECL_CONVERT_OP
 
-// TODO: remove macro when replacing common function return types with
-// llvm::Optional<> Helper macros for checking the return value of a common
-// legalization function that returns a single tensor.
-// Packs the result in a list.
-#define TOSA_REPLACE_LOWERED_OP(REWRITER, OP, LOWERED_OP)   \
-  if (LOWERED_OP) {                                         \
-    REWRITER.replaceOp((OP), {(LOWERED_OP)->getResults()}); \
-    return success();                                       \
-  } else {                                                  \
-    return failure();                                       \
-  }
-
-// TODO: remove macro when replacing common function return types with
-// llvm::Optional<> Helper macros for checking the return value of a common
-// legalization function that returns a tensor list.
-#define TOSA_REPLACE_LOWERED_OP_LIST(REWRITER, OP, LOWERED_OP) \
-  if (LOWERED_OP) {                                            \
-    REWRITER.replaceOp((OP), (LOWERED_OP)->getResults());      \
-    return success();                                          \
-  } else {                                                     \
-    return failure();                                          \
-  }
-
 LogicalResult ConvertTFLReluOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_relu_op = cast<TFL::ReluOp>(op);
 
-  auto input_type = tfl_relu_op.x().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType input_type =
+      tfl_relu_op.x().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tfl_relu_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_type || !output_type) return failure();
@@ -209,20 +164,22 @@
 
   Value output;
   if (output_is_qtype) {
-    auto rescale_type =
+    RankedTensorType rescale_type =
         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
-    auto input_qtype = input_type.getElementType()
-                           .dyn_cast<mlir::quant::UniformQuantizedType>();
-    auto output_qtype = output_type.getElementType()
-                            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_qtype =
+        input_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType output_qtype =
+        output_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
 
-    auto op1_rescale_in = buildRescaleToInt32(rewriter, op, tfl_relu_op.x(),
-                                              1.0f, input_qtype.getZeroPoint());
+    Value op1_rescale_in = buildRescaleToInt32(
+        rewriter, op, tfl_relu_op.x(), 1.0f, input_qtype.getZeroPoint());
     auto op2_relun_op1 = rewriter.create<tosa::ReluNOp>(
         op->getLoc(), rescale_type, op1_rescale_in,
         rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
         rewriter.getF32FloatAttr(0.0f));
-    auto op3_rescale_op2 = buildRescaleFromInt32(
+    Value op3_rescale_op2 = buildRescaleFromInt32(
         rewriter, op, output_type, op2_relun_op1.getResult(), 1.0f,
         output_qtype.getZeroPoint());
 
@@ -244,8 +201,9 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_relu6_op = cast<TFL::Relu6Op>(op);
 
-  auto input_type = tfl_relu6_op.x().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType input_type =
+      tfl_relu6_op.x().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tfl_relu6_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_type || !output_type) return failure();
@@ -263,21 +221,23 @@
 
   Value output;
   if (output_is_qtype && input_is_qtype) {
-    auto rescale_type =
+    RankedTensorType rescale_type =
         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
-    auto input_qtype = input_type.getElementType()
-                           .dyn_cast<mlir::quant::UniformQuantizedType>();
-    auto output_qtype = output_type.getElementType()
-                            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_qtype =
+        input_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType output_qtype =
+        output_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
     int64_t rescaled_6 = std::llround(6.0f / input_qtype.getScale()) +
                          input_qtype.getZeroPoint();
 
-    auto op1_rescale_in = buildRescaleToInt32(rewriter, op, tfl_relu6_op.x(),
-                                              1.0f, input_qtype.getZeroPoint());
+    Value op1_rescale_in = buildRescaleToInt32(
+        rewriter, op, tfl_relu6_op.x(), 1.0f, input_qtype.getZeroPoint());
     auto op2_relun_op1 = rewriter.create<tosa::ReluNOp>(
         op->getLoc(), rescale_type, op1_rescale_in,
         rewriter.getI64IntegerAttr(rescaled_6), rewriter.getF32FloatAttr(0.0f));
-    auto op3_rescale_op2 = buildRescaleFromInt32(
+    Value op3_rescale_op2 = buildRescaleFromInt32(
         rewriter, op, output_type, op2_relun_op1.getResult(), 1.0f,
         output_qtype.getZeroPoint());
 
@@ -299,9 +259,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_equal_op = cast<TFL::EqualOp>(op);
 
-  auto input_x_type = tfl_equal_op.x().getType().dyn_cast<RankedTensorType>();
-  auto input_y_type = tfl_equal_op.y().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType input_x_type =
+      tfl_equal_op.x().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_y_type =
+      tfl_equal_op.y().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tfl_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_x_type || !input_y_type || !output_type) return failure();
@@ -322,10 +284,12 @@
 
   Value output;
   if (output_is_qtype && input_x_is_qtype && input_y_is_qtype) {
-    auto input_x_qtype = input_x_type.getElementType()
-                             .dyn_cast<mlir::quant::UniformQuantizedType>();
-    auto input_y_qtype = input_y_type.getElementType()
-                             .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_x_qtype =
+        input_x_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_y_qtype =
+        input_y_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
 
     if (input_x_qtype.getScale() != input_y_qtype.getScale() ||
         input_x_qtype.getZeroPoint() != input_y_qtype.getZeroPoint()) {
@@ -334,9 +298,9 @@
           "must be the same");
     }
 
-    auto op1_rescale_x = buildRescaleToInt32(
+    Value op1_rescale_x = buildRescaleToInt32(
         rewriter, op, tfl_equal_op.x(), 1.0f, input_x_qtype.getZeroPoint());
-    auto op2_rescale_y = buildRescaleToInt32(
+    Value op2_rescale_y = buildRescaleToInt32(
         rewriter, op, tfl_equal_op.y(), 1.0f, input_y_qtype.getZeroPoint());
     auto op3_equal_op1_op2 = rewriter.create<tosa::EqualOp>(
         op->getLoc(), output_type, op1_rescale_x, op2_rescale_y);
@@ -357,11 +321,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_not_equal_op = cast<TFL::NotEqualOp>(op);
 
-  auto input_lhs_type =
+  RankedTensorType input_lhs_type =
       tfl_not_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
-  auto input_rhs_type =
+  RankedTensorType input_rhs_type =
       tfl_not_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_not_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
@@ -382,10 +346,12 @@
 
   Value output;
   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
-    auto input_lhs_qtype = input_lhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
-    auto input_rhs_qtype = input_rhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_lhs_qtype =
+        input_lhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_rhs_qtype =
+        input_rhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
 
     if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
         input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
@@ -394,10 +360,10 @@
           "must be the same");
     }
 
-    auto op1_rescale_lhs =
+    Value op1_rescale_lhs =
         buildRescaleToInt32(rewriter, op, tfl_not_equal_op.lhs(), 1.0f,
                             input_lhs_qtype.getZeroPoint());
-    auto op2_rescale_rhs =
+    Value op2_rescale_rhs =
         buildRescaleToInt32(rewriter, op, tfl_not_equal_op.rhs(), 1.0f,
                             input_rhs_qtype.getZeroPoint());
     auto op3_equal_op1_op2 = rewriter.create<tosa::EqualOp>(
@@ -424,11 +390,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_greater_op = cast<TFL::GreaterOp>(op);
 
-  auto input_lhs_type =
+  RankedTensorType input_lhs_type =
       tfl_greater_op.lhs().getType().dyn_cast<RankedTensorType>();
-  auto input_rhs_type =
+  RankedTensorType input_rhs_type =
       tfl_greater_op.rhs().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_greater_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
@@ -449,10 +415,12 @@
 
   Value output;
   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
-    auto input_lhs_qtype = input_lhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
-    auto input_rhs_qtype = input_rhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_lhs_qtype =
+        input_lhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_rhs_qtype =
+        input_rhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
 
     if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
         input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
@@ -461,10 +429,10 @@
           "must be the same");
     }
 
-    auto op1_rescale_lhs =
+    Value op1_rescale_lhs =
         buildRescaleToInt32(rewriter, op, tfl_greater_op.lhs(), 1.0f,
                             input_lhs_qtype.getZeroPoint());
-    auto op2_rescale_rhs =
+    Value op2_rescale_rhs =
         buildRescaleToInt32(rewriter, op, tfl_greater_op.rhs(), 1.0f,
                             input_rhs_qtype.getZeroPoint());
     auto op3_greater_op1_op2 = rewriter.create<tosa::GreaterOp>(
@@ -486,11 +454,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_greater_equal_op = cast<TFL::GreaterEqualOp>(op);
 
-  auto input_lhs_type =
+  RankedTensorType input_lhs_type =
       tfl_greater_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
-  auto input_rhs_type =
+  RankedTensorType input_rhs_type =
       tfl_greater_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_greater_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
@@ -511,10 +479,12 @@
 
   Value output;
   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
-    auto input_lhs_qtype = input_lhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
-    auto input_rhs_qtype = input_rhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_lhs_qtype =
+        input_lhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_rhs_qtype =
+        input_rhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
 
     if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
         input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
@@ -523,10 +493,10 @@
           "must be the same");
     }
 
-    auto op1_rescale_lhs =
+    Value op1_rescale_lhs =
         buildRescaleToInt32(rewriter, op, tfl_greater_equal_op.lhs(), 1.0f,
                             input_lhs_qtype.getZeroPoint());
-    auto op2_rescale_rhs =
+    Value op2_rescale_rhs =
         buildRescaleToInt32(rewriter, op, tfl_greater_equal_op.rhs(), 1.0f,
                             input_rhs_qtype.getZeroPoint());
     auto op3_greater_equal_op1_op2 = rewriter.create<tosa::GreaterEqualOp>(
@@ -550,9 +520,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_add_op = cast<TFL::AddOp>(op);
 
-  auto input_lhs_type = tfl_add_op.lhs().getType().dyn_cast<RankedTensorType>();
-  auto input_rhs_type = tfl_add_op.rhs().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType input_lhs_type =
+      tfl_add_op.lhs().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_rhs_type =
+      tfl_add_op.rhs().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tfl_add_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
@@ -573,14 +545,17 @@
 
   Value output;
   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
-    auto rescale_type =
+    RankedTensorType rescale_type =
         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
-    auto input_lhs_qtype = input_lhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
-    auto input_rhs_qtype = input_rhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
-    auto output_qtype = output_type.getElementType()
-                            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_lhs_qtype =
+        input_lhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_rhs_qtype =
+        input_rhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType output_qtype =
+        output_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
 
     // Following quantization described in tensorflow/lite/kernels/add.cc
     // In details it does:
@@ -604,15 +579,15 @@
     double output_rescale_scale =
         max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
 
-    auto op1_rescale_lhs =
+    Value op1_rescale_lhs =
         buildRescaleToInt32(rewriter, op, tfl_add_op.lhs(), lhs_rescale_scale,
                             input_lhs_qtype.getZeroPoint());
-    auto op2_rescale_rhs =
+    Value op2_rescale_rhs =
         buildRescaleToInt32(rewriter, op, tfl_add_op.rhs(), rhs_rescale_scale,
                             input_rhs_qtype.getZeroPoint());
     auto op3_add_op1_op2 = rewriter.create<tosa::AddOp>(
         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
-    auto op4_rescale_op3 = buildRescaleFromInt32(
+    Value op4_rescale_op3 = buildRescaleFromInt32(
         rewriter, op, output_type, op3_add_op1_op2.getResult(),
         output_rescale_scale, output_qtype.getZeroPoint());
     output = op4_rescale_op3;
@@ -626,12 +601,13 @@
   auto fused_activation_fn = tfl_add_op.fused_activation_functionAttr();
 
   if (fused_activation_fn) {
-    auto fused_activation_op =
+    llvm::Optional<Value> fused_activation_val =
         convertFusedActivation(rewriter, op, output, fused_activation_fn);
 
-    if (fused_activation_op) {
-      TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
-    }
+    if (!fused_activation_val) return failure();
+
+    rewriter.replaceOp(op, {fused_activation_val.getValue()});
+    return success();
   }
 
   rewriter.replaceOp(op, {output});
@@ -642,9 +618,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_sub_op = cast<TFL::SubOp>(op);
 
-  auto input_lhs_type = tfl_sub_op.lhs().getType().dyn_cast<RankedTensorType>();
-  auto input_rhs_type = tfl_sub_op.rhs().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType input_lhs_type =
+      tfl_sub_op.lhs().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_rhs_type =
+      tfl_sub_op.rhs().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tfl_sub_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
@@ -665,13 +643,15 @@
 
   Value output;
   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
-    auto rescale_type =
+    RankedTensorType rescale_type =
         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
-    auto input_lhs_qtype = input_lhs_type.getElementType()
-                               .cast<mlir::quant::UniformQuantizedType>();
-    auto input_rhs_qtype = input_rhs_type.getElementType()
-                               .cast<mlir::quant::UniformQuantizedType>();
-    auto output_qtype =
+    UniformQuantizedType input_lhs_qtype =
+        input_lhs_type.getElementType()
+            .cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_rhs_qtype =
+        input_rhs_type.getElementType()
+            .cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType output_qtype =
         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
 
     // Following quantization described in tensorflow/lite/kernels/add.cc
@@ -696,15 +676,15 @@
     double output_rescale_scale =
         max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
 
-    auto op1_rescale_lhs =
+    Value op1_rescale_lhs =
         buildRescaleToInt32(rewriter, op, tfl_sub_op.lhs(), lhs_rescale_scale,
                             input_lhs_qtype.getZeroPoint());
-    auto op2_rescale_rhs =
+    Value op2_rescale_rhs =
         buildRescaleToInt32(rewriter, op, tfl_sub_op.rhs(), rhs_rescale_scale,
                             input_rhs_qtype.getZeroPoint());
     auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
-    auto op4_rescale_op3 = buildRescaleFromInt32(
+    Value op4_rescale_op3 = buildRescaleFromInt32(
         rewriter, op, output_type, op3_sub_op1_op2.getResult(),
         output_rescale_scale, output_qtype.getZeroPoint());
     output = op4_rescale_op3;
@@ -718,12 +698,13 @@
   auto fused_activation_fn = tfl_sub_op.fused_activation_functionAttr();
 
   if (fused_activation_fn) {
-    auto fused_activation_op =
+    llvm::Optional<Value> fused_activation_val =
         convertFusedActivation(rewriter, op, output, fused_activation_fn);
 
-    if (fused_activation_op) {
-      TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
-    }
+    if (!fused_activation_val) return failure();
+
+    rewriter.replaceOp(op, {fused_activation_val.getValue()});
+    return success();
   }
 
   rewriter.replaceOp(op, {output});
@@ -734,25 +715,24 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_mul_op = cast<TFL::MulOp>(op);
 
-  auto lowered_op = convertMultiplyOp(rewriter, op, tfl_mul_op.getResult(),
-                                      tfl_mul_op.lhs(), tfl_mul_op.rhs());
+  llvm::Optional<Value> result = convertMultiplyOp(
+      rewriter, op, tfl_mul_op.getResult(), tfl_mul_op.lhs(), tfl_mul_op.rhs());
 
-  if (!lowered_op) {
-    return failure();
-  }
+  if (!result) return failure();
 
   auto fused_activation_fn = tfl_mul_op.fused_activation_functionAttr();
 
   if (fused_activation_fn) {
-    auto fused_activation_op = convertFusedActivation(
-        rewriter, op, lowered_op->getResult(0), fused_activation_fn);
+    llvm::Optional<Value> fused_activation_val = convertFusedActivation(
+        rewriter, op, result.getValue(), fused_activation_fn);
 
-    if (fused_activation_op) {
-      TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
-    }
+    if (!fused_activation_val) return failure();
+
+    rewriter.replaceOp(op, {fused_activation_val.getValue()});
+    return success();
   }
 
-  rewriter.replaceOp(op, {lowered_op->getResult(0)});
+  rewriter.replaceOp(op, {result.getValue()});
   return success();
 }
 
@@ -760,14 +740,13 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_square_op = cast<TFL::SquareOp>(op);
 
-  auto lowered_op = convertMultiplyOp(rewriter, op, tfl_square_op.getResult(),
-                                      tfl_square_op.x(), tfl_square_op.x());
+  llvm::Optional<Value> result =
+      convertMultiplyOp(rewriter, op, tfl_square_op.getResult(),
+                        tfl_square_op.x(), tfl_square_op.x());
 
-  if (!lowered_op) {
-    return failure();
-  }
+  if (!result) return failure();
 
-  rewriter.replaceOp(op, {lowered_op->getResult(0)});
+  rewriter.replaceOp(op, {result.getValue()});
   return success();
 }
 
@@ -775,27 +754,34 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_squared_op = cast<TFL::SquaredDifferenceOp>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertSquaredDifferenceOp(rewriter, op, tfl_squared_op.getResult(),
                                  tfl_squared_op.lhs(), tfl_squared_op.rhs());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+  return success();
 }
 
 LogicalResult ConvertTFLRoundOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_round_op = cast<TFL::RoundOp>(op);
 
-  auto input_type = tfl_round_op.x().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      tfl_round_op.x().getType().dyn_cast<RankedTensorType>();
   if (!input_type) {
     return op->emitOpError("Round: input not ranked tensor type");
   }
 
   if (input_type.getElementType().isa<FloatType>()) {
-    auto lowered_op = convertRoundOp(rewriter, op, tfl_round_op.getResult(),
-                                     tfl_round_op.x());
+    llvm::Optional<Value> result = convertRoundOp(
+        rewriter, op, tfl_round_op.getResult(), tfl_round_op.x());
 
-    TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+    if (!result) return failure();
+
+    rewriter.replaceOp(op, {result.getValue()});
+    return success();
 
   } else {
     // Round on int is nonsensical. Instead, replace uses of result with the
@@ -809,7 +795,7 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_div_op = cast<TFL::DivOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_div_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -823,12 +809,13 @@
                                    reciprocal_op.getResult(), 0);
 
   if (fused_activation_fn) {
-    auto fused_activation_op = convertFusedActivation(
+    llvm::Optional<Value> fused_activation_val = convertFusedActivation(
         rewriter, op, mul_op.getResult(), fused_activation_fn);
 
-    if (fused_activation_op) {
-      TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
-    }
+    if (!fused_activation_val) return failure();
+
+    rewriter.replaceOp(op, {fused_activation_val.getValue()});
+    return success();
   }
 
   rewriter.replaceOp(op, {mul_op.getResult()});
@@ -840,9 +827,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_max_op = cast<TFL::MaximumOp>(op);
 
-  auto input_lhs_type = tfl_max_op.lhs().getType().dyn_cast<RankedTensorType>();
-  auto input_rhs_type = tfl_max_op.rhs().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType input_lhs_type =
+      tfl_max_op.lhs().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_rhs_type =
+      tfl_max_op.rhs().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
 
   // Not a ranked tensor output
@@ -864,16 +853,16 @@
 
   Value output;
   if (output_is_qtype) {
-    auto rescale_type =
+    RankedTensorType rescale_type =
         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
 
-    auto op1_rescale_lhs =
+    Value op1_rescale_lhs =
         buildRescaleToInt32(rewriter, op, tfl_max_op.lhs(), 1.0f, 0);
-    auto op2_rescale_rhs =
+    Value op2_rescale_rhs =
         buildRescaleToInt32(rewriter, op, tfl_max_op.rhs(), 1.0f, 0);
     auto op3_max_op1_op2 = rewriter.create<tosa::MaximumOp>(
         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
-    auto op4_rescale_op3 = buildRescaleFromInt32(
+    Value op4_rescale_op3 = buildRescaleFromInt32(
         rewriter, op, output_type, op3_max_op1_op2.getResult(), 1.0f, 0);
 
     output = op4_rescale_op3;
@@ -893,9 +882,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_min_op = cast<TFL::MinimumOp>(op);
 
-  auto input_lhs_type = tfl_min_op.lhs().getType().dyn_cast<RankedTensorType>();
-  auto input_rhs_type = tfl_min_op.rhs().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType input_lhs_type =
+      tfl_min_op.lhs().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_rhs_type =
+      tfl_min_op.rhs().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType output_type =
       tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
@@ -916,16 +907,16 @@
 
   Value output;
   if (output_is_qtype) {
-    auto rescale_type =
+    RankedTensorType rescale_type =
         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
 
-    auto op1_rescale_lhs =
+    Value op1_rescale_lhs =
         buildRescaleToInt32(rewriter, op, tfl_min_op.lhs(), 1.0f, 0);
-    auto op2_rescale_rhs =
+    Value op2_rescale_rhs =
         buildRescaleToInt32(rewriter, op, tfl_min_op.rhs(), 1.0f, 0);
     auto op3_min_op1_op2 = rewriter.create<tosa::MinimumOp>(
         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
-    auto op4_rescale_op3 = buildRescaleFromInt32(
+    Value op4_rescale_op3 = buildRescaleFromInt32(
         rewriter, op, output_type, op3_min_op1_op2.getResult(), 1.0f, 0);
 
     output = op4_rescale_op3;
@@ -945,29 +936,37 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_floordiv_op = cast<TFL::FloorDivOp>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertFloorDivOp(rewriter, op, tfl_floordiv_op.getResult(),
                         tfl_floordiv_op.lhs(), tfl_floordiv_op.rhs());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLFloorModOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_floormod_op = cast<TFL::FloorModOp>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertFloorModOp(rewriter, op, tfl_floormod_op.getResult(),
                         tfl_floormod_op.lhs(), tfl_floormod_op.rhs());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLAddNOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_addn_op = cast<TFL::AddNOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_addn_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -992,9 +991,9 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_avgpool_op = cast<TFL::AveragePool2DOp>(op);
 
-  auto input_type =
+  RankedTensorType input_type =
       tfl_avgpool_op.input().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_avgpool_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1025,7 +1024,7 @@
     // Pooling has no non-unit dilation
     ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
 
-    auto filter_type = RankedTensorType::get(
+    RankedTensorType filter_type = RankedTensorType::get(
         llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
 
     // TFLite doesn't support explicit padding
@@ -1046,9 +1045,9 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_maxpool_op = cast<TFL::MaxPool2DOp>(op);
 
-  auto input_type =
+  RankedTensorType input_type =
       tfl_maxpool_op.input().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_maxpool_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1079,7 +1078,7 @@
     // Pooling has no non-unit dilation
     ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
 
-    auto filter_type = RankedTensorType::get(
+    RankedTensorType filter_type = RankedTensorType::get(
         llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
 
     // TFLite doesn't support explicit padding
@@ -1100,11 +1099,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_conv2d_op = cast<TFL::Conv2DOp>(op);
 
-  auto input_type =
+  RankedTensorType input_type =
       tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
-  auto filter_type =
+  RankedTensorType filter_type =
       tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_type) return failure();
@@ -1171,12 +1170,13 @@
   auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
 
   if (fused_activation_fn) {
-    auto fused_activation_op = convertFusedActivation(
+    llvm::Optional<Value> fused_activation_val = convertFusedActivation(
         rewriter, op, conv2d_output, fused_activation_fn);
 
-    if (fused_activation_op) {
-      TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
-    }
+    if (!fused_activation_val) return failure();
+
+    rewriter.replaceOp(op, {fused_activation_val.getValue()});
+    return success();
   }
 
   rewriter.replaceOp(op, {conv2d_output});
@@ -1188,10 +1188,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_conv_op = cast<TFL::TransposeConvOp>(op);
 
-  auto input_type = tfl_conv_op.input().getType().dyn_cast<RankedTensorType>();
-  auto filter_type =
+  RankedTensorType input_type =
+      tfl_conv_op.input().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType filter_type =
       tfl_conv_op.weights().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_conv_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_type) return failure();
@@ -1299,11 +1300,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_conv2d_op = cast<TFL::DepthwiseConv2DOp>(op);
 
-  auto input_type =
+  RankedTensorType input_type =
       tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
-  auto filter_type =
+  RankedTensorType filter_type =
       tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_type) return failure();
@@ -1377,7 +1378,7 @@
   a2_reshape_dims.push_back(a1_transpose_dims[2] / depth_multiplier.getInt());
   a2_reshape_dims.push_back(depth_multiplier.getInt());
 
-  auto a1_filter_transpose_perms =
+  Value a1_filter_transpose_perms =
       get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {1, 2, 3, 0});
   auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
       op->getLoc(),
@@ -1412,12 +1413,13 @@
   auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
 
   if (fused_activation_fn) {
-    auto fused_activation_op = convertFusedActivation(
+    llvm::Optional<Value> fused_activation_val = convertFusedActivation(
         rewriter, op, conv2d_output, fused_activation_fn);
 
-    if (fused_activation_op) {
-      TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
-    }
+    if (!fused_activation_val) return failure();
+
+    rewriter.replaceOp(op, {fused_activation_val.getValue()});
+    return success();
   }
 
   rewriter.replaceOp(op, {conv2d_output});
@@ -1429,14 +1431,17 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_fc_op = cast<TFL::FullyConnectedOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_fc_op.getResult(0).getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
 
-  auto input_type = tfl_fc_op.input().getType().dyn_cast<RankedTensorType>();
-  auto filter_type = tfl_fc_op.filter().getType().dyn_cast<RankedTensorType>();
-  auto bias_type = tfl_fc_op.bias().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      tfl_fc_op.input().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType filter_type =
+      tfl_fc_op.filter().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType bias_type =
+      tfl_fc_op.bias().getType().dyn_cast<RankedTensorType>();
   if (!input_type || !filter_type) return failure();
 
   bool input_is_qtype =
@@ -1468,8 +1473,8 @@
     int64_t num_batch = input_type.getNumElements() / num_elems;
     SmallVector<int64_t, 2> shape_vals({num_batch, num_elems});
 
-    auto reshape_type = RankedTensorType::get(ArrayRef<int64_t>(shape_vals),
-                                              input_type.getElementType());
+    RankedTensorType reshape_type = RankedTensorType::get(
+        ArrayRef<int64_t>(shape_vals), input_type.getElementType());
     auto reshape_op = rewriter.create<tosa::ReshapeOp>(
         op->getLoc(), reshape_type, tfl_fc_op.input(),
         rewriter.getI64ArrayAttr(shape_vals));
@@ -1483,8 +1488,8 @@
     // value. TOSA requires bias to be an array of output_channel_count values,
     // so create a constant of the appropriate number and type of zeros.
     SmallVector<int64_t, 1> bias_shape({filter_type.getShape()[0]});
-    auto bias_type = RankedTensorType::get(ArrayRef<int64_t>(bias_shape),
-                                           input_type.getElementType());
+    RankedTensorType bias_type = RankedTensorType::get(
+        ArrayRef<int64_t>(bias_shape), input_type.getElementType());
 
     DenseElementsAttr bias_attr;
     if (input_type.getElementType().isa<FloatType>()) {
@@ -1527,12 +1532,13 @@
   auto fused_activation_fn = tfl_fc_op.fused_activation_functionAttr();
 
   if (fused_activation_fn) {
-    auto fused_activation_op =
+    llvm::Optional<Value> fused_activation_val =
         convertFusedActivation(rewriter, op, fc_output, fused_activation_fn);
 
-    if (fused_activation_op) {
-      TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
-    }
+    if (!fused_activation_val) return failure();
+
+    rewriter.replaceOp(op, {fused_activation_val.getValue()});
+    return success();
   }
 
   rewriter.replaceOp(op, {fc_output});
@@ -1556,17 +1562,20 @@
   }
   int32_t axis = axis_attr.getInt();
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertConcatV2Op(rewriter, op, tfl_concat_op.getResult(), values, axis);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+  return success();
 }
 
 LogicalResult ConvertTFLReshapeOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_reshape_op = cast<TFL::ReshapeOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1586,12 +1595,14 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_rank_op = cast<TFL::RankOp>(op);
 
-  auto input_type = tfl_rank_op.input().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      tfl_rank_op.input().getType().dyn_cast<RankedTensorType>();
   if (!input_type) return failure();
 
   int32_t rank = input_type.getRank();
 
-  auto rank_type = RankedTensorType::get({1}, rewriter.getIntegerType(32));
+  RankedTensorType rank_type =
+      RankedTensorType::get({1}, rewriter.getIntegerType(32));
   auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
   auto rank_const =
       rewriter.create<tosa::ConstOp>(op->getLoc(), rank_type, rank_attr);
@@ -1605,12 +1616,13 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_shape_op = cast<TFL::ShapeOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
 
-  auto input_type = tfl_shape_op.input().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      tfl_shape_op.input().getType().dyn_cast<RankedTensorType>();
   if (!input_type) return failure();
 
   auto input_shape = input_type.getShape();
@@ -1620,7 +1632,7 @@
     shape_arr.emplace_back(input_shape[i]);
   }
 
-  auto shape_type = RankedTensorType::get(
+  RankedTensorType shape_type = RankedTensorType::get(
       {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
   auto shape_attr = DenseElementsAttr::get(
       shape_type, llvm::makeArrayRef<int32_t>(shape_arr));
@@ -1636,11 +1648,15 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_expanddims_op = cast<TFL::ExpandDimsOp>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertExpandDimsOp(rewriter, op, tfl_expanddims_op.getResult(),
                           tfl_expanddims_op.input(), tfl_expanddims_op.dim());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLSqueezeOp::matchAndRewrite(
@@ -1654,17 +1670,22 @@
     squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
   }
 
-  auto lowered_op = convertSqueezeOp(rewriter, op, tfl_squeeze_op.getResult(),
-                                     tfl_squeeze_op.input(), squeeze_dims);
+  llvm::Optional<Value> result =
+      convertSqueezeOp(rewriter, op, tfl_squeeze_op.getResult(),
+                       tfl_squeeze_op.input(), squeeze_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLFillOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_fill_op = cast<TFL::FillOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1683,8 +1704,8 @@
   if (!matchPattern(tfl_fill_op.input(), m_Constant(&value_elem)))
     return failure();
 
-  auto fill_type = RankedTensorType::get(ArrayRef<int64_t>(dims_vals),
-                                         value_elem.getType().getElementType());
+  RankedTensorType fill_type = RankedTensorType::get(
+      ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
   DenseElementsAttr fill_attr;
 
   // Convert to a compatible zero type.
@@ -1712,7 +1733,7 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_any_op = cast<TFL::ReduceAnyOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_any_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1724,17 +1745,21 @@
   auto keep_dims_attr = tfl_any_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceAnyOp(
+  llvm::Optional<Value> result = convertReduceAnyOp(
       rewriter, op, output_type, tfl_any_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_max_op = cast<TFL::ReduceMaxOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1746,17 +1771,21 @@
   auto keep_dims_attr = tfl_max_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceMaxOp(
+  llvm::Optional<Value> result = convertReduceMaxOp(
       rewriter, op, output_type, tfl_max_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLReduceMinOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_min_op = cast<TFL::ReduceMinOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1768,17 +1797,21 @@
   auto keep_dims_attr = tfl_min_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceMinOp(
+  llvm::Optional<Value> result = convertReduceMinOp(
       rewriter, op, output_type, tfl_min_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLReduceProdOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_prod_op = cast<TFL::ReduceProdOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1790,17 +1823,21 @@
   auto keep_dims_attr = tfl_prod_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceProdOp(
+  llvm::Optional<Value> result = convertReduceProdOp(
       rewriter, op, output_type, tfl_prod_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLMeanOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_mean_op = cast<TFL::MeanOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1812,17 +1849,21 @@
   auto keep_dims_attr = tfl_mean_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceMeanOp(
+  llvm::Optional<Value> result = convertReduceMeanOp(
       rewriter, op, output_type, tfl_mean_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLSumOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_sum_op = cast<TFL::SumOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -1834,47 +1875,63 @@
   auto keep_dims_attr = tfl_sum_op.keep_dimsAttr();
   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
 
-  auto lowered_op = convertReduceSumOp(
+  llvm::Optional<Value> result = convertReduceSumOp(
       rewriter, op, output_type, tfl_sum_op.input(), axes_elems, keep_dims);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLEluOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_elu_op = cast<TFL::EluOp>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertEluOp(rewriter, op, tfl_elu_op.getResult(), tfl_elu_op.x());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLSoftmaxOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_softmax_op = cast<TFL::SoftmaxOp>(op);
 
-  auto lowered_op = convertSoftmaxOp(rewriter, op, tfl_softmax_op.getResult(),
-                                     tfl_softmax_op.input());
+  llvm::Optional<Value> result = convertSoftmaxOp(
+      rewriter, op, tfl_softmax_op.getResult(), tfl_softmax_op.input());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLLogSoftmaxOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_logsoftmax_op = cast<TFL::LogSoftmaxOp>(op);
 
-  auto lowered_op = convertLogSoftmaxOp(
+  llvm::Optional<Value> result = convertLogSoftmaxOp(
       rewriter, op, tfl_logsoftmax_op.getResult(), tfl_logsoftmax_op.input());
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLSliceOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_slice_op = cast<TFL::SliceOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1906,7 +1963,7 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_tile_op = cast<TFL::TileOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1929,7 +1986,7 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_transpose_op = cast<TFL::TransposeOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_transpose_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -1955,10 +2012,14 @@
   }
   int32_t axis_i32 = axis_attr.getInt();
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertPackOp(rewriter, op, tfl_pack_op.getResult(), inputs, axis_i32);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLUnpackOp::matchAndRewrite(
@@ -1973,10 +2034,14 @@
   }
   int32_t axis_i32 = axis_attr.getInt();
 
-  auto lowered_op =
+  llvm::Optional<ValueRange> results =
       convertUnpackOp(rewriter, op, tfl_unpack_op.input(), axis_i32);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!results) return failure();
+
+  rewriter.replaceOp(op, results.getValue());
+
+  return success();
 }
 
 // Splits in num_split parts along split_dim
@@ -2003,10 +2068,15 @@
   // an integer attribute in TFLite MLIR.
   int32_t axis = axisAttrElems.getValue<IntegerAttr>({}).getInt();
 
-  auto lowered_op = convertSplitOp(rewriter, op, tfl_split_op.getResult(0),
-                                   tfl_split_op.value(), num_split, axis);
+  llvm::Optional<ValueRange> results =
+      convertSplitOp(rewriter, op, tfl_split_op.getResult(0),
+                     tfl_split_op.value(), num_split, axis);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!results) return failure();
+
+  rewriter.replaceOp(op, results.getValue());
+
+  return success();
 }
 
 // Splits in num_split parts along split_dim
@@ -2036,21 +2106,26 @@
   // an integer attribute in TFLite MLIR.
   int32_t axis = axisAttrElems.getValue<IntegerAttr>(0).getInt();
 
-  auto lowered_op = convertSplitVOp(rewriter, op, tfl_splitv_op.getResult(0),
-                                    tfl_splitv_op.value(), size_split, axis);
+  llvm::Optional<ValueRange> results =
+      convertSplitVOp(rewriter, op, tfl_splitv_op.getResult(0),
+                      tfl_splitv_op.value(), size_split, axis);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!results) return failure();
+
+  rewriter.replaceOp(op, results.getValue());
+
+  return success();
 }
 
 LogicalResult ConvertTFLLessOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_less_op = cast<TFL::LessOp>(op);
 
-  auto input_lhs_type =
+  RankedTensorType input_lhs_type =
       tfl_less_op.lhs().getType().dyn_cast<RankedTensorType>();
-  auto input_rhs_type =
+  RankedTensorType input_rhs_type =
       tfl_less_op.rhs().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_less_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
@@ -2071,10 +2146,12 @@
 
   Value output;
   if (output_is_qtype) {
-    auto input_lhs_qtype = input_lhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
-    auto input_rhs_qtype = input_rhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_lhs_qtype =
+        input_lhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_rhs_qtype =
+        input_rhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
 
     if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
         input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
@@ -2083,9 +2160,9 @@
           "must be the same");
     }
 
-    auto op1_rescale_lhs = buildRescaleToInt32(
+    Value op1_rescale_lhs = buildRescaleToInt32(
         rewriter, op, tfl_less_op.lhs(), 1.0f, input_lhs_qtype.getZeroPoint());
-    auto op2_rescale_rhs = buildRescaleToInt32(
+    Value op2_rescale_rhs = buildRescaleToInt32(
         rewriter, op, tfl_less_op.rhs(), 1.0f, input_rhs_qtype.getZeroPoint());
     auto op3_greater_equal_op1_op2 = rewriter.create<tosa::GreaterEqualOp>(
         op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
@@ -2110,11 +2187,11 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_less_equal_op = cast<TFL::LessEqualOp>(op);
 
-  auto input_lhs_type =
+  RankedTensorType input_lhs_type =
       tfl_less_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
-  auto input_rhs_type =
+  RankedTensorType input_rhs_type =
       tfl_less_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_less_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
@@ -2135,10 +2212,12 @@
 
   Value output;
   if (output_is_qtype) {
-    auto input_lhs_qtype = input_lhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
-    auto input_rhs_qtype = input_rhs_type.getElementType()
-                               .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_lhs_qtype =
+        input_lhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
+    UniformQuantizedType input_rhs_qtype =
+        input_rhs_type.getElementType()
+            .dyn_cast<mlir::quant::UniformQuantizedType>();
 
     if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
         input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
@@ -2147,10 +2226,10 @@
           "must be the same");
     }
 
-    auto op1_rescale_lhs =
+    Value op1_rescale_lhs =
         buildRescaleToInt32(rewriter, op, tfl_less_equal_op.lhs(), 1.0f,
                             input_lhs_qtype.getZeroPoint());
-    auto op2_rescale_rhs =
+    Value op2_rescale_rhs =
         buildRescaleToInt32(rewriter, op, tfl_less_equal_op.rhs(), 1.0f,
                             input_rhs_qtype.getZeroPoint());
     auto op3_greater_op1_op2 = rewriter.create<tosa::GreaterOp>(
@@ -2177,7 +2256,7 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_pad_op = cast<TFL::PadOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_pad_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -2193,69 +2272,95 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_resize_op = cast<TFL::ResizeBilinearOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
 
-  auto lowered_op = convertResizeOp(
+  llvm::Optional<Value> result = convertResizeOp(
       rewriter, op, output_type, tfl_resize_op.input(), StringRef("BILINEAR"));
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLResizeNearestNeighborOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_resize_op = cast<TFL::ResizeNearestNeighborOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
 
-  auto lowered_op = convertResizeOp(
+  llvm::Optional<Value> result = convertResizeOp(
       rewriter, op, output_type, tfl_resize_op.input(), StringRef("NEAREST"));
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLSelectOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_sel_op = cast<TFL::SelectOp>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
                       tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLSelectV2Op::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_sel_op = cast<TFL::SelectV2Op>(op);
 
-  auto lowered_op =
+  llvm::Optional<Value> result =
       convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
                       tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLSpaceToBatchNdOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_s2b_op = cast<TFL::SpaceToBatchNdOp>(op);
-  auto lowered_op = convertSpaceToBatchNDOp(
+  llvm::Optional<Value> result = convertSpaceToBatchNDOp(
       rewriter, op, tfl_s2b_op.getResult(), tfl_s2b_op.input(),
       tfl_s2b_op.block_shape(), tfl_s2b_op.paddings());
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLBatchToSpaceNdOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_b2s_op = cast<TFL::BatchToSpaceNdOp>(op);
 
-  auto lowered_op = convertBatchToSpaceNDOp(
+  llvm::Optional<Value> result = convertBatchToSpaceNDOp(
       rewriter, op, tfl_b2s_op.getResult(), tfl_b2s_op.input(),
       tfl_b2s_op.block_shape(), tfl_b2s_op.indices());
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLSpaceToDepthOp::matchAndRewrite(
@@ -2263,11 +2368,15 @@
   auto tfl_s2d_op = cast<TFL::SpaceToDepthOp>(op);
 
   auto block_size_attr = tfl_s2d_op.block_sizeAttr();
-  auto lowered_op = convertSpaceToDepthOp(rewriter, op, tfl_s2d_op.getResult(),
-                                          tfl_s2d_op.input(), block_size_attr,
-                                          rewriter.getStringAttr("NHWC"));
+  llvm::Optional<Value> result = convertSpaceToDepthOp(
+      rewriter, op, tfl_s2d_op.getResult(), tfl_s2d_op.input(), block_size_attr,
+      rewriter.getStringAttr("NHWC"));
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLDepthToSpaceOp::matchAndRewrite(
@@ -2275,44 +2384,57 @@
   auto tfl_d2s_op = cast<TFL::DepthToSpaceOp>(op);
 
   auto block_size_attr = tfl_d2s_op.block_sizeAttr();
-  auto lowered_op = convertDepthToSpaceOp(rewriter, op, tfl_d2s_op.getResult(),
-                                          tfl_d2s_op.input(), block_size_attr,
-                                          rewriter.getStringAttr("NHWC"));
+  llvm::Optional<Value> result = convertDepthToSpaceOp(
+      rewriter, op, tfl_d2s_op.getResult(), tfl_d2s_op.input(), block_size_attr,
+      rewriter.getStringAttr("NHWC"));
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLStridedSliceOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_ss_op = cast<TFL::StridedSliceOp>(op);
 
-  auto lowered_op = convertStridedSliceOp(
+  llvm::Optional<Value> result = convertStridedSliceOp(
       rewriter, op, tfl_ss_op.getResult(), tfl_ss_op.input(), tfl_ss_op.begin(),
       tfl_ss_op.end(), tfl_ss_op.strides(), tfl_ss_op.begin_maskAttr().getInt(),
       tfl_ss_op.end_maskAttr().getInt(), tfl_ss_op.ellipsis_maskAttr().getInt(),
       tfl_ss_op.new_axis_maskAttr().getInt(),
       tfl_ss_op.shrink_axis_maskAttr().getInt());
-  TOSA_REPLACE_LOWERED_OP_LIST(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLZerosLikeOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_zeroslike_op = cast<TFL::ZerosLikeOp>(op);
 
-  auto lowered_op = convertZerosLikeOp(
+  llvm::Optional<Value> result = convertZerosLikeOp(
       rewriter, op, tfl_zeroslike_op.getResult(), tfl_zeroslike_op.input());
-  TOSA_REPLACE_LOWERED_OP_LIST(rewriter, op, lowered_op);
+
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLHardSwishOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_hardswish_op = cast<TFL::HardSwishOp>(op);
-  auto output_type =
+  RankedTensorType output_type =
       tfl_hardswish_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
 
-  auto input_type =
+  RankedTensorType input_type =
       tfl_hardswish_op.input().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!input_type) return failure();
@@ -2334,12 +2456,16 @@
         output_type.getElementType()
             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
 
-    auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
-        true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
-        -32768, 32767);
-    auto bool_type = RankedTensorType::get(input_shape, rewriter.getI1Type());
-    auto int16_type = RankedTensorType::get(input_shape, int16_element_qtype);
-    auto int32_type = RankedTensorType::get(input_shape, rewriter.getI32Type());
+    UniformQuantizedType int16_element_qtype =
+        mlir::quant::UniformQuantizedType::get(
+            true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
+            -32768, 32767);
+    RankedTensorType bool_type =
+        RankedTensorType::get(input_shape, rewriter.getI1Type());
+    RankedTensorType int16_type =
+        RankedTensorType::get(input_shape, int16_element_qtype);
+    RankedTensorType int32_type =
+        RankedTensorType::get(input_shape, rewriter.getI32Type());
 
     // Table's real input range [-4.0, 4.0].
     // Use TABLE op to get relu6(x+3) / 6
@@ -2352,10 +2478,10 @@
       return std::lround(32768.0 * v);
     };
 
-    auto table_const = getTosa1DConstTensorTable(rewriter, op, hardswish_func);
+    Value table_const = getTosa1DConstTensorTable(rewriter, op, hardswish_func);
 
     // Rescale input to 9.7
-    auto op1_rescale_in =
+    Value op1_rescale_in =
         buildRescale(rewriter, op, int16_type, tfl_hardswish_op.input(),
                      (in_quant_type.getScale() * 128.0) / input_sample_grain,
                      in_quant_type.getZeroPoint(), 0);
@@ -2365,13 +2491,13 @@
         op->getLoc(), int32_type, op1_rescale_in, table_const);
 
     // scale table output back to quantized space
-    auto op3_rescale_op2 =
+    Value op3_rescale_op2 =
         buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
                      1.0 / (128.0 * 32768.0 * out_quant_type.getScale()), 0,
                      out_quant_type.getZeroPoint());
 
-    auto op4_rescale_in = buildRescale(rewriter, op, int32_type,
-                                       tfl_hardswish_op.input(), 1.0, 0, 0);
+    Value op4_rescale_in = buildRescale(rewriter, op, int32_type,
+                                        tfl_hardswish_op.input(), 1.0, 0, 0);
 
     // Get 3.0 in quantized space
     int32_t quantized_3 =
@@ -2398,7 +2524,7 @@
     // op5 = reciprocal(6)
     // op6 = mul (op4, op5)
 
-    auto op1_value = getTosaConstTensorSingleF32(rewriter, op, 3.0);
+    Value op1_value = getTosaConstTensorSingleF32(rewriter, op, 3.0);
 
     auto op2_add_x_op1 = rewriter.create<tosa::AddOp>(
         op->getLoc(), output_type, tfl_hardswish_op.input(), op1_value);
@@ -2429,9 +2555,10 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_logistic_op = cast<TFL::LogisticOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_logistic_op.getResult().getType().dyn_cast<RankedTensorType>();
-  auto input_type = tfl_logistic_op.x().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      tfl_logistic_op.x().getType().dyn_cast<RankedTensorType>();
   if (!input_type || !output_type) return failure();
 
   bool input_is_qtype =
@@ -2446,13 +2573,14 @@
   }
 
   if (input_is_qtype) {
-    auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
-        true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
-        -32768, 32767);
-    auto int16_type =
+    UniformQuantizedType int16_element_qtype =
+        mlir::quant::UniformQuantizedType::get(
+            true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
+            -32768, 32767);
+    RankedTensorType int16_type =
         RankedTensorType::get(output_type.getShape(), int16_element_qtype);
-    auto int32_type = RankedTensorType::get(output_type.getShape(),
-                                            rewriter.getIntegerType(32));
+    RankedTensorType int32_type = RankedTensorType::get(
+        output_type.getShape(), rewriter.getIntegerType(32));
     mlir::quant::UniformQuantizedType input_qtype =
         input_type.getElementType()
             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
@@ -2468,10 +2596,10 @@
       return std::lround(32768.0 * v);
     };
 
-    auto table_const = getTosa1DConstTensorTable(rewriter, op, sigmoid_func);
+    Value table_const = getTosa1DConstTensorTable(rewriter, op, sigmoid_func);
 
     // Rescale input to 9.7 precision.
-    auto op1_rescale_in =
+    Value op1_rescale_in =
         buildRescale(rewriter, op, int16_type, tfl_logistic_op.x(),
                      (input_qtype.getScale() * 128.0) / input_sample_grain,
                      input_qtype.getZeroPoint(), 0);
@@ -2482,7 +2610,7 @@
     double output_rescale_scale =
         1.0 / (output_qtype.getScale() * 32768.0 * 128.0);
 
-    auto op3_rescale_op2 =
+    Value op3_rescale_op2 =
         buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
                      output_rescale_scale, 0, output_qtype.getZeroPoint());
 
@@ -2498,9 +2626,10 @@
 LogicalResult ConvertTFLTanhOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_tanh_op = cast<TFL::TanhOp>(op);
-  auto output_type =
+  RankedTensorType output_type =
       tfl_tanh_op.getResult().getType().dyn_cast<RankedTensorType>();
-  auto input_type = tfl_tanh_op.input().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType input_type =
+      tfl_tanh_op.input().getType().dyn_cast<RankedTensorType>();
   if (!input_type || !output_type) return failure();
 
   bool input_is_qtype =
@@ -2515,13 +2644,14 @@
   }
 
   if (input_is_qtype) {
-    auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
-        true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
-        -32768, 32767);
-    auto int16_type =
+    UniformQuantizedType int16_element_qtype =
+        mlir::quant::UniformQuantizedType::get(
+            true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
+            -32768, 32767);
+    RankedTensorType int16_type =
         RankedTensorType::get(output_type.getShape(), int16_element_qtype);
-    auto int32_type = RankedTensorType::get(output_type.getShape(),
-                                            rewriter.getIntegerType(32));
+    RankedTensorType int32_type = RankedTensorType::get(
+        output_type.getShape(), rewriter.getIntegerType(32));
     mlir::quant::UniformQuantizedType input_qtype =
         input_type.getElementType()
             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
@@ -2538,10 +2668,10 @@
       return std::lround(32768.0 * v);
     };
 
-    auto table_const = getTosa1DConstTensorTable(rewriter, op, tanh_func);
+    Value table_const = getTosa1DConstTensorTable(rewriter, op, tanh_func);
 
     // Rescale input to 9.7 precision.
-    auto op1_rescale_in =
+    Value op1_rescale_in =
         buildRescale(rewriter, op, int16_type, tfl_tanh_op.input(),
                      (input_qtype.getScale() * 128.0) / input_sample_grain,
                      input_qtype.getZeroPoint(), 0);
@@ -2552,7 +2682,7 @@
     double output_rescale_scale =
         1.0 / (output_qtype.getScale() * 32768.0 * 128.0);
 
-    auto op3_rescale_op2 =
+    Value op3_rescale_op2 =
         buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
                      output_rescale_scale, 0, output_qtype.getZeroPoint());
 
@@ -2568,7 +2698,7 @@
 LogicalResult ConvertTFLPReluOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_prelu_op = cast<TFL::PReluOp>(op);
-  auto output_type =
+  RankedTensorType output_type =
       tfl_prelu_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -2580,7 +2710,7 @@
 LogicalResult ConvertTFLLeakyReluOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_leakyrelu_op = cast<TFL::LeakyReluOp>(op);
-  auto output_type =
+  RankedTensorType output_type =
       tfl_leakyrelu_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -2592,7 +2722,7 @@
 LogicalResult ConvertTFLNegOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_neg_op = cast<TFL::NegOp>(op);
-  auto output_type =
+  RankedTensorType output_type =
       tfl_neg_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!output_type) return failure();
 
@@ -2622,9 +2752,9 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_reverse_op = cast<TFL::ReverseV2Op>(op);
 
-  auto input_type =
+  RankedTensorType input_type =
       tfl_reverse_op.input().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
   if (!input_type || !output_type) return failure();
 
@@ -2659,22 +2789,22 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_quantize_op = cast<TFL::QuantizeOp>(op);
 
-  auto input_type =
+  RankedTensorType input_type =
       tfl_quantize_op.input().getType().dyn_cast<RankedTensorType>();
-  auto output_type =
+  RankedTensorType output_type =
       tfl_quantize_op.getResult().getType().dyn_cast<RankedTensorType>();
 
   if (!input_type || !output_type) return failure();
 
-  auto qtype =
+  RankedTensorType qtype =
       tfl_quantize_op.qtypeAttr().getValue().dyn_cast<RankedTensorType>();
   if (!qtype) return failure();
 
-  auto element_type =
+  UniformQuantizedType element_type =
       qtype.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
   if (!element_type) return failure();
 
-  auto input_element_type =
+  UniformQuantizedType input_element_type =
       input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
 
   // If input is already a quantized type, this is basically a RESCALE (or
@@ -2682,7 +2812,7 @@
   if (input_element_type) {
     double rescale_scale =
         input_element_type.getScale() / element_type.getScale();
-    auto rescale_op = buildRescale(
+    Value rescale_op = buildRescale(
         rewriter, op, output_type, tfl_quantize_op.input(), rescale_scale,
         input_element_type.getZeroPoint(), element_type.getZeroPoint());
 
@@ -2694,10 +2824,14 @@
     int64_t num_bits = element_type.getStorageTypeIntegralWidth();
     zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
 
-    auto lowered_op = convertQuantizeOp(rewriter, op, output_type,
-                                        tfl_quantize_op.input(), scale, zp);
+    llvm::Optional<Value> result = convertQuantizeOp(
+        rewriter, op, output_type, tfl_quantize_op.input(), scale, zp);
 
-    TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+    if (!result) return failure();
+
+    rewriter.replaceOp(op, {result.getValue()});
+
+    return success();
   }
 }
 
@@ -2705,15 +2839,16 @@
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_dequantize_op = cast<TFL::DequantizeOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_dequantize_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
 
-  auto qtype = tfl_dequantize_op.input().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType qtype =
+      tfl_dequantize_op.input().getType().dyn_cast<RankedTensorType>();
   if (!qtype) return failure();
 
-  auto element_type =
+  UniformQuantizedType element_type =
       qtype.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
   if (!element_type) return failure();
 
@@ -2722,17 +2857,21 @@
   int64_t num_bits = element_type.getStorageTypeIntegralWidth();
   zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
 
-  auto lowered_op = convertDequantizeOp(rewriter, op, output_type,
-                                        tfl_dequantize_op.input(), scale, zp);
+  llvm::Optional<Value> result = convertDequantizeOp(
+      rewriter, op, output_type, tfl_dequantize_op.input(), scale, zp);
 
-  TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
 }
 
 LogicalResult ConvertTFLQConstOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_qconst_op = cast<TFL::QConstOp>(op);
 
-  auto output_type =
+  RankedTensorType output_type =
       tfl_qconst_op.getResult().getType().dyn_cast<RankedTensorType>();
   // Not a ranked tensor output
   if (!output_type) return failure();
@@ -2834,5 +2973,6 @@
 
 static PassRegistration<LegalizeTFL> pass(
     PASS_NAME, "Legalize from TensorFlow Lite to TOSA dialect");
+
 }  // namespace tosa
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc
index 5bae8ec..7280d4c 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc
@@ -15,13 +15,14 @@
 
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
 
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
 
 // Implements legalization and post-legalization optimization helper functions
 
 namespace mlir {
-
 namespace tosa {
 
 // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h
index 69671a6..f18e573 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h
@@ -22,24 +22,10 @@
 #include <iterator>
 #include <numeric>
 
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "tensorflow/core/framework/kernel_shape_util.h"
 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
 #include "tensorflow/core/util/padding.h"
diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.h b/tensorflow/compiler/mlir/tosa/transforms/passes.h
index f944908..69d4e92 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tosa/transforms/passes.h
@@ -18,15 +18,11 @@
 
 #include <memory>
 
-#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/Pass.h"  // from @llvm-project
 
 namespace mlir {
-
 namespace tosa {
 
-struct TOSALegalizationPipelineOptions
-    : public PassPipelineOptions<TOSALegalizationPipelineOptions> {};
-
 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass();
 std::unique_ptr<OperationPass<FuncOp>> createFuseBiasTFPass();
 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass();
@@ -36,7 +32,6 @@
 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc"
 
 }  // namespace tosa
-
 }  // namespace mlir
 
 #endif  // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H
diff --git a/tensorflow/compiler/mlir/tosa/transforms/register_passes.h b/tensorflow/compiler/mlir/tosa/transforms/register_passes.h
deleted file mode 100644
index 7d13205..0000000
--- a/tensorflow/compiler/mlir/tosa/transforms/register_passes.h
+++ /dev/null
@@ -1,34 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H
-#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H
-
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Pass/Pass.h"
-#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
-
-namespace mlir {
-namespace tosa {
-
-inline void registerAllTosaPasses() {
-  registerLegalizeTosaPasses();
-  registerTosaOptPasses();
-}
-
-}  // namespace tosa
-}  // namespace mlir
-
-#endif  // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index 00f2f4b..4d016d6 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -134,6 +134,7 @@
     srcs = ["transforms/mhlo_to_lhlo_with_xla.cc"],
     hdrs = ["transforms/mhlo_to_lhlo_with_xla.h"],
     deps = [
+        ":attribute_importer",
         ":hlo_module_importer",
         ":hlo_utils",
         ":mlir_hlo_to_hlo",
@@ -153,6 +154,8 @@
         "//tensorflow/compiler/xla/service/gpu:backend_configs_cc",
         "//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
         "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/types:optional",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
@@ -248,6 +251,7 @@
     ],
     hdrs = ["mlir_hlo_to_hlo.h"],
     deps = [
+        ":attribute_exporter",
         ":type_to_shape",
         "//tensorflow/compiler/mlir:name_utils",
         "//tensorflow/compiler/mlir/hlo",
@@ -337,6 +341,24 @@
 )
 
 cc_library(
+    name = "attribute_exporter",
+    srcs = ["attribute_exporter.cc"],
+    hdrs = ["attribute_exporter.h"],
+    deps = [
+        "//tensorflow/compiler/mlir/hlo",
+        "//tensorflow/compiler/mlir/hlo:lhlo_gpu",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/core/platform:types",
+        "//tensorflow/stream_executor:dnn",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+    ],
+)
+
+cc_library(
     name = "translate_cl_options",
     srcs = ["xla_mlir_translate_cl.cc"],
     hdrs = ["xla_mlir_translate_cl.h"],
diff --git a/tensorflow/compiler/mlir/xla/attribute_exporter.cc b/tensorflow/compiler/mlir/xla/attribute_exporter.cc
new file mode 100644
index 0000000..88296ab
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/attribute_exporter.cc
@@ -0,0 +1,87 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
+
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/stream_executor/dnn.h"
+
+namespace xla {
+
+ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
+    mlir::mhlo::ConvDimensionNumbers input) {
+  ConvolutionDimensionNumbers output;
+
+  output.set_input_batch_dimension(
+      input.input_batch_dimension().getValue().getSExtValue());
+  output.set_input_feature_dimension(
+      input.input_feature_dimension().getValue().getSExtValue());
+
+  for (auto v : input.input_spatial_dimensions().getValues<int64>()) {
+    output.add_input_spatial_dimensions(v);
+  }
+
+  output.set_kernel_input_feature_dimension(
+      input.kernel_input_feature_dimension().getValue().getSExtValue());
+  output.set_kernel_output_feature_dimension(
+      input.kernel_output_feature_dimension().getValue().getSExtValue());
+
+  for (auto v : input.kernel_spatial_dimensions().getValues<int64>()) {
+    output.add_kernel_spatial_dimensions(v);
+  }
+
+  output.set_output_batch_dimension(
+      input.output_batch_dimension().getValue().getSExtValue());
+  output.set_output_feature_dimension(
+      input.output_feature_dimension().getValue().getSExtValue());
+
+  for (auto v : input.output_spatial_dimensions().getValues<int64>()) {
+    output.add_output_spatial_dimensions(v);
+  }
+
+  return output;
+}
+
+StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
+    llvm::StringRef input) {
+  llvm::Optional<mlir::lmhlo_gpu::Activation> activation =
+      mlir::lmhlo_gpu::symbolizeActivation(input);
+  if (!activation) {
+    return InternalError("Unexpected activation");
+  }
+
+  switch (activation.getValue()) {
+    case mlir::lmhlo_gpu::Activation::None:
+      return stream_executor::dnn::kNone;
+    case mlir::lmhlo_gpu::Activation::Sigmoid:
+      return stream_executor::dnn::kSigmoid;
+    case mlir::lmhlo_gpu::Activation::Tanh:
+      return stream_executor::dnn::kTanh;
+    case mlir::lmhlo_gpu::Activation::Relu:
+      return stream_executor::dnn::kRelu;
+    case mlir::lmhlo_gpu::Activation::Relu6:
+      return stream_executor::dnn::kRelu6;
+    case mlir::lmhlo_gpu::Activation::ReluX:
+      return stream_executor::dnn::kReluX;
+    case mlir::lmhlo_gpu::Activation::BandPass:
+      return stream_executor::dnn::kBandPass;
+    default:
+      return InternalError("Unexpected activation");
+  }
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/mlir/xla/attribute_exporter.h b/tensorflow/compiler/mlir/xla/attribute_exporter.h
new file mode 100644
index 0000000..c58cff0
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/attribute_exporter.h
@@ -0,0 +1,37 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
+#define TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
+
+#include "llvm/ADT/StringRef.h"
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/dnn.h"
+
+namespace xla {
+
+// Converts the conv dimensions attribute to XLA HLO.
+ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
+    mlir::mhlo::ConvDimensionNumbers input);
+
+StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
+    llvm::StringRef input);
+
+}  // namespace xla
+#endif  // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index 23fab36..501f2a4 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -26,10 +26,10 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Region.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
@@ -127,7 +127,7 @@
   llvm::SmallVector<Type, 4> args, rets;
   TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
   TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
-  auto func_type = mlir::FunctionType::get(args, rets, context_);
+  auto func_type = mlir::FunctionType::get(context_, args, rets);
 
   string computation_name =
       computation.parent()->entry_computation() == &computation
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
index d849b83..99fc64f 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
@@ -23,8 +23,8 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
 #include "tensorflow/compiler/xla/comparison_util.h"
diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc
index 9db5861..b554f38 100644
--- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc
@@ -17,9 +17,9 @@
 
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc
index 51e00bc..16aaec0 100644
--- a/tensorflow/compiler/mlir/xla/hlo_utils.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc
@@ -19,7 +19,7 @@
 
 #include "mlir/IR/AffineMap.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/xla/literal.h"
diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.h b/tensorflow/compiler/mlir/xla/hlo_utils.h
index 3ad39ae..88b775b 100644
--- a/tensorflow/compiler/mlir/xla/hlo_utils.h
+++ b/tensorflow/compiler/mlir/xla/hlo_utils.h
@@ -20,7 +20,7 @@
 
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 80c0180..2f1320a 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -17,7 +17,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
@@ -113,6 +113,7 @@
       ConvertPadding(padding, &builder_),
       GetI64ElementsAttr(lhs_dilation, &builder_),
       GetI64ElementsAttr(rhs_dilation, &builder_),
+      /*window_reversal=*/nullptr,
       ConvertConvDimensionNumbers(dimension_numbers, &builder_),
       builder_.getI64IntegerAttr(feature_group_count),
       builder_.getI64IntegerAttr(batch_group_count), config_attr);
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 36aa31b..2e58bf2 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -31,16 +31,17 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
 #include "tensorflow/compiler/mlir/utils/name_utils.h"
+#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/xla/client/lib/matrix.h"
@@ -297,36 +298,7 @@
 
 static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
     mlir::mhlo::ConvDimensionNumbers input) {
-  xla::ConvolutionDimensionNumbers output;
-
-  output.set_input_batch_dimension(
-      input.input_batch_dimension().getValue().getSExtValue());
-  output.set_input_feature_dimension(
-      input.input_feature_dimension().getValue().getSExtValue());
-
-  for (int64 v : input.input_spatial_dimensions().getValues<int64>()) {
-    output.add_input_spatial_dimensions(v);
-  }
-
-  output.set_kernel_input_feature_dimension(
-      input.kernel_input_feature_dimension().getValue().getSExtValue());
-  output.set_kernel_output_feature_dimension(
-      input.kernel_output_feature_dimension().getValue().getSExtValue());
-
-  for (int64 v : input.kernel_spatial_dimensions().getValues<int64>()) {
-    output.add_kernel_spatial_dimensions(v);
-  }
-
-  output.set_output_batch_dimension(
-      input.output_batch_dimension().getValue().getSExtValue());
-  output.set_output_feature_dimension(
-      input.output_feature_dimension().getValue().getSExtValue());
-
-  for (int64 v : input.output_spatial_dimensions().getValues<int64>()) {
-    output.add_output_spatial_dimensions(v);
-  }
-
-  return output;
+  return xla::ConvertConvDimensionNumbers(input);
 }
 
 xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
@@ -516,6 +488,11 @@
   //
   // TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
   LogicalResult Run() {
+    auto main = module_.lookupSymbol<mlir::FuncOp>("main");
+    if (!main)
+      return module_.emitError(
+          "conversion requires module with `main` function");
+
     for (auto func : module_.getOps<FuncOp>()) {
       if (func.empty()) continue;
       if (failed(RunOnFunction(func))) return failure();
@@ -539,8 +516,11 @@
       xla::XlaComputation* result);
 
   ::xla::HloModuleProto ConsumeMainProto() {
-    return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")]
-        .proto();
+    auto main = module_.lookupSymbol<mlir::FuncOp>("main");
+    // This is an invariant check as Run returns failure if there is no main
+    // function and so the main proto shouldn't be consumed in that case.
+    CHECK(main) << "requires module to have main function";  // Crash Ok.
+    return lowered_computation_[main].proto();
   }
 
   // Lower function call to HLO call instruction
@@ -757,6 +737,26 @@
   return failure();
 }
 
+LogicalResult ExportXlaOp(mlir::mhlo::ConvOp op, OpLoweringContext ctx) {
+  // XLA client builder API does not support generating convolution instructions
+  // with window reversal.
+  if (op.hasWindowReversal()) return failure();
+  auto& value_map = *ctx.values;
+  xla::XlaOp lhs, rhs;
+  if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
+  if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
+  xla::XlaOp xla_result = xla::ConvGeneralDilated(
+      lhs, rhs, Convert_window_strides(op.window_strides()),
+      Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()),
+      Convert_rhs_dilation(op.rhs_dilation()),
+      Convert_dimension_numbers(op.dimension_numbers()),
+      Convertuint64_t(op.feature_group_count()),
+      Convertuint64_t(op.batch_group_count()),
+      Unwrap(Convert_precision_config(op.precision_config())));
+  value_map[op] = xla_result;
+  return mlir::success();
+}
+
 LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
   auto& value_map = *ctx.values;
   xla::XlaOp operand;
@@ -1256,10 +1256,10 @@
   }
 
   if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
-    // Construct the return value for the function. If there are multiple
-    // values returned, then create a tuple, else return value directly.
+    // Construct the return value for the function. If there is a single value
+    // returned, then return it directly, else create a tuple and return.
     unsigned num_return_values = inst->getNumOperands();
-    if ((return_tuple_ && is_entry_function) || num_return_values > 1) {
+    if ((return_tuple_ && is_entry_function) || num_return_values != 1) {
       const bool has_ret_shardings =
           !ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
 
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
index 0884230..358454f 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
@@ -218,3 +218,185 @@
   // CHECK-SAME:  replica_groups = dense<{{\[\[}}0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
   ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, to_apply=add
 }
+
+// -----
+
+HloModule ConvForward
+
+// CHECK-LABEL: func @main
+// CHECK: "lmhlo_gpu.conv_forward"
+// CHECK-SAME: algorithm = 2 : i64
+// CHECK-SAME: operand_0_layout = [3, 2, 1, 0]
+// CKECK-SAME: operand_1_layout = [3, 2, 1, 0]
+// CHECK-SAME: result_layout = [3, 2, 1, 0]
+// CHECK-SAME: tensor_ops_enabled = false
+// CHECK-SAME: batch_group_count = 1 : i64
+// CHECK-SAME: input_batch_dimension = 0 : i64
+// CHECK-SAME: input_feature_dimension = 1 : i64
+// CHECK-SAME: input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME: kernel_input_feature_dimension = 1 : i64,
+// CHECK_SAME: kernel_output_feature_dimension = 0 : i64,
+// CHECK-SAME: kernel_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME: output_batch_dimension = 0 : i64
+// CHECK-SAME: output_feature_dimension = 1 : i64
+// CHECK-SAME: output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME: feature_group_count = 1 : i64
+// CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>
+// CHECK-SAME: padding = dense<0> : tensor<2xi64>
+// CHECK_SAME: result_scale = 1.000000e+00 : f64
+// CHECK_SAME: rhs_dilation = dense<1> : tensor<2xi64>
+// CHECK-SAME: window_reversal = dense<true> : tensor<2xi1>
+// CHECK-SAME: window_strides = dense<1> : tensor<2xi64>
+// CHECK: (memref<4x256x3x3xf32>, memref<256x256x2x2xf32>, memref<4x256x2x2xf32>, memref<65536xui8>)
+ENTRY main {
+  %input = f32[4,256,3,3]{3,2,1,0} parameter(0)
+  %filter = f32[256,256,2,2]{3,2,1,0} parameter(1)
+  ROOT %custom-call.1 = (f32[4,256,2,2]{3,2, 1,0}, u8[65536]{0}) custom-call(f32[4,256,3,3]{3,2,1,0} %input, f32[256,256,2,2]{3,2,1,0} %filter),
+                        window={size=2x2 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01,
+                        custom_call_target="__cudnn$convForward",
+                        backend_config="{\"algorithm\":\"2\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
+}
+
+// -----
+
+// CHECK: func @main
+// CHECK: "lmhlo_gpu.conv_forward_fused"
+// CHECK-SAME: activation_mode = "Relu"
+// CHECK-SAME: algorithm = 0 : i64
+// CHECK-SAME: operand_0_layout = [1, 3, 2, 0]
+// CHECK-SAME: operand_1_layout = [2, 1, 0, 3]
+// CHECK-SAME: result_layout = [1, 3, 2, 0]
+// CHECK-SAME: tensor_ops_enabled = false
+// CHECK-SAME: batch_group_count = 1 : i64
+// CHECK-SAME: input_batch_dimension = 0 : i64
+// CHECK-SAME: input_feature_dimension = 1 : i64
+// CHECK-SAME: input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME: kernel_input_feature_dimension = 2 : i64
+// CHECK-SAME: kernel_output_feature_dimension = 3 : i64
+// CHECK-SAME: kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>
+// CHECK-SAME: output_batch_dimension = 0 : i64
+// CHECK-SAME: output_feature_dimension = 1 : i64
+// CHECK-SAME: output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME: feature_group_count = 1 : i64
+// CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>
+// CHECK-SAME: padding = dense<1> : tensor<2xi64>
+// CHECK-SAME: precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"]
+// CHECK-SAME: result_scale = 1.000000e+00 : f64
+// CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>
+// CHECK-SAME: window_reversal = dense<false> : tensor<2xi1>
+// CHECK-SAME: window_strides = dense<1> : tensor<2xi64>
+// CHECK-SAME: (memref<1x17x9x9xf16, #map{{.*}}>, memref<3x3x17x32xf16, #map{{.*}}>, memref<32xf16>, memref<1x32x9x9xf16, #{{.*}}>, memref<0xui8>) -> ()
+
+HloModule FusedConvForward
+
+ENTRY main {
+  %input = f16[1,17,9,9]{1,3,2,0} parameter(0)
+  %filter = f16[3,3,17,32]{2,1,0,3} parameter(1)
+  %bias = f16[32]{0} parameter(2)
+  ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":0}"
+}
+
+// -----
+
+// CHECK: func @main
+// CHECK: "lmhlo_gpu.conv_forward_fused_with_side_input"
+// CHECK-SAME: activation_mode = "Relu"
+// CHECK-SAME: algorithm = 0 : i64
+// CHECK-SAME: operand_0_layout = [1, 3, 2, 0]
+// CHECK-SAME: operand_1_layout = [2, 1, 0, 3]
+// CHECK-SAME: result_layout = [1, 3, 2, 0]
+// CHECK-SAME: tensor_ops_enabled = false
+// CHECK-SAME: batch_group_count = 1 : i64
+// CHECK-SAME: input_batch_dimension = 0 : i64
+// CHECK-SAME: input_feature_dimension = 1 : i64
+// CHECK-SAME: input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME: kernel_input_feature_dimension = 2 : i64
+// CHECK-SAME: kernel_output_feature_dimension = 3 : i64
+// CHECK-SAME: kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>
+// CHECK-SAME: output_batch_dimension = 0 : i64
+// CHECK-SAME: output_feature_dimension = 1 : i64
+// CHECK-SAME: output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME: feature_group_count = 1 : i64
+// CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>
+// CHECK-SAME: padding = dense<1> : tensor<2xi64>
+// CHECK-SAME: precision_config = ["DEFAULT", "DEFAULT", "DEFAULT", "DEFAULT"]
+// CHECK-SAME: result_scale = 1.000000e+00 : f64
+// CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>
+// CHECK-SAME: side_input_scale = 1.000000e+00
+// CHECK-SAME: window_strides = dense<1> : tensor<2xi64>
+// CHECK-SAME: (memref<1x17x9x9xf16, #map{{.*}}>, memref<3x3x17x32xf16, #map{{.*}}>, memref<32xf16>, memref<1x32x9x9xf16, #{{.*}}>, memref<0xui8>) -> ()
+
+HloModule FusedConvForwardSideInput
+
+ENTRY main {
+  %input = f16[1,17,9,9]{1,3,2,0} parameter(0)
+  %filter = f16[3,3,17,32]{2,1,0,3} parameter(1)
+  %bias = f16[32]{0} parameter(2)
+  %side = f16[32]{0} parameter(3)
+  ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias, f16[32]{0} %side), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":1}"
+}
+
+// -----
+
+HloModule BatchNormForwardTraining
+
+// CHECK: func @main
+// CHECK: "lmhlo_gpu.batch_norm_training"
+// CHECK-SAME: epsilon = 1.000000e-03 : f32
+// CHECK-SAME: feature_index = 3 : i64
+// CHECK-SAME: (memref<1x1x10x1xf32>, memref<1xf32>, memref<1xf32>, memref<1x1x10x1xf32>, memref<1xf32>, memref<1xf32>) -> ()
+
+ENTRY main {
+  %input = f32[1,1,10,1]{3,2,1,0} parameter(0)
+  %scale = f32[1]{0} parameter(1)
+  %offset = f32[1]{0} parameter(2)
+  %constant = f32[] constant(0.001)
+  %constant_1 = s64[] constant(3)
+  %custom-call = (f32[1,1,10,1]{3,2,1,0}, f32[1]{0}, f32[1]{0})
+                 custom-call(f32[1,1,10,1]{3,2,1,0} %input, f32[1]{0} %scale, f32[1]{0} %offset, f32[] %constant, s64[] %constant_1),
+                 custom_call_target="__cudnn$batchNormalizationForwardTraining"
+}
+
+// -----
+
+HloModule BatchNormBackward
+
+// CHECK: func @main
+// CHECK: "lmhlo_gpu.batch_norm_grad"
+// CHECK-SAME: epsilon = 1.000000e-03 : f32
+// CHECK-SAME: feature_index = 2 : i64
+// CHECK-SAME: (memref<2x2x2x1xf16>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x1xf16>, memref<2x2x2x1xf16>, memref<2xf32>, memref<2xf32>)
+ENTRY main {
+  %input = f16[2,2,2,1]{3,2,1,0} parameter(0)
+  %scale = f32[2]{0} parameter(1)
+  %mean = f32[2]{0} parameter(2)
+  %stddev = f32[2]{0} parameter(3)
+  %grad = f16[2,2,2,1]{3,2,1,0} parameter(4)
+  %constant = f32[] constant(0.001)
+  %constant_2 = s64[] constant(2)
+  ROOT %custom-call = (f16[2,2,2,1]{3,2,1,0}, f32[2]{0}, f32[2]{0})
+                      custom-call(f16[2,2,2,1]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %stddev, f16[2,2,2,1]{3,2,1,0} %grad, f32[] %constant, s64[] %constant_2),
+                      custom_call_target="__cudnn$batchNormalizationBackward"
+}
+
+// -----
+
+HloModule BatchNormForwardInference
+
+// CHECK: func @main
+// CHECK: lmhlo_gpu.batch_norm_inference"
+// CHECK-SAME: epsilon = 1.000000e-03 : f32
+// CHECK-SAME: feature_index = 0 : i64
+// CHECK-SAME: (memref<2x2x2x2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x2xf32>) -> ()
+ENTRY main {
+  %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
+  %offset = f32[2]{0} parameter(1)
+  %scale = f32[2]{0} parameter(2)
+  %mean = f32[2]{0} parameter(3)
+  %variance = f32[2]{0} parameter(4)
+  %constant = f32[] constant(0.001)
+  %constant_1 = s64[] constant(0)
+  ROOT %custom-call = f32[2,2,2,2]{3,2,1,0}
+                      custom-call(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[] %constant, s64[] %constant_1),
+                      custom_call_target="__cudnn$batchNormalizationForwardInference"
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir
index 04dc3c8..16f5c96 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir
@@ -174,6 +174,13 @@
   return %0: tensor<4xi32>
 }
 
+// CHECK-LABEL: func @bitwise_xor
+func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
+  // CHECK-NEXT: mhlo.xor
+  %0 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  return %0: tensor<4xi32>
+}
+
 // CHECK-LABEL: func @bitwise_and
 func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
   // CHECK-NEXT: mhlo.and
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index c44715b..8aa39c6 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -835,7 +835,14 @@
 }
 
 // CHECK-LABEL: func @floordiv_unranked
-func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
+func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK-NOT: tf.FloorDiv
+  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  return %0: tensor<*xf32>
+}
+
+// CHECK-LABEL: func @floordiv_int
+func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
   // CHECK: tf.FloorDiv
   %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
   return %0: tensor<*xi32>
@@ -894,7 +901,7 @@
 
 // CHECK-LABEL: func @floormod_unranked
 func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
-  // CHECK: tf.FloorMod
+  // CHECK-NOT: tf.FloorMod
   %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
   return %0: tensor<*xi32>
 }
@@ -2045,6 +2052,14 @@
   return %0 : tensor<2xf32>
 }
 
+// CHECK-LABEL: @acos_complex
+// CHLO-LABEL: @acos_complex
+func @acos_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
+  // CHLO: tf.Acos
+  %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
+  return %0 : tensor<2xcomplex<f32>>
+}
+
 // CHECK-LABEL: @acos_dynamic
 // CHLO-LABEL: @acos_dynamic
 func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> {
@@ -2083,6 +2098,14 @@
   return %result : tensor<*xf32>
 }
 
+// CHECK-LABEL: @sinh_complex
+// CHLO-LABEL: @sinh_complex
+func @sinh_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
+  // CHLO: tf.Sinh
+  %0 = "tf.Sinh"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
+  return %0 : tensor<2xcomplex<f32>>
+}
+
 // CHECK-LABEL: func @cast_dynamic_i2f
 func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> {
   // CHECK: "mhlo.convert"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
@@ -3531,7 +3554,7 @@
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: conv_simple
-func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
+func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> {
 
   // CHECK: "mhlo.convolution"(%arg0, %arg1)
 
@@ -3557,12 +3580,12 @@
   // CHECK-DAG-SAME: feature_group_count = 2
   // CHECK-DAG-SAME: batch_group_count = 1
 
-  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
-  return %0 : tensor<256x30x30x16xf32>
+  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
+  return %0 : tensor<256x8x7x16xf32>
 }
 
 // CHECK-LABEL: conv3d_simple
-func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x30x30x30x16xf32> {
+func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> {
 
   // CHECK: "mhlo.convolution"(%arg0, %arg1)
 
@@ -3588,8 +3611,8 @@
   // CHECK-DAG-SAME: feature_group_count = 2
   // CHECK-DAG-SAME: batch_group_count = 1
 
-  %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x30x30x30x16xf32>
-  return %0 : tensor<256x30x30x30x16xf32>
+  %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32>
+  return %0 : tensor<256x7x6x5x16xf32>
 }
 
 // CHECK-LABEL: depthwiseconv_simple
@@ -3617,13 +3640,13 @@
 }
 
 // CHECK-LABEL: conv_explicit_paddings
-func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> {
+func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> {
 
   // CHECK: "mhlo.convolution"(%arg0, %arg1)
   // CHECK-SAME: padding = dense<{{\[\[}}6, 0], [3, 3]]>
 
-  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32>
-  return %0 : tensor<256x32x32x16xf32>
+  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32>
+  return %0 : tensor<256x9x7x16xf32>
 }
 
 // CHECK-LABEL: @conv2d_backprop_input
@@ -4950,6 +4973,16 @@
   return %1 : tensor<4xf32>
 }
 
+// CHECK-LABEL: func @cumsum_empty
+func @cumsum_empty(%arg0: tensor<0xf32>) -> tensor<0xf32> {
+  %0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+
+  // CHECK: "mhlo.reduce_window"
+  // CHECK: padding = dense<0> : tensor<1x2xi64>
+  %1 = "tf.Cumsum"(%arg0, %0) : (tensor<0xf32>, tensor<i32>) -> tensor<0xf32>
+  return %1 : tensor<0xf32>
+}
+
 // CHECK-LABEL: func @cumsum_dynamic
 func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> {
   // CHECK: "tf.Cumsum"
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
index b3d3603..e797518 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
@@ -145,6 +145,27 @@
 
 // CHECK:  HloModule
 func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+  call @empty_callee() : () -> ()
+  return %arg0 : tensor<4xi32>
+}
+
+func @empty_callee() {
+  return
+}
+
+// CHECK:       [[CALLEE:%.*]] () -> () {
+// CHECK-NEXT:    ROOT %{{.*}} = () tuple()
+// CHECK-NEXT:  }
+
+// CHECK:       ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] {
+// CHECK-NEXT:    ROOT %[[ARG]] = s32[4] parameter(0)
+// CHECK-NEXT:    [[CALL:%.*]] = () call(), to_apply=[[CALLEE]]
+// CHECK-NEXT:  }
+
+// -----
+
+// CHECK:  HloModule
+func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
   %0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
   %1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
   return %1 : tensor<4xi32>
@@ -962,7 +983,7 @@
 
 // CHECK: [[SORT:%.+]] = (f32[16,16], s32[16,16]) sort(f32[16,16] %Arg_0.1, s32[16,16] %Arg_1.2), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
 // CHECK: [[GET0:%.+]] = f32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=0
-// CHECK: ROOT [[GET1:%.+]] = s32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=1
+// CHECK: [[GET1:%.+]] = s32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=1
 
 // -----
 
@@ -979,7 +1000,7 @@
 // CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] {
 // CHECK:   ROOT %[[CMP:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT
 
-// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] sort(f32[16,16] %Arg_0.1), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
+// CHECK: %[[RESULT:.*]] = f32[16,16] sort(f32[16,16] %Arg_0.1), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
 
 // -----
 
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/missing_main.mlir b/tensorflow/compiler/mlir/xla/tests/translate/missing_main.mlir
new file mode 100644
index 0000000..a2647d2
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/translate/missing_main.mlir
@@ -0,0 +1,7 @@
+// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s 2>&1 | FileCheck %s
+
+// CHECK: conversion requires module with `main`
+func @non_main() {
+  %0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
+  return
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index ad77a9d..0e754ce 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -34,12 +34,12 @@
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
@@ -144,7 +144,7 @@
 static DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) {
   RankedTensorType ty =
       RankedTensorType::get(static_cast<int64_t>(attr.size()),
-                            IntegerType::get(64, attr.getContext()));
+                            IntegerType::get(attr.getContext(), 64));
   return DenseIntElementsAttr::get(ty, attr.getValue());
 }
 
@@ -184,7 +184,7 @@
   MLIRContext *ctx = input_type.getContext();
   if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
   if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
-    return IntegerType::get(32, ctx);
+    return IntegerType::get(ctx, 32);
   return input_type;
 }
 
@@ -828,7 +828,7 @@
     }
   }
 
-  auto element_type = IntegerType::get(64, input.getContext());
+  auto element_type = IntegerType::get(input.getContext(), 64);
   return DenseIntElementsAttr::get(
       RankedTensorType::get({shape[0]}, element_type), values);
 }
@@ -837,7 +837,7 @@
 // in TensorFlow PadV2 op.
 static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
   auto length = tf_padding.getType().getShape()[0];
-  auto element_type = IntegerType::get(64, tf_padding.getContext());
+  auto element_type = IntegerType::get(tf_padding.getContext(), 64);
   return DenseIntElementsAttr::get<int64_t>(
       RankedTensorType::get({length}, element_type), 0);
 }
@@ -1185,7 +1185,7 @@
       // Conv2D. So, fetch attribute by identifier instead of the
       // op.explicit_paddings() attribute getter.
       explicit_paddings =
-          op.template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
+          op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
     }
 
     SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
@@ -1473,7 +1473,7 @@
     // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the
     // left, "RIGHT" means rows will be padded ot the right.  The default is
     // "RIGHT_LEFT".
-    StringRef align = op.getAttrOfType<StringAttr>("align").getValue();
+    StringRef align = op->getAttrOfType<StringAttr>("align").getValue();
     enum Alignment { kLeft, kRight };
 
     // default is RIGHT_LEFT
@@ -1691,7 +1691,7 @@
 
   LogicalResult matchAndRewrite(TF::EinsumOp op,
                                 PatternRewriter &rewriter) const override {
-    StringAttr equation = op.getAttrOfType<StringAttr>("equation");
+    StringAttr equation = op->getAttrOfType<StringAttr>("equation");
     if (op.N() == 1) {
       rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
           op, op.getType(), *op.inputs().begin(), equation);
@@ -1837,7 +1837,7 @@
       Type feature_type = RankedTensorType::get(
           {GetDimSize(act_type, feature_dim)}, kernel_type);
       Type result_type = TupleType::get(
-          {act.getType(), feature_type, feature_type}, rewriter.getContext());
+          rewriter.getContext(), {act.getType(), feature_type, feature_type});
 
       auto training_op = rewriter.create<BatchNormGradOp>(
           loc, result_type, act, scale, mean, var, grad, op.epsilon(),
@@ -1973,7 +1973,7 @@
       // batch_mean, and batch_var.
       SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
                                             mean_var_type, mean_var_type};
-      Type result_type = TupleType::get(operand_types, rewriter.getContext());
+      Type result_type = TupleType::get(rewriter.getContext(), operand_types);
 
       auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
           op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
@@ -4183,7 +4183,7 @@
       // Conv2DBackpropInput. So, fetch attribute by identifier instead of the
       // op.explicit_paddings() attribute getter.
       ArrayRef<Attribute> explicit_paddings_attr =
-          op.template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
+          op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
       explicit_paddings.reserve(explicit_paddings_attr.size());
       for (Attribute explicit_padding : explicit_paddings_attr)
         explicit_paddings.push_back(
@@ -4265,6 +4265,7 @@
                                    &rewriter),
         /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
         GetI64ElementsAttr(rhs_dilation, &rewriter),
+        /*window_reversal=*/nullptr,
         ConvDimensionNumbers::get(
             /*input_batch_dimension=*/batch_dim_attr,
             /*input_feature_dimension=*/feature_dim_attr,
@@ -4346,7 +4347,7 @@
       // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the
       // op.explicit_paddings() attribute getter.
       ArrayRef<Attribute> explicit_paddings_attr =
-          op.template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
+          op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
       explicit_paddings.reserve(explicit_paddings_attr.size());
       for (Attribute explicit_padding : explicit_paddings_attr)
         explicit_paddings.push_back(
@@ -4479,6 +4480,7 @@
         GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
                                    &rewriter),
         GetI64ElementsAttr(rhs_dilation, &rewriter),
+        /*window_reversal=*/nullptr,
         ConvDimensionNumbers::get(
             // Swap batch_dim and feature_dim in the activations.
             /*input_batch_dimension=*/feature_dim_attr,
@@ -4616,9 +4618,9 @@
     // Emit infeed op.
     // The result type of infeed is a tuple(tuple(result types), token type).
     auto data_tuple_type =
-        mlir::TupleType::get(result_types, rewriter.getContext());
+        mlir::TupleType::get(rewriter.getContext(), result_types);
     auto data_and_token_type = mlir::TupleType::get(
-        {data_tuple_type, token.getType()}, rewriter.getContext());
+        rewriter.getContext(), {data_tuple_type, token.getType()});
 
     auto data_and_token =
         rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,
@@ -4635,11 +4637,11 @@
       if (sharding_proto.type() == ::xla::OpSharding::TUPLE) {
         *sharding_proto.add_tuple_shardings() =
             ::xla::sharding_builder::AssignDevice(0);
-        data_and_token.setAttr(
+        data_and_token->setAttr(
             kShardingAttr,
             rewriter.getStringAttr(sharding_proto.SerializeAsString()));
       } else {
-        data_and_token.setAttr(kShardingAttr, op._XlaShardingAttr());
+        data_and_token->setAttr(kShardingAttr, op._XlaShardingAttr());
       }
     }
 
@@ -5155,7 +5157,7 @@
         /*call_target_name=*/rewriter.getStringAttr("Sharding"),
         /*has_side_effect=*/rewriter.getBoolAttr(false),
         /*backend_config=*/rewriter.getStringAttr(""));
-    custom_call.setAttr(kShardingAttr, op._XlaShardingAttr());
+    custom_call->setAttr(kShardingAttr, op._XlaShardingAttr());
     rewriter.replaceOp(op, custom_call.getResult(0));
 
     return success();
@@ -5407,7 +5409,8 @@
     window_dims[axis] = input_shape[axis];
 
     SmallVector<int64_t, 8> paddings(rank * 2, 0);
-    paddings[axis * 2] = input_shape[axis] - 1;
+    paddings[axis * 2] =
+        std::max(input_shape[axis] - 1, static_cast<int64_t>(0));
     auto paddings_attr = DenseIntElementsAttr::get(
         RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
         paddings);
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc
index 6056ebe..ef14408 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc
@@ -29,7 +29,7 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/IR/Visitors.h"  // from @llvm-project
@@ -281,7 +281,7 @@
       /*type=*/builder.getI64IntegerAttr(3), builder.getContext());
   auto result_type = result.getType();
   auto recv_result_type =
-      TupleType::get({result_type, token.getType()}, builder.getContext());
+      TupleType::get(builder.getContext(), {result_type, token.getType()});
   auto recv =
       builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
                              /*is_host_transfer=*/builder.getBoolAttr(true));
@@ -712,8 +712,8 @@
   auto new_argument_types = llvm::to_vector<4>(func_body.getArgumentTypes());
   auto new_result_types =
       llvm::to_vector<4>(func_body.getTerminator()->getOperandTypes());
-  func.setType(FunctionType::get(new_argument_types, new_result_types,
-                                 builder.getContext()));
+  func.setType(FunctionType::get(builder.getContext(), new_argument_types,
+                                 new_result_types));
 }
 
 // Replaces a function terminator `return` with another `return` that has an
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
index 5cee6dd..eb602c3 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
@@ -28,9 +28,9 @@
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index b363e2a..77d804c 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -101,7 +101,7 @@
   def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
 
 def LowerRightShiftSigned :
-  Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r),
+  Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r),
       (HLOClient_BroadcastShiftRightArithmeticOp $l, $r,
        (BinBroadcastDimensions $l, $r)),
       [(SignedIntTensor $r)]>;
@@ -114,7 +114,7 @@
 // Performs a substitution of FloorDiv, pseudo code below:
 //
 //  return floor(div(x, y))
-def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
+def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
           (HLO_FloorOp
            (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
           [(IEEEFloatTensor $l)]>;
@@ -166,7 +166,7 @@
 //   return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
 // Requires static shaped inputs to create constant splats and computation of
 // broadcast attributes.
-def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
+def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
       (HLO_SelectOp
        (HLOClient_BroadcastAndOp
         (HLOClient_BroadcastCompareOp
@@ -188,19 +188,22 @@
        (HLOClient_BroadcastAddOp $r,
         $rem, (BinBroadcastDimensions $r, $rem)), $rem)>;
 
+def : Pat<(TF_RiscAddOp $l, $r), (HLO_AddOp $l, $r)>;
+
 //===----------------------------------------------------------------------===//
 // Logical & bitwise binary op patterns.
 //===----------------------------------------------------------------------===//
 
 class DirectLogicalBinaryPat<Op FromOp, Op ToOp>
-  : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r),
+  : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
         (ToOp $l, $r, (BinBroadcastDimensions $l, $r)),
         [(SignedIntTensor $l)]>;
 
 foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp],
                          [TF_LogicalOrOp, HLOClient_BroadcastOrOp],
+                         [TF_BitwiseAndOp, HLOClient_BroadcastAndOp],
                          [TF_BitwiseOrOp, HLOClient_BroadcastOrOp],
-                         [TF_BitwiseAndOp, HLOClient_BroadcastAndOp]] in
+                         [TF_BitwiseXorOp, HLOClient_BroadcastXorOp]] in
   def : DirectLogicalBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index b1d05aa..d455de1 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -28,11 +28,11 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
@@ -154,9 +154,11 @@
     TypeID::get<TF::IgammaOp>(),
     TypeID::get<TF::IgammacOp>(),
     TypeID::get<TF::IgammaGradAOp>(),
+    TypeID::get<TF::InplaceAddOp>(),
     TypeID::get<TF::InTopKV2Op>(),
     TypeID::get<TF::InvertOp>(),
     TypeID::get<TF::InvOp>(),
+    TypeID::get<TF::KthOrderStatisticOp>(),
     TypeID::get<TF::LRNOp>(),
     TypeID::get<TF::LRNGradOp>(),
     TypeID::get<TF::LeakyReluGradOp>(),
@@ -170,6 +172,7 @@
     TypeID::get<TF::LogicalOrOp>(),
     TypeID::get<TF::LogOp>(),
     TypeID::get<TF::LowerBoundOp>(),
+    TypeID::get<TF::MakeUniqueOp>(),
     TypeID::get<TF::MatMulOp>(),
     TypeID::get<TF::MatrixDiagV3Op>(),
     TypeID::get<TF::MatrixInverseOp>(),
@@ -248,6 +251,8 @@
     TypeID::get<TF::TensorScatterAddOp>(),
     TypeID::get<TF::TensorScatterSubOp>(),
     TypeID::get<TF::TPUEmbeddingActivationsOp>(),
+    TypeID::get<TF::TopKUniqueOp>(),
+    TypeID::get<TF::TopKWithUniqueOp>(),
     TypeID::get<TF::TransposeOp>(),
     TypeID::get<TF::TridiagonalSolveOp>(),
     TypeID::get<TF::TruncateDivOp>(),
@@ -255,7 +260,6 @@
     TypeID::get<TF::TruncateModOp>(),
     TypeID::get<TF::UnpackOp>(),
     TypeID::get<TF::UpperBoundOp>(),
-    TypeID::get<TF::XdivyOp>(),
     TypeID::get<TF::XlaBroadcastHelperOp>(),
     TypeID::get<TF::XlaConvOp>(),
     TypeID::get<TF::XlaDotOp>(),
@@ -265,8 +269,6 @@
     TypeID::get<TF::XlaKeyValueSortOp>(),
     TypeID::get<TF::XlaPadOp>(),
     TypeID::get<TF::XlaSetDynamicDimensionSizeOp>(),
-    TypeID::get<TF::Xlog1pyOp>(),
-    TypeID::get<TF::XlogyOp>(),
     TypeID::get<TF::XlaSortOp>(),
     TypeID::get<TF::XlaSvdOp>()
   };
@@ -350,7 +352,8 @@
   // XlaCompiler within the context is only used by the functional ops to
   // compile functions. We are not handling those at the moment so XlaCompiler
   // is not required.
-  context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_);
+  context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
+                                        /*graph=*/nullptr);
   context_->Ref();
 
   device_mgr_ = CreateDeviceMgr(device_type_);
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index f4ad6f0..3bc8afc 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -19,6 +19,7 @@
 #include <memory>
 #include <tuple>
 
+#include "absl/algorithm/container.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
@@ -27,13 +28,13 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassOptions.h"  // from @llvm-project
@@ -42,6 +43,7 @@
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
+#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
@@ -120,7 +122,7 @@
   // Run all HLO passes to produce an optimized module.
   auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
       std::move(hlo_module), backend->default_stream_executor(),
-      backend->memory_allocator(), optimize_xla_hlo);
+      optimize_xla_hlo, {backend->memory_allocator()});
   TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
                                   "running XLA pass pipeline");
   std::unique_ptr<HloModule> optimized_hlo_module =
@@ -196,11 +198,16 @@
 
 }  // namespace
 
+// Creates MLIR operands corresponding to operands and results of the XLA HLO
+// instruction. If `num_operands` is not -1, then only the first `num_operands`
+// operands of the HLO instruction will be considered.
 Status LhloDialectEmitter::CreateOperands(
     HloInstruction* instr, llvm::SmallVectorImpl<Value>& operands,
-    size_t& num_arguments, size_t& num_results) {
-  for (const HloInstruction* operand : instr->operands()) {
-    TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
+    size_t& num_arguments, size_t& num_results,
+    absl::optional<xla::int64> num_operands) {
+  for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
+       i++) {
+    TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands));
   }
   num_arguments = operands.size();
   TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands));
@@ -209,17 +216,17 @@
 }
 
 template <typename OpType>
-StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(HloInstruction* instr,
-                                                          size_t& num_arguments,
-                                                          size_t& num_results) {
+StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
+    HloInstruction* instr, size_t& num_arguments, size_t& num_results,
+    absl::optional<xla::int64> num_operands) {
   Location loc = getLocation(instr);
   std::pair<Identifier, Attribute> attrs[] = {
       {Identifier::get("name", builder_.getContext()),
        builder_.getStringAttr(instr->name())},
   };
   llvm::SmallVector<Value, 4> operands;
-  TF_RETURN_IF_ERROR(
-      CreateOperands(instr, operands, num_arguments, num_results));
+  TF_RETURN_IF_ERROR(CreateOperands(instr, operands, num_arguments, num_results,
+                                    num_operands));
   return builder_.create<OpType>(loc, llvm::None, operands, attrs);
 }
 
@@ -398,7 +405,7 @@
     llvm::SmallVector<int64_t, 4> minor_to_major(
         shape.layout().minor_to_major().begin(),
         shape.layout().minor_to_major().end());
-    load.setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major));
+    load->setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major));
   }
   return load.getResult();
 }
@@ -556,6 +563,14 @@
     return EmitGemm(custom_call_instr);
   }
 
+  if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) {
+    return EmitDnnConvolution(custom_call_instr);
+  }
+
+  if (xla::gpu::IsCustomCallToDnnBatchNorm(*instr)) {
+    return EmitDnnBatchNorm(custom_call_instr);
+  }
+
   size_t num_arguments, num_results;
   TF_ASSIGN_OR_RETURN(auto custom_call,
                       CreateOpWithoutAttrs<lmhlo::CustomCallOp>(
@@ -566,8 +581,8 @@
       builder_.getStringAttr(custom_call_instr->opaque()));
   const int32_t segments[2] = {static_cast<int32_t>(num_arguments),
                                static_cast<int32_t>(num_results)};
-  custom_call.setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
-                      builder_.getI32VectorAttr(segments));
+  custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
+                       builder_.getI32VectorAttr(segments));
   return custom_call.getOperation();
 }
 
@@ -623,6 +638,196 @@
   return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands");
 }
 
+static StatusOr<mlir::lmhlo_gpu::Activation> GetLHLOActivation(
+    stream_executor::dnn::ActivationMode activation) {
+  switch (activation) {
+    case stream_executor::dnn::kNone:
+      return mlir::lmhlo_gpu::Activation::None;
+    case stream_executor::dnn::kSigmoid:
+      return mlir::lmhlo_gpu::Activation::Sigmoid;
+    case stream_executor::dnn::kRelu:
+      return mlir::lmhlo_gpu::Activation::Relu;
+    case stream_executor::dnn::kRelu6:
+      return mlir::lmhlo_gpu::Activation::Relu6;
+    case stream_executor::dnn::kReluX:
+      return mlir::lmhlo_gpu::Activation::ReluX;
+    case stream_executor::dnn::kTanh:
+      return mlir::lmhlo_gpu::Activation::Tanh;
+    case stream_executor::dnn::kBandPass:
+      return mlir::lmhlo_gpu::Activation::BandPass;
+    default:
+      return xla::InternalError("Unknown activation");
+  }
+}
+
+StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
+    HloCustomCallInstruction* custom_call) {
+  TF_ASSIGN_OR_RETURN(
+      auto const backend_config,
+      custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>());
+
+  TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind,
+                      xla::gpu::GetCudnnConvKind(custom_call));
+
+  auto get_layout_attribute = [&](const xla::Layout& layout) {
+    std::vector<int64_t> minor_to_major(layout.minor_to_major_size());
+    absl::c_transform(layout.minor_to_major(), minor_to_major.begin(),
+                      [](xla::int64 x) { return static_cast<int64_t>(x); });
+    return builder_.getI64ArrayAttr(minor_to_major);
+  };
+
+  auto set_common_conv_attributes = [&, this](auto op) -> Operation* {
+    const xla::Window& window = custom_call->window();
+    // Window size for Cudnn Conv is same as the kernel size.
+    op.window_stridesAttr(
+        GetWindowElements(window, [](const ::xla::WindowDimension& dim) {
+          return static_cast<int64_t>(dim.stride());
+        }));
+    // Cudnn Conv requires low and high padding to be equal.
+    op.paddingAttr(
+        GetWindowElements(window, [](const ::xla::WindowDimension& dim) {
+          return static_cast<int64_t>(dim.padding_low());
+        }));
+    // LHS dilation is encoded in base_dilation of the backend config.
+    // RHS dilation is encoded in window_dilation of the backend config.
+    op.lhs_dilationAttr(
+        GetWindowElements(window, [](const ::xla::WindowDimension& dim) {
+          return static_cast<int64_t>(dim.base_dilation());
+        }));
+    op.rhs_dilationAttr(
+        GetWindowElements(window, [](const ::xla::WindowDimension& dim) {
+          return static_cast<int64_t>(dim.window_dilation());
+        }));
+    // Setup window reversal.
+    auto window_reversal = llvm::to_vector<4>(llvm::map_range(
+        window.dimensions(), [](const ::xla::WindowDimension& dim) {
+          return dim.window_reversal();
+        }));
+    auto type = RankedTensorType::get(op.window_strides()->getType().getShape(),
+                                      builder_.getIntegerType(/*width=*/1));
+    op.window_reversalAttr(DenseElementsAttr::get(type, window_reversal));
+
+    op.dimension_numbersAttr(xla::ConvertConvDimensionNumbers(
+        custom_call->convolution_dimension_numbers(), &builder_));
+    op.feature_group_countAttr(
+        builder_.getI64IntegerAttr(custom_call->feature_group_count()));
+    op.batch_group_countAttr(
+        builder_.getI64IntegerAttr(custom_call->batch_group_count()));
+    op.precision_configAttr(xla::ConvertPrecisionConfig(
+        &custom_call->precision_config(), &builder_));
+    op.result_scaleAttr(
+        builder_.getF64FloatAttr(backend_config.conv_result_scale()));
+    auto config = mlir::lmhlo_gpu::ConvolutionBackendConfig::get(
+        builder_.getI64IntegerAttr(backend_config.algorithm()),
+        builder_.getBoolAttr(backend_config.tensor_ops_enabled()),
+        get_layout_attribute(custom_call->operand(0)->shape().layout()),
+        get_layout_attribute(custom_call->operand(1)->shape().layout()),
+        get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()),
+        builder_.getContext());
+    op.backend_configAttr(config);
+
+    return op.getOperation();
+  };
+
+  auto set_activation = [&, this](auto op) -> Status {
+    auto se_activation = static_cast<stream_executor::dnn::ActivationMode>(
+        backend_config.activation_mode());
+    TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation,
+                        GetLHLOActivation(se_activation));
+    StringAttr activation_attr = builder_.getStringAttr(
+        mlir::lmhlo_gpu::stringifyActivation(activation));
+    op.activation_modeAttr(activation_attr);
+    return Status::OK();
+  };
+
+  switch (kind) {
+    case xla::gpu::CudnnConvKind::kForward: {
+      TF_ASSIGN_OR_RETURN(
+          auto cnn_forward,
+          CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardOp>(custom_call));
+      return set_common_conv_attributes(cnn_forward);
+    }
+    case xla::gpu::CudnnConvKind::kBackwardInput: {
+      TF_ASSIGN_OR_RETURN(
+          auto cnn_backward,
+          CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardInputOp>(custom_call));
+      return set_common_conv_attributes(cnn_backward);
+    }
+    case xla::gpu::CudnnConvKind::kBackwardFilter: {
+      TF_ASSIGN_OR_RETURN(
+          auto cnn_backward,
+          CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardFilterOp>(custom_call));
+      return set_common_conv_attributes(cnn_backward);
+    }
+    case xla::gpu::CudnnConvKind::kForwardActivation: {
+      // Fused conv can be either with side input or without.
+      if (custom_call->operand_count() == 3) {
+        TF_ASSIGN_OR_RETURN(
+            auto cnn_fused,
+            CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedOp>(custom_call));
+        TF_RETURN_IF_ERROR(set_activation(cnn_fused));
+        return set_common_conv_attributes(cnn_fused);
+      }
+
+      TF_RET_CHECK(custom_call->operand_count() == 4);
+      TF_ASSIGN_OR_RETURN(
+          auto cnn_fused_side_input,
+          CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedSideInputOp>(
+              custom_call));
+      cnn_fused_side_input.side_input_scaleAttr(
+          builder_.getF64FloatAttr(backend_config.side_input_scale()));
+      TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input));
+      return set_common_conv_attributes(cnn_fused_side_input);
+    }
+  }
+}
+
+StatusOr<Operation*> LhloDialectEmitter::EmitDnnBatchNorm(
+    HloCustomCallInstruction* custom_call) {
+  const xla::int64 num_operands = custom_call->operand_count();
+  auto set_batchnorm_attributes = [&](auto op) -> StatusOr<Operation*> {
+    // The last 2 operands of a custom call for batch norm are the epsilon and
+    // feature_index.
+    const HloInstruction* epsilon = custom_call->operand(num_operands - 2);
+    TF_RET_CHECK(epsilon->IsConstant());
+    float epsilon_value = epsilon->literal().Get<float>({});
+
+    const HloInstruction* feature_index =
+        custom_call->operand(num_operands - 1);
+    TF_RET_CHECK(feature_index->IsConstant());
+    xla::int64 feature_index_value =
+        feature_index->literal().Get<xla::int64>({});
+
+    op.epsilonAttr(builder_.getF32FloatAttr(epsilon_value));
+    op.feature_indexAttr(builder_.getI64IntegerAttr(feature_index_value));
+    return op.getOperation();
+  };
+
+  const std::string& target = custom_call->custom_call_target();
+  if (target == xla::gpu::kCudnnBatchNormForwardTrainingCallTarget) {
+    TF_ASSIGN_OR_RETURN(auto fwd_training,
+                        CreateOpWithoutAttrs<lmhlo_gpu::BatchNormTrainingOp>(
+                            custom_call, num_operands - 2));
+    return set_batchnorm_attributes(fwd_training);
+  }
+
+  if (target == xla::gpu::kCudnnBatchNormBackwardCallTarget) {
+    TF_ASSIGN_OR_RETURN(auto backward,
+                        CreateOpWithoutAttrs<lmhlo_gpu::BatchNormGradOp>(
+                            custom_call, num_operands - 2));
+    return set_batchnorm_attributes(backward);
+  }
+
+  if (target == xla::gpu::kCudnnBatchNormForwardInferenceCallTarget) {
+    TF_ASSIGN_OR_RETURN(auto fwd_inference,
+                        CreateOpWithoutAttrs<lmhlo_gpu::BatchNormInferenceOp>(
+                            custom_call, num_operands - 2));
+    return set_batchnorm_attributes(fwd_inference);
+  }
+
+  return xla::Unimplemented("Unsupported batch norm operation");
+}
+
 // Convert an XLA HLO constant to a global_memref + get_global_memref pair.
 StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant(
     const HloInstruction* instr) {
@@ -664,12 +869,12 @@
 
   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
                       assignment_.GetUniqueTopLevelSlice(instr));
-  get_global_memref.setAttr("lmhlo.alloc",
-                            builder_.getIndexAttr(slice.index()));
-  get_global_memref.setAttr("lmhlo.slice_offset",
-                            builder_.getI64IntegerAttr(slice.offset()));
-  get_global_memref.setAttr("lmhlo.slice_size",
-                            builder_.getI64IntegerAttr(slice.size()));
+  get_global_memref->setAttr("lmhlo.alloc",
+                             builder_.getIndexAttr(slice.index()));
+  get_global_memref->setAttr("lmhlo.slice_offset",
+                             builder_.getI64IntegerAttr(slice.offset()));
+  get_global_memref->setAttr("lmhlo.slice_size",
+                             builder_.getI64IntegerAttr(slice.size()));
 
   // Update the cache to remember this value.
   auto& cached_value = slices_[std::make_pair(instr, ::xla::ShapeIndex())];
@@ -763,7 +968,7 @@
   auto* all_reduce = ::xla::Cast<::xla::HloAllReduceInstruction>(instr);
   auto replica_groups_attr = ::xla::HloFunctionImporter::ConvertReplicaGroups(
       all_reduce->replica_groups(), builder_);
-  all_reduce_op.setAttr(replica_groups_attr.first, replica_groups_attr.second);
+  all_reduce_op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
   all_reduce_op.constrain_layoutAttr(
       builder_.getBoolAttr(all_reduce->constrain_layout()));
   all_reduce_op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
index 6c7bdd8..73d1593 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
@@ -16,11 +16,12 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
 #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
 
+#include "absl/types/optional.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
@@ -63,6 +64,10 @@
       ::xla::HloCustomCallInstruction* custom_call);
   ::xla::StatusOr<Operation*> EmitGemm(
       ::xla::HloCustomCallInstruction* custom_call);
+  ::xla::StatusOr<Operation*> EmitDnnConvolution(
+      ::xla::HloCustomCallInstruction* custom_call);
+  ::xla::StatusOr<Operation*> EmitDnnBatchNorm(
+      ::xla::HloCustomCallInstruction* custom_call);
 
   ::xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(::xla::HloInstruction* instr);
   ::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
@@ -82,20 +87,23 @@
   ::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
       ::xla::HloInstruction* instr);
 
-  ::xla::Status CreateOperands(::xla::HloInstruction* instr,
-                               SmallVectorImpl<Value>& operands,
-                               size_t& num_arguments, size_t& num_results);
+  ::xla::Status CreateOperands(
+      ::xla::HloInstruction* instr, SmallVectorImpl<Value>& operands,
+      size_t& num_arguments, size_t& num_results,
+      absl::optional<xla::int64> num_operands = absl::nullopt);
 
   template <typename OpType>
-  ::xla::StatusOr<OpType> CreateOpWithoutAttrs(::xla::HloInstruction* instr) {
+  ::xla::StatusOr<OpType> CreateOpWithoutAttrs(
+      ::xla::HloInstruction* instr,
+      absl::optional<xla::int64> num_operands = absl::nullopt) {
     size_t unused;
-    return CreateOpWithoutAttrs<OpType>(instr, unused, unused);
+    return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands);
   }
 
   template <typename OpType>
-  ::xla::StatusOr<OpType> CreateOpWithoutAttrs(::xla::HloInstruction* instr,
-                                               size_t& num_arguments,
-                                               size_t& num_results);
+  ::xla::StatusOr<OpType> CreateOpWithoutAttrs(
+      ::xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results,
+      absl::optional<xla::int64> num_operands = absl::nullopt);
 
   template <typename T>
   DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc
index 3822e10..049e541 100644
--- a/tensorflow/compiler/mlir/xla/type_to_shape.cc
+++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc
@@ -18,9 +18,9 @@
 #include <string>
 
 #include "mlir/IR/AffineMap.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
index 9741774..bb63611 100644
--- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
+++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
@@ -18,8 +18,8 @@
 #include <iostream>
 
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index f0941d3..1f8e708 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -477,7 +477,6 @@
     python_version = "PY3",
     tags = [
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
-        "no_rocm",
     ],
     deps = [
         ":xla_test",
@@ -791,6 +790,7 @@
     name = "listdiff_op_test",
     size = "small",
     srcs = ["listdiff_op_test.py"],
+    enable_mlir_bridge = True,
     python_version = "PY3",
     tags = [
         "no_cuda_asan",  # times out
@@ -1390,6 +1390,7 @@
     name = "unary_ops_test",
     size = "medium",
     srcs = ["unary_ops_test.py"],
+    enable_mlir_bridge = True,
     python_version = "PY3",
     tags = [
         "no_cuda_asan",  # times out
@@ -1627,7 +1628,6 @@
     shard_count = 5,
     tags = [
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
-        "no_rocm",
     ],
     xla_enable_strict_auto_jit = False,
     xla_enabled = True,
@@ -1652,7 +1652,6 @@
     srcs = ["dense_layer_test.py"],
     tags = [
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
-        "no_rocm",
     ],
     xla_enable_strict_auto_jit = False,
     xla_enabled = True,
@@ -1747,7 +1746,6 @@
     srcs = ["lstm_test.py"],
     tags = [
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
-        "no_rocm",
     ],
     xla_enable_strict_auto_jit = False,
     xla_enabled = True,
@@ -1871,7 +1869,6 @@
     tags = [
         "no_oss",  # TODO(b/148108508): Re-enable this test in OSS.
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
-        "no_rocm",
     ],
     deps = [
         ":xla_test",
@@ -1965,3 +1962,22 @@
         "//tensorflow/python/compiler/xla:compiler_py",
     ],
 )
+
+tf_xla_py_test(
+    name = "risc_ops_test",
+    size = "small",
+    srcs = ["risc_ops_test.py"],
+    enabled_backends = ["cpu"],
+    python_version = "PY3",
+    tags = [
+        "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
+    ],
+    deps = [
+        ":xla_test",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:is_mlir_bridge_test_true",
+        "//tensorflow/python:platform_test",
+        "//tensorflow/python/eager:function",
+        "//tensorflow/python/ops/risc:risc_ops",
+    ],
+)
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 94b34cf..957a888 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1009,7 +1009,6 @@
           np.array([], dtype=dtype).reshape((0, 3)),
           expected=np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype))
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testMatMul(self):
     self._testMatMul(math_ops.matmul, self.float_types | {np.float64})
 
@@ -1047,7 +1046,6 @@
     self._testMatMul(SparseMatmulWrapperFT, self.float_types)
     self._testMatMul(SparseMatmulWrapperTT, self.float_types)
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testBatchMatMul(self):
     # Tests with batches of matrices.
     for dtype in self.float_types | {np.float64}:
@@ -1099,7 +1097,6 @@
             x,
             expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testExpandDims(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -1367,7 +1364,6 @@
               np.reshape(np.array([16, 18, 8], dtype=dtype), (3, 1)),
               (1, 2, 3, 1)))
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testReshape(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -1499,7 +1495,6 @@
                [1, 2]],
               dtype=dtype))
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testTranspose(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -1518,7 +1513,6 @@
           np.array([1, 0], dtype=np.int32),
           expected=np.array([[1, 3], [2, 4]], dtype=dtype))
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testConjugateTranspose(self):
     for dtype in self.complex_types:
       self._testBinary(
@@ -1555,7 +1549,6 @@
           np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype),
           expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype))
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testBroadcastArgs(self):
     self._testBinary(array_ops.broadcast_dynamic_shape,
                      np.array([2, 3, 5], dtype=np.int32),
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 5d2b8a6..41107d0 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -368,7 +368,6 @@
         ans = self.evaluate(packed)
         self.assertAllEqual(ans, [2, 3, 5])
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testEmpty(self):
     with self.session():
       with self.test_scope():
diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py
index d0f6229..8210dff 100644
--- a/tensorflow/compiler/tests/nary_ops_test.py
+++ b/tensorflow/compiler/tests/nary_ops_test.py
@@ -24,7 +24,6 @@
 
 from tensorflow.compiler.tests import xla_test
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import googletest
@@ -159,7 +158,6 @@
                     np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)]
         self.assertAllEqual(output, expected)
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testStridedSlice(self):
     self._testNAry(lambda x: array_ops.strided_slice(*x),
                    [np.array([[], [], []], dtype=np.float32),
@@ -204,7 +202,6 @@
                              dtype=np.float32)],
                    expected=np.array([[4], [5], [6]], dtype=np.float32))
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testStridedSliceGrad(self):
     # Tests cases where input shape is empty.
     self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
index d6d97fd..75fe7d6 100644
--- a/tensorflow/compiler/tests/qr_op_test.py
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -19,6 +19,7 @@
 from __future__ import print_function
 
 import itertools
+import unittest
 
 from absl.testing import parameterized
 import numpy as np
@@ -129,6 +130,11 @@
     x_np = self._random_matrix(np.float32, (2000, 2000))
     self._test(x_np, full_matrices=True)
 
+  @unittest.skip("Test times out on CI")
+  def testLarge17500x128(self):
+    x_np = self._random_matrix(np.float32, (17500, 128))
+    self._test(x_np, full_matrices=True)
+
   @parameterized.parameters((23, 25), (513, 23))
   def testZeroColumn(self, rows, cols):
     x_np = self._random_matrix(np.complex64, (rows, cols))
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index fe1d2c5..b890960 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -98,27 +98,22 @@
   ]
   ONES = [np.ones([34000, 2])]
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceSumF32(self, index_dtype):
     self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA,
                         index_dtype)
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceSumC64(self, index_dtype):
     self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
                         self.COMPLEX_DATA, index_dtype)
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceProdF32(self, index_dtype):
     self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
                         self.REAL_DATA, index_dtype)
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceProdC64(self, index_dtype):
     self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
                         self.COMPLEX_DATA, index_dtype)
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceMin(self, index_dtype):
 
     def reference_min(dtype, inp, axis):
@@ -136,7 +131,6 @@
                           functools.partial(reference_min, dtype), dtype,
                           self.REAL_DATA, index_dtype)
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceMax(self, index_dtype):
 
     def reference_max(dtype, inp, axis):
@@ -155,7 +149,6 @@
                           functools.partial(reference_max, dtype), dtype,
                           self.REAL_DATA, index_dtype)
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceMeanF32(self, index_dtype):
     # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
     # reducing across zero inputs.
@@ -171,12 +164,10 @@
     self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
                         self.NONEMPTY_COMPLEX_DATA, index_dtype)
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceAll(self, index_dtype):
     self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA,
                         index_dtype)
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceAny(self, index_dtype):
     self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA,
                         index_dtype)
diff --git a/tensorflow/compiler/tests/risc_ops_test.py b/tensorflow/compiler/tests/risc_ops_test.py
new file mode 100644
index 0000000..0aa0936
--- /dev/null
+++ b/tensorflow/compiler/tests/risc_ops_test.py
@@ -0,0 +1,44 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for RISC Ops."""
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops.risc import risc_ops
+from tensorflow.python.platform import test
+
+
+class XlaRiscOpsTest(xla_test.XLATestCase):
+
+  def testRiscAddBasic(self):
+
+    @def_function.function(jit_compile=True)
+    def f(a, b):
+      return risc_ops.risc_add(a, b)
+
+    l1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
+                              dtype=dtypes.float32)
+    l2 = constant_op.constant([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
+                              dtype=dtypes.float32)
+    l = f(l1, l2)
+    self.assertAllEqual(l, [[8.0, 10.0], [12.0, 14.0], [16.0, 18.0]])
+
+
+if __name__ == "__main__":
+  ops.enable_eager_execution()
+  test.main()
diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py
index aa72f47..440b767 100644
--- a/tensorflow/compiler/tests/scan_ops_test.py
+++ b/tensorflow/compiler/tests/scan_ops_test.py
@@ -91,7 +91,6 @@
       for reverse in [True, False]:
         self._compare(x, axis, exclusive, reverse)
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testEmpty(self):
     for dtype in self.valid_dtypes:
       x = np.zeros([0]).astype(dtype)
@@ -171,7 +170,6 @@
       for reverse in [True, False]:
         self._compare(x, axis, exclusive, reverse)
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testEmpty(self):
     for dtype in self.valid_dtypes:
       x = np.zeros([0]).astype(dtype)
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index ea4e72b..b4b3b4f 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -18,14 +18,18 @@
 from __future__ import division
 from __future__ import print_function
 
+import unittest
 from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.compiler.tests import xla_test
 from tensorflow.compiler.tf2xla.python import xla
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.platform import test
 
@@ -33,6 +37,7 @@
 class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
 
   def _assertOpOutputMatchesExpected(self, op, args, expected):
+    """Tests that op(*args) == expected."""
     with self.session() as session:
       with self.test_scope():
         placeholders = [
@@ -48,37 +53,34 @@
       for result, v in zip(results, expected):
         self.assertAllClose(v, result, rtol=1e-3)
 
+  def _shuffled_arange(self, shape, dtype):
+    x = np.arange(np.prod(shape), dtype=dtype)
+    np.random.shuffle(x)
+    return x.reshape(shape)
+
+  def _supported_key_types(self):
+    supported_key_types = set([
+        dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
+        np.int32, np.uint32, np.int16, np.uint16, np.int8, np.uint8
+    ])
+    res = supported_key_types.intersection(self.numeric_types)
+    assert res
+    return res
+
   def testSort(self):
-    supported_types = set(
-        [dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
-    for dtype in supported_types.intersection(self.numeric_types):
-      # TPU implementation is not supported for double precision
-      if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
-        continue
-      x = np.arange(101, dtype=dtype)
-      np.random.shuffle(x)
+    for dtype in self._supported_key_types():
+      x = self._shuffled_arange((101,), dtype)
       self._assertOpOutputMatchesExpected(
           xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
 
   def testKeyValueSort(self):
-    supported_key_types = set([
-        dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
-        np.int32, np.uint32
-    ])
-    supported_value_types = set([
-        dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
-        np.int32, np.uint32, dtypes.int64.as_numpy_dtype,
-        dtypes.uint64.as_numpy_dtype
-    ])
-    for key_type in supported_key_types.intersection(self.numeric_types):
-      for value_type in supported_value_types.intersection(self.numeric_types):
-        if key_type == np.float64 or value_type == np.float64 or \
-            key_type == np.float16 or value_type == np.float16:
-          # TPU implementation is not supported for double precision
-          if self.device == "TPU":
-            continue
-        x = np.arange(101, dtype=key_type)
-        np.random.shuffle(x)
+    for key_type in self._supported_key_types():
+      for value_type in self._supported_key_types():
+        if key_type == np.uint8 or value_type == np.uint8:
+          # I do not understand why the test fails on uint8. We plan to
+          # deprecate xla.key_value_sort in favor of xla.variadic_sort anyway.
+          continue
+        x = self._shuffled_arange((101,), key_type)
         y = (-x).astype(value_type)
         self._assertOpOutputMatchesExpected(
             xla.key_value_sort, [x, y],
@@ -87,6 +89,156 @@
                 -np.arange(101, dtype=value_type)
             ])
 
+  @parameterized.parameters(0, 1, 2)
+  @test_util.disable_mlir_bridge("Not supported yet")
+  def testVariadicSortDimension(self, dimension):
+    shape = (2, 3, 4)
+    for key_type in self._supported_key_types():
+      x = self._shuffled_arange(shape, key_type)
+      expected = np.sort(x, axis=dimension)
+
+      @function.Defun(key_type, key_type)
+      def compare_lt(x1, x2):
+        return x1 < x2
+
+      def wrap_sort(x):
+        return xla.variadic_sort([x],
+                                 dimension=dimension,
+                                 is_stable=False,
+                                 comparator=compare_lt)
+
+      self._assertOpOutputMatchesExpected(wrap_sort, [x], expected=[expected])
+
+  @test_util.disable_mlir_bridge("Not supported yet")
+  def testVariadicSortReverse(self):
+    shape = (100,)
+    for key_type in self._supported_key_types():
+      x = self._shuffled_arange(shape, key_type)
+      expected = np.sort(x, axis=0)[::-1]
+
+      @function.Defun(key_type, key_type)
+      def compare_gt(x1, x2):
+        return x1 > x2
+
+      def wrap_sort(x):
+        return xla.variadic_sort([x],
+                                 dimension=0,
+                                 is_stable=False,
+                                 comparator=compare_gt)
+
+      self._assertOpOutputMatchesExpected(wrap_sort, [x], expected=[expected])
+
+  @parameterized.parameters(0, 1, 2)
+  @test_util.disable_mlir_bridge("Not supported yet")
+  def testVariadicSortSeveral(self, dimension):
+    if np.__version__ < "1.15":
+      raise unittest.SkipTest("np.take_along_axis was added in 1.15")
+    shape = (2, 3, 4)
+    for key_type in self._supported_key_types():
+      for value_type_1 in self._supported_key_types():
+        for value_type_2 in self._supported_key_types():
+          inputs = [
+              self._shuffled_arange(shape, key_type),
+              self._shuffled_arange(shape, value_type_1),
+              self._shuffled_arange(shape, value_type_2)
+          ]
+
+          # The first array is sorted, and the others are shuffled the same way
+          sorted_indices = np.argsort(inputs[0], axis=dimension)
+          expected = [
+              np.take_along_axis(inp, sorted_indices, axis=dimension)
+              for inp in inputs
+          ]
+          self.assertAllEqual(np.sort(inputs[0], axis=dimension), expected[0])
+
+          @function.Defun(key_type, key_type, value_type_1, value_type_1,
+                          value_type_2, value_type_2)
+          def compare_lt(x1, x2, y1, y2, z1, z2):
+            del y1, y2, z1, z2
+            return x1 < x2
+
+          def wrap_sort(*args):
+            return xla.variadic_sort(
+                args,  # Pass the arguments as a tuple
+                comparator=compare_lt,
+                dimension=dimension,
+                is_stable=False)
+
+          self._assertOpOutputMatchesExpected(
+              wrap_sort, inputs, expected=expected)
+
+  @test_util.disable_mlir_bridge("Not supported yet")
+  def testVariadicSortLexicographic(self):
+    # Three inputs: the first two are used for lexicographic sort, and the
+    # third is just swapped accordingly.
+    # The first array will contain only 0 and 1, to test lexicographic order
+    if np.__version__ < "1.15":
+      raise unittest.SkipTest("np.take_along_axis was added in 1.15")
+    shape = (20,)
+    for key_type_1 in set([np.int16, np.uint16, np.int32, np.uint32]):
+      for key_type_2 in self._supported_key_types():
+        for value_type in self._supported_key_types():
+          inputs = [
+              # Ensure that some keys in the first input are equal
+              np.random.uniform(0, 2, shape).astype(key_type_1),
+              self._shuffled_arange(shape, key_type_2),
+              self._shuffled_arange(shape, value_type)
+          ]
+          # The first two arrays are sorted lexicographically, and the third
+          # is shuffled the same way
+          sorted_indices = np.argsort(100 * inputs[0] + inputs[1])
+          expected = [
+              np.take_along_axis(inp, sorted_indices, axis=0) for inp in inputs
+          ]
+
+          @function.Defun(key_type_1, key_type_1, key_type_2, key_type_2,
+                          value_type, value_type)
+          def compare_lexicographic(x1, x2, y1, y2, z1, z2):
+            del z1, z2
+            return math_ops.logical_or(
+                x1 < x2, math_ops.logical_and(math_ops.equal(x1, x2), y1 < y2))
+
+          def wrap_sort(*args):
+            return xla.variadic_sort(
+                args,  # Pass the arguments as a tuple
+                comparator=compare_lexicographic,
+                dimension=0,
+                is_stable=False)
+
+          self._assertOpOutputMatchesExpected(
+              wrap_sort, inputs, expected=expected)
+
+  @parameterized.parameters(0, 1, 2)
+  @test_util.disable_mlir_bridge("Not supported yet")
+  def testVariadicSortSeveralStable(self, dimension):
+    shape = (2, 3, 4)
+    for key_type in self._supported_key_types():
+      for value_type_1 in self._supported_key_types():
+        for value_type_2 in self._supported_key_types():
+          # The first input is all 0s, there should be no changes for
+          # stable sort.
+          inputs = [
+              np.zeros(shape, key_type),
+              self._shuffled_arange(shape, value_type_1),
+              self._shuffled_arange(shape, value_type_2)
+          ]
+
+          @function.Defun(key_type, key_type, value_type_1, value_type_1,
+                          value_type_2, value_type_2)
+          def compare_lt(x1, x2, y1, y2, z1, z2):
+            del y1, y2, z1, z2
+            return x1 < x2
+
+          def wrap_sort(*args):
+            return xla.variadic_sort(
+                args,  # Pass the arguments as a tuple
+                comparator=compare_lt,
+                dimension=dimension,
+                is_stable=False)
+
+          self._assertOpOutputMatchesExpected(
+              wrap_sort, inputs, expected=inputs)
+
   def testTopK(self):
     supported_types = set([
         dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
index 76b7e18..74f5f7b 100644
--- a/tensorflow/compiler/tests/spacetobatch_op_test.py
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -22,7 +22,6 @@
 
 from tensorflow.compiler.tests import xla_test
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.platform import test
@@ -248,19 +247,16 @@
         outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]],
                  [[4, 41], [6, 61]]])
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testDirect0(self):
     # Test with zero-size remaining dimension.
     self._testDirect(
         input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]])
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testDirect1(self):
     # Test with zero-size blocked dimension.
     self._testDirect(
         input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]])
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testDirect2(self):
     # Test with padding up from zero size.
     self._testDirect(
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 3d310dd..4109fdc 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -63,7 +63,6 @@
     self.assertEqual(result[-1], expected[-1])
     self.assertEqual(result[0], expected[0])
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testRange(self):
     self._testTernary(
         math_ops.range,
@@ -183,7 +182,6 @@
           np.array([8, 9], dtype=dtype),
           expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
 
-  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testSlice(self):
     for dtype in self.numeric_types:
       self._testTernary(
diff --git a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py
index 84aa725..ca50916 100644
--- a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py
+++ b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py
@@ -193,7 +193,6 @@
   def test1x1(self):
     self._test(diags=[[0], [3], [0]], rhs=[6], expected=[2])
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def test0x0(self):
     self._test(
         diags=np.zeros(shape=(3, 0), dtype=np.float32),
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index 254f9ac..271bf66 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -24,7 +24,6 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_control_flow_ops
 from tensorflow.python.platform import test
@@ -32,7 +31,6 @@
 
 class XlaDeviceTest(xla_test.XLATestCase):
 
-  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testCopies(self):
     """Tests that copies onto and off XLA devices work."""
     shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3],
diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD
index 100d243..cb19ab9 100644
--- a/tensorflow/compiler/tf2tensorrt/BUILD
+++ b/tensorflow/compiler/tf2tensorrt/BUILD
@@ -417,6 +417,7 @@
         "//tensorflow/core/grappler/clusters:virtual_cluster",
         "//tensorflow/core/grappler/costs:graph_properties",
         "//tensorflow/core/grappler/optimizers:meta_optimizer",
+        "//tensorflow/core/profiler/lib:annotated_traceme",
         "//tensorflow/stream_executor/lib",
         "//tensorflow/tools/graph_transforms:transform_utils",
     ] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
@@ -595,8 +596,8 @@
     srcs = ["utils/py_utils.cc"],
     hdrs = ["utils/py_utils.h"],
     copts = tf_copts(),
-    defines = select({
-        "@local_config_tensorrt//:use_static_tensorrt": ["TF_OSS_TENSORRT_STATIC=1"],
+    local_defines = select({
+        "@local_config_tensorrt//:use_static_tensorrt": ["TF_USE_TENSORRT_STATIC=1"],
         "//conditions:default": [],
     }),
     deps = if_tensorrt([
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index eed37cd..1f54564 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -56,6 +56,7 @@
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/tensor_coding.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/lib/annotated_traceme.h"
 #include "tensorflow/core/public/version.h"
 #include "tensorflow/core/util/env_var.h"
 #include "tensorflow/core/util/strided_slice_op.h"
@@ -1409,6 +1410,13 @@
     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, int max_batch_size,
     size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator,
     TRTInt8Calibrator* calibrator, TrtShapeOptimizationProfile* profiles) {
+  tensorflow::profiler::AnnotatedTraceMe activity(
+      [&]() {
+        return tensorflow::profiler::TraceMeOpOverride("TRTEngineOp",
+                                                       "BuildEngine");
+      },
+      tensorflow::profiler::TraceMeLevel::kInfo);
+
   VLOG(1) << "Configuring TensorRT builder";
   trt_builder_->setMaxBatchSize(max_batch_size);
   trt_builder_->setGpuAllocator(allocator);
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index 1d60ebb..f09ae20 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -1893,27 +1893,28 @@
 //   how TRT handles the precision inside the TRT network, but should not matter
 //   for the TF -> TRT conversion. Therefore it should be sufficient to test
 //   for FP32.
-class OpConverterTest1 : public ParameterizedOpConverterTestBase {};
+class OpConverter_FP32_Test : public ParameterizedOpConverterTestBase {};
+// Base class for tests that need to be tested for both FP32 and FP16.
+class OpConverter_FP32_FP16_Test : public ParameterizedOpConverterTestBase {};
+// Base class for tests that need to be tested for FP32, FP16, and INT32
+class OpConverter_FP32_FP16_INT32_Test
+    : public ParameterizedOpConverterTestBase {};
 
-// Instantiate parameter combinations to OpConverterTest1
+// Instantiate parameter combinations to OpConverter_<DT_X...>_Test
 INSTANTIATE_TEST_CASE_P(
-    OpConvTestInstantiation, OpConverterTest1,
+    OpConvTestInstantiation, OpConverter_FP32_Test,
     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
                        ::testing::Values(DT_FLOAT),
                        ::testing::Values(TrtPrecisionMode::FP32)));
 
-// Base class for tests that need to be tested for both FP32 and FP16.
-class OpConverterTest2 : public ParameterizedOpConverterTestBase {};
 INSTANTIATE_TEST_CASE_P(
-    OpConvTestInstantiation, OpConverterTest2,
+    OpConvTestInstantiation, OpConverter_FP32_FP16_Test,
     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
                        ::testing::Values(DT_FLOAT, DT_HALF),
                        ::testing::Values(TrtPrecisionMode::FP32)));
 
-// Base class for tests that need to be tested for FP32, FP16, and INT32
-class OpConverterTest3 : public ParameterizedOpConverterTestBase {};
 INSTANTIATE_TEST_CASE_P(
-    OpConvTestInstantiation3, OpConverterTest3,
+    OpConvTestInstantiation, OpConverter_FP32_FP16_INT32_Test,
     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
                        ::testing::Values(DT_FLOAT, DT_HALF, DT_INT32),
                        ::testing::Values(TrtPrecisionMode::FP32)));
@@ -2078,7 +2079,7 @@
       ->def();
 }
 
-TEST_P(OpConverterTest1, ConvertFusedBatchNorm) {
+TEST_P(OpConverter_FP32_Test, ConvertFusedBatchNorm) {
   using OpFunc = std::function<NodeDef(DataType, std::string, bool, float)>;
   std::vector<OpFunc> get_node_def_vec{
       CreateFusedBatchNormOp<ops::FusedBatchNorm>,
@@ -2191,7 +2192,7 @@
   }
 }
 
-TEST_P(OpConverterTest1, ConvertTranspose) {
+TEST_P(OpConverter_FP32_Test, ConvertTranspose) {
   // Get the NodeDef for Transpose.
   Scope s = Scope::NewRootScope();
   auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
@@ -2349,7 +2350,7 @@
   }
 }
 
-TEST_P(OpConverterTest1, ConvertShape) {
+TEST_P(OpConverter_FP32_Test, ConvertShape) {
   // Get the NodeDef for Shape op.
   Scope s = Scope::NewRootScope();
   auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
@@ -2637,7 +2638,7 @@
   TestMatMulHelper(this, get_batch_matmul_nodedef, "BatchMatMul");
 }
 
-TEST_P(OpConverterTest2, ConvertBiasAdd) {
+TEST_P(OpConverter_FP32_FP16_Test, ConvertBiasAdd) {
   // Note that kINT32 is not supported by IScaleLayer, so we don't test
   // DT_INT32 type here. DT_FLOAT and DT_HALF are tested.
   // Get the NodeDef for BiasAdd.
@@ -2710,7 +2711,7 @@
   return op.operation.node()->def();
 }
 
-TEST_P(OpConverterTest2, ConvertBinary) {
+TEST_P(OpConverter_FP32_FP16_Test, ConvertBinary) {
   {
     AttrValue dtype;
     dtype.set_type(tf_type_);
@@ -2974,7 +2975,7 @@
   }
 }
 
-TEST_P(OpConverterTest2, ConvertSquare) {
+TEST_P(OpConverter_FP32_FP16_Test, ConvertSquare) {
   {
     // Input is weights, should fail.
     Reset();
@@ -3127,7 +3128,7 @@
       ->def();
 }
 
-TEST_P(OpConverterTest1, ConvertActivation) {
+TEST_P(OpConverter_FP32_Test, ConvertActivation) {
   {
     // Input is weights, should fail.
     Reset();
@@ -3213,7 +3214,7 @@
   }
 }
 
-TEST_P(OpConverterTest1, ConvertExpandDims) {
+TEST_P(OpConverter_FP32_Test, ConvertExpandDims) {
   // Get the NodeDef for ExpandDims.
   Scope s = Scope::NewRootScope();
   auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
@@ -3290,7 +3291,7 @@
   }
 }
 
-TEST_P(OpConverterTest1, ConvertSqueeze) {
+TEST_P(OpConverter_FP32_Test, ConvertSqueeze) {
   const bool use_implicit_batch = (trt_mode_ == TrtTestMode::kImplicitBatch);
   // Get the NodeDef for Squeeze.
   auto get_squeeze_nodedef = [](std::vector<int> axes,
@@ -4141,7 +4142,7 @@
   }
 }
 
-TEST_P(OpConverterTest1, ConvertConv2D) {
+TEST_P(OpConverter_FP32_Test, ConvertConv2D) {
   // Get nodedef for Conv2D layer.
   DataType tf_type = tf_type_;
   auto get_conv2d_nodedef =
@@ -4835,7 +4836,7 @@
       .operation.node()
       ->def();
 }
-TEST_P(OpConverterTest1, ConvertPool) {
+TEST_P(OpConverter_FP32_Test, ConvertPool) {
   // Get nodedef for MaxPool and AvgPool layers (2D or 3D).
   auto get_pool_nodedef =
       [](DataType tf_type, int nDim, std::vector<int> ksize = {},
@@ -5049,7 +5050,7 @@
   }
 }
 
-TEST_P(OpConverterTest3, ConvertGather) {
+TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertGather) {
   // Get the NodeDef for GatherV2.
   Scope s = Scope::NewRootScope();
   auto params = ops::Placeholder(s.WithOpName("params"), tf_type_);
@@ -5302,7 +5303,7 @@
   }
   return output;
 }
-TEST_P(OpConverterTest1, ConvertReduce) {
+TEST_P(OpConverter_FP32_Test, ConvertReduce) {
   {
     // Input is weights, should fail.
     Reset();
@@ -5428,7 +5429,7 @@
       ->def();
 }
 
-TEST_P(OpConverterTest1, ConvertUnary) {
+TEST_P(OpConverter_FP32_Test, ConvertUnary) {
   {
     // Input is weights, should fail.
     Reset();
@@ -6041,9 +6042,9 @@
 }
 
 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
-TEST_P(OpConverterTest3, ConvertPack) {
+TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertPack) {
 #else
-TEST_P(OpConverterTest2, ConvertPack) {
+TEST_P(OpConverter_FP32_FP16_Test, ConvertPack) {
 #endif
   struct TestParams {
     std::vector<std::vector<int>> input_shapes;
@@ -6606,83 +6607,26 @@
 }
 
 #if IS_TRT_VERSION_GE(5, 1, 2, 0)
-// Get the NodeDef for ClipByValue.
-NodeDef GetClipByValueNodeDef(DataType dtype) {
+TEST_P(OpConverter_FP32_FP16_Test, ConvertClipByValue) {
   Scope s = Scope::NewRootScope();
-  auto t = ops::Placeholder(s.WithOpName("t"), dtype);
-  auto clip_value_min = ops::Placeholder(s.WithOpName("clip_value_min"), dtype);
-  auto clip_value_max = ops::Placeholder(s.WithOpName("clip_value_max"), dtype);
+  auto t = ops::Placeholder(s.WithOpName("t"), tf_type_);
+  auto clip_value_min =
+      ops::Placeholder(s.WithOpName("clip_value_min"), tf_type_);
+  auto clip_value_max =
+      ops::Placeholder(s.WithOpName("clip_value_max"), tf_type_);
   auto clip = ops::ClipByValue(s.WithOpName("my_clip"), t, clip_value_min,
                                clip_value_max);
-  return clip.operation.node()->def();
-}
+  const NodeDef& node_def = clip.operation.node()->def();
 
-template <DataType dtype>
-void TestConvertClipByValue(OpConverterTest* test) {
-  typedef typename EnumToDataType<dtype>::Type CType;
+  nvinfer1::DataType trt_type_;
+  TF_ASSERT_OK(TfTypeToTrtType(tf_type_, &trt_type_));
 
-  struct TestParams {
-    std::vector<int> dims;
-    std::vector<CType> input_value;
-    CType clip_value_min;
-    CType clip_value_max;
-    std::vector<CType> expected_output;
-  };
-
-  const std::vector<CType> common_input = InitTestVector<CType>(6);
-  std::vector<TestParams> params = {
-      {
-          /*dims=*/{1, 2, 3},
-          /*input_value=*/common_input,
-          /*clip_value_min=*/CType(2),
-          /*clip_value_max=*/CType(5),
-          /*expected_output=*/
-          {CType(2), CType(2), CType(2), CType(3), CType(4), CType(5)},
-      },
-      {
-          /*dims=*/{2, 1, 3},
-          /*input_value=*/common_input,
-          /*clip_value_min=*/CType(-1),
-          /*clip_value_max=*/CType(8),
-          /*expected_output=*/common_input,
-      },
-  };
-
-  for (int i = 0; i < params.size(); ++i) {
-    test->Reset();
-
-    NodeDef node_def = GetClipByValueNodeDef(dtype);
-    nvinfer1::DataType trt_type;
-    TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
-    test->AddTestTensor("t", params[i].dims, 1, trt_type);
-    test->AddTestWeights<CType>("clip_value_min", {1},
-                                {params[i].clip_value_min});
-    test->AddTestWeights<CType>("clip_value_max", {1},
-                                {params[i].clip_value_max});
-    test->RunValidationAndConversion(node_def);
-
-    TRT_TensorOrWeights output;
-    TF_EXPECT_OK(test->GetTensorOrWeights("my_clip", &output));
-    EXPECT_TRUE(output.is_tensor());
-    ExpectTrtDimsEqualsArray(params[i].dims, output.tensor()->getDimensions());
-
-    DataVec input_data{{"t", test->AsTensor<CType>(params[i].input_value)}};
-    DataVec output_data{{"my_clip", test->ConstructTensor<CType>(
-                                        params[i].expected_output.size())}};
-    TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data));
-    EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
-                ElementsAreArray(params[i].expected_output));
-  }
-}
-
-TEST_F(OpConverterTest, ConvertClipByValue) {
   {
     // Input is a weight, should fail.
     Reset();
-    NodeDef node_def = GetClipByValueNodeDef(DT_FLOAT);
-    AddTestWeights<float>("t", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
-    AddTestWeights<float>("clip_value_min", {1}, {1});
-    AddTestWeights<float>("clip_value_max", {1}, {5});
+    AddTestWeights("t", {1, 2, 3}, {1, 2, 3, 4, 5, 6}, tf_type_);
+    AddTestWeights("clip_value_min", {1}, {1}, tf_type_);
+    AddTestWeights("clip_value_max", {1}, {5}, tf_type_);
     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
                                "The input \"t\" for ClipByValue must be a "
                                "tensor, at my_clip");
@@ -6690,10 +6634,9 @@
   {
     // Clip min is a tensor, should fail.
     Reset();
-    NodeDef node_def = GetClipByValueNodeDef(DT_FLOAT);
     AddTestTensor("t", {1, 2, 3});
     AddTestTensor("clip_value_min", {1});
-    AddTestWeights<float>("clip_value_max", {1}, {1});
+    AddTestWeights("clip_value_max", {1}, {1}, tf_type_);
     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
                                "The input \"clip_value_min\" for ClipByValue "
                                "must be a constant, at my_clip");
@@ -6701,17 +6644,78 @@
   {
     // Clip max is a tensor, should fail.
     Reset();
-    NodeDef node_def = GetClipByValueNodeDef(DT_FLOAT);
     AddTestTensor("t", {1, 2, 3});
-    AddTestWeights<float>("clip_value_min", {1}, {1});
+    AddTestWeights("clip_value_min", {1}, {1}, tf_type_);
     AddTestTensor("clip_value_max", {1});
     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
                                "The input \"clip_value_max\" for ClipByValue "
                                "must be a constant, at my_clip");
   }
 
-  TestConvertClipByValue<DT_FLOAT>(this);
-  TestConvertClipByValue<DT_HALF>(this);
+  struct TestParams {
+    std::vector<int> dims;
+    int clip_value_min;
+    int clip_value_max;
+    std::vector<float> expected_output;
+  };
+
+  const std::vector<float> common_input = InitTestVector<float>(6);
+
+  std::vector<TestParams> params = {{
+                                        /*dims=*/{6},
+                                        /*clip_value_min=*/2,
+                                        /*clip_value_max=*/4,
+                                        /*expected_output=*/{2, 2, 2, 3, 4, 4},
+                                    },
+                                    {
+                                        /*dims=*/{1, 6},
+                                        /*clip_value_min=*/2,
+                                        /*clip_value_max=*/4,
+                                        /*expected_output=*/{2, 2, 2, 3, 4, 4},
+                                    },
+                                    {
+                                        /*dims=*/{1, 2, 3},
+                                        /*clip_value_min=*/2,
+                                        /*clip_value_max=*/4,
+                                        /*expected_output=*/{2, 2, 2, 3, 4, 4},
+                                    },
+                                    {
+                                        /*dims=*/{1, 2, 3, 1},
+                                        /*clip_value_min=*/2,
+                                        /*clip_value_max=*/4,
+                                        /*expected_output=*/{2, 2, 2, 3, 4, 4},
+                                    },
+                                    {
+                                        /*dims=*/{1, 1, 3, 1, 2},
+                                        /*clip_value_min=*/2,
+                                        /*clip_value_max=*/4,
+                                        /*expected_output=*/{2, 2, 2, 3, 4, 4},
+                                    },
+                                    {
+                                        /*dims=*/{1, 1, 3, 1, 2, 1},
+                                        /*clip_value_min=*/2,
+                                        /*clip_value_max=*/4,
+                                        /*expected_output=*/{2, 2, 2, 3, 4, 4},
+                                    },
+                                    {
+                                        /*dims=*/{2, 1, 3},
+                                        /*clip_value_min=*/-1,
+                                        /*clip_value_max=*/8,
+                                        /*expected_output=*/common_input,
+                                    }};
+
+  for (auto p : params) {
+    Reset();
+
+    AddTestTensor("t", p.dims, tf_type_, common_input);
+    AddTestWeights("clip_value_min", {1}, {p.clip_value_min}, tf_type_);
+    AddTestWeights("clip_value_max", {1}, {p.clip_value_max}, tf_type_);
+
+    TestOpConverter("my_clip", node_def, p.dims,
+                    /*expected_conversion_status=*/Status::OK(),
+                    /*expected_runtime_status=*/Status::OK(),
+                    /*matcher=*/ElementsAreArray(p.expected_output));
+  }
 }
 #endif  // IS_TRT_VERSION_GE(5, 1, 2, 0)
 
@@ -6725,7 +6729,7 @@
   return squared_diff.operation.node()->def();
 }
 
-TEST_P(OpConverterTest2, ConvertSquaredDifference) {
+TEST_P(OpConverter_FP32_FP16_Test, ConvertSquaredDifference) {
   {
     // Input is a weight, should fail.
     Reset();
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
index 2d56209..c00b4d7 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
@@ -45,6 +45,7 @@
 #include "tensorflow/core/platform/stream_executor.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/util/env_var.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
 
@@ -433,6 +434,9 @@
 
 void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
                                        AsyncHelper* helper) {
+  tensorflow::profiler::TraceMe activity(
+      "TRTEngineOp::ExecuteNativeSegment",
+      tensorflow::profiler::TraceMeLevel::kInfo);
   std::vector<Tensor> inputs;
   std::vector<Tensor>* outputs = new std::vector<Tensor>();
   if (native_execution_func_handle_ == kInvalidHandle) {
@@ -457,18 +461,21 @@
   lib->Run(opts, native_execution_func_handle_, inputs, outputs,
            [this, ctx, outputs, helper](const Status& s) {
              core::ScopedUnref sc(helper);
+             std::unique_ptr<std::vector<Tensor>> outputs_wrapper(outputs);
              OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
              VLOG(1) << "Native Segment completed";
              for (size_t t = 0; t < outputs->size(); ++t) {
                ctx->set_output(t, outputs->at(t));
              }
-             delete outputs;
            });
 }
 
 void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
                                      TRTEngineCacheResource* cache_res,
                                      AsyncHelper* helper) {
+  tensorflow::profiler::TraceMe activity(
+      "TRTEngineOp::ExecuteCalibration",
+      tensorflow::profiler::TraceMeLevel::kInfo);
   VLOG(1) << "Executing TRT calibration: " << name();
   helper->Ref();
   core::ScopedUnref sc(helper);
@@ -594,6 +601,8 @@
 
 void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
                                AsyncOpKernel::DoneCallback done) {
+  tensorflow::profiler::TraceMe activity(
+      "TRTEngineOp::ComputeAsync", tensorflow::profiler::TraceMeLevel::kInfo);
   auto helper = new AsyncHelper(done);
   core::ScopedUnref sc(helper);
 
@@ -718,6 +727,9 @@
 Status TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
                                      EngineContext* engine_context,
                                      int trt_context_idx) {
+  tensorflow::profiler::TraceMe activity(
+      "TRTEngineOp::ExecuteTrtEngine",
+      tensorflow::profiler::TraceMeLevel::kInfo);
   VLOG(1) << "Executing TRT engine: " << name();
   auto& cuda_engine = engine_context->cuda_engine;
 
diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
index f470d96..fd75c48 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
+++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
@@ -26,24 +26,22 @@
 
 bool IsGoogleTensorRTEnabled() {
 #if GOOGLE_CUDA && GOOGLE_TENSORRT
-#if TF_OSS_TENSORRT_STATIC
+#if TF_USE_TENSORRT_STATIC
   LOG(INFO) << "TensorRT libraries are statically linked, skip dlopen check";
   return true;
-#else
+#else   // TF_USE_TENSORRT_STATIC
   auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries();
   if (!handle_or.ok()) {
     LOG_WARNING_WITH_PREFIX
         << "Cannot dlopen some TensorRT libraries. If you would like "
            "to use Nvidia GPU with TensorRT, please make sure the "
            "missing libraries mentioned above are installed properly.";
-    return false;
-  } else {
-    return true;
   }
-#endif
-#else
+  return handle_or.ok();
+#endif  // TF_USE_TENSORRT_STATIC
+#else   // GOOGLE_CUDA && GOOGLE_TENSORRT
   return false;
-#endif
+#endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
 }
 
 }  // namespace tensorrt
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index eea7262..d1b83cf 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -350,15 +350,12 @@
         ":xla_helpers",
         ":xla_op_registry",
         ":xla_resource",
-        "//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/types:span",
-        "@com_google_absl//absl/types:variant",
         "//tensorflow/compiler/jit:common",
         "//tensorflow/compiler/jit:flags",
         "//tensorflow/compiler/jit:shape_inference",
+        "//tensorflow/compiler/mlir:array_container_utils",
+        "//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
+        "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
         "//tensorflow/compiler/xla:protobuf_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
@@ -376,13 +373,12 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:ops",
         "//tensorflow/core:protos_all_cc",
-    ] + if_libtpu(
-        if_false = [
-            "//tensorflow/compiler/mlir:array_container_utils",
-            "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
-        ],
-        if_true = [],
-    ),
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/types:span",
+        "@com_google_absl//absl/types:variant",
+    ],
     alwayslink = 1,
 )
 
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index b4f8970..cec7b9a 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -121,6 +121,7 @@
         "tridiagonal_ops.cc",
         "unary_ops.cc",
         "unary_ops_composition.cc",
+        "unique_op.cc",
         "unpack_op.cc",
         "variable_ops.cc",
         "where_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index b461aa4..dc2a270 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -91,6 +91,13 @@
     xla::PrimitiveType type;
     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type));
     xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx);
+    bool num_samples_is_dynamic = false;
+    OP_REQUIRES_OK(
+        ctx, ctx->ResolveInputDynamismIntoPred(1, &num_samples_is_dynamic));
+    if (num_samples_is_dynamic && num_samples != 1) {
+      // Number samples is dimension 1 in uniform_shape_array.
+      log_uniforms = xla::SetDimensionSize(log_uniforms, ctx->Input(1), 1);
+    }
 
     // Use Gumbel softmax trick to generate categorical samples.
     // See:
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
index e9f32f2..9d64aaf 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -179,7 +179,7 @@
 
 }  // anonymous namespace
 
-absl::Span<const DataType> GetXlaConvTypes() {
+std::vector<DataType> GetXlaConvTypes() {
   return {DT_FLOAT, DT_BFLOAT16, DT_HALF, DT_DOUBLE};
 }
 
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
index 94451ae..179f5fc 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
@@ -37,7 +37,7 @@
 
 // We don't support integers for convolutions, so we list the supported types
 // here.
-absl::Span<const DataType> GetXlaConvTypes();
+std::vector<DataType> GetXlaConvTypes();
 
 // ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA
 // convolution.
diff --git a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc
index 028f5fa..7b8921e 100644
--- a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc
@@ -40,7 +40,7 @@
 
   void Compile(XlaOpKernelContext* ctx) override {
     xla::XlaOp lhs = ctx->Input(0);
-    if (equation_.find(",") == equation_.npos) {
+    if (equation_.find(',') == equation_.npos) {
       ctx->SetOutput(0, xla::Einsum(lhs, equation_));
     } else {
       xla::XlaOp rhs = ctx->Input(1);
@@ -68,7 +68,7 @@
     OP_REQUIRES_OK(ctx,
                    ctx->InputList("inputs", &input_handles, &input_shapes));
 
-    if (equation_.find(",") == equation_.npos) {
+    if (equation_.find(',') == equation_.npos) {
       OP_REQUIRES(
           ctx, input_handles.size() == 1,
           errors::InvalidArgument(
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index e8149d3..dfbad70 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -110,16 +110,15 @@
     OP_REQUIRES_OK(ctx, output.status());
 
     if (type == DT_INT32 || type == DT_INT64) {
-      // If input has dynamic dimension (value is -1), propagate the dynamic
-      // dimension to output using set-dimension-size.
-      ctx->set_dynamic_dimension_is_minus_one(true);
-      OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit));
+      bool limit_is_dynamic = false;
+      OP_REQUIRES_OK(ctx,
+                     ctx->ResolveInputDynamismIntoPred(1, &limit_is_dynamic));
       if (type == DT_INT32) {
-        if (limit.Get<int32>({}) == -1) {
+        if (limit_is_dynamic) {
           output = xla::SetDimensionSize(output.ValueOrDie(), ctx->Input(1), 0);
         }
       } else {
-        if (limit.Get<int64>({}) == -1) {
+        if (limit_is_dynamic) {
           output = xla::SetDimensionSize(
               output.ValueOrDie(),
               xla::ConvertElementType(ctx->Input(1), xla::S32), 0);
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
index 8cfd985..7b2acae 100644
--- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
@@ -13,6 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/lib/comparators.h"
@@ -53,5 +55,76 @@
 
 REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp);
 
+class XlaVariadicSortOp : public XlaOpKernel {
+ public:
+  explicit XlaVariadicSortOp(OpKernelConstruction* context)
+      : XlaOpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("T", &input_types_));
+    OP_REQUIRES_OK(context, context->GetAttr("comparator", &comparator_));
+    OP_REQUIRES_OK(context, context->GetAttr("dimension", &dimension_));
+    OP_REQUIRES_OK(context, context->GetAttr("is_stable", &is_stable_));
+  }
+
+  void Compile(XlaOpKernelContext* context) override {
+    std::vector<xla::XlaOp> inputs(input_types_.size());
+    std::vector<xla::PrimitiveType> input_xla_types(input_types_.size());
+    std::vector<XlaCompiler::Argument> comparator_args(2 * input_types_.size());
+
+    for (int i = 0; i < input_types_.size(); ++i) {
+      inputs[i] = context->Input(i);
+      OP_REQUIRES_OK(context, DataTypeToPrimitiveType(input_types_[i],
+                                                      &input_xla_types[i]));
+      XlaCompiler::Argument comparator_arg;
+      comparator_arg.kind = XlaCompiler::Argument::kParameter;
+      comparator_arg.type = input_types_[i];
+      comparator_arg.shape = TensorShape();
+      comparator_args[2 * i] = comparator_arg;
+      comparator_args[2 * i + 1] = comparator_arg;
+    }
+
+    // Build the comparator function.
+    XlaCompiler::CompilationResult comparator;
+    XlaCompiler::CompileOptions compile_options;
+    compile_options.use_tuple_arg = false;
+    compile_options.always_return_tuple = false;
+    compile_options.is_entry_computation = false;
+    OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+                                compile_options, *comparator_, comparator_args,
+                                &comparator));
+
+    xla::Shape expected_comparator_output_shape;
+    OP_REQUIRES_OK(context,
+                   TensorShapeToXLAShape(DT_BOOL, TensorShape(),
+                                         &expected_comparator_output_shape));
+    OP_REQUIRES(
+        context,
+        xla::ShapeUtil::Compatible(comparator.xla_output_shape,
+                                   expected_comparator_output_shape),
+        errors::InvalidArgument(
+            "Invalid output shape of XlaReduce reducer. Expected ",
+            xla::ShapeUtil::HumanString(expected_comparator_output_shape),
+            " got ", xla::ShapeUtil::HumanString(comparator.xla_output_shape)));
+
+    xla::XlaOp outputs =
+        xla::Sort(inputs, *comparator.computation, dimension_, is_stable_);
+
+    for (int i = 0; i < input_types_.size(); ++i) {
+      xla::XlaOp output_handle =
+          (input_types_.size() > 1 ? xla::GetTupleElement(outputs, i)
+                                   : outputs);
+      context->SetOutput(i, output_handle);
+    }
+  }
+
+ private:
+  DataTypeVector input_types_;
+  const NameAttrList* comparator_;
+  int64 dimension_;
+  bool is_stable_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(XlaVariadicSortOp);
+};
+
+REGISTER_XLA_OP(Name("XlaVariadicSort"), XlaVariadicSortOp);
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/unique_op.cc b/tensorflow/compiler/tf2xla/kernels/unique_op.cc
new file mode 100644
index 0000000..2704554
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/unique_op.cc
@@ -0,0 +1,189 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/comparators.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/comparison_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/ops_util.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+namespace {
+
+class UniqueOp : public XlaOpKernel {
+ public:
+  explicit UniqueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+  // We use a two level loop algorithm to calculate unique.
+  //
+  // i = 0
+  // output_size = 0
+  // output_indices = broadcast(0, {input_size})
+  // while (i < input_size) {
+  //   search_result_index = output_size
+  //   j = 0
+  //   while (j < output_size) {
+  //     if(input[j]==input[i]) {
+  //       search_result_index = j
+  //     }
+  //     ++j
+  //   }
+  //   input[search_result_index] = input[i]
+  //   output_indices[i] = search_result_index
+  //   if (search_result_index == output_size) {
+  //     // Not found
+  //     output_size ++;
+  //   }
+  //   i ++;
+  // }
+  //
+  // The algorithm is then functionalized into xla whiles.  Outer-scoped
+  // variables are captured as inputs and outputs to the while loop.
+  // Conditionals are rewritten into xla select for simplicity.
+  xla::XlaComputation BuildInnerLoopCond(XlaOpKernelContext* ctx,
+                                         xla::Shape inner_loop_shape) {
+    std::unique_ptr<xla::XlaBuilder> builder =
+        ctx->builder()->CreateSubBuilder("inner_loop_cond");
+    auto param = xla::Parameter(builder.get(), 0, inner_loop_shape, "param");
+    auto j = xla::GetTupleElement(param, 2);
+    auto output_element_size = xla::GetTupleElement(param, 3);
+    xla::Lt(j, output_element_size);
+    return builder->Build().ConsumeValueOrDie();
+  }
+
+  xla::XlaComputation BuildInnerLoopBody(XlaOpKernelContext* ctx,
+                                         xla::Shape inner_loop_shape,
+                                         xla::Shape single_element_shape) {
+    std::unique_ptr<xla::XlaBuilder> builder =
+        ctx->builder()->CreateSubBuilder("inner_loop_body");
+    auto param = xla::Parameter(builder.get(), 0, inner_loop_shape, "param");
+    auto input = xla::GetTupleElement(param, 0);
+    auto target = xla::GetTupleElement(param, 1);
+    auto j = xla::GetTupleElement(param, 2);
+    auto output_element_size = xla::GetTupleElement(param, 3);
+    auto output_index = xla::GetTupleElement(param, 4);
+    auto input_elem = xla::DynamicSlice(input, {j}, {1});
+    auto input_elem_scalar = xla::Reshape(single_element_shape, input_elem);
+    auto eq = xla::Eq(input_elem_scalar, target);
+    auto select = xla::Select(eq, j, output_index);
+    auto next_j = xla::Add(j, xla::One(builder.get(), xla::S32));
+    xla::Tuple(builder.get(),
+               {input, target, next_j, output_element_size, select});
+    return builder->Build().ConsumeValueOrDie();
+  }
+
+  xla::XlaComputation BuildOuterLoopCond(XlaOpKernelContext* ctx,
+                                         xla::Shape outer_loop_shape,
+                                         int64 list_size) {
+    std::unique_ptr<xla::XlaBuilder> builder =
+        ctx->builder()->CreateSubBuilder("outer_loop_body");
+    auto param =
+        xla::Parameter(builder.get(), 0, outer_loop_shape, "outer_loop_param");
+    auto i = xla::GetTupleElement(param, 2);
+    auto bound = xla::ConstantR0<int32>(builder.get(), list_size);
+    xla::Lt(i, bound);
+    return builder->Build().ConsumeValueOrDie();
+  }
+
+  xla::XlaComputation BuildOuterLoopBody(
+      XlaOpKernelContext* ctx, xla::Shape outer_loop_shape,
+      xla::Shape single_element_shape, const xla::XlaComputation& inner_cond,
+      const xla::XlaComputation& inner_body) {
+    std::unique_ptr<xla::XlaBuilder> builder =
+        ctx->builder()->CreateSubBuilder("outer_loop_body");
+    auto param = xla::Parameter(builder.get(), 0, outer_loop_shape, "param");
+    auto input = xla::GetTupleElement(param, 0);
+    auto indices = xla::GetTupleElement(param, 1);
+    auto i = xla::GetTupleElement(param, 2);
+    auto output_element_size = xla::GetTupleElement(param, 3);
+    auto zero = xla::Zero(builder.get(), xla::S32);
+    auto target = xla::DynamicSlice(input, {i}, {1});
+    auto target_scalar = xla::Reshape(single_element_shape, target);
+    auto inner_loop_param = xla::Tuple(
+        builder.get(),
+        {input, target_scalar, zero, output_element_size, output_element_size});
+    auto inner_loop = xla::While(inner_cond, inner_body, inner_loop_param);
+    auto output_index = xla::GetTupleElement(inner_loop, 4);
+    auto one = xla::One(builder.get(), xla::S32);
+    auto update_output_element_size =
+        xla::Select(xla::Eq(output_index, output_element_size),
+                    xla::Add(output_element_size, one), output_element_size);
+    auto update_input = xla::DynamicUpdateSlice(input, target, {output_index});
+    auto update_indices =
+        xla::DynamicUpdateSlice(indices, xla::Reshape(output_index, {1}), {i});
+    xla::Tuple(builder.get(), {update_input, update_indices, xla::Add(i, one),
+                               update_output_element_size});
+    return builder->Build().ConsumeValueOrDie();
+  }
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    xla::XlaOp input = ctx->Input(0);
+    xla::StatusOr<xla::Shape> input_shape_or = ctx->builder()->GetShape(input);
+    OP_REQUIRES_OK(ctx, input_shape_or.status());
+    auto input_shape = input_shape_or.ValueOrDie();
+    xla::Shape single_index_shape = xla::ShapeUtil::MakeScalarShape(xla::S32);
+    xla::Shape single_element_shape =
+        xla::ShapeUtil::MakeScalarShape(input_shape.element_type());
+    OP_REQUIRES(ctx, input_shape.rank() == 1,
+                xla::InvalidArgument("Input to UniqueOp must be rank-1: %s",
+                                     input_shape.ToString()));
+    int64 list_size = input_shape.dimensions()[0];
+    auto indices_shape =
+        xla::ShapeUtil::ChangeElementType(input_shape, xla::S32);
+    auto outer_loop_shape = xla::ShapeUtil::MakeTupleShape(
+        {input_shape, indices_shape, single_index_shape, single_index_shape});
+    auto inner_loop_shape = xla::ShapeUtil::MakeTupleShape(
+        {input_shape, single_element_shape, single_index_shape,
+         single_index_shape, single_index_shape});
+    xla::XlaComputation inner_loop_cond =
+        BuildInnerLoopCond(ctx, inner_loop_shape);
+    xla::XlaComputation inner_loop_body =
+        BuildInnerLoopBody(ctx, inner_loop_shape, single_element_shape);
+    xla::XlaComputation outer_loop_cond =
+        BuildOuterLoopCond(ctx, outer_loop_shape, list_size);
+    xla::XlaComputation outer_loop_body =
+        BuildOuterLoopBody(ctx, outer_loop_shape, single_element_shape,
+                           inner_loop_cond, inner_loop_body);
+    auto zero = xla::Zero(ctx->builder(), xla::S32);
+    auto init_indices = xla::Broadcast(zero, {list_size});
+    auto init = xla::Tuple(ctx->builder(), {input, init_indices, zero, zero});
+    auto outer_while = xla::While(outer_loop_cond, outer_loop_body, init);
+    auto output = xla::GetTupleElement(outer_while, 0);
+    auto output_indices = xla::GetTupleElement(outer_while, 1);
+    auto output_size = xla::GetTupleElement(outer_while, 3);
+    auto output_dynamic = xla::SetDimensionSize(output, output_size, 0);
+    ctx->SetOutput(0, output_dynamic);
+    ctx->SetOutput(1, output_indices);
+  }
+};
+
+REGISTER_XLA_OP(Name("Unique").Device(DEVICE_TPU_XLA_JIT), UniqueOp);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index c557813..0e780a7 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -651,6 +651,36 @@
 sorted_values: A `Tensor` of type V.
 )doc");
 
+REGISTER_OP("XlaVariadicSort")
+    .Input("input: T")
+    .Output("output: T")
+    .Attr("T: list(type) >= 1")
+    .Attr("comparator: func")
+    .Attr("dimension: int")
+    .Attr("is_stable: bool")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      for (int i = 0; i < c->num_inputs(); ++i) {
+        c->set_output(i, c->input(i));
+      }
+      return Status::OK();
+    })
+    .Doc(R"doc(
+Wraps the XLA Sort operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#sort
+.
+
+Sorts one or more tensors, with support for custom comparator, dimension, and
+is_stable attributes.
+
+input: A list of `Tensor` of identical shape by possibly different types.
+comparator: A comparator function to apply to 2*N scalars and returning a
+  boolean. N is the number of sort inputs. If you want to sort in ascending
+  order then the comparator should perform a less-than comparison.
+output: A list of `Tensor` of type T.
+dimension: The dimension along which to sort.
+is_stable: Whether to use stable sort.
+)doc");
+
 // TODO(b/37549631) setting the While Op to always be stateful is too
 // conservative.
 REGISTER_OP("XlaWhile")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index f9d7181..df2b1c3 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -342,6 +342,8 @@
 reduce = gen_xla_ops.xla_reduce
 variadic_reduce = gen_xla_ops.xla_variadic_reduce
 
+ops.no_gradient("XlaVariadicReduce")
+
 
 def reduce_window(operand,
                   init,
@@ -471,6 +473,7 @@
 
 sort = gen_xla_ops.xla_sort
 key_value_sort = gen_xla_ops.xla_key_value_sort
+variadic_sort = gen_xla_ops.xla_variadic_sort
 while_loop = gen_xla_ops.xla_while
 dequantize = gen_xla_ops.xla_dequantize
 
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index 0642301..0de0058 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -83,14 +83,31 @@
   return allocator_.get();
 }
 
+// Attaches location from the node stack trace to metadata. As a heuristic,
+// picks the last frame which does not contain the "tensorflow/python" substring
+// (making exception for frames containing "test" to allow for testing the
+// feature).
+static void AttachLocationToMetadata(xla::OpMetadata& metadata,
+                                     OpKernel* op_kernel, XlaContext& context) {
+  if (const AbstractStackTrace* stack_trace =
+          context.StackTraceForNodeName(op_kernel->def().name())) {
+    if (absl::optional<StackFrame> frame = stack_trace->LastUserFrame()) {
+      metadata.set_source_file(frame->file_name);
+      metadata.set_source_line(frame->line_number);
+    }
+  }
+}
+
 void XlaCompilationDevice::Compute(OpKernel* op_kernel,
                                    OpKernelContext* context) {
   VLOG(4) << "XlaCompilationDevice::Compute "
           << FormatNodeDefForError(op_kernel->def());
-  auto* b = XlaContext::Get(context).builder();
+  XlaContext& xla_context = XlaContext::Get(context);
+  auto* b = xla_context.builder();
   xla::OpMetadata metadata;
   metadata.set_op_type(op_kernel->type_string());
   metadata.set_op_name(op_kernel->name());
+  AttachLocationToMetadata(metadata, op_kernel, xla_context);
   b->SetOpMetadata(metadata);
 
   auto sharding_parse_result = ParseShardingFromDevice(
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 716146e..8dc8da7 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -25,6 +25,8 @@
 #include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/compiler/jit/shape_inference.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
+#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
@@ -58,11 +60,6 @@
 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
 #include "tensorflow/core/util/dump_graph.h"
 
-#ifndef LIBTPU_ON_GCE
-#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
-#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
-#endif
-
 namespace tensorflow {
 namespace {
 
@@ -805,13 +802,6 @@
   VLOG(1) << "====================================================";
   MlirBridgeRolloutPolicy policy =
       GetMlirBridgeRolloutPolicy(*graph, config_proto);
-#ifdef LIBTPU_ON_GCE
-  if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
-    VLOG(1) << "MLIR is not supported in this environment.";
-  }
-  TF_RETURN_IF_ERROR(
-      CompileGraph(options, function_id, std::move(graph), args, result));
-#else
   if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
     VLOG(1) << "Using MLIR bridge";
     GraphDebugInfo debug_info;
@@ -828,7 +818,6 @@
     TF_RETURN_IF_ERROR(
         CompileGraph(options, function_id, std::move(graph), args, result));
   }
-#endif
   VLOG(1) << "====================================================";
 
   cache_[{function_id, arg_vector}] = *result;
@@ -1309,7 +1298,7 @@
                                    options_.device_type, name));
 
   xla::XlaBuilder builder(name);
-  XlaContext* context = new XlaContext(this, &builder);
+  XlaContext* context = new XlaContext(this, &builder, graph.get());
   core::ScopedUnref context_unref(context);
 
   std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index cb5bf34..7e81644 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -57,8 +57,15 @@
   args_ = std::move(args);
 }
 
-XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder)
-    : compiler_(compiler), builder_(builder) {}
+XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
+                       const Graph* graph)
+    : compiler_(compiler), builder_(builder) {
+  if (graph) {
+    for (const Node* node : graph->nodes()) {
+      stack_traces_[node->name()] = node->GetStackTrace();
+    }
+  }
+}
 
 string XlaContext::DebugString() const { return "XLA JIT context"; }
 
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index e44ac05..8376471 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -27,6 +27,7 @@
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/platform/macros.h"
 
 namespace tensorflow {
@@ -44,13 +45,22 @@
 
   // Creates a new XlaContext. See the documentation on the class data fields
   // for descriptions of the arguments.
-  XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder);
+  XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
+             const Graph* graph);
 
   // Virtual method defined by ResourceBase.
   string DebugString() const override;
 
   XlaCompiler* compiler() const { return compiler_; }
 
+  const AbstractStackTrace* StackTraceForNodeName(const std::string& name) {
+    const auto& it = stack_traces_.find(name);
+    if (it != stack_traces_.end()) {
+      return it->second.get();
+    }
+    return nullptr;
+  }
+
   // Returns the XlaBuilder that Ops use for compiling new expressions.
   xla::XlaBuilder* builder() { return builder_; }
 
@@ -100,6 +110,9 @@
   // The XlaBuilder used to construct the subgraph's compiled representation.
   xla::XlaBuilder* builder_;
 
+  // Stack traces for the graph used for compilation.
+  StackTracesMap stack_traces_;
+
   // Arguments to the Tensorflow graph, indexed by _Arg index.
   // Includes both compile-time constant arguments and runtime parameters.
   std::vector<XlaExpression> args_;
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 6de4ae5..6e99618 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -248,7 +248,6 @@
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
-        "@llvm-project//llvm:Support",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 171afa4..6baeca8 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -95,6 +95,7 @@
     hdrs = ["executable_build_options.h"],
     deps = [
         "//tensorflow/compiler/xla:debug_options_flags",
+        "//tensorflow/compiler/xla:execution_options_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc
index f39a3e7..6472323 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.cc
+++ b/tensorflow/compiler/xla/client/executable_build_options.cc
@@ -17,6 +17,7 @@
 
 #include "absl/strings/str_format.h"
 #include "tensorflow/compiler/xla/debug_options_flags.h"
+#include "tensorflow/compiler/xla/execution_options_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 
 namespace xla {
@@ -99,4 +100,34 @@
       device_ordinal_, result_layout, num_replicas_);
 }
 
+ExecutionOptions CreateExecutionOptions(
+    const ExecutableBuildOptions& build_options,
+    const ProgramShape* program_shape) {
+  ExecutionOptions execution_options = CreateDefaultExecutionOptions();
+  if (build_options.has_debug_options()) {
+    *execution_options.mutable_debug_options() = build_options.debug_options();
+  }
+  if (build_options.result_layout() != nullptr) {
+    *execution_options.mutable_shape_with_output_layout() =
+        build_options.result_layout()->ToProto();
+  } else {
+    Shape result_shape(program_shape->result());
+    LayoutUtil::SetToDefaultLayout(&result_shape);
+    *execution_options.mutable_shape_with_output_layout() =
+        result_shape.ToProto();
+  }
+  execution_options.set_num_replicas(build_options.num_replicas());
+  execution_options.set_num_partitions(build_options.num_partitions());
+  execution_options.set_use_spmd_partitioning(
+      build_options.use_spmd_partitioning());
+  execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo());
+  if (build_options.has_device_assignment()) {
+    TF_CHECK_OK(build_options.device_assignment().Serialize(
+        execution_options.mutable_device_assignment()));
+  }
+  execution_options.set_alias_passthrough_params(
+      build_options.alias_passthrough_params());
+  return execution_options;
+}
+
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index d3f5dd3..000d2ad 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -115,6 +115,16 @@
     return *this;
   }
 
+  // Thread pool for parallel compilation.
+  tensorflow::thread::ThreadPool* compile_thread_pool() const {
+    return compile_thread_pool_;
+  }
+  ExecutableBuildOptions& set_compile_thread_pool(
+      tensorflow::thread::ThreadPool* compile_thread_pool) {
+    compile_thread_pool_ = compile_thread_pool;
+    return *this;
+  }
+
  private:
   int device_ordinal_ = -1;
   Shape result_layout_;
@@ -128,8 +138,15 @@
   absl::optional<DeviceAssignment> device_assignment_;
   bool alias_passthrough_params_ = false;
   bool run_backend_only_ = false;
+  tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
 };
 
+// Creates an ExecutionOptions based on a given ExecutableBuildOptions and
+// ProgramShape.
+ExecutionOptions CreateExecutionOptions(
+    const ExecutableBuildOptions& build_options,
+    const ProgramShape* program_shape);
+
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc
index 6008677..264835b 100644
--- a/tensorflow/compiler/xla/client/lib/prng.cc
+++ b/tensorflow/compiler/xla/client/lib/prng.cc
@@ -158,58 +158,96 @@
   return std::make_pair(Uint64ToUint32s(input_u64), new_state);
 }
 
+// Result for SplitShapeIntoHalves().
+struct SplitShapePair {
+  Shape half_shape;
+  Shape concat_shape;
+  int64 split_dim;
+  int64 new_concat_dim;
+};
+
+// Split the shape on a dimension > 1 into two halves.
+SplitShapePair SplitShapeIntoHalves(const Shape& shape) {
+  SplitShapePair pair;
+  if (shape.rank() == 0) {
+    pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), {1});
+    pair.concat_shape = ShapeUtil::MakeShape(shape.element_type(), {2});
+    pair.split_dim = 0;
+    pair.new_concat_dim = 0;
+    return pair;
+  }
+  pair.split_dim = -1;
+  for (int64 i = 0; i < shape.rank(); ++i) {
+    if (shape.dimensions(i) % 2 == 0) {
+      pair.split_dim = i;
+      break;
+    }
+  }
+  if (pair.split_dim == -1) {
+    // No even dims. Find a dimension with maximum size.
+    for (int64 i = 0; i < shape.rank(); ++i) {
+      if (pair.split_dim == -1 ||
+          shape.dimensions(i) > shape.dimensions(pair.split_dim)) {
+        pair.split_dim = i;
+      }
+    }
+  }
+  CHECK_GE(pair.split_dim, 0);
+  std::vector<int64> half_shape_dims;
+  std::vector<int64> concat_shape_dims;
+  for (int64 i = 0; i < shape.rank(); ++i) {
+    if (i == pair.split_dim) {
+      // Create a new trivial dim for the later concat, which is more friendly
+      // to sharding propagation.
+      half_shape_dims.push_back(CeilOfRatio<int64>(shape.dimensions(i), 2));
+      half_shape_dims.push_back(1);
+      concat_shape_dims.push_back(half_shape_dims[i]);
+      concat_shape_dims.push_back(2);
+    } else {
+      half_shape_dims.push_back(shape.dimensions(i));
+      concat_shape_dims.push_back(shape.dimensions(i));
+    }
+  }
+  pair.new_concat_dim = pair.split_dim + 1;
+  pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
+  pair.concat_shape =
+      ShapeUtil::MakeShape(shape.element_type(), concat_shape_dims);
+  return pair;
+}
+
+// Combines a pair of split shapes. It works with scalar and non-scalar shapes.
+XlaOp CombineShapePair(absl::Span<const XlaOp> pair,
+                       const SplitShapePair& shape_pair,
+                       const Shape& original_shape) {
+  if (original_shape.rank() == 0) {
+    return Reshape(pair[0], {});
+  }
+  XlaBuilder* builder = pair[0].builder();
+  XlaOp result = ConcatInDim(builder, pair, shape_pair.new_concat_dim);
+  const int64 pre_split_size = original_shape.dimensions(shape_pair.split_dim);
+  std::vector<int64> reshape_dims(original_shape.dimensions().begin(),
+                                  original_shape.dimensions().end());
+  reshape_dims[shape_pair.split_dim] =
+      RoundUpToNearest<int64>(pre_split_size, 2);
+  result = Reshape(result, reshape_dims);
+  if (reshape_dims[shape_pair.split_dim] != pre_split_size) {
+    result = Slice(result, std::vector<int64>(original_shape.rank(), 0),
+                   original_shape.dimensions(),
+                   std::vector<int64>(original_shape.rank(), 1));
+  }
+  return result;
+}
+
 // Generates random 32bits with the given shape using the Three Fry
 // implementation. Returns the random bits and the new state.
 RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
-  XlaBuilder* builder = key.builder();
-  // Try to split the shape on a dimension > 1 into two halves, each
-  // representing a U32 value.
-  std::vector<int64> half_shape_dims;
-  std::vector<int64> padded_full_shape_dims;
-  int64 split_dim = -1;
-  for (int64 i = 0; i < shape.rank(); ++i) {
-    if (shape.dimensions(i) > 1 && split_dim < 0) {
-      half_shape_dims.push_back(CeilOfRatio<int64>(shape.dimensions(i), 2));
-      // Create a new trivial dim for the later concat, which is more friendly
-      // to sharding propagation.
-      half_shape_dims.push_back(1);
-      split_dim = i;
-      padded_full_shape_dims.push_back(half_shape_dims[i] * 2);
-    } else {
-      half_shape_dims.push_back(shape.dimensions(i));
-      padded_full_shape_dims.push_back(shape.dimensions(i));
-    }
-  }
-  auto half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
-  if (split_dim >= 0) {
-    std::pair<ThreeFry2x32State, XlaOp> inputs_state =
-        GetThreeFryInputsAndUpdatedState(initial_state, half_shape);
-    ThreeFry2x32State inputs = inputs_state.first;
-    ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
-    XlaOp result = ConcatInDim(builder, outputs, split_dim + 1);
-    result = Reshape(result, padded_full_shape_dims);
-    if (shape.dimensions(split_dim) % 2 != 0) {
-      result = Slice(result, std::vector<int64>(shape.rank(), 0),
-                     shape.dimensions(), std::vector<int64>(shape.rank(), 1));
-    }
-    return {result, inputs_state.second};
-  }
-  // Use an R1 shape if the previous attempt failed.
-  const int64 size = ShapeUtil::ElementsIn(shape);
-  const int64 half_size = CeilOfRatio<int64>(size, 2);
-  const bool size_is_odd = (half_size * 2 != size);
+  auto shape_pair = SplitShapeIntoHalves(shape);
   std::pair<ThreeFry2x32State, XlaOp> inputs_state =
-      GetThreeFryInputsAndUpdatedState(
-          initial_state,
-          ShapeUtil::MakeShape(shape.element_type(), {half_size}));
+      GetThreeFryInputsAndUpdatedState(initial_state, shape_pair.half_shape);
   ThreeFry2x32State inputs = inputs_state.first;
   ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
-  if (size_is_odd) {
-    outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
-  }
-  XlaOp result = ConcatInDim(builder, outputs, 0);
-  return {Reshape(result, AsInt64Slice(shape.dimensions())),
-          inputs_state.second};
+  XlaOp result = CombineShapePair(outputs, shape_pair, shape);
+  return {result, inputs_state.second};
 }
 
 // Generates random 64bits with the given shape using the Three Fry
@@ -577,27 +615,27 @@
   DCHECK(primitive_type == F32 || primitive_type == F64);
 
   XlaBuilder* builder = key.builder();
-  const int64 num_elems = ShapeUtil::ElementsIn(shape);
-  const int64 num_pairs = CeilOfRatio<int64>(num_elems, 2);
+  auto shape_pair = SplitShapeIntoHalves(shape);
   RngOutput bits_state = UniformFloatingPointDistribution(
       key, initial_state, bit_generator,
       xla::ConstantR0WithType(builder, primitive_type, 0.0),
       xla::ConstantR0WithType(builder, primitive_type, 1.0),
-      ShapeUtil::MakeShape(primitive_type, {num_pairs * 2}));
+      shape_pair.concat_shape);
 
   // Separate the bits into two groups to perform the Box-Muller transform.
-  XlaOp bits_0 = Slice(bits_state.value, {0}, {num_pairs}, {1});
-  XlaOp bits_1 = Slice(bits_state.value, {num_pairs}, {2 * num_pairs}, {1});
+  XlaOp bits_0 = Slice(bits_state.value,
+                       std::vector<int64>(shape_pair.half_shape.rank(), 0),
+                       shape_pair.half_shape.dimensions(),
+                       std::vector<int64>(shape_pair.half_shape.rank(), 1));
+  std::vector<int64> bits_1_starts(shape_pair.half_shape.rank(), 0);
+  bits_1_starts[shape_pair.new_concat_dim] = 1;
+  XlaOp bits_1 = Slice(bits_state.value, bits_1_starts,
+                       shape_pair.concat_shape.dimensions(),
+                       std::vector<int64>(shape_pair.half_shape.rank(), 1));
   std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
 
   // Put the numbers in the two groups back to form the requested shape.
-  XlaOp normal = ConcatInDim(builder, {bits_0, bits_1}, /*dimension=*/0);
-  if (num_elems != num_pairs * 2) {
-    normal = Slice(normal, /*start_indices=*/{0}, /*limit_indices=*/{num_elems},
-                   /*strides=*/{1});
-  }
-  normal = Reshape(normal, shape.dimensions());
-
+  XlaOp normal = CombineShapePair({bits_0, bits_1}, shape_pair, shape);
   return {normal, bits_state.state};
 }
 
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index d3b8856..3f582c7 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -94,9 +94,12 @@
   *instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
 }
 
-// Converts a HloComputation into ReducerOr with predicate types.
-HloComputationProto CreateReduceOr(int64 reducer_id,
-                                   HloComputationProto* original_reducer) {
+// Copy `original_reducer` into a new computation proto with `reducer_id` as new
+// id. If `rewrite_into_pred` is true, the instructions in the reducer are
+// rewritten into predicate form.
+HloComputationProto CopyReducer(int64 reducer_id,
+                                HloComputationProto* original_reducer,
+                                bool rewrite_into_pred, int64* global_id) {
   HloComputationProto reducer;
   SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
   std::vector<int64> operands_id;
@@ -106,19 +109,28 @@
         HloOpcode::kParameter) {
       HloInstructionProto* new_param = reducer.add_instructions();
       *new_param = inst;
-      *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
-      operands_id.push_back(inst.id());
+      new_param->set_id((*global_id)++);
+      *new_param->mutable_name() =
+          GetFullName(inst.name(), '.', new_param->id());
+      if (rewrite_into_pred) {
+        *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
+      }
+      operands_id.push_back(new_param->id());
     }
     if (inst.id() == original_reducer->root_id()) {
       HloInstructionProto* new_root = reducer.add_instructions();
       *new_root = inst;
-      *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
-      *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
+      new_root->set_id((*global_id)++);
+      *new_root->mutable_name() = GetFullName(inst.name(), '.', new_root->id());
+      if (rewrite_into_pred) {
+        *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
+        *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
+      }
       new_root->clear_operand_ids();
       for (int64 operand_id : operands_id) {
         new_root->add_operand_ids(operand_id);
       }
-      reducer.set_root_id(inst.id());
+      reducer.set_root_id(new_root->id());
     }
   }
   return reducer;
@@ -132,6 +144,7 @@
   }
   return false;
 }
+
 }  // namespace
 
 namespace internal {
@@ -3323,7 +3336,7 @@
   *program_shape->mutable_result() =
       ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
 
-  std::vector<HloComputationProto> called_computatons;
+  std::vector<HloComputationProto> called_computations;
   auto operand_is_constant = [&](const HloInstructionProto* instr_proto,
                                  int64 operand_index) -> StatusOr<bool> {
     int64 operand_id = instr_proto->operand_ids(operand_index);
@@ -3336,7 +3349,8 @@
   // graph with have id set to `id`.
   auto process_instruction = [&](const HloInstructionProto* instr_proto,
                                  bool need_rewrite, int64 id,
-                                 absl::Span<int64 const> operand_ids) {
+                                 absl::Span<int64 const> operand_ids,
+                                 int64* global_id) {
     // Rewrite the instruction with following rules:
     // - Unary ops: Convert into bitcast (identity) with type Pred.
     // - Binary ops: Convert into binary or.
@@ -3347,6 +3361,8 @@
     // contant False if dimension is static.
     // - Reduce: Convert to reduce or.
     // - Constant: Convert to constant False.
+    // - Reshape, slice, transpose, pad:
+    //   Convert into predicate type with same opcode.
     // - Other ops: Not supported.
     // Create the instruction for the new handle.
     TF_ASSIGN_OR_RETURN(HloOpcode opcode,
@@ -3362,6 +3378,17 @@
     if (!need_rewrite) {
       *new_instr->mutable_name() =
           GetFullName(instr_proto->opcode(), kNameSeparator, id);
+      if (opcode == HloOpcode::kReduce) {
+        // Copy the reducer to the new module, with a new id that's same as the
+        // reduce op.
+        HloComputationProto* reducer =
+            &embedded_[new_instr->called_computation_ids(0)];
+        int64 reducer_id = (*global_id)++;
+        new_instr->clear_called_computation_ids();
+        new_instr->add_called_computation_ids(reducer_id);
+        called_computations.push_back(CopyReducer(
+            reducer_id, reducer, /*rewrite_into_pred=*/false, global_id));
+      }
       return Status::OK();
     }
     *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
@@ -3437,9 +3464,12 @@
         break;
       }
       case HloOpcode::kReduce: {
-        int64 reducer_id = new_instr->called_computation_ids(0);
-        called_computatons.push_back(
-            CreateReduceOr(reducer_id, &embedded_[reducer_id]));
+        auto* reducer = &embedded_[new_instr->called_computation_ids(0)];
+        int64 reducer_id = (*global_id)++;
+        new_instr->clear_called_computation_ids();
+        new_instr->add_called_computation_ids(reducer_id);
+        called_computations.push_back(CopyReducer(
+            reducer_id, reducer, /*rewrite_into_pred=*/true, global_id));
         break;
       }
       case HloOpcode::kTuple:
@@ -3449,6 +3479,7 @@
       case HloOpcode::kBroadcast:
       case HloOpcode::kConcatenate:
       case HloOpcode::kReshape:
+      case HloOpcode::kPad:
         break;
       case HloOpcode::kGetDimensionSize: {
         int64 dimension = instr_proto->dimensions(0);
@@ -3564,10 +3595,11 @@
     if (next_operand >= instr_proto->operand_ids_size() ||
         !should_visit_operand || InstrIsSetBound(instr_proto)) {
       // No more operands to process, process self.
-      int64 new_id = ++global_id;
+      int64 new_id = global_id++;
       VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
       TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
-                                             new_id, item.processed_operands));
+                                             new_id, item.processed_operands,
+                                             &global_id));
       stacktop_id = new_id;
       seen[item_key] = stacktop_id;
       worklist.pop_back();
@@ -3599,10 +3631,14 @@
   module->set_entry_computation_name(entry.name());
   module->set_entry_computation_id(entry.id());
   *module->mutable_host_program_shape() = *program_shape;
-  for (auto& called_comp : called_computatons) {
+  for (auto& called_comp : called_computations) {
     *module->add_computations() = called_comp;
   }
   *module->add_computations() = std::move(entry);
+  // Make sure all ids appear in the computation with ascending order.
+  absl::c_sort(*module->mutable_computations(),
+               [](const HloComputationProto& c1,
+                  const HloComputationProto& c2) { return c1.id() < c2.id(); });
   XLA_VLOG_LINES(3, module->DebugString());
   return std::move(computation);
 }
diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc
index df93e39..2004026 100644
--- a/tensorflow/compiler/xla/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/debug_options_flags.cc
@@ -593,6 +593,18 @@
       "fragmentation. The constraint is soft, so it works with tensors "
       "larger than the given constraint size. -1 corresponds to no "
       "constraints."));
+  flag_objects->push_back(tensorflow::Flag(
+      "xla_gpu_force_compilation_parallelism",
+      int32_setter_for(
+          &DebugOptions::set_xla_gpu_force_compilation_parallelism),
+      flag_values->xla_gpu_force_compilation_parallelism(),
+      "Overrides normal multi-threaded compilation settting to use this many "
+      "threads. Setting to 0 (the default value) means no enforcement."));
+  flag_objects->push_back(tensorflow::Flag(
+      "xla_gpu_deterministic_ops",
+      bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_ops),
+      flag_values->xla_gpu_deterministic_ops(),
+      "Guarantees run-to-run determinism on GPU."));
 
   ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects);
 }
diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml
index 8d217b8..a3948fb 100644
--- a/tensorflow/compiler/xla/g3doc/_book.yaml
+++ b/tensorflow/compiler/xla/g3doc/_book.yaml
@@ -39,6 +39,6 @@
       - title: XLA autoclustering
         path: /xla/tutorials/autoclustering_xla
       - title: Use XLA with tf.function
-        path: /xla/tutorials/compile
+        path: /xla/tutorials/jit_compile
 
 - include: /_upper_tabs_right.yaml
diff --git a/tensorflow/compiler/xla/g3doc/images/tf_xla_performance.png b/tensorflow/compiler/xla/g3doc/images/tf_xla_performance.png
index 70087f5..7ab49fd 100644
--- a/tensorflow/compiler/xla/g3doc/images/tf_xla_performance.png
+++ b/tensorflow/compiler/xla/g3doc/images/tf_xla_performance.png
Binary files differ
diff --git a/tensorflow/compiler/xla/g3doc/index.md b/tensorflow/compiler/xla/g3doc/index.md
index d749e8b..7bbcd43 100644
--- a/tensorflow/compiler/xla/g3doc/index.md
+++ b/tensorflow/compiler/xla/g3doc/index.md
@@ -4,9 +4,10 @@
 algebra that can accelerate TensorFlow models with potentially no source code
 changes.
 
-The results are improvements in speed and memory usage: most internal benchmarks
-run ~1.15x faster after XLA is enabled. The dataset below is evaluated on a
-single NVidia V100 GPU:
+The results are improvements in speed and memory usage: e.g. in BERT
+[MLPerf](https://blog.tensorflow.org/2020/07/tensorflow-2-mlperf-submissions.html)
+submission using 8 Volta V100 GPUs using XLA has achieved a ~7x performance
+improvement and ~5x batch size improvement:
 
 <div style="width:90%; margin:auto; margin-bottom:10px; margin-top:20px;">
 <img style="width:90%" src="./images/tf_xla_performance.png">
@@ -42,40 +43,11 @@
 
 ## Enable XLA for TensorFlow models
 
-### Auto-clustering
+### Explicit compilation with `tf.function(jit_compile=True)`
 
-A simplest way to start using XLA in TensorFlow models is to enable
-_auto-clustering_, which automatically finds _clusters_ (connected subgraphs)
-within the TensorFlow graph which can be compiled and executed using XLA.
-Auto-clustering on GPU can be enabled by setting the `TF_XLA_FLAGS` environment
-variable:
-
-```
-$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program
-```
-
-Auto-clustering is currently optimized for GPU workloads, but it can also be
-enabled on CPU by additionally using the flag `--tf_xla_cpu_global_jit`:
-
-```
-$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program
-```
-
-Note: Auto-clustering support on CPU and on multi-GPU environments is
-experimental.
-
-For a detailed usage example see the [auto-clustering tutorial
-colab](./tutorials/autoclustering_xla.ipynb).
-
-### Explicit compilation with tf.function
-
-Auto-clustering is a great tool for making the model faster without any changes
-to the code, but it may be hard to understand what changes have been performed.
-
-Explicit compilation API offers a more fine-grained control for choosing which
-functions should be compiled.
-For example, the following TensorFlow function which performs the MNIST training
-is compiled with XLA:
+Explicit compilation API offers a fine-grained control for choosing which
+functions should be compiled. For example, the following TensorFlow function
+which performs the MNIST training is compiled with XLA:
 
 ```
 @tf.function(jit_compile=True)
@@ -116,8 +88,38 @@
 recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))
 ```
 
-See the [tutorial colab](./tutorials/compile.ipynb) for a more detailed usage
-example.
+Note: Nesting behavior: the function will be compiled if at least one function
+in its call stack has `jit_compile=True`.
+
+See the [tutorial colab](./tutorials/jit_compile.ipynb) for a more detailed
+usage example.
+
+### Auto-clustering
+
+A simple way to start using XLA in TensorFlow models without any changes is to
+enable _auto-clustering_, which automatically finds _clusters_ (connected
+subgraphs) within the TensorFlow functions which can be compiled and executed
+using XLA. Auto-clustering on GPU can be enabled by setting the `TF_XLA_FLAGS`
+environment variable:
+
+Note: In TF2, only the code inside `tf.function` will be clustered.
+
+```
+$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program
+```
+
+Auto-clustering is currently optimized for GPU workloads, but it can also be
+enabled on CPU by additionally using the flag `--tf_xla_cpu_global_jit`:
+
+```
+$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program
+```
+
+Note: Auto-clustering support on CPU and on multi-GPU environments is
+experimental.
+
+For a detailed usage example see the
+[auto-clustering tutorial colab](./tutorials/autoclustering_xla.ipynb).
 
 ### AOT (Ahead-of-time) compilation for CPU with `tfcompile`
 
@@ -177,6 +179,16 @@
 [`replay_computation`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/tools/run_hlo_module_main.cc)
 and iteratively running it on generated programs.
 
+## Further reading
+
+-   [Known Issues](./known_issues.md) List of known issues with XLA
+-   [XLA Architecture](./architecture.md): Overview of the XLA architecture
+-   [XLA - TensorFlow, Compiled](https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html):
+    Read on Google Developers Blog
+-   Check out the
+    [XLA source](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla)
+    on Github!
+
 ## XLA Frontends
 
 Apart from TensorFlow, XLA programs can be generated by:
@@ -187,15 +199,6 @@
     scientific computing
 -   [PyTorch](https://github.com/pytorch/xla): PyTorch framework
 
-## Further reading
-
--   [XLA Architecture](./architecture.md): Overview of the XLA architecture
--   [XLA - TensorFlow, Compiled](https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html):
-    Read on Google Developers Blog
--   Check out the
-    [XLA source](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla)
-    on Github!
-
 <iframe frameborder="0" allow="accelerometer; autoplay;
 encrypted-media; gyroscope; picture-in-picture; fullscreen" width="640" height="360"
 src="https://www.youtube.com/embed/kAOanJczHA0?origin=https%3A%2F%2Fwww.tensorflow.org&amp;autohide=1&amp;showinfo=0&amp;video-id=kAOanJczHA0&amp;enablejsapi=1&amp;widgetid=1"
diff --git a/tensorflow/compiler/xla/g3doc/known_issues.md b/tensorflow/compiler/xla/g3doc/known_issues.md
index 516d96b..0bcd4b0 100644
--- a/tensorflow/compiler/xla/g3doc/known_issues.md
+++ b/tensorflow/compiler/xla/g3doc/known_issues.md
@@ -3,7 +3,7 @@
 Compilation with XLA can greatly improve the performance of your programs, but
 the TensorFlow interop has a number of known sharp corners.
 
-## TensorArray TF/XLA interconversion
+## TensorArray TF/XLA interconversion is not supported
 
 *Error message*:
 `Support for TensorList crossing the XLA/TF boundary is not implemented`.
@@ -31,7 +31,7 @@
 parameter set to a constant value known at compile time, or backpropagation
 disabled using `back_prop=False`.
 
-## Dynamic `tf.TensorArray`
+## Dynamic `tf.TensorArray` is not supported
 
 Writes into `tf.TensorArray(..., dynamic_size=True)` are not compilable with
 XLA, as such writes require an unknown number of reallocations when the array
@@ -39,9 +39,16 @@
 
 *Workaround*: provide a statically known bound to your arrays.
 
-## Random number generation
+## Random number generation ignores TF seed
 
 XLA currently ignores TF seeds to random operations. This affects stateful TF
 random operations, such as `tf.random.normal`, or `tf.nn.dropout`.  XLA will
 behave as if the compilation was seeded with a new unique seed at each run. This
 limitation does not apply to stateless random ops.
+
+## TensorFlow Asserts are ignored
+
+Assertions created using `tf.Assert` and similar functions are noops when
+compiled to XLA. While proper assertion support is in principle possible, it
+might make certain optimizations impossible (mainly fusing the buffer on which
+the assertion is performed).
diff --git a/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb
similarity index 100%
rename from tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb
rename to tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb
diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD
index 02fc84c..b368afd 100644
--- a/tensorflow/compiler/xla/pjrt/BUILD
+++ b/tensorflow/compiler/xla/pjrt/BUILD
@@ -119,13 +119,59 @@
 
 cc_library(
     name = "pjrt_client",
-    srcs = ["pjrt_client.cc"],
     hdrs = ["pjrt_client.h"],
     visibility = ["//tensorflow/compiler/xla:friends"],
     deps = [
+        "//tensorflow/compiler/xla:executable_run_options",
+        "//tensorflow/compiler/xla:literal",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/client:executable_build_options",
+        "//tensorflow/compiler/xla/client:xla_computation",
+        "//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_cost_analysis",
+        "//tensorflow/core:lib",
+        "@com_google_absl//absl/base",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+cc_library(
+    name = "utils",
+    srcs = ["utils.cc"],
+    hdrs = ["utils.h"],
+    visibility = ["//tensorflow/compiler/xla:friends"],
+    deps = [
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/client:executable_build_options",
+        "//tensorflow/compiler/xla/client:xla_computation",
+        "//tensorflow/compiler/xla/service:computation_placer",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_proto_cc",
+        "@com_google_absl//absl/container:flat_hash_set",
+    ],
+)
+
+cc_library(
+    name = "pjrt_stream_executor_client",
+    srcs = ["pjrt_stream_executor_client.cc"],
+    hdrs = ["pjrt_stream_executor_client.h"],
+    visibility = ["//tensorflow/compiler/xla:friends"],
+    deps = [
         ":event_pool",
         ":local_device_state",
+        ":pjrt_client",
         ":tracked_device_buffer",
+        ":utils",
         "//tensorflow/compiler/xla:cpu_function_runtime",
         "//tensorflow/compiler/xla:executable_run_options",
         "//tensorflow/compiler/xla:literal",
@@ -181,7 +227,7 @@
     ],
     deps = [
         ":local_device_state",
-        ":pjrt_client",
+        ":pjrt_stream_executor_client",
         ":tracked_device_buffer",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status",
@@ -215,7 +261,7 @@
     srcs = ["interpreter_device.cc"],
     hdrs = ["interpreter_device.h"],
     deps = [
-        ":pjrt_client",
+        ":pjrt_stream_executor_client",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/compiler/xla/service:interpreter_plugin",
@@ -229,7 +275,7 @@
     srcs = ["cpu_device.cc"],
     hdrs = ["cpu_device.h"],
     deps = [
-        ":pjrt_client",
+        ":pjrt_stream_executor_client",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/compiler/xla/service:platform_util",
@@ -242,7 +288,7 @@
     srcs = ["gpu_device.cc"],
     hdrs = ["gpu_device.h"],
     deps = [
-        ":pjrt_client",
+        ":pjrt_stream_executor_client",
         "@com_google_absl//absl/container:flat_hash_map",
         "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
         "//tensorflow/compiler/xla:statusor",
@@ -279,6 +325,7 @@
     deps = [
         ":gpu_device",
         ":pjrt_client",
+        ":pjrt_stream_executor_client",
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla/client:executable_build_options",
         "//tensorflow/compiler/xla/client:xla_builder",
diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc
index 9b0f060..72da2d2 100644
--- a/tensorflow/compiler/xla/pjrt/cpu_device.cc
+++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc
@@ -17,7 +17,7 @@
 
 #include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 #include "tensorflow/compiler/xla/service/platform_util.h"
 
 namespace xla {
@@ -26,8 +26,8 @@
 
 CpuDevice::CpuDevice(int id,
                      std::unique_ptr<LocalDeviceState> local_device_state)
-    : PjRtDevice(id, std::move(local_device_state),
-                 /*device_kind=*/kCpuPlatformName) {}
+    : PjRtStreamExecutorDevice(id, std::move(local_device_state),
+                               /*device_kind=*/kCpuPlatformName) {}
 
 StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
   TF_ASSIGN_OR_RETURN(se::Platform * platform,
@@ -40,7 +40,7 @@
   TF_ASSIGN_OR_RETURN(LocalClient * client,
                       ClientLibrary::GetOrCreateLocalClient(options));
 
-  std::vector<std::unique_ptr<PjRtDevice>> devices;
+  std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
   for (int i = 0; i < client->device_count(); ++i) {
     se::StreamExecutorConfig config;
     config.ordinal = i;
@@ -57,11 +57,11 @@
     devices.push_back(std::move(device));
   }
 
-  return std::make_unique<PjRtClient>(
+  return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
       kCpuName, client, std::move(devices), /*host_id=*/0,
       /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
       /*should_stage_host_to_device_transfers=*/false,
-      /*gpu_run_options=*/nullptr);
+      /*gpu_run_options=*/nullptr));
 }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h
index 1036d8f..e0106fd 100644
--- a/tensorflow/compiler/xla/pjrt/cpu_device.h
+++ b/tensorflow/compiler/xla/pjrt/cpu_device.h
@@ -18,12 +18,12 @@
 
 #include <memory>
 
-#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 #include "tensorflow/compiler/xla/statusor.h"
 
 namespace xla {
 
-class CpuDevice : public PjRtDevice {
+class CpuDevice : public PjRtStreamExecutorDevice {
  public:
   CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state);
 };
diff --git a/tensorflow/compiler/xla/pjrt/gpu_device.cc b/tensorflow/compiler/xla/pjrt/gpu_device.cc
index 26f38c2..8c860d5 100644
--- a/tensorflow/compiler/xla/pjrt/gpu_device.cc
+++ b/tensorflow/compiler/xla/pjrt/gpu_device.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/xla/pjrt/gpu_device.h"
 
 #include "absl/container/flat_hash_map.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 
 #ifdef NCCL_ENABLED
 #include "third_party/nccl/nccl.h"
@@ -35,9 +36,9 @@
 namespace {
 
 // A custom PjRtClient that overrides the device assignment method.
-class GpuClient : public xla::PjRtClient {
+class GpuClient : public xla::PjRtStreamExecutorClient {
  public:
-  using xla::PjRtClient::PjRtClient;
+  using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
 
   xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
       int num_replicas, int num_partitions) const override;
@@ -55,7 +56,8 @@
     return assignment;
   }
   // Fallback to default global device assignment if we can't run locally.
-  return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
+  return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
+                                                              num_partitions);
 }
 
 // Builds an xla::LocalClient for the GPU platform.
@@ -225,9 +227,9 @@
   return result.first->second;
 }
 
-std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
+std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
     std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
-  std::vector<std::unique_ptr<PjRtDevice>> devices;
+  std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
   for (auto& local_device : local_device_states) {
     int device_ordinal = local_device->device_ordinal();
     const se::DeviceDescription& description =
@@ -243,7 +245,7 @@
 Status BuildDistributedDevices(
     std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
     std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
-    std::vector<std::unique_ptr<PjRtDevice>>* devices,
+    std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
     gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
   LocalTopologyProto local_topology;
   local_topology.set_node_id(node_id);
@@ -306,8 +308,8 @@
 GpuDevice::GpuDevice(int id,
                      std::unique_ptr<LocalDeviceState> local_device_state,
                      std::string device_kind, int node_id)
-    : PjRtDevice(id, std::move(local_device_state), std::move(device_kind),
-                 node_id) {}
+    : PjRtStreamExecutorDevice(id, std::move(local_device_state),
+                               std::move(device_kind), node_id) {}
 
 StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
     bool asynchronous, const GpuAllocatorConfig& allocator_config,
@@ -322,7 +324,7 @@
   auto host_memory_allocator =
       GetGpuHostAllocator(local_device_states.front()->executor());
 
-  std::vector<std::unique_ptr<PjRtDevice>> devices;
+  std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
   auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
   if (distributed_client) {
     TF_RETURN_IF_ERROR(BuildDistributedDevices(
diff --git a/tensorflow/compiler/xla/pjrt/gpu_device.h b/tensorflow/compiler/xla/pjrt/gpu_device.h
index 7ea85db..3e11c31 100644
--- a/tensorflow/compiler/xla/pjrt/gpu_device.h
+++ b/tensorflow/compiler/xla/pjrt/gpu_device.h
@@ -19,13 +19,13 @@
 #include <memory>
 
 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
-#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/core/common_runtime/bfc_allocator.h"
 
 namespace xla {
 
-class GpuDevice : public PjRtDevice {
+class GpuDevice : public PjRtStreamExecutorDevice {
  public:
   GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state,
             std::string device_kind, int node_id);
diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc
index 2819cab..818740c 100644
--- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc
+++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc
@@ -17,7 +17,7 @@
 
 #include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 #include "tensorflow/compiler/xla/service/platform_util.h"
 
 namespace xla {
@@ -26,8 +26,8 @@
 
 InterpreterDevice::InterpreterDevice(
     int id, std::unique_ptr<LocalDeviceState> local_device_state)
-    : PjRtDevice(id, std::move(local_device_state),
-                 /*device_kind=*/kInterpreterPlatformName) {}
+    : PjRtStreamExecutorDevice(id, std::move(local_device_state),
+                               /*device_kind=*/kInterpreterPlatformName) {}
 
 StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
   TF_ASSIGN_OR_RETURN(se::Platform * platform,
@@ -41,7 +41,7 @@
   TF_ASSIGN_OR_RETURN(LocalClient * client,
                       ClientLibrary::GetOrCreateLocalClient(options));
 
-  std::vector<std::unique_ptr<PjRtDevice>> devices;
+  std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
   se::StreamExecutor* executor =
       client->backend().stream_executor(0).ValueOrDie();
   auto device_state = absl::make_unique<LocalDeviceState>(
@@ -51,11 +51,11 @@
       absl::make_unique<InterpreterDevice>(0, std::move(device_state));
   devices.push_back(std::move(device));
 
-  return std::make_unique<PjRtClient>(
+  return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
       "interpreter", client, std::move(devices), /*host_id=*/0,
       /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
       /*should_stage_host_to_device_transfers=*/false,
-      /*gpu_run_options=*/nullptr);
+      /*gpu_run_options=*/nullptr));
 }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.h b/tensorflow/compiler/xla/pjrt/interpreter_device.h
index 4038d8d..4a4477a 100644
--- a/tensorflow/compiler/xla/pjrt/interpreter_device.h
+++ b/tensorflow/compiler/xla/pjrt/interpreter_device.h
@@ -18,12 +18,12 @@
 
 #include <memory>
 
-#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 #include "tensorflow/compiler/xla/statusor.h"
 
 namespace xla {
 
-class InterpreterDevice : public PjRtDevice {
+class InterpreterDevice : public PjRtStreamExecutorDevice {
  public:
   InterpreterDevice(int id,
                     std::unique_ptr<LocalDeviceState> local_device_state);
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h
index b32d288..b54c93b 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_client.h
+++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h
@@ -20,31 +20,21 @@
 #include <string>
 #include <vector>
 
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
 #include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/synchronization/notification.h"
 #include "absl/types/optional.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/client/executable_build_options.h"
-#include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/layout.h"
-#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
-#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
-#include "tensorflow/compiler/xla/service/computation_placer.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
+#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/shaped_buffer.h"
 #include "tensorflow/compiler/xla/shape.h"
 #include "tensorflow/compiler/xla/status.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/casts.h"
 #include "tensorflow/core/platform/fingerprint.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/platform/types.h"
@@ -67,74 +57,42 @@
 
 class PjRtDevice {
  public:
-  explicit PjRtDevice(int id,
-                      std::unique_ptr<LocalDeviceState> local_device_state,
-                      std::string device_kind, int host_id = 0)
-      : id_(id),
-        local_device_id_(
-            local_device_state ? local_device_state->device_ordinal() : -1),
-        local_device_state_(std::move(local_device_state)),
-        host_id_(host_id),
-        device_kind_(std::move(device_kind)) {}
   virtual ~PjRtDevice() {}
 
-  // Must set client exactly once.
-  void SetClient(PjRtClient* client) {
-    CHECK(client_ == nullptr);
-    client_ = client;
-  }
+  // Return the client that owns this device.
+  virtual PjRtClient* client() const = 0;
+
+  // Whether client can issue command to this device.
+  virtual bool IsAddressable() const = 0;
 
   // The ID of this device. IDs are unique among devices of this type
   // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
   // hosts' devices.  This is the ID that should be used in a DeviceAssignment.
-  int id() const { return id_; }
+  virtual int id() const = 0;
 
-  bool IsLocalDevice() const { return local_device_id_ != -1; }
+  // The task ID of this device according to TpuTopology. This is not the same
+  // as PjRtClient::host_id() in a multi-task setting, where each client can see
+  // devices from all tasks, but only a subset of them are addressable and have
+  // the same task_id as the client.
+  virtual int host_id() const = 0;
 
-  int local_device_id() const { return local_device_id_; }
+  // Opaque hardware ID, e.g., the CUDA device number, useful for identifying
+  // which GPU when interacting with non-JAX code. In general, not guaranteed to
+  // be dense, and -1 if undefined.
+  virtual int local_hardware_id() const = 0;
 
-  // If this is a device local to this host, returns a LocalDeviceState object
-  // that can be used to manipulate the device. Returns nullptr if the device is
-  // not local to this host.
-  LocalDeviceState* local_device_state() const {
-    return local_device_state_.get();
-  }
+  // A vendor-dependent string that uniquely identifies the kind of device,
+  // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
+  // compatible compilation.
+  virtual const std::string& device_kind() const = 0;
 
-  // If this is a device local to this host, returns a LocalDeviceState object
-  // that can be used to manipulate the device. Returns an error if the device
-  // is not local to this host.
-  StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
+  virtual std::string DebugString() const = 0;
 
-  // The ID of this device's host. This is always 0 on single-host platforms.
-  int host_id() const { return host_id_; }
+  // Transfer the given literal to the infeed queue.
+  virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
 
-  // Return `platform_id` from client.
-  PjRtPlatformId platform_id() const;
-
-  // Return `platform_name` from client.
-  const std::string& platform_name() const;
-
-  // A vendor-dependent string that uniquely identifies the kind of device.
-  const std::string& device_kind() const { return device_kind_; }
-
-  virtual std::string DebugString() const;
-
-  PjRtClient* client() const { return client_; }
-
-  // Transfer the given literal to the infeed queue of the given localdevice.
-  virtual Status TransferToInfeed(const LiteralSlice& literal) const;
-
-  // Transfer and return a value of the given shape from the outfeed of the
-  // given device.
-  virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
-
- private:
-  const int id_;
-  const int local_device_id_;  // -1 means not local.
-  const std::unique_ptr<LocalDeviceState> local_device_state_;
-  const int host_id_;
-  const std::string device_kind_;
-  PjRtClient* client_ = nullptr;
+  // Transfer and return a value of the given shape from the outfeed queue.
+  virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
 };
 
 // Forward declaration.
@@ -178,86 +136,62 @@
 // alive as long as any of the other runtime objects are alive.
 class PjRtClient {
  public:
-  // `allocator` may null, in which case the platform default allocator is used.
-  explicit PjRtClient(
-      std::string platform_name, LocalClient* client,
-      std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
-      std::unique_ptr<se::DeviceMemoryAllocator> allocator,
-      std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
-      bool should_stage_host_to_device_transfers,
-      std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
   virtual ~PjRtClient() = default;
 
+  // TODO(zhangqiaorjc): Rename to task_id.
+  // Return the task id of this client. In single-task setting, always 0.
+  virtual int host_id() const = 0;
+
+  // Return the number of devices in the entire computation. In multi-headed
+  // client setting, some are addressable by this client, some are not. In a
+  // single-client setting, this is equal to the number of addressable devices.
+  virtual int device_count() const = 0;
+
+  // Return number of addressable devices. Addressable devices are those that
+  // the client can issue commands to.
+  virtual int addressable_device_count() const = 0;
+
+  // Return all devices in the entire computation, including addressable and
+  // non-addressable devices.
+  virtual absl::Span<PjRtDevice* const> devices() const = 0;
+
+  // TODO(zhangqiaorjc): Rename to addressable_devices.
+  // Return only addressable devices.
+  virtual absl::Span<PjRtDevice* const> local_devices() const = 0;
+
+  // Lookup any PjRtDevice for a given PjRtDevice::id().
+  virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
+
+  // Return an addressable PjRtDevice for a given
+  // PjRtDevice::local_hardware_id().
+  virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
+      int local_hardware_id) const = 0;
+
+  // Return an ID that identifies the platform (CPU/GPU/TPU).
+  virtual PjRtPlatformId platform_id() const = 0;
+
+  // Returns a string that identifies the platform (CPU/GPU/TPU).
+  virtual const std::string& platform_name() const = 0;
+
+  // Return a device-specific default device assignment, e.g., GPU and TPU may
+  // be different.
   virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
-      int num_replicas, int num_partitions) const;
-
-  int device_count() const { return devices_.size(); }
-  int local_device_count() const { return local_devices_.size(); }
-  const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
-    return devices_;
-  }
-  const std::vector<PjRtDevice*>& local_devices() const {
-    return local_devices_;
-  }
-  const std::map<int, PjRtDevice*>& id_to_device() const {
-    return id_to_device_;
-  }
-  int host_id() const { return host_id_; }
-  PjRtPlatformId platform_id() const { return platform_id_; }
-  const std::string& platform_name() const { return platform_name_; }
-
-  LocalDeviceState& device_state(int device_ordinal) const {
-    return *local_devices_.at(device_ordinal)->local_device_state();
-  }
-
-  // Return a local PjRtDevice for a given `local_device_id`.
-  virtual StatusOr<PjRtDevice*> LookupLocalDevice(int local_device_id) const;
-
-  LocalClient* client() const { return client_; }
-  se::DeviceMemoryAllocator* allocator() const { return allocator_; }
-  tensorflow::Allocator* host_memory_allocator() const {
-    return host_memory_allocator_.get();
-  }
-  bool should_stage_host_to_device_transfers() const {
-    return should_stage_host_to_device_transfers_;
-  }
-
-  gpu::GpuExecutableRunOptions* gpu_run_options() const {
-    return gpu_run_options_.get();
-  }
-
-  tensorflow::thread::ThreadPool* h2d_transfer_pool() {
-    return &h2d_transfer_pool_;
-  }
-
-  // Most platforms expect device-to-device transfers to be enqueued on the
-  // source d2d stream, but some platforms use the destination d2d stream. This
-  // function specifies which one the platform expects.
-  virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
-
-  // Generates a unique fingerprint for `executable`.
-  virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
-      const PjRtExecutable& executable) const {
-    return absl::optional<std::string>();
-  }
+      int num_replicas, int num_partitions) const = 0;
 
   // Returns a backend-specific HLO cost analysis visitor.
-  virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
+  virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() = 0;
 
+  // Compile `computation` with given `options`.
   virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
-      const XlaComputation& computation, CompileOptions options);
+      const XlaComputation& computation, CompileOptions options) = 0;
+
+  // Generates a unique fingerprint for `executable`, may be absl::nullopt.
+  virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
+      const PjRtExecutable& executable) const = 0;
 
   // Creates a buffer on the device without initializing or copying any data.
-  // An optional `definition_event` may be speficied that can be used to
-  // ensure the buffer isn't referenced until some external mechanism has
-  // initialized the data.
-  // NOTE: The sequencing mechanism is not guaranteed to be supported by all
-  // future backends and so callers should avoid wherever possible.
   virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
-      const Shape& shape, PjRtDevice* device);
-  virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
-      const Shape& shape, PjRtDevice* device,
-      std::shared_ptr<BufferSequencingEvent> definition_event);
+      const Shape& shape, PjRtDevice* device) = 0;
 
   // Describes the semantics the caller to BufferFromHostBuffer expects from the
   // runtime, in a total order from most restrictive to least restrictive.
@@ -289,13 +223,13 @@
   virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
       const void* data, const Shape& shape,
       HostBufferSemantics host_buffer_semantics,
-      std::shared_ptr<void> buffer_reference, PjRtDevice* device);
+      std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
 
   // Note that literal must remain in scope until the transfer has completed, so
   // the caller should, for example, wait for BlockHostUntilReady() completes on
   // the return value before letting literal go out of scope.
   virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
-      const LiteralSlice& literal, PjRtDevice* device);
+      const LiteralSlice& literal, PjRtDevice* device) = 0;
 
   // Asynchronously makes a vector of PjRtBuffers that can be used to receive
   // cross host transfers using `client` on `device'. `shapes` must be the exact
@@ -308,65 +242,14 @@
   // buffers will become ready until *all* of the sends have completed.
   virtual void MakeCrossHostReceiveBuffers(
       absl::Span<const Shape> shapes, PjRtDevice* device,
-      PjRtCrossHostRecvNotifier&& notifier);
+      PjRtCrossHostRecvNotifier&& notifier) = 0;
 
-  virtual StatusOr<ChannelHandle> CreateChannelHandle() {
-    return client()->CreateChannelHandle();
-  }
-  virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
-    return client()->CreateDeviceToHostChannelHandle();
-  }
-  virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
-    return client()->CreateHostToDeviceChannelHandle();
-  }
-
- protected:
-  friend class PjRtBuffer;
-  virtual void EnqueueCrossHostReceive(
-      std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
-      std::shared_ptr<BufferSequencingEvent> definition_event,
-      PjRtCrossHostRecvNotifier&& notifier) const {
-    notifier(Unimplemented("Cross host receives not implemented."));
-  }
-
-  virtual Status CopyToRemoteDevice(
-      PjRtBuffer* buffer, absl::string_view serialized_descriptor) const {
-    return Unimplemented("Cross host sends not implemented.");
-  }
-
-  const PjRtPlatformId platform_id_;
-  const std::string platform_name_;
-  LocalClient* client_;
-
-  // Allocator to be used for staging memory transfers to devices.
-  std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
-
-  // Includes all devices, including non-local devices on multi-host platforms.
-  std::vector<std::unique_ptr<PjRtDevice>> devices_;
-  // Maps Device::id() to the corresponding Device. Includes all devices.
-  std::map<int, PjRtDevice*> id_to_device_;
-  // Local devices indexed by local device ordinal.
-  std::vector<PjRtDevice*> local_devices_;
-  int host_id_;
-
-  se::DeviceMemoryAllocator* allocator_;
-  std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
-
-  // Should we always prefer to stage host-to-device transfers via memory
-  // allocated on host_memory_allocator_? True only on GPU, where we prefer to
-  // transfer via pinned memory.
-  bool should_stage_host_to_device_transfers_;
-
-  std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
-
-  tensorflow::thread::ThreadPool h2d_transfer_pool_;
+  // Create ChannelHandles for XLA send/recv.
+  virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
+  virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
+  virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
 };
 
-// Converts a 2D set of Device objects indexed by [replica][partition] into an
-// xla::DeviceAssignment.
-StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
-    absl::Span<const std::vector<PjRtDevice*>> devices);
-
 // Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
 // can be either valid or invalid. An invalid buffer is one that has never been
 // initialized, or a buffer that has been deleted (e.g., by calling Delete, or
@@ -376,200 +259,30 @@
 // references if needed. Thread-safe.
 class PjRtBuffer {
  public:
-  // Helper class to retain a "hold" on a PjRtBuffer. A ScopedHold may not
-  // outlive its parent PjRtBuffer.
-  //
-  // There are three types of hold, as follows:
-  //
-  // 1) Usage hold: a transient hold while an operation using the buffer is
-  //    being enqueued onto a stream.
-  // A client acquires a usage hold by calling
-  // PjRtBuffer::GetBufferWithHold(kUsage) or the convenience wrapper
-  // GetBufferWithUsageHold(). If the enqueue completes successfully the hold
-  // should be released using a call to ConvertUsageHold. If the ScopedHold is
-  // deleted without ConvertUsageHold being called, e.g., on error, the hold is
-  // dropped. It is legal to drop a usage hold instead of calling
-  // ConvertUsageHold, even if the buffer was successfully enqueued, as long as
-  // the client ensures that all necessary synchronization has been done.
-  //
-  // 2) External hold: a potentially long-lived hold while the buffer is being
-  //    shared by an external framework, e.g., NumPy.
-  // A client acquires an external hold by calling
-  // PjRtBuffer::GetBufferWithHold(kExternal) or the convenience wrapper
-  // GetBufferWithExternalReference and releases it by deleting the ScopedHold.
-  // The external framework should not modify the underlying buffer unless it is
-  // confident via its own synchronization that modifications do not race with
-  // reads from the PjRtBuffer.
-  //
-  // 3) Donation hold: a transient hold while an execution that donates the
-  //    buffer is being enqueued onto the compute stream.
-  // A client acquires a donation hold by calling
-  // PjRtBuffer::GetBufferWithHold(kDonation). If the enqueue completes
-  // successfully the hold should be released using a call to ConfirmDonation
-  // after which the buffer is invalid. If the ScopedHold is deleted without
-  // ConfirmDonation being called, e.g., on error, the hold is dropped and the
-  // buffer remains valid. If the buffer is successfully enqueued the client
-  // *must* call ConfirmDonation.
-  //
-  // Donation holds behave like exclusive write locks: when a donation hold
-  // has been acquired, any attempt to acquire another hold of any type will
-  // block until the donation hold is dropped or confirmed. Acquiring a donation
-  // hold will fail with an error if there is any outstanding external hold, and
-  // will block if there are any outstanding usage holds until those holds are
-  // dropped or converted.
-  //
-  // Calls to PjRtBuffer::Release (and transitively to
-  // PjRtBuffer::Delete() and ~PjRtBuffer()) will block until all usage
-  // and donation holds are either deleted or converted/confirmed.
-  class ScopedHold {
-   public:
-    enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
-    // Use a State enum instead of encoding the state in an error Status to
-    // avoid creating Status values in non-error cases. Creating a Status
-    // entails several allocations and can add O(us) to every use of a hold.
-    enum State {
-      kUninitialized = 0,
-      kValid,
-      kMoved,
-      kConverted,
-      kReleased,
-      kDonated,
-      kError
-    };
+  virtual ~PjRtBuffer() = default;
 
-    ~ScopedHold();
-    ScopedHold(ScopedHold&& other);
-    ScopedHold(const ScopedHold&) = delete;
-    ScopedHold& operator=(const ScopedHold&) = delete;
-
-    Type type() const { return type_; }
-
-    Status status() const {
-      // Lazily create Status values only when they are requested.
-      switch (state_) {
-        case kUninitialized:
-          return InvalidArgument("Buffer has not been initialized");
-        case kValid:
-          return Status::OK();
-        case kMoved:
-          return InvalidArgument("Buffer has been moved.");
-        case kConverted:
-          return InvalidArgument("Buffer has been converted");
-        case kReleased:
-          return InvalidArgument("Buffer has been released");
-        case kDonated:
-          return InvalidArgument("Buffer has been donated");
-        case kError:
-          return buffer_or_.status();
-        default:
-          CHECK(false) << "Unexpected state value " << state_;
-      }
-    }
-    bool ok() const { return state_ == kValid; }
-
-    // Access to the underlying device buffer storage. Requires this->ok().
-    const std::shared_ptr<TrackedDeviceBuffer>& buffer() const {
-      CHECK_EQ(state_, kValid);
-      CHECK_NE(buffer_or_.ValueOrDie(), nullptr);
-      return buffer_or_.ValueOrDie();
-    }
-    TrackedDeviceBuffer* operator->() const { return buffer().get(); }
-    const TrackedDeviceBuffer& operator*() const { return *buffer(); }
-
-    // Converts the hold into a usage event. Only valid for holds of type
-    // kUsage.
-    //
-    //   usage_stream:   the stream that the buffer was used on.
-    //   event:          an event that has been recorded on usage_stream after
-    //                   the buffer was used.
-    //   reference_held: true if and only if the caller has caused a
-    //                   reference to this->buffer() to stay live until after
-    //                   the host is sure that the usage (transfer or execution)
-    //                   has completed.
-    void ConvertUsageHold(se::Stream* usage_stream,
-                          std::shared_ptr<BufferSequencingEvent> event,
-                          bool reference_held);
-
-    // Confirms that the buffer was successfully donated to an execution.
-    // Only valid for holds of type kDonation. Causes the buffer to become
-    // invalid.
-    void ConfirmDonation();
-
-    // Adds the held device buffers in order to 'iterator'. Used to add the
-    // buffers to an ExecutionInput. We require but do not verify that
-    // 'iterator' when passed in is pointing to a sub-tuple of the
-    // ExecutionInput whose on_device_shape matches that of the
-    // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
-    // out of bounds. Donates the device buffers if the hold type is kDonation,
-    // otherwise retains ownership of the device buffers.
-    void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
-                    const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
-                    ExecutionInput* execution_input,
-                    se::DeviceMemoryAllocator* allocator) const;
-
-   private:
-    friend class PjRtBuffer;
-    friend class PjRtClient;
-
-    // Helper struct that makes it possible to move a ScopedHold through a
-    // closure.
-    using ForClosure =
-        std::tuple<PjRtBuffer*, Type, State,
-                   StatusOr<std::shared_ptr<TrackedDeviceBuffer>>>;
-
-    ScopedHold(PjRtBuffer* parent, Type type)
-        : parent_(parent), type_(type), state_(kUninitialized) {}
-    explicit ScopedHold(const ForClosure& closure_helper)
-        : parent_(std::get<0>(closure_helper)),
-          type_(std::get<1>(closure_helper)),
-          state_(std::get<2>(closure_helper)),
-          buffer_or_(std::get<3>(closure_helper)) {
-      // Check the buffer is not in an error state.
-      CHECK(buffer_or_.ValueOrDie() != nullptr);
-    }
-
-    // Sets buffer state.
-    void SetState(State state) { state_ = state; }
-
-    // Sets buffer_or_. Called by parent_ to initialize the hold.
-    void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or);
-    // Releases the contents of *this, so *this can subsequently be
-    // deleted without releasing the parent's hold. Should be passed to the
-    // appropriate constructor of another ScopedHold, e.g., when a hold must be
-    // passed through a closure that is incompatible with std::move.
-    ForClosure ToClosure();
-
-    PjRtBuffer* const parent_;
-    const Type type_;
-
-    // There is an invariant that if ok() then
-    // buffer_or_.ValueOrDie() != nullptr.
-    State state_;
-    StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
-  };
-
-  PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
-             std::shared_ptr<TrackedDeviceBuffer> device_buffer,
-             PjRtClient* client, PjRtDevice* device);
-  virtual ~PjRtBuffer();
-
-  PjRtBuffer(const PjRtBuffer&) = delete;
-  PjRtBuffer(PjRtBuffer&&) = delete;
-  PjRtBuffer& operator=(const PjRtBuffer&) = delete;
-  PjRtBuffer& operator=(PjRtBuffer&&) = delete;
-
-  const Shape& on_host_shape() const { return on_host_shape_; }
-  const Shape& on_device_shape() const { return on_device_shape_; }
-  PjRtDevice* device() const { return device_; }
-  PjRtPlatformId platform_id() const { return client_->platform_id(); }
-  const std::string& platform_name() const { return client_->platform_name(); }
-  PjRtClient* client() const { return client_; }
-  bool IsEmptyTuple() const {
-    return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
-  }
+  virtual const Shape& on_host_shape() const = 0;
+  virtual const Shape& on_device_shape() const = 0;
+  virtual PjRtDevice* device() const = 0;
+  virtual PjRtClient* client() const = 0;
 
   // Returns the size of the on-device representation of this buffer in bytes.
-  int64 OnDeviceSizeInBytes() const;
+  virtual int64 OnDeviceSizeInBytes() const = 0;
+
+  // ExternalReferenceHold is a potentially long-lived hold while the buffer is
+  // being shared by an external framework, e.g., NumPy. A client acquires an
+  // external hold by calling PjRtBuffer::AcquireExternalReference() and
+  // releases it by deleting the ExternalReferenceHold. The external framework
+  // should not modify the underlying buffer unless it is confident via its own
+  // synchronization that modifications do not race with reads from the
+  // PjRtBuffer.
+  struct ExternalReferenceHold {
+    virtual ~ExternalReferenceHold() = default;
+    // Return opaque device memory pointer to root buffer.
+    virtual void* OpaqueDeviceMemoryDataPointer() const = 0;
+  };
+  virtual StatusOr<std::unique_ptr<ExternalReferenceHold>>
+  AcquireExternalReference() = 0;
 
   // Returns the buffer's value as an XLA Literal. If the value has previously
   // been prefetched to the host, then returns the prefetched version, otherwise
@@ -578,15 +291,21 @@
   // cached copy of the literal (i.e. The reference to the host value will be
   // removed.) If a layout is passed than a literal with this layout will be
   // returned.
-  StatusOr<std::shared_ptr<Literal>> ToLiteral(
-      bool discard_cached_copy = false,
-      absl::optional<xla::Layout> layout = {});
+  StatusOr<std::shared_ptr<Literal>> ToLiteral() {
+    return ToLiteral(/*discard_cached_copy=*/false, /*layout=*/{});
+  }
+  StatusOr<std::shared_ptr<Literal>> ToLiteral(bool discard_cached_copy) {
+    return ToLiteral(discard_cached_copy, /*layout=*/{});
+  }
+  virtual StatusOr<std::shared_ptr<Literal>> ToLiteral(
+      bool discard_cached_copy, absl::optional<xla::Layout> layout) = 0;
 
   // Initiates a copy of the buffer to the host. Does not block waiting for
   // the transfer to complete. The value can be retrieved by a later call to
   // ToLiteral(). If a layout is passed then a cached copy with this layout will
   // be created.
-  Status CopyToHostAsync(absl::optional<xla::Layout> layout = {});
+  Status CopyToHostAsync() { return CopyToHostAsync(/*layout=*/{}); }
+  virtual Status CopyToHostAsync(absl::optional<xla::Layout> layout) = 0;
 
   // Drops the buffer's reference to its associated device memory, leaving the
   // buffer in an invalid state. The memory will be freed lazily when all async
@@ -597,47 +316,36 @@
   // framework holds a reference to the TrackedDeviceBuffer via
   // GetBufferWithExternalReference, the memory will not be freed until the
   // external framework drops the reference.
-  void Delete();
+  virtual void Delete() = 0;
 
   // Similar to Delete, drops the buffer's reference to its associated device
-  // memory, leaving the buffer in an invalid state, but returns the
-  // TrackedDeviceBuffer rather than freeing the device memory, so that another
-  // framework can take ownership of it. The buffer returned from Release may
-  // be safely dropped at any time even if it still has pending async
-  // operations. The client should call BlockHostUntilReady before calling
-  // Release with wait_for_operations_to_complete=false, to ensure that the host
-  // has synchronized past any outstanding write operations to the buffer. If
+  // memory, leaving the buffer in an invalid state, but transfers the device
+  // memory ownership out via absl::optional<std::shared_ptr<void>> rather than
+  // freeing the device memory, so that another framework can take ownership of
+  // it. A return value of absl::nullopt indicates that PjRtBuffer has been
+  // deleted. The buffer returned from Release may be safely dropped at any time
+  // even if it still has pending async operations. The client should call
+  // BlockHostUntilReady before calling ReleaseDeviceMemoryOwnership with
+  // wait_for_operations_to_complete=false, to ensure that the host has
+  // synchronized past any outstanding write operations to the buffer. If
   // wait_for_operations_to_complete=true the host will block until any
   // potentially outstanding asynchronous operations have completed before
   // returning, in which case it is safe to read or mutate the returned buffer.
   // If the buffer was shared via an external reference it is the client's
   // responsibility that accesses via that reference do not interfere with
-  // accesses via the buffer returned from Release.
-  StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release(
-      bool wait_for_operations_to_complete);
+  // accesses via the buffer returned from ReleaseDeviceMemoryOwnership.
+  virtual StatusOr<absl::optional<std::shared_ptr<void>>>
+  ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete) = 0;
 
   // True if and only if Delete or Release has previously been called.
-  bool IsDeleted();
-
-  // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The
-  // PjRtBuffer retains ownership of the device buffers.
-  StatusOr<ShapedBuffer> AsShapedBuffer() const;
-
-  // Returns a hold on the TrackedDeviceBuffer holding the device
-  // buffers. See comment on ScopedHold.
-  ScopedHold GetBufferWithHold(ScopedHold::Type type);
-  ScopedHold GetBufferWithUsageHold() {
-    return GetBufferWithHold(ScopedHold::kUsage);
-  }
-  ScopedHold GetBufferWithExternalReference() {
-    return GetBufferWithHold(ScopedHold::kExternalReference);
-  }
+  virtual bool IsDeleted() = 0;
 
   // Copies the buffer to device `dst_device`, performing a d2d transfer when
   // `dst_device` is sharing the same Client, and performing a d2h and h2d copy
   // if `dst_device` lives on a different Client.
   // Returns an error if the buffer is already on dst_device.
-  StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(PjRtDevice* dst_device);
+  virtual StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
+      PjRtDevice* dst_device) = 0;
 
   // Copies the buffer to the remote device encoded in serialized_descriptor.
   // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the
@@ -648,90 +356,15 @@
   // matching call to src->CopyToRemoteDevice on a remote host for a src buffer
   // of the corresponding shape. serialized_descriptor is the string returned by
   // the callback along with the corresponding destination buffer.
-  Status CopyToRemoteDevice(absl::string_view serialized_descriptor);
+  virtual Status CopyToRemoteDevice(
+      absl::string_view serialized_descriptor) = 0;
 
   // Blocks the host until the buffer's value has been computed and is ready for
   // immediate use on the device. Useful in particular for timing benchmarks.
-  Status BlockHostUntilReady();
+  virtual Status BlockHostUntilReady() = 0;
 
   // Whether this buffer is on CPU and thus allows for certain optimizations.
-  bool IsOnCpu() const;
-
- private:
-  friend class PjRtClient;
-  // The cached value of the buffer on the host, produced either from a call to
-  // CopyToHost or from a call to ToLiteral. Once a value has been fetched to
-  // the host, it persists Delete() is called or the PjRtBuffer is destroyed.
-  struct HostValue {
-    absl::Notification ready;
-    // status and value are valid for reading only after `ready` has been
-    // notified.
-    Status status;
-    std::shared_ptr<Literal> value;
-  };
-
-  // Blocks in mu_.Await until there are no more usage holds.
-  void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Blocks in mu_.Await until there is no donation hold.
-  void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Adds a hold of 'type' and returns device_buffer_. Returns an error if
-  // device_buffer_ is null, or if a donation hold was requested when there is
-  // an outstanding external hold.
-  StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked(
-      ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Adds a hold of hold->type() and initializes `hold` with device_buffer_.
-  // Initializes hold with an error if device_buffer_ is null, or if a donation
-  // hold was requested when there is an outstanding external hold.
-  void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
-  // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after
-  // device_buffer_ was successfully enqueued on a stream.
-  void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
-                        std::shared_ptr<BufferSequencingEvent> event,
-                        bool reference_held);
-
-  // Drops a donation hold and makes *this invalid for further use. Does a
-  // sanity check that buffer==device_buffer_. Called after device_buffer_ was
-  // successfully donated to an execution.
-  void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
-
-  // Initiates a copy of the buffer to the host. Does not block waiting for
-  // the transfer to complete. A host value is returned and if
-  // `discard_cached_copy` is false stored in an internal buffer so that future
-  // transfers don't have to transfer the data from host again. If a layout is
-  // passed then a literal of this layout will be returned and possibly cached.
-  StatusOr<std::shared_ptr<HostValue>> CopyToHostAsyncInternal(
-      bool discard_cached_copy, absl::optional<xla::Layout> layout);
-
-  // Drops a hold without taking any other action. Does a sanity check that
-  // buffer==device_buffer_ or device_buffer_==nullptr.
-  void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
-
-  StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
-                     std::shared_ptr<BufferSequencingEvent>>>
-  CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
-                     LocalDeviceState* transfer_local_device,
-                     se::Stream* transfer_stream,
-                     std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
-
-  PjRtClient* const client_;
-  const Shape on_host_shape_;
-  const Shape on_device_shape_;
-  PjRtDevice* const device_;
-
-  mutable absl::Mutex mu_;
-  std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
-  absl::flat_hash_map<xla::Layout, std::shared_ptr<HostValue>> host_values_
-      TF_GUARDED_BY(mu_);
-  std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
-  // Count of holds on the buffer.
-  std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
-  // Semaphore used to ensure there is only one outstanding donation hold.
-  Semaphore donation_semaphore_;
+  virtual bool IsOnCpu() const = 0;
 };
 
 class ExecuteContext {
@@ -769,7 +402,7 @@
   virtual PjRtClient* client() const = 0;
 
   // Unique name for this executable, e.g., HloModule name.
-  virtual const string& name() const = 0;
+  virtual const std::string& name() const = 0;
 
   virtual int num_replicas() const = 0;
 
@@ -791,6 +424,7 @@
   virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
       const = 0;
 
+  // An addressable_device is one which the client can issue commands to.
   // addressable_devices()[i] is the Device to which
   // addressable_device_logical_ids()[i] is assigned.
   virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
@@ -804,167 +438,26 @@
   // by the client.
   virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
   Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
-          const ExecuteOptions& options) const = 0;
+          const ExecuteOptions& options) = 0;
 
   // Execute the assigned replica/partition on a given `device`. Requires
   // executable has a device_assignment, `device` is present in the
   // device_assignment and addressable by the client.
   virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
-      const ExecuteOptions& options) const = 0;
+      const ExecuteOptions& options) = 0;
 
   // Execute on a given `device`. Requires `device` to be addressable by client.
   // Requires executable has exactly 1 replica and 1 partition and no
   // device_assignment (thus portable).
   virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
-      const ExecuteOptions& options) const = 0;
+      const ExecuteOptions& options) = 0;
 
   // Asynchronously free resources after the last execution completes.
   virtual void Delete() = 0;
 };
 
-// Wraps one or more XLA LocalExecutables (one per partition, as specified by
-// the build options).
-class PjRtStreamExecutorExecutable : public PjRtExecutable {
- public:
-  PjRtStreamExecutorExecutable(
-      std::vector<std::unique_ptr<LocalExecutable>> executables,
-      bool parameter_is_tupled_arguments,
-      std::shared_ptr<DeviceAssignment> device_assignment,
-      std::vector<LogicalDeviceIds> addressable_device_logical_ids,
-      std::vector<PjRtDevice*> addressable_devices, PjRtClient* client);
-
-  ~PjRtStreamExecutorExecutable() override = default;
-
-  PjRtClient* client() const override { return client_; }
-
-  const string& name() const override;
-
-  int num_replicas() const override {
-    return executables_[0]->build_options().num_replicas();
-  }
-
-  int num_partitions() const override {
-    return executables_[0]->build_options().num_partitions();
-  }
-
-  int64 SizeOfGeneratedCodeInBytes() const override {
-    int64 size = 0;
-    for (auto& executable : executables_) {
-      size += executable->executable()->SizeOfGeneratedCodeInBytes();
-    }
-    return size;
-  }
-
-  const DeviceAssignment& device_assignment() const override {
-    return *device_assignment_;
-  }
-
-  absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
-      const override {
-    return addressable_device_logical_ids_;
-  }
-
-  absl::Span<PjRtDevice* const> addressable_devices() const override {
-    return addressable_devices_;
-  }
-
-  // Return an HloModule per partition.
-  StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
-      const override;
-
-  StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
-      absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
-      const ExecuteOptions& options) const override;
-
-  StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
-      absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
-      const ExecuteOptions& options) const override;
-
-  StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
-      absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
-      const ExecuteOptions& options) const override;
-
-  void Delete() override { executables_.clear(); }
-
-  absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
-    return executables_;
-  }
-
- protected:
-  bool parameter_is_tupled_arguments() const {
-    return parameter_is_tupled_arguments_;
-  }
-
- private:
-  friend class PjRtClient;
-  // Initializes information about which arguments to which executables must be
-  // donated due to aliases that were specified by the computation.
-  Status SetUpDonation(bool tuple_inputs);
-
-  virtual bool MustDonateParameter(int executable_idx, int parameter) const;
-
-  virtual StatusOr<std::vector<ExecutionInput>>
-  MakeExecutionInputsAndWaitForEvents(
-      int device_ordinal, const ExecuteOptions& options,
-      absl::Span<PjRtBuffer* const> argument_handles,
-      absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
-      absl::flat_hash_set<BufferSequencingEvent*>& events) const;
-
-  StatusOr<ScopedShapedBuffer> EnqueueExecution(
-      absl::Span<PjRtBuffer* const> argument_handles, int replica,
-      int partition, int executable_idx, const RunId& run_id,
-      const ExecuteOptions& options, PjRtDevice* device,
-      std::vector<PjRtBuffer::ScopedHold>* device_buffers,
-      std::shared_ptr<DeviceAssignment> device_assignment) const;
-
-  virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers(
-      int device_ordinal, const ExecuteOptions& options,
-      ScopedShapedBuffer result_buffer,
-      std::shared_ptr<BufferSequencingEvent> definition_event,
-      PjRtDevice* device) const;
-
-  StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
-      absl::Span<PjRtBuffer* const> argument_handles, int replica,
-      int partition, const RunId& run_id, const ExecuteOptions& options,
-      PjRtDevice* device = nullptr) const;
-
-  // Create shared pointers so we can free them after the execution: with
-  // asynchronous execution, the process being executed can outlive the
-  // executable itself.
-  PjRtClient* const client_;
-  // One executable per partition.
-  std::vector<std::shared_ptr<LocalExecutable>> executables_;
-  // Per-executable set of parameters that have any aliased buffers and thus
-  // must be donated when executing the computation.
-  std::vector<absl::flat_hash_set<int>> parameters_that_must_be_donated_;
-  std::shared_ptr<DeviceAssignment> device_assignment_;
-
-  // True if the executables were compiled expecting arguments in a single
-  // tuple.
-  const bool parameter_is_tupled_arguments_;
-
-  // The replica and partition indices of device_assignment_ to be run by this
-  // client. On single-host platforms without partitioning, this is all replicas
-  // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
-  // case on multi-host platforms. If there are 4 replicas and 2 partitions on a
-  // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
-  std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
-
-  // addressable_devices_[i] is the Device to which
-  // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
-  // unique_ptrs to play well with the Python bindings (see xla.cc).
-  std::vector<PjRtDevice*> addressable_devices_;
-};
-
-// Executables can donate buffers so that buffers can be aliased from inputs
-// to outputs. This function returns the list of parameters that must be
-// donated when executable is run. tuple_inputs reflects the option that
-// executable was compiled with.
-StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
-    const HloModule& hlo_module, bool tuple_inputs);
-
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
similarity index 79%
rename from tensorflow/compiler/xla/pjrt/pjrt_client.cc
rename to tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
index 191b346..e31db15 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc
+++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
@@ -62,9 +62,10 @@
 // See the comment on LocalDeviceState::AllocationModel for a discussion of the
 // different allocation semantics on CPU, GPU, and TPU.
 
-#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 
 #include <cstddef>
+#include <cstdlib>
 #include <memory>
 #include <string>
 #include <vector>
@@ -89,6 +90,7 @@
 #include "tensorflow/compiler/xla/pjrt/event_pool.h"
 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
+#include "tensorflow/compiler/xla/pjrt/utils.h"
 #include "tensorflow/compiler/xla/service/executable.h"
 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
@@ -97,6 +99,7 @@
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/cpu_info.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/fingerprint.h"
 #include "tensorflow/core/platform/mem.h"
@@ -114,21 +117,22 @@
 
 namespace xla {
 
-PjRtPlatformId PjRtDevice::platform_id() const {
+PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
   return client_->platform_id();
 }
-const std::string& PjRtDevice::platform_name() const {
+const std::string& PjRtStreamExecutorDevice::platform_name() const {
   return client_->platform_name();
 }
 
-StatusOr<LocalDeviceState*> PjRtDevice::GetLocalDeviceState() const {
+StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
+    const {
   if (local_device_state_) {
     return local_device_state_.get();
   }
   return InvalidArgument("Device %s is not a local device.", DebugString());
 }
 
-std::string PjRtDevice::DebugString() const {
+std::string PjRtStreamExecutorDevice::DebugString() const {
   return absl::StrCat(platform_name(), ":", id());
 }
 
@@ -153,14 +157,15 @@
           devices[replica].size(), replica, devices[0].size());
     }
     for (int partition = 0; partition < devices[replica].size(); ++partition) {
-      if (devices[0][0]->platform_id() !=
-          devices[replica][partition]->platform_id()) {
+      if (devices[0][0]->client()->platform_id() !=
+          devices[replica][partition]->client()->platform_id()) {
         return InvalidArgument(
             "Device assignment passed to Compile() must have devices of a "
             "single kind, got %s for replica 0 partition 0 and %s for replica "
             "%d partition %d.",
-            devices[0][0]->platform_name(),
-            devices[replica][partition]->platform_name(), replica, partition);
+            devices[0][0]->client()->platform_name(),
+            devices[replica][partition]->client()->platform_name(), replica,
+            partition);
       }
       xla_assignment(replica, partition) = devices[replica][partition]->id();
     }
@@ -182,9 +187,22 @@
   }
 };
 
-PjRtClient::PjRtClient(
+static int DefaultThreadPoolSize() {
+  // Google's CI system exposes an environment variable NPROC that describes
+  // a CPU reservation for tests.
+  // TODO(phawkins): expose a better thought-out set of knobs to control
+  // parallelism.
+  const char* nproc_str = std::getenv("NPROC");
+  int nproc = 0;
+  if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
+    return std::max(0, nproc);
+  }
+  return tensorflow::port::MaxParallelism();
+}
+
+PjRtStreamExecutorClient::PjRtStreamExecutorClient(
     std::string platform_name, LocalClient* client,
-    std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
+    std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
     std::unique_ptr<se::DeviceMemoryAllocator> allocator,
     std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
     bool should_stage_host_to_device_transfers,
@@ -193,14 +211,15 @@
       platform_name_(std::move(platform_name)),
       client_(client),
       host_memory_allocator_(std::move(host_memory_allocator)),
-      devices_(std::move(devices)),
+      owned_devices_(std::move(devices)),
       host_id_(host_id),
       owned_allocator_(std::move(allocator)),
       should_stage_host_to_device_transfers_(
           should_stage_host_to_device_transfers),
       gpu_run_options_(std::move(gpu_run_options)),
-      h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
-                         client->device_count()) {
+      thread_pool_(
+          tensorflow::Env::Default(), "pjrt_thread_pool",
+          std::max<int>(DefaultThreadPoolSize(), client->device_count())) {
   if (owned_allocator_ != nullptr) {
     allocator_ = owned_allocator_.get();
   } else {
@@ -211,12 +230,14 @@
     host_memory_allocator_ = std::make_unique<CpuAllocator>();
   }
 
-  for (const std::unique_ptr<PjRtDevice>& device : devices_) {
+  for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
+       owned_devices_) {
+    devices_.push_back(device.get());
     CHECK(id_to_device_.insert({device->id(), device.get()}).second)
         << "Duplicate device id: " << device->id();
 
-    if (device->IsLocalDevice()) {
-      int idx = device->local_device_id();
+    if (device->IsAddressable()) {
+      int idx = device->local_hardware_id();
       if (idx >= local_devices_.size()) {
         local_devices_.resize(idx + 1);
       }
@@ -230,13 +251,14 @@
   }
 }
 
-StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
+StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
     int num_replicas, int num_partitions) const {
   return client_->backend().computation_placer()->AssignDevices(num_replicas,
                                                                 num_partitions);
 }
 
-std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
+std::unique_ptr<HloCostAnalysis>
+PjRtStreamExecutorClient::GetHloCostAnalysis() {
   return absl::make_unique<HloCostAnalysis>(
       client_->backend().compiler()->ShapeSizeBytesFunction());
 }
@@ -303,7 +325,7 @@
 // a reference to the buffer until the copy completes or serialize the compute
 // stream behind the copy. It is often better to retain a reference since while
 // that keeps memory alive longer, it avoids stalling the compute stream.
-void RecordUsage(PjRtBuffer::ScopedHold device_buffer,
+void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
                  LocalDeviceState* buffer_local_device,
                  LocalDeviceState* stream_local_device,
                  std::shared_ptr<BufferSequencingEvent> event,
@@ -337,7 +359,7 @@
 //
 // The caller may optionally provide a definition event to be recorded in
 // the buffer.
-StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
+StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
     const Shape& on_host_shape, PjRtDevice* device,
     LocalDeviceState* local_device, se::Stream* copy_stream,
     bool is_uninitialized_create, PjRtClient* client,
@@ -346,12 +368,13 @@
     return InvalidArgument("Can't make a buffer from an empty tuple");
   }
 
+  auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
   TransferManager* transfer_manager =
-      client->client()->backend().transfer_manager();
-  TF_ASSIGN_OR_RETURN(
-      ScopedShapedBuffer dst_buffer,
-      transfer_manager->AllocateScopedShapedBuffer(
-          on_host_shape, client->allocator(), local_device->device_ordinal()));
+      se_client->client()->backend().transfer_manager();
+  TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
+                      transfer_manager->AllocateScopedShapedBuffer(
+                          on_host_shape, se_client->allocator(),
+                          local_device->device_ordinal()));
   if (local_device->allocation_model() ==
       LocalDeviceState::kComputeSynchronized) {
     if (copy_stream == nullptr) {
@@ -429,9 +452,9 @@
       TrackedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
                                                   definition_events);
 
-  auto py_buffer = absl::make_unique<PjRtBuffer>(on_host_shape, on_device_shape,
-                                                 std::move(dst_device_buffer),
-                                                 client, device);
+  auto py_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
+      on_host_shape, on_device_shape, std::move(dst_device_buffer), client,
+      device);
 
   if (on_device_shape.IsTuple()) {
     // Add a usage hold for the tuple table write and immediately convert it to
@@ -454,7 +477,8 @@
 // definition_event was added when the buffer was allocated, but has not yet
 // had an event recorded.
 Status AddDestinationBufferSynchronization(
-    LocalDeviceState* local_device, PjRtBuffer::ScopedHold device_buffer,
+    LocalDeviceState* local_device,
+    PjRtStreamExecutorBuffer::ScopedHold device_buffer,
     std::shared_ptr<BufferSequencingEvent> definition_event,
     se::Stream* copy_stream) {
   StatusOr<EventPool::Handle> event_or =
@@ -479,13 +503,13 @@
 
 }  // namespace
 
-PjRtBuffer::ScopedHold::~ScopedHold() {
+PjRtStreamExecutorBuffer::ScopedHold::~ScopedHold() {
   if (ok()) {
     parent_->DropHold(type_, buffer().get());
   }
 }
 
-PjRtBuffer::ScopedHold::ScopedHold(ScopedHold&& other)
+PjRtStreamExecutorBuffer::ScopedHold::ScopedHold(ScopedHold&& other)
     : parent_(other.parent_),
       type_(other.type_),
       state_(other.state_),
@@ -494,7 +518,7 @@
   other.SetState(kMoved);
 }
 
-void PjRtBuffer::ScopedHold::Acquire(
+void PjRtStreamExecutorBuffer::ScopedHold::Acquire(
     StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or) {
   CHECK(!ok());
   buffer_or_ = std::move(buffer_or);
@@ -503,31 +527,32 @@
   CHECK(!ok() || buffer_or_.ValueOrDie() != nullptr);
 }
 
-PjRtBuffer::ScopedHold::ForClosure PjRtBuffer::ScopedHold::ToClosure() {
+PjRtStreamExecutorBuffer::ScopedHold::ForClosure
+PjRtStreamExecutorBuffer::ScopedHold::ToClosure() {
   CHECK(ok());
   ForClosure for_closure(parent_, type_, state_, std::move(buffer_or_));
   SetState(kReleased);
   return for_closure;
 }
 
-void PjRtBuffer::ScopedHold::ConvertUsageHold(
+void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold(
     se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
     bool reference_held) {
   CHECK(ok());
-  CHECK(type_ == kUsage);
+  CHECK_EQ(type_, kUsage);
   parent_->ConvertUsageHold(buffer().get(), usage_stream, std::move(event),
                             reference_held);
   SetState(kConverted);
 }
 
-void PjRtBuffer::ScopedHold::ConfirmDonation() {
+void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() {
   CHECK(ok());
-  CHECK(type_ == kDonation);
+  CHECK_EQ(type_, kDonation);
   parent_->ConfirmDonation(buffer().get());
   SetState(kDonated);
 }
 
-void PjRtBuffer::ScopedHold::AddToInput(
+void PjRtStreamExecutorBuffer::ScopedHold::AddToInput(
     ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
     ExecutionInput* execution_input,
@@ -536,25 +561,60 @@
   if (type_ == kDonation) {
     buffer()->AddToInputAsDonated(iterator, end, execution_input, allocator);
   } else {
-    CHECK(type_ == kUsage);
+    CHECK_EQ(type_, kUsage);
     buffer()->AddToInputAsImmutable(iterator, end);
   }
 }
 
-bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
+bool PjRtStreamExecutorBuffer::IsOnCpu() const {
+  return client()->platform_id() == kCpuId;
+}
 
-StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
+StatusOr<std::unique_ptr<PjRtBuffer::ExternalReferenceHold>>
+PjRtStreamExecutorBuffer::AcquireExternalReference() {
+  ScopedHold hold = GetBufferWithExternalReference();
+  Status hold_status = hold.status();
+  if (!hold_status.ok()) return hold_status;
+  return std::unique_ptr<ExternalReferenceHold>(
+      std::make_unique<ScopedHoldAsExternalReference>(std::move(hold)));
+}
+
+StatusOr<absl::optional<std::shared_ptr<void>>>
+PjRtStreamExecutorBuffer::ReleaseDeviceMemoryOwnership(
+    bool wait_for_operations_to_complete) {
+  if (on_device_shape_.IsTuple()) {
+    return InvalidArgument(
+        "ReleaseDeviceMemoryOwnership allowed only for non-tuple");
+  }
+  TF_ASSIGN_OR_RETURN(
+      std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer,
+      Release(wait_for_operations_to_complete));
+
+  if (!tracked_device_buffer) {
+    // Buffer has been deleted.
+    return {absl::nullopt};
+  }
+  void* opaque_ptr = tracked_device_buffer->device_memory()[0].opaque();
+  return absl::make_optional<std::shared_ptr<void>>(
+      opaque_ptr,
+      [tracked_device_buffer = std::move(tracked_device_buffer)](void*) {});
+}
+
+StatusOr<std::unique_ptr<PjRtBuffer>>
+PjRtStreamExecutorClient::BufferFromHostBuffer(
     const void* data, const Shape& shape,
     HostBufferSemantics host_buffer_semantics,
     std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
-  tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
-  VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
-          << " device: " << device->DebugString();
+  tensorflow::profiler::TraceMe traceme(
+      "PjRtStreamExecutorClient::BufferFromHostBuffer");
+  VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
+          << shape.ToString() << " device: " << device->DebugString();
   if (shape.IsTuple()) {
     return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
   }
   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
-                      device->GetLocalDeviceState());
+                      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
+                          ->GetLocalDeviceState());
   int64 size = ShapeUtil::ByteSizeOf(shape);
 
   TransferManager* transfer_manager = client()->backend().transfer_manager();
@@ -607,18 +667,20 @@
           /*allocator=*/nullptr, local_device->device_ordinal(),
           std::initializer_list<se::DeviceMemoryBase>{buffer},
           definition_events, std::move(on_delete_callback));
-      return absl::make_unique<PjRtBuffer>(
-          shape, shape, std::move(device_buffer), this, device);
+      return std::unique_ptr<PjRtBuffer>(
+          std::make_unique<PjRtStreamExecutorBuffer>(
+              shape, shape, std::move(device_buffer), this, device));
     }
   }
 
   TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<PjRtBuffer> py_buffer,
+      std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
       AllocateDestinationBuffer(compact_shape, device, local_device,
                                 local_device->host_to_device_stream(),
                                 /*is_uninitialized_create=*/false, this));
 
-  PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
+  PjRtStreamExecutorBuffer::ScopedHold device_buffer(
+      py_buffer->GetBufferWithUsageHold());
   CHECK(device_buffer.ok());
 
   // If necessary, allocate a host-side buffer for staging host-to-device
@@ -658,7 +720,7 @@
                        staging_buffer{std::move(staging_buffer)},
                        buffer_reference{std::move(buffer_reference)},
                        host_buffer_semantics]() {
-    PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
+    PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
     // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
     // to report failures from a callback. However, the operations here are
     // unlikely to fail and not recoverable even if we were to fail: DMAs to
@@ -699,59 +761,69 @@
         std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
   };
   if (is_cpu_platform) {
-    // Using the h2d_transfer_pool would be a double thread hop; the code
+    // Using the thread_pool would be a double thread hop; the code
     // already defers its work onto a stream (= thread on CPU).
     transfer_h2d();
   } else {
-    h2d_transfer_pool()->Schedule(transfer_h2d);
+    thread_pool()->Schedule(transfer_h2d);
   }
-  return py_buffer;
+  return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
 }
 
-StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
-    const Shape& shape, PjRtDevice* device) {
+StatusOr<std::unique_ptr<PjRtBuffer>>
+PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
+                                                    PjRtDevice* device) {
   return CreateUninitializedBuffer(shape, device, nullptr);
 }
 
-StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
+StatusOr<std::unique_ptr<PjRtBuffer>>
+PjRtStreamExecutorClient::CreateUninitializedBuffer(
     const Shape& shape, PjRtDevice* device,
     std::shared_ptr<BufferSequencingEvent> definition_event) {
   tensorflow::profiler::TraceMe traceme(
-      "PjRtClient::CreateUninitializedBuffer");
-  VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
+      "PjRtStreamExecutorClient::CreateUninitializedBuffer");
+  VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
           << shape.ToString() << " device: " << device->DebugString();
   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
-                      device->GetLocalDeviceState());
+                      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
+                          ->GetLocalDeviceState());
 
   TransferManager* transfer_manager = client()->backend().transfer_manager();
   TF_ASSIGN_OR_RETURN(Shape compact_shape,
                       transfer_manager->ChooseCompactLayoutForShape(shape));
 
-  return AllocateDestinationBuffer(compact_shape, device, local_device,
-                                   /*copy_stream=*/nullptr,
-                                   /*is_uninitialized_create=*/true, this,
-                                   definition_event);
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
+      AllocateDestinationBuffer(compact_shape, device, local_device,
+                                /*copy_stream=*/nullptr,
+                                /*is_uninitialized_create=*/true, this,
+                                definition_event));
+  return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
 }
 
-StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
-    const LiteralSlice& literal, PjRtDevice* device) {
-  tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
-  VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
+StatusOr<std::unique_ptr<PjRtBuffer>>
+PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
+                                                PjRtDevice* device) {
+  tensorflow::profiler::TraceMe traceme(
+      "PjRtStreamExecutorClient::BufferFromHostLiteral");
+  VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
           << literal.shape().ToString() << " device: " << device->DebugString();
   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
-                      device->GetLocalDeviceState());
+                      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
+                          ->GetLocalDeviceState());
 
   TransferManager* transfer_manager = client()->backend().transfer_manager();
   TF_ASSIGN_OR_RETURN(
       Shape compact_shape,
       transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
   TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<PjRtBuffer> py_buffer,
+      std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
       AllocateDestinationBuffer(compact_shape, device, local_device,
                                 local_device->host_to_device_stream(),
                                 /*is_uninitialized_create=*/false, this));
 
-  PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
+  PjRtStreamExecutorBuffer::ScopedHold device_buffer(
+      py_buffer->GetBufferWithUsageHold());
   CHECK(device_buffer.ok());
 
   // The host to device transfer is performed on a thread pool, mostly because
@@ -764,7 +836,7 @@
                        movable_device_buffer{device_buffer.ToClosure()},
                        literal, py_buffer{py_buffer.get()}, compact_shape,
                        on_device_shape{py_buffer->on_device_shape()}]() {
-    PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
+    PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
     // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
     // to report failures from a callback. However, the operations here are
     // unlikely to fail and not recoverable even if we were to fail: DMAs to
@@ -788,11 +860,11 @@
         .IgnoreError();  // Can return error::Unimplemented
     QCHECK(h2d_stream->ok());
   };
-  h2d_transfer_pool()->Schedule(transfer_h2d);
-  return py_buffer;
+  thread_pool()->Schedule(transfer_h2d);
+  return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
 }
 
-void PjRtClient::MakeCrossHostReceiveBuffers(
+void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
     absl::Span<const Shape> shapes, PjRtDevice* device,
     PjRtCrossHostRecvNotifier&& notifier) {
   if (shapes.empty()) {
@@ -801,7 +873,9 @@
     return;
   }
 
-  auto local_device_or = device->GetLocalDeviceState();
+  auto local_device_or =
+      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
+          ->GetLocalDeviceState();
   if (!local_device_or.ok()) {
     notifier(local_device_or.status());
     return;
@@ -828,36 +902,40 @@
 }
 
 // Transfer the given literal to the infeed queue of the given local device.
-Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const {
+Status PjRtStreamExecutorDevice::TransferToInfeed(
+    const LiteralSlice& literal) const {
   // Only support infeed to local device.
   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
   return local_device->client()->TransferToInfeedLocal(
       literal, local_device->device_ordinal());
 }
 
-StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
+StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
+    const Shape& shape) const {
   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
   return local_device->client()->TransferFromOutfeedLocal(
       shape, local_device->device_ordinal());
 }
 
-StatusOr<PjRtDevice*> PjRtClient::LookupLocalDevice(int local_device_id) const {
+StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
+    int local_hardware_id) const {
   for (auto* device : local_devices_) {
-    if (local_device_id == device->local_device_id()) {
+    if (local_hardware_id == device->local_hardware_id()) {
       return device;
     }
   }
-  return InvalidArgument("No matching device found for local_device_id %d",
-                         local_device_id);
+  return InvalidArgument("No matching device found for local_hardware_id %d",
+                         local_hardware_id);
 }
 
-PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
-                       std::shared_ptr<TrackedDeviceBuffer> device_buffer,
-                       PjRtClient* client, PjRtDevice* device)
-    : client_(client),
+PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer(
+    Shape on_host_shape, Shape on_device_shape,
+    std::shared_ptr<TrackedDeviceBuffer> device_buffer, PjRtClient* client,
+    PjRtDevice* device)
+    : client_(tensorflow::down_cast<PjRtStreamExecutorClient*>(client)),
       on_host_shape_(std::move(on_host_shape)),
       on_device_shape_(std::move(on_device_shape)),
-      device_(device),
+      device_(tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)),
       device_buffer_(std::move(device_buffer)),
       donation_semaphore_(/*capacity=*/1) {
   for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
@@ -865,39 +943,38 @@
   }
 }
 
-PjRtBuffer::~PjRtBuffer() {
+PjRtStreamExecutorBuffer::~PjRtStreamExecutorBuffer() {
   Delete();
   for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
     CHECK_EQ(holds_[i], 0);
   }
 }
 
-int64 PjRtBuffer::OnDeviceSizeInBytes() const {
-  return client_->client()
+int64 PjRtStreamExecutorBuffer::OnDeviceSizeInBytes() const {
+  return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
+      ->client()
       ->backend()
       .transfer_manager()
       ->GetByteSizeRequirement(on_device_shape_);
 }
 
-void PjRtBuffer::WaitForOutstandingUsageHolds() {
-  auto not_in_usage_hold = [&]() {
-    mu_.AssertHeld();
+void PjRtStreamExecutorBuffer::WaitForOutstandingUsageHolds() {
+  auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     return holds_[ScopedHold::kUsage] == 0;
   };
   mu_.Await(absl::Condition(&not_in_usage_hold));
 }
 
-void PjRtBuffer::WaitForOutstandingDonationHold() {
-  auto not_in_donation_hold = [&]() {
-    mu_.AssertHeld();
+void PjRtStreamExecutorBuffer::WaitForOutstandingDonationHold() {
+  auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     return holds_[ScopedHold::kDonation] == 0;
   };
   mu_.Await(absl::Condition(&not_in_donation_hold));
 }
 
-StatusOr<std::shared_ptr<TrackedDeviceBuffer>> PjRtBuffer::Release(
-    bool wait_for_operations_to_complete) {
-  tensorflow::profiler::TraceMe trace_me("PjRtBuffer::Release");
+StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
+PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
+  tensorflow::profiler::TraceMe trace_me("PjRtStreamExecutorBuffer::Release");
   std::shared_ptr<TrackedDeviceBuffer> device_buffer;
   TrackedDeviceBuffer::StreamAndEventContainer events;
   {
@@ -919,7 +996,9 @@
     // the final set of usage events.
     events = device_buffer->LockUseAndTransferUsageEvents();
   }
-  LocalDeviceState* local_device_state = device_->local_device_state();
+  LocalDeviceState* local_device_state =
+      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
+          ->local_device_state();
   if (wait_for_operations_to_complete) {
     // Block the host until all usage events have completed. Usage events
     // dominate definition events, so this also waits for the buffer to be
@@ -972,18 +1051,18 @@
   return device_buffer;
 }
 
-void PjRtBuffer::Delete() {
+void PjRtStreamExecutorBuffer::Delete() {
   // When wait_for_reads_to_complete is false, Release should never fail.
   TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
 }
 
-bool PjRtBuffer::IsDeleted() {
+bool PjRtStreamExecutorBuffer::IsDeleted() {
   absl::MutexLock lock(&mu_);
   return device_buffer_ == nullptr;
 }
 
 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
-PjRtBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
+PjRtStreamExecutorBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
   if (type == ScopedHold::kDonation) {
     if (device_buffer_ == nullptr) {
       return InvalidArgument("Donation requested for invalid buffer");
@@ -1017,14 +1096,13 @@
   return device_buffer_;
 }
 
-void PjRtBuffer::AcquireHoldLocked(ScopedHold* hold) {
+void PjRtStreamExecutorBuffer::AcquireHoldLocked(ScopedHold* hold) {
   hold->Acquire(GetBufferForHoldLocked(hold->type()));
 }
 
-void PjRtBuffer::ConvertUsageHold(TrackedDeviceBuffer* buffer,
-                                  se::Stream* usage_stream,
-                                  std::shared_ptr<BufferSequencingEvent> event,
-                                  bool reference_held) {
+void PjRtStreamExecutorBuffer::ConvertUsageHold(
+    TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
+    std::shared_ptr<BufferSequencingEvent> event, bool reference_held) {
   absl::MutexLock lock(&mu_);
   CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
   buffer->AddUsageEvent(usage_stream, std::move(event), reference_held);
@@ -1032,7 +1110,8 @@
   --holds_[ScopedHold::kUsage];
 }
 
-void PjRtBuffer::ConfirmDonation(TrackedDeviceBuffer* device_buffer) {
+void PjRtStreamExecutorBuffer::ConfirmDonation(
+    TrackedDeviceBuffer* device_buffer) {
   {
     absl::MutexLock lock(&mu_);
     CHECK_EQ(holds_[ScopedHold::kUsage], 0);
@@ -1054,7 +1133,8 @@
   donation_semaphore_.Release(1);
 }
 
-void PjRtBuffer::DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer) {
+void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type,
+                                        TrackedDeviceBuffer* buffer) {
   absl::MutexLock lock(&mu_);
   CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
   CHECK_GT(holds_[type], 0);
@@ -1067,20 +1147,23 @@
   }
 }
 
-Status PjRtBuffer::CopyToHostAsync(absl::optional<xla::Layout> layout) {
+Status PjRtStreamExecutorBuffer::CopyToHostAsync(
+    absl::optional<xla::Layout> layout) {
   return CopyToHostAsyncInternal(/*discard_cached_copy=*/false, layout)
       .status();
 }
 
-StatusOr<std::shared_ptr<PjRtBuffer::HostValue>>
-PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
-                                    absl::optional<xla::Layout> layout) {
+StatusOr<std::shared_ptr<PjRtStreamExecutorBuffer::HostValue>>
+PjRtStreamExecutorBuffer::CopyToHostAsyncInternal(
+    bool discard_cached_copy, absl::optional<xla::Layout> layout) {
   if (IsEmptyTuple()) {
     return InvalidArgument("CopyToHostAsync called on empty tuple");
   }
   ScopedHold device_buffer(this, ScopedHold::kUsage);
   std::shared_ptr<HostValue> host_value;
-  LocalDeviceState* local_device = device_->local_device_state();
+  LocalDeviceState* local_device =
+      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
+          ->local_device_state();
   se::Stream* stream = local_device->GetDeviceToHostStream();
   const xla::Layout& host_layout =
       layout.has_value() ? layout.value() : on_host_shape_.layout();
@@ -1122,12 +1205,16 @@
   host_value->value = std::make_shared<Literal>(host_shape);
   ShapedBuffer shaped_buffer =
       device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
-  client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
-      stream, shaped_buffer, host_value->value.get(),
-      [host_value](Status done_status) {
-        host_value->status = done_status;
-        host_value->ready.Notify();
-      });
+  tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
+      ->client()
+      ->backend()
+      .transfer_manager()
+      ->TransferLiteralFromDevice(stream, shaped_buffer,
+                                  host_value->value.get(),
+                                  [host_value](Status done_status) {
+                                    host_value->status = done_status;
+                                    host_value->ready.Notify();
+                                  });
 
   auto usage_event = std::make_shared<BufferSequencingEvent>();
   StatusOr<EventPool::Handle> event_or =
@@ -1154,9 +1241,9 @@
   return host_value;
 }
 
-StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
+StatusOr<std::shared_ptr<Literal>> PjRtStreamExecutorBuffer::ToLiteral(
     const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
-  tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
+  tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
   TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
                       CopyToHostAsyncInternal(discard_cached_copy, layout));
   if (host_value == nullptr) {
@@ -1167,7 +1254,7 @@
   return host_value->value;
 }
 
-StatusOr<ShapedBuffer> PjRtBuffer::AsShapedBuffer() const {
+StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
   absl::MutexLock lock(&mu_);
   if (device_buffer_ == nullptr) {
     return InvalidArgument(
@@ -1176,7 +1263,8 @@
   return device_buffer_->AsShapedBuffer(on_host_shape_, on_device_shape_);
 }
 
-PjRtBuffer::ScopedHold PjRtBuffer::GetBufferWithHold(ScopedHold::Type type) {
+PjRtStreamExecutorBuffer::ScopedHold
+PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) {
   if (type == ScopedHold::kDonation) {
     // Ensure that at most one donation hold can be in progress at a time.
     donation_semaphore_.Acquire(1);
@@ -1192,12 +1280,12 @@
 
 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
                    std::shared_ptr<BufferSequencingEvent>>>
-PjRtBuffer::CopyToDeviceHelper(
+PjRtStreamExecutorBuffer::CopyToDeviceHelper(
     PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
     LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
     std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
   TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<PjRtBuffer> py_buffer,
+      std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
       AllocateDestinationBuffer(on_host_shape_, dst_device, dst_local_device,
                                 transfer_stream,
                                 /*is_uninitialized_create=*/false, client_));
@@ -1241,20 +1329,23 @@
       // StallStreamOnError only makes sure the destination device is ok, so
       // make sure that the src buffer remains valid until after any transfers
       // have completed.
-      device_->local_device_state()->ThenRelease(transfer_stream,
-                                                 src_device_buffer);
+      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
+          ->local_device_state()
+          ->ThenRelease(transfer_stream, src_device_buffer);
     }
     return copy_event_or.status();
   }
 
   return std::pair<std::unique_ptr<PjRtBuffer>,
                    std::shared_ptr<BufferSequencingEvent>>(
-      std::move(py_buffer), copy_event_or.ConsumeValueOrDie());
+      std::unique_ptr<PjRtStreamExecutorBuffer>(std::move(py_buffer)),
+      copy_event_or.ConsumeValueOrDie());
 }
 
-StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
+StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
     PjRtDevice* dst_device) {
-  tensorflow::profiler::TraceMe traceme("PjRtBuffer::CopyToDevice");
+  tensorflow::profiler::TraceMe traceme(
+      "PjRtStreamExecutorBuffer::CopyToDevice");
   if (dst_device == device_) {
     return InvalidArgument(
         "CopyToDevice cannot accept the same source and destination devices");
@@ -1265,14 +1356,20 @@
     TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
     return dst_device->client()->BufferFromHostBuffer(
         literal->untyped_data(), literal->shape(),
-        PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
+        PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, nullptr,
+        dst_device);
   }
 
-  TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
-                      dst_device->GetLocalDeviceState());
+  TF_ASSIGN_OR_RETURN(
+      LocalDeviceState * dst_local_device,
+      tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
+          ->GetLocalDeviceState());
   LocalDeviceState* transfer_local_device =
-      client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
-                                                : dst_local_device;
+      tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
+              ->EnqueueD2DTransfersOnSrcStream()
+          ? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
+                ->local_device_state()
+          : dst_local_device;
   CHECK_EQ(dst_local_device->allocation_model(),
            transfer_local_device->allocation_model());
 
@@ -1310,19 +1407,24 @@
   // alternative is to ensure, before freeing the buffer, that the compute
   // stream is synchronized past the transfer, but it seems better to hold onto
   // the buffer too long than to stall the compute stream.
-  RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
+  RecordUsage(std::move(src_device_buffer),
+              tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
+                  ->local_device_state(),
               transfer_local_device, event, transfer_stream,
               /*prefer_to_retain_reference=*/true);
 
   return std::move(buffer);
 }
 
-Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
-  return client_->CopyToRemoteDevice(this, serialized_descriptor);
+Status PjRtStreamExecutorBuffer::CopyToRemoteDevice(
+    absl::string_view serialized_descriptor) {
+  return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
+      ->CopyToRemoteDevice(this, serialized_descriptor);
 }
 
-Status PjRtBuffer::BlockHostUntilReady() {
-  tensorflow::profiler::TraceMe traceme("PjRtBuffer::BlockHostUntilReady");
+Status PjRtStreamExecutorBuffer::BlockHostUntilReady() {
+  tensorflow::profiler::TraceMe traceme(
+      "PjRtStreamExecutorBuffer::BlockHostUntilReady");
   std::shared_ptr<TrackedDeviceBuffer> device_buffer;
   {
     absl::MutexLock lock(&mu_);
@@ -1332,7 +1434,9 @@
     }
     device_buffer = device_buffer_;
   }
-  LocalDeviceState* local_device_state = device_->local_device_state();
+  LocalDeviceState* local_device_state =
+      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
+          ->local_device_state();
   std::unique_ptr<se::Stream> stream;
   for (auto& event : device_buffer->definition_events()) {
     if (!event->IsComplete()) {
@@ -1365,7 +1469,7 @@
 StatusOr<TupleHandle> MakeTupleHelper(
     PjRtClient* client, LocalDeviceState* local_device,
     absl::Span<PjRtBuffer* const> py_buffers,
-    absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
+    absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
     int device_ordinal) {
   std::vector<Shape> host_shapes;
   std::vector<Shape> device_shapes;
@@ -1378,9 +1482,13 @@
   Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
   Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
 
-  se::DeviceMemoryAllocator* allocator = client->allocator();
+  se::DeviceMemoryAllocator* allocator =
+      tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
   TransferManager* transfer_manager =
-      client->client()->backend().transfer_manager();
+      tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
+          ->client()
+          ->backend()
+          .transfer_manager();
   se::Stream* stream = local_device->host_to_device_stream();
   TF_ASSIGN_OR_RETURN(
       se::OwningDeviceMemory root_table_memory,
@@ -1407,7 +1515,8 @@
       MaybeOwningDeviceMemory(std::move(root_table_memory)));
   ++input_iterator;
   // Then set each sub-tuple in turn from the parameters.
-  for (const PjRtBuffer::ScopedHold& device_buffer : device_buffers) {
+  for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
+       device_buffers) {
     device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input,
                              allocator);
   }
@@ -1436,22 +1545,14 @@
   std::shared_ptr<TrackedDeviceBuffer> out_buffer =
       TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
                                                   {definition_event});
-  auto pjrt_buffer = absl::make_unique<PjRtBuffer>(
+  auto pjrt_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
       result_buffer->on_host_shape(), result_buffer->on_device_shape(),
       std::move(out_buffer), client, device);
   RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
               definition_event, local_device->compute_stream(),
               /*prefer_to_retain_reference=*/false);
-  return pjrt_buffer;
+  return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
 }
-
-static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
-  auto it = client.id_to_device().find(device_id);
-  CHECK(it != client.id_to_device().end())
-      << "Unknown device id: " << device_id;
-  return it->second;
-}
-
 }  // namespace
 
 PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
@@ -1459,7 +1560,8 @@
     bool parameter_is_tupled_arguments,
     std::shared_ptr<DeviceAssignment> device_assignment,
     std::vector<LogicalDeviceIds> addressable_device_logical_ids,
-    std::vector<PjRtDevice*> addressable_devices, PjRtClient* client)
+    std::vector<PjRtDevice*> addressable_devices,
+    PjRtStreamExecutorClient* client)
     : client_(client),
       device_assignment_(std::move(device_assignment)),
       parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
@@ -1482,7 +1584,7 @@
     VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
             << device_assignment_->ToString();
     CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
-    CHECK_LE(addressable_devices_.size(), client_->local_device_count())
+    CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
         << "Inconsistent local device count.";
     num_partitions = device_assignment_->computation_count();
   }
@@ -1495,60 +1597,6 @@
   }
 }
 
-StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
-    const HloModule& module, bool tuple_inputs) {
-  HloComputation* computation = module.entry_computation();
-  int number_of_parameters = [&]() -> int {
-    if (tuple_inputs) {
-      CHECK_EQ(computation->num_parameters(), 1);
-      const Shape& input_tuple_shape =
-          computation->parameter_instruction(0)->shape();
-      CHECK(input_tuple_shape.IsTuple());
-      return input_tuple_shape.tuple_shapes_size();
-    } else {
-      return computation->num_parameters();
-    }
-  }();
-  // If any buffer in a parameter is aliased we will donate the entire input
-  // parameter.
-  absl::flat_hash_set<int> parameters_to_donate;
-  const HloInputOutputAliasConfig& config = module.input_output_alias_config();
-  TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
-      [&](const ShapeIndex& output_index,
-          const HloInputOutputAliasConfig::Alias& alias) {
-        if (tuple_inputs) {
-          if (alias.parameter_number != 0) {
-            return InvalidArgument(
-                "Unexpected parameter number %d in alias config with tupled "
-                "inputs",
-                alias.parameter_number);
-          }
-          const ShapeIndex& index = alias.parameter_index;
-          if (!index.empty()) {
-            int this_parameter = index.data()[0];
-            if (this_parameter >= number_of_parameters) {
-              return InvalidArgument(
-                  "Unexpected parameter index %s in alias config with tupled "
-                  "inputs and %d parameters",
-                  index.ToString(), number_of_parameters);
-            }
-            parameters_to_donate.insert(this_parameter);
-          }
-        } else {
-          int this_parameter = alias.parameter_number;
-          if (this_parameter >= number_of_parameters) {
-            return InvalidArgument(
-                "Unexpected parameter number %d in alias config without tupled "
-                "inputs and %d parameters",
-                this_parameter, number_of_parameters);
-          }
-          parameters_to_donate.insert(this_parameter);
-        }
-        return Status::OK();
-      }));
-  return parameters_to_donate;
-}
-
 Status PjRtStreamExecutorExecutable::SetUpDonation(bool tuple_inputs) {
   parameters_that_must_be_donated_.reserve(executables_.size());
   for (auto& executable : executables_) {
@@ -1581,10 +1629,10 @@
 PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
     int device_ordinal, const ExecuteOptions& options,
     absl::Span<PjRtBuffer* const> argument_handles,
-    absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
+    absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
     absl::flat_hash_set<BufferSequencingEvent*>& events) const {
   std::vector<ExecutionInput> execution_inputs;
-  LocalDeviceState* device_state = &client_->device_state(device_ordinal);
+  LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
   // Lift tuple_handle outside the conditional so that the event it returns is
   // not destroyed until after the loop below that waits on events.
   absl::optional<TupleHandle> tuple_handle;
@@ -1607,8 +1655,10 @@
           execution_input.MutableBuffers()->begin();
       ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
           execution_input.MutableBuffers()->end();
-      device_buffers[i].AddToInput(&input_iterator, iterator_end,
-                                   &execution_input, client_->allocator());
+      device_buffers[i].AddToInput(
+          &input_iterator, iterator_end, &execution_input,
+          tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
+              ->allocator());
       CHECK(input_iterator == iterator_end);
     }
   }
@@ -1626,10 +1676,13 @@
 StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
     absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
     int executable_idx, const RunId& run_id, const ExecuteOptions& options,
-    PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
+    PjRtDevice* device,
+    std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
     std::shared_ptr<DeviceAssignment> device_assignment) const {
-  int device_ordinal = device->local_device_state()->device_ordinal();
-  LocalDeviceState* device_state = &client_->device_state(device_ordinal);
+  int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
+                           ->local_device_state()
+                           ->device_ordinal();
+  LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
   tensorflow::profiler::TraceMeConsumer activity(
       "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
       run_id.ToInt());
@@ -1639,7 +1692,8 @@
   absl::flat_hash_set<BufferSequencingEvent*> events;
   device_buffers->reserve(argument_handles.size());
   for (int i = 0; i < argument_handles.size(); ++i) {
-    PjRtBuffer* handle = argument_handles[i];
+    auto* handle =
+        tensorflow::down_cast<PjRtStreamExecutorBuffer*>(argument_handles[i]);
     if (handle->device() != device) {
       return InvalidArgument(
           "Buffer passed to Execute() as argument %d to replica %d is on "
@@ -1648,9 +1702,10 @@
     }
     bool must_donate = MustDonateParameter(executable_idx, i);
     device_buffers->emplace_back(handle->GetBufferWithHold(
-        must_donate ? PjRtBuffer::ScopedHold::kDonation
-                    : PjRtBuffer::ScopedHold::kUsage));
-    PjRtBuffer::ScopedHold& device_buffer = device_buffers->back();
+        must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
+                    : PjRtStreamExecutorBuffer::ScopedHold::kUsage));
+    PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
+        device_buffers->back();
     if (!device_buffer.ok()) {
       return InvalidArgument(
           "Invalid buffer passed to Execute() as argument %d to replica %d: "
@@ -1765,7 +1820,7 @@
     std::shared_ptr<BufferSequencingEvent> definition_event,
     PjRtDevice* device) const {
   std::vector<std::unique_ptr<PjRtBuffer>> outputs;
-  LocalDeviceState* device_state = &client_->device_state(device_ordinal);
+  LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
   if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
     int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
     outputs.reserve(tuple_count);
@@ -1802,7 +1857,7 @@
   if (device == nullptr) {
     CHECK(device_assignment_ != nullptr);
     const int device_id = (*device_assignment_)(replica, partition);
-    device = LookupDevice(*client_, device_id);
+    TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
     device_assignment = device_assignment_;
   } else {
     CHECK(device_assignment_ == nullptr);
@@ -1814,7 +1869,9 @@
   }
 
   CHECK_EQ(device->host_id(), client_->host_id());
-  int device_ordinal = device->local_device_state()->device_ordinal();
+  int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
+                           ->local_device_state()
+                           ->device_ordinal();
   tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
   VLOG(3) << "Replica " << replica << ", partition " << partition
           << " mapped to device ordinal for execution: " << device_ordinal;
@@ -1822,7 +1879,7 @@
   // SPMD sharding produces a single executable for multiple partitions.
   int executable_idx = executables_.size() > 1 ? partition : 0;
 
-  std::vector<PjRtBuffer::ScopedHold> device_buffers;
+  std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
   device_buffers.reserve(argument_handles.size());
   StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
       argument_handles, replica, partition, executable_idx, run_id, options,
@@ -1836,14 +1893,14 @@
   ScopedShapedBuffer result_buffer =
       result_buffer_or_status.ConsumeValueOrDie();
 
-  LocalDeviceState* device_state = &client_->device_state(device_ordinal);
+  LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
   se::Stream* stream = device_state->compute_stream();
   StatusOr<EventPool::Handle> event_or =
       device_state->event_pool().ThenAllocateAndRecordEvent(stream);
   if (!event_or.ok()) {
     StallStreamOnError(device_state, stream);
-    for (PjRtBuffer::ScopedHold& b : device_buffers) {
-      if (b.type() == PjRtBuffer::ScopedHold::kDonation) {
+    for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
+      if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
         // Even though there was an error we need to call ConfirmDonation, which
         // renders b invalid, since the computation has been enqueued and b has
         // been donated.
@@ -1858,17 +1915,17 @@
       MakeOutputBuffers(device_ordinal, options, std::move(result_buffer),
                         definition_event, device);
 
-  for (PjRtBuffer::ScopedHold& b : device_buffers) {
+  for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
     // prefer_to_retain_reference=false because when using the
     // ComputeSynchronized allocation model we don't need to retain a reference
     // to the device_buffer during execution because by definition the compute
     // stream is synchronized past the execution.
-    if (b.type() == PjRtBuffer::ScopedHold::kUsage) {
+    if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
       RecordUsage(std::move(b), device_state, device_state, definition_event,
                   stream,
                   /*prefer_to_retain_reference=*/false);
     } else {
-      CHECK(b.type() == PjRtBuffer::ScopedHold::kDonation);
+      CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
       b.ConfirmDonation();
     }
   }
@@ -1879,7 +1936,7 @@
 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
 PjRtStreamExecutorExecutable::Execute(
     absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
-    const ExecuteOptions& options) const {
+    const ExecuteOptions& options) {
   if (device_assignment_ == nullptr) {
     return InvalidArgument("Execute expects a non-null device_assignment");
   }
@@ -1922,7 +1979,9 @@
       const int replica = addressable_device_logical_ids_[i].replica;
       const int partition = addressable_device_logical_ids_[i].partition;
       PjRtDevice* device = addressable_devices_[i];
-      const LocalDeviceState& device_state = *device->local_device_state();
+      const LocalDeviceState& device_state =
+          *tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
+               ->local_device_state();
       device_state.execute_thread()->Schedule([&, replica, partition, i] {
         results[i] = ExecuteHelper(argument_handles[i], replica, partition,
                                    run_id, options);
@@ -1988,7 +2047,7 @@
 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
 PjRtStreamExecutorExecutable::ExecuteSharded(
     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
-    const ExecuteOptions& options) const {
+    const ExecuteOptions& options) {
   if (device_assignment_ == nullptr) {
     return InvalidArgument("ExecuteShard expects a non-null device_assignment");
   }
@@ -2011,7 +2070,7 @@
 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
 PjRtStreamExecutorExecutable::ExecutePortable(
     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
-    const ExecuteOptions& options) const {
+    const ExecuteOptions& options) {
   if (device_assignment_ != nullptr) {
     return InvalidArgument("ExecutePortable gets a non-portable executable");
   }
@@ -2044,98 +2103,14 @@
   return std::move(modules);
 }
 
-namespace {
-
-StatusOr<Shape> GetShardedShape(const Shape& shape,
-                                const OpSharding& sharding) {
-  if (sharding.type() == OpSharding::TUPLE) {
-    if (!shape.IsTuple()) {
-      return InvalidArgument(
-          "Got tuple OpSharding (%s) for non-tuple shape (%s)",
-          sharding.DebugString(), shape.ToString());
-    }
-    if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
-      return InvalidArgument(
-          "Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
-          " (OpSharding: %s, shape: %s)",
-          sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
-          sharding.DebugString(), shape.ToString());
-    }
-    std::vector<Shape> sharded_subshapes;
-    for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
-      TF_ASSIGN_OR_RETURN(
-          Shape sharded_subshape,
-          GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
-      sharded_subshapes.emplace_back(std::move(sharded_subshape));
-    }
-    return ShapeUtil::MakeTupleShape(sharded_subshapes);
-  }
-  TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
-                      HloSharding::FromProto(sharding));
-  return hlo_sharding.TileShape(shape);
-}
-
-StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
-  const Shape unsharded_shape(instr.shape());
-  Shape sharded_shape;
-  if (instr.has_sharding()) {
-    TF_ASSIGN_OR_RETURN(sharded_shape,
-                        GetShardedShape(unsharded_shape, instr.sharding()));
-  } else {
-    sharded_shape = unsharded_shape;
-  }
-  LayoutUtil::ClearLayout(&sharded_shape);
-  return sharded_shape;
-}
-
-// Returns sharded (argument shapes, result shape) without layouts.
-StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
-    const XlaComputation& computation) {
-  TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
-                      computation.GetProgramShape());
-  std::vector<Shape> arg_shapes;
-  arg_shapes.resize(program_shape.parameters_size());
-  Shape result_shape;
-  for (const HloComputationProto& comp : computation.proto().computations()) {
-    if (comp.id() != computation.proto().entry_computation_id()) {
-      continue;
-    }
-    for (const HloInstructionProto& instr : comp.instructions()) {
-      if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
-        if (instr.parameter_number() >= program_shape.parameters_size()) {
-          return InvalidArgument(
-              "Got invalid parameter number %d, expected %d parameters",
-              instr.parameter_number(), program_shape.parameters_size());
-        }
-        TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
-                            GetShardedShape(instr));
-      }
-      if (instr.id() == comp.root_id()) {
-        if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
-          return InvalidArgument("Found multiple root instructions");
-        }
-        TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
-      }
-    }
-  }
-  for (int i = 0; i < arg_shapes.size(); ++i) {
-    if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
-      return InvalidArgument("Couldn't find parameter %d", i);
-    }
-  }
-  if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
-    return InvalidArgument("Couldn't find root instruction");
-  }
-  return std::make_pair(arg_shapes, result_shape);
-}
-
-}  // namespace
-
-StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
+StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
     const XlaComputation& computation, CompileOptions options) {
-  tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
+  tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
 
   ExecutableBuildOptions& build_options = options.executable_build_options;
+  if (!build_options.compile_thread_pool()) {
+    build_options.set_compile_thread_pool(thread_pool());
+  }
   if (!build_options.device_allocator()) {
     build_options.set_device_allocator(allocator());
   }
@@ -2143,87 +2118,23 @@
   int num_replicas;
   int num_partitions;
   std::shared_ptr<DeviceAssignment> device_assignment;
-  if (options.compile_portable_executable) {
-    if (build_options.has_device_assignment()) {
-      return InvalidArgument(
-          "CompileOptions requests portable executable but "
-          "ExecutableBuildOptions includes a device assignment");
-    }
-    num_replicas = 1;
-    num_partitions = 1;
-  } else {
-    if (!build_options.has_device_assignment()) {
-      VLOG(2) << "PjRtClient::Compile using default device_assignment.";
-      TF_ASSIGN_OR_RETURN(
-          DeviceAssignment device_assignment,
-          GetDefaultDeviceAssignment(build_options.num_replicas(),
-                                     build_options.num_partitions()));
-      build_options.set_device_assignment(device_assignment);
-    }
-    VLOG(2) << "PjRtClient::Compile device_assignment:\n"
-            << build_options.device_assignment().ToString();
-    num_replicas = build_options.device_assignment().replica_count();
-    num_partitions = build_options.device_assignment().computation_count();
-    device_assignment =
-        std::make_shared<DeviceAssignment>(build_options.device_assignment());
-  }
+  TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
+      options.compile_portable_executable, &options.executable_build_options,
+      [this](int num_replicas, int num_partitions) {
+        return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
+      },
+      &num_replicas, &num_partitions, &device_assignment));
 
-  TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
-                      computation.GetProgramShape());
-  if (!options.argument_layouts) {
-    options.argument_layouts = program_shape.parameters();
-    for (Shape& shape : *options.argument_layouts) {
-      LayoutUtil::ClearLayout(&shape);
-    }
-  } else if (options.argument_layouts->size() !=
-             program_shape.parameters_size()) {
-    return InvalidArgument(
-        "CompileOptions specify %d argument layouts, but computation has %d "
-        "arguments",
-        options.argument_layouts->size(), program_shape.parameters_size());
-  }
   std::vector<const Shape*> argument_layout_pointers;
-  argument_layout_pointers.reserve(options.argument_layouts->size());
-
-  // Assign a default layout based on `sharded_shape` to any array subshapes in
-  // `dst_shape` that are missing layouts.
-  auto assign_layouts = [local_client = client()](const Shape& sharded_shape,
-                                                  Shape* dst_shape) {
-    return ShapeUtil::ForEachMutableSubshapeWithStatus(
-        dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
-          if (subshape->IsArray() && !subshape->has_layout()) {
-            CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
-            const Shape& sharded_subshape =
-                ShapeUtil::GetSubshape(sharded_shape, idx);
-            LayoutUtil::SetToDefaultLayout(subshape);
-            TF_ASSIGN_OR_RETURN(Shape layout, local_client->backend()
-                                                  .transfer_manager()
-                                                  ->ChooseCompactLayoutForShape(
-                                                      sharded_subshape));
-            *subshape->mutable_layout() = layout.layout();
-          }
-          return Status::OK();
-        });
-  };
-  TF_ASSIGN_OR_RETURN(auto sharded_shapes,
-                      GetShardedProgramShapes(computation));
-
-  CHECK_EQ(sharded_shapes.first.size(), options.argument_layouts->size());
-  for (int i = 0; i < options.argument_layouts->size(); ++i) {
-    Shape* layout = &(*options.argument_layouts)[i];
-    argument_layout_pointers.push_back(layout);
-    TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
-  }
-
-  Shape result_layout;
-  if (build_options.result_layout()) {
-    result_layout = *build_options.result_layout();
-  } else {
-    result_layout = program_shape.result();
-    LayoutUtil::ClearLayout(&result_layout);
-  }
-  TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
-  build_options.set_result_layout(result_layout);
+  TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
+      computation,
+      [local_client = client()](Shape shape) {
+        return local_client->backend()
+            .transfer_manager()
+            ->ChooseCompactLayoutForShape(shape);
+      },
+      options.argument_layouts, &options.executable_build_options,
+      &argument_layout_pointers));
 
   // Find devices that are addressable by this client/task.
   std::vector<PjRtExecutable::LogicalDeviceIds> addressable_device_logical_ids;
@@ -2234,7 +2145,7 @@
     for (int replica = 0; replica < num_replicas; ++replica) {
       for (int partition = 0; partition < num_partitions; ++partition) {
         int device_id = (*device_assignment)(replica, partition);
-        PjRtDevice* device = LookupDevice(*this, device_id);
+        TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
         if (device->host_id() != host_id()) {
           VLOG(3) << "Non-local device: " << device_id;
           continue;
@@ -2254,7 +2165,7 @@
 
     if (build_options.device_ordinal() < 0) {
       build_options.set_device_ordinal(
-          addressable_devices.front()->local_device_state()->device_ordinal());
+          addressable_devices.front()->local_hardware_id());
     }
   }
 
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h
new file mode 100644
index 0000000..2f55a71
--- /dev/null
+++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h
@@ -0,0 +1,780 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
+#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/synchronization/notification.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/client/executable_build_options.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/layout.h"
+#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class PjRtStreamExecutorDevice : public PjRtDevice {
+ public:
+  explicit PjRtStreamExecutorDevice(
+      int id, std::unique_ptr<LocalDeviceState> local_device_state,
+      std::string device_kind, int host_id = 0)
+      : id_(id),
+        device_ordinal_(
+            local_device_state ? local_device_state->device_ordinal() : -1),
+        local_device_state_(std::move(local_device_state)),
+        host_id_(host_id),
+        device_kind_(std::move(device_kind)) {}
+  ~PjRtStreamExecutorDevice() override {}
+
+  // Must set client exactly once.
+  void SetClient(PjRtClient* client) {
+    CHECK(client_ == nullptr);
+    client_ = client;
+  }
+
+  int host_id() const override { return host_id_; }
+
+  // Return `platform_id` from client.
+  PjRtPlatformId platform_id() const;
+
+  // Return `platform_name` from client.
+  const std::string& platform_name() const;
+
+  PjRtClient* client() const override { return client_; }
+
+  int id() const override { return id_; }
+
+  bool IsAddressable() const override { return device_ordinal_ != -1; }
+
+  int local_hardware_id() const override { return device_ordinal_; }
+
+  // If this is a device local to this host, returns a LocalDeviceState object
+  // that can be used to manipulate the device. Returns nullptr if the device is
+  // not local to this host.
+  LocalDeviceState* local_device_state() const {
+    return local_device_state_.get();
+  }
+
+  // If this is a device local to this host, returns a LocalDeviceState object
+  // that can be used to manipulate the device. Returns an error if the device
+  // is not local to this host.
+  StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
+
+  const std::string& device_kind() const override { return device_kind_; }
+
+  std::string DebugString() const override;
+
+  Status TransferToInfeed(const LiteralSlice& literal) const override;
+
+  StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
+
+ private:
+  const int id_;
+  const int device_ordinal_;  // -1 means not local.
+  const std::unique_ptr<LocalDeviceState> local_device_state_;
+  const int host_id_;
+  const std::string device_kind_;
+  PjRtClient* client_ = nullptr;
+};
+
+class PjRtStreamExecutorClient : public PjRtClient {
+ public:
+  // `allocator` may null, in which case the platform default allocator is used.
+  explicit PjRtStreamExecutorClient(
+      std::string platform_name, LocalClient* client,
+      std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
+      int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
+      std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
+      bool should_stage_host_to_device_transfers,
+      std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
+  ~PjRtStreamExecutorClient() override = default;
+
+  int host_id() const override { return host_id_; }
+
+  int device_count() const override { return devices_.size(); }
+  int addressable_device_count() const override {
+    return local_devices_.size();
+  }
+  absl::Span<PjRtDevice* const> devices() const override { return devices_; }
+  absl::Span<PjRtDevice* const> local_devices() const override {
+    return local_devices_;
+  }
+
+  StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
+    auto it = id_to_device_.find(device_id);
+    if (it != id_to_device_.end()) {
+      return it->second;
+    }
+    return InvalidArgument("No matching device found for device_id %d",
+                           device_id);
+  }
+
+  StatusOr<PjRtDevice*> LookupAddressableDevice(
+      int local_hardware_id) const override;
+
+  PjRtPlatformId platform_id() const override { return platform_id_; }
+  const std::string& platform_name() const override { return platform_name_; }
+
+  // Most platforms expect device-to-device transfers to be enqueued on the
+  // source d2d stream, but some platforms use the destination d2d stream. This
+  // function specifies which one the platform expects.
+  virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
+
+  StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
+      int num_replicas, int num_partitions) const override;
+
+  StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+      const XlaComputation& computation, CompileOptions options) override;
+
+  StatusOr<absl::optional<std::string>> ExecutableFingerprint(
+      const PjRtExecutable& executable) const override {
+    return absl::optional<std::string>();
+  }
+
+  std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
+
+  // Creates a buffer on the device without initializing or copying any data.
+  // An optional `definition_event` may be speficied that can be used to
+  // ensure the buffer isn't referenced until some external mechanism has
+  // initialized the data.
+  StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
+      const Shape& shape, PjRtDevice* device) override;
+  StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
+      const Shape& shape, PjRtDevice* device,
+      std::shared_ptr<BufferSequencingEvent> definition_event);
+
+  StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
+      const void* data, const Shape& shape,
+      HostBufferSemantics host_buffer_semantics,
+      std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
+
+  StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
+      const LiteralSlice& literal, PjRtDevice* device) override;
+
+  void MakeCrossHostReceiveBuffers(
+      absl::Span<const Shape> shapes, PjRtDevice* device,
+      PjRtCrossHostRecvNotifier&& notifier) override;
+
+  StatusOr<ChannelHandle> CreateChannelHandle() override {
+    return client()->CreateChannelHandle();
+  }
+  StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
+    return client()->CreateDeviceToHostChannelHandle();
+  }
+  StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
+    return client()->CreateHostToDeviceChannelHandle();
+  }
+
+  LocalDeviceState& device_state(int device_ordinal) const {
+    return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
+                local_devices_.at(device_ordinal))
+                ->local_device_state();
+  }
+  LocalClient* client() const { return client_; }
+  se::DeviceMemoryAllocator* allocator() const { return allocator_; }
+  tensorflow::Allocator* host_memory_allocator() const {
+    return host_memory_allocator_.get();
+  }
+  bool should_stage_host_to_device_transfers() const {
+    return should_stage_host_to_device_transfers_;
+  }
+
+  gpu::GpuExecutableRunOptions* gpu_run_options() const {
+    return gpu_run_options_.get();
+  }
+
+  tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; }
+
+ protected:
+  friend class PjRtStreamExecutorBuffer;
+  virtual void EnqueueCrossHostReceive(
+      std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
+      std::shared_ptr<BufferSequencingEvent> definition_event,
+      PjRtCrossHostRecvNotifier&& notifier) const {
+    notifier(Unimplemented("Cross host receives not implemented."));
+  }
+
+  virtual Status CopyToRemoteDevice(
+      PjRtBuffer* buffer, absl::string_view serialized_descriptor) const {
+    return Unimplemented("Cross host sends not implemented.");
+  }
+
+  const PjRtPlatformId platform_id_;
+  const std::string platform_name_;
+  LocalClient* client_;
+
+  // Allocator to be used for staging memory transfers to devices.
+  std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
+
+  // Includes all devices, including non-local devices on multi-host platforms.
+  std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
+  // Pointers to `owned_devices_`.
+  std::vector<PjRtDevice*> devices_;
+  // Maps Device::id() to the corresponding Device. Includes all devices.
+  std::map<int, PjRtDevice*> id_to_device_;
+  // Local devices indexed by local device ordinal.
+  std::vector<PjRtDevice*> local_devices_;
+  int host_id_;
+
+  se::DeviceMemoryAllocator* allocator_;
+  std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
+
+  // Should we always prefer to stage host-to-device transfers via memory
+  // allocated on host_memory_allocator_? True only on GPU, where we prefer to
+  // transfer via pinned memory.
+  bool should_stage_host_to_device_transfers_;
+
+  std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
+
+  tensorflow::thread::ThreadPool thread_pool_;
+};
+
+// Converts a 2D set of Device objects indexed by [replica][partition] into an
+// xla::DeviceAssignment.
+StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
+    absl::Span<const std::vector<PjRtDevice*>> devices);
+
+class PjRtStreamExecutorBuffer : public PjRtBuffer {
+ public:
+  // Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold
+  // may not outlive its parent PjRtStreamExecutorBuffer.
+  //
+  // There are three types of hold, as follows:
+  //
+  // 1) Usage hold: a transient hold while an operation using the buffer is
+  //    being enqueued onto a stream.
+  // A client acquires a usage hold by calling
+  // PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience
+  // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the
+  // hold should be released using a call to ConvertUsageHold. If the ScopedHold
+  // is deleted without ConvertUsageHold being called, e.g., on error, the hold
+  // is dropped. It is legal to drop a usage hold instead of calling
+  // ConvertUsageHold, even if the buffer was successfully enqueued, as long as
+  // the client ensures that all necessary synchronization has been done.
+  //
+  // 2) External hold: a potentially long-lived hold while the buffer is being
+  //    shared by an external framework, e.g., NumPy.
+  // A client acquires an external hold by calling
+  // PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience
+  // wrapper GetBufferWithExternalReference and releases it by deleting the
+  // ScopedHold. The external framework should not modify the underlying buffer
+  // unless it is confident via its own synchronization that modifications do
+  // not race with reads from the PjRtStreamExecutorBuffer.
+  //
+  // 3) Donation hold: a transient hold while an execution that donates the
+  //    buffer is being enqueued onto the compute stream.
+  // A client acquires a donation hold by calling
+  // PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue
+  // completes successfully the hold should be released using a call to
+  // ConfirmDonation after which the buffer is invalid. If the ScopedHold is
+  // deleted without ConfirmDonation being called, e.g., on error, the hold is
+  // dropped and the buffer remains valid. If the buffer is successfully
+  // enqueued the client *must* call ConfirmDonation.
+  //
+  // Donation holds behave like exclusive write locks: when a donation hold
+  // has been acquired, any attempt to acquire another hold of any type will
+  // block until the donation hold is dropped or confirmed. Acquiring a donation
+  // hold will fail with an error if there is any outstanding external hold, and
+  // will block if there are any outstanding usage holds until those holds are
+  // dropped or converted.
+  //
+  // Calls to PjRtStreamExecutorBuffer::Release (and transitively to
+  // PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will
+  // block until all usage and donation holds are either deleted or
+  // converted/confirmed.
+  class ScopedHold {
+   public:
+    enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
+    // Use a State enum instead of encoding the state in an error Status to
+    // avoid creating Status values in non-error cases. Creating a Status
+    // entails several allocations and can add O(us) to every use of a hold.
+    enum State {
+      kUninitialized = 0,
+      kValid,
+      kMoved,
+      kConverted,
+      kReleased,
+      kDonated,
+      kError
+    };
+
+    ~ScopedHold();
+    ScopedHold(ScopedHold&& other);
+    ScopedHold(const ScopedHold&) = delete;
+    ScopedHold& operator=(const ScopedHold&) = delete;
+
+    Type type() const { return type_; }
+
+    Status status() const {
+      // Lazily create Status values only when they are requested.
+      switch (state_) {
+        case kUninitialized:
+          return InvalidArgument("Buffer has not been initialized");
+        case kValid:
+          return Status::OK();
+        case kMoved:
+          return InvalidArgument("Buffer has been moved.");
+        case kConverted:
+          return InvalidArgument("Buffer has been converted");
+        case kReleased:
+          return InvalidArgument("Buffer has been released");
+        case kDonated:
+          return InvalidArgument("Buffer has been donated");
+        case kError:
+          return buffer_or_.status();
+        default:
+          CHECK(false) << "Unexpected state value " << state_;
+      }
+    }
+    bool ok() const { return state_ == kValid; }
+
+    // Access to the underlying device buffer storage. Requires this->ok().
+    const std::shared_ptr<TrackedDeviceBuffer>& buffer() const {
+      CHECK_EQ(state_, kValid);
+      CHECK_NE(buffer_or_.ValueOrDie(), nullptr);
+      return buffer_or_.ValueOrDie();
+    }
+    TrackedDeviceBuffer* operator->() const { return buffer().get(); }
+    const TrackedDeviceBuffer& operator*() const { return *buffer(); }
+
+    // Converts the hold into a usage event. Only valid for holds of type
+    // kUsage.
+    //
+    //   usage_stream:   the stream that the buffer was used on.
+    //   event:          an event that has been recorded on usage_stream after
+    //                   the buffer was used.
+    //   reference_held: true if and only if the caller has caused a
+    //                   reference to this->buffer() to stay live until after
+    //                   the host is sure that the usage (transfer or execution)
+    //                   has completed.
+    void ConvertUsageHold(se::Stream* usage_stream,
+                          std::shared_ptr<BufferSequencingEvent> event,
+                          bool reference_held);
+
+    // Confirms that the buffer was successfully donated to an execution.
+    // Only valid for holds of type kDonation. Causes the buffer to become
+    // invalid.
+    void ConfirmDonation();
+
+    // Adds the held device buffers in order to 'iterator'. Used to add the
+    // buffers to an ExecutionInput. We require but do not verify that
+    // 'iterator' when passed in is pointing to a sub-tuple of the
+    // ExecutionInput whose on_device_shape matches that of the
+    // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
+    // out of bounds. Donates the device buffers if the hold type is kDonation,
+    // otherwise retains ownership of the device buffers.
+    void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
+                    const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
+                    ExecutionInput* execution_input,
+                    se::DeviceMemoryAllocator* allocator) const;
+
+   private:
+    friend class PjRtStreamExecutorBuffer;
+    friend class PjRtStreamExecutorClient;
+
+    // Helper struct that makes it possible to move a ScopedHold through a
+    // closure.
+    using ForClosure =
+        std::tuple<PjRtStreamExecutorBuffer*, Type, State,
+                   StatusOr<std::shared_ptr<TrackedDeviceBuffer>>>;
+
+    ScopedHold(PjRtStreamExecutorBuffer* parent, Type type)
+        : parent_(parent), type_(type), state_(kUninitialized) {}
+    explicit ScopedHold(const ForClosure& closure_helper)
+        : parent_(std::get<0>(closure_helper)),
+          type_(std::get<1>(closure_helper)),
+          state_(std::get<2>(closure_helper)),
+          buffer_or_(std::get<3>(closure_helper)) {
+      // Check the buffer is not in an error state.
+      CHECK(buffer_or_.ValueOrDie() != nullptr);
+    }
+
+    // Sets buffer state.
+    void SetState(State state) { state_ = state; }
+
+    // Sets buffer_or_. Called by parent_ to initialize the hold.
+    void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or);
+    // Releases the contents of *this, so *this can subsequently be
+    // deleted without releasing the parent's hold. Should be passed to the
+    // appropriate constructor of another ScopedHold, e.g., when a hold must be
+    // passed through a closure that is incompatible with std::move.
+    ForClosure ToClosure();
+
+    PjRtStreamExecutorBuffer* const parent_;
+    const Type type_;
+
+    // There is an invariant that if ok() then
+    // buffer_or_.ValueOrDie() != nullptr.
+    State state_;
+    StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
+  };
+
+  PjRtStreamExecutorBuffer(Shape on_host_shape, Shape on_device_shape,
+                           std::shared_ptr<TrackedDeviceBuffer> device_buffer,
+                           PjRtClient* client, PjRtDevice* device);
+  ~PjRtStreamExecutorBuffer() override;
+
+  PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete;
+  PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete;
+  PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
+  PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
+
+  const Shape& on_host_shape() const override { return on_host_shape_; }
+  const Shape& on_device_shape() const override { return on_device_shape_; }
+  PjRtStreamExecutorDevice* device() const override { return device_; }
+  PjRtPlatformId platform_id() const { return client_->platform_id(); }
+  const std::string& platform_name() const { return client_->platform_name(); }
+  PjRtStreamExecutorClient* client() const override { return client_; }
+  bool IsEmptyTuple() const {
+    return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
+  }
+
+  int64 OnDeviceSizeInBytes() const override;
+
+  // Implement PjRtBuffer::ExternalReferenceHold a wrapped
+  // ScopedHold::kExternalReference.
+  class ScopedHoldAsExternalReference
+      : public PjRtBuffer::ExternalReferenceHold {
+   public:
+    explicit ScopedHoldAsExternalReference(ScopedHold hold)
+        : external_reference_(std::move(hold)) {
+      CHECK(hold.type() == ScopedHold::kExternalReference);
+    }
+
+    ~ScopedHoldAsExternalReference() override = default;
+
+    void* OpaqueDeviceMemoryDataPointer() const override {
+      return external_reference_->device_memory().front().opaque();
+    }
+
+   private:
+    ScopedHold external_reference_;
+  };
+  StatusOr<std::unique_ptr<ExternalReferenceHold>> AcquireExternalReference()
+      override;
+
+  StatusOr<absl::optional<std::shared_ptr<void>>> ReleaseDeviceMemoryOwnership(
+      bool wait_for_operations_to_complete) override;
+
+  using PjRtBuffer::ToLiteral;
+  StatusOr<std::shared_ptr<Literal>> ToLiteral(
+      bool discard_cached_copy, absl::optional<xla::Layout> layout) override;
+
+  using PjRtBuffer::CopyToHostAsync;
+  Status CopyToHostAsync(absl::optional<xla::Layout> layout) override;
+
+  // Drops the buffer's reference to its associated device memory, leaving the
+  // buffer in an invalid state. The memory will be freed lazily when all async
+  // operations using the buffer have completed, according to the allocation
+  // semantics of the underlying platform. Delete may briefly block if another
+  // thread is in the process of enqueuing an operation on this buffer, but it
+  // will never block for a stream operation to complete. If an external
+  // framework holds a reference to the TrackedDeviceBuffer via
+  // GetBufferWithExternalReference, the memory will not be freed until the
+  // external framework drops the reference.
+  void Delete() override;
+
+  bool IsDeleted() override;
+
+  // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The
+  // PjRtBuffer retains ownership of the device buffers.
+  StatusOr<ShapedBuffer> AsShapedBuffer() const;
+
+  // Returns a hold on the TrackedDeviceBuffer holding the device
+  // buffers. See comment on ScopedHold.
+  ScopedHold GetBufferWithHold(ScopedHold::Type type);
+  ScopedHold GetBufferWithUsageHold() {
+    return GetBufferWithHold(ScopedHold::kUsage);
+  }
+  ScopedHold GetBufferWithExternalReference() {
+    return GetBufferWithHold(ScopedHold::kExternalReference);
+  }
+
+  StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
+      PjRtDevice* dst_device) override;
+
+  Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override;
+
+  Status BlockHostUntilReady() override;
+
+  bool IsOnCpu() const override;
+
+  // Similar to Delete, drops the buffer's reference to its associated device
+  // memory, leaving the buffer in an invalid state, but returns the
+  // TrackedDeviceBuffer rather than freeing the device memory, so that another
+  // framework can take ownership of it. The buffer returned from Release may
+  // be safely dropped at any time even if it still has pending async
+  // operations. The client should call BlockHostUntilReady before calling
+  // Release with wait_for_operations_to_complete=false, to ensure that the host
+  // has synchronized past any outstanding write operations to the buffer. If
+  // wait_for_operations_to_complete=true the host will block until any
+  // potentially outstanding asynchronous operations have completed before
+  // returning, in which case it is safe to read or mutate the returned buffer.
+  // If the buffer was shared via an external reference it is the client's
+  // responsibility that accesses via that reference do not interfere with
+  // accesses via the buffer returned from Release.
+  StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release(
+      bool wait_for_operations_to_complete);
+
+ private:
+  friend class PjRtClient;
+  // The cached value of the buffer on the host, produced either from a call to
+  // CopyToHost or from a call to ToLiteral. Once a value has been fetched to
+  // the host, it persists Delete() is called or the PjRtBuffer is destroyed.
+  struct HostValue {
+    absl::Notification ready;
+    // status and value are valid for reading only after `ready` has been
+    // notified.
+    Status status;
+    std::shared_ptr<Literal> value;
+  };
+
+  // Blocks in mu_.Await until there are no more usage holds.
+  void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Blocks in mu_.Await until there is no donation hold.
+  void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Adds a hold of 'type' and returns device_buffer_. Returns an error if
+  // device_buffer_ is null, or if a donation hold was requested when there is
+  // an outstanding external hold.
+  StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked(
+      ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Adds a hold of hold->type() and initializes `hold` with device_buffer_.
+  // Initializes hold with an error if device_buffer_ is null, or if a donation
+  // hold was requested when there is an outstanding external hold.
+  void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
+  // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after
+  // device_buffer_ was successfully enqueued on a stream.
+  void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
+                        std::shared_ptr<BufferSequencingEvent> event,
+                        bool reference_held);
+
+  // Drops a donation hold and makes *this invalid for further use. Does a
+  // sanity check that buffer==device_buffer_. Called after device_buffer_ was
+  // successfully donated to an execution.
+  void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
+
+  // Initiates a copy of the buffer to the host. Does not block waiting for
+  // the transfer to complete. A host value is returned and if
+  // `discard_cached_copy` is false stored in an internal buffer so that future
+  // transfers don't have to transfer the data from host again. If a layout is
+  // passed then a literal of this layout will be returned and possibly cached.
+  StatusOr<std::shared_ptr<HostValue>> CopyToHostAsyncInternal(
+      bool discard_cached_copy, absl::optional<xla::Layout> layout);
+
+  // Drops a hold without taking any other action. Does a sanity check that
+  // buffer==device_buffer_ or device_buffer_==nullptr.
+  void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
+
+  StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
+                     std::shared_ptr<BufferSequencingEvent>>>
+  CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
+                     LocalDeviceState* transfer_local_device,
+                     se::Stream* transfer_stream,
+                     std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
+
+  PjRtStreamExecutorClient* const client_;
+  const Shape on_host_shape_;
+  const Shape on_device_shape_;
+  PjRtStreamExecutorDevice* const device_;
+
+  mutable absl::Mutex mu_;
+  std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
+  absl::flat_hash_map<xla::Layout, std::shared_ptr<HostValue>> host_values_
+      TF_GUARDED_BY(mu_);
+  std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
+  // Count of holds on the buffer.
+  std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
+  // Semaphore used to ensure there is only one outstanding donation hold.
+  Semaphore donation_semaphore_;
+};
+
+// Wraps one or more XLA LocalExecutables (one per partition, as specified by
+// the build options).
+class PjRtStreamExecutorExecutable : public PjRtExecutable {
+ public:
+  PjRtStreamExecutorExecutable(
+      std::vector<std::unique_ptr<LocalExecutable>> executables,
+      bool parameter_is_tupled_arguments,
+      std::shared_ptr<DeviceAssignment> device_assignment,
+      std::vector<LogicalDeviceIds> addressable_device_logical_ids,
+      std::vector<PjRtDevice*> addressable_devices,
+      PjRtStreamExecutorClient* client);
+
+  ~PjRtStreamExecutorExecutable() override = default;
+
+  PjRtStreamExecutorClient* client() const override { return client_; }
+
+  const std::string& name() const override;
+
+  int num_replicas() const override {
+    return executables_[0]->build_options().num_replicas();
+  }
+
+  int num_partitions() const override {
+    return executables_[0]->build_options().num_partitions();
+  }
+
+  int64 SizeOfGeneratedCodeInBytes() const override {
+    int64 size = 0;
+    for (auto& executable : executables_) {
+      size += executable->executable()->SizeOfGeneratedCodeInBytes();
+    }
+    return size;
+  }
+
+  const DeviceAssignment& device_assignment() const override {
+    return *device_assignment_;
+  }
+
+  absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
+      const override {
+    return addressable_device_logical_ids_;
+  }
+
+  absl::Span<PjRtDevice* const> addressable_devices() const override {
+    return addressable_devices_;
+  }
+
+  // Return an HloModule per partition.
+  StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
+      const override;
+
+  StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
+      absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
+      const ExecuteOptions& options) override;
+
+  StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
+      absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
+      const ExecuteOptions& options) override;
+
+  StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
+      absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
+      const ExecuteOptions& options) override;
+
+  void Delete() override { executables_.clear(); }
+
+  absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
+    return executables_;
+  }
+
+ protected:
+  bool parameter_is_tupled_arguments() const {
+    return parameter_is_tupled_arguments_;
+  }
+
+ private:
+  friend class PjRtStreamExecutorClient;
+  // Initializes information about which arguments to which executables must be
+  // donated due to aliases that were specified by the computation.
+  Status SetUpDonation(bool tuple_inputs);
+
+  virtual bool MustDonateParameter(int executable_idx, int parameter) const;
+
+  virtual StatusOr<std::vector<ExecutionInput>>
+  MakeExecutionInputsAndWaitForEvents(
+      int device_ordinal, const ExecuteOptions& options,
+      absl::Span<PjRtBuffer* const> argument_handles,
+      absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
+      absl::flat_hash_set<BufferSequencingEvent*>& events) const;
+
+  StatusOr<ScopedShapedBuffer> EnqueueExecution(
+      absl::Span<PjRtBuffer* const> argument_handles, int replica,
+      int partition, int executable_idx, const RunId& run_id,
+      const ExecuteOptions& options, PjRtDevice* device,
+      std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
+      std::shared_ptr<DeviceAssignment> device_assignment) const;
+
+  virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers(
+      int device_ordinal, const ExecuteOptions& options,
+      ScopedShapedBuffer result_buffer,
+      std::shared_ptr<BufferSequencingEvent> definition_event,
+      PjRtDevice* device) const;
+
+  StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
+      absl::Span<PjRtBuffer* const> argument_handles, int replica,
+      int partition, const RunId& run_id, const ExecuteOptions& options,
+      PjRtDevice* device = nullptr) const;
+
+  // Create shared pointers so we can free them after the execution: with
+  // asynchronous execution, the process being executed can outlive the
+  // executable itself.
+  PjRtStreamExecutorClient* const client_;
+  // One executable per partition.
+  std::vector<std::shared_ptr<LocalExecutable>> executables_;
+  // Per-executable set of parameters that have any aliased buffers and thus
+  // must be donated when executing the computation.
+  std::vector<absl::flat_hash_set<int>> parameters_that_must_be_donated_;
+  std::shared_ptr<DeviceAssignment> device_assignment_;
+
+  // True if the executables were compiled expecting arguments in a single
+  // tuple.
+  const bool parameter_is_tupled_arguments_;
+
+  // The replica and partition indices of device_assignment_ to be run by this
+  // client. On single-host platforms without partitioning, this is all replicas
+  // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
+  // case on multi-host platforms. If there are 4 replicas and 2 partitions on a
+  // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
+  std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
+
+  // addressable_devices_[i] is the Device to which
+  // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
+  // unique_ptrs to play well with the Python bindings (see xla.cc).
+  std::vector<PjRtDevice*> addressable_devices_;
+};
+
+// Executables can donate buffers so that buffers can be aliased from inputs
+// to outputs. This function returns the list of parameters that must be
+// donated when executable is run. tuple_inputs reflects the option that
+// executable was compiled with.
+StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
+    const HloModule& hlo_module, bool tuple_inputs);
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc
index 9d6fa92..830f7c6 100644
--- a/tensorflow/compiler/xla/pjrt/tpu_client.cc
+++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc
@@ -23,7 +23,7 @@
 #include "absl/status/status.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
-#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
 #include "tensorflow/compiler/xla/shape.h"
@@ -94,10 +94,11 @@
   return Status::OK();
 }
 
-class PjRtTpuClient : public PjRtClient {
+class PjRtTpuClient : public PjRtStreamExecutorClient {
  public:
   PjRtTpuClient(LocalClient* client,
-                std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id);
+                std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
+                int host_id);
 
   StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
       int num_replicas, int num_partitions) const override;
@@ -108,14 +109,14 @@
       const PjRtExecutable& executable) const override;
 };
 
-PjRtTpuClient::PjRtTpuClient(LocalClient* client,
-                             std::vector<std::unique_ptr<PjRtDevice>> devices,
-                             int host_id)
-    : PjRtClient(kTpuName, client, std::move(devices), host_id,
-                 /*allocator=*/nullptr,
-                 /*host_memory_allocator=*/nullptr,
-                 /*should_stage_host_to_device_transfers=*/false,
-                 /*gpu_run_options=*/nullptr) {}
+PjRtTpuClient::PjRtTpuClient(
+    LocalClient* client,
+    std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
+    : PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
+                               /*allocator=*/nullptr,
+                               /*host_memory_allocator=*/nullptr,
+                               /*should_stage_host_to_device_transfers=*/false,
+                               /*gpu_run_options=*/nullptr) {}
 
 StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
     int num_replicas, int num_partitions) const {
@@ -128,7 +129,8 @@
                                                             num_partitions);
   }
   // Fallback to default global device assignment if we can't run locally.
-  return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
+  return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
+                                                              num_partitions);
 }
 
 StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
@@ -152,10 +154,10 @@
   return absl::optional<std::string>(tpu_executable->fingerprint());
 }
 
-StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
+StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
     LocalClient* client,
     std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
-  std::vector<std::unique_ptr<PjRtDevice>> devices;
+  std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
   tf_tpu::TpuTopologyExternal topology =
       tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
 
diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h
index cdc68bc..d9847bb 100644
--- a/tensorflow/compiler/xla/pjrt/tpu_client.h
+++ b/tensorflow/compiler/xla/pjrt/tpu_client.h
@@ -20,20 +20,20 @@
 #include <memory>
 #include <vector>
 
-#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/stream_executor/tpu/tpu_topology.h"
 
 namespace xla {
 
-class PjRtTpuDevice : public PjRtDevice {
+class PjRtTpuDevice : public PjRtStreamExecutorDevice {
  public:
   PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
                 std::unique_ptr<LocalDeviceState> local_device_state,
                 int host_id, const std::array<int, 3>& coords,
                 std::string device_kind)
-      : PjRtDevice(core.Id(), std::move(local_device_state),
-                   std::move(device_kind), host_id),
+      : PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state),
+                                 std::move(device_kind), host_id),
         core_(core),
         coords_(coords) {}
 
diff --git a/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc
index 2843c9b..cd7a37a 100644
--- a/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc
+++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc
@@ -160,13 +160,6 @@
   }
 }
 
-namespace {
-
-using MoveIterator =
-    absl::Span<const std::shared_ptr<BufferSequencingEvent>>::iterator;
-
-}  // namespace
-
 TrackedDeviceBuffer::TrackedDeviceBuffer(
     se::DeviceMemoryAllocator* allocator, int device_ordinal,
     absl::Span<se::DeviceMemoryBase const> device_memory,
@@ -175,9 +168,8 @@
     : allocator_(allocator),
       device_ordinal_(device_ordinal),
       device_memory_(device_memory.begin(), device_memory.end()),
-      definition_events_(
-          std::move_iterator<MoveIterator>(definition_events.begin()),
-          std::move_iterator<MoveIterator>(definition_events.end())),
+      definition_events_(std::make_move_iterator(definition_events.begin()),
+                         std::make_move_iterator(definition_events.end())),
       in_use_(true),
       on_delete_callback_(std::move(on_delete_callback)) {}
 
diff --git a/tensorflow/compiler/xla/pjrt/utils.cc b/tensorflow/compiler/xla/pjrt/utils.cc
new file mode 100644
index 0000000..a919c84
--- /dev/null
+++ b/tensorflow/compiler/xla/pjrt/utils.cc
@@ -0,0 +1,263 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/pjrt/utils.h"
+
+#include "absl/container/flat_hash_set.h"
+#include "tensorflow/compiler/xla/client/executable_build_options.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+namespace {
+StatusOr<Shape> GetShardedShape(const Shape& shape,
+                                const OpSharding& sharding) {
+  if (sharding.type() == OpSharding::TUPLE) {
+    if (!shape.IsTuple()) {
+      return InvalidArgument(
+          "Got tuple OpSharding (%s) for non-tuple shape (%s)",
+          sharding.DebugString(), shape.ToString());
+    }
+    if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
+      return InvalidArgument(
+          "Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
+          " (OpSharding: %s, shape: %s)",
+          sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
+          sharding.DebugString(), shape.ToString());
+    }
+    std::vector<Shape> sharded_subshapes;
+    for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
+      TF_ASSIGN_OR_RETURN(
+          Shape sharded_subshape,
+          GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
+      sharded_subshapes.emplace_back(std::move(sharded_subshape));
+    }
+    return ShapeUtil::MakeTupleShape(sharded_subshapes);
+  }
+  TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
+                      HloSharding::FromProto(sharding));
+  return hlo_sharding.TileShape(shape);
+}
+
+StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
+  const Shape unsharded_shape(instr.shape());
+  Shape sharded_shape;
+  if (instr.has_sharding()) {
+    TF_ASSIGN_OR_RETURN(sharded_shape,
+                        GetShardedShape(unsharded_shape, instr.sharding()));
+  } else {
+    sharded_shape = unsharded_shape;
+  }
+  LayoutUtil::ClearLayout(&sharded_shape);
+  return sharded_shape;
+}
+
+// Returns sharded (argument shapes, result shape) without layouts.
+StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
+    const XlaComputation& computation, const ProgramShape& program_shape) {
+  std::vector<Shape> arg_shapes;
+  arg_shapes.resize(program_shape.parameters_size());
+  Shape result_shape;
+  for (const HloComputationProto& comp : computation.proto().computations()) {
+    if (comp.id() != computation.proto().entry_computation_id()) {
+      continue;
+    }
+    for (const HloInstructionProto& instr : comp.instructions()) {
+      if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
+        if (instr.parameter_number() >= program_shape.parameters_size()) {
+          return InvalidArgument(
+              "Got invalid parameter number %d, expected %d parameters",
+              instr.parameter_number(), program_shape.parameters_size());
+        }
+        TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
+                            GetShardedShape(instr));
+      }
+      if (instr.id() == comp.root_id()) {
+        if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
+          return InvalidArgument("Found multiple root instructions");
+        }
+        TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
+      }
+    }
+  }
+  for (int i = 0; i < arg_shapes.size(); ++i) {
+    if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
+      return InvalidArgument("Couldn't find parameter %d", i);
+    }
+  }
+  if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
+    return InvalidArgument("Couldn't find root instruction");
+  }
+  return std::make_pair(arg_shapes, result_shape);
+}
+}  // namespace
+
+Status ParseDeviceAssignmentCompileOptions(
+    bool compile_portable_executable, ExecutableBuildOptions* build_options,
+    std::function<StatusOr<DeviceAssignment>(int, int)>
+        GetDefaultDeviceAssignmentFunction,
+    int* num_replicas, int* num_partitions,
+    std::shared_ptr<DeviceAssignment>* device_assignment) {
+  if (compile_portable_executable) {
+    if (build_options->has_device_assignment()) {
+      return InvalidArgument(
+          "CompileOptions requests portable executable but "
+          "ExecutableBuildOptions includes a device assignment");
+    }
+    *num_replicas = 1;
+    *num_partitions = 1;
+  } else {
+    if (!build_options->has_device_assignment()) {
+      VLOG(2) << "Compile using default device_assignment.";
+      TF_ASSIGN_OR_RETURN(
+          DeviceAssignment device_assignment,
+          GetDefaultDeviceAssignmentFunction(build_options->num_replicas(),
+                                             build_options->num_partitions()));
+      build_options->set_device_assignment(device_assignment);
+    }
+    VLOG(2) << "Compile device_assignment:\n"
+            << build_options->device_assignment().ToString();
+    *num_replicas = build_options->device_assignment().replica_count();
+    *num_partitions = build_options->device_assignment().computation_count();
+    *device_assignment =
+        std::make_shared<DeviceAssignment>(build_options->device_assignment());
+  }
+  return Status::OK();
+}
+
+Status DetermineArgumentLayoutsFromCompileOptions(
+    const XlaComputation& computation,
+    std::function<StatusOr<Shape>(Shape)>
+        choose_compact_layout_for_shape_function,
+    absl::optional<std::vector<Shape>>& argument_layouts,
+    ExecutableBuildOptions* build_options,
+    std::vector<const Shape*>* argument_layout_pointers) {
+  TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
+                      computation.GetProgramShape());
+  if (!argument_layouts) {
+    argument_layouts.emplace(program_shape.parameters());
+    for (Shape& shape : *argument_layouts) {
+      LayoutUtil::ClearLayout(&shape);
+    }
+  } else if (argument_layouts->size() != program_shape.parameters_size()) {
+    return InvalidArgument(
+        "CompileOptions specify %d argument layouts, but computation has %d "
+        "arguments",
+        argument_layouts->size(), program_shape.parameters_size());
+  }
+  argument_layout_pointers->reserve(argument_layouts->size());
+
+  // Assign a default layout based on `sharded_shape` to any array subshapes in
+  // `dst_shape` that are missing layouts.
+  auto assign_layouts = [&choose_compact_layout_for_shape_function](
+                            const Shape& sharded_shape, Shape* dst_shape) {
+    return ShapeUtil::ForEachMutableSubshapeWithStatus(
+        dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
+          if (subshape->IsArray() && !subshape->has_layout()) {
+            CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
+            const Shape& sharded_subshape =
+                ShapeUtil::GetSubshape(sharded_shape, idx);
+            LayoutUtil::SetToDefaultLayout(subshape);
+            TF_ASSIGN_OR_RETURN(
+                Shape layout,
+                choose_compact_layout_for_shape_function(sharded_subshape));
+            *subshape->mutable_layout() = layout.layout();
+          }
+          return Status::OK();
+        });
+  };
+  TF_ASSIGN_OR_RETURN(auto sharded_shapes,
+                      GetShardedProgramShapes(computation, program_shape));
+
+  CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size());
+  for (int i = 0; i < argument_layouts->size(); ++i) {
+    Shape* layout = &(*argument_layouts)[i];
+    argument_layout_pointers->push_back(layout);
+    TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
+  }
+
+  Shape result_layout;
+  if (build_options->result_layout()) {
+    result_layout = *build_options->result_layout();
+  } else {
+    result_layout = program_shape.result();
+    LayoutUtil::ClearLayout(&result_layout);
+  }
+  TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
+  build_options->set_result_layout(result_layout);
+  return Status::OK();
+}
+
+StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
+    const HloModule& module, bool tuple_inputs) {
+  HloComputation* computation = module.entry_computation();
+  int number_of_parameters = [&]() -> int {
+    if (tuple_inputs) {
+      CHECK_EQ(computation->num_parameters(), 1);
+      const Shape& input_tuple_shape =
+          computation->parameter_instruction(0)->shape();
+      CHECK(input_tuple_shape.IsTuple());
+      return input_tuple_shape.tuple_shapes_size();
+    } else {
+      return computation->num_parameters();
+    }
+  }();
+  // If any buffer in a parameter is aliased we will donate the entire input
+  // parameter.
+  absl::flat_hash_set<int> parameters_to_donate;
+  const HloInputOutputAliasConfig& config = module.input_output_alias_config();
+  TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
+      [&](const ShapeIndex& output_index,
+          const HloInputOutputAliasConfig::Alias& alias) {
+        if (tuple_inputs) {
+          if (alias.parameter_number != 0) {
+            return InvalidArgument(
+                "Unexpected parameter number %d in alias config with tupled "
+                "inputs",
+                alias.parameter_number);
+          }
+          const ShapeIndex& index = alias.parameter_index;
+          if (!index.empty()) {
+            int this_parameter = index.data()[0];
+            if (this_parameter >= number_of_parameters) {
+              return InvalidArgument(
+                  "Unexpected parameter index %s in alias config with tupled "
+                  "inputs and %d parameters",
+                  index.ToString(), number_of_parameters);
+            }
+            parameters_to_donate.insert(this_parameter);
+          }
+        } else {
+          int this_parameter = alias.parameter_number;
+          if (this_parameter >= number_of_parameters) {
+            return InvalidArgument(
+                "Unexpected parameter number %d in alias config without tupled "
+                "inputs and %d parameters",
+                this_parameter, number_of_parameters);
+          }
+          parameters_to_donate.insert(this_parameter);
+        }
+        return Status::OK();
+      }));
+  return parameters_to_donate;
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/pjrt/utils.h b/tensorflow/compiler/xla/pjrt/utils.h
new file mode 100644
index 0000000..ff4b5ba
--- /dev/null
+++ b/tensorflow/compiler/xla/pjrt/utils.h
@@ -0,0 +1,57 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
+#define TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "tensorflow/compiler/xla/client/executable_build_options.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+// Returns the num_replicas, num_partitions and device assignment given a
+// ExecutableBuildOptions and whether we want a portable executable.
+Status ParseDeviceAssignmentCompileOptions(
+    bool compile_portable_executable, ExecutableBuildOptions* build_options,
+    std::function<StatusOr<DeviceAssignment>(int, int)>
+        GetDefaultDeviceAssignmentFunction,
+    int* num_replicas, int* num_partitions,
+    std::shared_ptr<DeviceAssignment>* device_assignment);
+
+// Returns pointers to the argument layouts given an XlaComputation and
+// ExecutableBuildOptions.
+Status DetermineArgumentLayoutsFromCompileOptions(
+    const XlaComputation& computation,
+    std::function<StatusOr<Shape>(Shape)>
+        choose_compact_layout_for_shape_function,
+    absl::optional<std::vector<Shape>>& argument_layouts,
+    ExecutableBuildOptions* build_options,
+    std::vector<const Shape*>* argument_layout_pointers);
+
+// Executables can donate buffers so that buffers can be aliased from inputs
+// to outputs. This function returns the list of parameters that must be
+// donated when executable is run. tuple_inputs reflects the option that
+// executable was compiled with.
+StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
+    const HloModule& module, bool tuple_inputs);
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 050f300..055c8d9 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -22,6 +22,22 @@
     deps = [":xla_extension"],
 )
 
+cc_library(
+    name = "absl_casters",
+    hdrs = ["absl_casters.h"],
+    compatible_with = [],
+    copts = [
+        "-fexceptions",
+        "-fno-strict-aliasing",
+    ],
+    features = ["-use_header_modules"],
+    deps = [
+        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/types:span",
+        "@pybind11",
+    ],
+)
+
 pyx_library(
     name = "custom_call_for_test",
     testonly = True,
@@ -97,13 +113,14 @@
     name = "types",
     srcs = ["types.cc"],
     hdrs = ["types.h"],
+    compatible_with = [],
     copts = [
         "-fexceptions",
         "-fno-strict-aliasing",
     ],
     features = ["-use_header_modules"],
     deps = [
-        ":bfloat16",
+        ":absl_casters",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status",
@@ -113,6 +130,7 @@
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/compiler/xla/pjrt:pjrt_client",
         "//tensorflow/core:lib",
+        "//tensorflow/python:bfloat16_lib",
         "//third_party/py/numpy:headers",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:inlined_vector",
@@ -159,42 +177,6 @@
 )
 
 cc_library(
-    name = "bfloat16",
-    srcs = ["bfloat16.cc"],
-    hdrs = ["bfloat16.h"],
-    copts = [
-        "-fexceptions",
-        "-fno-strict-aliasing",
-    ],
-    features = ["-use_header_modules"],
-    deps = [
-        "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla:types",
-        "//tensorflow/compiler/xla:util",
-        "//tensorflow/core/platform:bfloat16",
-        "//tensorflow/core/platform:logging",
-        "//third_party/py/numpy:headers",
-        "//third_party/python_runtime:headers",  # buildcleaner: keep
-        "@com_google_absl//absl/strings",
-        "@pybind11",
-    ],
-)
-
-py_test(
-    name = "bfloat16_test",
-    srcs = ["bfloat16_test.py"],
-    main = "bfloat16_test.py",
-    python_version = "PY3",
-    tags = ["no_oss"],
-    deps = [
-        ":xla_client",
-        ":xla_extension",
-        "@absl_py//absl/testing:absltest",
-        "@absl_py//absl/testing:parameterized",
-    ] + xla_py_test_deps(),
-)
-
-cc_library(
     name = "py_client",
     srcs = [
         "py_buffer.cc",
@@ -206,6 +188,7 @@
         "py_client.h",
         "py_executable.h",
     ],
+    compatible_with = [],
     copts = [
         "-fexceptions",
         "-fno-strict-aliasing",
@@ -221,6 +204,7 @@
         "//tensorflow/core/platform:fingerprint",
         "//tensorflow/core/profiler:protos_all_cc",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
@@ -232,6 +216,7 @@
     name = "dlpack",
     srcs = ["dlpack.cc"],
     hdrs = ["dlpack.h"],
+    compatible_with = [],
     copts = [
         "-fexceptions",
         "-fno-strict-aliasing",
@@ -244,6 +229,7 @@
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla/pjrt:pjrt_client",
+        "//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client",  # TODO(zhangqiaorjc): Remove after adding a factory method for PjRtBuffer.
         "//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
         "//tensorflow/stream_executor:device_memory",
         "//tensorflow/stream_executor:platform",
@@ -263,6 +249,7 @@
     name = "jax_jit",
     srcs = ["jax_jit.cc"],
     hdrs = ["jax_jit.h"],
+    compatible_with = [],
     copts = [
         "-fexceptions",
         "-fno-strict-aliasing",
@@ -292,6 +279,7 @@
     name = "ops",
     srcs = ["ops.cc"],
     hdrs = ["ops.h"],
+    compatible_with = [],
     copts = [
         "-fexceptions",
         "-fno-strict-aliasing",
@@ -356,6 +344,7 @@
     name = "outfeed_receiver_py",
     srcs = ["outfeed_receiver_py.cc"],
     hdrs = ["outfeed_receiver_py.h"],
+    compatible_with = [],
     copts = [
         "-fexceptions",
         "-fno-strict-aliasing",
@@ -379,12 +368,14 @@
     name = "pytree",
     srcs = ["pytree.cc"],
     hdrs = ["pytree.h"],
+    compatible_with = [],
     copts = [
         "-fexceptions",
         "-fno-strict-aliasing",
     ],
     features = ["-use_header_modules"],
     deps = [
+        ":types",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/hash",
@@ -435,6 +426,7 @@
     name = "xla_compiler",
     srcs = ["xla_compiler.cc"],
     hdrs = ["xla_compiler.h"],
+    compatible_with = [],
     copts = [
         "-fexceptions",
         "-fno-strict-aliasing",
@@ -481,7 +473,6 @@
     features = ["-use_header_modules"],
     module_name = "xla_extension",
     deps = [
-        ":bfloat16",
         ":dlpack",
         ":jax_jit",
         ":ops",
@@ -534,6 +525,7 @@
         # without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does
         # not require Tensorflow.
         "//tensorflow/core:lib_internal_impl",  # buildcleaner: keep
+        "//tensorflow/python:bfloat16_lib",
         "//tensorflow/stream_executor:device_memory_allocator",
         "//tensorflow/stream_executor:platform",
     ] + select({
diff --git a/tensorflow/compiler/xla/python/absl_casters.h b/tensorflow/compiler/xla/python/absl_casters.h
new file mode 100644
index 0000000..d7892a6
--- /dev/null
+++ b/tensorflow/compiler/xla/python/absl_casters.h
@@ -0,0 +1,73 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_ABSL_CASTERS_H_
+#define TENSORFLOW_COMPILER_XLA_PYTHON_ABSL_CASTERS_H_
+
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "pybind11/cast.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/stl.h"
+
+namespace pybind11 {
+namespace detail {
+
+// absl::Span
+template <typename T>
+struct type_caster<absl::Span<const T>> {
+  using value_conv = make_caster<T>;
+
+  PYBIND11_TYPE_CASTER(absl::Span<const T>,
+                       _("Span[") + value_conv::name + _("]"));
+
+  // absl::Span doesn't hold ownership. We therefore need a temporary array.
+  // Pybind appears to keep type_casters alive until the callee has run.
+  std::vector<T> storage;
+
+  bool load(handle src, bool convert) {
+    if (!isinstance<sequence>(src)) {
+      return false;
+    }
+    auto seq = reinterpret_borrow<sequence>(src);
+    storage.clear();
+    storage.reserve(seq.size());
+    for (const auto& it : seq) {
+      value_conv conv;
+      if (!conv.load(it, convert)) {
+        return false;
+      }
+      storage.push_back(cast_op<T&&>(std::move(conv)));
+    }
+    value = absl::Span<const T>(storage);
+    return true;
+  }
+};
+
+// When absl::optional is an alias for std::optional, the type_caster
+// specializations are provided by pybind11.
+#ifndef ABSL_HAVE_STD_OPTIONAL
+// absl::optional
+template <typename T>
+struct type_caster<absl::optional<T>> : optional_caster<absl::optional<T>> {};
+
+template <>
+struct type_caster<absl::nullopt_t> : public void_caster<absl::nullopt_t> {};
+#endif
+
+}  // namespace detail
+}  // namespace pybind11
+
+#endif  // TENSORFLOW_COMPILER_XLA_PYTHON_ABSL_CASTERS_H_
diff --git a/tensorflow/compiler/xla/python/bfloat16.cc b/tensorflow/compiler/xla/python/bfloat16.cc
deleted file mode 100644
index 5f96c49..0000000
--- a/tensorflow/compiler/xla/python/bfloat16.cc
+++ /dev/null
@@ -1,1576 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/python/bfloat16.h"
-
-#include <array>
-#include <locale>
-// Place `<locale>` before <Python.h> to avoid a build failure in macOS.
-#include <Python.h>
-
-#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
-
-#include "numpy/arrayobject.h"
-#include "numpy/ufuncobject.h"
-#include "absl/strings/str_cat.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/platform/bfloat16.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace xla {
-namespace {
-
-namespace py = pybind11;
-
-struct PyDecrefDeleter {
-  void operator()(PyObject* p) const { Py_DECREF(p); }
-};
-
-// Safe container for an owned PyObject. On destruction, the reference count of
-// the contained object will be decremented.
-using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
-Safe_PyObjectPtr make_safe(PyObject* object) {
-  return Safe_PyObjectPtr(object);
-}
-
-bool PyLong_CheckNoOverflow(PyObject* object) {
-  if (!PyLong_Check(object)) {
-    return false;
-  }
-  int overflow = 0;
-  PyLong_AsLongAndOverflow(object, &overflow);
-  return (overflow == 0);
-}
-
-// Registered numpy type ID. Global variable populated by the registration code.
-// Protected by the GIL.
-int npy_bfloat16 = -1;
-
-// Forward declaration.
-extern PyTypeObject PyBfloat16_Type;
-
-// Representation of a Python bfloat16 object.
-struct PyBfloat16 {
-  PyObject_HEAD;  // Python object header
-  bfloat16 value;
-};
-
-// Returns true if 'object' is a PyBfloat16.
-bool PyBfloat16_Check(PyObject* object) {
-  return PyObject_IsInstance(object,
-                             reinterpret_cast<PyObject*>(&PyBfloat16_Type));
-}
-
-// Extracts the value of a PyBfloat16 object.
-bfloat16 PyBfloat16_Bfloat16(PyObject* object) {
-  return reinterpret_cast<PyBfloat16*>(object)->value;
-}
-
-// Constructs a PyBfloat16 object from a bfloat16.
-Safe_PyObjectPtr PyBfloat16_FromBfloat16(bfloat16 x) {
-  Safe_PyObjectPtr ref =
-      make_safe(PyBfloat16_Type.tp_alloc(&PyBfloat16_Type, 0));
-  PyBfloat16* p = reinterpret_cast<PyBfloat16*>(ref.get());
-  if (p) {
-    p->value = x;
-  }
-  return ref;
-}
-
-// Converts a Python object to a bfloat16 value. Returns true on success,
-// returns false and reports a Python error on failure.
-bool CastToBfloat16(PyObject* arg, bfloat16* output) {
-  if (PyBfloat16_Check(arg)) {
-    *output = PyBfloat16_Bfloat16(arg);
-    return true;
-  }
-  if (PyFloat_Check(arg)) {
-    double d = PyFloat_AsDouble(arg);
-    if (PyErr_Occurred()) {
-      return false;
-    }
-    // TODO(phawkins): check for overflow
-    *output = bfloat16(d);
-    return true;
-  }
-  if (PyLong_CheckNoOverflow(arg)) {
-    long l = PyLong_AsLong(arg);  // NOLINT
-    if (PyErr_Occurred()) {
-      return false;
-    }
-    // TODO(phawkins): check for overflow
-    *output = bfloat16(static_cast<float>(l));
-    return true;
-  }
-  if (PyArray_IsScalar(arg, Half)) {
-    Eigen::half f;
-    PyArray_ScalarAsCtype(arg, &f);
-    *output = bfloat16(f);
-    return true;
-  }
-  if (PyArray_IsScalar(arg, Float)) {
-    float f;
-    PyArray_ScalarAsCtype(arg, &f);
-    *output = bfloat16(f);
-    return true;
-  }
-  if (PyArray_IsScalar(arg, Double)) {
-    double f;
-    PyArray_ScalarAsCtype(arg, &f);
-    *output = bfloat16(f);
-    return true;
-  }
-  if (PyArray_IsZeroDim(arg)) {
-    Safe_PyObjectPtr ref;
-    PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
-    if (PyArray_TYPE(arr) != npy_bfloat16) {
-      ref = make_safe(PyArray_Cast(arr, npy_bfloat16));
-      if (PyErr_Occurred()) {
-        return false;
-      }
-      arg = ref.get();
-      arr = reinterpret_cast<PyArrayObject*>(arg);
-    }
-    *output = *reinterpret_cast<bfloat16*>(PyArray_DATA(arr));
-    return true;
-  }
-  return false;
-}
-
-bool SafeCastToBfloat16(PyObject* arg, bfloat16* output) {
-  if (PyBfloat16_Check(arg)) {
-    *output = PyBfloat16_Bfloat16(arg);
-    return true;
-  }
-  return false;
-}
-
-// Converts a PyBfloat16 into a PyFloat.
-PyObject* PyBfloat16_Float(PyObject* self) {
-  bfloat16 x = PyBfloat16_Bfloat16(self);
-  return PyFloat_FromDouble(static_cast<double>(x));
-}
-
-// Converts a PyBfloat16 into a PyInt.
-PyObject* PyBfloat16_Int(PyObject* self) {
-  bfloat16 x = PyBfloat16_Bfloat16(self);
-  long y = static_cast<long>(x);  // NOLINT
-  return PyLong_FromLong(y);
-}
-
-// Negates a PyBfloat16.
-PyObject* PyBfloat16_Negative(PyObject* self) {
-  bfloat16 x = PyBfloat16_Bfloat16(self);
-  return PyBfloat16_FromBfloat16(-x).release();
-}
-
-PyObject* PyBfloat16_Add(PyObject* a, PyObject* b) {
-  bfloat16 x, y;
-  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
-    return PyBfloat16_FromBfloat16(x + y).release();
-  }
-  return PyArray_Type.tp_as_number->nb_add(a, b);
-}
-
-PyObject* PyBfloat16_Subtract(PyObject* a, PyObject* b) {
-  bfloat16 x, y;
-  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
-    return PyBfloat16_FromBfloat16(x - y).release();
-  }
-  return PyArray_Type.tp_as_number->nb_subtract(a, b);
-}
-
-PyObject* PyBfloat16_Multiply(PyObject* a, PyObject* b) {
-  bfloat16 x, y;
-  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
-    return PyBfloat16_FromBfloat16(x * y).release();
-  }
-  return PyArray_Type.tp_as_number->nb_multiply(a, b);
-}
-
-PyObject* PyBfloat16_TrueDivide(PyObject* a, PyObject* b) {
-  bfloat16 x, y;
-  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
-    return PyBfloat16_FromBfloat16(x / y).release();
-  }
-  return PyArray_Type.tp_as_number->nb_true_divide(a, b);
-}
-
-// Python number methods for PyBfloat16 objects.
-PyNumberMethods PyBfloat16_AsNumber = {
-    PyBfloat16_Add,       // nb_add
-    PyBfloat16_Subtract,  // nb_subtract
-    PyBfloat16_Multiply,  // nb_multiply
-    nullptr,              // nb_remainder
-    nullptr,              // nb_divmod
-    nullptr,              // nb_power
-    PyBfloat16_Negative,  // nb_negative
-    nullptr,              // nb_positive
-    nullptr,              // nb_absolute
-    nullptr,              // nb_nonzero
-    nullptr,              // nb_invert
-    nullptr,              // nb_lshift
-    nullptr,              // nb_rshift
-    nullptr,              // nb_and
-    nullptr,              // nb_xor
-    nullptr,              // nb_or
-    PyBfloat16_Int,       // nb_int
-    nullptr,              // reserved
-    PyBfloat16_Float,     // nb_float
-
-    nullptr,  // nb_inplace_add
-    nullptr,  // nb_inplace_subtract
-    nullptr,  // nb_inplace_multiply
-    nullptr,  // nb_inplace_remainder
-    nullptr,  // nb_inplace_power
-    nullptr,  // nb_inplace_lshift
-    nullptr,  // nb_inplace_rshift
-    nullptr,  // nb_inplace_and
-    nullptr,  // nb_inplace_xor
-    nullptr,  // nb_inplace_or
-
-    nullptr,                // nb_floor_divide
-    PyBfloat16_TrueDivide,  // nb_true_divide
-    nullptr,                // nb_inplace_floor_divide
-    nullptr,                // nb_inplace_true_divide
-    nullptr,                // nb_index
-};
-
-// Constructs a new PyBfloat16.
-PyObject* PyBfloat16_New(PyTypeObject* type, PyObject* args, PyObject* kwds) {
-  if (kwds && PyDict_Size(kwds)) {
-    PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments");
-    return nullptr;
-  }
-  Py_ssize_t size = PyTuple_Size(args);
-  if (size != 1) {
-    PyErr_SetString(PyExc_TypeError,
-                    "expected number as argument to bfloat16 constructor");
-    return nullptr;
-  }
-  PyObject* arg = PyTuple_GetItem(args, 0);
-
-  bfloat16 value;
-  if (PyBfloat16_Check(arg)) {
-    Py_INCREF(arg);
-    return arg;
-  } else if (CastToBfloat16(arg, &value)) {
-    return PyBfloat16_FromBfloat16(value).release();
-  } else if (PyArray_Check(arg)) {
-    PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
-    if (PyArray_TYPE(arr) != npy_bfloat16) {
-      return PyArray_Cast(arr, npy_bfloat16);
-    } else {
-      Py_INCREF(arg);
-      return arg;
-    }
-  }
-  PyErr_Format(PyExc_TypeError, "expected number, got %s",
-               arg->ob_type->tp_name);
-  return nullptr;
-}
-
-// Comparisons on PyBfloat16s.
-PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) {
-  bfloat16 x, y;
-  if (!SafeCastToBfloat16(a, &x) || !SafeCastToBfloat16(b, &y)) {
-    return PyGenericArrType_Type.tp_richcompare(a, b, op);
-  }
-  bool result;
-  switch (op) {
-    case Py_LT:
-      result = x < y;
-      break;
-    case Py_LE:
-      result = x <= y;
-      break;
-    case Py_EQ:
-      result = x == y;
-      break;
-    case Py_NE:
-      result = x != y;
-      break;
-    case Py_GT:
-      result = x > y;
-      break;
-    case Py_GE:
-      result = x >= y;
-      break;
-    default:
-      LOG(FATAL) << "Invalid op type " << op;
-  }
-  return PyBool_FromLong(result);
-}
-
-// Implementation of repr() for PyBfloat16.
-PyObject* PyBfloat16_Repr(PyObject* self) {
-  bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
-  std::string v = absl::StrCat(static_cast<float>(x));
-  return PyUnicode_FromString(v.c_str());
-}
-
-// Implementation of str() for PyBfloat16.
-PyObject* PyBfloat16_Str(PyObject* self) {
-  bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
-  std::string v = absl::StrCat(static_cast<float>(x));
-  return PyUnicode_FromString(v.c_str());
-}
-
-// Hash function for PyBfloat16. We use the identity function, which is a weak
-// hash function.
-Py_hash_t PyBfloat16_Hash(PyObject* self) {
-  bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
-  return x.value;
-}
-
-// Python type for PyBfloat16 objects.
-PyTypeObject PyBfloat16_Type = {
-    PyVarObject_HEAD_INIT(nullptr, 0) "bfloat16",  // tp_name
-    sizeof(PyBfloat16),                            // tp_basicsize
-    0,                                             // tp_itemsize
-    nullptr,                                       // tp_dealloc
-#if PY_VERSION_HEX < 0x03080000
-    nullptr,  // tp_print
-#else
-    0,  // tp_vectorcall_offset
-#endif
-    nullptr,               // tp_getattr
-    nullptr,               // tp_setattr
-    nullptr,               // tp_compare / tp_reserved
-    PyBfloat16_Repr,       // tp_repr
-    &PyBfloat16_AsNumber,  // tp_as_number
-    nullptr,               // tp_as_sequence
-    nullptr,               // tp_as_mapping
-    PyBfloat16_Hash,       // tp_hash
-    nullptr,               // tp_call
-    PyBfloat16_Str,        // tp_str
-    nullptr,               // tp_getattro
-    nullptr,               // tp_setattro
-    nullptr,               // tp_as_buffer
-                           // tp_flags
-    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
-    "bfloat16 floating-point values",  // tp_doc
-    nullptr,                           // tp_traverse
-    nullptr,                           // tp_clear
-    PyBfloat16_RichCompare,            // tp_richcompare
-    0,                                 // tp_weaklistoffset
-    nullptr,                           // tp_iter
-    nullptr,                           // tp_iternext
-    nullptr,                           // tp_methods
-    nullptr,                           // tp_members
-    nullptr,                           // tp_getset
-    nullptr,                           // tp_base
-    nullptr,                           // tp_dict
-    nullptr,                           // tp_descr_get
-    nullptr,                           // tp_descr_set
-    0,                                 // tp_dictoffset
-    nullptr,                           // tp_init
-    nullptr,                           // tp_alloc
-    PyBfloat16_New,                    // tp_new
-    nullptr,                           // tp_free
-    nullptr,                           // tp_is_gc
-    nullptr,                           // tp_bases
-    nullptr,                           // tp_mro
-    nullptr,                           // tp_cache
-    nullptr,                           // tp_subclasses
-    nullptr,                           // tp_weaklist
-    nullptr,                           // tp_del
-    0,                                 // tp_version_tag
-};
-
-// Numpy support
-
-PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
-
-PyArray_Descr NPyBfloat16_Descr = {
-    PyObject_HEAD_INIT(nullptr)  //
-                                 /*typeobj=*/
-    (&PyBfloat16_Type),
-    // We must register bfloat16 with a kind other than "f", because numpy
-    // considers two types with the same kind and size to be equal, but
-    // float16 != bfloat16.
-    // The downside of this is that NumPy scalar promotion does not work with
-    // bfloat16 values.
-    /*kind=*/'V',
-    // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
-    // character is unique.
-    /*type=*/'E',
-    /*byteorder=*/'=',
-    /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
-    /*type_num=*/0,
-    /*elsize=*/sizeof(bfloat16),
-    /*alignment=*/alignof(bfloat16),
-    /*subarray=*/nullptr,
-    /*fields=*/nullptr,
-    /*names=*/nullptr,
-    /*f=*/&NPyBfloat16_ArrFuncs,
-    /*metadata=*/nullptr,
-    /*c_metadata=*/nullptr,
-    /*hash=*/-1,  // -1 means "not computed yet".
-};
-
-// Implementations of NumPy array methods.
-
-PyObject* NPyBfloat16_GetItem(void* data, void* arr) {
-  bfloat16 x;
-  memcpy(&x, data, sizeof(bfloat16));
-  return PyBfloat16_FromBfloat16(x).release();
-}
-
-int NPyBfloat16_SetItem(PyObject* item, void* data, void* arr) {
-  bfloat16 x;
-  if (!CastToBfloat16(item, &x)) {
-    PyErr_Format(PyExc_TypeError, "expected number, got %s",
-                 item->ob_type->tp_name);
-    return -1;
-  }
-  memcpy(data, &x, sizeof(bfloat16));
-  return 0;
-}
-
-void ByteSwap16(void* value) {
-  char* p = reinterpret_cast<char*>(value);
-  std::swap(p[0], p[1]);
-}
-
-int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
-  bfloat16 x;
-  memcpy(&x, a, sizeof(bfloat16));
-
-  bfloat16 y;
-  memcpy(&y, b, sizeof(bfloat16));
-
-  if (x < y) {
-    return -1;
-  }
-  if (y < x) {
-    return 1;
-  }
-  // NaNs sort to the end.
-  if (!Eigen::numext::isnan(x) && Eigen::numext::isnan(y)) {
-    return -1;
-  }
-  if (Eigen::numext::isnan(x) && !Eigen::numext::isnan(y)) {
-    return 1;
-  }
-  return 0;
-}
-
-void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
-                           npy_intp sstride, npy_intp n, int swap, void* arr) {
-  char* dst = reinterpret_cast<char*>(dstv);
-  char* src = reinterpret_cast<char*>(srcv);
-  if (!src) {
-    return;
-  }
-  if (swap) {
-    for (npy_intp i = 0; i < n; i++) {
-      char* r = dst + dstride * i;
-      memcpy(r, src + sstride * i, sizeof(uint16_t));
-      ByteSwap16(r);
-    }
-  } else if (dstride == sizeof(uint16_t) && sstride == sizeof(uint16_t)) {
-    memcpy(dst, src, n * sizeof(uint16_t));
-  } else {
-    for (npy_intp i = 0; i < n; i++) {
-      memcpy(dst + dstride * i, src + sstride * i, sizeof(uint16_t));
-    }
-  }
-}
-
-void NPyBfloat16_CopySwap(void* dst, void* src, int swap, void* arr) {
-  if (!src) {
-    return;
-  }
-  memcpy(dst, src, sizeof(uint16_t));
-  if (swap) {
-    ByteSwap16(dst);
-  }
-}
-
-npy_bool NPyBfloat16_NonZero(void* data, void* arr) {
-  bfloat16 x;
-  memcpy(&x, data, sizeof(x));
-  return x != static_cast<bfloat16>(0);
-}
-
-int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
-  bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw);
-  const float start(buffer[0]);
-  const float delta = static_cast<float>(buffer[1]) - start;
-  for (npy_intp i = 2; i < length; ++i) {
-    buffer[i] = static_cast<bfloat16>(start + i * delta);
-  }
-  return 0;
-}
-
-void NPyBfloat16_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
-                         void* op, npy_intp n, void* arr) {
-  char* c1 = reinterpret_cast<char*>(ip1);
-  char* c2 = reinterpret_cast<char*>(ip2);
-  float acc = 0.0f;
-  for (npy_intp i = 0; i < n; ++i) {
-    bfloat16* const b1 = reinterpret_cast<bfloat16*>(c1);
-    bfloat16* const b2 = reinterpret_cast<bfloat16*>(c2);
-    acc += static_cast<float>(*b1) * static_cast<float>(*b2);
-    c1 += is1;
-    c2 += is2;
-  }
-  bfloat16* out = reinterpret_cast<bfloat16*>(op);
-  *out = static_cast<bfloat16>(acc);
-}
-
-int NPyBfloat16_CompareFunc(const void* v1, const void* v2, void* arr) {
-  bfloat16 b1 = *reinterpret_cast<const bfloat16*>(v1);
-  bfloat16 b2 = *reinterpret_cast<const bfloat16*>(v2);
-  if (b1 < b2) {
-    return -1;
-  }
-  if (b1 > b2) {
-    return 1;
-  }
-  return 0;
-}
-
-int NPyBfloat16_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind,
-                           void* arr) {
-  const bfloat16* bdata = reinterpret_cast<const bfloat16*>(data);
-  float max_val = -std::numeric_limits<float>::infinity();
-  for (npy_intp i = 0; i < n; ++i) {
-    if (static_cast<float>(bdata[i]) > max_val) {
-      max_val = static_cast<float>(bdata[i]);
-      *max_ind = i;
-    }
-  }
-  return 0;
-}
-
-int NPyBfloat16_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,
-                           void* arr) {
-  const bfloat16* bdata = reinterpret_cast<const bfloat16*>(data);
-  float min_val = std::numeric_limits<float>::infinity();
-  for (npy_intp i = 0; i < n; ++i) {
-    if (static_cast<float>(bdata[i]) < min_val) {
-      min_val = static_cast<float>(bdata[i]);
-      *min_ind = i;
-    }
-  }
-  return 0;
-}
-
-// NumPy casts
-
-template <typename T, typename Enable = void>
-struct TypeDescriptor {
-  // typedef ... T;  // Representation type in memory for NumPy values of type
-  // static int Dtype() { return NPY_...; }  // Numpy type number for T.
-};
-
-template <>
-struct TypeDescriptor<bfloat16> {
-  typedef bfloat16 T;
-  static int Dtype() { return npy_bfloat16; }
-};
-
-template <>
-struct TypeDescriptor<uint8> {
-  typedef uint8 T;
-  static int Dtype() { return NPY_UINT8; }
-};
-
-template <>
-struct TypeDescriptor<uint16> {
-  typedef uint16 T;
-  static int Dtype() { return NPY_UINT16; }
-};
-
-// We register "int", "long", and "long long" types for portability across
-// Linux, where "int" and "long" are the same type, and Windows, where "long"
-// and "longlong" are the same type.
-template <>
-struct TypeDescriptor<unsigned int> {
-  typedef unsigned int T;
-  static int Dtype() { return NPY_UINT; }
-};
-
-template <>
-struct TypeDescriptor<unsigned long> {  // NOLINT
-  typedef unsigned long T;              // NOLINT
-  static int Dtype() { return NPY_ULONG; }
-};
-
-template <>
-struct TypeDescriptor<unsigned long long> {  // NOLINT
-  typedef unsigned long long T;              // NOLINT
-  static int Dtype() { return NPY_ULONGLONG; }
-};
-
-template <>
-struct TypeDescriptor<int8> {
-  typedef int8 T;
-  static int Dtype() { return NPY_INT8; }
-};
-
-template <>
-struct TypeDescriptor<int16> {
-  typedef int16 T;
-  static int Dtype() { return NPY_INT16; }
-};
-
-template <>
-struct TypeDescriptor<int> {
-  typedef int T;
-  static int Dtype() { return NPY_INT; }
-};
-
-template <>
-struct TypeDescriptor<long> {  // NOLINT
-  typedef long T;              // NOLINT
-  static int Dtype() { return NPY_LONG; }
-};
-
-template <>
-struct TypeDescriptor<long long> {  // NOLINT
-  typedef long long T;              // NOLINT
-  static int Dtype() { return NPY_LONGLONG; }
-};
-
-template <>
-struct TypeDescriptor<bool> {
-  typedef int8 T;
-  static int Dtype() { return NPY_BOOL; }
-};
-
-template <>
-struct TypeDescriptor<Eigen::half> {
-  typedef Eigen::half T;
-  static int Dtype() { return NPY_HALF; }
-};
-
-template <>
-struct TypeDescriptor<float> {
-  typedef float T;
-  static int Dtype() { return NPY_FLOAT; }
-};
-
-template <>
-struct TypeDescriptor<double> {
-  typedef double T;
-  static int Dtype() { return NPY_DOUBLE; }
-};
-
-template <>
-struct TypeDescriptor<complex64> {
-  typedef complex64 T;
-  static int Dtype() { return NPY_COMPLEX64; }
-};
-
-template <>
-struct TypeDescriptor<complex128> {
-  typedef complex128 T;
-  static int Dtype() { return NPY_COMPLEX128; }
-};
-
-// Performs a NumPy array cast from type 'From' to 'To'.
-template <typename From, typename To>
-void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
-             void* toarr) {
-  const auto* from =
-      reinterpret_cast<typename TypeDescriptor<From>::T*>(from_void);
-  auto* to = reinterpret_cast<typename TypeDescriptor<To>::T*>(to_void);
-  for (npy_intp i = 0; i < n; ++i) {
-    to[i] =
-        static_cast<typename TypeDescriptor<To>::T>(static_cast<To>(from[i]));
-  }
-}
-
-// Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy
-// type corresponding to 'T'. If 'cast_is_safe', registers that bfloat16 can be
-// safely coerced to T.
-template <typename T>
-bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
-  if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16,
-                               NPyCast<T, bfloat16>) < 0) {
-    return false;
-  }
-  if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type,
-                               NPyCast<bfloat16, T>) < 0) {
-    return false;
-  }
-  if (cast_is_safe && PyArray_RegisterCanCast(&NPyBfloat16_Descr, numpy_type,
-                                              NPY_NOSCALAR) < 0) {
-    return false;
-  }
-  return true;
-}
-
-template <typename InType, typename OutType, typename Functor>
-struct UnaryUFunc {
-  static std::vector<int> Types() {
-    return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype()};
-  }
-  static void Call(char** args, const npy_intp* dimensions,
-                   const npy_intp* steps, void* data) {
-    const char* i0 = args[0];
-    char* o = args[1];
-    for (npy_intp k = 0; k < *dimensions; k++) {
-      auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
-      *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) = Functor()(x);
-      i0 += steps[0];
-      o += steps[1];
-    }
-  }
-};
-
-template <typename InType, typename OutType, typename OutType2,
-          typename Functor>
-struct UnaryUFunc2 {
-  static std::vector<int> Types() {
-    return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype(),
-            TypeDescriptor<OutType2>::Dtype()};
-  }
-  static void Call(char** args, const npy_intp* dimensions,
-                   const npy_intp* steps, void* data) {
-    const char* i0 = args[0];
-    char* o0 = args[1];
-    char* o1 = args[2];
-    for (npy_intp k = 0; k < *dimensions; k++) {
-      auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
-      std::tie(*reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o0),
-               *reinterpret_cast<typename TypeDescriptor<OutType2>::T*>(o1)) =
-          Functor()(x);
-      i0 += steps[0];
-      o0 += steps[1];
-      o1 += steps[2];
-    }
-  }
-};
-
-template <typename InType, typename OutType, typename Functor>
-struct BinaryUFunc {
-  static std::vector<int> Types() {
-    return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType>::Dtype(),
-            TypeDescriptor<OutType>::Dtype()};
-  }
-  static void Call(char** args, const npy_intp* dimensions,
-                   const npy_intp* steps, void* data) {
-    const char* i0 = args[0];
-    const char* i1 = args[1];
-    char* o = args[2];
-    for (npy_intp k = 0; k < *dimensions; k++) {
-      auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
-      auto y = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i1);
-      *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
-          Functor()(x, y);
-      i0 += steps[0];
-      i1 += steps[1];
-      o += steps[2];
-    }
-  }
-};
-
-template <typename InType, typename InType2, typename OutType, typename Functor>
-struct BinaryUFunc2 {
-  static std::vector<int> Types() {
-    return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType2>::Dtype(),
-            TypeDescriptor<OutType>::Dtype()};
-  }
-  static void Call(char** args, const npy_intp* dimensions,
-                   const npy_intp* steps, void* data) {
-    const char* i0 = args[0];
-    const char* i1 = args[1];
-    char* o = args[2];
-    for (npy_intp k = 0; k < *dimensions; k++) {
-      auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
-      auto y =
-          *reinterpret_cast<const typename TypeDescriptor<InType2>::T*>(i1);
-      *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
-          Functor()(x, y);
-      i0 += steps[0];
-      i1 += steps[1];
-      o += steps[2];
-    }
-  }
-};
-
-template <typename UFunc>
-bool RegisterUFunc(PyObject* numpy, const char* name) {
-  std::vector<int> types = UFunc::Types();
-  PyUFuncGenericFunction fn =
-      reinterpret_cast<PyUFuncGenericFunction>(UFunc::Call);
-  Safe_PyObjectPtr ufunc_obj = make_safe(PyObject_GetAttrString(numpy, name));
-  if (!ufunc_obj) {
-    return false;
-  }
-  PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
-  if (static_cast<int>(types.size()) != ufunc->nargs) {
-    PyErr_Format(PyExc_AssertionError,
-                 "ufunc %s takes %d arguments, loop takes %lu", name,
-                 ufunc->nargs, types.size());
-    return false;
-  }
-  if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16, fn,
-                                  const_cast<int*>(types.data()),
-                                  nullptr) < 0) {
-    return false;
-  }
-  return true;
-}
-
-namespace ufuncs {
-
-struct Add {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a + b; }
-};
-struct Subtract {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a - b; }
-};
-struct Multiply {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a * b; }
-};
-struct TrueDivide {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a / b; }
-};
-
-std::pair<float, float> divmod(float a, float b) {
-  if (b == 0.0f) {
-    float nan = std::numeric_limits<float>::quiet_NaN();
-    return {nan, nan};
-  }
-  float mod = std::fmod(a, b);
-  float div = (a - mod) / b;
-  if (mod != 0.0f) {
-    if ((b < 0.0f) != (mod < 0.0f)) {
-      mod += b;
-      div -= 1.0f;
-    }
-  } else {
-    mod = std::copysign(0.0f, b);
-  }
-
-  float floordiv;
-  if (div != 0.0f) {
-    floordiv = std::floor(div);
-    if (div - floordiv > 0.5f) {
-      floordiv += 1.0f;
-    }
-  } else {
-    floordiv = std::copysign(0.0f, a / b);
-  }
-  return {floordiv, mod};
-}
-
-struct FloorDivide {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(divmod(static_cast<float>(a), static_cast<float>(b)).first);
-  }
-};
-struct Remainder {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(
-        divmod(static_cast<float>(a), static_cast<float>(b)).second);
-  }
-};
-struct DivmodUFunc {
-  static std::vector<int> Types() {
-    return {npy_bfloat16, npy_bfloat16, npy_bfloat16, npy_bfloat16};
-  }
-  static void Call(char** args, npy_intp* dimensions, npy_intp* steps,
-                   void* data) {
-    const char* i0 = args[0];
-    const char* i1 = args[1];
-    char* o0 = args[2];
-    char* o1 = args[3];
-    for (npy_intp k = 0; k < *dimensions; k++) {
-      bfloat16 x = *reinterpret_cast<const bfloat16*>(i0);
-      bfloat16 y = *reinterpret_cast<const bfloat16*>(i1);
-      float floordiv, mod;
-      std::tie(floordiv, mod) =
-          divmod(static_cast<float>(x), static_cast<float>(y));
-      *reinterpret_cast<bfloat16*>(o0) = bfloat16(floordiv);
-      *reinterpret_cast<bfloat16*>(o1) = bfloat16(mod);
-      i0 += steps[0];
-      i1 += steps[1];
-      o0 += steps[2];
-      o1 += steps[3];
-    }
-  }
-};
-struct Fmod {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(std::fmod(static_cast<float>(a), static_cast<float>(b)));
-  }
-};
-struct Negative {
-  bfloat16 operator()(bfloat16 a) { return -a; }
-};
-struct Positive {
-  bfloat16 operator()(bfloat16 a) { return a; }
-};
-struct Power {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(std::pow(static_cast<float>(a), static_cast<float>(b)));
-  }
-};
-struct Abs {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::abs(static_cast<float>(a)));
-  }
-};
-struct Cbrt {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::cbrt(static_cast<float>(a)));
-  }
-};
-struct Ceil {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::ceil(static_cast<float>(a)));
-  }
-};
-struct CopySign {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(
-        std::copysign(static_cast<float>(a), static_cast<float>(b)));
-  }
-};
-struct Exp {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::exp(static_cast<float>(a)));
-  }
-};
-struct Exp2 {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::exp2(static_cast<float>(a)));
-  }
-};
-struct Expm1 {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::expm1(static_cast<float>(a)));
-  }
-};
-struct Floor {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::floor(static_cast<float>(a)));
-  }
-};
-struct Frexp {
-  std::pair<bfloat16, int> operator()(bfloat16 a) {
-    int exp;
-    float f = std::frexp(static_cast<float>(a), &exp);
-    return {bfloat16(f), exp};
-  }
-};
-struct Heaviside {
-  bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
-    float x = static_cast<float>(bx);
-    if (Eigen::numext::isnan(x)) {
-      return bx;
-    }
-    if (x < 0) {
-      return bfloat16(0.0f);
-    }
-    if (x > 0) {
-      return bfloat16(1.0f);
-    }
-    return h0;  // x == 0
-  }
-};
-struct Conjugate {
-  bfloat16 operator()(bfloat16 a) { return a; }
-};
-struct IsFinite {
-  bool operator()(bfloat16 a) { return std::isfinite(static_cast<float>(a)); }
-};
-struct IsInf {
-  bool operator()(bfloat16 a) { return std::isinf(static_cast<float>(a)); }
-};
-struct IsNan {
-  bool operator()(bfloat16 a) {
-    return Eigen::numext::isnan(static_cast<float>(a));
-  }
-};
-struct Ldexp {
-  bfloat16 operator()(bfloat16 a, int exp) {
-    return bfloat16(std::ldexp(static_cast<float>(a), exp));
-  }
-};
-struct Log {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::log(static_cast<float>(a)));
-  }
-};
-struct Log2 {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::log2(static_cast<float>(a)));
-  }
-};
-struct Log10 {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::log10(static_cast<float>(a)));
-  }
-};
-struct Log1p {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::log1p(static_cast<float>(a)));
-  }
-};
-struct LogAddExp {
-  bfloat16 operator()(bfloat16 bx, bfloat16 by) {
-    float x = static_cast<float>(bx);
-    float y = static_cast<float>(by);
-    if (x == y) {
-      // Handles infinities of the same sign.
-      return bfloat16(x + std::log(2.0f));
-    }
-    float out = std::numeric_limits<float>::quiet_NaN();
-    if (x > y) {
-      out = x + std::log1p(std::exp(y - x));
-    } else if (x < y) {
-      out = y + std::log1p(std::exp(x - y));
-    }
-    return bfloat16(out);
-  }
-};
-struct LogAddExp2 {
-  bfloat16 operator()(bfloat16 bx, bfloat16 by) {
-    float x = static_cast<float>(bx);
-    float y = static_cast<float>(by);
-    if (x == y) {
-      // Handles infinities of the same sign.
-      return bfloat16(x + 1.0f);
-    }
-    float out = std::numeric_limits<float>::quiet_NaN();
-    if (x > y) {
-      out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
-    } else if (x < y) {
-      out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
-    }
-    return bfloat16(out);
-  }
-};
-struct Modf {
-  std::pair<bfloat16, bfloat16> operator()(bfloat16 a) {
-    float integral;
-    float f = std::modf(static_cast<float>(a), &integral);
-    return {bfloat16(f), bfloat16(integral)};
-  }
-};
-
-struct Reciprocal {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(1.f / static_cast<float>(a));
-  }
-};
-struct Rint {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::rint(static_cast<float>(a)));
-  }
-};
-struct Sign {
-  bfloat16 operator()(bfloat16 a) {
-    float f(a);
-    if (f < 0) {
-      return bfloat16(-1);
-    }
-    if (f > 0) {
-      return bfloat16(1);
-    }
-    return a;
-  }
-};
-struct SignBit {
-  bool operator()(bfloat16 a) { return std::signbit(static_cast<float>(a)); }
-};
-struct Sqrt {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::sqrt(static_cast<float>(a)));
-  }
-};
-struct Square {
-  bfloat16 operator()(bfloat16 a) {
-    float f(a);
-    return bfloat16(f * f);
-  }
-};
-struct Trunc {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::trunc(static_cast<float>(a)));
-  }
-};
-
-// Trigonometric functions
-struct Sin {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::sin(static_cast<float>(a)));
-  }
-};
-struct Cos {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::cos(static_cast<float>(a)));
-  }
-};
-struct Tan {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::tan(static_cast<float>(a)));
-  }
-};
-struct Arcsin {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::asin(static_cast<float>(a)));
-  }
-};
-struct Arccos {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::acos(static_cast<float>(a)));
-  }
-};
-struct Arctan {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::atan(static_cast<float>(a)));
-  }
-};
-struct Arctan2 {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(std::atan2(static_cast<float>(a), static_cast<float>(b)));
-  }
-};
-struct Hypot {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(std::hypot(static_cast<float>(a), static_cast<float>(b)));
-  }
-};
-struct Sinh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::sinh(static_cast<float>(a)));
-  }
-};
-struct Cosh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::cosh(static_cast<float>(a)));
-  }
-};
-struct Tanh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::tanh(static_cast<float>(a)));
-  }
-};
-struct Arcsinh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::asinh(static_cast<float>(a)));
-  }
-};
-struct Arccosh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::acosh(static_cast<float>(a)));
-  }
-};
-struct Arctanh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::atanh(static_cast<float>(a)));
-  }
-};
-struct Deg2rad {
-  bfloat16 operator()(bfloat16 a) {
-    static constexpr float radians_per_degree = M_PI / 180.0f;
-    return bfloat16(static_cast<float>(a) * radians_per_degree);
-  }
-};
-struct Rad2deg {
-  bfloat16 operator()(bfloat16 a) {
-    static constexpr float degrees_per_radian = 180.0f / M_PI;
-    return bfloat16(static_cast<float>(a) * degrees_per_radian);
-  }
-};
-
-struct Eq {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
-};
-struct Ne {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
-};
-struct Lt {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
-};
-struct Gt {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
-};
-struct Le {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
-};
-struct Ge {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
-};
-struct Maximum {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    float fa(a), fb(b);
-    return Eigen::numext::isnan(fa) || fa > fb ? a : b;
-  }
-};
-struct Minimum {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    float fa(a), fb(b);
-    return Eigen::numext::isnan(fa) || fa < fb ? a : b;
-  }
-};
-struct Fmax {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    float fa(a), fb(b);
-    return Eigen::numext::isnan(fb) || fa > fb ? a : b;
-  }
-};
-struct Fmin {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    float fa(a), fb(b);
-    return Eigen::numext::isnan(fb) || fa < fb ? a : b;
-  }
-};
-
-struct LogicalNot {
-  npy_bool operator()(bfloat16 a) { return !a; }
-};
-struct LogicalAnd {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a && b; }
-};
-struct LogicalOr {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a || b; }
-};
-struct LogicalXor {
-  npy_bool operator()(bfloat16 a, bfloat16 b) {
-    return static_cast<bool>(a) ^ static_cast<bool>(b);
-  }
-};
-
-struct NextAfter {
-  bfloat16 operator()(bfloat16 from, bfloat16 to) {
-    uint16_t from_as_int, to_as_int;
-    const uint16_t sign_mask = 1 << 15;
-    float from_as_float(from), to_as_float(to);
-    memcpy(&from_as_int, &from, sizeof(bfloat16));
-    memcpy(&to_as_int, &to, sizeof(bfloat16));
-    if (Eigen::numext::isnan(from_as_float) ||
-        Eigen::numext::isnan(to_as_float)) {
-      return bfloat16(std::numeric_limits<float>::quiet_NaN());
-    }
-    if (from_as_int == to_as_int) {
-      return to;
-    }
-    if (from_as_float == 0) {
-      if (to_as_float == 0) {
-        return to;
-      } else {
-        // Smallest subnormal signed like `to`.
-        uint16_t out_int = (to_as_int & sign_mask) | 1;
-        bfloat16 out;
-        memcpy(&out, &out_int, sizeof(bfloat16));
-        return out;
-      }
-    }
-    uint16_t from_sign = from_as_int & sign_mask;
-    uint16_t to_sign = to_as_int & sign_mask;
-    uint16_t from_abs = from_as_int & ~sign_mask;
-    uint16_t to_abs = to_as_int & ~sign_mask;
-    uint16_t magnitude_adjustment =
-        (from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
-    uint16_t out_int = from_as_int + magnitude_adjustment;
-    bfloat16 out;
-    memcpy(&out, &out_int, sizeof(bfloat16));
-    return out;
-  }
-};
-
-// TODO(phawkins): implement spacing
-
-}  // namespace ufuncs
-
-}  // namespace
-
-// Initializes the module.
-bool Initialize() {
-  import_array1(false);
-  import_umath1(false);
-
-  Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy"));
-  if (!numpy_str) {
-    return false;
-  }
-  Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
-  if (!numpy) {
-    return false;
-  }
-
-  PyBfloat16_Type.tp_base = &PyGenericArrType_Type;
-
-  if (PyType_Ready(&PyBfloat16_Type) < 0) {
-    return false;
-  }
-
-  // Initializes the NumPy descriptor.
-  PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
-  NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
-  NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
-  NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare;
-  NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
-  NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
-  NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
-  NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
-  NPyBfloat16_ArrFuncs.dotfunc = NPyBfloat16_DotFunc;
-  NPyBfloat16_ArrFuncs.compare = NPyBfloat16_CompareFunc;
-  NPyBfloat16_ArrFuncs.argmax = NPyBfloat16_ArgMaxFunc;
-  NPyBfloat16_ArrFuncs.argmin = NPyBfloat16_ArgMinFunc;
-
-  Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
-  npy_bfloat16 = PyArray_RegisterDataType(&NPyBfloat16_Descr);
-  if (npy_bfloat16 < 0) {
-    return false;
-  }
-
-  // Support dtype(bfloat16)
-  if (PyDict_SetItemString(PyBfloat16_Type.tp_dict, "dtype",
-                           reinterpret_cast<PyObject*>(&NPyBfloat16_Descr)) <
-      0) {
-    return false;
-  }
-
-  // Register casts
-  if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<float>(NPY_FLOAT, /*cast_is_safe=*/true)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<double>(NPY_DOUBLE, /*cast_is_safe=*/true)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<bool>(NPY_BOOL, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<uint8>(NPY_UINT8, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<uint16>(NPY_UINT16, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG,  // NOLINT
-                                           /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<unsigned long long>(  // NOLINT
-          NPY_ULONGLONG, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<uint64>(NPY_UINT64, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<int8>(NPY_INT8, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<int16>(NPY_INT16, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<int>(NPY_INT, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<long>(NPY_LONG,  // NOLINT
-                                  /*cast_is_safe=*/false)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<long long>(  // NOLINT
-          NPY_LONGLONG, /*cast_is_safe=*/false)) {
-    return false;
-  }
-  // Following the numpy convention. imag part is dropped when converting to
-  // float.
-  if (!RegisterBfloat16Cast<complex64>(NPY_COMPLEX64, /*cast_is_safe=*/true)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<complex128>(NPY_COMPLEX128,
-                                        /*cast_is_safe=*/true)) {
-    return false;
-  }
-
-  bool ok =
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Add>>(numpy.get(),
-                                                                  "add") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Subtract>>(
-          numpy.get(), "subtract") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Multiply>>(
-          numpy.get(), "multiply") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::TrueDivide>>(
-          numpy.get(), "divide") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::LogAddExp>>(
-          numpy.get(), "logaddexp") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::LogAddExp2>>(
-          numpy.get(), "logaddexp2") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Negative>>(
-          numpy.get(), "negative") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Positive>>(
-          numpy.get(), "positive") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::TrueDivide>>(
-          numpy.get(), "true_divide") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::FloorDivide>>(
-          numpy.get(), "floor_divide") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Power>>(numpy.get(),
-                                                                    "power") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Remainder>>(
-          numpy.get(), "remainder") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Remainder>>(
-          numpy.get(), "mod") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmod>>(numpy.get(),
-                                                                   "fmod") &&
-      RegisterUFunc<ufuncs::DivmodUFunc>(numpy.get(), "divmod") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Abs>>(numpy.get(),
-                                                                 "absolute") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Abs>>(numpy.get(),
-                                                                 "fabs") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Rint>>(numpy.get(),
-                                                                  "rint") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sign>>(numpy.get(),
-                                                                  "sign") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Heaviside>>(
-          numpy.get(), "heaviside") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Conjugate>>(
-          numpy.get(), "conjugate") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Exp>>(numpy.get(),
-                                                                 "exp") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Exp2>>(numpy.get(),
-                                                                  "exp2") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Expm1>>(numpy.get(),
-                                                                   "expm1") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log>>(numpy.get(),
-                                                                 "log") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log2>>(numpy.get(),
-                                                                  "log2") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log10>>(numpy.get(),
-                                                                   "log10") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log1p>>(numpy.get(),
-                                                                   "log1p") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sqrt>>(numpy.get(),
-                                                                  "sqrt") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Square>>(numpy.get(),
-                                                                    "square") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cbrt>>(numpy.get(),
-                                                                  "cbrt") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Reciprocal>>(
-          numpy.get(), "reciprocal") &&
-
-      // Trigonometric functions
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sin>>(numpy.get(),
-                                                                 "sin") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cos>>(numpy.get(),
-                                                                 "cos") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Tan>>(numpy.get(),
-                                                                 "tan") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arcsin>>(numpy.get(),
-                                                                    "arcsin") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arccos>>(numpy.get(),
-                                                                    "arccos") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arctan>>(numpy.get(),
-                                                                    "arctan") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Arctan2>>(
-          numpy.get(), "arctan2") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Hypot>>(numpy.get(),
-                                                                    "hypot") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sinh>>(numpy.get(),
-                                                                  "sinh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cosh>>(numpy.get(),
-                                                                  "cosh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Tanh>>(numpy.get(),
-                                                                  "tanh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arcsinh>>(
-          numpy.get(), "arcsinh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arccosh>>(
-          numpy.get(), "arccosh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arctanh>>(
-          numpy.get(), "arctanh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Deg2rad>>(
-          numpy.get(), "deg2rad") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Rad2deg>>(
-          numpy.get(), "rad2deg") &&
-
-      // Comparison functions
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Eq>>(numpy.get(),
-                                                             "equal") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Ne>>(numpy.get(),
-                                                             "not_equal") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Lt>>(numpy.get(),
-                                                             "less") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Gt>>(numpy.get(),
-                                                             "greater") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Le>>(numpy.get(),
-                                                             "less_equal") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Ge>>(numpy.get(),
-                                                             "greater_equal") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Maximum>>(
-          numpy.get(), "maximum") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Minimum>>(
-          numpy.get(), "minimum") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmax>>(numpy.get(),
-                                                                   "fmax") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmin>>(numpy.get(),
-                                                                   "fmin") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalAnd>>(
-          numpy.get(), "logical_and") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalOr>>(
-          numpy.get(), "logical_or") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalXor>>(
-          numpy.get(), "logical_xor") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::LogicalNot>>(
-          numpy.get(), "logical_not") &&
-
-      // Floating point functions
-      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsFinite>>(numpy.get(),
-                                                                  "isfinite") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsInf>>(numpy.get(),
-                                                               "isinf") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsNan>>(numpy.get(),
-                                                               "isnan") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::SignBit>>(numpy.get(),
-                                                                 "signbit") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::CopySign>>(
-          numpy.get(), "copysign") &&
-      RegisterUFunc<UnaryUFunc2<bfloat16, bfloat16, bfloat16, ufuncs::Modf>>(
-          numpy.get(), "modf") &&
-      RegisterUFunc<BinaryUFunc2<bfloat16, int, bfloat16, ufuncs::Ldexp>>(
-          numpy.get(), "ldexp") &&
-      RegisterUFunc<UnaryUFunc2<bfloat16, bfloat16, int, ufuncs::Frexp>>(
-          numpy.get(), "frexp") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Floor>>(numpy.get(),
-                                                                   "floor") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
-                                                                  "ceil") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
-                                                                   "trunc") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::NextAfter>>(
-          numpy.get(), "nextafter");
-
-  return ok;
-}
-
-StatusOr<py::object> Bfloat16Dtype() {
-  if (npy_bfloat16 < 0) {
-    // Not yet initialized. We assume the GIL protects npy_bfloat16.
-    if (!Initialize()) {
-      return InternalError("Bfloat16 numpy type initialization failed.");
-    }
-  }
-  return py::object(reinterpret_cast<PyObject*>(&PyBfloat16_Type),
-                    /*is_borrowed=*/true);
-}
-
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/python/bfloat16.h b/tensorflow/compiler/xla/python/bfloat16.h
deleted file mode 100644
index 9e52d086..0000000
--- a/tensorflow/compiler/xla/python/bfloat16.h
+++ /dev/null
@@ -1,28 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
-#define TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
-
-#include "pybind11/pybind11.h"
-#include "tensorflow/compiler/xla/statusor.h"
-
-namespace xla {
-
-xla::StatusOr<pybind11::object> Bfloat16Dtype();
-
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
diff --git a/tensorflow/compiler/xla/python/bfloat16_test.py b/tensorflow/compiler/xla/python/bfloat16_test.py
deleted file mode 100644
index 4c7321a..0000000
--- a/tensorflow/compiler/xla/python/bfloat16_test.py
+++ /dev/null
@@ -1,440 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test cases for the bfloat16 Python type."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import collections
-import copy
-import itertools
-import math
-
-from absl.testing import absltest
-from absl.testing import parameterized
-
-import numpy as np
-
-from tensorflow.compiler.xla.python import xla_client
-
-bfloat16 = xla_client.bfloat16
-
-
-def numpy_assert_allclose(a, b, **kwargs):
-  a = a.astype(np.float32) if a.dtype == bfloat16 else a
-  b = b.astype(np.float32) if b.dtype == bfloat16 else b
-  return np.testing.assert_allclose(a, b, **kwargs)
-
-
-epsilon = float.fromhex("1.0p-7")
-
-# Values that should round trip exactly to float and back.
-FLOAT_VALUES = [
-    0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-    -1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
-    float("inf"),
-    float("-inf"),
-    float("nan")
-]
-
-
-class Bfloat16Test(parameterized.TestCase):
-  """Tests the non-numpy Python methods of the bfloat16 type."""
-
-  def testRoundTripToFloat(self):
-    for v in FLOAT_VALUES:
-      np.testing.assert_equal(v, float(bfloat16(v)))
-
-  def testRoundTripNumpyTypes(self):
-    for dtype in [np.float16, np.float32, np.float64]:
-      np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
-      np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
-      np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
-      np.testing.assert_equal(
-          np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
-
-  def testRoundTripToInt(self):
-    for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
-      self.assertEqual(v, int(bfloat16(v)))
-
-  # pylint: disable=g-complex-comprehension
-  @parameterized.named_parameters(({
-      "testcase_name": "_" + dtype.__name__,
-      "dtype": dtype
-  } for dtype in [bfloat16, np.float16, np.float32, np.float64]))
-  def testRoundTripToNumpy(self, dtype):
-    for v in FLOAT_VALUES:
-      np.testing.assert_equal(v, bfloat16(dtype(v)))
-      np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
-      np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
-    if dtype != bfloat16:
-      np.testing.assert_equal(
-          np.array(FLOAT_VALUES, dtype),
-          bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
-
-  def testStr(self):
-    self.assertEqual("0", str(bfloat16(0.0)))
-    self.assertEqual("1", str(bfloat16(1.0)))
-    self.assertEqual("-3.5", str(bfloat16(-3.5)))
-    self.assertEqual("0.0078125", str(bfloat16(float.fromhex("1.0p-7"))))
-    self.assertEqual("inf", str(bfloat16(float("inf"))))
-    self.assertEqual("-inf", str(bfloat16(float("-inf"))))
-    self.assertEqual("nan", str(bfloat16(float("nan"))))
-
-  def testRepr(self):
-    self.assertEqual("0", repr(bfloat16(0)))
-    self.assertEqual("1", repr(bfloat16(1)))
-    self.assertEqual("-3.5", repr(bfloat16(-3.5)))
-    self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
-    self.assertEqual("inf", repr(bfloat16(float("inf"))))
-    self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
-    self.assertEqual("nan", repr(bfloat16(float("nan"))))
-
-  def testHash(self):
-    self.assertEqual(0, hash(bfloat16(0.0)))
-    self.assertEqual(0x3f80, hash(bfloat16(1.0)))
-    self.assertEqual(0x7fc0, hash(bfloat16(float("nan"))))
-
-  # Tests for Python operations
-  def testNegate(self):
-    for v in FLOAT_VALUES:
-      np.testing.assert_equal(-v, float(-bfloat16(v)))
-
-  def testAdd(self):
-    np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
-    np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
-    np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
-    np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
-    np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
-    self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
-
-    # Test type promotion against Numpy scalar values.
-    self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
-    self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
-    self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
-    self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
-    self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
-    self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
-    self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
-    self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
-    self.assertEqual(np.float32,
-                     type(bfloat16(3.5) + np.array(2.25, np.float32)))
-    self.assertEqual(np.float32,
-                     type(np.array(3.5, np.float32) + bfloat16(2.25)))
-
-  def testSub(self):
-    np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
-    np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
-    np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
-    np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
-    np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
-    np.testing.assert_equal(
-        float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
-    self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
-
-  def testMul(self):
-    np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
-    np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
-    np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
-    np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
-    self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
-
-  def testDiv(self):
-    self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
-    np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
-    np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
-    np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
-    np.testing.assert_equal(
-        float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
-    self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
-
-  def testLess(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
-
-  def testLessEqual(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
-
-  def testGreater(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
-
-  def testGreaterEqual(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
-
-  def testEqual(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
-
-  def testNotEqual(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
-
-  def testNan(self):
-    a = np.isnan(bfloat16(float("nan")))
-    self.assertTrue(a)
-    numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
-
-    a = np.array([bfloat16(1.34375),
-                  bfloat16(1.4375),
-                  bfloat16(float("nan"))],
-                 dtype=bfloat16)
-    b = np.array(
-        [bfloat16(1.3359375),
-         bfloat16(1.4375),
-         bfloat16(float("nan"))],
-        dtype=bfloat16)
-    numpy_assert_allclose(
-        a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
-
-  def testSort(self):
-    values_to_sort = np.float32(FLOAT_VALUES)
-    sorted_f32 = np.sort(values_to_sort)
-    sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
-    np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
-
-
-BinaryOp = collections.namedtuple("BinaryOp", ["op"])
-
-UNARY_UFUNCS = [
-    np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
-    np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
-    np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
-    np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
-    np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
-]
-
-BINARY_UFUNCS = [
-    np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
-    np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
-    np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
-]
-
-BINARY_PREDICATE_UFUNCS = [
-    np.equal, np.not_equal, np.less, np.greater, np.less_equal,
-    np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
-]
-
-
-class Bfloat16NumPyTest(parameterized.TestCase):
-  """Tests the NumPy integration of the bfloat16 type."""
-
-  def testDtype(self):
-    self.assertEqual(bfloat16, np.dtype(bfloat16))
-
-  def testDeepCopyDoesNotAlterHash(self):
-    # For context, see https://github.com/google/jax/issues/4651. If the hash
-    # value of the type descriptor is not initialized correctly, a deep copy
-    # can change the type hash.
-    dtype = np.dtype(bfloat16)
-    h = hash(dtype)
-    _ = copy.deepcopy(dtype)
-    self.assertEqual(h, hash(dtype))
-
-  def testArray(self):
-    x = np.array([[1, 2, 3]], dtype=bfloat16)
-    self.assertEqual(bfloat16, x.dtype)
-    self.assertEqual("[[1 2 3]]", str(x))
-    np.testing.assert_equal(x, x)
-    numpy_assert_allclose(x, x)
-    self.assertTrue((x == x).all())
-
-  def testComparisons(self):
-    x = np.array([401408, 7, -32], dtype=np.float32)
-    bx = x.astype(bfloat16)
-    y = np.array([82432, 7, 0], dtype=np.float32)
-    by = y.astype(bfloat16)
-    np.testing.assert_equal(x == y, bx == by)
-    np.testing.assert_equal(x != y, bx != by)
-    np.testing.assert_equal(x < y, bx < by)
-    np.testing.assert_equal(x > y, bx > by)
-    np.testing.assert_equal(x <= y, bx <= by)
-    np.testing.assert_equal(x >= y, bx >= by)
-
-  def testEqual2(self):
-    a = np.array([401408], bfloat16)
-    b = np.array([82432], bfloat16)
-    self.assertFalse(a.__eq__(b))
-
-  def testCasts(self):
-    for dtype in [
-        np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
-        np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
-        np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
-    ]:
-      x = np.array([[1, 2, 3]], dtype=dtype)
-      y = x.astype(bfloat16)
-      z = y.astype(dtype)
-      self.assertTrue(np.all(x == y))
-      self.assertEqual(bfloat16, y.dtype)
-      self.assertTrue(np.all(x == z))
-      self.assertEqual(dtype, z.dtype)
-
-  def testConformNumpyComplex(self):
-    for dtype in [np.complex64, np.complex128]:
-      x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
-      y_np = x.astype(np.float32)
-      y_tf = x.astype(bfloat16)
-      numpy_assert_allclose(y_np, y_tf, atol=2e-2)
-
-      z_np = y_np.astype(dtype)
-      z_tf = y_tf.astype(dtype)
-      numpy_assert_allclose(z_np, z_tf, atol=2e-2)
-
-  def testArange(self):
-    np.testing.assert_equal(
-        np.arange(100, dtype=np.float32).astype(bfloat16),
-        np.arange(100, dtype=bfloat16))
-    np.testing.assert_equal(
-        np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
-        np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
-    np.testing.assert_equal(
-        np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
-        np.arange(-0., -7., -0.25, dtype=bfloat16))
-    np.testing.assert_equal(
-        np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
-        np.arange(-16384., 16384., 64., dtype=bfloat16))
-
-  # pylint: disable=g-complex-comprehension
-  @parameterized.named_parameters(({
-      "testcase_name": "_" + op.__name__,
-      "op": op
-  } for op in UNARY_UFUNCS))
-  def testUnaryUfunc(self, op):
-    rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7, 10).astype(bfloat16)
-    numpy_assert_allclose(
-        op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
-
-  @parameterized.named_parameters(({
-      "testcase_name": "_" + op.__name__,
-      "op": op
-  } for op in BINARY_UFUNCS))
-  def testBinaryUfunc(self, op):
-    rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7, 10).astype(bfloat16)
-    y = rng.randn(4, 1, 7, 10).astype(bfloat16)
-    numpy_assert_allclose(
-        op(x, y).astype(np.float32),
-        op(x.astype(np.float32), y.astype(np.float32)),
-        rtol=1e-2)
-
-  @parameterized.named_parameters(({
-      "testcase_name": "_" + op.__name__,
-      "op": op
-  } for op in BINARY_PREDICATE_UFUNCS))
-  def testBinaryPredicateUfunc(self, op):
-    rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7).astype(bfloat16)
-    y = rng.randn(4, 1, 7).astype(bfloat16)
-    np.testing.assert_equal(
-        op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
-
-  @parameterized.named_parameters(({
-      "testcase_name": "_" + op.__name__,
-      "op": op
-  } for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
-  def testPredicateUfunc(self, op):
-    rng = np.random.RandomState(seed=42)
-    shape = (3, 7, 10)
-    posinf_flips = rng.rand(*shape) < 0.1
-    neginf_flips = rng.rand(*shape) < 0.1
-    nan_flips = rng.rand(*shape) < 0.1
-    vals = rng.randn(*shape)
-    vals = np.where(posinf_flips, np.inf, vals)
-    vals = np.where(neginf_flips, -np.inf, vals)
-    vals = np.where(nan_flips, np.nan, vals)
-    vals = vals.astype(bfloat16)
-    np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
-
-  def testDivmod(self):
-    rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7).astype(bfloat16)
-    y = rng.randn(4, 1, 7).astype(bfloat16)
-    o1, o2 = np.divmod(x, y)
-    e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
-    numpy_assert_allclose(o1, e1, rtol=1e-2)
-    numpy_assert_allclose(o2, e2, rtol=1e-2)
-
-  def testModf(self):
-    rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7).astype(bfloat16)
-    o1, o2 = np.modf(x)
-    e1, e2 = np.modf(x.astype(np.float32))
-    numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
-    numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
-
-  def testLdexp(self):
-    rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7).astype(bfloat16)
-    y = rng.randint(-50, 50, (1, 7))
-    numpy_assert_allclose(
-        np.ldexp(x, y).astype(np.float32),
-        np.ldexp(x.astype(np.float32), y),
-        rtol=1e-2,
-        atol=1e-6)
-
-  def testFrexp(self):
-    rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7).astype(bfloat16)
-    mant1, exp1 = np.frexp(x)
-    mant2, exp2 = np.frexp(x.astype(np.float32))
-    np.testing.assert_equal(exp1, exp2)
-    numpy_assert_allclose(mant1, mant2, rtol=1e-2)
-
-  def testNextAfter(self):
-    one = np.array(1., dtype=bfloat16)
-    two = np.array(2., dtype=bfloat16)
-    zero = np.array(0., dtype=bfloat16)
-    nan = np.array(np.nan, dtype=bfloat16)
-    np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
-    np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
-    np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
-    np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
-    np.testing.assert_equal(np.nextafter(one, one), one)
-    smallest_denormal = float.fromhex("1.0p-133")
-    np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
-    np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
-    for a, b in itertools.permutations([0., -0., nan], 2):
-      np.testing.assert_equal(
-          np.nextafter(
-              np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
-          np.nextafter(
-              np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
-
-
-if __name__ == "__main__":
-  absltest.main()
diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc
index 8525225..a673584 100644
--- a/tensorflow/compiler/xla/python/dlpack.cc
+++ b/tensorflow/compiler/xla/python/dlpack.cc
@@ -25,6 +25,7 @@
 #include "include/dlpack/dlpack.h"  // from @dlpack
 #include "pybind11/pytypes.h"
 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
 #include "tensorflow/compiler/xla/python/traceback.h"
@@ -45,15 +46,16 @@
 struct DLPackTensor {
   ~DLPackTensor();
 
-  // At most one of buffer and buffer_reference/scoped_hold is populated.
+  // At most one of owned_buffer and buffer_reference/external_reference_hold is
+  // populated.
 
-  // `buffer` is populated if we have exclusive (read-write) access.
-  std::shared_ptr<TrackedDeviceBuffer> buffer;
+  // `owned_buffer` is populated if we have exclusive (read-write) access.
+  std::shared_ptr<void> owned_buffer;
 
-  // `buffer_reference` and `scoped_hold` are populated if we have
+  // `buffer_reference` and `external_reference_hold` are populated if we have
   // shared (read-only) access.
   py::object buffer_reference;
-  absl::optional<PjRtBuffer::ScopedHold> scoped_hold;
+  std::unique_ptr<PjRtBuffer::ExternalReferenceHold> external_reference_hold;
 
   std::vector<int64> shape;
   std::vector<int64> strides;
@@ -214,11 +216,9 @@
 }
 
 StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
-  const se::Platform* platform =
-      device.local_device_state()->executor()->platform();
-  if (platform->id() == se::host::kHostPlatformId) {
+  if (device.client()->platform_id() == kCpuId) {
     return kDLCPU;
-  } else if (platform->id() == se::cuda::kCudaPlatformId) {
+  } else if (device.client()->platform_id() == kGpuId) {
     return kDLGPU;
   }
   return InvalidArgument("Device %s cannot be used as a DLPack device.",
@@ -228,7 +228,7 @@
 StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
   DLContext context;
   TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
-  context.device_id = device.local_device_id();
+  context.device_id = device.local_hardware_id();
   return context;
 }
 
@@ -241,14 +241,14 @@
             "DLPack CPU device type mismatch with PjRtClient platform %s",
             client.platform_name());
       }
-      return client.LookupLocalDevice(context.device_id);
+      return client.LookupAddressableDevice(context.device_id);
     case kDLGPU:
       if (client.platform_id() != kGpuId) {
         return InvalidArgument(
             "DLPack GPU device type mismatch with PjRtClient platform %s",
             client.platform_name());
       }
-      return client.LookupLocalDevice(context.device_id);
+      return client.LookupAddressableDevice(context.device_id);
     default:
       return InvalidArgument("Unknown/unsupported DLPack device type %d",
                              context.device_type);
@@ -271,33 +271,35 @@
   if (take_ownership) {
     // Block on outstanding operations, so that it is safe to read or mutate the
     // returned buffer.
-    StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or =
-        buffer->buffer()->Release(/*wait_for_operations_to_complete=*/true);
+    StatusOr<absl::optional<std::shared_ptr<void>>> buffer_or =
+        buffer->buffer()->ReleaseDeviceMemoryOwnership(
+            /*wait_for_operations_to_complete=*/true);
     if (!buffer_or.ok()) {
       return InvalidArgument(
           "Buffer synchronization failed converting to DLPack tensor: %s",
           buffer_or.status().ToString());
     }
-    pack->buffer = buffer_or.ConsumeValueOrDie();
-    if (!pack->buffer) {
+    absl::optional<std::shared_ptr<void>> owned_buffer_opt =
+        buffer_or.ConsumeValueOrDie();
+    if (!owned_buffer_opt.has_value()) {
       return InvalidArgument(
           "Cannot convert deleted/invalid buffer to DLPack tensor.");
     }
-    TF_RET_CHECK(pack->buffer->device_memory().size() == 1);
-    dt.data = pack->buffer->device_memory().front().opaque();
+    pack->owned_buffer = owned_buffer_opt.value();
+    dt.data = pack->owned_buffer.get();
   } else {
     // Block on outstanding operations, so that it is safe to read or mutate the
     // returned buffer.
     TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
     pack->buffer_reference = py::reinterpret_borrow<py::object>(py_buffer);
-    pack->scoped_hold.emplace(
-        buffer->buffer()->GetBufferWithExternalReference());
-    dt.data = pack->scoped_hold->buffer()->device_memory().front().opaque();
+    TF_ASSIGN_OR_RETURN(pack->external_reference_hold,
+                        buffer->buffer()->AcquireExternalReference());
+    dt.data = pack->external_reference_hold->OpaqueDeviceMemoryDataPointer();
   }
   pack->tensor.manager_ctx = pack.get();
   pack->tensor.deleter = DLPackTensorDeleter;
   TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
-  dt.ctx.device_id = buffer->buffer()->device()->local_device_id();
+  dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
   dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
   TF_ASSIGN_OR_RETURN(dt.dtype,
                       PrimitiveTypeToDLDataType(
@@ -379,7 +381,10 @@
   // capsule it cannot be used again.
   PyCapsule_SetName(tensor.ptr(), "used_dltensor");
   PyCapsule_SetDestructor(tensor.ptr(), nullptr);
-  auto pjrt_buffer = std::make_unique<PjRtBuffer>(
+  // TODO(zhangqiaorjc): Add a factory method that avoids StreamExecutor
+  // specifics. The challenge may be what generic data structures to use for
+  // definition events.
+  auto pjrt_buffer = std::make_unique<PjRtStreamExecutorBuffer>(
       shape, shape, std::move(device_buffer), client->pjrt_client(), device);
   return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer),
                                     Traceback::Get());
diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc
index 7d216e0..0c6d2078 100644
--- a/tensorflow/compiler/xla/python/jax_jit.cc
+++ b/tensorflow/compiler/xla/python/jax_jit.cc
@@ -57,115 +57,41 @@
 
 // TODO(phawkins): Add support for Tracers.
 // TODO(jblespiau): Use absl Status.
+// TODO(jblespiau): Remove the "xla::" prefixes when not needed.
 
-namespace {
-
-thread_local bool disable_jit;
-void SetDisableJit(bool disable_jit_) { disable_jit = disable_jit_; }
-bool GetDisableJit() { return disable_jit; }
-
-// Describes the abstract shape and dtype of an argument.
-struct ArgSignature {
-  // This is the XLA dtype of the object.
-  xla::PrimitiveType dtype;
-  // JAX arguments can be of weak type, if and only if they are Python scalars
-  // or `DeviceArray` values such that `aval.weak_type` is true.
-  bool weak_type;
-  absl::InlinedVector<int64, 4> shape;
-  bool operator==(const ArgSignature& other) const {
-    return std::tie(dtype, weak_type, shape) ==
-           std::tie(other.dtype, other.weak_type, other.shape);
+std::string ArgSignature::DebugString() const {
+  std::string result = "";
+  if (weak_type) {
+    absl::StrAppend(&result, "weak_");
   }
-  bool operator!=(const ArgSignature& other) const { return !(*this == other); }
-
-  std::string DebugString() const {
-    std::string result = "";
-    if (weak_type) {
-      absl::StrAppend(&result, "weak_");
-    }
-    absl::StrAppend(&result, xla::PrimitiveType_Name(dtype));
-    absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]");
-    return result;
-  }
-};
-
-template <typename H>
-H AbslHashValue(H h, const ArgSignature& s) {
-  h = H::combine(std::move(h), s.dtype);
-  h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size());
-  return h;
+  absl::StrAppend(&result, xla::PrimitiveType_Name(dtype));
+  absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]");
+  return result;
 }
 
-// The signature of Python jitted function call, partitioned into:
-// - dynamic positional arguments (i.e. positional args which are not static)
-// - static positional arguments (i.e. the args associated to static_argnums)
-// - keyword arguments
-// The CallSignature should unambiguously identify a function call, thus,
-// equality is based on:
-// (a) Same PyTree for all dynamic positional arguments and keyword arguments
-// (a) equality of the arguments and keyword arguments ArgSignature
-// (a) equality (delegated to Python) of the static arguments.
-struct CallSignature {
-  struct KwargEntry {
-    // To avoid comparing strings, we intern the kwargs strings.
-    // The compilation cache holds a reference to all the keys.
-    py::handle key;
-    PyTreeDef value_treedef;
-    bool operator==(const KwargEntry& other) const {
-      return key.ptr() == other.key.ptr() &&
-             value_treedef == other.value_treedef;
-    }
-    bool operator!=(const KwargEntry& other) const { return !(*this == other); }
-  };
-
-  // Only contains the arguments associated to `static_argnums`, sorted in the
-  // order of their argnum index.
-  std::vector<py::object> static_args;
-  // A PyTreeDef for each positional dynamic (i.e. not static) argument.
-  std::vector<PyTreeDef> dynamic_positional_args_treedef;
-  // Keyword arguments. Sorted by the keyword name.
-  std::vector<KwargEntry> keyword_args;
-  // Shape and dtype for both the dynamic positional arguments and the keyword
-  // arguments (sorted by keyword name).
-  std::vector<ArgSignature> dynamic_args_signatures;
-  PjRtDevice* device;
-
-  bool operator==(const CallSignature& other) const {
-    return std::tie(dynamic_positional_args_treedef, keyword_args,
-                    dynamic_args_signatures, device) ==
-               std::tie(other.dynamic_positional_args_treedef,
-                        other.keyword_args, other.dynamic_args_signatures,
-                        other.device) &&
-           // `==` on py:objects is the Python `is`. We need equal.
-           std::equal(
-               static_args.begin(), static_args.end(),
-               other.static_args.begin(), other.static_args.end(),
-               [](const py::object& a, const py::object& b) {
-                 try {
-                   return a.equal(b);
-                 } catch (const py::error_already_set& e) {
-                   throw std::invalid_argument(absl::StrCat(
-                       "static arguments should be comparable using __eq__."
-                       "The following error was raised when comparing two "
-                       "objects of types ",
-                       py::cast<std::string>(py::str(py::type::of(a))), " and ",
-                       py::cast<std::string>(py::str(py::type::of(b))),
-                       ". The error was:\n", e.what()));
-                 }
-               });
-  }
-  bool operator!=(const CallSignature& other) const {
-    return !(*this == other);
-  }
-
-  // To be used when we want to keep ownership of Python values referenced by
-  // the `CallSignature` (i.e. when we insert an entry).
-  void IncRef() const;
-  // The destructor of the cache should call this on all entries.
-  void DecRef() const;
-
-  std::string DebugString() const;
-};
+bool CallSignature::operator==(const CallSignature& other) const {
+  return std::tie(dynamic_positional_args_treedef, keyword_args,
+                  dynamic_args_signatures, device) ==
+             std::tie(other.dynamic_positional_args_treedef, other.keyword_args,
+                      other.dynamic_args_signatures, other.device) &&
+         // `==` on py:objects is the Python `is`. We need equal.
+         std::equal(
+             static_args.begin(), static_args.end(), other.static_args.begin(),
+             other.static_args.end(),
+             [](const py::object& a, const py::object& b) {
+               try {
+                 return a.equal(b);
+               } catch (const py::error_already_set& e) {
+                 throw std::invalid_argument(absl::StrCat(
+                     "static arguments should be comparable using __eq__."
+                     "The following error was raised when comparing two "
+                     "objects of types ",
+                     py::cast<std::string>(py::str(py::type::of(a))), " and ",
+                     py::cast<std::string>(py::str(py::type::of(b))),
+                     ". The error was:\n", e.what()));
+               }
+             });
+}
 
 void CallSignature::IncRef() const {
   for (const auto& kw : keyword_args) {
@@ -179,38 +105,13 @@
   }
 }
 
-template <typename H>
-H AbslHashValue(H h, const CallSignature::KwargEntry& kw) {
-  h = H::combine(std::move(h), kw.key.ptr(), kw.value_treedef);
-  return h;
-}
+namespace {
 
-template <typename H>
-H AbslHashValue(H h, const CallSignature& s) {
-  h = H::combine_contiguous(std::move(h),
-                            s.dynamic_positional_args_treedef.data(),
-                            s.dynamic_positional_args_treedef.size());
-  h = H::combine_contiguous(std::move(h), s.keyword_args.data(),
-                            s.keyword_args.size());
-  h = H::combine_contiguous(std::move(h), s.dynamic_args_signatures.data(),
-                            s.dynamic_args_signatures.size());
-  h = H::combine(std::move(h), s.device);
-  for (const auto& static_arg : s.static_args) {
-    ssize_t hash;
-    try {
-      hash = py::hash(static_arg);
-    } catch (const py::error_already_set& e) {
-      throw std::invalid_argument(absl::StrCat(
-          "Non-hashable static arguments are not supported. An error occured "
-          "while trying to hash an object of type ",
-          py::cast<std::string>(py::str(py::type::of(static_arg))), ", ",
-          py::cast<std::string>(py::str(static_arg)), ". The error was:\n",
-          e.what(), "\n"));
-    }
-    h = H::combine(std::move(h), hash);
-  }
-  return h;
-}
+thread_local bool disable_jit;
+void SetDisableJit(bool disable_jit_) { disable_jit = disable_jit_; }
+bool GetDisableJit() { return disable_jit; }
+
+}  // namespace
 
 std::string CallSignature::DebugString() const {
   std::vector<std::string> static_args_str;
@@ -248,6 +149,297 @@
       absl::StrJoin(tree_def_str, " | "));
 }
 
+template <typename H>
+H AbslHashValue(H h, const CallSignature& s) {
+  h = H::combine_contiguous(std::move(h),
+                            s.dynamic_positional_args_treedef.data(),
+                            s.dynamic_positional_args_treedef.size());
+  h = H::combine_contiguous(std::move(h), s.keyword_args.data(),
+                            s.keyword_args.size());
+  h = H::combine_contiguous(std::move(h), s.dynamic_args_signatures.data(),
+                            s.dynamic_args_signatures.size());
+  h = H::combine(std::move(h), s.device);
+  for (const auto& static_arg : s.static_args) {
+    ssize_t hash;
+    try {
+      hash = py::hash(static_arg);
+    } catch (const py::error_already_set& e) {
+      throw std::invalid_argument(absl::StrCat(
+          "Non-hashable static arguments are not supported. An error occured "
+          "while trying to hash an object of type ",
+          py::cast<std::string>(py::str(py::type::of(static_arg))), ", ",
+          py::cast<std::string>(py::str(static_arg)), ". The error was:\n",
+          e.what(), "\n"));
+    }
+    h = H::combine(std::move(h), hash);
+  }
+  return h;
+}
+
+// Filter out static arguments, flatten and concatenate other arguments (i.e.
+// dynamic positional and keyword arguments), filling `arguments` in place.
+Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
+                      absl::Span<int const> static_argnums,
+                      ParsedArgumentsAsBuffers& arguments) {
+  if (static_argnums.size() > args.size()) {
+    return InvalidArgument(
+        "%s", "[jaxjit] Error with static argnums, executing the Python path.");
+  }
+  arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
+                                      static_argnums.size());
+  arguments.signature.dynamic_positional_args_treedef.reserve(
+      args.size() - static_argnums.size());
+
+  // Positional arguments.
+  for (size_t i = 0; i < args.size(); ++i) {
+    if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
+        static_argnums.end()) {
+      PyTreeDef pytree_def;
+      pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args);
+      arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def);
+    } else {
+      arguments.signature.static_args.emplace_back(
+          // borrow is mandatory here.
+          py::reinterpret_borrow<py::object>(args[i]));
+    }
+  }
+
+  // Keyword arguments.
+  std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs.begin(),
+                                                        py_kwargs.end());
+  // We first intern the keys, then sort them (by name, as in the Python path)
+  // (see also PyTreeDef::Flatten) and then create the signatures.
+  // TODO(jblespiau): We should be able to sort the keys by interned-key
+  // pointers, but this requires the Python compilation to do the same.
+  arguments.signature.keyword_args.resize(kwargs.size());
+  for (size_t i = 0; i < kwargs.size(); ++i) {
+    // Intern the key if not already interned.
+    if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) {
+      PyObject* key = kwargs[i].first.ptr();
+      kwargs[i].first.inc_ref();
+      PyUnicode_InternInPlace(&key);
+      arguments.keep_alive_objects.push_back(
+          py::reinterpret_steal<py::object>(key));
+      kwargs[i].first = py::handle(key);
+    }
+  }
+
+  std::sort(kwargs.begin(), kwargs.end(),
+            [](const std::pair<py::handle, py::handle>& a,
+               const std::pair<py::handle, py::handle>& b) {
+              return a.first < b.first;
+            });
+  for (size_t i = 0; i < kwargs.size(); ++i) {
+    arguments.signature.keyword_args[i].key = kwargs[i].first;
+    arguments.signature.keyword_args[i].value_treedef.FlattenInto(
+        kwargs[i].second, arguments.flat_dynamic_args);
+  }
+  return Status::OK();
+}
+
+namespace {
+const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) {
+  static const auto* int64_dt = new py::dtype("int64");
+  static const auto* int32_dt = new py::dtype("int32");
+  static const auto* uint64_dt = new py::dtype("uint64");
+  static const auto* uint32_dt = new py::dtype("uint32");
+  static const auto* float64_dt = new py::dtype("float64");
+  static const auto* float32_dt = new py::dtype("float32");
+  static const auto* complex64_dt = new py::dtype("complex64");
+  static const auto* complex128_dt = new py::dtype("complex128");
+
+  if (dtype.equal(*int64_dt)) {
+    return int32_dt;
+  }
+  if (dtype.equal(*float64_dt)) {
+    return float32_dt;
+  }
+  if (dtype.equal(*uint64_dt)) {
+    return uint32_dt;
+  }
+  if (dtype.equal(*complex128_dt)) {
+    return complex64_dt;
+  }
+
+  return nullptr;
+}
+
+// The equivalent of the Python jax/lazy.py::is_trivial:
+// return (type(lexpr.input) is ArrayVar and
+//         lexpr.dims == tuple(range(len(lexpr.shape))))
+//
+// Expects *only* `None` or a LazyExpr` object.
+bool IsTrivialLazyExpr(py::handle lexpr) {
+  if (lexpr.is_none()) {
+    return true;
+  }
+
+  static const auto* lazy_module =
+      new py::module(py::module::import("jax.lazy"));
+  auto input = py::getattr(lexpr, "input");
+  if (!input.get_type().is(lazy_module->attr("ArrayVar"))) {
+    return false;
+  }
+  py::tuple dims = py::cast<py::tuple>(lexpr.attr("dims"));
+  py::tuple shape = py::cast<py::tuple>(lexpr.attr("shape"));
+
+  for (int i = 0; i < shape.size(); ++i) {
+    if (dims[i].is_none()) {
+      return false;
+    }
+    if (py::cast<int>(dims[i]) != i) {
+      return false;
+    }
+  }
+  return true;
+}
+
+bool IsFloat0(py::array arg) {
+  static const auto* dtypes_module =
+      new py::module(py::module::import("jax.dtypes"));
+  static const auto* float0_dtype =
+      new py::handle(dtypes_module->attr("float0"));
+  return float0_dtype->is(arg.attr("dtype"));
+}
+
+template <typename CppType, typename Pybind11Type>
+std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
+    const py::handle& scalar, xla::PjRtClient* client,
+    xla::PjRtDevice* device) {
+  CppType data = py::cast<Pybind11Type>(scalar);
+  xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
+  return ValueOrThrow(client->BufferFromHostBuffer(
+      &data, shape,
+      xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
+      device));
+}
+
+// Convert a scalar to the associated PjRtBuffer or raises an error if it is
+// not convertible (thus, this must be called after other checks).
+StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer(
+    py::handle scalar, bool jax_enable_x64, xla::PjRtClient* client,
+    xla::PjRtDevice* device) {
+  // Important: In Python, isinstance(True, int) returns True. Thus, we have
+  // to check for bool before int.
+  if (py::isinstance<py::bool_>(scalar)) {
+    return ConvertToScalarBuffer<bool, py::bool_>(scalar, client, device);
+  } else if (py::isinstance<py::int_>(scalar)) {
+    if (jax_enable_x64) {
+      return ConvertToScalarBuffer<int64, py::int_>(scalar, client, device);
+    } else {
+      return ConvertToScalarBuffer<int, py::int_>(scalar, client, device);
+    }
+  } else if (py::isinstance<py::float_>(scalar)) {
+    if (jax_enable_x64) {
+      return ConvertToScalarBuffer<double, py::float_>(scalar, client, device);
+
+    } else {
+      return ConvertToScalarBuffer<float, py::float_>(scalar, client, device);
+    }
+  } else if (PyComplex_Check(scalar.ptr())) {
+    Py_complex result = PyComplex_AsCComplex(scalar.ptr());
+    if (result.real == -1.0 && PyErr_Occurred()) {
+      PyErr_Clear();
+      throw std::runtime_error("Could not convert the complex number");
+    }
+    if (jax_enable_x64) {
+      xla::complex128 data(result.real, result.imag);
+      xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
+      return ValueOrThrow(client->BufferFromHostBuffer(
+          &data, shape,
+          xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
+          nullptr, device));
+    } else {
+      xla::complex64 data(result.real, result.imag);
+      xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
+      return ValueOrThrow(client->BufferFromHostBuffer(
+          &data, shape,
+          xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
+          nullptr, device));
+    }
+  }
+  return InvalidArgument(
+      "%s", absl::StrCat(
+                "Not supported: The C++ jax jit execution path, only accepts "
+                "DeviceArray, Numpy arrays, or Python scalars. Got type ",
+                py::cast<std::string>(py::str(scalar.get_type()))));
+}
+
+}  // namespace
+
+StatusOr<DevicePutResult> DevicePut(pybind11::handle obj, PjRtDevice* to_device,
+                                    bool jax_enable_x64,
+                                    xla::PyClient& pyclient) {
+  static const auto* xla_module =
+      new py::module(py::module::import("jax.interpreters.xla"));
+  const auto& device_array = xla_module->attr("_DeviceArray");
+
+  static const auto* numpy_module = new py::module(py::module::import("numpy"));
+  const auto& np_array = numpy_module->attr("array");
+
+  bool is_py_buffer = py::isinstance<PyBuffer>(obj);
+  if (is_py_buffer) {
+    // PyBuffer necessarily has a trivial LazyExpr, no need to check it.
+    PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj);
+    bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
+    if (buffer->device().contents == to_device) {
+      return DevicePutResult(buffer->buffer(), weak_type);
+    } else {
+      // Performs a device-to-device copy if the devices are on the same
+      // platform.
+      // Buffers from different XLA backends are passed through the host.
+      std::unique_ptr<PjRtBuffer> copied_buffer =
+          ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
+      return DevicePutResult(std::move(copied_buffer), weak_type);
+    }
+
+  } else if (obj.get_type().is(device_array)) {
+    if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) {
+      return InvalidArgument(
+          "Non-trivial lazy expression not supported in C++. "
+          "Falling back to Python.");
+    }
+    PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
+    bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
+    // Same block as in the previous `if (is_py_buffer)`.
+    if (buffer->device().contents == to_device) {
+      return DevicePutResult(buffer->buffer(), weak_type);
+    } else {
+      std::unique_ptr<PjRtBuffer> copied_buffer =
+          ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
+      return DevicePutResult(std::move(copied_buffer), weak_type);
+    }
+  } else if (py::isinstance<py::array>(obj)) {
+    py::array numpy_array = py::cast<py::array>(obj);
+    if (IsFloat0(numpy_array)) {
+      return InvalidArgument(
+          "float0 numpy arrays not supported in C++. "
+          "Falling back to Python.");
+    }
+    // If jax_enable_x64 is not set, we need to coerce 32 bits types.
+    // Note that this is calling back to Python!
+    if (!jax_enable_x64) {
+      const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype());
+      if (to_dtype) {
+        numpy_array = np_array(numpy_array, *to_dtype);
+      }
+    }
+    std::unique_ptr<xla::PjRtBuffer> buffer =
+        ValueOrThrow(pyclient.PjRtBufferFromPyval(
+            numpy_array, to_device,
+            /*force_copy=*/false, /*host_buffer_semantics=*/
+            xla::PjRtClient::HostBufferSemantics::kZeroCopy));
+    return DevicePutResult(std::move(buffer), /*weak_type=*/false);
+  } else {
+    TF_ASSIGN_OR_RETURN(
+        std::unique_ptr<xla::PjRtBuffer> buffer,
+        ScalarToBuffer(obj, jax_enable_x64, to_device->client(), to_device));
+    return DevicePutResult(std::move(buffer), /*weak_type=*/true);
+  }
+}
+
+namespace {
+
 struct CacheEntry {
   std::shared_ptr<xla::PyExecutable> executable;
   PyTreeDef out_pytree_def;
@@ -370,216 +562,6 @@
   }
 }
 
-namespace {
-
-// The equivalent of the Python jax/lazy.py::is_trivial:
-// return (type(lexpr.input) is ArrayVar and
-//         lexpr.dims == tuple(range(len(lexpr.shape))))
-//
-// Expects *only* instances of `DeviceArray`.
-bool IsTrivialLazyExpr(py::handle lexpr) {
-  if (lexpr.is_none()) {
-    return true;
-  }
-
-  static const auto* lazy_module =
-      new py::module(py::module::import("jax.lazy"));
-  auto input = py::getattr(lexpr, "input");
-  if (!input.get_type().is(lazy_module->attr("ArrayVar"))) {
-    return false;
-  }
-  py::tuple dims = py::cast<py::tuple>(lexpr.attr("dims"));
-  py::tuple shape = py::cast<py::tuple>(lexpr.attr("shape"));
-
-  for (int i = 0; i < shape.size(); ++i) {
-    if (dims[i].is_none()) {
-      return false;
-    }
-    if (py::cast<int>(dims[i]) != i) {
-      return false;
-    }
-  }
-  return true;
-}
-
-// The resulting information of the parsing and conversion of the arguments.
-struct ParsedArgumentsAsBuffers {
-  // The call signature will be filled during 2 steps:
-  // - `FlattenArguments` will fill the static arguments and the pytree
-  //    structures
-  // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`.
-  CallSignature signature;
-  // The concatenation of the dynamic positional arguments and the sorted
-  // keyword arguments. We do not need ownership, thus the py::handle.
-  // TODO(jblespiau): We do not need py::object here and py::handle suffice and
-  // will prevent any counter increment.
-  std::vector<py::object> flat_dynamic_args;
-  std::vector<py::object> keep_alive_objects;
-
-  // The following is only valid if the parsing succeeds.
-  std::vector<xla::PjRtBuffer*> arg_buffers;
-  // We may need to keep some objects around, because:
-  // (a) we need to extend the lifetime of objects created within
-  //    `ConvertArgsToBuffers`
-  // (b) `arg_buffers` do not maintain ownership
-  std::vector<absl::variant<std::unique_ptr<xla::PyBuffer>,
-                            std::unique_ptr<xla::PjRtBuffer>>>
-      keep_alive;
-};
-
-// Filter out static arguments, flatten and concatenate other arguments (i.e.
-// dynamic positional and keyword arguments), filling `arguments` in place.
-void FlattenArguments(const py::args& args, const py::kwargs& py_kwargs,
-                      absl::Span<int const> static_argnums,
-                      ParsedArgumentsAsBuffers& arguments) {
-  arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
-                                      static_argnums.size());
-  arguments.signature.dynamic_positional_args_treedef.reserve(
-      args.size() - static_argnums.size());
-
-  // Positional arguments.
-  for (size_t i = 0; i < args.size(); ++i) {
-    if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
-        static_argnums.end()) {
-      PyTreeDef pytree_def;
-      pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args);
-      arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def);
-    } else {
-      arguments.signature.static_args.emplace_back(
-          // borrow is mandatory here.
-          py::reinterpret_borrow<py::object>(args[i]));
-    }
-  }
-
-  // Keyword arguments.
-  std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs.begin(),
-                                                        py_kwargs.end());
-  // We first intern the keys, then sort them (by name, as in the Python path)
-  // (see also PyTreeDef::Flatten) and then create the signatures.
-  // TODO(jblespiau): We should be able to sort the keys by interned-key
-  // pointers, but this requires the Python compilation to do the same.
-  arguments.signature.keyword_args.resize(kwargs.size());
-  for (size_t i = 0; i < kwargs.size(); ++i) {
-    // Intern the key if not already interned.
-    if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) {
-      PyObject* key = kwargs[i].first.ptr();
-      kwargs[i].first.inc_ref();
-      PyUnicode_InternInPlace(&key);
-      arguments.keep_alive_objects.push_back(
-          py::reinterpret_steal<py::object>(key));
-      kwargs[i].first = py::handle(key);
-    }
-  }
-
-  std::sort(kwargs.begin(), kwargs.end(),
-            [](const std::pair<py::handle, py::handle>& a,
-               const std::pair<py::handle, py::handle>& b) {
-              return a.first < b.first;
-            });
-  for (size_t i = 0; i < kwargs.size(); ++i) {
-    arguments.signature.keyword_args[i].key = kwargs[i].first;
-    arguments.signature.keyword_args[i].value_treedef.FlattenInto(
-        kwargs[i].second, arguments.flat_dynamic_args);
-  }
-}
-
-template <typename CppType, typename Pybind11Type>
-std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
-    const py::handle& scalar, xla::PjRtClient* client,
-    xla::PjRtDevice* device) {
-  CppType data = py::cast<Pybind11Type>(scalar);
-  xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
-  return ValueOrThrow(client->BufferFromHostBuffer(
-      &data, shape,
-      xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
-      device));
-}
-
-// Convert a scalar to the associated PjRtBuffer or raises an error if it is
-// not convertible (thus, this must be called after other checks).
-StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer(
-    py::handle scalar, bool jax_enable_x64, xla::PjRtClient* client,
-    xla::PjRtDevice* device) {
-  // Important: In Python, isinstance(True, int) returns True. Thus, we have
-  // to check for bool before int.
-  if (py::isinstance<py::bool_>(scalar)) {
-    return ConvertToScalarBuffer<bool, py::bool_>(scalar, client, device);
-  } else if (py::isinstance<py::int_>(scalar)) {
-    if (jax_enable_x64) {
-      return ConvertToScalarBuffer<int64, py::int_>(scalar, client, device);
-    } else {
-      return ConvertToScalarBuffer<int, py::int_>(scalar, client, device);
-    }
-  } else if (py::isinstance<py::float_>(scalar)) {
-    if (jax_enable_x64) {
-      return ConvertToScalarBuffer<double, py::float_>(scalar, client, device);
-
-    } else {
-      return ConvertToScalarBuffer<float, py::float_>(scalar, client, device);
-    }
-  } else if (PyComplex_Check(scalar.ptr())) {
-    Py_complex result = PyComplex_AsCComplex(scalar.ptr());
-    if (result.real == -1.0 && PyErr_Occurred()) {
-      PyErr_Clear();
-      throw std::runtime_error("Could not convert the complex number");
-    }
-    if (jax_enable_x64) {
-      xla::complex128 data(result.real, result.imag);
-      xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
-      return ValueOrThrow(client->BufferFromHostBuffer(
-          &data, shape,
-          xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
-          nullptr, device));
-    } else {
-      xla::complex64 data(result.real, result.imag);
-      xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
-      return ValueOrThrow(client->BufferFromHostBuffer(
-          &data, shape,
-          xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
-          nullptr, device));
-    }
-  }
-  return InvalidArgument(
-      "%s", absl::StrCat(
-                "Not supported: The C++ jax jit execution path, only accepts "
-                "DeviceArray, Numpy arrays, or Python scalars. Got type ",
-                py::cast<std::string>(py::str(scalar.get_type()))));
-}
-
-const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) {
-  static const auto* int64_dt = new py::dtype("int64");
-  static const auto* int32_dt = new py::dtype("int32");
-  static const auto* uint64_dt = new py::dtype("uint64");
-  static const auto* uint32_dt = new py::dtype("uint32");
-  static const auto* float64_dt = new py::dtype("float64");
-  static const auto* float32_dt = new py::dtype("float32");
-  static const auto* complex64_dt = new py::dtype("complex64");
-  static const auto* complex128_dt = new py::dtype("complex128");
-
-  if (dtype.equal(*int64_dt)) {
-    return int32_dt;
-  }
-  if (dtype.equal(*float64_dt)) {
-    return float32_dt;
-  }
-  if (dtype.equal(*uint64_dt)) {
-    return uint32_dt;
-  }
-  if (dtype.equal(*complex128_dt)) {
-    return complex64_dt;
-  }
-
-  return nullptr;
-}
-
-bool IsFloat0(py::array arg) {
-  static const auto* dtypes_module =
-      new py::module(py::module::import("jax.dtypes"));
-  static const auto* float0_dtype =
-      new py::handle(dtypes_module->attr("float0"));
-  return float0_dtype->is(arg.attr("dtype"));
-}
-
 // Converts flattened arguments contained in ParsedArgumentsAsBuffers in
 // place. If arguments are `DeviceArray`, they must all be on the same `Device`.
 //
@@ -599,9 +581,6 @@
       new py::module(py::module::import("jax.interpreters.xla"));
   const auto& device_array = xla_module->attr("_DeviceArray");
 
-  static const auto* numpy_module = new py::module(py::module::import("numpy"));
-  const auto& np_array = numpy_module->attr("array");
-
   // When the jitted function is not committed, we first check whether any
   // sticky `DeviceArray` is present and on which device they live. See also:
   // https://github.com/google/jax/pull/1884
@@ -652,94 +631,24 @@
   }
   CHECK(data_device);
   arguments.signature.device = data_device;
-  xla::PjRtClient* pjrt_client = data_device->client();
 
   for (py::handle arg : arguments.flat_dynamic_args) {
-    bool is_py_buffer = py::isinstance<PyBuffer>(arg);
-    if (is_py_buffer || arg.get_type().is(device_array)) {
-      PyBuffer* buffer;
-      if (is_py_buffer) {
-        // PyBuffer necessarily has a trivial LazyExpr, no need to check it.
-        buffer = py::cast<xla::PyBuffer*>(arg);
-      } else {
-        if (!IsTrivialLazyExpr(py::getattr(arg, "_lazy_expr"))) {
-          return InvalidArgument(
-              "Non-trivial lazy expression not supported in C++. "
-              "Falling back to Python.");
-        }
-        buffer = py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
-      }
+    TF_ASSIGN_OR_RETURN(DevicePutResult on_device,
+                        DevicePut(arg, data_device, jax_enable_x64, pyclient));
 
-      if (buffer->device().contents == data_device) {
-        arg_buffers.push_back(buffer->buffer());
-      } else {
-        // source and target platforms are the same, but different device.
-        // Perform a device-to-device copy.
-        // buffers from different XLA backends are passed through the host.
-        std::unique_ptr<PjRtBuffer> copied_buffer =
-            ValueOrThrow(buffer->buffer()->CopyToDevice(data_device));
-        arg_buffers.push_back(copied_buffer.get());
-        keep_alive.emplace_back(std::move(copied_buffer));
-      }
-
-      ArgSignature sig;
-      sig.dtype = buffer->shape().element_type();
-      sig.shape.assign(buffer->shape().dimensions().begin(),
-                       buffer->shape().dimensions().end());
-      sig.weak_type = py::cast<py::bool_>(arg.attr("aval").attr("weak_type"));
-      arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
-    } else if (py::isinstance<py::array>(arg)) {
-      // TODO(jblespiau): Can we improve this call? Do we need the underlying
-      // GlobalPyRefManager() and co?
-      py::array numpy_array = py::cast<py::array>(arg);
-      if (IsFloat0(numpy_array)) {
-        return InvalidArgument(
-            "float0 numpy arrays not supported in C++. "
-            "It will fallback to Python.");
-      }
-      // If jax_enable_x64 is not set, we need to coerce 32 bits types.
-      // Note that this is calling back to Python!
-      if (!jax_enable_x64) {
-        const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype());
-        if (to_dtype) {
-          numpy_array = np_array(numpy_array, *to_dtype);
-        }
-      }
-      std::unique_ptr<xla::PyBuffer> buffer =
-          ValueOrThrow(pyclient.BufferFromPyval(
-              numpy_array, data_device,
-              /*force_copy=*/false, /*host_buffer_semantics=*/
-              xla::PjRtClient::HostBufferSemantics::kZeroCopy));
-      arg_buffers.push_back(buffer->buffer());
-
-      ArgSignature sig;
-      sig.dtype = buffer->shape().element_type();
-      sig.weak_type = false;
-      sig.shape.assign(buffer->shape().dimensions().begin(),
-                       buffer->shape().dimensions().end());
-      arguments.signature.dynamic_args_signatures.push_back(sig);
-
-      keep_alive.emplace_back(std::move(buffer));
-    } else {
-      StatusOr<std::unique_ptr<xla::PjRtBuffer>> buffer =
-          ScalarToBuffer(arg, jax_enable_x64, pjrt_client, data_device);
-      if (!buffer.ok()) {
-        return buffer.status();
-      }
-      arg_buffers.push_back(buffer.ValueOrDie().get());
-      ArgSignature sig;
-      sig.dtype = buffer.ValueOrDie()->on_host_shape().element_type();
-      sig.weak_type = true;
-      arguments.signature.dynamic_args_signatures.push_back(sig);
-
-      keep_alive.emplace_back(std::move(buffer).ValueOrDie());
+    PjRtBuffer* buffer = on_device.buffer;
+    arg_buffers.push_back(buffer);
+    if (on_device.owned_buffer) {
+      keep_alive.emplace_back(std::move(on_device.owned_buffer));
     }
+
+    ArgSignature sig(buffer->on_host_shape().element_type(),
+                     buffer->on_host_shape().dimensions(), on_device.weak_type);
+    arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
   }
   return Status::OK();
 }
 
-}  // namespace
-
 CacheEntry* CompiledFunction::GetCacheEntryIfPresent(
     const CallSignature& signature) {
   auto found_iterator = executables_.find(signature);
@@ -860,7 +769,9 @@
     return fun_(*args, **kwargs);
   }
   ParsedArgumentsAsBuffers arguments;
-  FlattenArguments(args, kwargs, static_argnums_, arguments);
+  if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
+    return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
+  }
 
   // The C++ jit do not support Tracers arguments inputs yet. The Python-based
   // jit function will be called if any of the dynamic arguments is unsupported.
diff --git a/tensorflow/compiler/xla/python/jax_jit.h b/tensorflow/compiler/xla/python/jax_jit.h
index 2b1603a..08ab7c8 100644
--- a/tensorflow/compiler/xla/python/jax_jit.h
+++ b/tensorflow/compiler/xla/python/jax_jit.h
@@ -16,10 +16,153 @@
 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
 #define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
 
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
 #include "pybind11/pybind11.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/python/py_client.h"
+#include "tensorflow/compiler/xla/python/pytree.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
 
 namespace xla {
 
+// Describes the abstract shape and dtype of an argument.
+struct ArgSignature {
+  ArgSignature(PrimitiveType dtype, absl::Span<const int64> shape,
+               bool weak_type)
+      : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {}
+  // This is the XLA dtype of the object.
+  const PrimitiveType dtype;
+  const absl::InlinedVector<int64, 4> shape;
+  // JAX arguments can be of weak type, if and only if they are Python scalars
+  // or `DeviceArray` values such that `aval.weak_type` is true.
+  const bool weak_type;
+  bool operator==(const ArgSignature& other) const {
+    return std::tie(dtype, weak_type, shape) ==
+           std::tie(other.dtype, other.weak_type, other.shape);
+  }
+  bool operator!=(const ArgSignature& other) const { return !(*this == other); }
+  std::string DebugString() const;
+};
+
+template <typename H>
+H AbslHashValue(H h, const ArgSignature& s) {
+  h = H::combine(std::move(h), s.dtype);
+  h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size());
+  return h;
+}
+
+// The signature of Python jitted function call, partitioned into:
+// - dynamic positional arguments (i.e. positional args which are not static)
+// - static positional arguments (i.e. the args associated to static_argnums)
+// - keyword arguments
+// The CallSignature should unambiguously identify a function call, thus,
+// equality is based on:
+// (a) Same PyTree for all dynamic positional arguments and keyword arguments
+// (a) equality of the arguments and keyword arguments ArgSignature
+// (a) equality (delegated to Python) of the static arguments.
+struct CallSignature {
+  struct KwargEntry {
+    // To avoid comparing strings, we intern the kwargs strings.
+    // The compilation cache holds a reference to all the keys.
+    pybind11::handle key;
+    PyTreeDef value_treedef;
+    bool operator==(const KwargEntry& other) const {
+      return key.ptr() == other.key.ptr() &&
+             value_treedef == other.value_treedef;
+    }
+    bool operator!=(const KwargEntry& other) const { return !(*this == other); }
+  };
+
+  // Only contains the arguments associated to `static_argnums`, sorted in the
+  // order of their argnum index.
+  std::vector<pybind11::object> static_args;
+  // A PyTreeDef for each positional dynamic (i.e. not static) argument.
+  std::vector<PyTreeDef> dynamic_positional_args_treedef;
+  // Keyword arguments. Sorted by the keyword name.
+  std::vector<KwargEntry> keyword_args;
+  // Shape and dtype for both the dynamic positional arguments and the keyword
+  // arguments (sorted by keyword name).
+  std::vector<ArgSignature> dynamic_args_signatures;
+  PjRtDevice* device;
+
+  bool operator==(const CallSignature& other) const;
+  bool operator!=(const CallSignature& other) const {
+    return !(*this == other);
+  }
+
+  // To be used when we want to keep ownership of Python values referenced by
+  // the `CallSignature` (i.e. when we insert an entry).
+  void IncRef() const;
+  // The destructor of the cache should call this on all entries.
+  void DecRef() const;
+
+  std::string DebugString() const;
+};
+
+template <typename H>
+H AbslHashValue(H h, const CallSignature::KwargEntry& kw) {
+  h = H::combine(std::move(h), kw.key.ptr(), kw.value_treedef);
+  return h;
+}
+
+template <typename H>
+H AbslHashValue(H h, const CallSignature& s);
+
+// The resulting information of the parsing and conversion of the arguments.
+struct ParsedArgumentsAsBuffers {
+  // The call signature will be filled during 2 steps:
+  // - `ParseArguments` will fill the static arguments and the pytree
+  //    structures
+  // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`.
+  CallSignature signature;
+  // The concatenation of the dynamic positional arguments and the sorted
+  // keyword arguments.
+  std::vector<pybind11::object> flat_dynamic_args;
+  std::vector<pybind11::object> keep_alive_objects;
+
+  // The following is only valid if the parsing succeeds.
+  std::vector<xla::PjRtBuffer*> arg_buffers;
+  // We may need to keep these objects around, because:
+  // (a) we need to extend the lifetime of objects created within
+  //    `ConvertArgsToBuffers`
+  // (b) `arg_buffers` do not maintain ownership
+  std::vector<std::unique_ptr<xla::PjRtBuffer>> keep_alive;
+};
+
+// Filter out static arguments, flatten and concatenate other arguments (i.e.
+// dynamic positional and keyword arguments), filling `arguments` in place.
+Status ParseArguments(const pybind11::args& args,
+                      const pybind11::kwargs& py_kwargs,
+                      absl::Span<int const> static_argnums,
+                      ParsedArgumentsAsBuffers& arguments);
+
+struct DevicePutResult {
+  explicit DevicePutResult(PjRtBuffer* b, bool weak_type)
+      : buffer(b), weak_type(weak_type), owned_buffer(nullptr) {}
+  DevicePutResult(std::unique_ptr<PjRtBuffer> new_buffer, bool weak_type)
+      : buffer(new_buffer.get()),
+        weak_type(weak_type),
+        owned_buffer(std::move(new_buffer)) {}
+
+  PjRtBuffer* buffer;
+  bool weak_type;
+  std::unique_ptr<PjRtBuffer> owned_buffer;
+};
+
+// Moves a device-like object to be on device.
+// - If the object is already on device, `owned_buffer` will be nullptr.
+// - If it's not, a new buffer will be created and returned using
+//   `owned_buffer`.
+// In all cases, `buffer` will point to the already existing or newly created
+// buffer.
+// If `obj` is not convertible to a `PjRtBuffer` from C++, an error will be
+// returned; float0 dtype and `_DeviceArray` with non-trivial LazyExpr are not
+// supported yet.
+StatusOr<DevicePutResult> DevicePut(pybind11::handle obj, PjRtDevice* to_device,
+                                    bool jax_enable_x64, PyClient& pyclient);
+
+// The function to call in `xla.cc` to add the bindings for this module.
 void BuildJaxjitSubmodule(pybind11::module& m);
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc
index 3c0f975..92aae35 100644
--- a/tensorflow/compiler/xla/python/outfeed_receiver.cc
+++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc
@@ -18,6 +18,7 @@
 #include <sys/types.h>
 
 #include <memory>
+#include <queue>
 #include <sstream>
 
 #include "absl/container/flat_hash_map.h"
@@ -230,8 +231,8 @@
   callback_ = callback;
   max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
   for (const auto& client : clients) {
-    for (const auto& device : client->devices()) {
-      devices_.push_back(device.get());
+    for (auto device : client->devices()) {
+      devices_.push_back(device);
     }
   }
   CHECK_GT(devices_.size(), 0);
@@ -342,11 +343,7 @@
     const PjRtDevice* device, const Shape& shape) {
   std::shared_ptr<Literal> literal_shared;
 
-  TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
-                      device->GetLocalDeviceState());
-  TF_ASSIGN_OR_RETURN(Literal literal,
-                      local_device->client()->TransferFromOutfeedLocal(
-                          shape, local_device->device_ordinal()));
+  TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape));
 
   return absl::make_unique<Literal>(std::move(literal));
 }
diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc
index 1f39266..f8db0c8 100644
--- a/tensorflow/compiler/xla/python/py_buffer.cc
+++ b/tensorflow/compiler/xla/python/py_buffer.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/compiler/xla/python/py_buffer.h"
 
+#include "absl/base/casts.h"
 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
 #include "tensorflow/compiler/xla/python/types.h"
@@ -75,19 +76,24 @@
   return buffer_->BlockHostUntilReady();
 }
 
+// TODO(zhangqiaorjc): Delete UnsafeBufferPointer.
 StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
-  TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer_->AsShapedBuffer());
-  if (shaped_buffer.on_device_shape().IsTuple()) {
+  if (buffer_->on_device_shape().IsTuple()) {
     return Unimplemented(
         "unsafe_buffer_pointer is not implemented for tuple "
         "buffers.");
   }
-  return absl::bit_cast<std::uintptr_t>(shaped_buffer.root_buffer().opaque());
+
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer::ExternalReferenceHold>
+                          external_reference_hold,
+                      buffer_->AcquireExternalReference());
+  const void* ptr = external_reference_hold->OpaqueDeviceMemoryDataPointer();
+  return absl::bit_cast<std::uintptr_t>(ptr);
 }
 
 StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
-  if (buffer_->device()->local_device_state()->executor()->platform_kind() !=
-      se::PlatformKind::kCuda) {
+  // TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
+  if (buffer_->client()->platform_id() != kGpuId) {
     return InvalidArgument(
         "__cuda_array_interface__ is only defined for NVidia GPU buffers.");
   }
@@ -101,17 +107,20 @@
   }
   TF_RET_CHECK(
       LayoutUtil::IsMonotonicWithDim0Major(buffer_->on_host_shape().layout()));
-  TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer_->AsShapedBuffer());
 
   py::dict result;
-  result["shape"] = IntSpanToTuple(shaped_buffer.on_host_shape().dimensions());
-  TF_ASSIGN_OR_RETURN(py::str typestr,
-                      TypeDescriptorForPrimitiveType(
-                          shaped_buffer.on_host_shape().element_type()));
+  result["shape"] = IntSpanToTuple(buffer_->on_host_shape().dimensions());
+  TF_ASSIGN_OR_RETURN(
+      py::str typestr,
+      TypeDescriptorForPrimitiveType(buffer_->on_host_shape().element_type()));
   result["typestr"] = std::move(typestr);
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer::ExternalReferenceHold>
+                          external_reference_hold,
+                      buffer_->AcquireExternalReference());
+  const void* root_ptr =
+      external_reference_hold->OpaqueDeviceMemoryDataPointer();
   py::tuple data(2);
-  data[0] = py::int_(
-      absl::bit_cast<std::uintptr_t>(shaped_buffer.root_buffer().opaque()));
+  data[0] = py::int_(absl::bit_cast<std::uintptr_t>(root_ptr));
   data[1] = py::bool_(true);  // read-only
   result["data"] = std::move(data);
   result["version"] = py::int_(2);
@@ -124,16 +133,17 @@
 
 // Extra data to be kept alive by the consumer of the buffer protocol.
 struct ExtraBufferInfo {
-  explicit ExtraBufferInfo(PjRtBuffer::ScopedHold device_buffer)
-      : device_buffer(std::move(device_buffer)) {}
+  explicit ExtraBufferInfo(std::unique_ptr<PjRtBuffer::ExternalReferenceHold>
+                               external_reference_hold)
+      : external_reference_hold(std::move(external_reference_hold)) {}
 
   std::string format;
   std::vector<Py_ssize_t> strides;
-  // We keep a reference to the TrackedDeviceBuffer that backs the
-  // PjRtBuffer. This prevents a use-after-free in the event that Delete() is
-  // called on a buffer with an live buffer protocol view. It does however mean
-  // that Delete() sometimes won't actually delete immediately.
-  PjRtBuffer::ScopedHold device_buffer;
+  // We keep an external reference hold to the PjRtBuffer. This prevents a
+  // use-after-free in the event that Delete() is called on a buffer with an
+  // live buffer protocol view. It does however mean that Delete() sometimes
+  // won't actually delete immediately.
+  std::unique_ptr<PjRtBuffer::ExternalReferenceHold> external_reference_hold;
 };
 
 int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) {
@@ -163,9 +173,10 @@
     if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) {
       return InvalidArgument("XLA buffers are read-only.");
     }
-    PjRtBuffer::ScopedHold device_buffer(
-        buffer.GetBufferWithExternalReference());
-    if (!device_buffer.status().ok()) {
+    TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer::ExternalReferenceHold>
+                            external_reference_hold,
+                        buffer.AcquireExternalReference());
+    if (buffer.IsDeleted()) {
       return InvalidArgument("Deleted buffer used in buffer protocol.");
     }
     const Shape& shape = buffer.on_host_shape();
@@ -182,10 +193,11 @@
       return InvalidArgument("Buffer is not in contiguous layout.");
     }
     std::memset(view, 0, sizeof(Py_buffer));
-    CHECK_EQ(device_buffer->device_memory().size(), 1);
-    view->buf =
-        const_cast<void*>(device_buffer->device_memory().front().opaque());
-    auto extra = absl::make_unique<ExtraBufferInfo>(std::move(device_buffer));
+    const void* root_ptr =
+        external_reference_hold->OpaqueDeviceMemoryDataPointer();
+    view->buf = const_cast<void*>(root_ptr);
+    auto extra =
+        absl::make_unique<ExtraBufferInfo>(std::move(external_reference_hold));
     view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
     view->len = ShapeUtil::ByteSizeOf(shape);
     view->readonly = 1;
diff --git a/tensorflow/compiler/xla/python/py_buffer.h b/tensorflow/compiler/xla/python/py_buffer.h
index c7c62d8..efef764 100644
--- a/tensorflow/compiler/xla/python/py_buffer.h
+++ b/tensorflow/compiler/xla/python/py_buffer.h
@@ -57,12 +57,16 @@
   PjRtBuffer* buffer() const { return buffer_.get(); }
 
   ClientAndPtr<PjRtDevice> device() const;
-  const std::string& platform_name() const { return buffer_->platform_name(); }
+  const std::string& platform_name() const {
+    return buffer_->client()->platform_name();
+  }
   bool is_deleted() const { return buffer_->IsDeleted(); }
 
   StatusOr<std::unique_ptr<PyBuffer>> CopyToDevice(
       const ClientAndPtr<PjRtDevice>& dst_device) const;
 
+  int64 OnDeviceSizeInBytes() { return buffer_->OnDeviceSizeInBytes(); }
+
   void Delete() {
     buffer_->Delete();
     npy_value_ = pybind11::none();
diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc
index d42bbdc..7d38472 100644
--- a/tensorflow/compiler/xla/python/py_client.cc
+++ b/tensorflow/compiler/xla/python/py_client.cc
@@ -37,9 +37,10 @@
 
 std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
   std::vector<ClientAndPtr<PjRtDevice>> devices;
-  devices.reserve(pjrt_client_->devices().size());
-  for (const auto& device : pjrt_client_->devices()) {
-    devices.push_back(WrapWithClient(shared_from_this(), device.get()));
+  auto span = pjrt_client_->devices();
+  devices.reserve(span.size());
+  for (PjRtDevice* device : span) {
+    devices.push_back(WrapWithClient(shared_from_this(), device));
   }
   return devices;
 }
@@ -64,9 +65,9 @@
     result[r].resize(num_partitions);
     for (int p = 0; p < num_partitions; ++p) {
       int device_id = device_assignment(r, p);
-      auto iter = pjrt_client_->id_to_device().find(device_id);
-      CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
-      result[r][p] = WrapWithClient(shared_from_this(), iter->second);
+      TF_ASSIGN_OR_RETURN(PjRtDevice * device,
+                          pjrt_client_->LookupDevice(device_id));
+      result[r][p] = WrapWithClient(shared_from_this(), device);
     }
   }
   return result;
@@ -80,14 +81,14 @@
   std::vector<ClientAndPtr<PjRtDevice>> result;
   for (int i = 0; i < num_replicas; ++i) {
     int device_id = device_assignment(i, 0);
-    auto iter = pjrt_client_->id_to_device().find(device_id);
-    CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
-    result.push_back(WrapWithClient(shared_from_this(), iter->second));
+    TF_ASSIGN_OR_RETURN(PjRtDevice * device,
+                        pjrt_client_->LookupDevice(device_id));
+    result.push_back(WrapWithClient(shared_from_this(), device));
   }
   return result;
 }
 
-StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
+StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
     const pybind11::object& argument, PjRtDevice* device, bool force_copy,
     PjRtClient::HostBufferSemantics host_buffer_semantics) {
   if (device == nullptr) {
@@ -95,8 +96,9 @@
     device = pjrt_client_->local_devices().front();
   }
   CHECK(device != nullptr);
-  auto iter = pjrt_client_->id_to_device().find(device->id());
-  if (iter->second != device) {
+  TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
+                      pjrt_client_->LookupDevice(device->id()));
+  if (found_device != device) {
     return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
                            device->DebugString(),
                            pjrt_client_->platform_name());
@@ -118,6 +120,15 @@
                                     c->buf_ptr, c->shape, host_buffer_semantics,
                                     std::move(py_buffer_ref), device));
   }
+  return buffer;
+}
+StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
+    const pybind11::object& argument, PjRtDevice* device, bool force_copy,
+    PjRtClient::HostBufferSemantics host_buffer_semantics) {
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<PjRtBuffer> buffer,
+      PjRtBufferFromPyval(argument, device, force_copy, host_buffer_semantics));
+
   auto traceback = Traceback::Get();
   return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
                                     std::move(traceback));
diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h
index 37f5333..f2690fd 100644
--- a/tensorflow/compiler/xla/python/py_client.h
+++ b/tensorflow/compiler/xla/python/py_client.h
@@ -97,7 +97,9 @@
   const std::string& platform_name() const {
     return pjrt_client_->platform_name();
   }
-  int local_device_count() const { return pjrt_client_->local_device_count(); }
+  int addressable_device_count() const {
+    return pjrt_client_->addressable_device_count();
+  }
   int device_count() const { return pjrt_client_->device_count(); }
   int host_id() const { return pjrt_client_->host_id(); }
 
@@ -121,6 +123,9 @@
     return pjrt_client_->CreateHostToDeviceChannelHandle();
   }
 
+  StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBufferFromPyval(
+      const pybind11::object& argument, PjRtDevice* device, bool force_copy,
+      PjRtClient::HostBufferSemantics host_buffer_semantics);
   StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
       const pybind11::object& argument, PjRtDevice* device, bool force_copy,
       PjRtClient::HostBufferSemantics host_buffer_semantics);
diff --git a/tensorflow/compiler/xla/python/pytree.cc b/tensorflow/compiler/xla/python/pytree.cc
index bf0bb1a8..d9a7a05 100644
--- a/tensorflow/compiler/xla/python/pytree.cc
+++ b/tensorflow/compiler/xla/python/pytree.cc
@@ -32,6 +32,7 @@
 #include "pybind11/pybind11.h"
 #include "pybind11/pytypes.h"
 #include "pybind11/stl.h"
+#include "tensorflow/compiler/xla/python/types.h"
 
 namespace xla {
 
@@ -106,59 +107,66 @@
   }
 }
 
-void PyTreeDef::FlattenInto(py::handle handle,
-                            std::vector<py::object>& leaves) {
+void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
+                            absl::optional<py::function> leaf_predicate) {
   Node node;
   int start_num_nodes = traversal_.size();
   int start_num_leaves = leaves.size();
-  node.kind = GetKind(handle, &node.custom);
-  if (node.kind == Kind::kNone) {
-    // Nothing to do.
-  } else if (node.kind == Kind::kTuple) {
-    py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
-    node.arity = tuple.size();
-    for (py::handle entry : tuple) {
-      FlattenInto(entry, leaves);
-    }
-  } else if (node.kind == Kind::kList) {
-    py::list list = py::reinterpret_borrow<py::list>(handle);
-    node.arity = list.size();
-    for (py::handle entry : list) {
-      FlattenInto(entry, leaves);
-    }
-  } else if (node.kind == Kind::kDict) {
-    py::dict dict = py::reinterpret_borrow<py::dict>(handle);
-    py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
-    if (PyList_Sort(keys.ptr())) {
-      throw std::runtime_error("Dictionary key sort failed.");
-    }
-    for (py::handle key : keys) {
-      FlattenInto(dict[key], leaves);
-    }
-    node.arity = dict.size();
-    node.node_data = std::move(keys);
-  } else if (node.kind == Kind::kCustom) {
-    py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
-    if (out.size() != 2) {
-      throw std::runtime_error(
-          "PyTree custom to_iterable function should return a pair");
-    }
-    node.node_data = out[1];
-    node.arity = 0;
-    for (py::handle entry : py::cast<py::iterable>(out[0])) {
-      ++node.arity;
-      FlattenInto(entry, leaves);
-    }
-  } else if (node.kind == Kind::kNamedTuple) {
-    py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
-    node.arity = tuple.size();
-    node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
-    for (py::handle entry : tuple) {
-      FlattenInto(entry, leaves);
-    }
+  if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
+    leaves.push_back(py::reinterpret_borrow<py::object>(handle));
   } else {
-    assert(node.kind == Kind::kLeaf);
-    leaves.push_back(pybind11::reinterpret_borrow<py::object>(handle));
+    node.kind = GetKind(handle, &node.custom);
+    auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
+      FlattenInto(child, leaves, leaf_predicate);
+    };
+    if (node.kind == Kind::kNone) {
+      // Nothing to do.
+    } else if (node.kind == Kind::kTuple) {
+      py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
+      node.arity = tuple.size();
+      for (py::handle entry : tuple) {
+        recurse(entry);
+      }
+    } else if (node.kind == Kind::kList) {
+      py::list list = py::reinterpret_borrow<py::list>(handle);
+      node.arity = list.size();
+      for (py::handle entry : list) {
+        recurse(entry);
+      }
+    } else if (node.kind == Kind::kDict) {
+      py::dict dict = py::reinterpret_borrow<py::dict>(handle);
+      py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
+      if (PyList_Sort(keys.ptr())) {
+        throw std::runtime_error("Dictionary key sort failed.");
+      }
+      for (py::handle key : keys) {
+        recurse(dict[key]);
+      }
+      node.arity = dict.size();
+      node.node_data = std::move(keys);
+    } else if (node.kind == Kind::kCustom) {
+      py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
+      if (out.size() != 2) {
+        throw std::runtime_error(
+            "PyTree custom to_iterable function should return a pair");
+      }
+      node.node_data = out[1];
+      node.arity = 0;
+      for (py::handle entry : py::cast<py::iterable>(out[0])) {
+        ++node.arity;
+        recurse(entry);
+      }
+    } else if (node.kind == Kind::kNamedTuple) {
+      py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
+      node.arity = tuple.size();
+      node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
+      for (py::handle entry : tuple) {
+        recurse(entry);
+      }
+    } else {
+      assert(node.kind == Kind::kLeaf);
+      leaves.push_back(py::reinterpret_borrow<py::object>(handle));
+    }
   }
   node.num_nodes = traversal_.size() - start_num_nodes + 1;
   node.num_leaves = leaves.size() - start_num_leaves;
@@ -166,10 +174,10 @@
 }
 
 /*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
-PyTreeDef::Flatten(py::handle x) {
+PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
   std::vector<py::object> leaves;
   auto tree = absl::make_unique<PyTreeDef>();
-  tree->FlattenInto(x, leaves);
+  tree->FlattenInto(x, leaves, leaf_predicate);
   return std::make_pair(std::move(leaves), std::move(tree));
 }
 
@@ -618,7 +626,8 @@
 
 void BuildPytreeSubmodule(py::module& m) {
   py::module pytree = m.def_submodule("pytree", "Python tree library");
-  pytree.def("flatten", &PyTreeDef::Flatten);
+  pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
+             py::arg("leaf_predicate") = absl::nullopt);
   pytree.def("tuple", &PyTreeDef::Tuple);
   pytree.def("all_leaves", &PyTreeDef::AllLeaves);
 
diff --git a/tensorflow/compiler/xla/python/pytree.h b/tensorflow/compiler/xla/python/pytree.h
index 69cd93a..c0a99a1 100644
--- a/tensorflow/compiler/xla/python/pytree.h
+++ b/tensorflow/compiler/xla/python/pytree.h
@@ -85,11 +85,13 @@
 
   // Flattens a Pytree into a list of leaves and a PyTreeDef.
   static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
-  Flatten(pybind11::handle x);
+  Flatten(pybind11::handle x,
+          absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
 
   // Recursive helper used to implement Flatten().
-  void FlattenInto(pybind11::handle handle,
-                   std::vector<pybind11::object>& leaves);
+  void FlattenInto(
+      pybind11::handle handle, std::vector<pybind11::object>& leaves,
+      absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
 
   // Tests whether the given list is a flat list of leaves.
   static bool AllLeaves(const pybind11::iterable& x);
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD
index 9d98d0c..3296d29 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD
@@ -26,7 +26,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/compiler/xla/client:executable_build_options",
-        "//tensorflow/compiler/xla/pjrt:pjrt_client",
+        "//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client",
         "//tensorflow/compiler/xla/pjrt:semaphore",
         "//tensorflow/compiler/xla/python/tpu_driver",
         "//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",
@@ -37,6 +37,7 @@
         "//tensorflow/compiler/xla/service:computation_placer",
         "//tensorflow/compiler/xla/service:shaped_buffer",
         "//tensorflow/core/framework:allocator",
+        "//tensorflow/core/platform:casts",
         "//tensorflow/core/platform:env",
         "//tensorflow/core/profiler/lib:traceme",
         "@com_google_absl//absl/memory",
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc
index c6a7480..a9aa218 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc
@@ -37,8 +37,8 @@
 
 TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
                      int core_on_chip)
-    : xla::PjRtDevice(id, /*local_device_state=*/nullptr,
-                      /*device_kind=*/"Cloud TPU", host_id),
+    : xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr,
+                                    /*device_kind=*/"Cloud TPU", host_id),
       coords_(coords),
       core_on_chip_(core_on_chip) {}
 
@@ -531,7 +531,7 @@
           << "Inserting duplicate replica:" << replica;
       executables_[replica] =
           client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
-      addressable_device_logical_ids_.emplace_back(replica, partition);
+      local_logical_device_ids_.emplace_back(replica, partition);
       local_devices_.push_back(device);
     }
   }
@@ -711,8 +711,8 @@
     // long time and we want all cores to be scheduled in parallel.
     thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
                            &execute_semaphore]() {
-      const int replica = addressable_device_logical_ids_[i].first;
-      const int partition = addressable_device_logical_ids_[i].second;
+      const int replica = local_logical_device_ids_[i].first;
+      const int partition = local_logical_device_ids_[i].second;
       RunId run_id;
       auto result = ExecuteHelper(argument_handles, argument_handles[i],
                                   replica, partition, run_id);
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h
index 20c2f74..cc4e447 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h
@@ -24,7 +24,7 @@
 #include "absl/synchronization/notification.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/client/executable_build_options.h"
-#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
@@ -32,13 +32,14 @@
 #include "tensorflow/compiler/xla/status.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/casts.h"
 #include "tensorflow/core/platform/threadpool.h"
 
 namespace xla {
 
 constexpr char kTpuPlatform[] = "tpu";
 
-class TpuDevice : public PjRtDevice {
+class TpuDevice : public PjRtStreamExecutorDevice {
  public:
   TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
             int core_on_chip);
@@ -298,9 +299,8 @@
     return device_assignment_;
   }
 
-  const std::vector<std::pair<int, int>>& addressable_device_logical_ids()
-      const {
-    return addressable_device_logical_ids_;
+  const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
+    return local_logical_device_ids_;
   }
 
   const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
@@ -341,14 +341,16 @@
 
   // The replica and partition indices of device_assignment_ to be run by this
   // client. On single-host platforms without partitioning, this is all replicas
-  // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
-  // case on multi-host platforms. If there are 4 replicas and 2 partitions on a
-  // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
-  std::vector<std::pair<int, int>> addressable_device_logical_ids_;
+  // (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
+  // on multi-host platforms.
+  // If there are 4 replicas and 2 partitions on a single host platform, size of
+  // local_logical_device_ids_ is 4*2 = 8.
+  std::vector<std::pair<int, int>> local_logical_device_ids_;
 
-  // local_devices_[i] is the Device to which addressable_device_logical_ids_[i]
-  // is assigned. shared_ptrs instead of unique_ptrs to play well with the
-  // Python bindings (see xla.cc).
+  // local_devices_[i] is the Device to which local_logical_device_ids_[i] is
+  // assigned.
+  // shared_ptrs instead of unique_ptrs to play well with the Python bindings
+  // (see xla.cc).
   std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
 
   xla::Shape result_shape_;
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc
index 0562ff2..a9fd70b 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc
@@ -186,7 +186,7 @@
 
   py::class_<PyTpuExecutable>(m, "TpuExecutable")
       .def("local_logical_device_ids",
-           &PyTpuExecutable::addressable_device_logical_ids)
+           &PyTpuExecutable::local_logical_device_ids)
       .def("local_devices", &PyTpuExecutable::local_devices)
       .def_property_readonly("client", &PyTpuExecutable::client)
       .def("size_of_generated_code_in_bytes",
diff --git a/tensorflow/compiler/xla/python/types.cc b/tensorflow/compiler/xla/python/types.cc
index 882b38d..40a3e58 100644
--- a/tensorflow/compiler/xla/python/types.cc
+++ b/tensorflow/compiler/xla/python/types.cc
@@ -16,8 +16,8 @@
 #include "tensorflow/compiler/xla/python/types.h"
 
 #include "absl/container/flat_hash_map.h"
-#include "tensorflow/compiler/xla/python/bfloat16.h"
 #include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/python/lib/core/bfloat16.h"
 
 namespace xla {
 
@@ -81,8 +81,8 @@
     case U64:
       return py::dtype::of<uint64>();
     case BF16: {
-      TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
-      return py::dtype::from_args(bfloat16);
+      py::handle bfloat16(tensorflow::Bfloat16Dtype());
+      return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
     }
     case F16:
       return py::dtype("e");  // PEP 3118 code for "float16
@@ -237,10 +237,11 @@
     // We requested an array of uint16 since NumPy doesn't know how
     // to produce our custom bfloat16 type. Reinterpret the array as bfloat16
     // before handing it back to the caller.
-    TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
+    py::handle bfloat16(tensorflow::Bfloat16Dtype());
+    bfloat16.inc_ref();
     array = py::reinterpret_steal<py::array>(
         PyArray_View(reinterpret_cast<PyArrayObject*>(array.ptr()),
-                     reinterpret_cast<PyArray_Descr*>(bfloat16.release().ptr()),
+                     reinterpret_cast<PyArray_Descr*>(bfloat16.ptr()),
                      static_cast<PyTypeObject*>(nullptr)));
   }
   return array;
diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h
index 8c7dc18..60c6f41 100644
--- a/tensorflow/compiler/xla/python/types.h
+++ b/tensorflow/compiler/xla/python/types.h
@@ -25,6 +25,7 @@
 #include "pybind11/numpy.h"
 #include "pybind11/pybind11.h"
 #include "pybind11/stl.h"
+#include "tensorflow/compiler/xla/python/absl_casters.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/shape.h"
 #include "tensorflow/compiler/xla/status.h"
@@ -112,48 +113,6 @@
 namespace pybind11 {
 namespace detail {
 
-// When absl::optional is an alias for std::optional, the type_caster
-// specializations are provided by pybind11.
-#ifndef ABSL_HAVE_STD_OPTIONAL
-// absl::optional
-template <typename T>
-struct type_caster<absl::optional<T>> : optional_caster<absl::optional<T>> {};
-
-template <>
-struct type_caster<absl::nullopt_t> : public void_caster<absl::nullopt_t> {};
-#endif
-
-// absl::Span
-template <typename T>
-struct type_caster<absl::Span<const T>> {
-  using value_conv = make_caster<T>;
-
-  PYBIND11_TYPE_CASTER(absl::Span<const T>,
-                       _("Span[") + value_conv::name + _("]"));
-
-  // absl::Span doesn't hold ownership. We therefore need a temporary array.
-  // Pybind appears to keep type_casters alive until the callee has run.
-  std::vector<T> storage_;
-
-  bool load(handle src, bool convert) {
-    if (!isinstance<sequence>(src)) {
-      return false;
-    }
-    auto seq = reinterpret_borrow<sequence>(src);
-    storage_.clear();
-    storage_.reserve(seq.size());
-    for (const auto& it : seq) {
-      value_conv conv;
-      if (!conv.load(it, convert)) {
-        return false;
-      }
-      storage_.push_back(cast_op<T&&>(std::move(conv)));
-    }
-    value = absl::Span<const T>(storage_);
-    return true;
-  }
-};
-
 // Status, StatusOr. Failing statuses become Python exceptions; Status::OK()
 // becomes None.
 template <>
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index dee1b14..e6539ef 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -40,7 +40,6 @@
 #include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
 #include "tensorflow/compiler/xla/pjrt/tpu_client.h"
-#include "tensorflow/compiler/xla/python/bfloat16.h"
 #include "tensorflow/compiler/xla/python/dlpack.h"
 #include "tensorflow/compiler/xla/python/jax_jit.h"
 #include "tensorflow/compiler/xla/python/ops.h"
@@ -59,6 +58,7 @@
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/platform/errors.h"
+#include "tensorflow/python/lib/core/bfloat16.h"
 #include "tensorflow/stream_executor/platform.h"
 
 namespace xla {
@@ -110,6 +110,8 @@
     throw std::runtime_error("Unable to initialize Numpy API");
   }
 
+  CHECK(tensorflow::RegisterNumpyBfloat16());
+
   // Types
   py::enum_<PrimitiveType>(m, "PrimitiveType")
       .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
@@ -132,7 +134,8 @@
       .value("OPAQUE_TYPE", OPAQUE_TYPE)
       .value("TOKEN", TOKEN);
 
-  m.def("bfloat16_dtype", Bfloat16Dtype);
+  m.def("bfloat16_dtype",
+        []() { return py::handle(tensorflow::Bfloat16Dtype()); });
 
   // Must be before PyClient.compile.
   BuildXlaCompilerSubmodule(m);
@@ -149,7 +152,10 @@
       .def_property_readonly("host_id", &PjRtDevice::host_id,
                              "Integer ID of this device's host.\n\n"
                              "This is always 0 except on multi-host platforms.")
-      .def_property_readonly("platform", &PjRtDevice::platform_name)
+      .def_property_readonly("platform",
+                             [](const PjRtDevice& device) {
+                               return device.client()->platform_name();
+                             })
       .def_property_readonly("device_kind", &PjRtDevice::device_kind)
       .def_property_readonly(
           "client",
@@ -234,7 +240,7 @@
   py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
   py_local_client.def_property_readonly("platform", &PyClient::platform_name)
       .def("device_count", &PyClient::device_count)
-      .def("local_device_count", &PyClient::local_device_count)
+      .def("local_device_count", &PyClient::addressable_device_count)
       .def("devices", &PyClient::Devices)
       .def("local_devices", &PyClient::LocalDevices)
       .def("host_id", &PyClient::host_id)
@@ -343,6 +349,7 @@
             return npy_value_;
           })
       .def("copy_to_device", &PyBuffer::CopyToDevice)
+      .def("on_device_size_in_bytes", &PyBuffer::OnDeviceSizeInBytes)
       .def("delete", &PyBuffer::Delete)
       // The GIL is released within BlockHostUntilReady.
       .def("block_until_ready",
@@ -381,10 +388,10 @@
            [](PyExecutable* exec) {
              auto span = exec->addressable_device_logical_ids();
              // Not on dispatch critical path, so ok to have heap allocation.
-             std::vector<std::pair<int, int>> addressable_device_logical_ids;
-             addressable_device_logical_ids.reserve(span.size());
+             std::vector<std::pair<int, int>> addressable_device_logic_ids;
+             addressable_device_logic_ids.reserve(span.size());
              for (const auto& logical_device_id : span) {
-               addressable_device_logical_ids.push_back(std::make_pair(
+               addressable_device_logic_ids.push_back(std::make_pair(
                    logical_device_id.replica, logical_device_id.partition));
              }
            })
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 4a78e5c..d1d3de9 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -431,7 +431,10 @@
     fn: a PyCapsule object containing the function pointer.
     platform: the target platform.
   """
-  _xla.register_custom_call_target(name, fn, xla_platform_names[platform])
+  # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM"
+  # Since that is hardcoded to CUDA, we are using the following as workaround.
+  _xla.register_custom_call_target(name, fn,
+                                   xla_platform_names.get(platform, platform))
 
 
 # Deprecated. Use register_custom_call_target instead.
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index fc30883..bca3ca5 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -486,6 +486,21 @@
       with self.assertRaises(RuntimeError):
         buffer.block_until_ready()
 
+    def testOnDeviceSizeInBytes(self):
+      if not isinstance(self.backend, xla_client.Client):
+        self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.")
+      arg0 = np.array([])
+      arg1 = np.array([[0., 1., 2.]], np.float32)
+      arg2 = np.array([[3., 4., 5.]], bfloat16)
+      arg0_buffer = self.backend.buffer_from_pyval(arg0)
+      arg1_buffer = self.backend.buffer_from_pyval(arg1)
+      arg2_buffer = self.backend.buffer_from_pyval(arg2)
+      self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0)
+      # OnDeviceSizeInBytes varies depending on the platform. Confirm there's
+      # a reasonable value.
+      self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0)
+      self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0)
+
     def testCopyToHost(self):
       arg0 = np.array([[1., 2.]], np.float32)
       arg1 = np.array([[3., 4.]], np.float32)
diff --git a/tensorflow/compiler/xla/refcounting_hash_map.h b/tensorflow/compiler/xla/refcounting_hash_map.h
index a9d07a7..7b79ea9 100644
--- a/tensorflow/compiler/xla/refcounting_hash_map.h
+++ b/tensorflow/compiler/xla/refcounting_hash_map.h
@@ -73,20 +73,19 @@
                         value_factory) {
     absl::MutexLock lock(&mu_);
     auto it = map_.find(key);
-    // We ensure that the entry has not expired in case deleter was running when
-    // we have entered this block.
     if (it != map_.end()) {
+      // We ensure that the entry has not expired in case deleter was running
+      // when we have entered this block.
       if (std::shared_ptr<V> value = it->second.lock()) {
         return value;
       }
-      map_.erase(it);
     }
 
     // Create entry in the map and then set its value, so the value can
     // contain a pointer back into the map.
     TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key));
     it = map_.emplace(key, std::weak_ptr<V>()).first;
-    std::shared_ptr<V> value(value_unique.release(), Deleter{&it->first, this});
+    std::shared_ptr<V> value(value_unique.release(), Deleter{it->first, *this});
     it->second = value;  // Set the weak ptr to the shared ptr.
     return value;
   }
@@ -108,15 +107,17 @@
 
  private:
   struct Deleter {
-    const K* key;  // Points into parent->map_.
-    RefcountingHashMap* parent;
+    const K& key;  // Points into parent->map_.
+    RefcountingHashMap& parent;
 
     void operator()(V* v) {
       delete v;
-      absl::MutexLock lock(&parent->mu_);
-      auto it = parent->map_.find(*key);
-      if (it != parent->map_.end() && it->second.expired()) {
-        parent->map_.erase(it);
+      absl::MutexLock lock(&parent.mu_);
+      // We must check if that the entry is still expired in case the value was
+      // replaced while the deleter was running.
+      auto it = parent.map_.find(key);
+      if (it != parent.map_.end() && it->second.expired()) {
+        parent.map_.erase(it);
       }
     }
   };
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 482b5f3..9957230 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -936,6 +936,7 @@
         ":hlo_evaluator",
         ":hlo_execution_profile",
         ":hlo_module_config",
+        ":hlo_module_util",
         ":hlo_proto_util",
         ":platform_util",
         ":source_map_util",
@@ -977,6 +978,7 @@
         ":hlo",
         ":hlo_execution_profile",
         ":hlo_module_config",
+        ":hlo_module_util",
         ":platform_util",
         ":service",
         ":shaped_buffer",
@@ -1034,31 +1036,8 @@
     ],
 )
 
-# This flag enables experimental MLIR GPU support.
-config_setting(
-    name = "with_mlir_gpu_support",
-    define_values = {"with_mlir_gpu_support": "true"},
-    visibility = ["//visibility:public"],
-)
-
-# Lets us choose the right GPU plugin depending on whether the experimental MLIR
-# GPU plugin should be used or not.
 cc_library(
     name = "gpu_plugin",
-    deps = select(
-        {
-            ":with_mlir_gpu_support": [
-                ":gpu_plugin_mlir",
-            ],
-            "//conditions:default": [
-                ":gpu_plugin_no_mlir",
-            ],
-        },
-    ),
-)
-
-cc_library(
-    name = "gpu_plugin_no_mlir",
     deps = [
         ":service",
         "//tensorflow/compiler/xla/service/gpu:gpu_compiler",
@@ -1074,17 +1053,6 @@
 )
 
 cc_library(
-    name = "gpu_plugin_mlir",
-    deps = [
-        ":service",
-        "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager",
-        "//tensorflow/core/platform:stream_executor_no_cuda",
-    ] + if_cuda_is_configured([
-        "//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler_impl",
-    ]) + internal_cuda_deps(),
-)
-
-cc_library(
     name = "interpreter_plugin",
     deps = [
         ":service",
@@ -1529,6 +1497,21 @@
 )
 
 cc_library(
+    name = "hlo_module_util",
+    srcs = ["hlo_module_util.cc"],
+    hdrs = ["hlo_module_util.h"],
+    deps = [
+        ":compiler",
+        ":hlo_module_config",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:statusor",
+        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+cc_library(
     name = "hlo_module_group_util",
     srcs = ["hlo_module_group_util.cc"],
     hdrs = ["hlo_module_group_util.h"],
@@ -3046,6 +3029,7 @@
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
+        "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/types:span",
     ],
@@ -3925,6 +3909,7 @@
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/container:flat_hash_set",
     ],
 )
 
@@ -5170,6 +5155,7 @@
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/time",
     ],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 9a725cd..10e19e7 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -4696,6 +4696,10 @@
 
 Status AlgebraicSimplifierVisitor::HandleReduceWindow(
     HloInstruction* reduce_window) {
+  // TODO(b/73062247) Variadic reduce window is not yet supported in simplifier.
+  if (reduce_window->shape().IsTuple()) {
+    return Status::OK();
+  }
   if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) {
     return ReplaceWithNewInstruction(
         reduce_window,
diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc
index a86ab60..d39ba17 100644
--- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc
+++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc
@@ -365,10 +365,13 @@
   auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
     return *a == *b;
   };
+  // Two MPMD AllReduces are identical if they have the same channel_id. Their
+  // operands don't have to be identical.
+  auto eq_operands = [](const HloInstruction*, const HloInstruction*) {
+    return true;
+  };
   if (i1->IsCrossModuleAllReduce()) {
-    return i1->Identical(*i2,
-                         /*eq_operands=*/std::equal_to<const HloInstruction*>(),
-                         eq_computations,
+    return i1->Identical(*i2, eq_operands, eq_computations,
                          /*layout_sensitive=*/false);
   }
   visited_pairs->emplace(min_uid, max_uid);
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 19a0e6e..515b58d 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -358,6 +358,14 @@
     return allocations_;
   }
 
+  // This is similar to copying Allocations(), but since it's moved out, it
+  // preserves the addresses. Since BufferAllocation::Slice keeps a
+  // BufferAllocation*, and some backends keep BufferAllocation::Slice in
+  // xla::Executables, migrating off the use of addresses can be hard.
+  std::vector<BufferAllocation> ReleaseAllocations() {
+    return std::move(allocations_);
+  }
+
   // Returns the total size allocation holding all temporary buffers.
   int64 temp_allocation_total_size() const {
     return temp_allocation_total_size_;
diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h
index d7f63cb..b896264 100644
--- a/tensorflow/compiler/xla/service/collective_ops_utils.h
+++ b/tensorflow/compiler/xla/service/collective_ops_utils.h
@@ -208,11 +208,6 @@
               std::enable_if_t<std::is_base_of<ParticipantData, I>::value>>
 class Rendezvous {
  public:
-  struct ParticipantImplOutput {
-    bool is_primary;
-    O custom_output;
-  };
-
   virtual ~Rendezvous() {}
   explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
 
@@ -241,13 +236,12 @@
           "rendezvous: %p",
           rendezvous.get());
     });
-    return p.first;
+    return std::move(p.first);
   }
 
  protected:
   // Returns domain-specific output O and whether this replica is primary.
-  virtual StatusOr<ParticipantImplOutput> RunCollectiveOp(
-      const I& participant) = 0;
+  virtual StatusOr<O> RunCollectiveOp(const I& participant) = 0;
 
   // Initialize the rendezvous by the first ("primary") thread which reaches the
   // barrier. Returns whether this thread is primary.
@@ -300,8 +294,8 @@
           participant.device_ordinal, participant.stream, key_.ToString());
     });
 
-    TF_ASSIGN_OR_RETURN(ParticipantImplOutput p, RunCollectiveOp(participant));
-    return std::make_pair(p.custom_output, returned_blocking_counter_);
+    TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant));
+    return std::make_pair(std::move(output), returned_blocking_counter_);
   }
 
   const RendezvousKey key_;
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index 9e169fd..623b826 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -42,6 +42,7 @@
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
 #include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/threadpool.h"
 
 namespace xla {
 
@@ -158,6 +159,18 @@
 // platform.
 class Compiler {
  public:
+  struct CompileOptions {
+    // If device_allocator is not null, the compiler may use it to allocate temp
+    // space on the device for use during compilation.  For example, the
+    // compiler may allocate buffers on the device and then run variants of a
+    // given algorithm over those buffers, to see which variant is fastest.  Any
+    // space allocated will be deallocated before the compilation returns.
+    se::DeviceMemoryAllocator* device_allocator = nullptr;
+
+    // An optional thread pool for parallel compilation.
+    tensorflow::thread::ThreadPool* thread_pool = nullptr;
+  };
+
   virtual ~Compiler() {}
 
   // Returns the ID of the platform that this compiler targets.
@@ -165,31 +178,24 @@
 
   // Runs Hlo passes to optimize the given Hlo module, returns the optimized
   // module.
-  //
-  // If device_allocator is not null, the compiler may use it to allocate temp
-  // space on the device for use during compilation.  For example, the compiler
-  // may allocate buffers on the device and then run variants of a given
-  // algorithm over those buffers, to see which variant is fastest.  Any space
-  // allocated should be deallocated before this function returns.
   virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
-      se::DeviceMemoryAllocator* device_allocator) = 0;
+      const CompileOptions& options) = 0;
+  StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
+      std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
+      se::DeviceMemoryAllocator* device_allocator) {
+    return RunHloPasses(std::move(module), executor,
+                        CompileOptions{device_allocator});
+  }
 
   // Runs HLO passes to optimize the given HloModule, perform scheduling and
   // buffer assignment, returns the optimized module and the buffer assignments.
   // This interface is intentionally narrow.
-  //
-  // If device_allocator is not null, the compiler may use it to allocate temp
-  // space on the device for use during compilation. For example, the compiler
-  // may allocate buffers on the device and then run variants of a given
-  // algorithm over those buffers, to see which variant is fastest. Any space
-  // allocated should be deallocated before this function returns.
   virtual StatusOr<
       std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
   RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
-                                   se::StreamExecutor* executor,
-                                   se::DeviceMemoryAllocator* device_allocator,
-                                   bool optimize) {
+                                   se::StreamExecutor* executor, bool optimize,
+                                   const CompileOptions& options) {
     return Unimplemented("This compiler does not support this method");
   }
 
@@ -201,24 +207,33 @@
   //
   // The compiler may optionally specialize to the individual device
   // (not just type of device) indicated by the executor.
-  //
-  // device_allocator is optional; see RunHloPasses.
   virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
-      se::DeviceMemoryAllocator* device_allocator) = 0;
+      const CompileOptions& options) = 0;
+  StatusOr<std::unique_ptr<Executable>> RunBackend(
+      std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
+      se::DeviceMemoryAllocator* device_allocator) {
+    return RunBackend(std::move(module), executor,
+                      CompileOptions{device_allocator});
+  }
 
   // Compiles a set of HLO modules that can run in parallel, potentially
   // communicating data between the modules, and returns a corresponding
   // sequence of executable objects.
   //
-  // device_allocator is optional; see RunHloPasses.
-  //
   // TODO(b/68666782): Remove this method after adding support for multiple
   // modules to RunHloPasses and RunBackends.
   virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
       std::unique_ptr<HloModuleGroup> module_group,
       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) = 0;
+      const CompileOptions& options) = 0;
+  StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
+      std::unique_ptr<HloModuleGroup> module_group,
+      std::vector<std::vector<se::StreamExecutor*>> stream_exec,
+      se::DeviceMemoryAllocator* device_allocator) {
+    return Compile(std::move(module_group), stream_exec,
+                   CompileOptions{device_allocator});
+  }
 
   // Returns the backend configurations that the backend will consider for the
   // given HLO. Returns no configurations if the backend does not support
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index ca67fe6..5bd2d13 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -553,7 +553,7 @@
 
 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
     std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
-    se::DeviceMemoryAllocator* /*device_allocator*/) {
+    const CompileOptions& /*options*/) {
   std::unique_ptr<llvm::TargetMachine> jit_target_machine =
       SimpleOrcJIT::InferTargetMachineForJIT(
           CompilerTargetOptions(module->config()),
@@ -566,12 +566,13 @@
 
 StatusOr<
     std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
-CpuCompiler::RunHloPassesAndBufferAssignement(
-    std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
-    se::DeviceMemoryAllocator* device_allocator, bool optimize) {
+CpuCompiler::RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
+                                              se::StreamExecutor* executor,
+                                              bool optimize,
+                                              const CompileOptions& options) {
   if (optimize) {
-    TF_ASSIGN_OR_RETURN(
-        module, RunHloPasses(std::move(module), executor, device_allocator));
+    TF_ASSIGN_OR_RETURN(module,
+                        RunHloPasses(std::move(module), executor, options));
   }
 
   // Select an order for emitting the HLO instructions for each computation.
@@ -632,11 +633,13 @@
 
 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* /*device_allocator*/) {
+    const CompileOptions& options) {
   VLOG(1) << "Compiling: " << module->name();
   XLA_SCOPED_LOGGING_TIMER(
       absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
-  auto slow_compile_alarm = SlowCompilationAlarm();
+  std::string slow_compilation_msg =
+      absl::StrCat("Compiling module ", module->name());
+  auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
 
   TF_RET_CHECK(stream_exec != nullptr);
   absl::call_once(llvm_command_line_options_initialized,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index 5c056fc..9f5e6a9 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -134,18 +134,17 @@
 
   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
       std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
+      const CompileOptions& options) override;
 
   StatusOr<
       std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
   RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
-                                   se::StreamExecutor* executor,
-                                   se::DeviceMemoryAllocator* device_allocator,
-                                   bool optimize) override;
+                                   se::StreamExecutor* executor, bool optimize,
+                                   const CompileOptions& options) override;
 
   StatusOr<std::unique_ptr<Executable>> RunBackend(
       std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
+      const CompileOptions& options) override;
 
   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index d4d78f5..437dbd5 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -323,7 +323,7 @@
       : xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
 
  protected:
-  xla::StatusOr<ParticipantImplOutput> RunCollectiveOp(
+  xla::StatusOr<std::nullptr_t> RunCollectiveOp(
       const AllToAllParticipantData& /*participant*/) override {
     bool is_primary = InitializationBarrier();
 
@@ -373,7 +373,7 @@
         }
       }
     }
-    return ParticipantImplOutput{is_primary, nullptr};
+    return nullptr;
   }
 };
 
@@ -384,7 +384,7 @@
       : xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t>(k) {}
 
  protected:
-  xla::StatusOr<ParticipantImplOutput> RunCollectiveOp(
+  xla::StatusOr<std::nullptr_t> RunCollectiveOp(
       const CollectivePermuteParticipantData& /*participant*/) override {
     bool primary = InitializationBarrier();
 
@@ -415,7 +415,7 @@
         std::memset(p.destination_data.opaque(), 0, p.byte_size);
       }
     }
-    return ParticipantImplOutput{primary, /*custom_output=*/nullptr};
+    return nullptr;
   }
 };
 
@@ -426,7 +426,7 @@
       : xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t>(k) {}
 
  protected:
-  xla::StatusOr<ParticipantImplOutput> RunCollectiveOp(
+  xla::StatusOr<std::nullptr_t> RunCollectiveOp(
       const xla::AllReduceParticipantData& participant) override {
     xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
     bool primary = InitializationBarrier();
@@ -465,7 +465,7 @@
           LOG(FATAL) << "Unexpected datatype;";
       }
     }
-    return ParticipantImplOutput{primary, /*custom_output=*/nullptr};
+    return nullptr;
   }
 
  private:
diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
index 6984938..ee0ffd6 100644
--- a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
@@ -108,7 +108,7 @@
   // Create the function an call the emission callback.
   mlir::Location loc = mlir::UnknownLoc::get(context);
   auto function = mlir::FuncOp::create(
-      loc, func_name, mlir::FunctionType::get(operand_types, {}, context));
+      loc, func_name, mlir::FunctionType::get(context, operand_types, {}));
   function.addEntryBlock();
   mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc);
   mlir_module->push_back(function);
diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc
index 167033b..ab94695 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder.cc
@@ -93,6 +93,9 @@
       return inst->mutable_operand(init_value_index);
     }
     case HloOpcode::kReduceWindow: {
+      if (inst->shape().IsTuple()) {
+        return Unimplemented("Variadic reduce window not yet supported. ");
+      }
       // Because of the way we do reduce, we already require the `init`
       // operand of hlo reduce instruction to be identity value. Here we reuse
       // the operand.
@@ -1015,6 +1018,10 @@
 StatusOr<bool> RewriteDynamicReduceWindowSamePadding(
     HloInstruction* hlo,
     DynamicDimensionInference* dynamic_dimension_inference) {
+  if (hlo->shape().IsTuple()) {
+    // TODO (b/73062247) variadic reduce window is not yet supported here.
+    return Unimplemented("Variadic reduce window net yet supported.");
+  }
   HloInstruction* input = hlo->mutable_operand(0);
   HloInstruction* init = hlo->mutable_operand(1);
   HloComputation* comp = hlo->parent();
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index ab6f22b..a456b3f 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -115,7 +115,7 @@
 
 tf_cc_test(
     name = "custom_call_test",
-    srcs = if_cuda_is_configured(["custom_call_test.cc"]),
+    srcs = if_cuda_or_rocm(["custom_call_test.cc"]),
     tags = tf_cuda_tests_tags(),
     deps = [
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",  # fixdeps: keep
@@ -259,6 +259,7 @@
         ":hlo_to_ir_bindings",
         ":ir_emission_utils",
         ":launch_dimensions",
+        ":nccl_all_gather_thunk",
         ":nccl_all_reduce_thunk",
         ":nccl_all_to_all_thunk",
         ":parallel_loop_emitter",
@@ -269,6 +270,7 @@
         "//tensorflow/compiler/mlir/hlo",
         "//tensorflow/compiler/mlir/hlo:lhlo",
         "//tensorflow/compiler/mlir/hlo:lhlo_gpu",
+        "//tensorflow/compiler/mlir/xla:attribute_exporter",
         "//tensorflow/compiler/mlir/xla:hlo_module_importer",
         "//tensorflow/compiler/mlir/xla:hlo_utils",
         "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
@@ -439,7 +441,7 @@
     name = "nccl_collective_thunk_src",
     srcs = if_nccl(
         ["nccl_collective_thunk.cc"],
-        ["dummy_collective_thunk.cc"],
+        ["nccl_collective_thunk_dummy.cc"],
     ),
 )
 
@@ -447,7 +449,7 @@
     name = "nccl_collective_thunk",
     srcs = if_cuda_or_rocm(
         [":nccl_collective_thunk_src"],
-        ["dummy_collective_thunk.cc"],
+        ["nccl_collective_thunk_dummy.cc"],
     ),
     hdrs = ["nccl_collective_thunk.h"],
     deps = [
@@ -455,6 +457,7 @@
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/synchronization",
         "//tensorflow/compiler/xla/service:collective_ops_utils",
         "//tensorflow/compiler/xla/service:global_device_id",
         "//tensorflow/compiler/xla/service:hlo",
@@ -476,10 +479,50 @@
 
 # First level of nested select. NCCL requires both if_cuda and if_nccl.
 filegroup(
+    name = "nccl_all_gather_thunk_src",
+    srcs = if_nccl(
+        ["nccl_all_gather_thunk.cc"],
+        ["nccl_all_gather_thunk_dummy.cc"],
+    ),
+)
+
+tf_cuda_library(
+    name = "nccl_all_gather_thunk",
+    srcs = if_cuda_or_rocm(
+        [":nccl_all_gather_thunk_src"],
+        ["nccl_all_gather_thunk_dummy.cc"],
+    ),
+    hdrs = ["nccl_all_gather_thunk.h"],
+    deps = [
+        ":buffer_allocations",
+        ":gpu_executable_run_options",
+        ":hlo_execution_profiler",
+        ":nccl_collective_thunk",
+        ":thunk",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/strings:str_format",
+        "//tensorflow/compiler/xla/service:buffer_assignment",
+        "//tensorflow/compiler/xla/service:collective_ops_utils",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_casting_utils",
+        "//tensorflow/compiler/xla/service:pattern_matcher",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/core:lib",
+    ] + if_nccl([
+        ":virtual_nccl",
+        ":virtual_nccl_utils",
+        ":virtual_rccl",
+    ]),
+)
+
+# First level of nested select. NCCL requires both if_cuda and if_nccl.
+filegroup(
     name = "nccl_all_reduce_thunk_src",
     srcs = if_nccl(
         ["nccl_all_reduce_thunk.cc"],
-        ["dummy_all_reduce_thunk.cc"],
+        ["nccl_all_reduce_thunk_dummy.cc"],
     ),
 )
 
@@ -487,7 +530,7 @@
     name = "nccl_all_reduce_thunk",
     srcs = if_cuda_or_rocm(
         [":nccl_all_reduce_thunk_src"],
-        ["dummy_all_reduce_thunk.cc"],
+        ["nccl_all_reduce_thunk_dummy.cc"],
     ),
     hdrs = ["nccl_all_reduce_thunk.h"],
     deps = [
@@ -519,7 +562,7 @@
     name = "nccl_all_to_all_thunk_src",
     srcs = if_nccl(
         ["nccl_all_to_all_thunk.cc"],
-        ["dummy_all_to_all_thunk.cc"],
+        ["nccl_all_to_all_thunk_dummy.cc"],
     ),
 )
 
@@ -527,7 +570,7 @@
     name = "nccl_all_to_all_thunk",
     srcs = if_cuda_or_rocm(
         [":nccl_all_to_all_thunk_src"],
-        ["dummy_all_to_all_thunk.cc"],
+        ["nccl_all_to_all_thunk_dummy.cc"],
     ),
     hdrs = ["nccl_all_to_all_thunk.h"],
     deps = [
@@ -559,7 +602,7 @@
     name = "nccl_test_utils_src",
     srcs = if_nccl(
         ["nccl_test_utils.cc"],
-        ["dummy_nccl_test_utils.cc"],
+        ["nccl_test_utils_dummy.cc"],
     ),
 )
 
@@ -567,7 +610,7 @@
     name = "nccl_test_utils",
     srcs = if_cuda_or_rocm(
         [":nccl_test_utils_src"],
-        ["dummy_nccl_test_utils.cc"],
+        ["nccl_test_utils_dummy.cc"],
     ),
     hdrs = ["nccl_test_utils.h"],
     deps = [
@@ -629,8 +672,8 @@
         "gpu_debug_info_manager.h",
     ],
     deps = [
-        "//tensorflow/compiler/xla/service:buffer_assignment",
         "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_proto_cc",
         "//tensorflow/compiler/xla/service:hlo_proto_util",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -642,15 +685,10 @@
     srcs = ["gpu_debug_info_manager_test.cc"],
     tags = tf_cuda_tests_tags(),
     deps = [
-        ":gpu_constants",
         ":gpu_debug_info_manager",
-        ":gpu_hlo_schedule",
-        ":stream_assignment",
-        "//tensorflow/compiler/xla/service:buffer_assignment",
+        "//tensorflow/compiler/xla/service:hlo_proto_cc",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
-        "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
-        "//tensorflow/core:test",
     ],
 )
 
@@ -679,6 +717,8 @@
     ] + if_cuda_is_configured([
         "cholesky_thunk.cc",
         "custom_call_thunk.cc",
+    ]) + if_rocm_is_configured([
+        "custom_call_thunk.cc",
     ]),
     hdrs = [
         "collective_permute_thunk.h",
@@ -703,6 +743,8 @@
     ] + if_cuda_is_configured([
         "cholesky_thunk.h",
         "custom_call_thunk.h",
+    ]) + if_rocm_is_configured([
+        "custom_call_thunk.h",
     ]),
     deps = [
         ":backend_configs_cc",
@@ -798,10 +840,27 @@
         "//tensorflow/core/platform:stream_executor_no_cuda",
         "//tensorflow/stream_executor:device_description",
         "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
     ],
 )
 
+tf_cc_test(
+    name = "ir_emission_utils_test",
+    srcs = ["ir_emission_utils_test.cc"],
+    deps = [
+        ":ir_emission_utils",
+        "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
+        "//tensorflow/compiler/mlir/hlo:lhlo",
+        "//tensorflow/compiler/xla/tests:test_utils",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "//tensorflow/core:test",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:StandardOps",
+    ],
+)
+
 cc_library(
     name = "gemm_rewriter",
     srcs = ["gemm_rewriter.cc"],
@@ -1411,7 +1470,11 @@
         "//tensorflow/stream_executor:stream_executor_headers",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:AsmParser",
+        "@llvm-project//llvm:BitReader",
+        "@llvm-project//llvm:BitWriter",
         "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:TransformUtils",
         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:IR",
     ],
@@ -1476,7 +1539,7 @@
         "//tensorflow/stream_executor:stream_executor_headers",
         "//tensorflow/stream_executor/cuda:cuda_diagnostics",
         "//tensorflow/stream_executor/gpu:asm_compiler",
-    ]),
+    ]) + ["//tensorflow/stream_executor/gpu:gpu_driver_header"],
 )
 
 cc_library(
diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
index 974db02..f6409b4 100644
--- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
@@ -108,12 +108,17 @@
 AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
                                     llvm::Module* llvm_module,
                                     GpuVersion gpu_version,
-                                    se::StreamExecutor* stream_exec) {
+                                    se::StreamExecutor* stream_exec,
+                                    bool relocatable) {
   if (rocdl_dir_.empty()) {
     // Compute rocdl_dir_ just once and cache it in this member.
     rocdl_dir_ = GetROCDLDir(module->config());
   }
 
+  if (relocatable) {
+    return Unimplemented("relocatable target binary is not implemented");
+  }
+
   std::vector<uint8> hsaco;
   {
     XLA_SCOPED_LOGGING_TIMER(
diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
index acc5e02..36318ba 100644
--- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
@@ -41,7 +41,8 @@
 
   StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
       const HloModule* hlo_module, llvm::Module* llvm_module,
-      GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
+      GpuVersion gpu_version, se::StreamExecutor* stream_exec,
+      bool relocatable) override;
 
  private:
   // The parent directory of ROCm-Device-Libs IR libraries.
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
index cac335c..a89cb43 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
@@ -34,13 +34,13 @@
 
 Status BufferAllocations::TearDown(
     const std::set<se::DeviceMemoryBase>& live_addresses,
-    const BufferAssignment* buffer_assignment) {
+    absl::Span<const BufferAllocation> allocations) {
   // Deallocate temporary buffers, taking care to try to deallocate all of them
   // even if one of the deallocations fails.
   Status status;
-  const int64 num_buffers = buffer_assignment->Allocations().size();
+  const int64 num_buffers = allocations.size();
   for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
-    const BufferAllocation& allocation = buffer_assignment->GetAllocation(i);
+    const BufferAllocation& allocation = allocations[i];
     se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index());
     // Deallocate buffers marked "maybe_live_out" but aren't actually live out,
     // and temp buffers.
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
index 0d534b0..d5fa8c5 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
@@ -70,7 +70,7 @@
   // Tears down all buffers allocated by this object that are not in
   // `live_addresses`.
   Status TearDown(const std::set<se::DeviceMemoryBase>& live_addresses,
-                  const BufferAssignment* buffer_assignment);
+                  absl::Span<const BufferAllocation> allocations);
 
   std::string ToString() {
     std::string out;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
index dae490e..8d70bb2 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
@@ -76,8 +76,7 @@
     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
     const BufferAllocation::Slice& output_data,
     const BufferAllocation::Slice& output_mean,
-    const BufferAllocation::Slice& output_inv_stddev,
-    const BufferAllocation::Slice& output_tuple)
+    const BufferAllocation::Slice& output_inv_stddev)
     : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
       config_(std::move(config)),
       operand_(operand),
@@ -85,8 +84,7 @@
       offset_(offset),
       output_data_(output_data),
       output_mean_(output_mean),
-      output_inv_stddev_(output_inv_stddev),
-      output_tuple_(output_tuple) {}
+      output_inv_stddev_(output_inv_stddev) {}
 
 Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
     const ExecuteParams& params) {
@@ -110,16 +108,6 @@
       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
       &stream));
 
-  // Write the output tuple.
-  const int kNumOutputs = 3;
-  auto ptrs = absl::make_unique<void*[]>(kNumOutputs);
-  ptrs[0] = output_data.opaque();
-  ptrs[1] = output_mean.opaque();
-  ptrs[2] = output_inv_stddev.opaque();
-  se::DeviceMemory<void*> tuple_addr(
-      buffer_allocations.GetDeviceAddress(output_tuple_));
-  SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream,
-                params.deferred_host_callbacks);
   if (!stream.ok()) {
     return InternalError("BatchNormalizationTraining call failed.");
   }
@@ -134,8 +122,7 @@
     const BufferAllocation::Slice& grad_output,
     const BufferAllocation::Slice& output_grad_data,
     const BufferAllocation::Slice& output_grad_scale,
-    const BufferAllocation::Slice& output_grad_offset,
-    const BufferAllocation::Slice& output_tuple)
+    const BufferAllocation::Slice& output_grad_offset)
     : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
       config_(std::move(config)),
       operand_(operand),
@@ -145,8 +132,7 @@
       grad_output_(grad_output),
       output_grad_data_(output_grad_data),
       output_grad_scale_(output_grad_scale),
-      output_grad_offset_(output_grad_offset),
-      output_tuple_(output_tuple) {}
+      output_grad_offset_(output_grad_offset) {}
 
 Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
     const ExecuteParams& params) {
@@ -172,17 +158,6 @@
       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
       stream));
 
-  // Write the output tuple.
-  const int kNumOutputs = 3;
-  auto ptrs = absl::make_unique<void*[]>(kNumOutputs);
-  ptrs[0] = output_grad_data.opaque();
-  ptrs[1] = output_grad_scale.opaque();
-  ptrs[2] = output_grad_offset.opaque();
-  se::DeviceMemory<void*> tuple_addr(
-      buffer_allocations.GetDeviceAddress(output_tuple_));
-  SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, stream,
-                params.deferred_host_callbacks);
-
   if (!stream->ok()) {
     return InternalError("BatchNormalizationBackward call failed.");
   }
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h
index d45e284..48c46a6 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h
@@ -82,8 +82,7 @@
       const BufferAllocation::Slice& offset,
       const BufferAllocation::Slice& output_data,
       const BufferAllocation::Slice& output_mean,
-      const BufferAllocation::Slice& output_inv_stddev,
-      const BufferAllocation::Slice& output_tuple);
+      const BufferAllocation::Slice& output_inv_stddev);
 
   CudnnBatchNormForwardTrainingThunk(
       const CudnnBatchNormForwardTrainingThunk&) = delete;
@@ -100,22 +99,19 @@
   BufferAllocation::Slice output_data_;
   BufferAllocation::Slice output_mean_;
   BufferAllocation::Slice output_inv_stddev_;
-  BufferAllocation::Slice output_tuple_;
 };
 
 class CudnnBatchNormBackwardThunk : public Thunk {
  public:
-  CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,
-                              CudnnBatchNormConfig&& config,
-                              const BufferAllocation::Slice& operand,
-                              const BufferAllocation::Slice& scale,
-                              const BufferAllocation::Slice& mean,
-                              const BufferAllocation::Slice& inv_stddev,
-                              const BufferAllocation::Slice& grad_output,
-                              const BufferAllocation::Slice& output_grad_data,
-                              const BufferAllocation::Slice& output_grad_scale,
-                              const BufferAllocation::Slice& output_grad_offset,
-                              const BufferAllocation::Slice& output_tuple);
+  CudnnBatchNormBackwardThunk(
+      ThunkInfo thunk_info, CudnnBatchNormConfig&& config,
+      const BufferAllocation::Slice& operand,
+      const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
+      const BufferAllocation::Slice& inv_stddev,
+      const BufferAllocation::Slice& grad_output,
+      const BufferAllocation::Slice& output_grad_data,
+      const BufferAllocation::Slice& output_grad_scale,
+      const BufferAllocation::Slice& output_grad_offset);
 
   CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
   CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =
@@ -133,7 +129,6 @@
   BufferAllocation::Slice output_grad_data_;
   BufferAllocation::Slice output_grad_scale_;
   BufferAllocation::Slice output_grad_offset_;
-  BufferAllocation::Slice output_tuple_;
 };
 
 }  // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
index bd6aa6e..afaaa80 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
@@ -324,7 +324,7 @@
       input = f32[1,17,9,9] parameter(0)
       filter = f32[3,3,17,32] parameter(1)
 
-      conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo"}
+      conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo" op_name="bar"}
       ROOT relu = f32[1,32,9,9] maximum(zeros, conv)
     })";
 
@@ -337,9 +337,9 @@
               backend().default_stream_executor(), backend().memory_allocator())
           .ConsumeValueOrDie()
           ->ToString();
-  EXPECT_THAT(
-      optimized_hlo_string,
-      ::testing::ContainsRegex(R"(custom-call.*metadata=\{op_type="foo"\})"));
+  EXPECT_THAT(optimized_hlo_string,
+              ::testing::ContainsRegex(
+                  R"(custom-call.*metadata=\{op_type="foo" op_name="bar"\})"));
 }
 
 TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) {
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
index 7d06451..7e1a9d2 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
@@ -22,16 +22,41 @@
 namespace xla {
 namespace gpu {
 
-// An HLO pass that attempts to merge fusion instructions to reduce kernel
-// launch overhead and improve data locality.
+// An HLO pass that attempts to merge fusion instructions to reduce memory
+// bandwidth requirements and kernel launch overhead.
 //
-// Fusion instructions are merged into their users if two conditions are met:
+// Consider the example below. On the left-hand side, op A is the producer and
+// ops B and C are its consumers. FusionMerger duplicates producer ops and fuses
+// them into all consumers. The result is depicted on the right-hand side below.
 //
-// 1) The flops_to_bytes ratio of the fusion instruction is below the threshold
-//    value of 1.0.
-// 2) The result of merging the fusion instruction into its users would not
-//    increase bytes transferred.
+//        p                    p
+//        |                  /   \
+//        v                 /     \
+//        A            +fusion+  +fusion+
+//      /   \          |  A'  |  |  A"  |
+//     |     |         |  |   |  |  |   |
+//     v     v         |  v   |  |  v   |
+//     B     C         |  B   |  |  C   |
+//                     +------+  +------+
 //
+// Op A has been cloned twice and fused with B and C. The kernel launch overhead
+// is reduced from 3 to 2. The memory bandwidth requirements may be reduced.
+// We trade 1 read of input(A) + 1 write and 2 reads of output(A) for 2 reads of
+// input(A). In general the achieveable savings in memory bandwidth depend on
+// the differences in memory read and written and the number of consumers. The
+// FusionMeger pass takes this into account when making fusion decisions.
+//
+// The pass traverses the HLO module in reverse post-order (defs before uses).
+// Fusion instructions are merged into their users if some conditions are met:
+// * The result of merging the fusion instruction into its users would not
+//   increase bytes transferred.
+// * Producer ops are fusible with _all_ consumers. If they are not fusible with
+//   at least one consumers, they won't be fused at all.
+// * Producers are kLoop fusion ops.
+//
+// None of these restrictions are necessary for correctness. In fact, lifting
+// the latter two could be beneficial.
+
 class FusionMerger : public HloModulePass {
  public:
   absl::string_view name() const override { return "fusion_merger"; }
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 3eee882..8084e0e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -24,11 +24,15 @@
 #include "absl/memory/memory.h"
 #include "absl/strings/numbers.h"
 #include "absl/strings/str_cat.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/Bitcode/BitcodeReader.h"
+#include "llvm/Bitcode/BitcodeWriter.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
+#include "llvm/Transforms/Utils/SplitModule.h"
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/InitAllDialects.h"  // from @llvm-project
 #include "tensorflow/compiler/xla/protobuf_util.h"
@@ -114,11 +118,13 @@
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/blocking_counter.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/regexp.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
 #include "tensorflow/core/platform/subprocess.h"
+#include "tensorflow/core/platform/threadpool.h"
 #include "tensorflow/core/platform/tracing.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/util/env_var.h"
@@ -415,7 +421,8 @@
   pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
 
   if (RequireDeterminism() ||
-      hlo_module->config().debug_options().xla_gpu_deterministic_reductions()) {
+      hlo_module->config().debug_options().xla_gpu_deterministic_reductions() ||
+      hlo_module->config().debug_options().xla_gpu_deterministic_ops()) {
     pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>();
   }
 
@@ -470,14 +477,14 @@
 
 StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* device_allocator) {
+    const CompileOptions& options) {
   // We dump the post-optimization HLO in RunBackend so no need to dump it here.
   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
   tensorflow::profiler::TraceMe activity(
       [&] { return absl::StrCat("HLO Transforms:", module->name()); },
       tensorflow::profiler::TraceMeLevel::kInfo);
   TF_RETURN_IF_ERROR(
-      OptimizeHloModule(module.get(), stream_exec, device_allocator));
+      OptimizeHloModule(module.get(), stream_exec, options.device_allocator));
 
   TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
 
@@ -494,10 +501,10 @@
     std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
 GpuCompiler::RunHloPassesAndBufferAssignement(
     std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor,
-    se::DeviceMemoryAllocator* device_allocator, bool optimize) {
+    bool optimize, const CompileOptions& options) {
   if (optimize) {
-    TF_ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module),
-                                                 executor, device_allocator));
+    TF_ASSIGN_OR_RETURN(hlo_module,
+                        RunHloPasses(std::move(hlo_module), executor, options));
   }
 
   std::unique_ptr<StreamAssignment> stream_assignment =
@@ -641,24 +648,149 @@
   return Status::OK();
 }
 
+StatusOr<std::pair<std::string, std::vector<uint8>>>
+GpuCompiler::CompileToTargetBinary(const HloModule& module,
+                                   std::unique_ptr<llvm::Module> llvm_module,
+                                   se::StreamExecutor* stream_exec,
+                                   const CompileOptions& options) {
+  using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
+
+  const auto compile_single_module =
+      [this, stream_exec, &module](
+          llvm::Module* llvm_module,
+          bool relocatable) -> StatusOr<BackendCompileResult> {
+    {
+      XLA_SCOPED_LOGGING_TIMER(
+          "GpuCompiler::RunBackend - Running LLVM verifier");
+
+      std::string err;
+      llvm::raw_string_ostream err_stream(err);
+
+      // verifyModule() returns true if the module is broken.
+      TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
+          << "Invalid LLVM IR before optimizations:\n"
+          << err_stream.str()
+          << "\nThis probably indicates a bug in the HLO -> LLVM IR "
+             "lowering. "
+             "Rerun with --xla_dump_to to get the IR and looks for files "
+             "with "
+             "name containing: *"
+          << FilenameFor(module, "", "") << "*";
+    }
+    GpuVersion gpu_version = GetGpuVersion(stream_exec);
+    return CompileTargetBinary(&module, llvm_module, gpu_version, stream_exec,
+                               relocatable);
+  };
+
+  tensorflow::thread::ThreadPool* thread_pool = options.thread_pool;
+
+  absl::optional<tensorflow::thread::ThreadPool> overriding_thread_pool;
+  if (module.config().debug_options().xla_gpu_force_compilation_parallelism() !=
+      0) {
+    overriding_thread_pool.emplace(
+        tensorflow::Env::Default(), "",
+        module.config()
+            .debug_options()
+            .xla_gpu_force_compilation_parallelism());
+    thread_pool = &*overriding_thread_pool;
+  }
+
+  if (!thread_pool) {
+    return compile_single_module(llvm_module.get(), /*relocatable=*/false);
+  }
+
+  // Test whether LinkModules is supported.
+  if (this->LinkModules(stream_exec, {}).status().code() ==
+      tensorflow::error::Code::UNIMPLEMENTED) {
+    return compile_single_module(llvm_module.get(), /*relocatable=*/false);
+  }
+
+  std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
+  int num_functions = 0;
+  for (llvm::Function& func : llvm_module->functions()) {
+    if (!func.isDeclaration() &&
+        func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
+      num_functions++;
+    }
+  }
+
+  llvm::SplitModule(
+      std::move(llvm_module),
+      std::max<unsigned>(
+          1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
+      [&](std::unique_ptr<llvm::Module> module) {
+        llvm_modules.push_back(std::move(module));
+      },
+      /*PreserveLocals=*/true);
+
+  std::vector<StatusOr<BackendCompileResult>> compile_results(
+      llvm_modules.size());
+  tensorflow::BlockingCounter counter(llvm_modules.size());
+  for (int i = 0; i < llvm_modules.size(); i++) {
+    thread_pool->Schedule([&compile_results, compile_single_module, i,
+                           &llvm_modules, &counter] {
+      llvm::Module* original_module = llvm_modules[i].get();
+      llvm::LLVMContext context;
+      std::string buffer;
+      llvm::raw_string_ostream error(buffer);
+      llvm::DiagnosticPrinterRawOStream printer(error);
+      auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
+                                  void* Context) {
+        auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
+        diag_info.print(*printer);
+      };
+      context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
+
+      std::unique_ptr<llvm::Module> new_llvm_module;
+      // Switch to a new context by dumping and re-parsing LLVM IR. Each thread
+      // has its own context to avoid race conditions.
+      {
+        std::string ir;
+        {
+          llvm::raw_string_ostream os(ir);
+          original_module->print(os, nullptr);
+        }
+        llvm::SMDiagnostic err;
+        new_llvm_module = llvm::parseAssemblyString(ir, err, context);
+      }
+
+      compile_results[i] =
+          compile_single_module(new_llvm_module.get(), /*relocatable=*/true);
+      counter.DecrementCount();
+    });
+  }
+  counter.Wait();
+
+  std::string ptx_snippets;
+  std::vector<std::vector<uint8>> submodule_compile_results;
+  for (auto& maybe_result : compile_results) {
+    TF_ASSIGN_OR_RETURN(auto result, maybe_result);
+    if (result.second.empty()) {
+      continue;
+    }
+    ptx_snippets += result.first;
+    ptx_snippets += "\n";
+    submodule_compile_results.push_back(result.second);
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      std::vector<uint8> backend_result,
+      this->LinkModules(stream_exec, std::move(submodule_compile_results)));
+
+  return std::make_pair(ptx_snippets, backend_result);
+}
+
 StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* device_allocator) {
+    const CompileOptions& options) {
   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
-  auto slow_compile_alarm = SlowCompilationAlarm();
+  std::string slow_compilation_msg =
+      absl::StrCat("Compiling module ", module->name());
+  auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
 
   TF_RET_CHECK(stream_exec != nullptr);
 
   llvm::LLVMContext llvm_context;
-  std::string buffer;
-  llvm::raw_string_ostream error(buffer);
-  llvm::DiagnosticPrinterRawOStream printer(error);
-  auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
-                              void* Context) {
-    auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
-    diag_info.print(*printer);
-  };
-  llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
 
   GpuDeviceInfo gpu_device_info;
   gpu_device_info.threads_per_block_limit =
@@ -724,39 +856,31 @@
 
   llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
 
-  {
-    XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier");
-
-    std::string err;
-    llvm::raw_string_ostream err_stream(err);
-
-    // verifyModule() returns true if the module is broken.
-    TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
-        << "Invalid LLVM IR before optimizations:\n"
-        << err_stream.str()
-        << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
-           "Rerun with --xla_dump_to to get the IR and looks for files with "
-           "name containing: *"
-        << FilenameFor(*module, "", "") << "*";
-  }
-
-  GpuVersion gpu_version = GetGpuVersion(stream_exec);
-
   using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
   TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
-                      CompileTargetBinary(module.get(), llvm_module.get(),
-                                          gpu_version, stream_exec));
-
+                      CompileToTargetBinary(*module, std::move(llvm_module),
+                                            stream_exec, options));
   if (DumpingEnabledForHloModule(*module)) {
     DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
                             thunk_schedule->ToString());
   }
 
+  using OutputInfoMap =
+      absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
+  TF_ASSIGN_OR_RETURN(OutputInfoMap output_info,
+                      GetOutputInfo(*module, *buffer_assignment));
+  auto buffer_assignment_proto =
+      std::make_unique<BufferAssignmentProto>(buffer_assignment->ToProto());
+  std::vector<BufferAllocation> allocations =
+      buffer_assignment->ReleaseAllocations();
+
+  GpuVersion gpu_version = GetGpuVersion(stream_exec);
   auto* gpu_executable = new GpuExecutable(
-      backend_result.first, backend_result.second, gpu_version,
-      std::move(thunk_schedule), std::move(module),
-      std::move(buffer_assignment), std::move(profile_printer),
-      std::move(profile_index_map), std::move(constants));
+      {std::move(backend_result.first), std::move(backend_result.second),
+       gpu_version, std::move(thunk_schedule), std::move(constants),
+       std::move(output_info), std::move(module), std::move(allocations),
+       std::move(buffer_assignment_proto), std::move(profile_printer),
+       std::move(profile_index_map)});
   if (embed_ir_in_executable) {
     DCHECK_NE("", ir_module_string_before_opt);
     gpu_executable->set_ir_module_string(ir_module_string_before_opt);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
index 824d740..1d42976 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
@@ -53,14 +53,13 @@
 
   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
       std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
+      const CompileOptions& options) override;
 
   StatusOr<
       std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
   RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module,
-                                   se::StreamExecutor* executor,
-                                   se::DeviceMemoryAllocator* device_allocator,
-                                   bool optimize) override;
+                                   se::StreamExecutor* executor, bool optimize,
+                                   const CompileOptions& options) override;
 
   Status OptimizeHloModule(HloModule* hlo_module,
                            se::StreamExecutor* stream_exec,
@@ -84,19 +83,23 @@
 
   virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
   CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
-                      GpuVersion gpu_version,
-                      se::StreamExecutor* stream_exec) = 0;
+                      GpuVersion gpu_version, se::StreamExecutor* stream_exec,
+                      bool relocatable) = 0;
 
   Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
 
   StatusOr<std::unique_ptr<Executable>> RunBackend(
       std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
+      const CompileOptions& options) override;
 
   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
                      AotCompilationOptions const& options) override;
 
+  StatusOr<std::pair<std::string, std::vector<uint8>>> CompileToTargetBinary(
+      const HloModule& module, std::unique_ptr<llvm::Module> llvm_module,
+      se::StreamExecutor* stream_exec, const CompileOptions& options);
+
   se::Platform::Id PlatformId() const override { return platform_id_; }
 
   HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
@@ -116,6 +119,12 @@
   }
 
  private:
+  virtual StatusOr<std::vector<uint8>> LinkModules(
+      se::StreamExecutor* stream_exec,
+      std::vector<std::vector<uint8>> modules) {
+    return Unimplemented("LinkModules is not implemented.");
+  }
+
   se::Platform::Id platform_id_;
 
   // The triple that represents our target.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
index 925caad..0368d1c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
@@ -636,7 +636,8 @@
   }
 
   auto selected_result = filtered_results.begin();
-  if (!RequireCudnnDeterminism()) {
+  if (!RequireCudnnDeterminism() &&
+      !hlo_module_config.debug_options().xla_gpu_deterministic_ops()) {
     selected_result = absl::c_min_element(
         filtered_results,
         [](const AutotuneResult& lhs, const AutotuneResult& rhs) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc
index e0ccbad..ba41ad1 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
 
 #include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
@@ -255,15 +256,17 @@
 }  // anonymous namespace
 
 StatusOr<GpuConvConfig> GetGpuConvConfig(
-    const HloCustomCallInstruction* cudnn_call) {
+    const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
   GpuConvConfig config;
 
-  config.input_type = cudnn_call->operand(0)->shape().element_type();
-  config.output_type = cudnn_call->shape().tuple_shapes(0).element_type();
+  const Shape& operand0_shape = desc.operand0_shape;
+  const Shape& operand1_shape = desc.operand1_shape;
+  const Shape& result_shape = desc.result_shape;
+  const CudnnConvBackendConfig& backend_config = desc.backend_config;
 
-  TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
-                      cudnn_call->backend_config<CudnnConvBackendConfig>());
-  TF_ASSIGN_OR_RETURN(config.kind, GetCudnnConvKind(cudnn_call));
+  config.input_type = operand0_shape.element_type();
+  config.output_type = result_shape.element_type();
+  config.kind = desc.kind;
 
   // The third field is scratch size stored from conv_algorithm_picker
   // The operand is added to the shape field of the conv instruction
@@ -271,13 +274,9 @@
   config.algorithm = se::dnn::AlgorithmConfig(
       se::dnn::AlgorithmDesc(backend_config.algorithm(),
                              backend_config.tensor_ops_enabled()),
-      cudnn_call->shape().tuple_shapes(1).dimensions(0));
+      desc.scratch_size);
   config.conv_result_scale = backend_config.conv_result_scale();
 
-  Shape operand0_shape = cudnn_call->operand(0)->shape();
-  Shape operand1_shape = cudnn_call->operand(1)->shape();
-  Shape result_shape = cudnn_call->shape().tuple_shapes(0);
-
   switch (config.kind) {
     case CudnnConvKind::kForward:
     case CudnnConvKind::kForwardActivation:
@@ -311,9 +310,8 @@
     fusion.side_input_scale = backend_config.side_input_scale();
   }
 
-  const Window& window = cudnn_call->window();
-  const ConvolutionDimensionNumbers& dnums =
-      cudnn_call->convolution_dimension_numbers();
+  const Window& window = desc.window;
+  const ConvolutionDimensionNumbers& dnums = desc.dnums;
 
   VLOG(3) << "Convolution Algorithm: "
           << config.algorithm.algorithm()->algo_id();
@@ -330,7 +328,7 @@
   VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
 
   const int num_dimensions = window.dimensions_size();
-  CHECK_LE(num_dimensions, 3) << cudnn_call->ToString();
+  CHECK_LE(num_dimensions, 3) << inst_as_string;
 
   // cuDNN does not support 1D convolutions. We therefore express 1D
   // convolutions as 2D convolutions where the first spatial dimension is 1.
@@ -344,18 +342,18 @@
       window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
 
   CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
-      << cudnn_call->ToString();
+      << inst_as_string;
   CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
-      << cudnn_call->ToString();
+      << inst_as_string;
   CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
-      << cudnn_call->ToString();
+      << inst_as_string;
   for (const WindowDimension& dim : window.dimensions()) {
-    CHECK_EQ(dims_reversed, dim.window_reversal()) << cudnn_call->ToString();
-    CHECK_EQ(dim.padding_low(), dim.padding_high()) << cudnn_call->ToString();
+    CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
+    CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
     CHECK_EQ(dim.base_dilation(), 1)
         << "cudnn does not support base dilation; it "
            "must be made explicit with a kPad: "
-        << cudnn_call->ToString();
+        << inst_as_string;
   }
 
   // cuDNN's convolution APIs support the BDYX layout for activations/output and
@@ -364,43 +362,43 @@
   FilterLayout filter_dl;
   DataLayout output_dl;
 
-  const Shape* input_shape = &config.input_shape;
-  const Shape* filter_shape = &config.filter_shape;
-  const Shape* output_shape = &config.output_shape;
+  const Shape& input_shape = config.input_shape;
+  const Shape& filter_shape = config.filter_shape;
+  const Shape& output_shape = config.output_shape;
 
   TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
                       XlaConvLayoutsToStreamExecutorLayouts(
-                          dnums, input_shape->layout(), filter_shape->layout(),
-                          output_shape->layout()));
+                          dnums, input_shape.layout(), filter_shape.layout(),
+                          output_shape.layout()));
 
   BatchDescriptor& input_descriptor = config.input_descriptor;
   input_descriptor = BatchDescriptor(effective_num_dimensions);
   input_descriptor.set_layout(input_dl)
       .set_feature_map_count(
-          input_shape->dimensions(dnums.input_feature_dimension()))
-      .set_count(input_shape->dimensions(dnums.input_batch_dimension()));
+          input_shape.dimensions(dnums.input_feature_dimension()))
+      .set_count(input_shape.dimensions(dnums.input_batch_dimension()));
   for (int dim = 0; dim < num_dimensions; ++dim) {
     // Note that the dimensions are reversed. The same holds below.
     input_descriptor.set_spatial_dim(
         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
-        input_shape->dimensions(dnums.input_spatial_dimensions(dim)));
+        input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
   }
 
   FilterDescriptor& filter_descriptor = config.filter_descriptor;
   filter_descriptor = FilterDescriptor(effective_num_dimensions);
   filter_descriptor.set_layout(filter_dl)
       .set_input_feature_map_count(
-          filter_shape->dimensions(dnums.kernel_input_feature_dimension()))
+          filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
       .set_output_feature_map_count(
-          filter_shape->dimensions(dnums.kernel_output_feature_dimension()));
+          filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
   for (int dim = 0; dim < num_dimensions; ++dim) {
     filter_descriptor.set_spatial_dim(
         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
-        filter_shape->dimensions(dnums.kernel_spatial_dimensions(dim)));
+        filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
   }
 
   config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
-  config.conv_desc.set_group_count(cudnn_call->feature_group_count());
+  config.conv_desc.set_group_count(desc.feature_group_count);
   config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
   for (int dim = 0; dim < num_dimensions; ++dim) {
     config.conv_desc
@@ -419,12 +417,12 @@
   output_descriptor = BatchDescriptor(effective_num_dimensions);
   output_descriptor.set_layout(output_dl)
       .set_feature_map_count(
-          output_shape->dimensions(dnums.output_feature_dimension()))
-      .set_count(output_shape->dimensions(dnums.output_batch_dimension()));
+          output_shape.dimensions(dnums.output_feature_dimension()))
+      .set_count(output_shape.dimensions(dnums.output_batch_dimension()));
   for (int dim = 0; dim < num_dimensions; ++dim) {
     output_descriptor.set_spatial_dim(
         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
-        output_shape->dimensions(dnums.output_spatial_dimensions(dim)));
+        output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
   }
 
   // Add a singleton dimension in the 1D convolution case.
@@ -439,6 +437,23 @@
   return config;
 }
 
+StatusOr<GpuConvConfig> GetGpuConvConfig(
+    const HloCustomCallInstruction* cudnn_call) {
+  GpuConvDescriptor descriptor;
+
+  TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
+  TF_ASSIGN_OR_RETURN(descriptor.backend_config,
+                      cudnn_call->backend_config<CudnnConvBackendConfig>());
+  descriptor.operand0_shape = cudnn_call->operand(0)->shape();
+  descriptor.operand1_shape = cudnn_call->operand(1)->shape();
+  descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
+  descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
+  descriptor.window = cudnn_call->window();
+  descriptor.dnums = cudnn_call->convolution_dimension_numbers();
+  descriptor.feature_group_count = cudnn_call->feature_group_count();
+  return GetGpuConvConfig(descriptor, cudnn_call->ToString());
+}
+
 StatusOr<GpuConvParams> GetGpuConvParams(
     const GpuConvConfig& config,
     absl::Span<se::DeviceMemoryBase> operand_buffers,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h
index 5d27e6d..af63dee 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h
@@ -119,9 +119,31 @@
                   se::ScratchAllocator* scratch_allocator, se::Stream* stream,
                   RunConvOptions = {});
 
+// Struct to describe properties of a convolution without being tied to specific
+// IR. Will be used to help build Convolution thunks from either XLA HLO or
+// LHLO GPU dialect in MLIR.
+struct GpuConvDescriptor {
+  CudnnConvKind kind;
+  CudnnConvBackendConfig backend_config;
+  Shape operand0_shape;
+  Shape operand1_shape;
+  Shape result_shape;
+  size_t scratch_size;
+  Window window;
+  ConvolutionDimensionNumbers dnums;
+  int64 feature_group_count;
+};
+
+// Returns the convolution configuration given a XLA HLO instruction.
 StatusOr<GpuConvConfig> GetGpuConvConfig(
     const HloCustomCallInstruction* cudnn_call);
 
+// Returns the convolution configuration given a convolution descriptor `desc`
+// and a string representation of the convolution instruction `inst_as_string`
+// (for error reporting).
+StatusOr<GpuConvConfig> GetGpuConvConfig(const GpuConvDescriptor& desc,
+                                         absl::string_view inst_as_string);
+
 // Implementation details exposed for debugging and log analysis.
 StatusOr<GpuConvParams> GetGpuConvParams(
     const GpuConvConfig& conv_config,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.cc
index 51888c0..9851ce0 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.cc
@@ -22,7 +22,7 @@
 
 void GpuDebugInfoManager::RegisterModule(
     const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
-    std::shared_ptr<const BufferAssignment> buffer_assignment) {
+    std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
   tensorflow::mutex_lock lock(mutex_);
   if (active_modules_.find(module_id) != active_modules_.end()) {
     active_modules_[module_id].instances.emplace_back(hlo_module,
@@ -40,7 +40,7 @@
 // However during tracing, we will defer the cleanup after serialization.
 void GpuDebugInfoManager::UnregisterModule(
     const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
-    std::shared_ptr<const BufferAssignment> buffer_assignment) {
+    std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
   tensorflow::mutex_lock lock(mutex_);
   CHECK(active_modules_.find(module_id) != active_modules_.end());
   GpuModuleEntry& active_module = active_modules_[module_id];
@@ -146,8 +146,10 @@
       // non-nullptr. Due to the inconvenience of creation of buffer_assignment
       // object in test, we set it to nullptr and guard this for it.
       if (m.instances[0].hlo_module && m.instances[0].buffer_assignment) {
-        info.hlo_proto = absl::make_unique<HloProto>(MakeHloProto(
-            *m.instances[0].hlo_module, *m.instances[0].buffer_assignment));
+        info.hlo_proto = absl::make_unique<HloProto>(
+            MakeHloProto(*m.instances[0].hlo_module));
+        *info.hlo_proto->mutable_buffer_assignment() =
+            *m.instances[0].buffer_assignment;
       }
       module_debug_info->emplace_back(std::move(info));
     }
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h
index 0a8b444..36d4435 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h
@@ -17,7 +17,7 @@
 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_DEBUG_INFO_MANAGER_H_
 
 #include "absl/container/flat_hash_map.h"
-#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/core/lib/core/status.h"
 
@@ -56,14 +56,14 @@
   // Modules with same module id can be registered and tracked separately.
   void RegisterModule(
       const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
-      std::shared_ptr<const BufferAssignment> buffer_assignment);
+      std::shared_ptr<const BufferAssignmentProto> buffer_assignment);
 
   // Unregister an active module. When the last active module of the same
   // module id is out of scope, we remove it from our database.
   // However during tracing, we will defer the cleanup after serialization.
   void UnregisterModule(
       const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
-      std::shared_ptr<const BufferAssignment> buffer_assignment);
+      std::shared_ptr<const BufferAssignmentProto> buffer_assignment);
 
   // Register when the module start execution on certain device.
   // TODO(jiesun): Do we need to track which device this is?
@@ -110,10 +110,10 @@
   // tracking, they need to be tracked separately.
   struct GpuModuleInstance {
     GpuModuleInstance(std::shared_ptr<HloModule> m,
-                      std::shared_ptr<const BufferAssignment> b)
+                      std::shared_ptr<const BufferAssignmentProto> b)
         : hlo_module(std::move(m)), buffer_assignment(std::move(b)) {}
     std::shared_ptr<HloModule> hlo_module;
-    std::shared_ptr<const BufferAssignment> buffer_assignment;
+    std::shared_ptr<const BufferAssignmentProto> buffer_assignment;
     bool active = true;
   };
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager_test.cc
index 5ea26c5..e0d42a3 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager_test.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 #include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
 
-#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 
 namespace xla {
@@ -30,7 +30,7 @@
     int unique_id;
     string id;
     std::shared_ptr<HloModule> module;
-    std::shared_ptr<BufferAssignment> buffer_assignment;
+    std::shared_ptr<BufferAssignmentProto> buffer_assignment;
   };
 
   // Return unique id of this module.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 1a0d1e0..eef078c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -54,31 +54,27 @@
 
 // Implementation note: HLO profiling is always enabled for GPU executables,
 // since we can use timers around thunks.
-GpuExecutable::GpuExecutable(
-    const string& text, const std::vector<uint8>& binary,
-    GpuVersion gpu_version, std::unique_ptr<const ThunkSchedule> thunk_schedule,
-    std::shared_ptr<HloModule> hlo_module,
-    std::shared_ptr<const BufferAssignment> assignment,
-    std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
-    std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
-    std::vector<ConstantInfo> globals)
-    : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
-                 std::move(hlo_profile_index_map)),
-      text_(text),
-      binary_(binary),
-      gpu_version_(gpu_version),
-      thunk_schedule_(std::move(thunk_schedule)),
-      assignment_(std::move(assignment)),
-      constants_(std::move(globals)) {
-  CHECK(has_module() && assignment_);
+GpuExecutable::GpuExecutable(GpuExecutable::Params params)
+    : Executable(std::move(params.hlo_module),
+                 std::move(params.hlo_profile_printer_data),
+                 std::move(params.hlo_profile_index_map)),
+      text_(std::move(params.asm_text)),
+      binary_(std::move(params.binary)),
+      gpu_version_(params.gpu_version),
+      thunk_schedule_(std::move(params.thunk_schedule)),
+      allocations_(std::move(params.allocations)),
+      debug_buffer_assignment_(std::move(params.debug_buffer_assignment)),
+      constants_(std::move(params.constants)),
+      output_info_(std::move(params.output_info)) {
+  CHECK(has_module());
   GpuDebugInfoManager::Get()->RegisterModule(module().name(), shared_module(),
-                                             assignment_);
+                                             debug_buffer_assignment_);
 }
 
 GpuExecutable::~GpuExecutable() {
-  CHECK(has_module() && assignment_);
+  CHECK(has_module());
   GpuDebugInfoManager::Get()->UnregisterModule(module().name(), shared_module(),
-                                               assignment_);
+                                               debug_buffer_assignment_);
 
   {
     // We could have issued host->device mem copies in ResolveConstantGlobals.
@@ -381,11 +377,11 @@
       [&] { return std::string("Build buffer allocations"); },
       tensorflow::profiler::TraceMeLevel::kInfo);
 
-  const int64 num_buffers = assignment_->Allocations().size();
+  const int64 num_buffers = allocations_.size();
   std::vector<se::DeviceMemoryBase> buffers;
   buffers.reserve(num_buffers);
   for (int64 i = 0; i < num_buffers; ++i) {
-    const BufferAllocation& allocation = assignment_->GetAllocation(i);
+    const BufferAllocation& allocation = allocations_[i];
     TF_ASSIGN_OR_RETURN(
         se::DeviceMemoryBase buffer,
         BufferForAllocation(arguments, globals, allocation, memory_allocator,
@@ -396,31 +392,6 @@
   return {{buffers, executor->device_ordinal(), memory_allocator}};
 }
 
-// Returns `true` if the entire tuple contents is aliased.
-static bool EntireTupleContentsAliased(
-    const Shape& output_shape, const ShapeIndex& index,
-    const HloInputOutputAliasConfig& alias_config) {
-  const Shape& indexed_shape = ShapeUtil::GetSubshape(output_shape, index);
-  if (!indexed_shape.IsTuple()) {
-    return false;
-  }
-  bool all_aliased = true;
-  ShapeUtil::ForEachSubshape(
-      indexed_shape, [&](const Shape& subshape, const ShapeIndex& subindex) {
-        if (subindex.empty()) {
-          return;
-        }
-        std::vector<int64> full_index;
-        absl::c_copy(index, std::back_inserter(full_index));
-        absl::c_copy(subindex, std::back_inserter(full_index));
-        if (!alias_config.OutputHasAlias(
-                ShapeIndex(full_index.begin(), full_index.end()))) {
-          all_aliased = false;
-        }
-      });
-  return all_aliased;
-}
-
 StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
     std::vector<ExecutionInput> arguments,
@@ -432,10 +403,6 @@
   const bool block_host_until_done =
       !memory_allocator->AllowsAsynchronousDeallocation();
 
-  if (GetRootValueSet().IsAmbiguous()) {
-    return Unimplemented("Points-to set of root instruction is ambiguous");
-  }
-
   const GpuExecutable::BufferAllocToDeviceMemoryMap* globals;
   {
     tensorflow::profiler::TraceMe hlo_module_activity(
@@ -458,33 +425,37 @@
                                                 memory_allocator, executor));
   VLOG(2) << buffer_allocations.ToString();
   std::set<se::DeviceMemoryBase> buffers_in_result;
+
+  const bool is_entire_tuple_contents_aliased = [&] {
+    for (auto& p : result.MutableResult()->buffers().leaves()) {
+      const OutputInfo& output_info = output_info_.at(p.first);
+      if (!output_info.alias_config.has_value()) {
+        return false;
+      }
+    }
+    return true;
+  }();
+
   for (auto& p : result.MutableResult()->buffers()) {
     const ShapeIndex& index = p.first;
+    const OutputInfo& output_info = output_info_.at(index);
+    const BufferAllocation* allocation =
+        &allocations_[output_info.allocation_index];
     se::DeviceMemoryBase& result_buffer = p.second;
-    const auto& sources = GetRootValueSet().element(index);
-    // The points-to set is unambiguous so the set should be a
-    // singleton. That is, we know exactly which instruction
-    // produced the array at this element.
-    CHECK_EQ(1, sources.values().size());
-    HloInstruction* src_hlo = sources.values()[0]->instruction();
 
-    VLOG(4) << "Looking at: " << src_hlo->ToString()
-            << "@ index: " << index.ToString();
+    VLOG(4) << "Looking at: allocation " << output_info.allocation_index
+            << " @ index: " << index.ToString();
 
-    const HloInputOutputAliasConfig& input_output_alias =
-        module().input_output_alias_config();
-    absl::optional<HloInputOutputAliasConfig::Alias> alias =
-        input_output_alias.GetAliasedParameter(index);
-    if (alias) {
-      CHECK_LT(alias->parameter_number, arguments.size());
-      ExecutionInput& input = arguments[alias->parameter_number];
+    if (output_info.alias_config) {
+      ExecutionInput& input = arguments[allocation->parameter_number()];
       MaybeOwningDeviceMemory* maybe_owning_memory =
-          input.MutableBuffer(alias->parameter_index);
-      if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) {
+          input.MutableBuffer(allocation->param_shape_index());
+      if (output_info.alias_config->must_alias() &&
+          !maybe_owning_memory->HasOwnership()) {
         return InvalidArgument(
             "An input was configured to be must-alias at "
-            "compile time but not donated at runtime: %s",
-            alias->ToString());
+            "compile time but not donated at runtime: allocation %d",
+            output_info.allocation_index);
       }
       if (absl::optional<se::OwningDeviceMemory> owning =
               maybe_owning_memory->Release()) {
@@ -504,7 +475,7 @@
         // the indices to drop the addresses from its own ScopedShapedBuffer
         // result, if the ExecutionOutput is not committed.
         result.AddAliasedIndex(index);
-      } else if (src_hlo->opcode() != HloOpcode::kParameter) {
+      } else if (!output_info.passthrough) {
         // The guard is above is not to insert copy-protection when aliasing
         // pass-through params, as we do not need to write into the output
         // buffer.
@@ -516,12 +487,9 @@
             se::OwningDeviceMemory allocated_buffer,
             memory_allocator->Allocate(device_ordinal, allocation_size));
         result_buffer = allocated_buffer.Release();
-        TF_ASSIGN_OR_RETURN(
-            const BufferAllocation::Slice slice,
-            assignment_->GetUniqueSlice(src_hlo, sources.values()[0]->index()));
-        CHECK_EQ(slice.offset(), 0) << "Parameter should get its own slice";
         se::DeviceMemoryBase& aliased_buffer =
-            buffer_allocations.GetMutableDeviceAddress(slice.index());
+            buffer_allocations.GetMutableDeviceAddress(
+                output_info.allocation_index);
         CHECK_EQ(aliased_buffer.size(), result_buffer.size());
         run_options->stream()->ThenMemcpyD2D(&result_buffer, aliased_buffer,
                                              aliased_buffer.size());
@@ -532,15 +500,12 @@
     if (result_buffer.is_null()) {
       // The source instruction should have a non-parameter buffer
       // assigned.
-      TF_ASSIGN_OR_RETURN(
-          const BufferAllocation::Slice slice,
-          assignment_->GetUniqueSlice(src_hlo, sources.values()[0]->index()));
-      result_buffer = buffer_allocations.GetDeviceAddress(slice.index());
+      result_buffer =
+          buffer_allocations.GetDeviceAddress(output_info.allocation_index);
 
       // If the entire tuple contents is aliased, the copy insertion will *not*
       // materialize a new tuple, so we mark it as aliased as well.
-      if (EntireTupleContentsAliased(root->shape(), index,
-                                     input_output_alias)) {
+      if (is_entire_tuple_contents_aliased) {
         result.AddAliasedIndex(index);
       }
     }
@@ -556,18 +521,13 @@
 
   // Free all temporary allocations.
   TF_RETURN_IF_ERROR(
-      buffer_allocations.TearDown(buffers_in_result, assignment_.get()));
+      buffer_allocations.TearDown(buffers_in_result, allocations_));
 
   // Free allocations for arguments.
   MarkToBeReleasedArguments(absl::MakeSpan(arguments), result);
   return std::move(result);
 }
 
-const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
-  return assignment_->dataflow_analysis().GetInstructionValueSet(
-      module().entry_computation()->root_instruction());
-}
-
 int64 GpuExecutable::SizeOfGeneratedCodeInBytes() const {
   // Non-empty PTX but empty cubin: compilation must have failed, return
   // "unknown".
@@ -575,9 +535,8 @@
     return -1;
   }
   int64 size = binary().size();
-  for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
-       ++i) {
-    const BufferAllocation& allocation = assignment_->GetAllocation(i);
+  for (BufferAllocation::Index i = 0; i < allocations_.size(); ++i) {
+    const BufferAllocation& allocation = allocations_[i];
     if (allocation.is_constant()) {
       size += allocation.size();
     }
@@ -585,5 +544,46 @@
   return size;
 }
 
+StatusOr<absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>>
+GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) {
+  const HloInstruction* root =
+      hlo_module.entry_computation()->root_instruction();
+
+  InstructionValueSet root_value_set =
+      assignment.dataflow_analysis().GetInstructionValueSet(root);
+
+  if (root_value_set.IsAmbiguous()) {
+    return Unimplemented("Points-to set of root instruction is ambiguous");
+  }
+
+  using OutputInfoMap =
+      absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
+  OutputInfoMap output;
+  TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+      root->shape(),
+      [&](const Shape& /*sub_shape*/, const ShapeIndex& index) -> Status {
+        const auto& sources = root_value_set.element(index);
+        // The points-to set is unambiguous so the set should be a
+        // singleton. That is, we know exactly which instruction
+        // produced the array at this element.
+        CHECK_EQ(1, sources.values().size());
+        HloInstruction* src_hlo = sources.values()[0]->instruction();
+
+        GpuExecutable::OutputInfo& info = output[index];
+        info.passthrough = src_hlo->opcode() == HloOpcode::kParameter;
+        TF_ASSIGN_OR_RETURN(
+            const BufferAllocation::Slice slice,
+            assignment.GetUniqueSlice(src_hlo, sources.values()[0]->index()));
+        CHECK_EQ(slice.offset(), 0) << "Parameter should get its own slice";
+        info.allocation_index = slice.index();
+
+        output[index].alias_config =
+            hlo_module.input_output_alias_config().GetAliasedParameter(index);
+
+        return Status::OK();
+      }));
+  return output;
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 613880f..23eb54f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -55,17 +55,36 @@
     int allocation_index = -1;
   };
 
+  struct OutputInfo {
+    // Output is passed-through from a parameter.
+    bool passthrough;
+
+    // Corresponding allocation index.
+    int allocation_index;
+
+    // Whether this output is hinted to alias a parameter (BufferAllocation*
+    // would indicate the aliased parameter), and what kind of alias it is.
+    absl::optional<HloInputOutputAliasConfig::Alias> alias_config;
+  };
+
+  struct Params {
+    std::string asm_text;
+    std::vector<uint8> binary;
+    GpuVersion gpu_version;
+    std::unique_ptr<const ThunkSchedule> thunk_schedule;
+    std::vector<ConstantInfo> constants;
+    absl::flat_hash_map<ShapeIndex, OutputInfo> output_info;
+    std::unique_ptr<HloModule> hlo_module;
+    std::vector<BufferAllocation> allocations;
+    std::unique_ptr<BufferAssignmentProto> debug_buffer_assignment;
+    std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data = nullptr;
+    std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map = nullptr;
+  };
+
   // We need to share ownership of hlo_module and assignment with profiler to
   // safely keep a reference to these objects during tracing period, thus they
   // are passed as shared pointers.
-  GpuExecutable(const string& text, const std::vector<uint8>& binary,
-                GpuVersion gpu_version,
-                std::unique_ptr<const ThunkSchedule> thunk_schedule,
-                std::shared_ptr<HloModule> hlo_module,
-                std::shared_ptr<const BufferAssignment> assignment,
-                std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
-                std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
-                std::vector<ConstantInfo> constants);
+  explicit GpuExecutable(Params params);
   ~GpuExecutable() override;
 
   int64 SizeOfGeneratedCodeInBytes() const override;
@@ -94,8 +113,8 @@
       std::vector<ExecutionInput> arguments,
       HloExecutionProfile* hlo_execution_profile) override;
 
-  std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {
-    return assignment_;
+  absl::Span<const BufferAllocation> GetAllocations() const {
+    return allocations_;
   }
 
  private:
@@ -109,10 +128,6 @@
                        bool block_host_until_done,
                        HloExecutionProfile* hlo_execution_profile);
 
-  // Returns the value set of the root instruction of the entry
-  // computation. Uses dataflow analysis from buffer assignment.
-  const InstructionValueSet& GetRootValueSet() const;
-
   using BufferAllocToDeviceMemoryMap =
       absl::flat_hash_map<BufferAllocation::Index, se::DeviceMemoryBase>;
 
@@ -166,7 +181,9 @@
 
   // Owns the buffer data at runtime. It provides information to allocate
   // memory for every output/temp buffers.
-  const std::shared_ptr<const BufferAssignment> assignment_;
+  const std::vector<BufferAllocation> allocations_;
+
+  std::shared_ptr<BufferAssignmentProto> debug_buffer_assignment_;
 
   // Cache of module handles and constant buffer allocation maps used by
   // `ResolveConstantGlobals`.
@@ -177,10 +194,14 @@
       module_globals_ TF_GUARDED_BY(module_handle_mutex_);
 
   std::vector<ConstantInfo> constants_;
+  const absl::flat_hash_map<ShapeIndex, OutputInfo> output_info_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable);
 };
 
+StatusOr<absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>>
+GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment);
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc
index 4f4409a..4c15a8f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc
@@ -19,8 +19,8 @@
 
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 
 namespace mlir {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 4a55947..ce22ee7 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -21,7 +21,7 @@
 
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/Module.h"
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
@@ -647,16 +647,39 @@
   return true;
 }
 
+// Given an LMHLO op, returns the operand index of the first output operand.
+//
+// Notice that an operand alised to an output isn't an output, even though in
+// that case WritesMlirBuffer() returns true on that operand.
+//
+// An operand is !WritesMlirBuffer() || equals (aliases) to a later operand. An
+// output is the opposite, being both WritesMlirBuffer() and does not equal to
+// any later operand.
+int PartitionLmhloOperandsAndOutputs(mlir::Operation* op) {
+  CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
+
+  int i;
+  for (i = op->getOperands().size() - 1; i >= 0; i--) {
+    const bool aliased =
+        std::find(op->getOperands().begin() + i + 1, op->getOperands().end(),
+                  op->getOperand(i)) != op->getOperands().end();
+    if (!WritesMlirBuffer(op, op->getOperand(i)) || aliased) {
+      break;
+    }
+  }
+  return i + 1;
+}
+
 std::vector<mlir::Value> GetHloOperands(mlir::Operation* op) {
   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
     return ToStdVector(fusion.getInputBuffers());
   }
   if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
+    int output_start = PartitionLmhloOperandsAndOutputs(op);
     std::vector<mlir::Value> operands;
-    for (auto buffer : op->getOperands()) {
-      if (!WritesMlirBuffer(op, buffer)) {
-        operands.push_back(buffer);
-      }
+    operands.reserve(output_start);
+    for (int i = 0; i < output_start; i++) {
+      operands.push_back(op->getOperand(i));
     }
     return operands;
   }
@@ -672,11 +695,10 @@
     return ToStdVector(fusion.getOutputBuffers());
   }
   if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
+    int output_start = PartitionLmhloOperandsAndOutputs(op);
     std::vector<mlir::Value> outputs;
-    for (auto buffer : op->getOperands()) {
-      if (WritesMlirBuffer(op, buffer)) {
-        outputs.push_back(buffer);
-      }
+    for (int i = output_start; i < op->getNumOperands(); i++) {
+      outputs.push_back(op->getOperand(i));
     }
     return outputs;
   }
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 614f15d..bc1d11d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -18,6 +18,7 @@
 
 #include <utility>
 
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Value.h"
 #include "mlir/IR/Operation.h"  // from @llvm-project
@@ -246,11 +247,17 @@
   return s;
 }
 
+int PartitionLmhloOperandsAndOutputs(mlir::Operation* op);
 std::vector<mlir::Value> GetHloOperands(mlir::Operation* op);
 std::vector<mlir::Value> GetHloOutputs(mlir::Operation* op);
 
 bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand);
 
+template <typename T>
+std::vector<T> ToStdVector(const llvm::SmallVectorImpl<T>& v) {
+  return std::vector<T>(v.begin(), v.end());
+}
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils_test.cc
new file mode 100644
index 0000000..9eec224
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils_test.cc
@@ -0,0 +1,78 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/Parser.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+
+TEST(IrEmissionUtilsTest, TestOperandPartitionNoAlias) {
+  mlir::MLIRContext context;
+  mlir::mhlo::registerAllMhloDialects(context.getDialectRegistry());
+
+  auto module = mlir::parseSourceString(R"(
+    func @foo(%arg0 : memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>) {
+      "lmhlo.add" (%arg0, %arg1, %arg2) : (memref<f32>, memref<f32>, memref<f32>) -> ()
+      "lmhlo.terminator" () : () -> ()
+    }
+  )",
+                                        &context);
+  mlir::FuncOp func = mlir::cast<mlir::FuncOp>(module->lookupSymbol("foo"));
+  mlir::Operation* op = &func.body().front().front();
+  EXPECT_EQ(2, PartitionLmhloOperandsAndOutputs(op));
+}
+
+TEST(IrEmissionUtilsTest, TestOperandPartitionWithAlias0) {
+  mlir::MLIRContext context;
+  mlir::mhlo::registerAllMhloDialects(context.getDialectRegistry());
+
+  auto module = mlir::parseSourceString(R"(
+    func @foo(%arg0 : memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>) {
+      "lmhlo.add" (%arg0, %arg1, %arg0) : (memref<f32>, memref<f32>, memref<f32>) -> ()
+      "lmhlo.terminator" () : () -> ()
+    }
+  )",
+                                        &context);
+  mlir::FuncOp func = mlir::cast<mlir::FuncOp>(module->lookupSymbol("foo"));
+  mlir::Operation* op = &func.body().front().front();
+  EXPECT_EQ(2, PartitionLmhloOperandsAndOutputs(op));
+}
+
+TEST(IrEmissionUtilsTest, TestOperandPartitionWithAlias1) {
+  mlir::MLIRContext context;
+  mlir::mhlo::registerAllMhloDialects(context.getDialectRegistry());
+
+  auto module = mlir::parseSourceString(R"(
+    func @foo(%arg0 : memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>) {
+      "lmhlo.add" (%arg0, %arg1, %arg1) : (memref<f32>, memref<f32>, memref<f32>) -> ()
+      "lmhlo.terminator" () : () -> ()
+    }
+  )",
+                                        &context);
+  mlir::FuncOp func = mlir::cast<mlir::FuncOp>(module->lookupSymbol("foo"));
+  mlir::Operation* op = &func.body().front().front();
+  EXPECT_EQ(2, PartitionLmhloOperandsAndOutputs(op));
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index d92884f..ba72087 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -31,6 +31,7 @@
 #include "absl/types/optional.h"
 #include "absl/types/span.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
@@ -40,12 +41,16 @@
 #include "llvm/IR/Module.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
+#include "mlir/IR/Verifier.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
 #include "tensorflow/compiler/mlir/utils/name_utils.h"
+#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
@@ -58,6 +63,7 @@
 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
 #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
@@ -71,6 +77,7 @@
 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
@@ -105,6 +112,7 @@
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/bits.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/logging.h"
 
 #if GOOGLE_CUDA
@@ -255,6 +263,8 @@
                               case HloOpcode::kPower:
                               case HloOpcode::kAtan2:
                                 return true;
+                              case HloOpcode::kReduce:
+                                return !instr->shape().IsArray();
                               default:
                                 return false;
                             }
@@ -302,6 +312,11 @@
         case HloOpcode::kPower:
         case HloOpcode::kAtan2:
           return true;
+        case HloOpcode::kReduce:
+          if (instr.getNumResults() > 1) {
+            return true;
+          }
+          break;
         default:
           break;
       }
@@ -329,12 +344,22 @@
   return true;
 }
 
+std::vector<mlir::Operation*> GetOutputOps(mlir::lmhlo::FusionOp fusion) {
+  llvm::SetVector<mlir::Operation*> ops;
+  for (mlir::Value output_value : fusion.getFusionResults()) {
+    ops.insert(output_value.getDefiningOp());
+  }
+  return std::vector<mlir::Operation*>(ops.begin(), ops.end());
+}
+
 // Computes the maximum valid unroll factor for a given instruction.
 int ComputeMaxUnrollFactor(const Shape& shape,
                            const HloModuleConfig& hlo_module_config) {
   int max_unroll_factor =
       hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
 
+  // Find the largest possible power of two to unroll by.
+  // TODO(kramerb): Make this smarter.
   int64 num_elements = ShapeUtil::ElementsIn(shape);
   for (int i = max_unroll_factor; i > 1; i /= 2) {
     if (num_elements % i == 0) {
@@ -348,14 +373,39 @@
 
 // Computes the maximum valid unroll factor for a given instruction.
 int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
-  // Find the largest possible power of two to unroll by.
-  // TODO(kramerb): Make this smarter.
   const Shape& element_shape = hlo->IsMultiOutputFusion()
                                    ? ShapeUtil::GetSubshape(hlo->shape(), {0})
                                    : hlo->shape();
   return ComputeMaxUnrollFactor(element_shape, hlo->GetModule()->config());
 }
 
+// Computes the maximum valid unroll factor for a given instruction.
+int ComputeMaxUnrollFactor(mlir::Operation* op,
+                           const HloModuleConfig& hlo_module_config) {
+  Shape element_shape = [&] {
+    std::vector<Shape> shapes;
+    // Detect multi-output fusion. Notice that for a reduce in the fusion that
+    // returns a tuple, we don't want to treat it as multi-output fusion. We
+    // want to pass that tuple into ComputeMaxUnrollFactor below. For an actual
+    // MOF, just pass the first element of the root tuple.
+    if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
+      std::vector<mlir::Operation*> fusion_outputs = GetOutputOps(fusion);
+      for (mlir::Value result : fusion_outputs[0]->getResults()) {
+        shapes.push_back(TypeToShape(result.getType()));
+      }
+    } else {
+      for (mlir::Value result : op->getResults()) {
+        shapes.push_back(TypeToShape(result.getType()));
+      }
+    }
+    if (shapes.size() > 1) {
+      return ShapeUtil::MakeTupleShape(shapes);
+    }
+    return shapes[0];
+  }();
+  return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
+}
+
 // Returns the llvm type for the indices used in the kernel that contains the
 // hlo instruction. Such indices include the index for the parallel loop and
 // the indices for the tensors accessed by the kernel. The return type is i32
@@ -612,10 +662,14 @@
 }
 
 Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
+  if (hlo->IsElementwise()) {
+    TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
+    return EmitUsingElementalIrEmitter(input);
+  }
   return IrEmitter::DefaultAction(hlo);
 }
 
-Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
+Status IrEmitterUnnested::EmitUsingElementalIrEmitter(MlirEmitterInput input) {
   // Replace unnested op with a fused nested op.
   //
   // TODO(timshen): Ultimately this should be a pass. It's currently not a pass,
@@ -669,19 +723,48 @@
       output_shape = ShapeUtil::MakeTupleShape(output_shapes);
     }
   } else {
-    LOG(FATAL) << "Unimplemented default action for mlir op: "
-               << MlirToString(input.op);
+    // Try to generically convert any LMHLO ops to LMHLO fusion + the
+    // corresponding MHLO op. Currently we've only looked at elementwise ops and
+    // they seem to be well covered.
+    //
+    // TODO(timshen): Moving forward, we should make it cover all ops if
+    // possible, and only special-case the ones it can't.
+    std::vector<mlir::Value> outputs;
+    mlir::Operation* new_op;
+    {
+      auto operands = GetHloOperands(input.op);
+      outputs = GetHloOutputs(input.op);
+      TF_RET_CHECK(outputs.size() == 1) << MlirToString(input.op);
+
+      std::vector<mlir::Value> loads = load_memrefs(operands);
+      std::string mhlo_op_name = mlir::hlo::LmhloToMhloOpName(
+          input.op->getName().getStringRef(), input.op->getContext());
+      TF_RET_CHECK(!mhlo_op_name.empty())
+          << "No corresponding MHLO op for given LMHLO op: "
+          << MlirToString(input.op);
+      mlir::OperationState op_state(loc, mhlo_op_name);
+
+      mlir::BlockAndValueMapping mapper;
+      for (mlir::Region& region : input.op->getRegions()) {
+        mlir::Region* new_region = op_state.addRegion();
+        region.cloneInto(new_region, mapper);
+      }
+
+      op_state.addOperands(loads);
+      op_state.addAttributes(input.op->getAttrs());
+      op_state.addTypes({mlir::RankedTensorType::get(
+          outputs[0].getType().cast<mlir::MemRefType>().getShape(),
+          outputs[0].getType().cast<mlir::MemRefType>().getElementType())});
+      new_op = b.createOperation(op_state);
+    }
+    TF_RET_CHECK(mlir::succeeded(mlir::verify(new_op)));
+    output_shape = TypeToShape(outputs[0].getType());
+    HloFunctionImporter::SetLayoutForMlir(new_op, output_shape);
+    b.create<mlir::TensorStoreOp>(loc, new_op->getResult(0), outputs[0]);
   }
   input.op->erase();
   input.op = fusion;
-  int unroll_factor = 1;
-  // TODO(timshen): Port MayPreventVectorization as we add more ops into this
-  // function.
-  if (output_shape.IsArray()) {
-    unroll_factor = ComputeMaxUnrollFactor(output_shape, hlo_module_config_);
-  }
-  auto ret = EmitLoopFusionFromMlir(input, output_shape, unroll_factor);
-  return ret;
+  return EmitLoopFusionFromMlir(input, output_shape);
 }
 
 Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
@@ -948,9 +1031,12 @@
 }
 
 Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
+  using mlir::dyn_cast;
+  using mlir::isa;
+
   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(custom_call));
 
-  if (auto call = mlir::dyn_cast<mlir::lmhlo::CustomCallOp>(input.op)) {
+  if (auto call = dyn_cast<mlir::lmhlo::CustomCallOp>(input.op)) {
     if (call.call_target_name() == "PadToStatic") {
       return EmitPadToStaticFromMlir(input);
     }
@@ -960,11 +1046,24 @@
     return ThunkEmitter(this).HandleCustomCall(custom_call);
   }
 
-  if (mlir::isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(
-          input.op)) {
+  if (isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(input.op)) {
     return EmitGemmThunkFromMlir(input);
   }
 
+  if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
+                mlir::lmhlo_gpu::ConvForwardFusedOp,
+                mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
+                mlir::lmhlo_gpu::ConvBackwardFilterOp,
+                mlir::lmhlo_gpu::ConvBackwardInputOp>(input.op)) {
+    return EmitConvolutionThunkFromMlir(input);
+  }
+
+  if (isa<mlir::lmhlo_gpu::BatchNormTrainingOp,
+          mlir::lmhlo_gpu::BatchNormInferenceOp,
+          mlir::lmhlo_gpu::BatchNormGradOp>(input.op)) {
+    return ThunkEmitter(this).HandleCustomCall(custom_call);
+  }
+
 #if GOOGLE_CUDA
   if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(input.op)) {
     return EmitCholeskyThunkFromMlir(input);
@@ -975,6 +1074,118 @@
                        custom_call->custom_call_target());
 }
 
+Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
+  using mlir::dyn_cast;
+  using mlir::lmhlo_gpu::Activation;
+  using mlir::lmhlo_gpu::ConvBackwardFilterOp;
+  using mlir::lmhlo_gpu::ConvBackwardInputOp;
+  using mlir::lmhlo_gpu::ConvForwardFusedOp;
+  using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
+  using mlir::lmhlo_gpu::ConvForwardOp;
+
+  // Last 2 operands of the convolution operation are the result and scratch.
+  std::vector<BufferAllocation::Slice> operand_slices;
+  int64 num_operands = input.op->getNumOperands();
+  operand_slices.reserve(num_operands - 2);
+  for (mlir::Value operand : input.op->getOperands().drop_back(2)) {
+    TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(operand));
+    operand_slices.push_back(slice);
+  }
+
+  mlir::Value conv_result = input.op->getOperand(num_operands - 2);
+  mlir::Value scratch_result = input.op->getOperand(num_operands - 1);
+  TF_ASSIGN_OR_RETURN(auto conv_result_slice,
+                      GetAllocationSliceForMlir(conv_result));
+  TF_ASSIGN_OR_RETURN(auto scratch_slice,
+                      GetAllocationSliceForMlir(scratch_result));
+
+  auto apply_layout = [](const Shape& shape, mlir::ArrayAttr layout_attrib) {
+    mlir::SmallVector<int64, 4> minor_to_major = llvm::to_vector<4>(
+        llvm::map_range(layout_attrib, [](mlir::Attribute a) -> int64 {
+          return static_cast<int64>(a.cast<mlir::IntegerAttr>().getInt());
+        }));
+    return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
+                                          shape.dimensions(), minor_to_major);
+  };
+
+  GpuConvDescriptor descriptor;
+
+  auto fill_conv_descriptor = [&](auto op) {
+    descriptor.operand0_shape =
+        apply_layout(TypeToShape(input.op->getOperand(0).getType()),
+                     op.backend_config().operand_0_layout());
+    descriptor.operand1_shape =
+        apply_layout(TypeToShape(input.op->getOperand(1).getType()),
+                     op.backend_config().operand_1_layout());
+    descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
+                                           op.backend_config().result_layout());
+    descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
+    mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
+    mlir::DenseIntElementsAttr padding = op.padding().getValue();
+    mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();
+    mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue();
+    mlir::DenseElementsAttr window_reversal = op.window_reversal().getValue();
+    for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
+      WindowDimension* dim = descriptor.window.add_dimensions();
+      // Window size for a convolution is the same as the kernel size.
+      // Kernel size of the convolution is operand1_shape. We need to look at
+      // the convolution dimension numbers kernel spatial dimensions to get
+      // the window size.
+      int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
+      dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
+      dim->set_stride(window_strides.getValue<int64>(index));
+      dim->set_padding_low(padding.getValue<int64>(index));
+      dim->set_padding_high(padding.getValue<int64>(index));
+      dim->set_base_dilation(lhs_dilation.getValue<int64>(index));
+      dim->set_window_dilation(rhs_dilation.getValue<int64>(index));
+      dim->set_window_reversal(window_reversal.getValue<bool>(index));
+    }
+    descriptor.feature_group_count = op.feature_group_count();
+    descriptor.backend_config.set_algorithm(
+        op.backend_config().algorithm().getInt());
+    descriptor.backend_config.set_tensor_ops_enabled(
+        op.backend_config().tensor_ops_enabled().getValue());
+    descriptor.backend_config.set_conv_result_scale(
+        op.result_scale().convertToDouble());
+  };
+
+  auto set_activation_mode = [&](auto op) -> Status {
+    TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
+                        ConvertConvActivationMode(op.activation_mode()));
+    descriptor.backend_config.set_activation_mode(
+        static_cast<int64>(activation_mode));
+    return Status::OK();
+  };
+
+  if (auto op = dyn_cast<ConvForwardOp>(input.op)) {
+    descriptor.kind = CudnnConvKind::kForward;
+    fill_conv_descriptor(op);
+  } else if (auto op = dyn_cast<ConvBackwardInputOp>(input.op)) {
+    descriptor.kind = CudnnConvKind::kBackwardInput;
+    fill_conv_descriptor(op);
+  } else if (auto op = dyn_cast<ConvBackwardFilterOp>(input.op)) {
+    descriptor.kind = CudnnConvKind::kBackwardFilter;
+    fill_conv_descriptor(op);
+  } else if (auto op = dyn_cast<ConvForwardFusedOp>(input.op)) {
+    descriptor.kind = CudnnConvKind::kForwardActivation;
+    fill_conv_descriptor(op);
+    TF_RETURN_IF_ERROR(set_activation_mode(op));
+  } else if (auto op = dyn_cast<ConvForwardFusedSideInputOp>(input.op)) {
+    descriptor.kind = CudnnConvKind::kForwardActivation;
+    fill_conv_descriptor(op);
+    TF_RETURN_IF_ERROR(set_activation_mode(op));
+    descriptor.backend_config.set_side_input_scale(
+        op.side_input_scale().convertToDouble());
+  } else {
+    return InternalError("Unexpected operation");
+  }
+  TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
+  AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
+      input.thunk_info, std::move(config), std::move(operand_slices),
+      conv_result_slice, scratch_slice));
+  return Status::OK();
+}
+
 Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) {
   auto build_gemm_config = [](auto op) {
     GpuGemmConfig config;
@@ -1205,8 +1416,7 @@
 // This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
 // subclass. The logic is de-virtualized and less scattered.
 Status IrEmitterUnnested::EmitLoopFusionFromMlir(MlirEmitterInput input,
-                                                 const Shape& output_shape,
-                                                 int unroll_factor) {
+                                                 const Shape& output_shape) {
   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(input.op);
   MlirEmitterContext context;
   context.SetOperation(fusion);
@@ -1253,6 +1463,11 @@
       auto element_generator,
       fused_emitter.GetGenerator(fused_computation->root_instruction()));
 
+  int unroll_factor = 1;
+  if (!MayPreventVectorization(fusion)) {
+    unroll_factor = ComputeMaxUnrollFactor(fusion, hlo_module_config_);
+  }
+
   Shape element_shape = context.output_shapes[0];
   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
       element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
@@ -1431,12 +1646,7 @@
     return Status::OK();
   }
 
-  int unroll_factor = 1;
-  if (!MayPreventVectorization(*fusion)) {
-    unroll_factor = ComputeMaxUnrollFactor(fusion);
-  }
-
-  return EmitLoopFusionFromMlir(mlir_input, fusion->shape(), unroll_factor);
+  return EmitLoopFusionFromMlir(mlir_input, fusion->shape());
 }
 
 Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
@@ -1471,7 +1681,7 @@
     return Status::OK();
   }
 
-  return DefaultActionForMlir(input);
+  return EmitUsingElementalIrEmitter(input);
 }
 
 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
@@ -1502,7 +1712,7 @@
     return EmitReductionFromOrToContiguousDimensions(mlir_input);
   }
 
-  return DefaultActionForMlir(mlir_input);
+  return EmitUsingElementalIrEmitter(mlir_input);
 }
 
 Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
@@ -1553,6 +1763,17 @@
   return Status::OK();
 }
 
+Status IrEmitterUnnested::AssertNonDeterminismIsOkay(const string& op_name) {
+  if (hlo_module_config_.debug_options().xla_gpu_deterministic_ops()) {
+    return Unimplemented(
+        "HLO instruction %s does not have a deterministic implementation, "
+        "but run-to-run determinism is required by "
+        "--xla_gpu_deterministic_ops.",
+        op_name);
+  }
+  return Status::OK();
+}
+
 Status IrEmitterUnnested::HandleSelectAndScatter(
     HloInstruction* select_and_scatter) {
   const Window& window = select_and_scatter->window();
@@ -1568,6 +1789,8 @@
         "Dilation for SelectAndScatter not implemented on GPU.");
   }
 
+  TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(select_and_scatter->name()));
+
   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(select_and_scatter));
   return EmitSelectAndScatterFromMlir(input);
 }
@@ -1853,6 +2076,9 @@
 }
 
 Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
+  if (!scatter->unique_indices()) {
+    TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(scatter->name()));
+  }
   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(scatter));
   return EmitScatterFromMlir(input);
 }
@@ -1954,6 +2180,9 @@
 
 Status IrEmitterUnnested::EmitScatter(const ScatterDescriptor& desc,
                                       Thunk* thunk) {
+  if (!desc.unique_indices) {
+    TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(desc.name));
+  }
   auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
     std::vector<llvm::Value*> raw_window_multidim;
     std::vector<llvm::Value*> input_scatter_multidim;
@@ -2185,7 +2414,18 @@
     for (HloComputation* computation : module->computations()) {
       for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
         if (instr->opcode() == HloOpcode::kConstant) {
-          instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(*instr));
+          // Notice that IR emitters use the name of constants as LLVM symbol
+          // names, therefore it's important to not let these constants in the
+          // new module collide with constants in the original module by names.
+          // Unique them by prepending the module name.
+          //
+          // TODO(timshen): A better solution would be to plumb the exact
+          // constant names through original HLO -> LHLO -> MHLO -> HLO. This is
+          // hard because XLA builder doesn't support setting names. Revisit
+          // this once we get rid of this function, or don't rely on the op name
+          // (which shouldn't be the identity) to generate LLVM symbols.
+          instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(
+              module->name() + "_" + instr->name()));
         }
         if (instr->shape().IsTuple() &&
             computation == module->entry_computation() &&
@@ -2431,6 +2671,103 @@
   return Status::OK();
 }
 
+Status IrEmitterUnnested::HandleAllGather(HloInstruction* hlo) {
+  VLOG(2) << "AllGather; replica count: " << hlo_module_config_.replica_count()
+          << "; operand count: " << hlo->operand_count()
+          << "; NCCL is enabled: " << NcclAllGatherThunk::NcclIsEnabled();
+
+  // Note the replica_count == 1 case is handled via device-to-device copy
+  // below.
+  bool should_use_nccl_thunk = hlo_module_config_.replica_count() > 1 &&
+                               NcclAllGatherThunk::CanImplement(hlo);
+
+  if (should_use_nccl_thunk) {
+    std::vector<NcclAllGatherThunk::Buffer> buffers;
+    std::vector<BufferAllocation::Slice> tuple_element_buffers;
+    buffers.resize(hlo->operand_count());
+    tuple_element_buffers.reserve(hlo->operand_count());
+    CHECK(hlo->shape().IsArray() && hlo->operand_count() == 1 ||
+          hlo->shape().IsTuple() &&
+              hlo->shape().tuple_shapes_size() == hlo->operand_count());
+    for (int i = 0; i < hlo->operand_count(); ++i) {
+      CHECK(hlo->operand(i)->shape().IsArray())
+          << "Operands to all-gather must be arrays: " << hlo->ToString();
+      buffers[i].element_count =
+          ShapeUtil::ElementsIn(hlo->operand(i)->shape());
+      buffers[i].source_buffer = GetAllocationSlice(*hlo->operand(i));
+      buffers[i].destination_buffer = GetAllocationSlice(
+          *hlo, hlo->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
+      tuple_element_buffers.push_back(buffers[i].destination_buffer);
+    }
+    NcclAllGatherConfig config =
+        GetNcclAllGatherConfig(hlo, hlo_module_config_.replica_count());
+    auto all_gather_thunk = absl::make_unique<NcclAllGatherThunk>(
+        GetThunkInfo(hlo), std::move(config),
+        /*buffers=*/std::move(buffers));
+    if (hlo->shape().IsTuple()) {
+      std::vector<std::unique_ptr<Thunk>> thunks;
+      thunks.push_back(std::move(all_gather_thunk));
+      thunks.push_back(absl::make_unique<TupleThunk>(
+          Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*hlo)));
+      AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
+          GetThunkInfo(hlo), std::move(thunks)));
+    } else {
+      AddThunkToThunkSequence(std::move(all_gather_thunk));
+    }
+
+    return Status::OK();
+  }
+
+  if (hlo_module_config_.replica_count() != 1) {
+    string message = absl::StrFormat(
+        "Requested AllGather not implemented on GPU; replica_count: %d; "
+        "operand_count: %d; NCCL support: %d",
+        hlo_module_config_.replica_count(), hlo->operand_count(),
+        NcclAllGatherThunk::NcclIsEnabled());
+    if (hlo->operand_count() > 0) {
+      absl::StrAppendFormat(
+          &message, "; first operand array element-type: %s",
+          PrimitiveType_Name(hlo->operand(0)->shape().element_type()));
+    }
+    return Unimplemented("%s", message);
+  }
+
+  // All-gather with one operand and one replica is simply the identity
+  // function. Buffer assignment expects a copy, so that's what we do.
+  if (hlo->operand_count() == 1) {
+    CHECK(hlo->operand(0)->shape().IsArray())
+        << "Operands to all-gather must be arrays: " << hlo->ToString();
+    AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
+        GetThunkInfo(hlo),
+        /*source_address=*/GetAllocationSlice(*hlo->operand(0)),
+        /*destination_buffer=*/GetAllocationSlice(*hlo),
+        /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->shape())));
+    return Status::OK();
+  }
+
+  // One-replica all-gather with multiple operands produces a tuple of the
+  // inputs. Again, buffer assignment expects us to copy each.
+  std::vector<std::unique_ptr<Thunk>> thunks;
+  std::vector<BufferAllocation::Slice> tuple_element_buffers;
+  for (int64 i = 0; i < hlo->operand_count(); ++i) {
+    tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
+                                        .GetUniqueSlice(hlo, {i})
+                                        .ValueOrDie());
+    thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
+        Thunk::ThunkInfo(),
+        /*source_address=*/GetAllocationSlice(*hlo->operand(i)),
+        /*destination_buffer=*/tuple_element_buffers.back(),
+        /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(i)->shape())));
+  }
+
+  // Output a tuple of the buffers above.
+  thunks.push_back(absl::make_unique<TupleThunk>(
+      Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*hlo)));
+  AddThunkToThunkSequence(
+      absl::make_unique<SequentialThunk>(GetThunkInfo(hlo), std::move(thunks)));
+  return Status::OK();
+}
+
 Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
   VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count()
           << "; operand count: " << crs->operand_count()
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 919f4ec..cc14fa9 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -157,7 +157,7 @@
   }
 
   Status DefaultAction(HloInstruction* hlo) override;
-  Status DefaultActionForMlir(MlirEmitterInput input);
+  Status EmitUsingElementalIrEmitter(MlirEmitterInput input);
 
   // IrEmitterUnnested handles the following instructions differently from
   // IrEmitter. It also mixes in some special handling for custom kernels
@@ -168,6 +168,7 @@
   Status HandleConditional(HloInstruction* conditional) override;
   Status HandleConvolution(HloInstruction* convolution) override;
   Status HandleCustomCall(HloInstruction* custom_call) override;
+  Status EmitConvolutionThunkFromMlir(MlirEmitterInput input);
   Status EmitGemmThunkFromMlir(MlirEmitterInput input);
 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
   Status EmitCholeskyThunkFromMlir(MlirEmitterInput input);
@@ -175,7 +176,7 @@
   Status HandleFft(HloInstruction* fft) override;
   Status HandleFusion(HloInstruction* fusion) override;
   Status EmitLoopFusionFromMlir(MlirEmitterInput input,
-                                const Shape& output_shape, int unroll_factor);
+                                const Shape& output_shape);
   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
   Status HandleReduce(HloInstruction* reduce) override;
   Status HandleSelectAndScatter(HloInstruction* instruction) override;
@@ -192,6 +193,7 @@
   Status HandleSort(HloInstruction* sort) override;
   Status EmitSortFromMlir(MlirEmitterInput mlir_input);
   Status HandleTriangularSolve(HloInstruction* hlo) override;
+  Status HandleAllGather(HloInstruction* hlo) override;
   Status HandleAllReduce(HloInstruction* crs) override;
   Status HandleAllToAll(HloInstruction* hlo) override;
   Status HandleAfterAll(HloInstruction* after_all) override;
@@ -707,6 +709,8 @@
 
   Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const override;
 
+  Status AssertNonDeterminismIsOkay(const string& op_name);
+
   // The thunk sequence this IrEmitter generates for the input computation.
   ThunkSequence thunk_sequence_;
 
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index d6f1e57..338db13 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -838,11 +838,11 @@
   // Delete the first two lines, since they usually vary even when the rest of
   // the code is the same (but verify that they are what we expect).
   if (str.size() >= 13 && str.substr(0, 13) == "; ModuleID = ") {
-    auto pos = str.find("\n");
+    auto pos = str.find('\n');
     if (pos != std::string::npos) str = str.substr(pos + 1);
   }
   if (str.size() >= 18 && str.substr(0, 18) == "source_filename = ") {
-    auto pos = str.find("\n");
+    auto pos = str.find('\n');
     if (pos != std::string::npos) str = str.substr(pos + 1);
   }
   str += hlo_module_config.compilation_cache_key();
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
index e60f3bc..c715d31 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
@@ -30,7 +30,63 @@
 namespace gpu {
 
 // Multi-output fusion of sibling and producer-consumer instructions for the
-// GPU backend.
+// GPU backend to reduce memory bandwidth requirements.
+//
+//   0) Before multi-    1) Sibling multi-    2) Producer-consumer
+//      output fusion       output fusion        multi-output fusion
+//
+//          p                    p                    p
+//          |                    |                    |
+//          v                    v                    v
+//          A                    A               +-fusion--+
+//        /   \                  |               |    A    |
+//       |     |            +-fusion--+          |   / \   |
+//       v     v            |   / \   |          |  B   |  |
+//       B     C            |  B   C  |          |  |   |  |
+//        \   /             |  |   |  |          |  v   v  |
+//         v v              |  v   v  |          |  tuple  |
+//        ROOT              |  tuple  |          +---------+
+//                          +---------+            /    \
+//                            /    \            gte_b  gte_a
+//                         gte_b  gte_c           |      |
+//                           |      |             |      v
+//                            \    /              |      C
+//                             v  v                \    /
+//                             ROOT                 v  v
+//                                                  ROOT
+//
+// Multi-output fusion ops have a tuple op at their root containing multiple
+// elements as outputs. GetTupleElement ops (depicted as gte_* above) are
+// inserted to extract tuple elements for consumers.
+//
+// The two different flavors of multi-output fusion this pass performs are
+// depicted above.
+// 1) Fusion of sibling ops reduces memory bandwidth requirements, because
+//    common input parameters have to be read only once.
+// 2) Fusion of producer-consumer ops reduces memory bandwidth requirements by
+//    saving one read from memory. In the example above, B does not need to read
+//    the output of A from memory, while C still does (using gte_a).
+// Note that sibling (1) and producer-consumer (2) multi-output fusion can be
+// combined.
+//
+// The GpuMultiOutputFusion pass modifies the HLO in reverse post-order (defs
+// before uses). First, it attempts to fuse the consumer ops of the current op,
+// which are siblings (1). Hereafter, it attempts to fuse the current op with
+// one of its consumers (2). This order avoids a phase ordering issue (described
+// in go/fusionfusion). It ensures that all GetTupleElement ops inserted as a
+// by-product of multi-output fusion will occur before the current op in the
+// order of traversal, and hence, not get into the way of subsequent fusion
+// attempts.
+//
+// The GpuMultiOutputFusion pass ensures several conditions are met for fusion.
+// Some of them are relevant for correctness. In particular, no cycles must be
+// introduced into the HLO module. Moreover, the code emitters for multi-output
+// fusion must support the combination of ops and their shapes. Other
+// restrictions are rather arbitrary and lifting them could be beneficial.
+// * Sibling fusion (1) requires at least one op to be a kFusion.
+// * Sibling fusion (1) does not fuse kInput fusions with kLoop fusions, i.e.
+//   the fusion kinds must match.
+
 class GpuMultiOutputFusion : public HloModulePass {
  public:
   GpuMultiOutputFusion() = default;
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
new file mode 100644
index 0000000..fa456ab
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
@@ -0,0 +1,109 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
+
+#include <chrono>  // NOLINT (required by TF interfaces)
+#include <cstdlib>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_format.h"
+#if GOOGLE_CUDA
+#include "third_party/nccl/nccl.h"
+#elif TENSORFLOW_USE_ROCM
+#include "rocm/include/rccl/rccl.h"
+#endif
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+namespace gpu {
+
+NcclAllGatherConfig GetNcclAllGatherConfig(const HloInstruction* hlo,
+                                           int64 replica_count) {
+  NcclAllGatherConfig config;
+  config.config = GetNcclCollectiveConfig(hlo, replica_count);
+  return config;
+}
+
+/*static*/ bool NcclAllGatherThunk::CanImplement(const HloInstruction* hlo) {
+  auto operands_are_supported = [hlo]() {
+    return absl::c_all_of(hlo->operands(), [](HloInstruction* operand) {
+      return LayoutUtil::IsDenseArray(operand->shape()) &&
+             ToNcclDataType(operand->shape().element_type()).ok();
+    });
+  };
+  return (Cast<HloAllGatherInstruction>(hlo)->all_gather_dimension() == 0) &&
+         operands_are_supported();
+}
+
+NcclAllGatherThunk::NcclAllGatherThunk(
+    ThunkInfo thunk_info, NcclAllGatherConfig config,
+    std::vector<NcclAllGatherThunk::Buffer> buffers)
+    : NcclCollectiveThunk(Thunk::kNcclAllGather, thunk_info),
+      config_(std::move(config)),
+      buffers_(std::move(buffers)) {
+  CHECK_EQ(config_.config.operand_count, buffers_.size());
+}
+
+Status NcclAllGatherThunk::RunNcclCollective(const ExecuteParams& params,
+                                             ncclComm_t comm) {
+  int device_ordinal = params.stream->parent()->device_ordinal();
+  VLOG(3) << "Performing all-gather from device ordinal: " << device_ordinal;
+
+  cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
+      params.stream->implementation()->GpuStreamMemberHack());
+
+  XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
+  for (size_t i = 0; i < buffers_.size(); ++i) {
+    const Buffer& buffer = buffers_[i];
+    const void* send_buffer =
+        params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
+            .opaque();
+    void* recv_buffer =
+        params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
+            .opaque();
+
+    TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
+                        ToNcclDataType(config_.config.operand_element_type[i]));
+
+    VLOG(3) << absl::StreamFormat(
+        "Calling ncclAllGather(send_buffer=%p, recv_buffer=%p, count=%d, "
+        "comm=%p, stream=%p)",
+        send_buffer, recv_buffer, buffer.element_count,
+        static_cast<const void*>(comm), cu_stream);
+
+    XLA_CUDA_RETURN_IF_ERROR(ncclAllGather(send_buffer, recv_buffer,
+                                           buffer.element_count, datatype, comm,
+                                           *cu_stream));
+  }
+  XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
+
+  VLOG(3) << "Done performing all-gather for ordinal: " << device_ordinal;
+  return Status::OK();
+}
+
+const NcclCollectiveConfig& NcclAllGatherThunk::config() const {
+  return config_.config;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h
new file mode 100644
index 0000000..fe57b84
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h
@@ -0,0 +1,66 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace gpu {
+
+struct NcclAllGatherConfig {
+  NcclCollectiveConfig config;
+};
+
+NcclAllGatherConfig GetNcclAllGatherConfig(const HloInstruction* hlo,
+                                           int64 replica_count);
+
+// Thunk that performs a NCCL-based All-Gather among CUDA GPU-based replicas.
+class NcclAllGatherThunk : public NcclCollectiveThunk {
+ public:
+  struct Buffer {
+    int64 element_count;
+    BufferAllocation::Slice source_buffer;
+    BufferAllocation::Slice destination_buffer;
+  };
+
+  NcclAllGatherThunk(ThunkInfo thunk_info, NcclAllGatherConfig config,
+                     std::vector<Buffer> buffers);
+
+  // Returns whether the given instruction can be lowered to a nccl all-gather
+  // call.
+  static bool CanImplement(const HloInstruction* hlo);
+
+ protected:
+  Status RunNcclCollective(const ExecuteParams& params,
+                           ncclComm_t comm) override;
+
+  const NcclCollectiveConfig& config() const override;
+
+ private:
+  const NcclAllGatherConfig config_;
+  const std::vector<Buffer> buffers_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk_dummy.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk_dummy.cc
new file mode 100644
index 0000000..221cc60
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk_dummy.cc
@@ -0,0 +1,51 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+namespace xla {
+namespace gpu {
+
+NcclAllGatherConfig GetNcclAllGatherConfig(const HloInstruction* hlo,
+                                           int64 replica_count) {
+  return NcclAllGatherConfig();
+}
+
+NcclAllGatherThunk::NcclAllGatherThunk(
+    ThunkInfo thunk_info, NcclAllGatherConfig config,
+    std::vector<NcclAllGatherThunk::Buffer> buffers)
+    : NcclCollectiveThunk(Thunk::kNcclAllGather, thunk_info),
+      config_(std::move(config)),
+      buffers_(std::move(buffers)) {}
+
+/* static */ bool NcclAllGatherThunk::CanImplement(const HloInstruction* hlo) {
+  return false;
+}
+
+Status NcclAllGatherThunk::RunNcclCollective(const ExecuteParams&, ncclComm_t) {
+  return Unimplemented(
+      "NCCL support is not available: this binary was not built with a CUDA "
+      "compiler, which is necessary to build the NCCL source library.");
+}
+
+const NcclCollectiveConfig& NcclAllGatherThunk::config() const {
+  // This function will never be called.
+  const NcclCollectiveConfig* config = nullptr;
+  return *config;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk_dummy.cc
similarity index 100%
rename from tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc
rename to tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk_dummy.cc
diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_to_all_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk_dummy.cc
similarity index 100%
rename from tensorflow/compiler/xla/service/gpu/dummy_all_to_all_thunk.cc
rename to tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk_dummy.cc
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
index 03d289e..7174eef 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
@@ -24,12 +24,12 @@
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/strings/str_format.h"
+#include "absl/synchronization/mutex.h"
 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
 #include "tensorflow/compiler/xla/service/global_device_id.h"
 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
 
 namespace xla {
@@ -47,20 +47,6 @@
 //    GPUs are participating in the op, so we get or create a NcclClique
 //    containing those GPUs.
 //  - We perform the NCCL operation using the clique.
-//
-// Creating NCCL cliques is expensive, so we cache them.  Our policy is, a thunk
-// keeps alive all cliques it's ever used.  When the thunk is destroyed, it
-// releases its handle on the cliques, and cliques whose refcounts go to 0 are
-// destroyed.
-
-// Extra data stored in NcclCollectiveThunk that we didn't want to expose in the
-// header.  In particular, this stores the thunk's cache of all NcclCliques it's
-// ever used.  This causes those cliques to stay alive as long as the thunk
-// lives, which is how we avoid expensive reinitialization of NCCL cliques.
-struct NcclCollectiveConfig::AuxData {
-  tensorflow::mutex mu;
-  absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques TF_GUARDED_BY(mu);
-};
 
 NcclCollectiveConfig::NcclCollectiveConfig() = default;
 NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig&&) = default;
@@ -87,7 +73,6 @@
     config.collective_op_kind = RendezvousKey::kCrossReplica;
     config.op_id = static_cast<int64>(hlo->GetModule()->unique_id());
   }
-  config.aux_data = std::make_unique<NcclCollectiveConfig::AuxData>();
   return config;
 }
 
@@ -137,12 +122,9 @@
 
   TF_RETURN_IF_ERROR(RunNcclCollective(params, comm));
 
-  // Keep the clique we used alive for as long as this Thunk lives.  Creating
-  // new NCCL cliques is expensive, and this is how we avoid thrashing them.
-  {
-    tensorflow::mutex_lock lock(config().aux_data->mu);
-    config().aux_data->cliques.insert(std::move(locked_clique.clique));
-  }
+  // Keep the clique we used alive for as long as this thunk lives.
+  absl::MutexLock lock(&mu_);
+  cliques_.insert(std::move(locked_clique.clique));
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
index 7f60c70..3343fc5 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
@@ -16,6 +16,7 @@
 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
 
+#include "absl/synchronization/mutex.h"
 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -28,6 +29,8 @@
 namespace xla {
 namespace gpu {
 
+struct NcclClique;
+
 struct NcclCollectiveConfig {
   NcclCollectiveConfig();
   NcclCollectiveConfig(NcclCollectiveConfig&&);
@@ -41,12 +44,6 @@
   std::vector<ReplicaGroup> replica_groups;
   RendezvousKey::CollectiveOpKind collective_op_kind;
   int64 op_id;
-  // Extra data stored in NcclCollectiveConfig whose types we don't want exposed
-  // in the header file.  (This is mainly because the implementation of
-  // NcclCollectiveConfig is different depending on whether CUDA is enabled in
-  // the build, and we don't want to expose *that* mess in the header.)
-  struct AuxData;
-  std::unique_ptr<AuxData> aux_data;
 };
 
 NcclCollectiveConfig GetNcclCollectiveConfig(const HloInstruction* hlo,
@@ -65,12 +62,19 @@
   // error.
   static bool NcclIsEnabled();
 
-  Status ExecuteOnStream(const ExecuteParams& params) override;
+  Status ExecuteOnStream(const ExecuteParams& params) override
+      ABSL_LOCKS_EXCLUDED(mu_);
 
  protected:
   virtual Status RunNcclCollective(const ExecuteParams& params,
                                    ncclComm_t comm) = 0;
   virtual const NcclCollectiveConfig& config() const = 0;
+
+ private:
+  // Creating NCCL cliques is expensive, so we cache them.
+  absl::Mutex mu_;
+  absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques_
+      ABSL_GUARDED_BY(mu_);
 };
 
 }  // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/dummy_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk_dummy.cc
similarity index 97%
rename from tensorflow/compiler/xla/service/gpu/dummy_collective_thunk.cc
rename to tensorflow/compiler/xla/service/gpu/nccl_collective_thunk_dummy.cc
index 0c49b2d..fc5ea04 100644
--- a/tensorflow/compiler/xla/service/gpu/dummy_collective_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk_dummy.cc
@@ -19,7 +19,7 @@
 namespace xla {
 namespace gpu {
 
-struct NcclCollectiveConfig::AuxData {};
+struct NcclClique {};
 
 NcclCollectiveConfig::NcclCollectiveConfig() = default;
 NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig &&) = default;
diff --git a/tensorflow/compiler/xla/service/gpu/dummy_nccl_test_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_test_utils_dummy.cc
similarity index 100%
rename from tensorflow/compiler/xla/service/gpu/dummy_nccl_test_utils.cc
rename to tensorflow/compiler/xla/service/gpu/nccl_test_utils_dummy.cc
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
index 81f54ba..b240138 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
@@ -16,9 +16,11 @@
 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
 
 #include <memory>
+#include <utility>
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/str_format.h"
+#include "absl/synchronization/blocking_counter.h"
 #include "absl/synchronization/mutex.h"
 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
@@ -221,9 +223,10 @@
       : Rendezvous(rendezvous_key),
         key_(std::move(rendezvous_key.global_devices)),
         local_participants_(local_participants),
-        callback_(callback) {}
+        callback_(callback),
+        counter_(nullptr) {}
 
-  StatusOr<ParticipantImplOutput> RunCollectiveOp(
+  StatusOr<LockedNcclClique> RunCollectiveOp(
       const NcclCliqueParticipantData&) override {
     tensorflow::mutex_lock lock(mu_);
     bool primary = !initialized_;
@@ -235,10 +238,13 @@
       initialized_ = true;
     }
     TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique, maybe_clique_);
+    std::unique_ptr<absl::MutexLock> clique_lock;
     if (primary) {
-      lock_ = std::make_shared<absl::MutexLock>(clique->mu());
+      clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
+      counter_ = new absl::BlockingCounter(local_participants_.size());
     }
-    return ParticipantImplOutput{primary, LockedNcclClique{clique, lock_}};
+    return LockedNcclClique(std::move(clique), std::move(clique_lock),
+                            counter_);
   }
 
  private:
@@ -247,7 +253,7 @@
   const NcclUniqueIdCallback* callback_;
 
   StatusOr<std::shared_ptr<NcclClique>> maybe_clique_;
-  std::shared_ptr<absl::MutexLock> lock_;
+  absl::BlockingCounter* counter_;
 };
 
 }  // namespace
@@ -282,6 +288,26 @@
   return local_participants;
 }
 
+LockedNcclClique::LockedNcclClique(std::shared_ptr<NcclClique> clique,
+                                   std::unique_ptr<absl::MutexLock> lock,
+                                   absl::BlockingCounter* counter)
+    : clique(std::move(clique)), lock_(std::move(lock)), counter_(counter) {}
+
+LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
+    : clique(std::move(other.clique)),
+      lock_(std::move(other.lock_)),
+      counter_(std::exchange(other.counter_, nullptr)) {}
+
+LockedNcclClique::~LockedNcclClique() {
+  if (counter_) {
+    counter_->DecrementCount();
+    if (lock_) {
+      counter_->Wait();  // Don't release lock until all threads are finished.
+      delete counter_;
+    }
+  }
+}
+
 StatusOr<LockedNcclClique> AcquireNcclClique(
     const RendezvousKey& rendezvous_key, int local_device_ordinal,
     se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.h b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
index 4a045d0..f24231d 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
@@ -20,6 +20,7 @@
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/synchronization/blocking_counter.h"
 #include "absl/synchronization/mutex.h"
 #if GOOGLE_CUDA
 #include "third_party/nccl/nccl.h"
@@ -110,13 +111,21 @@
     const std::vector<GlobalDeviceId>& participants,
     const std::vector<GlobalDeviceId>* local_devices);  // may be null
 
-struct LockedNcclClique {
+class LockedNcclClique {
+ public:
+  LockedNcclClique(std::shared_ptr<NcclClique> clique,
+                   std::unique_ptr<absl::MutexLock> lock,
+                   absl::BlockingCounter* counter);
+  LockedNcclClique(LockedNcclClique&&);
+  ~LockedNcclClique();
+
   std::shared_ptr<NcclClique> clique;
+
+ private:
   // Must come after clique, so it is destroyed first.
-  // This lock prevents other threads from using this clique. All of the threads
-  // involved should hold onto the lock until they have finished with their
-  // communicator.
-  std::shared_ptr<absl::MutexLock> lock;
+  // One thread holds a lock (it is null in the others).
+  std::unique_ptr<absl::MutexLock> lock_;
+  absl::BlockingCounter* counter_;
 };
 
 // Acquires a locked NCCL clique for use in NCCL collective operations.
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 3225cd2..070b8a1 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -51,6 +51,7 @@
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
+#include "tensorflow/stream_executor/gpu/gpu_driver.h"
 
 namespace xla {
 namespace gpu {
@@ -299,7 +300,8 @@
 NVPTXCompiler::CompileTargetBinary(const HloModule* module,
                                    llvm::Module* llvm_module,
                                    GpuVersion gpu_version,
-                                   se::StreamExecutor* stream_exec) {
+                                   se::StreamExecutor* stream_exec,
+                                   bool relocatable) {
   std::pair<int, int> compute_capability =
       absl::get<std::pair<int, int>>(gpu_version);
 
@@ -338,7 +340,7 @@
 
   std::vector<uint8> cubin = CompileGpuAsmOrGetCachedResult(
       stream_exec, ptx, compute_capability.first, compute_capability.second,
-      module->config());
+      module->config(), relocatable);
 
   return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
                                                     std::move(cubin));
@@ -346,7 +348,7 @@
 
 std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
     se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
-    int cc_minor, const HloModuleConfig& hlo_module_config) {
+    int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable) {
   XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult");
   tensorflow::profiler::TraceMe activity(
       "PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo);
@@ -361,7 +363,7 @@
     tensorflow::mutex_lock lock(mutex_);
     std::tie(iter, inserted) = compilation_cache_.emplace(
         std::piecewise_construct,
-        std::forward_as_tuple(ptx, cc_major, cc_minor),
+        std::forward_as_tuple(ptx, cc_major, cc_minor, relocatable),
         std::forward_as_tuple());
     cache_ptx = &iter->first.ptx;
     cache_value = &iter->second;
@@ -375,9 +377,13 @@
     if (inserted) {
       CHECK(!cache_value->compilation_done);
       if (!ptx.empty()) {
-        StatusOr<std::vector<uint8>> maybe_cubin =
-            se::CompileGpuAsm(stream_exec->device_ordinal(), cache_ptx->c_str(),
-                              PtxOptsFromConfig(hlo_module_config));
+        auto ptxas_config = PtxOptsFromConfig(hlo_module_config);
+        if (relocatable) {
+          ptxas_config.extra_flags.push_back("-c");
+        }
+        StatusOr<std::vector<uint8>> maybe_cubin = se::CompileGpuAsm(
+            stream_exec->device_ordinal(), cache_ptx->c_str(), ptxas_config);
+
         if (maybe_cubin.ok()) {
           cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie();
           VLOG(2) << "Compiled PTX size:" << ptx.size()
@@ -445,5 +451,17 @@
   return cache_value->cubin_data;
 }
 
+StatusOr<std::vector<uint8>> NVPTXCompiler::LinkModules(
+    se::StreamExecutor* stream_exec, std::vector<std::vector<uint8>> modules) {
+  std::vector<stream_executor::CubinOrPTXImage> images;
+  images.reserve(modules.size());
+  for (auto& module : modules) {
+    images.push_back({"", std::move(module)});
+  }
+  return LinkGpuAsm(static_cast<se::gpu::GpuContext*>(
+                        stream_exec->implementation()->GpuContextHack()),
+                    images);
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index 3e19b35..5c78b48 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -52,9 +52,14 @@
 
   StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
       const HloModule* hlo_module, llvm::Module* llvm_module,
-      GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
+      GpuVersion gpu_version, se::StreamExecutor* stream_exec,
+      bool relocatable) override;
 
  private:
+  StatusOr<std::vector<uint8>> LinkModules(
+      se::StreamExecutor* stream_exec,
+      std::vector<std::vector<uint8>> modules) override;
+
   tensorflow::mutex mutex_;
 
   // When compiling an HLO module, we need to find a path to the nvvm libdevice
@@ -71,7 +76,7 @@
   // compiled cubin.  If compilation was unsuccessful, returns an empty vector.
   std::vector<uint8> CompileGpuAsmOrGetCachedResult(
       se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
-      int cc_minor, const HloModuleConfig& hlo_module_config);
+      int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable);
 
   // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor}
   // -> cubin so we don't recompile the same ptx twice.  This is important for
@@ -86,24 +91,32 @@
   // If compiling the ptx fails, we return an empty cubin, cross our fingers,
   // and leave compilation up to the driver.
   struct CompilationCacheKey {
-    CompilationCacheKey(std::string ptx, int cc_major, int cc_minor)
-        : ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor) {}
+    CompilationCacheKey(std::string ptx, int cc_major, int cc_minor,
+                        bool relocatable)
+        : ptx(std::move(ptx)),
+          cc_major(cc_major),
+          cc_minor(cc_minor),
+          relocatable(relocatable) {}
     string ptx;
     int cc_major;
     int cc_minor;
+    bool relocatable;
   };
   struct CompilationCacheHash {
     size_t operator()(const CompilationCacheKey& key) const {
       return tensorflow::Hash64Combine(
-          tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), key.cc_major),
-          key.cc_minor);
+          tensorflow::Hash64Combine(
+              tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx),
+                                        key.cc_major),
+              key.cc_minor),
+          key.relocatable);
     }
   };
   struct CompilationCacheEq {
     size_t operator()(const CompilationCacheKey& a,
                       const CompilationCacheKey& b) const {
       return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor &&
-             a.ptx == b.ptx;
+             a.ptx == b.ptx && a.relocatable == b.relocatable;
     }
   };
   struct CompilationCacheValue {
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 681e025..4e94176 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -172,7 +172,7 @@
     srcs = [
         "reduction_vectorization_test.cc",
     ],
-    tags = tf_cuda_tests_tags() + ["no_rocm"],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla:debug_options_flags",
@@ -410,7 +410,7 @@
 tf_cc_test(
     name = "gpu_unrolling_test",
     srcs = ["gpu_unrolling_test.cc"],
-    tags = tf_cuda_tests_tags(),
+    tags = tf_cuda_tests_tags() + ["no_rocm"],
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla/service:hlo_module_config",
@@ -441,7 +441,7 @@
 tf_cc_test(
     name = "gpu_atomic_test",
     srcs = ["gpu_atomic_test.cc"],
-    tags = tf_cuda_tests_tags(),
+    tags = tf_cuda_tests_tags() + ["no_rocm"],
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla/tests:filecheck",
diff --git a/tensorflow/compiler/xla/service/gpu/tests/elementwise.hlo b/tensorflow/compiler/xla/service/gpu/tests/elementwise.hlo
index d4ed447..c54affb 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/elementwise.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/elementwise.hlo
@@ -7,9 +7,9 @@
 // CHECK:         %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_7:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 128
+// CHECK:         %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 256
 // CHECK:         %[[VAL_9:.*]] = add nuw nsw i32 %[[VAL_8]], %[[VAL_7]]
-// CHECK:         %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], 163840
+// CHECK:         %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_10]])
 // CHECK:         %[[VAL_11:.*]] = mul nuw nsw i32 %[[VAL_9]], 4
 // CHECK:         %[[VAL_12:.*]] = udiv i32 %[[VAL_11]], 1
@@ -32,32 +32,32 @@
 // CHECK:       r0.in_bounds-after:                               ; preds = %[[VAL_28]], %[[VAL_30:.*]]
 // CHECK:         ret void
 // CHECK:       r0.in_bounds-true:                                ; preds = %[[VAL_30]]
-// CHECK:         %[[VAL_31:.*]] = bitcast [100 x [200 x float]]* %[[VAL_5]] to float*
+// CHECK:         %[[VAL_31:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2]] to float*
 // CHECK:         %[[VAL_32:.*]] = getelementptr inbounds float, float* %[[VAL_31]], i32 %[[VAL_11]]
 // CHECK:         %[[VAL_33:.*]] = load float, float* %[[VAL_32]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_34:.*]] = call float @llvm.fabs.f32(float %[[VAL_33]])
-// CHECK:         %[[VAL_35:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2]] to float*
+// CHECK:         %[[VAL_35:.*]] = bitcast [100 x [200 x float]]* %[[VAL_5]] to float*
 // CHECK:         %[[VAL_36:.*]] = getelementptr inbounds float, float* %[[VAL_35]], i32 %[[VAL_11]]
 // CHECK:         store float %[[VAL_34]], float* %[[VAL_36]], align 4
-// CHECK:         %[[VAL_37:.*]] = bitcast [100 x [200 x float]]* %[[VAL_5]] to float*
+// CHECK:         %[[VAL_37:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2]] to float*
 // CHECK:         %[[VAL_38:.*]] = getelementptr inbounds float, float* %[[VAL_37]], i32 %[[VAL_15]]
 // CHECK:         %[[VAL_39:.*]] = load float, float* %[[VAL_38]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_40:.*]] = call float @llvm.fabs.f32(float %[[VAL_39]])
-// CHECK:         %[[VAL_41:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2]] to float*
+// CHECK:         %[[VAL_41:.*]] = bitcast [100 x [200 x float]]* %[[VAL_5]] to float*
 // CHECK:         %[[VAL_42:.*]] = getelementptr inbounds float, float* %[[VAL_41]], i32 %[[VAL_15]]
 // CHECK:         store float %[[VAL_40]], float* %[[VAL_42]], align 4
-// CHECK:         %[[VAL_43:.*]] = bitcast [100 x [200 x float]]* %[[VAL_5]] to float*
+// CHECK:         %[[VAL_43:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2]] to float*
 // CHECK:         %[[VAL_44:.*]] = getelementptr inbounds float, float* %[[VAL_43]], i32 %[[VAL_19]]
 // CHECK:         %[[VAL_45:.*]] = load float, float* %[[VAL_44]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_46:.*]] = call float @llvm.fabs.f32(float %[[VAL_45]])
-// CHECK:         %[[VAL_47:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2]] to float*
+// CHECK:         %[[VAL_47:.*]] = bitcast [100 x [200 x float]]* %[[VAL_5]] to float*
 // CHECK:         %[[VAL_48:.*]] = getelementptr inbounds float, float* %[[VAL_47]], i32 %[[VAL_19]]
 // CHECK:         store float %[[VAL_46]], float* %[[VAL_48]], align 4
-// CHECK:         %[[VAL_49:.*]] = bitcast [100 x [200 x float]]* %[[VAL_5]] to float*
+// CHECK:         %[[VAL_49:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2]] to float*
 // CHECK:         %[[VAL_50:.*]] = getelementptr inbounds float, float* %[[VAL_49]], i32 %[[VAL_23]]
 // CHECK:         %[[VAL_51:.*]] = load float, float* %[[VAL_50]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_52:.*]] = call float @llvm.fabs.f32(float %[[VAL_51]])
-// CHECK:         %[[VAL_53:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2]] to float*
+// CHECK:         %[[VAL_53:.*]] = bitcast [100 x [200 x float]]* %[[VAL_5]] to float*
 // CHECK:         %[[VAL_54:.*]] = getelementptr inbounds float, float* %[[VAL_53]], i32 %[[VAL_23]]
 // CHECK:         store float %[[VAL_52]], float* %[[VAL_54]], align 4
 // CHECK:         br label %[[VAL_29]]
@@ -68,9 +68,9 @@
 // CHECK:         %[[VAL_60:.*]] = bitcast i8* %[[VAL_58]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_61:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_62:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_63:.*]] = mul nuw nsw i32 %[[VAL_61]], 128
+// CHECK:         %[[VAL_63:.*]] = mul nuw nsw i32 %[[VAL_61]], 256
 // CHECK:         %[[VAL_64:.*]] = add nuw nsw i32 %[[VAL_63]], %[[VAL_62]]
-// CHECK:         %[[VAL_65:.*]] = icmp ult i32 %[[VAL_64]], 163840
+// CHECK:         %[[VAL_65:.*]] = icmp ult i32 %[[VAL_64]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_65]])
 // CHECK:         %[[VAL_66:.*]] = mul nuw nsw i32 %[[VAL_64]], 4
 // CHECK:         %[[VAL_67:.*]] = udiv i32 %[[VAL_66]], 1
@@ -93,32 +93,32 @@
 // CHECK:       r1.in_bounds-after:                               ; preds = %[[VAL_83]], %[[VAL_85:.*]]
 // CHECK:         ret void
 // CHECK:       r1.in_bounds-true:                                ; preds = %[[VAL_85]]
-// CHECK:         %[[VAL_86:.*]] = bitcast [100 x [200 x float]]* %[[VAL_60]] to float*
+// CHECK:         %[[VAL_86:.*]] = bitcast [100 x [200 x float]]* %[[VAL_57]] to float*
 // CHECK:         %[[VAL_87:.*]] = getelementptr inbounds float, float* %[[VAL_86]], i32 %[[VAL_66]]
 // CHECK:         %[[VAL_88:.*]] = load float, float* %[[VAL_87]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_89:.*]] = call float @llvm.round.f32(float %[[VAL_88]])
-// CHECK:         %[[VAL_90:.*]] = bitcast [100 x [200 x float]]* %[[VAL_57]] to float*
+// CHECK:         %[[VAL_90:.*]] = bitcast [100 x [200 x float]]* %[[VAL_60]] to float*
 // CHECK:         %[[VAL_91:.*]] = getelementptr inbounds float, float* %[[VAL_90]], i32 %[[VAL_66]]
 // CHECK:         store float %[[VAL_89]], float* %[[VAL_91]], align 4
-// CHECK:         %[[VAL_92:.*]] = bitcast [100 x [200 x float]]* %[[VAL_60]] to float*
+// CHECK:         %[[VAL_92:.*]] = bitcast [100 x [200 x float]]* %[[VAL_57]] to float*
 // CHECK:         %[[VAL_93:.*]] = getelementptr inbounds float, float* %[[VAL_92]], i32 %[[VAL_70]]
 // CHECK:         %[[VAL_94:.*]] = load float, float* %[[VAL_93]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_95:.*]] = call float @llvm.round.f32(float %[[VAL_94]])
-// CHECK:         %[[VAL_96:.*]] = bitcast [100 x [200 x float]]* %[[VAL_57]] to float*
+// CHECK:         %[[VAL_96:.*]] = bitcast [100 x [200 x float]]* %[[VAL_60]] to float*
 // CHECK:         %[[VAL_97:.*]] = getelementptr inbounds float, float* %[[VAL_96]], i32 %[[VAL_70]]
 // CHECK:         store float %[[VAL_95]], float* %[[VAL_97]], align 4
-// CHECK:         %[[VAL_98:.*]] = bitcast [100 x [200 x float]]* %[[VAL_60]] to float*
+// CHECK:         %[[VAL_98:.*]] = bitcast [100 x [200 x float]]* %[[VAL_57]] to float*
 // CHECK:         %[[VAL_99:.*]] = getelementptr inbounds float, float* %[[VAL_98]], i32 %[[VAL_74]]
 // CHECK:         %[[VAL_100:.*]] = load float, float* %[[VAL_99]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_101:.*]] = call float @llvm.round.f32(float %[[VAL_100]])
-// CHECK:         %[[VAL_102:.*]] = bitcast [100 x [200 x float]]* %[[VAL_57]] to float*
+// CHECK:         %[[VAL_102:.*]] = bitcast [100 x [200 x float]]* %[[VAL_60]] to float*
 // CHECK:         %[[VAL_103:.*]] = getelementptr inbounds float, float* %[[VAL_102]], i32 %[[VAL_74]]
 // CHECK:         store float %[[VAL_101]], float* %[[VAL_103]], align 4
-// CHECK:         %[[VAL_104:.*]] = bitcast [100 x [200 x float]]* %[[VAL_60]] to float*
+// CHECK:         %[[VAL_104:.*]] = bitcast [100 x [200 x float]]* %[[VAL_57]] to float*
 // CHECK:         %[[VAL_105:.*]] = getelementptr inbounds float, float* %[[VAL_104]], i32 %[[VAL_78]]
 // CHECK:         %[[VAL_106:.*]] = load float, float* %[[VAL_105]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_107:.*]] = call float @llvm.round.f32(float %[[VAL_106]])
-// CHECK:         %[[VAL_108:.*]] = bitcast [100 x [200 x float]]* %[[VAL_57]] to float*
+// CHECK:         %[[VAL_108:.*]] = bitcast [100 x [200 x float]]* %[[VAL_60]] to float*
 // CHECK:         %[[VAL_109:.*]] = getelementptr inbounds float, float* %[[VAL_108]], i32 %[[VAL_78]]
 // CHECK:         store float %[[VAL_107]], float* %[[VAL_109]], align 4
 // CHECK:         br label %[[VAL_84]]
@@ -129,9 +129,9 @@
 // CHECK:         %[[VAL_115:.*]] = bitcast i8* %[[VAL_113]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_116:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_117:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_118:.*]] = mul nuw nsw i32 %[[VAL_116]], 128
+// CHECK:         %[[VAL_118:.*]] = mul nuw nsw i32 %[[VAL_116]], 256
 // CHECK:         %[[VAL_119:.*]] = add nuw nsw i32 %[[VAL_118]], %[[VAL_117]]
-// CHECK:         %[[VAL_120:.*]] = icmp ult i32 %[[VAL_119]], 163840
+// CHECK:         %[[VAL_120:.*]] = icmp ult i32 %[[VAL_119]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_120]])
 // CHECK:         %[[VAL_121:.*]] = mul nuw nsw i32 %[[VAL_119]], 4
 // CHECK:         %[[VAL_122:.*]] = udiv i32 %[[VAL_121]], 1
@@ -154,32 +154,32 @@
 // CHECK:       r2.in_bounds-after:                               ; preds = %[[VAL_138]], %[[VAL_140:.*]]
 // CHECK:         ret void
 // CHECK:       r2.in_bounds-true:                                ; preds = %[[VAL_140]]
-// CHECK:         %[[VAL_141:.*]] = bitcast [100 x [200 x float]]* %[[VAL_115]] to float*
+// CHECK:         %[[VAL_141:.*]] = bitcast [100 x [200 x float]]* %[[VAL_112]] to float*
 // CHECK:         %[[VAL_142:.*]] = getelementptr inbounds float, float* %[[VAL_141]], i32 %[[VAL_121]]
 // CHECK:         %[[VAL_143:.*]] = load float, float* %[[VAL_142]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_144:.*]] = call float @llvm.ceil.f32(float %[[VAL_143]])
-// CHECK:         %[[VAL_145:.*]] = bitcast [100 x [200 x float]]* %[[VAL_112]] to float*
+// CHECK:         %[[VAL_145:.*]] = bitcast [100 x [200 x float]]* %[[VAL_115]] to float*
 // CHECK:         %[[VAL_146:.*]] = getelementptr inbounds float, float* %[[VAL_145]], i32 %[[VAL_121]]
 // CHECK:         store float %[[VAL_144]], float* %[[VAL_146]], align 4
-// CHECK:         %[[VAL_147:.*]] = bitcast [100 x [200 x float]]* %[[VAL_115]] to float*
+// CHECK:         %[[VAL_147:.*]] = bitcast [100 x [200 x float]]* %[[VAL_112]] to float*
 // CHECK:         %[[VAL_148:.*]] = getelementptr inbounds float, float* %[[VAL_147]], i32 %[[VAL_125]]
 // CHECK:         %[[VAL_149:.*]] = load float, float* %[[VAL_148]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_150:.*]] = call float @llvm.ceil.f32(float %[[VAL_149]])
-// CHECK:         %[[VAL_151:.*]] = bitcast [100 x [200 x float]]* %[[VAL_112]] to float*
+// CHECK:         %[[VAL_151:.*]] = bitcast [100 x [200 x float]]* %[[VAL_115]] to float*
 // CHECK:         %[[VAL_152:.*]] = getelementptr inbounds float, float* %[[VAL_151]], i32 %[[VAL_125]]
 // CHECK:         store float %[[VAL_150]], float* %[[VAL_152]], align 4
-// CHECK:         %[[VAL_153:.*]] = bitcast [100 x [200 x float]]* %[[VAL_115]] to float*
+// CHECK:         %[[VAL_153:.*]] = bitcast [100 x [200 x float]]* %[[VAL_112]] to float*
 // CHECK:         %[[VAL_154:.*]] = getelementptr inbounds float, float* %[[VAL_153]], i32 %[[VAL_129]]
 // CHECK:         %[[VAL_155:.*]] = load float, float* %[[VAL_154]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_156:.*]] = call float @llvm.ceil.f32(float %[[VAL_155]])
-// CHECK:         %[[VAL_157:.*]] = bitcast [100 x [200 x float]]* %[[VAL_112]] to float*
+// CHECK:         %[[VAL_157:.*]] = bitcast [100 x [200 x float]]* %[[VAL_115]] to float*
 // CHECK:         %[[VAL_158:.*]] = getelementptr inbounds float, float* %[[VAL_157]], i32 %[[VAL_129]]
 // CHECK:         store float %[[VAL_156]], float* %[[VAL_158]], align 4
-// CHECK:         %[[VAL_159:.*]] = bitcast [100 x [200 x float]]* %[[VAL_115]] to float*
+// CHECK:         %[[VAL_159:.*]] = bitcast [100 x [200 x float]]* %[[VAL_112]] to float*
 // CHECK:         %[[VAL_160:.*]] = getelementptr inbounds float, float* %[[VAL_159]], i32 %[[VAL_133]]
 // CHECK:         %[[VAL_161:.*]] = load float, float* %[[VAL_160]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_162:.*]] = call float @llvm.ceil.f32(float %[[VAL_161]])
-// CHECK:         %[[VAL_163:.*]] = bitcast [100 x [200 x float]]* %[[VAL_112]] to float*
+// CHECK:         %[[VAL_163:.*]] = bitcast [100 x [200 x float]]* %[[VAL_115]] to float*
 // CHECK:         %[[VAL_164:.*]] = getelementptr inbounds float, float* %[[VAL_163]], i32 %[[VAL_133]]
 // CHECK:         store float %[[VAL_162]], float* %[[VAL_164]], align 4
 // CHECK:         br label %[[VAL_139]]
@@ -190,9 +190,9 @@
 // CHECK:         %[[VAL_170:.*]] = bitcast i8* %[[VAL_168]] to [100 x [200 x i32]]*
 // CHECK:         %[[VAL_171:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_172:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_173:.*]] = mul nuw nsw i32 %[[VAL_171]], 128
+// CHECK:         %[[VAL_173:.*]] = mul nuw nsw i32 %[[VAL_171]], 256
 // CHECK:         %[[VAL_174:.*]] = add nuw nsw i32 %[[VAL_173]], %[[VAL_172]]
-// CHECK:         %[[VAL_175:.*]] = icmp ult i32 %[[VAL_174]], 163840
+// CHECK:         %[[VAL_175:.*]] = icmp ult i32 %[[VAL_174]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_175]])
 // CHECK:         %[[VAL_176:.*]] = mul nuw nsw i32 %[[VAL_174]], 4
 // CHECK:         %[[VAL_177:.*]] = udiv i32 %[[VAL_176]], 1
@@ -215,32 +215,32 @@
 // CHECK:       r3.in_bounds-after:                               ; preds = %[[VAL_193]], %[[VAL_195:.*]]
 // CHECK:         ret void
 // CHECK:       r3.in_bounds-true:                                ; preds = %[[VAL_195]]
-// CHECK:         %[[VAL_196:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_170]] to i32*
+// CHECK:         %[[VAL_196:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_167]] to i32*
 // CHECK:         %[[VAL_197:.*]] = getelementptr inbounds i32, i32* %[[VAL_196]], i32 %[[VAL_176]]
 // CHECK:         %[[VAL_198:.*]] = load i32, i32* %[[VAL_197]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_199:.*]] = call i32 @llvm.ctlz.i32(i32 %[[VAL_198]], i1 false)
-// CHECK:         %[[VAL_200:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_167]] to i32*
+// CHECK:         %[[VAL_200:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_170]] to i32*
 // CHECK:         %[[VAL_201:.*]] = getelementptr inbounds i32, i32* %[[VAL_200]], i32 %[[VAL_176]]
 // CHECK:         store i32 %[[VAL_199]], i32* %[[VAL_201]], align 4
-// CHECK:         %[[VAL_202:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_170]] to i32*
+// CHECK:         %[[VAL_202:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_167]] to i32*
 // CHECK:         %[[VAL_203:.*]] = getelementptr inbounds i32, i32* %[[VAL_202]], i32 %[[VAL_180]]
 // CHECK:         %[[VAL_204:.*]] = load i32, i32* %[[VAL_203]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_205:.*]] = call i32 @llvm.ctlz.i32(i32 %[[VAL_204]], i1 false)
-// CHECK:         %[[VAL_206:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_167]] to i32*
+// CHECK:         %[[VAL_206:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_170]] to i32*
 // CHECK:         %[[VAL_207:.*]] = getelementptr inbounds i32, i32* %[[VAL_206]], i32 %[[VAL_180]]
 // CHECK:         store i32 %[[VAL_205]], i32* %[[VAL_207]], align 4
-// CHECK:         %[[VAL_208:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_170]] to i32*
+// CHECK:         %[[VAL_208:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_167]] to i32*
 // CHECK:         %[[VAL_209:.*]] = getelementptr inbounds i32, i32* %[[VAL_208]], i32 %[[VAL_184]]
 // CHECK:         %[[VAL_210:.*]] = load i32, i32* %[[VAL_209]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_211:.*]] = call i32 @llvm.ctlz.i32(i32 %[[VAL_210]], i1 false)
-// CHECK:         %[[VAL_212:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_167]] to i32*
+// CHECK:         %[[VAL_212:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_170]] to i32*
 // CHECK:         %[[VAL_213:.*]] = getelementptr inbounds i32, i32* %[[VAL_212]], i32 %[[VAL_184]]
 // CHECK:         store i32 %[[VAL_211]], i32* %[[VAL_213]], align 4
-// CHECK:         %[[VAL_214:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_170]] to i32*
+// CHECK:         %[[VAL_214:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_167]] to i32*
 // CHECK:         %[[VAL_215:.*]] = getelementptr inbounds i32, i32* %[[VAL_214]], i32 %[[VAL_188]]
 // CHECK:         %[[VAL_216:.*]] = load i32, i32* %[[VAL_215]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_217:.*]] = call i32 @llvm.ctlz.i32(i32 %[[VAL_216]], i1 false)
-// CHECK:         %[[VAL_218:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_167]] to i32*
+// CHECK:         %[[VAL_218:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_170]] to i32*
 // CHECK:         %[[VAL_219:.*]] = getelementptr inbounds i32, i32* %[[VAL_218]], i32 %[[VAL_188]]
 // CHECK:         store i32 %[[VAL_217]], i32* %[[VAL_219]], align 4
 // CHECK:         br label %[[VAL_194]]
@@ -251,9 +251,9 @@
 // CHECK:         %[[VAL_225:.*]] = bitcast i8* %[[VAL_223]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_226:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_227:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_228:.*]] = mul nuw nsw i32 %[[VAL_226]], 128
+// CHECK:         %[[VAL_228:.*]] = mul nuw nsw i32 %[[VAL_226]], 256
 // CHECK:         %[[VAL_229:.*]] = add nuw nsw i32 %[[VAL_228]], %[[VAL_227]]
-// CHECK:         %[[VAL_230:.*]] = icmp ult i32 %[[VAL_229]], 163840
+// CHECK:         %[[VAL_230:.*]] = icmp ult i32 %[[VAL_229]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_230]])
 // CHECK:         %[[VAL_231:.*]] = mul nuw nsw i32 %[[VAL_229]], 4
 // CHECK:         %[[VAL_232:.*]] = udiv i32 %[[VAL_231]], 1
@@ -276,28 +276,28 @@
 // CHECK:       r4.in_bounds-after:                               ; preds = %[[VAL_248]], %[[VAL_250:.*]]
 // CHECK:         ret void
 // CHECK:       r4.in_bounds-true:                                ; preds = %[[VAL_250]]
-// CHECK:         %[[VAL_251:.*]] = bitcast [100 x [200 x float]]* %[[VAL_225]] to float*
+// CHECK:         %[[VAL_251:.*]] = bitcast [100 x [200 x float]]* %[[VAL_222]] to float*
 // CHECK:         %[[VAL_252:.*]] = getelementptr inbounds float, float* %[[VAL_251]], i32 %[[VAL_231]]
 // CHECK:         %[[VAL_253:.*]] = load float, float* %[[VAL_252]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_254:.*]] = bitcast [100 x [200 x float]]* %[[VAL_222]] to float*
+// CHECK:         %[[VAL_254:.*]] = bitcast [100 x [200 x float]]* %[[VAL_225]] to float*
 // CHECK:         %[[VAL_255:.*]] = getelementptr inbounds float, float* %[[VAL_254]], i32 %[[VAL_231]]
 // CHECK:         store float %[[VAL_253]], float* %[[VAL_255]], align 4
-// CHECK:         %[[VAL_256:.*]] = bitcast [100 x [200 x float]]* %[[VAL_225]] to float*
+// CHECK:         %[[VAL_256:.*]] = bitcast [100 x [200 x float]]* %[[VAL_222]] to float*
 // CHECK:         %[[VAL_257:.*]] = getelementptr inbounds float, float* %[[VAL_256]], i32 %[[VAL_235]]
 // CHECK:         %[[VAL_258:.*]] = load float, float* %[[VAL_257]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_259:.*]] = bitcast [100 x [200 x float]]* %[[VAL_222]] to float*
+// CHECK:         %[[VAL_259:.*]] = bitcast [100 x [200 x float]]* %[[VAL_225]] to float*
 // CHECK:         %[[VAL_260:.*]] = getelementptr inbounds float, float* %[[VAL_259]], i32 %[[VAL_235]]
 // CHECK:         store float %[[VAL_258]], float* %[[VAL_260]], align 4
-// CHECK:         %[[VAL_261:.*]] = bitcast [100 x [200 x float]]* %[[VAL_225]] to float*
+// CHECK:         %[[VAL_261:.*]] = bitcast [100 x [200 x float]]* %[[VAL_222]] to float*
 // CHECK:         %[[VAL_262:.*]] = getelementptr inbounds float, float* %[[VAL_261]], i32 %[[VAL_239]]
 // CHECK:         %[[VAL_263:.*]] = load float, float* %[[VAL_262]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_264:.*]] = bitcast [100 x [200 x float]]* %[[VAL_222]] to float*
+// CHECK:         %[[VAL_264:.*]] = bitcast [100 x [200 x float]]* %[[VAL_225]] to float*
 // CHECK:         %[[VAL_265:.*]] = getelementptr inbounds float, float* %[[VAL_264]], i32 %[[VAL_239]]
 // CHECK:         store float %[[VAL_263]], float* %[[VAL_265]], align 4
-// CHECK:         %[[VAL_266:.*]] = bitcast [100 x [200 x float]]* %[[VAL_225]] to float*
+// CHECK:         %[[VAL_266:.*]] = bitcast [100 x [200 x float]]* %[[VAL_222]] to float*
 // CHECK:         %[[VAL_267:.*]] = getelementptr inbounds float, float* %[[VAL_266]], i32 %[[VAL_243]]
 // CHECK:         %[[VAL_268:.*]] = load float, float* %[[VAL_267]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_269:.*]] = bitcast [100 x [200 x float]]* %[[VAL_222]] to float*
+// CHECK:         %[[VAL_269:.*]] = bitcast [100 x [200 x float]]* %[[VAL_225]] to float*
 // CHECK:         %[[VAL_270:.*]] = getelementptr inbounds float, float* %[[VAL_269]], i32 %[[VAL_243]]
 // CHECK:         store float %[[VAL_268]], float* %[[VAL_270]], align 4
 // CHECK:         br label %[[VAL_249]]
@@ -308,9 +308,9 @@
 // CHECK:         %[[VAL_276:.*]] = bitcast i8* %[[VAL_274]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_277:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_278:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_279:.*]] = mul nuw nsw i32 %[[VAL_277]], 128
+// CHECK:         %[[VAL_279:.*]] = mul nuw nsw i32 %[[VAL_277]], 256
 // CHECK:         %[[VAL_280:.*]] = add nuw nsw i32 %[[VAL_279]], %[[VAL_278]]
-// CHECK:         %[[VAL_281:.*]] = icmp ult i32 %[[VAL_280]], 163840
+// CHECK:         %[[VAL_281:.*]] = icmp ult i32 %[[VAL_280]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_281]])
 // CHECK:         %[[VAL_282:.*]] = mul nuw nsw i32 %[[VAL_280]], 4
 // CHECK:         %[[VAL_283:.*]] = udiv i32 %[[VAL_282]], 1
@@ -333,28 +333,28 @@
 // CHECK:       r5.in_bounds-after:                               ; preds = %[[VAL_299]], %[[VAL_301:.*]]
 // CHECK:         ret void
 // CHECK:       r5.in_bounds-true:                                ; preds = %[[VAL_301]]
-// CHECK:         %[[VAL_302:.*]] = bitcast [100 x [200 x float]]* %[[VAL_276]] to float*
+// CHECK:         %[[VAL_302:.*]] = bitcast [100 x [200 x float]]* %[[VAL_273]] to float*
 // CHECK:         %[[VAL_303:.*]] = getelementptr inbounds float, float* %[[VAL_302]], i32 %[[VAL_282]]
 // CHECK:         %[[VAL_304:.*]] = load float, float* %[[VAL_303]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_305:.*]] = bitcast [100 x [200 x float]]* %[[VAL_273]] to float*
+// CHECK:         %[[VAL_305:.*]] = bitcast [100 x [200 x float]]* %[[VAL_276]] to float*
 // CHECK:         %[[VAL_306:.*]] = getelementptr inbounds float, float* %[[VAL_305]], i32 %[[VAL_282]]
 // CHECK:         store float %[[VAL_304]], float* %[[VAL_306]], align 4
-// CHECK:         %[[VAL_307:.*]] = bitcast [100 x [200 x float]]* %[[VAL_276]] to float*
+// CHECK:         %[[VAL_307:.*]] = bitcast [100 x [200 x float]]* %[[VAL_273]] to float*
 // CHECK:         %[[VAL_308:.*]] = getelementptr inbounds float, float* %[[VAL_307]], i32 %[[VAL_286]]
 // CHECK:         %[[VAL_309:.*]] = load float, float* %[[VAL_308]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_310:.*]] = bitcast [100 x [200 x float]]* %[[VAL_273]] to float*
+// CHECK:         %[[VAL_310:.*]] = bitcast [100 x [200 x float]]* %[[VAL_276]] to float*
 // CHECK:         %[[VAL_311:.*]] = getelementptr inbounds float, float* %[[VAL_310]], i32 %[[VAL_286]]
 // CHECK:         store float %[[VAL_309]], float* %[[VAL_311]], align 4
-// CHECK:         %[[VAL_312:.*]] = bitcast [100 x [200 x float]]* %[[VAL_276]] to float*
+// CHECK:         %[[VAL_312:.*]] = bitcast [100 x [200 x float]]* %[[VAL_273]] to float*
 // CHECK:         %[[VAL_313:.*]] = getelementptr inbounds float, float* %[[VAL_312]], i32 %[[VAL_290]]
 // CHECK:         %[[VAL_314:.*]] = load float, float* %[[VAL_313]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_315:.*]] = bitcast [100 x [200 x float]]* %[[VAL_273]] to float*
+// CHECK:         %[[VAL_315:.*]] = bitcast [100 x [200 x float]]* %[[VAL_276]] to float*
 // CHECK:         %[[VAL_316:.*]] = getelementptr inbounds float, float* %[[VAL_315]], i32 %[[VAL_290]]
 // CHECK:         store float %[[VAL_314]], float* %[[VAL_316]], align 4
-// CHECK:         %[[VAL_317:.*]] = bitcast [100 x [200 x float]]* %[[VAL_276]] to float*
+// CHECK:         %[[VAL_317:.*]] = bitcast [100 x [200 x float]]* %[[VAL_273]] to float*
 // CHECK:         %[[VAL_318:.*]] = getelementptr inbounds float, float* %[[VAL_317]], i32 %[[VAL_294]]
 // CHECK:         %[[VAL_319:.*]] = load float, float* %[[VAL_318]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_320:.*]] = bitcast [100 x [200 x float]]* %[[VAL_273]] to float*
+// CHECK:         %[[VAL_320:.*]] = bitcast [100 x [200 x float]]* %[[VAL_276]] to float*
 // CHECK:         %[[VAL_321:.*]] = getelementptr inbounds float, float* %[[VAL_320]], i32 %[[VAL_294]]
 // CHECK:         store float %[[VAL_319]], float* %[[VAL_321]], align 4
 // CHECK:         br label %[[VAL_300]]
@@ -364,10 +364,10 @@
 // CHECK:         %[[VAL_325:.*]] = getelementptr inbounds i8, i8* %[[VAL_326:.*]], i64 0
 // CHECK:         %[[VAL_327:.*]] = bitcast i8* %[[VAL_325]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_328:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
-// CHECK:         %[[VAL_329:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_330:.*]] = mul nuw nsw i32 %[[VAL_328]], 128
+// CHECK:         %[[VAL_329:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !93
+// CHECK:         %[[VAL_330:.*]] = mul nuw nsw i32 %[[VAL_328]], 1024
 // CHECK:         %[[VAL_331:.*]] = add nuw nsw i32 %[[VAL_330]], %[[VAL_329]]
-// CHECK:         %[[VAL_332:.*]] = icmp ult i32 %[[VAL_331]], 163840
+// CHECK:         %[[VAL_332:.*]] = icmp ult i32 %[[VAL_331]], 20480
 // CHECK:         call void @llvm.assume(i1 %[[VAL_332]])
 // CHECK:         %[[VAL_333:.*]] = udiv i32 %[[VAL_331]], 1
 // CHECK:         %[[VAL_334:.*]] = urem i32 %[[VAL_333]], 200
@@ -377,11 +377,11 @@
 // CHECK:       r7.in_bounds-after:                               ; preds = %[[VAL_337]], %[[VAL_339:.*]]
 // CHECK:         ret void
 // CHECK:       r7.in_bounds-true:                                ; preds = %[[VAL_339]]
-// CHECK:         %[[VAL_340:.*]] = bitcast [100 x [200 x float]]* %[[VAL_327]] to float*
+// CHECK:         %[[VAL_340:.*]] = bitcast [100 x [200 x float]]* %[[VAL_324]] to float*
 // CHECK:         %[[VAL_341:.*]] = getelementptr inbounds float, float* %[[VAL_340]], i32 %[[VAL_331]]
 // CHECK:         %[[VAL_342:.*]] = load float, float* %[[VAL_341]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_343:.*]] = call float @__nv_cosf(float %[[VAL_342]])
-// CHECK:         %[[VAL_344:.*]] = bitcast [100 x [200 x float]]* %[[VAL_324]] to float*
+// CHECK:         %[[VAL_344:.*]] = bitcast [100 x [200 x float]]* %[[VAL_327]] to float*
 // CHECK:         %[[VAL_345:.*]] = getelementptr inbounds float, float* %[[VAL_344]], i32 %[[VAL_331]]
 // CHECK:         store float %[[VAL_343]], float* %[[VAL_345]], align 4
 // CHECK:         br label %[[VAL_338]]
@@ -392,9 +392,9 @@
 // CHECK:         %[[VAL_351:.*]] = bitcast i8* %[[VAL_349]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_352:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_353:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_354:.*]] = mul nuw nsw i32 %[[VAL_352]], 128
+// CHECK:         %[[VAL_354:.*]] = mul nuw nsw i32 %[[VAL_352]], 256
 // CHECK:         %[[VAL_355:.*]] = add nuw nsw i32 %[[VAL_354]], %[[VAL_353]]
-// CHECK:         %[[VAL_356:.*]] = icmp ult i32 %[[VAL_355]], 163840
+// CHECK:         %[[VAL_356:.*]] = icmp ult i32 %[[VAL_355]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_356]])
 // CHECK:         %[[VAL_357:.*]] = mul nuw nsw i32 %[[VAL_355]], 4
 // CHECK:         %[[VAL_358:.*]] = udiv i32 %[[VAL_357]], 1
@@ -417,32 +417,32 @@
 // CHECK:       r8.in_bounds-after:                               ; preds = %[[VAL_374]], %[[VAL_376:.*]]
 // CHECK:         ret void
 // CHECK:       r8.in_bounds-true:                                ; preds = %[[VAL_376]]
-// CHECK:         %[[VAL_377:.*]] = bitcast [100 x [200 x float]]* %[[VAL_351]] to float*
+// CHECK:         %[[VAL_377:.*]] = bitcast [100 x [200 x float]]* %[[VAL_348]] to float*
 // CHECK:         %[[VAL_378:.*]] = getelementptr inbounds float, float* %[[VAL_377]], i32 %[[VAL_357]]
 // CHECK:         %[[VAL_379:.*]] = load float, float* %[[VAL_378]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_380:.*]] = call float @__nv_expf(float %[[VAL_379]])
-// CHECK:         %[[VAL_381:.*]] = bitcast [100 x [200 x float]]* %[[VAL_348]] to float*
+// CHECK:         %[[VAL_381:.*]] = bitcast [100 x [200 x float]]* %[[VAL_351]] to float*
 // CHECK:         %[[VAL_382:.*]] = getelementptr inbounds float, float* %[[VAL_381]], i32 %[[VAL_357]]
 // CHECK:         store float %[[VAL_380]], float* %[[VAL_382]], align 4
-// CHECK:         %[[VAL_383:.*]] = bitcast [100 x [200 x float]]* %[[VAL_351]] to float*
+// CHECK:         %[[VAL_383:.*]] = bitcast [100 x [200 x float]]* %[[VAL_348]] to float*
 // CHECK:         %[[VAL_384:.*]] = getelementptr inbounds float, float* %[[VAL_383]], i32 %[[VAL_361]]
 // CHECK:         %[[VAL_385:.*]] = load float, float* %[[VAL_384]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_386:.*]] = call float @__nv_expf(float %[[VAL_385]])
-// CHECK:         %[[VAL_387:.*]] = bitcast [100 x [200 x float]]* %[[VAL_348]] to float*
+// CHECK:         %[[VAL_387:.*]] = bitcast [100 x [200 x float]]* %[[VAL_351]] to float*
 // CHECK:         %[[VAL_388:.*]] = getelementptr inbounds float, float* %[[VAL_387]], i32 %[[VAL_361]]
 // CHECK:         store float %[[VAL_386]], float* %[[VAL_388]], align 4
-// CHECK:         %[[VAL_389:.*]] = bitcast [100 x [200 x float]]* %[[VAL_351]] to float*
+// CHECK:         %[[VAL_389:.*]] = bitcast [100 x [200 x float]]* %[[VAL_348]] to float*
 // CHECK:         %[[VAL_390:.*]] = getelementptr inbounds float, float* %[[VAL_389]], i32 %[[VAL_365]]
 // CHECK:         %[[VAL_391:.*]] = load float, float* %[[VAL_390]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_392:.*]] = call float @__nv_expf(float %[[VAL_391]])
-// CHECK:         %[[VAL_393:.*]] = bitcast [100 x [200 x float]]* %[[VAL_348]] to float*
+// CHECK:         %[[VAL_393:.*]] = bitcast [100 x [200 x float]]* %[[VAL_351]] to float*
 // CHECK:         %[[VAL_394:.*]] = getelementptr inbounds float, float* %[[VAL_393]], i32 %[[VAL_365]]
 // CHECK:         store float %[[VAL_392]], float* %[[VAL_394]], align 4
-// CHECK:         %[[VAL_395:.*]] = bitcast [100 x [200 x float]]* %[[VAL_351]] to float*
+// CHECK:         %[[VAL_395:.*]] = bitcast [100 x [200 x float]]* %[[VAL_348]] to float*
 // CHECK:         %[[VAL_396:.*]] = getelementptr inbounds float, float* %[[VAL_395]], i32 %[[VAL_369]]
 // CHECK:         %[[VAL_397:.*]] = load float, float* %[[VAL_396]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_398:.*]] = call float @__nv_expf(float %[[VAL_397]])
-// CHECK:         %[[VAL_399:.*]] = bitcast [100 x [200 x float]]* %[[VAL_348]] to float*
+// CHECK:         %[[VAL_399:.*]] = bitcast [100 x [200 x float]]* %[[VAL_351]] to float*
 // CHECK:         %[[VAL_400:.*]] = getelementptr inbounds float, float* %[[VAL_399]], i32 %[[VAL_369]]
 // CHECK:         store float %[[VAL_398]], float* %[[VAL_400]], align 4
 // CHECK:         br label %[[VAL_375]]
@@ -453,9 +453,9 @@
 // CHECK:         %[[VAL_406:.*]] = bitcast i8* %[[VAL_404]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_407:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_408:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_409:.*]] = mul nuw nsw i32 %[[VAL_407]], 128
+// CHECK:         %[[VAL_409:.*]] = mul nuw nsw i32 %[[VAL_407]], 256
 // CHECK:         %[[VAL_410:.*]] = add nuw nsw i32 %[[VAL_409]], %[[VAL_408]]
-// CHECK:         %[[VAL_411:.*]] = icmp ult i32 %[[VAL_410]], 163840
+// CHECK:         %[[VAL_411:.*]] = icmp ult i32 %[[VAL_410]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_411]])
 // CHECK:         %[[VAL_412:.*]] = mul nuw nsw i32 %[[VAL_410]], 4
 // CHECK:         %[[VAL_413:.*]] = udiv i32 %[[VAL_412]], 1
@@ -478,32 +478,32 @@
 // CHECK:       r9.in_bounds-after:                               ; preds = %[[VAL_429]], %[[VAL_431:.*]]
 // CHECK:         ret void
 // CHECK:       r9.in_bounds-true:                                ; preds = %[[VAL_431]]
-// CHECK:         %[[VAL_432:.*]] = bitcast [100 x [200 x float]]* %[[VAL_406]] to float*
+// CHECK:         %[[VAL_432:.*]] = bitcast [100 x [200 x float]]* %[[VAL_403]] to float*
 // CHECK:         %[[VAL_433:.*]] = getelementptr inbounds float, float* %[[VAL_432]], i32 %[[VAL_412]]
 // CHECK:         %[[VAL_434:.*]] = load float, float* %[[VAL_433]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_435:.*]] = call float @__nv_expm1f(float %[[VAL_434]])
-// CHECK:         %[[VAL_436:.*]] = bitcast [100 x [200 x float]]* %[[VAL_403]] to float*
+// CHECK:         %[[VAL_436:.*]] = bitcast [100 x [200 x float]]* %[[VAL_406]] to float*
 // CHECK:         %[[VAL_437:.*]] = getelementptr inbounds float, float* %[[VAL_436]], i32 %[[VAL_412]]
 // CHECK:         store float %[[VAL_435]], float* %[[VAL_437]], align 4
-// CHECK:         %[[VAL_438:.*]] = bitcast [100 x [200 x float]]* %[[VAL_406]] to float*
+// CHECK:         %[[VAL_438:.*]] = bitcast [100 x [200 x float]]* %[[VAL_403]] to float*
 // CHECK:         %[[VAL_439:.*]] = getelementptr inbounds float, float* %[[VAL_438]], i32 %[[VAL_416]]
 // CHECK:         %[[VAL_440:.*]] = load float, float* %[[VAL_439]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_441:.*]] = call float @__nv_expm1f(float %[[VAL_440]])
-// CHECK:         %[[VAL_442:.*]] = bitcast [100 x [200 x float]]* %[[VAL_403]] to float*
+// CHECK:         %[[VAL_442:.*]] = bitcast [100 x [200 x float]]* %[[VAL_406]] to float*
 // CHECK:         %[[VAL_443:.*]] = getelementptr inbounds float, float* %[[VAL_442]], i32 %[[VAL_416]]
 // CHECK:         store float %[[VAL_441]], float* %[[VAL_443]], align 4
-// CHECK:         %[[VAL_444:.*]] = bitcast [100 x [200 x float]]* %[[VAL_406]] to float*
+// CHECK:         %[[VAL_444:.*]] = bitcast [100 x [200 x float]]* %[[VAL_403]] to float*
 // CHECK:         %[[VAL_445:.*]] = getelementptr inbounds float, float* %[[VAL_444]], i32 %[[VAL_420]]
 // CHECK:         %[[VAL_446:.*]] = load float, float* %[[VAL_445]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_447:.*]] = call float @__nv_expm1f(float %[[VAL_446]])
-// CHECK:         %[[VAL_448:.*]] = bitcast [100 x [200 x float]]* %[[VAL_403]] to float*
+// CHECK:         %[[VAL_448:.*]] = bitcast [100 x [200 x float]]* %[[VAL_406]] to float*
 // CHECK:         %[[VAL_449:.*]] = getelementptr inbounds float, float* %[[VAL_448]], i32 %[[VAL_420]]
 // CHECK:         store float %[[VAL_447]], float* %[[VAL_449]], align 4
-// CHECK:         %[[VAL_450:.*]] = bitcast [100 x [200 x float]]* %[[VAL_406]] to float*
+// CHECK:         %[[VAL_450:.*]] = bitcast [100 x [200 x float]]* %[[VAL_403]] to float*
 // CHECK:         %[[VAL_451:.*]] = getelementptr inbounds float, float* %[[VAL_450]], i32 %[[VAL_424]]
 // CHECK:         %[[VAL_452:.*]] = load float, float* %[[VAL_451]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_453:.*]] = call float @__nv_expm1f(float %[[VAL_452]])
-// CHECK:         %[[VAL_454:.*]] = bitcast [100 x [200 x float]]* %[[VAL_403]] to float*
+// CHECK:         %[[VAL_454:.*]] = bitcast [100 x [200 x float]]* %[[VAL_406]] to float*
 // CHECK:         %[[VAL_455:.*]] = getelementptr inbounds float, float* %[[VAL_454]], i32 %[[VAL_424]]
 // CHECK:         store float %[[VAL_453]], float* %[[VAL_455]], align 4
 // CHECK:         br label %[[VAL_430]]
@@ -514,9 +514,9 @@
 // CHECK:         %[[VAL_461:.*]] = bitcast i8* %[[VAL_459]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_462:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_463:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_464:.*]] = mul nuw nsw i32 %[[VAL_462]], 128
+// CHECK:         %[[VAL_464:.*]] = mul nuw nsw i32 %[[VAL_462]], 256
 // CHECK:         %[[VAL_465:.*]] = add nuw nsw i32 %[[VAL_464]], %[[VAL_463]]
-// CHECK:         %[[VAL_466:.*]] = icmp ult i32 %[[VAL_465]], 163840
+// CHECK:         %[[VAL_466:.*]] = icmp ult i32 %[[VAL_465]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_466]])
 // CHECK:         %[[VAL_467:.*]] = mul nuw nsw i32 %[[VAL_465]], 4
 // CHECK:         %[[VAL_468:.*]] = udiv i32 %[[VAL_467]], 1
@@ -539,45 +539,45 @@
 // CHECK:       r10.in_bounds-after:                              ; preds = %[[VAL_484]], %[[VAL_486:.*]]
 // CHECK:         ret void
 // CHECK:       r10.in_bounds-true:                               ; preds = %[[VAL_486]]
-// CHECK:         %[[VAL_487:.*]] = bitcast [100 x [200 x float]]* %[[VAL_461]] to float*
+// CHECK:         %[[VAL_487:.*]] = bitcast [100 x [200 x float]]* %[[VAL_458]] to float*
 // CHECK:         %[[VAL_488:.*]] = getelementptr inbounds float, float* %[[VAL_487]], i32 %[[VAL_467]]
 // CHECK:         %[[VAL_489:.*]] = load float, float* %[[VAL_488]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_490:.*]] = call float @llvm.floor.f32(float %[[VAL_489]])
-// CHECK:         %[[VAL_491:.*]] = bitcast [100 x [200 x float]]* %[[VAL_458]] to float*
+// CHECK:         %[[VAL_491:.*]] = bitcast [100 x [200 x float]]* %[[VAL_461]] to float*
 // CHECK:         %[[VAL_492:.*]] = getelementptr inbounds float, float* %[[VAL_491]], i32 %[[VAL_467]]
 // CHECK:         store float %[[VAL_490]], float* %[[VAL_492]], align 4
-// CHECK:         %[[VAL_493:.*]] = bitcast [100 x [200 x float]]* %[[VAL_461]] to float*
+// CHECK:         %[[VAL_493:.*]] = bitcast [100 x [200 x float]]* %[[VAL_458]] to float*
 // CHECK:         %[[VAL_494:.*]] = getelementptr inbounds float, float* %[[VAL_493]], i32 %[[VAL_471]]
 // CHECK:         %[[VAL_495:.*]] = load float, float* %[[VAL_494]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_496:.*]] = call float @llvm.floor.f32(float %[[VAL_495]])
-// CHECK:         %[[VAL_497:.*]] = bitcast [100 x [200 x float]]* %[[VAL_458]] to float*
+// CHECK:         %[[VAL_497:.*]] = bitcast [100 x [200 x float]]* %[[VAL_461]] to float*
 // CHECK:         %[[VAL_498:.*]] = getelementptr inbounds float, float* %[[VAL_497]], i32 %[[VAL_471]]
 // CHECK:         store float %[[VAL_496]], float* %[[VAL_498]], align 4
-// CHECK:         %[[VAL_499:.*]] = bitcast [100 x [200 x float]]* %[[VAL_461]] to float*
+// CHECK:         %[[VAL_499:.*]] = bitcast [100 x [200 x float]]* %[[VAL_458]] to float*
 // CHECK:         %[[VAL_500:.*]] = getelementptr inbounds float, float* %[[VAL_499]], i32 %[[VAL_475]]
 // CHECK:         %[[VAL_501:.*]] = load float, float* %[[VAL_500]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_502:.*]] = call float @llvm.floor.f32(float %[[VAL_501]])
-// CHECK:         %[[VAL_503:.*]] = bitcast [100 x [200 x float]]* %[[VAL_458]] to float*
+// CHECK:         %[[VAL_503:.*]] = bitcast [100 x [200 x float]]* %[[VAL_461]] to float*
 // CHECK:         %[[VAL_504:.*]] = getelementptr inbounds float, float* %[[VAL_503]], i32 %[[VAL_475]]
 // CHECK:         store float %[[VAL_502]], float* %[[VAL_504]], align 4
-// CHECK:         %[[VAL_505:.*]] = bitcast [100 x [200 x float]]* %[[VAL_461]] to float*
+// CHECK:         %[[VAL_505:.*]] = bitcast [100 x [200 x float]]* %[[VAL_458]] to float*
 // CHECK:         %[[VAL_506:.*]] = getelementptr inbounds float, float* %[[VAL_505]], i32 %[[VAL_479]]
 // CHECK:         %[[VAL_507:.*]] = load float, float* %[[VAL_506]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_508:.*]] = call float @llvm.floor.f32(float %[[VAL_507]])
-// CHECK:         %[[VAL_509:.*]] = bitcast [100 x [200 x float]]* %[[VAL_458]] to float*
+// CHECK:         %[[VAL_509:.*]] = bitcast [100 x [200 x float]]* %[[VAL_461]] to float*
 // CHECK:         %[[VAL_510:.*]] = getelementptr inbounds float, float* %[[VAL_509]], i32 %[[VAL_479]]
 // CHECK:         store float %[[VAL_508]], float* %[[VAL_510]], align 4
 // CHECK:         br label %[[VAL_485]]
 // CHECK:       entry:
 // CHECK:         %[[VAL_511:.*]] = getelementptr inbounds i8, i8* %[[VAL_512:.*]], i64 0
-// CHECK:         %[[VAL_513:.*]] = bitcast i8* %[[VAL_511]] to [100 x [200 x float]]*
-// CHECK:         %[[VAL_514:.*]] = getelementptr inbounds i8, i8* %[[VAL_515:.*]], i64 0
-// CHECK:         %[[VAL_516:.*]] = bitcast i8* %[[VAL_514]] to [100 x [200 x %[[VAL_517:.*]]]]*
+// CHECK:         %[[VAL_513:.*]] = bitcast i8* %[[VAL_511]] to [100 x [200 x %[[VAL_514:.*]]]]*
+// CHECK:         %[[VAL_515:.*]] = getelementptr inbounds i8, i8* %[[VAL_516:.*]], i64 0
+// CHECK:         %[[VAL_517:.*]] = bitcast i8* %[[VAL_515]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_518:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_519:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_520:.*]] = mul nuw nsw i32 %[[VAL_518]], 128
+// CHECK:         %[[VAL_520:.*]] = mul nuw nsw i32 %[[VAL_518]], 256
 // CHECK:         %[[VAL_521:.*]] = add nuw nsw i32 %[[VAL_520]], %[[VAL_519]]
-// CHECK:         %[[VAL_522:.*]] = icmp ult i32 %[[VAL_521]], 163840
+// CHECK:         %[[VAL_522:.*]] = icmp ult i32 %[[VAL_521]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_522]])
 // CHECK:         %[[VAL_523:.*]] = mul nuw nsw i32 %[[VAL_521]], 4
 // CHECK:         %[[VAL_524:.*]] = udiv i32 %[[VAL_523]], 1
@@ -600,45 +600,45 @@
 // CHECK:       r11.in_bounds-after:                              ; preds = %[[VAL_540]], %[[VAL_542:.*]]
 // CHECK:         ret void
 // CHECK:       r11.in_bounds-true:                               ; preds = %[[VAL_542]]
-// CHECK:         %[[VAL_543:.*]] = bitcast [100 x [200 x %[[VAL_517]]]]* %[[VAL_516]] to %[[VAL_517]]*
-// CHECK:         %[[VAL_544:.*]] = getelementptr inbounds %[[VAL_517]], %[[VAL_517]]* %[[VAL_543]], i32 %[[VAL_523]]
-// CHECK:         %[[VAL_545:.*]] = load %[[VAL_517]], %[[VAL_517]]* %[[VAL_544]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_546:.*]] = extractvalue %[[VAL_517]] %[[VAL_545]], 1
-// CHECK:         %[[VAL_547:.*]] = bitcast [100 x [200 x float]]* %[[VAL_513]] to float*
+// CHECK:         %[[VAL_543:.*]] = bitcast [100 x [200 x %[[VAL_514]]]]* %[[VAL_513]] to %[[VAL_514]]*
+// CHECK:         %[[VAL_544:.*]] = getelementptr inbounds %[[VAL_514]], %[[VAL_514]]* %[[VAL_543]], i32 %[[VAL_523]]
+// CHECK:         %[[VAL_545:.*]] = load %[[VAL_514]], %[[VAL_514]]* %[[VAL_544]], align 1, !invariant.load !92
+// CHECK:         %[[VAL_546:.*]] = extractvalue %[[VAL_514]] %[[VAL_545]], 1
+// CHECK:         %[[VAL_547:.*]] = bitcast [100 x [200 x float]]* %[[VAL_517]] to float*
 // CHECK:         %[[VAL_548:.*]] = getelementptr inbounds float, float* %[[VAL_547]], i32 %[[VAL_523]]
 // CHECK:         store float %[[VAL_546]], float* %[[VAL_548]], align 4
-// CHECK:         %[[VAL_549:.*]] = bitcast [100 x [200 x %[[VAL_517]]]]* %[[VAL_516]] to %[[VAL_517]]*
-// CHECK:         %[[VAL_550:.*]] = getelementptr inbounds %[[VAL_517]], %[[VAL_517]]* %[[VAL_549]], i32 %[[VAL_527]]
-// CHECK:         %[[VAL_551:.*]] = load %[[VAL_517]], %[[VAL_517]]* %[[VAL_550]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_552:.*]] = extractvalue %[[VAL_517]] %[[VAL_551]], 1
-// CHECK:         %[[VAL_553:.*]] = bitcast [100 x [200 x float]]* %[[VAL_513]] to float*
+// CHECK:         %[[VAL_549:.*]] = bitcast [100 x [200 x %[[VAL_514]]]]* %[[VAL_513]] to %[[VAL_514]]*
+// CHECK:         %[[VAL_550:.*]] = getelementptr inbounds %[[VAL_514]], %[[VAL_514]]* %[[VAL_549]], i32 %[[VAL_527]]
+// CHECK:         %[[VAL_551:.*]] = load %[[VAL_514]], %[[VAL_514]]* %[[VAL_550]], align 1, !invariant.load !92
+// CHECK:         %[[VAL_552:.*]] = extractvalue %[[VAL_514]] %[[VAL_551]], 1
+// CHECK:         %[[VAL_553:.*]] = bitcast [100 x [200 x float]]* %[[VAL_517]] to float*
 // CHECK:         %[[VAL_554:.*]] = getelementptr inbounds float, float* %[[VAL_553]], i32 %[[VAL_527]]
 // CHECK:         store float %[[VAL_552]], float* %[[VAL_554]], align 4
-// CHECK:         %[[VAL_555:.*]] = bitcast [100 x [200 x %[[VAL_517]]]]* %[[VAL_516]] to %[[VAL_517]]*
-// CHECK:         %[[VAL_556:.*]] = getelementptr inbounds %[[VAL_517]], %[[VAL_517]]* %[[VAL_555]], i32 %[[VAL_531]]
-// CHECK:         %[[VAL_557:.*]] = load %[[VAL_517]], %[[VAL_517]]* %[[VAL_556]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_558:.*]] = extractvalue %[[VAL_517]] %[[VAL_557]], 1
-// CHECK:         %[[VAL_559:.*]] = bitcast [100 x [200 x float]]* %[[VAL_513]] to float*
+// CHECK:         %[[VAL_555:.*]] = bitcast [100 x [200 x %[[VAL_514]]]]* %[[VAL_513]] to %[[VAL_514]]*
+// CHECK:         %[[VAL_556:.*]] = getelementptr inbounds %[[VAL_514]], %[[VAL_514]]* %[[VAL_555]], i32 %[[VAL_531]]
+// CHECK:         %[[VAL_557:.*]] = load %[[VAL_514]], %[[VAL_514]]* %[[VAL_556]], align 1, !invariant.load !92
+// CHECK:         %[[VAL_558:.*]] = extractvalue %[[VAL_514]] %[[VAL_557]], 1
+// CHECK:         %[[VAL_559:.*]] = bitcast [100 x [200 x float]]* %[[VAL_517]] to float*
 // CHECK:         %[[VAL_560:.*]] = getelementptr inbounds float, float* %[[VAL_559]], i32 %[[VAL_531]]
 // CHECK:         store float %[[VAL_558]], float* %[[VAL_560]], align 4
-// CHECK:         %[[VAL_561:.*]] = bitcast [100 x [200 x %[[VAL_517]]]]* %[[VAL_516]] to %[[VAL_517]]*
-// CHECK:         %[[VAL_562:.*]] = getelementptr inbounds %[[VAL_517]], %[[VAL_517]]* %[[VAL_561]], i32 %[[VAL_535]]
-// CHECK:         %[[VAL_563:.*]] = load %[[VAL_517]], %[[VAL_517]]* %[[VAL_562]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_564:.*]] = extractvalue %[[VAL_517]] %[[VAL_563]], 1
-// CHECK:         %[[VAL_565:.*]] = bitcast [100 x [200 x float]]* %[[VAL_513]] to float*
+// CHECK:         %[[VAL_561:.*]] = bitcast [100 x [200 x %[[VAL_514]]]]* %[[VAL_513]] to %[[VAL_514]]*
+// CHECK:         %[[VAL_562:.*]] = getelementptr inbounds %[[VAL_514]], %[[VAL_514]]* %[[VAL_561]], i32 %[[VAL_535]]
+// CHECK:         %[[VAL_563:.*]] = load %[[VAL_514]], %[[VAL_514]]* %[[VAL_562]], align 1, !invariant.load !92
+// CHECK:         %[[VAL_564:.*]] = extractvalue %[[VAL_514]] %[[VAL_563]], 1
+// CHECK:         %[[VAL_565:.*]] = bitcast [100 x [200 x float]]* %[[VAL_517]] to float*
 // CHECK:         %[[VAL_566:.*]] = getelementptr inbounds float, float* %[[VAL_565]], i32 %[[VAL_535]]
 // CHECK:         store float %[[VAL_564]], float* %[[VAL_566]], align 4
 // CHECK:         br label %[[VAL_541]]
 // CHECK:       entry:
 // CHECK:         %[[VAL_567:.*]] = getelementptr inbounds i8, i8* %[[VAL_568:.*]], i64 0
-// CHECK:         %[[VAL_569:.*]] = bitcast i8* %[[VAL_567]] to [100 x [200 x i8]]*
+// CHECK:         %[[VAL_569:.*]] = bitcast i8* %[[VAL_567]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_570:.*]] = getelementptr inbounds i8, i8* %[[VAL_571:.*]], i64 0
-// CHECK:         %[[VAL_572:.*]] = bitcast i8* %[[VAL_570]] to [100 x [200 x float]]*
+// CHECK:         %[[VAL_572:.*]] = bitcast i8* %[[VAL_570]] to [100 x [200 x i8]]*
 // CHECK:         %[[VAL_573:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_574:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_575:.*]] = mul nuw nsw i32 %[[VAL_573]], 128
+// CHECK:         %[[VAL_575:.*]] = mul nuw nsw i32 %[[VAL_573]], 256
 // CHECK:         %[[VAL_576:.*]] = add nuw nsw i32 %[[VAL_575]], %[[VAL_574]]
-// CHECK:         %[[VAL_577:.*]] = icmp ult i32 %[[VAL_576]], 163840
+// CHECK:         %[[VAL_577:.*]] = icmp ult i32 %[[VAL_576]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_577]])
 // CHECK:         %[[VAL_578:.*]] = mul nuw nsw i32 %[[VAL_576]], 4
 // CHECK:         %[[VAL_579:.*]] = udiv i32 %[[VAL_578]], 1
@@ -661,40 +661,40 @@
 // CHECK:       r12.in_bounds-after:                              ; preds = %[[VAL_595]], %[[VAL_597:.*]]
 // CHECK:         ret void
 // CHECK:       r12.in_bounds-true:                               ; preds = %[[VAL_597]]
-// CHECK:         %[[VAL_598:.*]] = bitcast [100 x [200 x float]]* %[[VAL_572]] to float*
+// CHECK:         %[[VAL_598:.*]] = bitcast [100 x [200 x float]]* %[[VAL_569]] to float*
 // CHECK:         %[[VAL_599:.*]] = getelementptr inbounds float, float* %[[VAL_598]], i32 %[[VAL_578]]
 // CHECK:         %[[VAL_600:.*]] = load float, float* %[[VAL_599]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_601:.*]] = call float @llvm.fabs.f32(float %[[VAL_600]])
 // CHECK:         %[[VAL_602:.*]] = fcmp one float %[[VAL_601]], 0x7FF0000000000000
 // CHECK:         %[[VAL_603:.*]] = zext i1 %[[VAL_602]] to i8
-// CHECK:         %[[VAL_604:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_569]] to i8*
+// CHECK:         %[[VAL_604:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_572]] to i8*
 // CHECK:         %[[VAL_605:.*]] = getelementptr inbounds i8, i8* %[[VAL_604]], i32 %[[VAL_578]]
 // CHECK:         store i8 %[[VAL_603]], i8* %[[VAL_605]], align 1
-// CHECK:         %[[VAL_606:.*]] = bitcast [100 x [200 x float]]* %[[VAL_572]] to float*
+// CHECK:         %[[VAL_606:.*]] = bitcast [100 x [200 x float]]* %[[VAL_569]] to float*
 // CHECK:         %[[VAL_607:.*]] = getelementptr inbounds float, float* %[[VAL_606]], i32 %[[VAL_582]]
 // CHECK:         %[[VAL_608:.*]] = load float, float* %[[VAL_607]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_609:.*]] = call float @llvm.fabs.f32(float %[[VAL_608]])
 // CHECK:         %[[VAL_610:.*]] = fcmp one float %[[VAL_609]], 0x7FF0000000000000
 // CHECK:         %[[VAL_611:.*]] = zext i1 %[[VAL_610]] to i8
-// CHECK:         %[[VAL_612:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_569]] to i8*
+// CHECK:         %[[VAL_612:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_572]] to i8*
 // CHECK:         %[[VAL_613:.*]] = getelementptr inbounds i8, i8* %[[VAL_612]], i32 %[[VAL_582]]
 // CHECK:         store i8 %[[VAL_611]], i8* %[[VAL_613]], align 1
-// CHECK:         %[[VAL_614:.*]] = bitcast [100 x [200 x float]]* %[[VAL_572]] to float*
+// CHECK:         %[[VAL_614:.*]] = bitcast [100 x [200 x float]]* %[[VAL_569]] to float*
 // CHECK:         %[[VAL_615:.*]] = getelementptr inbounds float, float* %[[VAL_614]], i32 %[[VAL_586]]
 // CHECK:         %[[VAL_616:.*]] = load float, float* %[[VAL_615]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_617:.*]] = call float @llvm.fabs.f32(float %[[VAL_616]])
 // CHECK:         %[[VAL_618:.*]] = fcmp one float %[[VAL_617]], 0x7FF0000000000000
 // CHECK:         %[[VAL_619:.*]] = zext i1 %[[VAL_618]] to i8
-// CHECK:         %[[VAL_620:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_569]] to i8*
+// CHECK:         %[[VAL_620:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_572]] to i8*
 // CHECK:         %[[VAL_621:.*]] = getelementptr inbounds i8, i8* %[[VAL_620]], i32 %[[VAL_586]]
 // CHECK:         store i8 %[[VAL_619]], i8* %[[VAL_621]], align 1
-// CHECK:         %[[VAL_622:.*]] = bitcast [100 x [200 x float]]* %[[VAL_572]] to float*
+// CHECK:         %[[VAL_622:.*]] = bitcast [100 x [200 x float]]* %[[VAL_569]] to float*
 // CHECK:         %[[VAL_623:.*]] = getelementptr inbounds float, float* %[[VAL_622]], i32 %[[VAL_590]]
 // CHECK:         %[[VAL_624:.*]] = load float, float* %[[VAL_623]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_625:.*]] = call float @llvm.fabs.f32(float %[[VAL_624]])
 // CHECK:         %[[VAL_626:.*]] = fcmp one float %[[VAL_625]], 0x7FF0000000000000
 // CHECK:         %[[VAL_627:.*]] = zext i1 %[[VAL_626]] to i8
-// CHECK:         %[[VAL_628:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_569]] to i8*
+// CHECK:         %[[VAL_628:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_572]] to i8*
 // CHECK:         %[[VAL_629:.*]] = getelementptr inbounds i8, i8* %[[VAL_628]], i32 %[[VAL_590]]
 // CHECK:         store i8 %[[VAL_627]], i8* %[[VAL_629]], align 1
 // CHECK:         br label %[[VAL_596]]
@@ -705,9 +705,9 @@
 // CHECK:         %[[VAL_635:.*]] = bitcast i8* %[[VAL_633]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_636:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_637:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_638:.*]] = mul nuw nsw i32 %[[VAL_636]], 128
+// CHECK:         %[[VAL_638:.*]] = mul nuw nsw i32 %[[VAL_636]], 256
 // CHECK:         %[[VAL_639:.*]] = add nuw nsw i32 %[[VAL_638]], %[[VAL_637]]
-// CHECK:         %[[VAL_640:.*]] = icmp ult i32 %[[VAL_639]], 163840
+// CHECK:         %[[VAL_640:.*]] = icmp ult i32 %[[VAL_639]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_640]])
 // CHECK:         %[[VAL_641:.*]] = mul nuw nsw i32 %[[VAL_639]], 4
 // CHECK:         %[[VAL_642:.*]] = udiv i32 %[[VAL_641]], 1
@@ -730,32 +730,32 @@
 // CHECK:       r13.in_bounds-after:                              ; preds = %[[VAL_658]], %[[VAL_660:.*]]
 // CHECK:         ret void
 // CHECK:       r13.in_bounds-true:                               ; preds = %[[VAL_660]]
-// CHECK:         %[[VAL_661:.*]] = bitcast [100 x [200 x float]]* %[[VAL_635]] to float*
+// CHECK:         %[[VAL_661:.*]] = bitcast [100 x [200 x float]]* %[[VAL_632]] to float*
 // CHECK:         %[[VAL_662:.*]] = getelementptr inbounds float, float* %[[VAL_661]], i32 %[[VAL_641]]
 // CHECK:         %[[VAL_663:.*]] = load float, float* %[[VAL_662]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_664:.*]] = call float @__nv_logf(float %[[VAL_663]])
-// CHECK:         %[[VAL_665:.*]] = bitcast [100 x [200 x float]]* %[[VAL_632]] to float*
+// CHECK:         %[[VAL_665:.*]] = bitcast [100 x [200 x float]]* %[[VAL_635]] to float*
 // CHECK:         %[[VAL_666:.*]] = getelementptr inbounds float, float* %[[VAL_665]], i32 %[[VAL_641]]
 // CHECK:         store float %[[VAL_664]], float* %[[VAL_666]], align 4
-// CHECK:         %[[VAL_667:.*]] = bitcast [100 x [200 x float]]* %[[VAL_635]] to float*
+// CHECK:         %[[VAL_667:.*]] = bitcast [100 x [200 x float]]* %[[VAL_632]] to float*
 // CHECK:         %[[VAL_668:.*]] = getelementptr inbounds float, float* %[[VAL_667]], i32 %[[VAL_645]]
 // CHECK:         %[[VAL_669:.*]] = load float, float* %[[VAL_668]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_670:.*]] = call float @__nv_logf(float %[[VAL_669]])
-// CHECK:         %[[VAL_671:.*]] = bitcast [100 x [200 x float]]* %[[VAL_632]] to float*
+// CHECK:         %[[VAL_671:.*]] = bitcast [100 x [200 x float]]* %[[VAL_635]] to float*
 // CHECK:         %[[VAL_672:.*]] = getelementptr inbounds float, float* %[[VAL_671]], i32 %[[VAL_645]]
 // CHECK:         store float %[[VAL_670]], float* %[[VAL_672]], align 4
-// CHECK:         %[[VAL_673:.*]] = bitcast [100 x [200 x float]]* %[[VAL_635]] to float*
+// CHECK:         %[[VAL_673:.*]] = bitcast [100 x [200 x float]]* %[[VAL_632]] to float*
 // CHECK:         %[[VAL_674:.*]] = getelementptr inbounds float, float* %[[VAL_673]], i32 %[[VAL_649]]
 // CHECK:         %[[VAL_675:.*]] = load float, float* %[[VAL_674]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_676:.*]] = call float @__nv_logf(float %[[VAL_675]])
-// CHECK:         %[[VAL_677:.*]] = bitcast [100 x [200 x float]]* %[[VAL_632]] to float*
+// CHECK:         %[[VAL_677:.*]] = bitcast [100 x [200 x float]]* %[[VAL_635]] to float*
 // CHECK:         %[[VAL_678:.*]] = getelementptr inbounds float, float* %[[VAL_677]], i32 %[[VAL_649]]
 // CHECK:         store float %[[VAL_676]], float* %[[VAL_678]], align 4
-// CHECK:         %[[VAL_679:.*]] = bitcast [100 x [200 x float]]* %[[VAL_635]] to float*
+// CHECK:         %[[VAL_679:.*]] = bitcast [100 x [200 x float]]* %[[VAL_632]] to float*
 // CHECK:         %[[VAL_680:.*]] = getelementptr inbounds float, float* %[[VAL_679]], i32 %[[VAL_653]]
 // CHECK:         %[[VAL_681:.*]] = load float, float* %[[VAL_680]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_682:.*]] = call float @__nv_logf(float %[[VAL_681]])
-// CHECK:         %[[VAL_683:.*]] = bitcast [100 x [200 x float]]* %[[VAL_632]] to float*
+// CHECK:         %[[VAL_683:.*]] = bitcast [100 x [200 x float]]* %[[VAL_635]] to float*
 // CHECK:         %[[VAL_684:.*]] = getelementptr inbounds float, float* %[[VAL_683]], i32 %[[VAL_653]]
 // CHECK:         store float %[[VAL_682]], float* %[[VAL_684]], align 4
 // CHECK:         br label %[[VAL_659]]
@@ -766,9 +766,9 @@
 // CHECK:         %[[VAL_690:.*]] = bitcast i8* %[[VAL_688]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_691:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_692:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_693:.*]] = mul nuw nsw i32 %[[VAL_691]], 128
+// CHECK:         %[[VAL_693:.*]] = mul nuw nsw i32 %[[VAL_691]], 256
 // CHECK:         %[[VAL_694:.*]] = add nuw nsw i32 %[[VAL_693]], %[[VAL_692]]
-// CHECK:         %[[VAL_695:.*]] = icmp ult i32 %[[VAL_694]], 163840
+// CHECK:         %[[VAL_695:.*]] = icmp ult i32 %[[VAL_694]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_695]])
 // CHECK:         %[[VAL_696:.*]] = mul nuw nsw i32 %[[VAL_694]], 4
 // CHECK:         %[[VAL_697:.*]] = udiv i32 %[[VAL_696]], 1
@@ -791,32 +791,32 @@
 // CHECK:       r14.in_bounds-after:                              ; preds = %[[VAL_713]], %[[VAL_715:.*]]
 // CHECK:         ret void
 // CHECK:       r14.in_bounds-true:                               ; preds = %[[VAL_715]]
-// CHECK:         %[[VAL_716:.*]] = bitcast [100 x [200 x float]]* %[[VAL_690]] to float*
+// CHECK:         %[[VAL_716:.*]] = bitcast [100 x [200 x float]]* %[[VAL_687]] to float*
 // CHECK:         %[[VAL_717:.*]] = getelementptr inbounds float, float* %[[VAL_716]], i32 %[[VAL_696]]
 // CHECK:         %[[VAL_718:.*]] = load float, float* %[[VAL_717]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_719:.*]] = call float @__nv_log1pf(float %[[VAL_718]])
-// CHECK:         %[[VAL_720:.*]] = bitcast [100 x [200 x float]]* %[[VAL_687]] to float*
+// CHECK:         %[[VAL_720:.*]] = bitcast [100 x [200 x float]]* %[[VAL_690]] to float*
 // CHECK:         %[[VAL_721:.*]] = getelementptr inbounds float, float* %[[VAL_720]], i32 %[[VAL_696]]
 // CHECK:         store float %[[VAL_719]], float* %[[VAL_721]], align 4
-// CHECK:         %[[VAL_722:.*]] = bitcast [100 x [200 x float]]* %[[VAL_690]] to float*
+// CHECK:         %[[VAL_722:.*]] = bitcast [100 x [200 x float]]* %[[VAL_687]] to float*
 // CHECK:         %[[VAL_723:.*]] = getelementptr inbounds float, float* %[[VAL_722]], i32 %[[VAL_700]]
 // CHECK:         %[[VAL_724:.*]] = load float, float* %[[VAL_723]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_725:.*]] = call float @__nv_log1pf(float %[[VAL_724]])
-// CHECK:         %[[VAL_726:.*]] = bitcast [100 x [200 x float]]* %[[VAL_687]] to float*
+// CHECK:         %[[VAL_726:.*]] = bitcast [100 x [200 x float]]* %[[VAL_690]] to float*
 // CHECK:         %[[VAL_727:.*]] = getelementptr inbounds float, float* %[[VAL_726]], i32 %[[VAL_700]]
 // CHECK:         store float %[[VAL_725]], float* %[[VAL_727]], align 4
-// CHECK:         %[[VAL_728:.*]] = bitcast [100 x [200 x float]]* %[[VAL_690]] to float*
+// CHECK:         %[[VAL_728:.*]] = bitcast [100 x [200 x float]]* %[[VAL_687]] to float*
 // CHECK:         %[[VAL_729:.*]] = getelementptr inbounds float, float* %[[VAL_728]], i32 %[[VAL_704]]
 // CHECK:         %[[VAL_730:.*]] = load float, float* %[[VAL_729]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_731:.*]] = call float @__nv_log1pf(float %[[VAL_730]])
-// CHECK:         %[[VAL_732:.*]] = bitcast [100 x [200 x float]]* %[[VAL_687]] to float*
+// CHECK:         %[[VAL_732:.*]] = bitcast [100 x [200 x float]]* %[[VAL_690]] to float*
 // CHECK:         %[[VAL_733:.*]] = getelementptr inbounds float, float* %[[VAL_732]], i32 %[[VAL_704]]
 // CHECK:         store float %[[VAL_731]], float* %[[VAL_733]], align 4
-// CHECK:         %[[VAL_734:.*]] = bitcast [100 x [200 x float]]* %[[VAL_690]] to float*
+// CHECK:         %[[VAL_734:.*]] = bitcast [100 x [200 x float]]* %[[VAL_687]] to float*
 // CHECK:         %[[VAL_735:.*]] = getelementptr inbounds float, float* %[[VAL_734]], i32 %[[VAL_708]]
 // CHECK:         %[[VAL_736:.*]] = load float, float* %[[VAL_735]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_737:.*]] = call float @__nv_log1pf(float %[[VAL_736]])
-// CHECK:         %[[VAL_738:.*]] = bitcast [100 x [200 x float]]* %[[VAL_687]] to float*
+// CHECK:         %[[VAL_738:.*]] = bitcast [100 x [200 x float]]* %[[VAL_690]] to float*
 // CHECK:         %[[VAL_739:.*]] = getelementptr inbounds float, float* %[[VAL_738]], i32 %[[VAL_708]]
 // CHECK:         store float %[[VAL_737]], float* %[[VAL_739]], align 4
 // CHECK:         br label %[[VAL_714]]
@@ -827,9 +827,9 @@
 // CHECK:         %[[VAL_745:.*]] = bitcast i8* %[[VAL_743]] to [100 x [200 x i8]]*
 // CHECK:         %[[VAL_746:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_747:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_748:.*]] = mul nuw nsw i32 %[[VAL_746]], 128
+// CHECK:         %[[VAL_748:.*]] = mul nuw nsw i32 %[[VAL_746]], 256
 // CHECK:         %[[VAL_749:.*]] = add nuw nsw i32 %[[VAL_748]], %[[VAL_747]]
-// CHECK:         %[[VAL_750:.*]] = icmp ult i32 %[[VAL_749]], 163840
+// CHECK:         %[[VAL_750:.*]] = icmp ult i32 %[[VAL_749]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_750]])
 // CHECK:         %[[VAL_751:.*]] = mul nuw nsw i32 %[[VAL_749]], 4
 // CHECK:         %[[VAL_752:.*]] = udiv i32 %[[VAL_751]], 1
@@ -852,40 +852,40 @@
 // CHECK:       r15.in_bounds-after:                              ; preds = %[[VAL_768]], %[[VAL_770:.*]]
 // CHECK:         ret void
 // CHECK:       r15.in_bounds-true:                               ; preds = %[[VAL_770]]
-// CHECK:         %[[VAL_771:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_745]] to i8*
+// CHECK:         %[[VAL_771:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_742]] to i8*
 // CHECK:         %[[VAL_772:.*]] = getelementptr inbounds i8, i8* %[[VAL_771]], i32 %[[VAL_751]]
 // CHECK:         %[[VAL_773:.*]] = load i8, i8* %[[VAL_772]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_774:.*]] = trunc i8 %[[VAL_773]] to i1
 // CHECK:         %[[VAL_775:.*]] = xor i1 %[[VAL_774]], true
 // CHECK:         %[[VAL_776:.*]] = zext i1 %[[VAL_775]] to i8
-// CHECK:         %[[VAL_777:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_742]] to i8*
+// CHECK:         %[[VAL_777:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_745]] to i8*
 // CHECK:         %[[VAL_778:.*]] = getelementptr inbounds i8, i8* %[[VAL_777]], i32 %[[VAL_751]]
 // CHECK:         store i8 %[[VAL_776]], i8* %[[VAL_778]], align 1
-// CHECK:         %[[VAL_779:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_745]] to i8*
+// CHECK:         %[[VAL_779:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_742]] to i8*
 // CHECK:         %[[VAL_780:.*]] = getelementptr inbounds i8, i8* %[[VAL_779]], i32 %[[VAL_755]]
 // CHECK:         %[[VAL_781:.*]] = load i8, i8* %[[VAL_780]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_782:.*]] = trunc i8 %[[VAL_781]] to i1
 // CHECK:         %[[VAL_783:.*]] = xor i1 %[[VAL_782]], true
 // CHECK:         %[[VAL_784:.*]] = zext i1 %[[VAL_783]] to i8
-// CHECK:         %[[VAL_785:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_742]] to i8*
+// CHECK:         %[[VAL_785:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_745]] to i8*
 // CHECK:         %[[VAL_786:.*]] = getelementptr inbounds i8, i8* %[[VAL_785]], i32 %[[VAL_755]]
 // CHECK:         store i8 %[[VAL_784]], i8* %[[VAL_786]], align 1
-// CHECK:         %[[VAL_787:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_745]] to i8*
+// CHECK:         %[[VAL_787:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_742]] to i8*
 // CHECK:         %[[VAL_788:.*]] = getelementptr inbounds i8, i8* %[[VAL_787]], i32 %[[VAL_759]]
 // CHECK:         %[[VAL_789:.*]] = load i8, i8* %[[VAL_788]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_790:.*]] = trunc i8 %[[VAL_789]] to i1
 // CHECK:         %[[VAL_791:.*]] = xor i1 %[[VAL_790]], true
 // CHECK:         %[[VAL_792:.*]] = zext i1 %[[VAL_791]] to i8
-// CHECK:         %[[VAL_793:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_742]] to i8*
+// CHECK:         %[[VAL_793:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_745]] to i8*
 // CHECK:         %[[VAL_794:.*]] = getelementptr inbounds i8, i8* %[[VAL_793]], i32 %[[VAL_759]]
 // CHECK:         store i8 %[[VAL_792]], i8* %[[VAL_794]], align 1
-// CHECK:         %[[VAL_795:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_745]] to i8*
+// CHECK:         %[[VAL_795:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_742]] to i8*
 // CHECK:         %[[VAL_796:.*]] = getelementptr inbounds i8, i8* %[[VAL_795]], i32 %[[VAL_763]]
 // CHECK:         %[[VAL_797:.*]] = load i8, i8* %[[VAL_796]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_798:.*]] = trunc i8 %[[VAL_797]] to i1
 // CHECK:         %[[VAL_799:.*]] = xor i1 %[[VAL_798]], true
 // CHECK:         %[[VAL_800:.*]] = zext i1 %[[VAL_799]] to i8
-// CHECK:         %[[VAL_801:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_742]] to i8*
+// CHECK:         %[[VAL_801:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_745]] to i8*
 // CHECK:         %[[VAL_802:.*]] = getelementptr inbounds i8, i8* %[[VAL_801]], i32 %[[VAL_763]]
 // CHECK:         store i8 %[[VAL_800]], i8* %[[VAL_802]], align 1
 // CHECK:         br label %[[VAL_769]]
@@ -896,9 +896,9 @@
 // CHECK:         %[[VAL_808:.*]] = bitcast i8* %[[VAL_806]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_809:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_810:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_811:.*]] = mul nuw nsw i32 %[[VAL_809]], 128
+// CHECK:         %[[VAL_811:.*]] = mul nuw nsw i32 %[[VAL_809]], 256
 // CHECK:         %[[VAL_812:.*]] = add nuw nsw i32 %[[VAL_811]], %[[VAL_810]]
-// CHECK:         %[[VAL_813:.*]] = icmp ult i32 %[[VAL_812]], 163840
+// CHECK:         %[[VAL_813:.*]] = icmp ult i32 %[[VAL_812]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_813]])
 // CHECK:         %[[VAL_814:.*]] = mul nuw nsw i32 %[[VAL_812]], 4
 // CHECK:         %[[VAL_815:.*]] = udiv i32 %[[VAL_814]], 1
@@ -921,32 +921,32 @@
 // CHECK:       r16.in_bounds-after:                              ; preds = %[[VAL_831]], %[[VAL_833:.*]]
 // CHECK:         ret void
 // CHECK:       r16.in_bounds-true:                               ; preds = %[[VAL_833]]
-// CHECK:         %[[VAL_834:.*]] = bitcast [100 x [200 x float]]* %[[VAL_808]] to float*
+// CHECK:         %[[VAL_834:.*]] = bitcast [100 x [200 x float]]* %[[VAL_805]] to float*
 // CHECK:         %[[VAL_835:.*]] = getelementptr inbounds float, float* %[[VAL_834]], i32 %[[VAL_814]]
 // CHECK:         %[[VAL_836:.*]] = load float, float* %[[VAL_835]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_837:.*]] = fneg float %[[VAL_836]]
-// CHECK:         %[[VAL_838:.*]] = bitcast [100 x [200 x float]]* %[[VAL_805]] to float*
+// CHECK:         %[[VAL_838:.*]] = bitcast [100 x [200 x float]]* %[[VAL_808]] to float*
 // CHECK:         %[[VAL_839:.*]] = getelementptr inbounds float, float* %[[VAL_838]], i32 %[[VAL_814]]
 // CHECK:         store float %[[VAL_837]], float* %[[VAL_839]], align 4
-// CHECK:         %[[VAL_840:.*]] = bitcast [100 x [200 x float]]* %[[VAL_808]] to float*
+// CHECK:         %[[VAL_840:.*]] = bitcast [100 x [200 x float]]* %[[VAL_805]] to float*
 // CHECK:         %[[VAL_841:.*]] = getelementptr inbounds float, float* %[[VAL_840]], i32 %[[VAL_818]]
 // CHECK:         %[[VAL_842:.*]] = load float, float* %[[VAL_841]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_843:.*]] = fneg float %[[VAL_842]]
-// CHECK:         %[[VAL_844:.*]] = bitcast [100 x [200 x float]]* %[[VAL_805]] to float*
+// CHECK:         %[[VAL_844:.*]] = bitcast [100 x [200 x float]]* %[[VAL_808]] to float*
 // CHECK:         %[[VAL_845:.*]] = getelementptr inbounds float, float* %[[VAL_844]], i32 %[[VAL_818]]
 // CHECK:         store float %[[VAL_843]], float* %[[VAL_845]], align 4
-// CHECK:         %[[VAL_846:.*]] = bitcast [100 x [200 x float]]* %[[VAL_808]] to float*
+// CHECK:         %[[VAL_846:.*]] = bitcast [100 x [200 x float]]* %[[VAL_805]] to float*
 // CHECK:         %[[VAL_847:.*]] = getelementptr inbounds float, float* %[[VAL_846]], i32 %[[VAL_822]]
 // CHECK:         %[[VAL_848:.*]] = load float, float* %[[VAL_847]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_849:.*]] = fneg float %[[VAL_848]]
-// CHECK:         %[[VAL_850:.*]] = bitcast [100 x [200 x float]]* %[[VAL_805]] to float*
+// CHECK:         %[[VAL_850:.*]] = bitcast [100 x [200 x float]]* %[[VAL_808]] to float*
 // CHECK:         %[[VAL_851:.*]] = getelementptr inbounds float, float* %[[VAL_850]], i32 %[[VAL_822]]
 // CHECK:         store float %[[VAL_849]], float* %[[VAL_851]], align 4
-// CHECK:         %[[VAL_852:.*]] = bitcast [100 x [200 x float]]* %[[VAL_808]] to float*
+// CHECK:         %[[VAL_852:.*]] = bitcast [100 x [200 x float]]* %[[VAL_805]] to float*
 // CHECK:         %[[VAL_853:.*]] = getelementptr inbounds float, float* %[[VAL_852]], i32 %[[VAL_826]]
 // CHECK:         %[[VAL_854:.*]] = load float, float* %[[VAL_853]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_855:.*]] = fneg float %[[VAL_854]]
-// CHECK:         %[[VAL_856:.*]] = bitcast [100 x [200 x float]]* %[[VAL_805]] to float*
+// CHECK:         %[[VAL_856:.*]] = bitcast [100 x [200 x float]]* %[[VAL_808]] to float*
 // CHECK:         %[[VAL_857:.*]] = getelementptr inbounds float, float* %[[VAL_856]], i32 %[[VAL_826]]
 // CHECK:         store float %[[VAL_855]], float* %[[VAL_857]], align 4
 // CHECK:         br label %[[VAL_832]]
@@ -957,9 +957,9 @@
 // CHECK:         %[[VAL_863:.*]] = bitcast i8* %[[VAL_861]] to [100 x [200 x i32]]*
 // CHECK:         %[[VAL_864:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_865:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_866:.*]] = mul nuw nsw i32 %[[VAL_864]], 128
+// CHECK:         %[[VAL_866:.*]] = mul nuw nsw i32 %[[VAL_864]], 256
 // CHECK:         %[[VAL_867:.*]] = add nuw nsw i32 %[[VAL_866]], %[[VAL_865]]
-// CHECK:         %[[VAL_868:.*]] = icmp ult i32 %[[VAL_867]], 163840
+// CHECK:         %[[VAL_868:.*]] = icmp ult i32 %[[VAL_867]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_868]])
 // CHECK:         %[[VAL_869:.*]] = mul nuw nsw i32 %[[VAL_867]], 4
 // CHECK:         %[[VAL_870:.*]] = udiv i32 %[[VAL_869]], 1
@@ -982,45 +982,45 @@
 // CHECK:       r17.in_bounds-after:                              ; preds = %[[VAL_886]], %[[VAL_888:.*]]
 // CHECK:         ret void
 // CHECK:       r17.in_bounds-true:                               ; preds = %[[VAL_888]]
-// CHECK:         %[[VAL_889:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_863]] to i32*
+// CHECK:         %[[VAL_889:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_860]] to i32*
 // CHECK:         %[[VAL_890:.*]] = getelementptr inbounds i32, i32* %[[VAL_889]], i32 %[[VAL_869]]
 // CHECK:         %[[VAL_891:.*]] = load i32, i32* %[[VAL_890]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_892:.*]] = call i32 @llvm.ctpop.i32(i32 %[[VAL_891]])
-// CHECK:         %[[VAL_893:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_860]] to i32*
+// CHECK:         %[[VAL_893:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_863]] to i32*
 // CHECK:         %[[VAL_894:.*]] = getelementptr inbounds i32, i32* %[[VAL_893]], i32 %[[VAL_869]]
 // CHECK:         store i32 %[[VAL_892]], i32* %[[VAL_894]], align 4
-// CHECK:         %[[VAL_895:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_863]] to i32*
+// CHECK:         %[[VAL_895:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_860]] to i32*
 // CHECK:         %[[VAL_896:.*]] = getelementptr inbounds i32, i32* %[[VAL_895]], i32 %[[VAL_873]]
 // CHECK:         %[[VAL_897:.*]] = load i32, i32* %[[VAL_896]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_898:.*]] = call i32 @llvm.ctpop.i32(i32 %[[VAL_897]])
-// CHECK:         %[[VAL_899:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_860]] to i32*
+// CHECK:         %[[VAL_899:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_863]] to i32*
 // CHECK:         %[[VAL_900:.*]] = getelementptr inbounds i32, i32* %[[VAL_899]], i32 %[[VAL_873]]
 // CHECK:         store i32 %[[VAL_898]], i32* %[[VAL_900]], align 4
-// CHECK:         %[[VAL_901:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_863]] to i32*
+// CHECK:         %[[VAL_901:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_860]] to i32*
 // CHECK:         %[[VAL_902:.*]] = getelementptr inbounds i32, i32* %[[VAL_901]], i32 %[[VAL_877]]
 // CHECK:         %[[VAL_903:.*]] = load i32, i32* %[[VAL_902]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_904:.*]] = call i32 @llvm.ctpop.i32(i32 %[[VAL_903]])
-// CHECK:         %[[VAL_905:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_860]] to i32*
+// CHECK:         %[[VAL_905:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_863]] to i32*
 // CHECK:         %[[VAL_906:.*]] = getelementptr inbounds i32, i32* %[[VAL_905]], i32 %[[VAL_877]]
 // CHECK:         store i32 %[[VAL_904]], i32* %[[VAL_906]], align 4
-// CHECK:         %[[VAL_907:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_863]] to i32*
+// CHECK:         %[[VAL_907:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_860]] to i32*
 // CHECK:         %[[VAL_908:.*]] = getelementptr inbounds i32, i32* %[[VAL_907]], i32 %[[VAL_881]]
 // CHECK:         %[[VAL_909:.*]] = load i32, i32* %[[VAL_908]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_910:.*]] = call i32 @llvm.ctpop.i32(i32 %[[VAL_909]])
-// CHECK:         %[[VAL_911:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_860]] to i32*
+// CHECK:         %[[VAL_911:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_863]] to i32*
 // CHECK:         %[[VAL_912:.*]] = getelementptr inbounds i32, i32* %[[VAL_911]], i32 %[[VAL_881]]
 // CHECK:         store i32 %[[VAL_910]], i32* %[[VAL_912]], align 4
 // CHECK:         br label %[[VAL_887]]
 // CHECK:       entry:
 // CHECK:         %[[VAL_913:.*]] = getelementptr inbounds i8, i8* %[[VAL_914:.*]], i64 0
-// CHECK:         %[[VAL_915:.*]] = bitcast i8* %[[VAL_913]] to [100 x [200 x float]]*
-// CHECK:         %[[VAL_916:.*]] = getelementptr inbounds i8, i8* %[[VAL_917:.*]], i64 0
-// CHECK:         %[[VAL_918:.*]] = bitcast i8* %[[VAL_916]] to [100 x [200 x %[[VAL_919:.*]]]]*
+// CHECK:         %[[VAL_915:.*]] = bitcast i8* %[[VAL_913]] to [100 x [200 x %[[VAL_916:.*]]]]*
+// CHECK:         %[[VAL_917:.*]] = getelementptr inbounds i8, i8* %[[VAL_918:.*]], i64 0
+// CHECK:         %[[VAL_919:.*]] = bitcast i8* %[[VAL_917]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_920:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_921:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_922:.*]] = mul nuw nsw i32 %[[VAL_920]], 128
+// CHECK:         %[[VAL_922:.*]] = mul nuw nsw i32 %[[VAL_920]], 256
 // CHECK:         %[[VAL_923:.*]] = add nuw nsw i32 %[[VAL_922]], %[[VAL_921]]
-// CHECK:         %[[VAL_924:.*]] = icmp ult i32 %[[VAL_923]], 163840
+// CHECK:         %[[VAL_924:.*]] = icmp ult i32 %[[VAL_923]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_924]])
 // CHECK:         %[[VAL_925:.*]] = mul nuw nsw i32 %[[VAL_923]], 4
 // CHECK:         %[[VAL_926:.*]] = udiv i32 %[[VAL_925]], 1
@@ -1043,32 +1043,32 @@
 // CHECK:       r18.in_bounds-after:                              ; preds = %[[VAL_942]], %[[VAL_944:.*]]
 // CHECK:         ret void
 // CHECK:       r18.in_bounds-true:                               ; preds = %[[VAL_944]]
-// CHECK:         %[[VAL_945:.*]] = bitcast [100 x [200 x %[[VAL_919]]]]* %[[VAL_918]] to %[[VAL_919]]*
-// CHECK:         %[[VAL_946:.*]] = getelementptr inbounds %[[VAL_919]], %[[VAL_919]]* %[[VAL_945]], i32 %[[VAL_925]]
-// CHECK:         %[[VAL_947:.*]] = load %[[VAL_919]], %[[VAL_919]]* %[[VAL_946]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_948:.*]] = extractvalue %[[VAL_919]] %[[VAL_947]], 0
-// CHECK:         %[[VAL_949:.*]] = bitcast [100 x [200 x float]]* %[[VAL_915]] to float*
+// CHECK:         %[[VAL_945:.*]] = bitcast [100 x [200 x %[[VAL_916]]]]* %[[VAL_915]] to %[[VAL_916]]*
+// CHECK:         %[[VAL_946:.*]] = getelementptr inbounds %[[VAL_916]], %[[VAL_916]]* %[[VAL_945]], i32 %[[VAL_925]]
+// CHECK:         %[[VAL_947:.*]] = load %[[VAL_916]], %[[VAL_916]]* %[[VAL_946]], align 1, !invariant.load !92
+// CHECK:         %[[VAL_948:.*]] = extractvalue %[[VAL_916]] %[[VAL_947]], 0
+// CHECK:         %[[VAL_949:.*]] = bitcast [100 x [200 x float]]* %[[VAL_919]] to float*
 // CHECK:         %[[VAL_950:.*]] = getelementptr inbounds float, float* %[[VAL_949]], i32 %[[VAL_925]]
 // CHECK:         store float %[[VAL_948]], float* %[[VAL_950]], align 4
-// CHECK:         %[[VAL_951:.*]] = bitcast [100 x [200 x %[[VAL_919]]]]* %[[VAL_918]] to %[[VAL_919]]*
-// CHECK:         %[[VAL_952:.*]] = getelementptr inbounds %[[VAL_919]], %[[VAL_919]]* %[[VAL_951]], i32 %[[VAL_929]]
-// CHECK:         %[[VAL_953:.*]] = load %[[VAL_919]], %[[VAL_919]]* %[[VAL_952]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_954:.*]] = extractvalue %[[VAL_919]] %[[VAL_953]], 0
-// CHECK:         %[[VAL_955:.*]] = bitcast [100 x [200 x float]]* %[[VAL_915]] to float*
+// CHECK:         %[[VAL_951:.*]] = bitcast [100 x [200 x %[[VAL_916]]]]* %[[VAL_915]] to %[[VAL_916]]*
+// CHECK:         %[[VAL_952:.*]] = getelementptr inbounds %[[VAL_916]], %[[VAL_916]]* %[[VAL_951]], i32 %[[VAL_929]]
+// CHECK:         %[[VAL_953:.*]] = load %[[VAL_916]], %[[VAL_916]]* %[[VAL_952]], align 1, !invariant.load !92
+// CHECK:         %[[VAL_954:.*]] = extractvalue %[[VAL_916]] %[[VAL_953]], 0
+// CHECK:         %[[VAL_955:.*]] = bitcast [100 x [200 x float]]* %[[VAL_919]] to float*
 // CHECK:         %[[VAL_956:.*]] = getelementptr inbounds float, float* %[[VAL_955]], i32 %[[VAL_929]]
 // CHECK:         store float %[[VAL_954]], float* %[[VAL_956]], align 4
-// CHECK:         %[[VAL_957:.*]] = bitcast [100 x [200 x %[[VAL_919]]]]* %[[VAL_918]] to %[[VAL_919]]*
-// CHECK:         %[[VAL_958:.*]] = getelementptr inbounds %[[VAL_919]], %[[VAL_919]]* %[[VAL_957]], i32 %[[VAL_933]]
-// CHECK:         %[[VAL_959:.*]] = load %[[VAL_919]], %[[VAL_919]]* %[[VAL_958]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_960:.*]] = extractvalue %[[VAL_919]] %[[VAL_959]], 0
-// CHECK:         %[[VAL_961:.*]] = bitcast [100 x [200 x float]]* %[[VAL_915]] to float*
+// CHECK:         %[[VAL_957:.*]] = bitcast [100 x [200 x %[[VAL_916]]]]* %[[VAL_915]] to %[[VAL_916]]*
+// CHECK:         %[[VAL_958:.*]] = getelementptr inbounds %[[VAL_916]], %[[VAL_916]]* %[[VAL_957]], i32 %[[VAL_933]]
+// CHECK:         %[[VAL_959:.*]] = load %[[VAL_916]], %[[VAL_916]]* %[[VAL_958]], align 1, !invariant.load !92
+// CHECK:         %[[VAL_960:.*]] = extractvalue %[[VAL_916]] %[[VAL_959]], 0
+// CHECK:         %[[VAL_961:.*]] = bitcast [100 x [200 x float]]* %[[VAL_919]] to float*
 // CHECK:         %[[VAL_962:.*]] = getelementptr inbounds float, float* %[[VAL_961]], i32 %[[VAL_933]]
 // CHECK:         store float %[[VAL_960]], float* %[[VAL_962]], align 4
-// CHECK:         %[[VAL_963:.*]] = bitcast [100 x [200 x %[[VAL_919]]]]* %[[VAL_918]] to %[[VAL_919]]*
-// CHECK:         %[[VAL_964:.*]] = getelementptr inbounds %[[VAL_919]], %[[VAL_919]]* %[[VAL_963]], i32 %[[VAL_937]]
-// CHECK:         %[[VAL_965:.*]] = load %[[VAL_919]], %[[VAL_919]]* %[[VAL_964]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_966:.*]] = extractvalue %[[VAL_919]] %[[VAL_965]], 0
-// CHECK:         %[[VAL_967:.*]] = bitcast [100 x [200 x float]]* %[[VAL_915]] to float*
+// CHECK:         %[[VAL_963:.*]] = bitcast [100 x [200 x %[[VAL_916]]]]* %[[VAL_915]] to %[[VAL_916]]*
+// CHECK:         %[[VAL_964:.*]] = getelementptr inbounds %[[VAL_916]], %[[VAL_916]]* %[[VAL_963]], i32 %[[VAL_937]]
+// CHECK:         %[[VAL_965:.*]] = load %[[VAL_916]], %[[VAL_916]]* %[[VAL_964]], align 1, !invariant.load !92
+// CHECK:         %[[VAL_966:.*]] = extractvalue %[[VAL_916]] %[[VAL_965]], 0
+// CHECK:         %[[VAL_967:.*]] = bitcast [100 x [200 x float]]* %[[VAL_919]] to float*
 // CHECK:         %[[VAL_968:.*]] = getelementptr inbounds float, float* %[[VAL_967]], i32 %[[VAL_937]]
 // CHECK:         store float %[[VAL_966]], float* %[[VAL_968]], align 4
 // CHECK:         br label %[[VAL_943]]
@@ -1079,9 +1079,9 @@
 // CHECK:         %[[VAL_974:.*]] = bitcast i8* %[[VAL_972]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_975:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_976:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_977:.*]] = mul nuw nsw i32 %[[VAL_975]], 128
+// CHECK:         %[[VAL_977:.*]] = mul nuw nsw i32 %[[VAL_975]], 256
 // CHECK:         %[[VAL_978:.*]] = add nuw nsw i32 %[[VAL_977]], %[[VAL_976]]
-// CHECK:         %[[VAL_979:.*]] = icmp ult i32 %[[VAL_978]], 163840
+// CHECK:         %[[VAL_979:.*]] = icmp ult i32 %[[VAL_978]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_979]])
 // CHECK:         %[[VAL_980:.*]] = mul nuw nsw i32 %[[VAL_978]], 4
 // CHECK:         %[[VAL_981:.*]] = udiv i32 %[[VAL_980]], 1
@@ -1104,7 +1104,7 @@
 // CHECK:       r19.in_bounds-after:                              ; preds = %[[VAL_997]], %[[VAL_999:.*]]
 // CHECK:         ret void
 // CHECK:       r19.in_bounds-true:                               ; preds = %[[VAL_999]]
-// CHECK:         %[[VAL_1000:.*]] = bitcast [100 x [200 x float]]* %[[VAL_974]] to float*
+// CHECK:         %[[VAL_1000:.*]] = bitcast [100 x [200 x float]]* %[[VAL_971]] to float*
 // CHECK:         %[[VAL_1001:.*]] = getelementptr inbounds float, float* %[[VAL_1000]], i32 %[[VAL_980]]
 // CHECK:         %[[VAL_1002:.*]] = load float, float* %[[VAL_1001]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1003:.*]] = bitcast float %[[VAL_1002]] to i32
@@ -1123,10 +1123,10 @@
 // CHECK:         %[[VAL_1016:.*]] = bitcast i32 %[[VAL_1015]] to float
 // CHECK:         %[[VAL_1017:.*]] = fcmp uno float %[[VAL_1002]], %[[VAL_1002]]
 // CHECK:         %[[VAL_1018:.*]] = select i1 %[[VAL_1017]], float %[[VAL_1002]], float %[[VAL_1016]]
-// CHECK:         %[[VAL_1019:.*]] = bitcast [100 x [200 x float]]* %[[VAL_971]] to float*
+// CHECK:         %[[VAL_1019:.*]] = bitcast [100 x [200 x float]]* %[[VAL_974]] to float*
 // CHECK:         %[[VAL_1020:.*]] = getelementptr inbounds float, float* %[[VAL_1019]], i32 %[[VAL_980]]
 // CHECK:         store float %[[VAL_1018]], float* %[[VAL_1020]], align 4
-// CHECK:         %[[VAL_1021:.*]] = bitcast [100 x [200 x float]]* %[[VAL_974]] to float*
+// CHECK:         %[[VAL_1021:.*]] = bitcast [100 x [200 x float]]* %[[VAL_971]] to float*
 // CHECK:         %[[VAL_1022:.*]] = getelementptr inbounds float, float* %[[VAL_1021]], i32 %[[VAL_984]]
 // CHECK:         %[[VAL_1023:.*]] = load float, float* %[[VAL_1022]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1024:.*]] = bitcast float %[[VAL_1023]] to i32
@@ -1145,10 +1145,10 @@
 // CHECK:         %[[VAL_1037:.*]] = bitcast i32 %[[VAL_1036]] to float
 // CHECK:         %[[VAL_1038:.*]] = fcmp uno float %[[VAL_1023]], %[[VAL_1023]]
 // CHECK:         %[[VAL_1039:.*]] = select i1 %[[VAL_1038]], float %[[VAL_1023]], float %[[VAL_1037]]
-// CHECK:         %[[VAL_1040:.*]] = bitcast [100 x [200 x float]]* %[[VAL_971]] to float*
+// CHECK:         %[[VAL_1040:.*]] = bitcast [100 x [200 x float]]* %[[VAL_974]] to float*
 // CHECK:         %[[VAL_1041:.*]] = getelementptr inbounds float, float* %[[VAL_1040]], i32 %[[VAL_984]]
 // CHECK:         store float %[[VAL_1039]], float* %[[VAL_1041]], align 4
-// CHECK:         %[[VAL_1042:.*]] = bitcast [100 x [200 x float]]* %[[VAL_974]] to float*
+// CHECK:         %[[VAL_1042:.*]] = bitcast [100 x [200 x float]]* %[[VAL_971]] to float*
 // CHECK:         %[[VAL_1043:.*]] = getelementptr inbounds float, float* %[[VAL_1042]], i32 %[[VAL_988]]
 // CHECK:         %[[VAL_1044:.*]] = load float, float* %[[VAL_1043]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1045:.*]] = bitcast float %[[VAL_1044]] to i32
@@ -1167,10 +1167,10 @@
 // CHECK:         %[[VAL_1058:.*]] = bitcast i32 %[[VAL_1057]] to float
 // CHECK:         %[[VAL_1059:.*]] = fcmp uno float %[[VAL_1044]], %[[VAL_1044]]
 // CHECK:         %[[VAL_1060:.*]] = select i1 %[[VAL_1059]], float %[[VAL_1044]], float %[[VAL_1058]]
-// CHECK:         %[[VAL_1061:.*]] = bitcast [100 x [200 x float]]* %[[VAL_971]] to float*
+// CHECK:         %[[VAL_1061:.*]] = bitcast [100 x [200 x float]]* %[[VAL_974]] to float*
 // CHECK:         %[[VAL_1062:.*]] = getelementptr inbounds float, float* %[[VAL_1061]], i32 %[[VAL_988]]
 // CHECK:         store float %[[VAL_1060]], float* %[[VAL_1062]], align 4
-// CHECK:         %[[VAL_1063:.*]] = bitcast [100 x [200 x float]]* %[[VAL_974]] to float*
+// CHECK:         %[[VAL_1063:.*]] = bitcast [100 x [200 x float]]* %[[VAL_971]] to float*
 // CHECK:         %[[VAL_1064:.*]] = getelementptr inbounds float, float* %[[VAL_1063]], i32 %[[VAL_992]]
 // CHECK:         %[[VAL_1065:.*]] = load float, float* %[[VAL_1064]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1066:.*]] = bitcast float %[[VAL_1065]] to i32
@@ -1189,7 +1189,7 @@
 // CHECK:         %[[VAL_1079:.*]] = bitcast i32 %[[VAL_1078]] to float
 // CHECK:         %[[VAL_1080:.*]] = fcmp uno float %[[VAL_1065]], %[[VAL_1065]]
 // CHECK:         %[[VAL_1081:.*]] = select i1 %[[VAL_1080]], float %[[VAL_1065]], float %[[VAL_1079]]
-// CHECK:         %[[VAL_1082:.*]] = bitcast [100 x [200 x float]]* %[[VAL_971]] to float*
+// CHECK:         %[[VAL_1082:.*]] = bitcast [100 x [200 x float]]* %[[VAL_974]] to float*
 // CHECK:         %[[VAL_1083:.*]] = getelementptr inbounds float, float* %[[VAL_1082]], i32 %[[VAL_992]]
 // CHECK:         store float %[[VAL_1081]], float* %[[VAL_1083]], align 4
 // CHECK:         br label %[[VAL_998]]
@@ -1200,9 +1200,9 @@
 // CHECK:         %[[VAL_1089:.*]] = bitcast i8* %[[VAL_1087]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1090:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1091:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1092:.*]] = mul nuw nsw i32 %[[VAL_1090]], 128
+// CHECK:         %[[VAL_1092:.*]] = mul nuw nsw i32 %[[VAL_1090]], 256
 // CHECK:         %[[VAL_1093:.*]] = add nuw nsw i32 %[[VAL_1092]], %[[VAL_1091]]
-// CHECK:         %[[VAL_1094:.*]] = icmp ult i32 %[[VAL_1093]], 163840
+// CHECK:         %[[VAL_1094:.*]] = icmp ult i32 %[[VAL_1093]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1094]])
 // CHECK:         %[[VAL_1095:.*]] = mul nuw nsw i32 %[[VAL_1093]], 4
 // CHECK:         %[[VAL_1096:.*]] = udiv i32 %[[VAL_1095]], 1
@@ -1225,32 +1225,32 @@
 // CHECK:       r20.in_bounds-after:                              ; preds = %[[VAL_1112]], %[[VAL_1114:.*]]
 // CHECK:         ret void
 // CHECK:       r20.in_bounds-true:                               ; preds = %[[VAL_1114]]
-// CHECK:         %[[VAL_1115:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1089]] to float*
+// CHECK:         %[[VAL_1115:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1086]] to float*
 // CHECK:         %[[VAL_1116:.*]] = getelementptr inbounds float, float* %[[VAL_1115]], i32 %[[VAL_1095]]
 // CHECK:         %[[VAL_1117:.*]] = load float, float* %[[VAL_1116]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1118:.*]] = call float @__nv_rsqrtf(float %[[VAL_1117]])
-// CHECK:         %[[VAL_1119:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1086]] to float*
+// CHECK:         %[[VAL_1119:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1089]] to float*
 // CHECK:         %[[VAL_1120:.*]] = getelementptr inbounds float, float* %[[VAL_1119]], i32 %[[VAL_1095]]
 // CHECK:         store float %[[VAL_1118]], float* %[[VAL_1120]], align 4
-// CHECK:         %[[VAL_1121:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1089]] to float*
+// CHECK:         %[[VAL_1121:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1086]] to float*
 // CHECK:         %[[VAL_1122:.*]] = getelementptr inbounds float, float* %[[VAL_1121]], i32 %[[VAL_1099]]
 // CHECK:         %[[VAL_1123:.*]] = load float, float* %[[VAL_1122]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1124:.*]] = call float @__nv_rsqrtf(float %[[VAL_1123]])
-// CHECK:         %[[VAL_1125:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1086]] to float*
+// CHECK:         %[[VAL_1125:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1089]] to float*
 // CHECK:         %[[VAL_1126:.*]] = getelementptr inbounds float, float* %[[VAL_1125]], i32 %[[VAL_1099]]
 // CHECK:         store float %[[VAL_1124]], float* %[[VAL_1126]], align 4
-// CHECK:         %[[VAL_1127:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1089]] to float*
+// CHECK:         %[[VAL_1127:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1086]] to float*
 // CHECK:         %[[VAL_1128:.*]] = getelementptr inbounds float, float* %[[VAL_1127]], i32 %[[VAL_1103]]
 // CHECK:         %[[VAL_1129:.*]] = load float, float* %[[VAL_1128]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1130:.*]] = call float @__nv_rsqrtf(float %[[VAL_1129]])
-// CHECK:         %[[VAL_1131:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1086]] to float*
+// CHECK:         %[[VAL_1131:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1089]] to float*
 // CHECK:         %[[VAL_1132:.*]] = getelementptr inbounds float, float* %[[VAL_1131]], i32 %[[VAL_1103]]
 // CHECK:         store float %[[VAL_1130]], float* %[[VAL_1132]], align 4
-// CHECK:         %[[VAL_1133:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1089]] to float*
+// CHECK:         %[[VAL_1133:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1086]] to float*
 // CHECK:         %[[VAL_1134:.*]] = getelementptr inbounds float, float* %[[VAL_1133]], i32 %[[VAL_1107]]
 // CHECK:         %[[VAL_1135:.*]] = load float, float* %[[VAL_1134]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1136:.*]] = call float @__nv_rsqrtf(float %[[VAL_1135]])
-// CHECK:         %[[VAL_1137:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1086]] to float*
+// CHECK:         %[[VAL_1137:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1089]] to float*
 // CHECK:         %[[VAL_1138:.*]] = getelementptr inbounds float, float* %[[VAL_1137]], i32 %[[VAL_1107]]
 // CHECK:         store float %[[VAL_1136]], float* %[[VAL_1138]], align 4
 // CHECK:         br label %[[VAL_1113]]
@@ -1261,9 +1261,9 @@
 // CHECK:         %[[VAL_1144:.*]] = bitcast i8* %[[VAL_1142]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1145:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1146:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1147:.*]] = mul nuw nsw i32 %[[VAL_1145]], 128
+// CHECK:         %[[VAL_1147:.*]] = mul nuw nsw i32 %[[VAL_1145]], 256
 // CHECK:         %[[VAL_1148:.*]] = add nuw nsw i32 %[[VAL_1147]], %[[VAL_1146]]
-// CHECK:         %[[VAL_1149:.*]] = icmp ult i32 %[[VAL_1148]], 163840
+// CHECK:         %[[VAL_1149:.*]] = icmp ult i32 %[[VAL_1148]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1149]])
 // CHECK:         %[[VAL_1150:.*]] = mul nuw nsw i32 %[[VAL_1148]], 4
 // CHECK:         %[[VAL_1151:.*]] = udiv i32 %[[VAL_1150]], 1
@@ -1286,7 +1286,7 @@
 // CHECK:       r22.in_bounds-after:                              ; preds = %[[VAL_1167]], %[[VAL_1169:.*]]
 // CHECK:         ret void
 // CHECK:       r22.in_bounds-true:                               ; preds = %[[VAL_1169]]
-// CHECK:         %[[VAL_1170:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1144]] to float*
+// CHECK:         %[[VAL_1170:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1141]] to float*
 // CHECK:         %[[VAL_1171:.*]] = getelementptr inbounds float, float* %[[VAL_1170]], i32 %[[VAL_1150]]
 // CHECK:         %[[VAL_1172:.*]] = load float, float* %[[VAL_1171]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1173:.*]] = fcmp one float %[[VAL_1172]], 0.000000e+00
@@ -1294,10 +1294,10 @@
 // CHECK:         %[[VAL_1175:.*]] = call float @llvm.copysign.f32(float %[[VAL_1174]], float %[[VAL_1172]])
 // CHECK:         %[[VAL_1176:.*]] = fcmp uno float %[[VAL_1172]], %[[VAL_1172]]
 // CHECK:         %[[VAL_1177:.*]] = select i1 %[[VAL_1176]], float %[[VAL_1172]], float %[[VAL_1175]]
-// CHECK:         %[[VAL_1178:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1141]] to float*
+// CHECK:         %[[VAL_1178:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1144]] to float*
 // CHECK:         %[[VAL_1179:.*]] = getelementptr inbounds float, float* %[[VAL_1178]], i32 %[[VAL_1150]]
 // CHECK:         store float %[[VAL_1177]], float* %[[VAL_1179]], align 4
-// CHECK:         %[[VAL_1180:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1144]] to float*
+// CHECK:         %[[VAL_1180:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1141]] to float*
 // CHECK:         %[[VAL_1181:.*]] = getelementptr inbounds float, float* %[[VAL_1180]], i32 %[[VAL_1154]]
 // CHECK:         %[[VAL_1182:.*]] = load float, float* %[[VAL_1181]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1183:.*]] = fcmp one float %[[VAL_1182]], 0.000000e+00
@@ -1305,10 +1305,10 @@
 // CHECK:         %[[VAL_1185:.*]] = call float @llvm.copysign.f32(float %[[VAL_1184]], float %[[VAL_1182]])
 // CHECK:         %[[VAL_1186:.*]] = fcmp uno float %[[VAL_1182]], %[[VAL_1182]]
 // CHECK:         %[[VAL_1187:.*]] = select i1 %[[VAL_1186]], float %[[VAL_1182]], float %[[VAL_1185]]
-// CHECK:         %[[VAL_1188:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1141]] to float*
+// CHECK:         %[[VAL_1188:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1144]] to float*
 // CHECK:         %[[VAL_1189:.*]] = getelementptr inbounds float, float* %[[VAL_1188]], i32 %[[VAL_1154]]
 // CHECK:         store float %[[VAL_1187]], float* %[[VAL_1189]], align 4
-// CHECK:         %[[VAL_1190:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1144]] to float*
+// CHECK:         %[[VAL_1190:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1141]] to float*
 // CHECK:         %[[VAL_1191:.*]] = getelementptr inbounds float, float* %[[VAL_1190]], i32 %[[VAL_1158]]
 // CHECK:         %[[VAL_1192:.*]] = load float, float* %[[VAL_1191]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1193:.*]] = fcmp one float %[[VAL_1192]], 0.000000e+00
@@ -1316,10 +1316,10 @@
 // CHECK:         %[[VAL_1195:.*]] = call float @llvm.copysign.f32(float %[[VAL_1194]], float %[[VAL_1192]])
 // CHECK:         %[[VAL_1196:.*]] = fcmp uno float %[[VAL_1192]], %[[VAL_1192]]
 // CHECK:         %[[VAL_1197:.*]] = select i1 %[[VAL_1196]], float %[[VAL_1192]], float %[[VAL_1195]]
-// CHECK:         %[[VAL_1198:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1141]] to float*
+// CHECK:         %[[VAL_1198:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1144]] to float*
 // CHECK:         %[[VAL_1199:.*]] = getelementptr inbounds float, float* %[[VAL_1198]], i32 %[[VAL_1158]]
 // CHECK:         store float %[[VAL_1197]], float* %[[VAL_1199]], align 4
-// CHECK:         %[[VAL_1200:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1144]] to float*
+// CHECK:         %[[VAL_1200:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1141]] to float*
 // CHECK:         %[[VAL_1201:.*]] = getelementptr inbounds float, float* %[[VAL_1200]], i32 %[[VAL_1162]]
 // CHECK:         %[[VAL_1202:.*]] = load float, float* %[[VAL_1201]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1203:.*]] = fcmp one float %[[VAL_1202]], 0.000000e+00
@@ -1327,7 +1327,7 @@
 // CHECK:         %[[VAL_1205:.*]] = call float @llvm.copysign.f32(float %[[VAL_1204]], float %[[VAL_1202]])
 // CHECK:         %[[VAL_1206:.*]] = fcmp uno float %[[VAL_1202]], %[[VAL_1202]]
 // CHECK:         %[[VAL_1207:.*]] = select i1 %[[VAL_1206]], float %[[VAL_1202]], float %[[VAL_1205]]
-// CHECK:         %[[VAL_1208:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1141]] to float*
+// CHECK:         %[[VAL_1208:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1144]] to float*
 // CHECK:         %[[VAL_1209:.*]] = getelementptr inbounds float, float* %[[VAL_1208]], i32 %[[VAL_1162]]
 // CHECK:         store float %[[VAL_1207]], float* %[[VAL_1209]], align 4
 // CHECK:         br label %[[VAL_1168]]
@@ -1337,10 +1337,10 @@
 // CHECK:         %[[VAL_1213:.*]] = getelementptr inbounds i8, i8* %[[VAL_1214:.*]], i64 0
 // CHECK:         %[[VAL_1215:.*]] = bitcast i8* %[[VAL_1213]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1216:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
-// CHECK:         %[[VAL_1217:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1218:.*]] = mul nuw nsw i32 %[[VAL_1216]], 128
+// CHECK:         %[[VAL_1217:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !93
+// CHECK:         %[[VAL_1218:.*]] = mul nuw nsw i32 %[[VAL_1216]], 1024
 // CHECK:         %[[VAL_1219:.*]] = add nuw nsw i32 %[[VAL_1218]], %[[VAL_1217]]
-// CHECK:         %[[VAL_1220:.*]] = icmp ult i32 %[[VAL_1219]], 163840
+// CHECK:         %[[VAL_1220:.*]] = icmp ult i32 %[[VAL_1219]], 20480
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1220]])
 // CHECK:         %[[VAL_1221:.*]] = udiv i32 %[[VAL_1219]], 1
 // CHECK:         %[[VAL_1222:.*]] = urem i32 %[[VAL_1221]], 200
@@ -1350,11 +1350,11 @@
 // CHECK:       r23.in_bounds-after:                              ; preds = %[[VAL_1225]], %[[VAL_1227:.*]]
 // CHECK:         ret void
 // CHECK:       r23.in_bounds-true:                               ; preds = %[[VAL_1227]]
-// CHECK:         %[[VAL_1228:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1215]] to float*
+// CHECK:         %[[VAL_1228:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1212]] to float*
 // CHECK:         %[[VAL_1229:.*]] = getelementptr inbounds float, float* %[[VAL_1228]], i32 %[[VAL_1219]]
 // CHECK:         %[[VAL_1230:.*]] = load float, float* %[[VAL_1229]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1231:.*]] = call float @__nv_sinf(float %[[VAL_1230]])
-// CHECK:         %[[VAL_1232:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1212]] to float*
+// CHECK:         %[[VAL_1232:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1215]] to float*
 // CHECK:         %[[VAL_1233:.*]] = getelementptr inbounds float, float* %[[VAL_1232]], i32 %[[VAL_1219]]
 // CHECK:         store float %[[VAL_1231]], float* %[[VAL_1233]], align 4
 // CHECK:         br label %[[VAL_1226]]
@@ -1365,9 +1365,9 @@
 // CHECK:         %[[VAL_1239:.*]] = bitcast i8* %[[VAL_1237]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1240:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1241:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1242:.*]] = mul nuw nsw i32 %[[VAL_1240]], 128
+// CHECK:         %[[VAL_1242:.*]] = mul nuw nsw i32 %[[VAL_1240]], 256
 // CHECK:         %[[VAL_1243:.*]] = add nuw nsw i32 %[[VAL_1242]], %[[VAL_1241]]
-// CHECK:         %[[VAL_1244:.*]] = icmp ult i32 %[[VAL_1243]], 163840
+// CHECK:         %[[VAL_1244:.*]] = icmp ult i32 %[[VAL_1243]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1244]])
 // CHECK:         %[[VAL_1245:.*]] = mul nuw nsw i32 %[[VAL_1243]], 4
 // CHECK:         %[[VAL_1246:.*]] = udiv i32 %[[VAL_1245]], 1
@@ -1390,32 +1390,32 @@
 // CHECK:       r24.in_bounds-after:                              ; preds = %[[VAL_1262]], %[[VAL_1264:.*]]
 // CHECK:         ret void
 // CHECK:       r24.in_bounds-true:                               ; preds = %[[VAL_1264]]
-// CHECK:         %[[VAL_1265:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1239]] to float*
+// CHECK:         %[[VAL_1265:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1236]] to float*
 // CHECK:         %[[VAL_1266:.*]] = getelementptr inbounds float, float* %[[VAL_1265]], i32 %[[VAL_1245]]
 // CHECK:         %[[VAL_1267:.*]] = load float, float* %[[VAL_1266]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1268:.*]] = call float @__nv_sqrtf(float %[[VAL_1267]])
-// CHECK:         %[[VAL_1269:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1236]] to float*
+// CHECK:         %[[VAL_1269:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1239]] to float*
 // CHECK:         %[[VAL_1270:.*]] = getelementptr inbounds float, float* %[[VAL_1269]], i32 %[[VAL_1245]]
 // CHECK:         store float %[[VAL_1268]], float* %[[VAL_1270]], align 4
-// CHECK:         %[[VAL_1271:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1239]] to float*
+// CHECK:         %[[VAL_1271:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1236]] to float*
 // CHECK:         %[[VAL_1272:.*]] = getelementptr inbounds float, float* %[[VAL_1271]], i32 %[[VAL_1249]]
 // CHECK:         %[[VAL_1273:.*]] = load float, float* %[[VAL_1272]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1274:.*]] = call float @__nv_sqrtf(float %[[VAL_1273]])
-// CHECK:         %[[VAL_1275:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1236]] to float*
+// CHECK:         %[[VAL_1275:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1239]] to float*
 // CHECK:         %[[VAL_1276:.*]] = getelementptr inbounds float, float* %[[VAL_1275]], i32 %[[VAL_1249]]
 // CHECK:         store float %[[VAL_1274]], float* %[[VAL_1276]], align 4
-// CHECK:         %[[VAL_1277:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1239]] to float*
+// CHECK:         %[[VAL_1277:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1236]] to float*
 // CHECK:         %[[VAL_1278:.*]] = getelementptr inbounds float, float* %[[VAL_1277]], i32 %[[VAL_1253]]
 // CHECK:         %[[VAL_1279:.*]] = load float, float* %[[VAL_1278]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1280:.*]] = call float @__nv_sqrtf(float %[[VAL_1279]])
-// CHECK:         %[[VAL_1281:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1236]] to float*
+// CHECK:         %[[VAL_1281:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1239]] to float*
 // CHECK:         %[[VAL_1282:.*]] = getelementptr inbounds float, float* %[[VAL_1281]], i32 %[[VAL_1253]]
 // CHECK:         store float %[[VAL_1280]], float* %[[VAL_1282]], align 4
-// CHECK:         %[[VAL_1283:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1239]] to float*
+// CHECK:         %[[VAL_1283:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1236]] to float*
 // CHECK:         %[[VAL_1284:.*]] = getelementptr inbounds float, float* %[[VAL_1283]], i32 %[[VAL_1257]]
 // CHECK:         %[[VAL_1285:.*]] = load float, float* %[[VAL_1284]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1286:.*]] = call float @__nv_sqrtf(float %[[VAL_1285]])
-// CHECK:         %[[VAL_1287:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1236]] to float*
+// CHECK:         %[[VAL_1287:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1239]] to float*
 // CHECK:         %[[VAL_1288:.*]] = getelementptr inbounds float, float* %[[VAL_1287]], i32 %[[VAL_1257]]
 // CHECK:         store float %[[VAL_1286]], float* %[[VAL_1288]], align 4
 // CHECK:         br label %[[VAL_1263]]
@@ -1426,9 +1426,9 @@
 // CHECK:         %[[VAL_1294:.*]] = bitcast i8* %[[VAL_1292]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1295:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1296:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1297:.*]] = mul nuw nsw i32 %[[VAL_1295]], 128
+// CHECK:         %[[VAL_1297:.*]] = mul nuw nsw i32 %[[VAL_1295]], 256
 // CHECK:         %[[VAL_1298:.*]] = add nuw nsw i32 %[[VAL_1297]], %[[VAL_1296]]
-// CHECK:         %[[VAL_1299:.*]] = icmp ult i32 %[[VAL_1298]], 163840
+// CHECK:         %[[VAL_1299:.*]] = icmp ult i32 %[[VAL_1298]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1299]])
 // CHECK:         %[[VAL_1300:.*]] = mul nuw nsw i32 %[[VAL_1298]], 4
 // CHECK:         %[[VAL_1301:.*]] = udiv i32 %[[VAL_1300]], 1
@@ -1451,40 +1451,40 @@
 // CHECK:       r25.in_bounds-after:                              ; preds = %[[VAL_1317]], %[[VAL_1319:.*]]
 // CHECK:         ret void
 // CHECK:       r25.in_bounds-true:                               ; preds = %[[VAL_1319]]
-// CHECK:         %[[VAL_1320:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1294]] to float*
+// CHECK:         %[[VAL_1320:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1291]] to float*
 // CHECK:         %[[VAL_1321:.*]] = getelementptr inbounds float, float* %[[VAL_1320]], i32 %[[VAL_1300]]
 // CHECK:         %[[VAL_1322:.*]] = load float, float* %[[VAL_1321]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1323:.*]] = call float @llvm.fabs.f32(float %[[VAL_1322]])
 // CHECK:         %[[VAL_1324:.*]] = call float @__nv_powf(float %[[VAL_1323]], float 0x3FD5555560000000)
 // CHECK:         %[[VAL_1325:.*]] = call float @llvm.copysign.f32(float %[[VAL_1324]], float %[[VAL_1322]])
-// CHECK:         %[[VAL_1326:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1291]] to float*
+// CHECK:         %[[VAL_1326:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1294]] to float*
 // CHECK:         %[[VAL_1327:.*]] = getelementptr inbounds float, float* %[[VAL_1326]], i32 %[[VAL_1300]]
 // CHECK:         store float %[[VAL_1325]], float* %[[VAL_1327]], align 4
-// CHECK:         %[[VAL_1328:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1294]] to float*
+// CHECK:         %[[VAL_1328:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1291]] to float*
 // CHECK:         %[[VAL_1329:.*]] = getelementptr inbounds float, float* %[[VAL_1328]], i32 %[[VAL_1304]]
 // CHECK:         %[[VAL_1330:.*]] = load float, float* %[[VAL_1329]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1331:.*]] = call float @llvm.fabs.f32(float %[[VAL_1330]])
 // CHECK:         %[[VAL_1332:.*]] = call float @__nv_powf(float %[[VAL_1331]], float 0x3FD5555560000000)
 // CHECK:         %[[VAL_1333:.*]] = call float @llvm.copysign.f32(float %[[VAL_1332]], float %[[VAL_1330]])
-// CHECK:         %[[VAL_1334:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1291]] to float*
+// CHECK:         %[[VAL_1334:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1294]] to float*
 // CHECK:         %[[VAL_1335:.*]] = getelementptr inbounds float, float* %[[VAL_1334]], i32 %[[VAL_1304]]
 // CHECK:         store float %[[VAL_1333]], float* %[[VAL_1335]], align 4
-// CHECK:         %[[VAL_1336:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1294]] to float*
+// CHECK:         %[[VAL_1336:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1291]] to float*
 // CHECK:         %[[VAL_1337:.*]] = getelementptr inbounds float, float* %[[VAL_1336]], i32 %[[VAL_1308]]
 // CHECK:         %[[VAL_1338:.*]] = load float, float* %[[VAL_1337]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1339:.*]] = call float @llvm.fabs.f32(float %[[VAL_1338]])
 // CHECK:         %[[VAL_1340:.*]] = call float @__nv_powf(float %[[VAL_1339]], float 0x3FD5555560000000)
 // CHECK:         %[[VAL_1341:.*]] = call float @llvm.copysign.f32(float %[[VAL_1340]], float %[[VAL_1338]])
-// CHECK:         %[[VAL_1342:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1291]] to float*
+// CHECK:         %[[VAL_1342:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1294]] to float*
 // CHECK:         %[[VAL_1343:.*]] = getelementptr inbounds float, float* %[[VAL_1342]], i32 %[[VAL_1308]]
 // CHECK:         store float %[[VAL_1341]], float* %[[VAL_1343]], align 4
-// CHECK:         %[[VAL_1344:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1294]] to float*
+// CHECK:         %[[VAL_1344:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1291]] to float*
 // CHECK:         %[[VAL_1345:.*]] = getelementptr inbounds float, float* %[[VAL_1344]], i32 %[[VAL_1312]]
 // CHECK:         %[[VAL_1346:.*]] = load float, float* %[[VAL_1345]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1347:.*]] = call float @llvm.fabs.f32(float %[[VAL_1346]])
 // CHECK:         %[[VAL_1348:.*]] = call float @__nv_powf(float %[[VAL_1347]], float 0x3FD5555560000000)
 // CHECK:         %[[VAL_1349:.*]] = call float @llvm.copysign.f32(float %[[VAL_1348]], float %[[VAL_1346]])
-// CHECK:         %[[VAL_1350:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1291]] to float*
+// CHECK:         %[[VAL_1350:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1294]] to float*
 // CHECK:         %[[VAL_1351:.*]] = getelementptr inbounds float, float* %[[VAL_1350]], i32 %[[VAL_1312]]
 // CHECK:         store float %[[VAL_1349]], float* %[[VAL_1351]], align 4
 // CHECK:         br label %[[VAL_1318]]
@@ -1495,9 +1495,9 @@
 // CHECK:         %[[VAL_1357:.*]] = bitcast i8* %[[VAL_1355]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1358:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1359:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1360:.*]] = mul nuw nsw i32 %[[VAL_1358]], 128
+// CHECK:         %[[VAL_1360:.*]] = mul nuw nsw i32 %[[VAL_1358]], 256
 // CHECK:         %[[VAL_1361:.*]] = add nuw nsw i32 %[[VAL_1360]], %[[VAL_1359]]
-// CHECK:         %[[VAL_1362:.*]] = icmp ult i32 %[[VAL_1361]], 163840
+// CHECK:         %[[VAL_1362:.*]] = icmp ult i32 %[[VAL_1361]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1362]])
 // CHECK:         %[[VAL_1363:.*]] = mul nuw nsw i32 %[[VAL_1361]], 4
 // CHECK:         %[[VAL_1364:.*]] = udiv i32 %[[VAL_1363]], 1
@@ -1520,7 +1520,7 @@
 // CHECK:       r26.in_bounds-after:                              ; preds = %[[VAL_1380]], %[[VAL_1382:.*]]
 // CHECK:         ret void
 // CHECK:       r26.in_bounds-true:                               ; preds = %[[VAL_1382]]
-// CHECK:         %[[VAL_1383:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1357]] to float*
+// CHECK:         %[[VAL_1383:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1354]] to float*
 // CHECK:         %[[VAL_1384:.*]] = getelementptr inbounds float, float* %[[VAL_1383]], i32 %[[VAL_1363]]
 // CHECK:         %[[VAL_1385:.*]] = load float, float* %[[VAL_1384]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1386:.*]] = call float @llvm.fabs.f32(float %[[VAL_1385]])
@@ -1555,10 +1555,10 @@
 // CHECK:         %[[VAL_1415:.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float %[[VAL_1385]])
 // CHECK:         %[[VAL_1416:.*]] = fcmp ult float %[[VAL_1386]], 2.000000e+01
 // CHECK:         %[[VAL_1417:.*]] = select i1 %[[VAL_1416]], float %[[VAL_1414]], float %[[VAL_1415]]
-// CHECK:         %[[VAL_1418:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1354]] to float*
+// CHECK:         %[[VAL_1418:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1357]] to float*
 // CHECK:         %[[VAL_1419:.*]] = getelementptr inbounds float, float* %[[VAL_1418]], i32 %[[VAL_1363]]
 // CHECK:         store float %[[VAL_1417]], float* %[[VAL_1419]], align 4
-// CHECK:         %[[VAL_1420:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1357]] to float*
+// CHECK:         %[[VAL_1420:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1354]] to float*
 // CHECK:         %[[VAL_1421:.*]] = getelementptr inbounds float, float* %[[VAL_1420]], i32 %[[VAL_1367]]
 // CHECK:         %[[VAL_1422:.*]] = load float, float* %[[VAL_1421]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1423:.*]] = call float @llvm.fabs.f32(float %[[VAL_1422]])
@@ -1593,10 +1593,10 @@
 // CHECK:         %[[VAL_1452:.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float %[[VAL_1422]])
 // CHECK:         %[[VAL_1453:.*]] = fcmp ult float %[[VAL_1423]], 2.000000e+01
 // CHECK:         %[[VAL_1454:.*]] = select i1 %[[VAL_1453]], float %[[VAL_1451]], float %[[VAL_1452]]
-// CHECK:         %[[VAL_1455:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1354]] to float*
+// CHECK:         %[[VAL_1455:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1357]] to float*
 // CHECK:         %[[VAL_1456:.*]] = getelementptr inbounds float, float* %[[VAL_1455]], i32 %[[VAL_1367]]
 // CHECK:         store float %[[VAL_1454]], float* %[[VAL_1456]], align 4
-// CHECK:         %[[VAL_1457:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1357]] to float*
+// CHECK:         %[[VAL_1457:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1354]] to float*
 // CHECK:         %[[VAL_1458:.*]] = getelementptr inbounds float, float* %[[VAL_1457]], i32 %[[VAL_1371]]
 // CHECK:         %[[VAL_1459:.*]] = load float, float* %[[VAL_1458]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1460:.*]] = call float @llvm.fabs.f32(float %[[VAL_1459]])
@@ -1631,10 +1631,10 @@
 // CHECK:         %[[VAL_1489:.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float %[[VAL_1459]])
 // CHECK:         %[[VAL_1490:.*]] = fcmp ult float %[[VAL_1460]], 2.000000e+01
 // CHECK:         %[[VAL_1491:.*]] = select i1 %[[VAL_1490]], float %[[VAL_1488]], float %[[VAL_1489]]
-// CHECK:         %[[VAL_1492:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1354]] to float*
+// CHECK:         %[[VAL_1492:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1357]] to float*
 // CHECK:         %[[VAL_1493:.*]] = getelementptr inbounds float, float* %[[VAL_1492]], i32 %[[VAL_1371]]
 // CHECK:         store float %[[VAL_1491]], float* %[[VAL_1493]], align 4
-// CHECK:         %[[VAL_1494:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1357]] to float*
+// CHECK:         %[[VAL_1494:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1354]] to float*
 // CHECK:         %[[VAL_1495:.*]] = getelementptr inbounds float, float* %[[VAL_1494]], i32 %[[VAL_1375]]
 // CHECK:         %[[VAL_1496:.*]] = load float, float* %[[VAL_1495]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1497:.*]] = call float @llvm.fabs.f32(float %[[VAL_1496]])
@@ -1669,7 +1669,7 @@
 // CHECK:         %[[VAL_1526:.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float %[[VAL_1496]])
 // CHECK:         %[[VAL_1527:.*]] = fcmp ult float %[[VAL_1497]], 2.000000e+01
 // CHECK:         %[[VAL_1528:.*]] = select i1 %[[VAL_1527]], float %[[VAL_1525]], float %[[VAL_1526]]
-// CHECK:         %[[VAL_1529:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1354]] to float*
+// CHECK:         %[[VAL_1529:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1357]] to float*
 // CHECK:         %[[VAL_1530:.*]] = getelementptr inbounds float, float* %[[VAL_1529]], i32 %[[VAL_1375]]
 // CHECK:         store float %[[VAL_1528]], float* %[[VAL_1530]], align 4
 // CHECK:         br label %[[VAL_1381]]
@@ -1682,9 +1682,9 @@
 // CHECK:         %[[VAL_1539:.*]] = bitcast i8* %[[VAL_1537]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1540:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1541:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1542:.*]] = mul nuw nsw i32 %[[VAL_1540]], 128
+// CHECK:         %[[VAL_1542:.*]] = mul nuw nsw i32 %[[VAL_1540]], 256
 // CHECK:         %[[VAL_1543:.*]] = add nuw nsw i32 %[[VAL_1542]], %[[VAL_1541]]
-// CHECK:         %[[VAL_1544:.*]] = icmp ult i32 %[[VAL_1543]], 163840
+// CHECK:         %[[VAL_1544:.*]] = icmp ult i32 %[[VAL_1543]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1544]])
 // CHECK:         %[[VAL_1545:.*]] = mul nuw nsw i32 %[[VAL_1543]], 4
 // CHECK:         %[[VAL_1546:.*]] = udiv i32 %[[VAL_1545]], 1
@@ -1707,44 +1707,44 @@
 // CHECK:       r27.in_bounds-after:                              ; preds = %[[VAL_1562]], %[[VAL_1564:.*]]
 // CHECK:         ret void
 // CHECK:       r27.in_bounds-true:                               ; preds = %[[VAL_1564]]
-// CHECK:         %[[VAL_1565:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1536]] to float*
+// CHECK:         %[[VAL_1565:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1533]] to float*
 // CHECK:         %[[VAL_1566:.*]] = getelementptr inbounds float, float* %[[VAL_1565]], i32 %[[VAL_1545]]
 // CHECK:         %[[VAL_1567:.*]] = load float, float* %[[VAL_1566]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1568:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1539]] to float*
+// CHECK:         %[[VAL_1568:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1536]] to float*
 // CHECK:         %[[VAL_1569:.*]] = getelementptr inbounds float, float* %[[VAL_1568]], i32 %[[VAL_1545]]
 // CHECK:         %[[VAL_1570:.*]] = load float, float* %[[VAL_1569]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1571:.*]] = fadd float %[[VAL_1567]], %[[VAL_1570]]
-// CHECK:         %[[VAL_1572:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1533]] to float*
+// CHECK:         %[[VAL_1572:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1539]] to float*
 // CHECK:         %[[VAL_1573:.*]] = getelementptr inbounds float, float* %[[VAL_1572]], i32 %[[VAL_1545]]
 // CHECK:         store float %[[VAL_1571]], float* %[[VAL_1573]], align 4
-// CHECK:         %[[VAL_1574:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1536]] to float*
+// CHECK:         %[[VAL_1574:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1533]] to float*
 // CHECK:         %[[VAL_1575:.*]] = getelementptr inbounds float, float* %[[VAL_1574]], i32 %[[VAL_1549]]
 // CHECK:         %[[VAL_1576:.*]] = load float, float* %[[VAL_1575]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1577:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1539]] to float*
+// CHECK:         %[[VAL_1577:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1536]] to float*
 // CHECK:         %[[VAL_1578:.*]] = getelementptr inbounds float, float* %[[VAL_1577]], i32 %[[VAL_1549]]
 // CHECK:         %[[VAL_1579:.*]] = load float, float* %[[VAL_1578]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1580:.*]] = fadd float %[[VAL_1576]], %[[VAL_1579]]
-// CHECK:         %[[VAL_1581:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1533]] to float*
+// CHECK:         %[[VAL_1581:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1539]] to float*
 // CHECK:         %[[VAL_1582:.*]] = getelementptr inbounds float, float* %[[VAL_1581]], i32 %[[VAL_1549]]
 // CHECK:         store float %[[VAL_1580]], float* %[[VAL_1582]], align 4
-// CHECK:         %[[VAL_1583:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1536]] to float*
+// CHECK:         %[[VAL_1583:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1533]] to float*
 // CHECK:         %[[VAL_1584:.*]] = getelementptr inbounds float, float* %[[VAL_1583]], i32 %[[VAL_1553]]
 // CHECK:         %[[VAL_1585:.*]] = load float, float* %[[VAL_1584]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1586:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1539]] to float*
+// CHECK:         %[[VAL_1586:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1536]] to float*
 // CHECK:         %[[VAL_1587:.*]] = getelementptr inbounds float, float* %[[VAL_1586]], i32 %[[VAL_1553]]
 // CHECK:         %[[VAL_1588:.*]] = load float, float* %[[VAL_1587]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1589:.*]] = fadd float %[[VAL_1585]], %[[VAL_1588]]
-// CHECK:         %[[VAL_1590:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1533]] to float*
+// CHECK:         %[[VAL_1590:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1539]] to float*
 // CHECK:         %[[VAL_1591:.*]] = getelementptr inbounds float, float* %[[VAL_1590]], i32 %[[VAL_1553]]
 // CHECK:         store float %[[VAL_1589]], float* %[[VAL_1591]], align 4
-// CHECK:         %[[VAL_1592:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1536]] to float*
+// CHECK:         %[[VAL_1592:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1533]] to float*
 // CHECK:         %[[VAL_1593:.*]] = getelementptr inbounds float, float* %[[VAL_1592]], i32 %[[VAL_1557]]
 // CHECK:         %[[VAL_1594:.*]] = load float, float* %[[VAL_1593]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1595:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1539]] to float*
+// CHECK:         %[[VAL_1595:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1536]] to float*
 // CHECK:         %[[VAL_1596:.*]] = getelementptr inbounds float, float* %[[VAL_1595]], i32 %[[VAL_1557]]
 // CHECK:         %[[VAL_1597:.*]] = load float, float* %[[VAL_1596]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1598:.*]] = fadd float %[[VAL_1594]], %[[VAL_1597]]
-// CHECK:         %[[VAL_1599:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1533]] to float*
+// CHECK:         %[[VAL_1599:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1539]] to float*
 // CHECK:         %[[VAL_1600:.*]] = getelementptr inbounds float, float* %[[VAL_1599]], i32 %[[VAL_1557]]
 // CHECK:         store float %[[VAL_1598]], float* %[[VAL_1600]], align 4
 // CHECK:         br label %[[VAL_1563]]
@@ -1756,10 +1756,10 @@
 // CHECK:         %[[VAL_1607:.*]] = getelementptr inbounds i8, i8* %[[VAL_1608:.*]], i64 0
 // CHECK:         %[[VAL_1609:.*]] = bitcast i8* %[[VAL_1607]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1610:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
-// CHECK:         %[[VAL_1611:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1612:.*]] = mul nuw nsw i32 %[[VAL_1610]], 128
+// CHECK:         %[[VAL_1611:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !93
+// CHECK:         %[[VAL_1612:.*]] = mul nuw nsw i32 %[[VAL_1610]], 1024
 // CHECK:         %[[VAL_1613:.*]] = add nuw nsw i32 %[[VAL_1612]], %[[VAL_1611]]
-// CHECK:         %[[VAL_1614:.*]] = icmp ult i32 %[[VAL_1613]], 163840
+// CHECK:         %[[VAL_1614:.*]] = icmp ult i32 %[[VAL_1613]], 20480
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1614]])
 // CHECK:         %[[VAL_1615:.*]] = udiv i32 %[[VAL_1613]], 1
 // CHECK:         %[[VAL_1616:.*]] = urem i32 %[[VAL_1615]], 200
@@ -1769,29 +1769,29 @@
 // CHECK:       r28.in_bounds-after:                              ; preds = %[[VAL_1619]], %[[VAL_1621:.*]]
 // CHECK:         ret void
 // CHECK:       r28.in_bounds-true:                               ; preds = %[[VAL_1621]]
-// CHECK:         %[[VAL_1622:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1606]] to float*
+// CHECK:         %[[VAL_1622:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1603]] to float*
 // CHECK:         %[[VAL_1623:.*]] = getelementptr inbounds float, float* %[[VAL_1622]], i32 %[[VAL_1613]]
 // CHECK:         %[[VAL_1624:.*]] = load float, float* %[[VAL_1623]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1625:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1609]] to float*
+// CHECK:         %[[VAL_1625:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1606]] to float*
 // CHECK:         %[[VAL_1626:.*]] = getelementptr inbounds float, float* %[[VAL_1625]], i32 %[[VAL_1613]]
 // CHECK:         %[[VAL_1627:.*]] = load float, float* %[[VAL_1626]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1628:.*]] = call float @__nv_atan2f(float %[[VAL_1624]], float %[[VAL_1627]])
-// CHECK:         %[[VAL_1629:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1603]] to float*
+// CHECK:         %[[VAL_1629:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1609]] to float*
 // CHECK:         %[[VAL_1630:.*]] = getelementptr inbounds float, float* %[[VAL_1629]], i32 %[[VAL_1613]]
 // CHECK:         store float %[[VAL_1628]], float* %[[VAL_1630]], align 4
 // CHECK:         br label %[[VAL_1620]]
 // CHECK:       entry:
 // CHECK:         %[[VAL_1631:.*]] = getelementptr inbounds i8, i8* %[[VAL_1632:.*]], i64 0
-// CHECK:         %[[VAL_1633:.*]] = bitcast i8* %[[VAL_1631]] to [100 x [200 x i8]]*
+// CHECK:         %[[VAL_1633:.*]] = bitcast i8* %[[VAL_1631]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1634:.*]] = getelementptr inbounds i8, i8* %[[VAL_1635:.*]], i64 0
 // CHECK:         %[[VAL_1636:.*]] = bitcast i8* %[[VAL_1634]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1637:.*]] = getelementptr inbounds i8, i8* %[[VAL_1638:.*]], i64 0
-// CHECK:         %[[VAL_1639:.*]] = bitcast i8* %[[VAL_1637]] to [100 x [200 x float]]*
+// CHECK:         %[[VAL_1639:.*]] = bitcast i8* %[[VAL_1637]] to [100 x [200 x i8]]*
 // CHECK:         %[[VAL_1640:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1641:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1642:.*]] = mul nuw nsw i32 %[[VAL_1640]], 128
+// CHECK:         %[[VAL_1642:.*]] = mul nuw nsw i32 %[[VAL_1640]], 256
 // CHECK:         %[[VAL_1643:.*]] = add nuw nsw i32 %[[VAL_1642]], %[[VAL_1641]]
-// CHECK:         %[[VAL_1644:.*]] = icmp ult i32 %[[VAL_1643]], 163840
+// CHECK:         %[[VAL_1644:.*]] = icmp ult i32 %[[VAL_1643]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1644]])
 // CHECK:         %[[VAL_1645:.*]] = mul nuw nsw i32 %[[VAL_1643]], 4
 // CHECK:         %[[VAL_1646:.*]] = udiv i32 %[[VAL_1645]], 1
@@ -1814,63 +1814,63 @@
 // CHECK:       r29.in_bounds-after:                              ; preds = %[[VAL_1662]], %[[VAL_1664:.*]]
 // CHECK:         ret void
 // CHECK:       r29.in_bounds-true:                               ; preds = %[[VAL_1664]]
-// CHECK:         %[[VAL_1665:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1636]] to float*
+// CHECK:         %[[VAL_1665:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1633]] to float*
 // CHECK:         %[[VAL_1666:.*]] = getelementptr inbounds float, float* %[[VAL_1665]], i32 %[[VAL_1645]]
 // CHECK:         %[[VAL_1667:.*]] = load float, float* %[[VAL_1666]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1668:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1639]] to float*
+// CHECK:         %[[VAL_1668:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1636]] to float*
 // CHECK:         %[[VAL_1669:.*]] = getelementptr inbounds float, float* %[[VAL_1668]], i32 %[[VAL_1645]]
 // CHECK:         %[[VAL_1670:.*]] = load float, float* %[[VAL_1669]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1671:.*]] = fcmp oeq float %[[VAL_1667]], %[[VAL_1670]]
 // CHECK:         %[[VAL_1672:.*]] = zext i1 %[[VAL_1671]] to i8
-// CHECK:         %[[VAL_1673:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_1633]] to i8*
+// CHECK:         %[[VAL_1673:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_1639]] to i8*
 // CHECK:         %[[VAL_1674:.*]] = getelementptr inbounds i8, i8* %[[VAL_1673]], i32 %[[VAL_1645]]
 // CHECK:         store i8 %[[VAL_1672]], i8* %[[VAL_1674]], align 1
-// CHECK:         %[[VAL_1675:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1636]] to float*
+// CHECK:         %[[VAL_1675:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1633]] to float*
 // CHECK:         %[[VAL_1676:.*]] = getelementptr inbounds float, float* %[[VAL_1675]], i32 %[[VAL_1649]]
 // CHECK:         %[[VAL_1677:.*]] = load float, float* %[[VAL_1676]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1678:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1639]] to float*
+// CHECK:         %[[VAL_1678:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1636]] to float*
 // CHECK:         %[[VAL_1679:.*]] = getelementptr inbounds float, float* %[[VAL_1678]], i32 %[[VAL_1649]]
 // CHECK:         %[[VAL_1680:.*]] = load float, float* %[[VAL_1679]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1681:.*]] = fcmp oeq float %[[VAL_1677]], %[[VAL_1680]]
 // CHECK:         %[[VAL_1682:.*]] = zext i1 %[[VAL_1681]] to i8
-// CHECK:         %[[VAL_1683:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_1633]] to i8*
+// CHECK:         %[[VAL_1683:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_1639]] to i8*
 // CHECK:         %[[VAL_1684:.*]] = getelementptr inbounds i8, i8* %[[VAL_1683]], i32 %[[VAL_1649]]
 // CHECK:         store i8 %[[VAL_1682]], i8* %[[VAL_1684]], align 1
-// CHECK:         %[[VAL_1685:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1636]] to float*
+// CHECK:         %[[VAL_1685:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1633]] to float*
 // CHECK:         %[[VAL_1686:.*]] = getelementptr inbounds float, float* %[[VAL_1685]], i32 %[[VAL_1653]]
 // CHECK:         %[[VAL_1687:.*]] = load float, float* %[[VAL_1686]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1688:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1639]] to float*
+// CHECK:         %[[VAL_1688:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1636]] to float*
 // CHECK:         %[[VAL_1689:.*]] = getelementptr inbounds float, float* %[[VAL_1688]], i32 %[[VAL_1653]]
 // CHECK:         %[[VAL_1690:.*]] = load float, float* %[[VAL_1689]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1691:.*]] = fcmp oeq float %[[VAL_1687]], %[[VAL_1690]]
 // CHECK:         %[[VAL_1692:.*]] = zext i1 %[[VAL_1691]] to i8
-// CHECK:         %[[VAL_1693:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_1633]] to i8*
+// CHECK:         %[[VAL_1693:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_1639]] to i8*
 // CHECK:         %[[VAL_1694:.*]] = getelementptr inbounds i8, i8* %[[VAL_1693]], i32 %[[VAL_1653]]
 // CHECK:         store i8 %[[VAL_1692]], i8* %[[VAL_1694]], align 1
-// CHECK:         %[[VAL_1695:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1636]] to float*
+// CHECK:         %[[VAL_1695:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1633]] to float*
 // CHECK:         %[[VAL_1696:.*]] = getelementptr inbounds float, float* %[[VAL_1695]], i32 %[[VAL_1657]]
 // CHECK:         %[[VAL_1697:.*]] = load float, float* %[[VAL_1696]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1698:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1639]] to float*
+// CHECK:         %[[VAL_1698:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1636]] to float*
 // CHECK:         %[[VAL_1699:.*]] = getelementptr inbounds float, float* %[[VAL_1698]], i32 %[[VAL_1657]]
 // CHECK:         %[[VAL_1700:.*]] = load float, float* %[[VAL_1699]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1701:.*]] = fcmp oeq float %[[VAL_1697]], %[[VAL_1700]]
 // CHECK:         %[[VAL_1702:.*]] = zext i1 %[[VAL_1701]] to i8
-// CHECK:         %[[VAL_1703:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_1633]] to i8*
+// CHECK:         %[[VAL_1703:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_1639]] to i8*
 // CHECK:         %[[VAL_1704:.*]] = getelementptr inbounds i8, i8* %[[VAL_1703]], i32 %[[VAL_1657]]
 // CHECK:         store i8 %[[VAL_1702]], i8* %[[VAL_1704]], align 1
 // CHECK:         br label %[[VAL_1663]]
 // CHECK:       entry:
 // CHECK:         %[[VAL_1705:.*]] = getelementptr inbounds i8, i8* %[[VAL_1706:.*]], i64 0
-// CHECK:         %[[VAL_1707:.*]] = bitcast i8* %[[VAL_1705]] to [100 x [200 x %[[VAL_1708:.*]]]]*
-// CHECK:         %[[VAL_1709:.*]] = getelementptr inbounds i8, i8* %[[VAL_1710:.*]], i64 0
-// CHECK:         %[[VAL_1711:.*]] = bitcast i8* %[[VAL_1709]] to [100 x [200 x float]]*
-// CHECK:         %[[VAL_1712:.*]] = getelementptr inbounds i8, i8* %[[VAL_1713:.*]], i64 0
-// CHECK:         %[[VAL_1714:.*]] = bitcast i8* %[[VAL_1712]] to [100 x [200 x float]]*
+// CHECK:         %[[VAL_1707:.*]] = bitcast i8* %[[VAL_1705]] to [100 x [200 x float]]*
+// CHECK:         %[[VAL_1708:.*]] = getelementptr inbounds i8, i8* %[[VAL_1709:.*]], i64 0
+// CHECK:         %[[VAL_1710:.*]] = bitcast i8* %[[VAL_1708]] to [100 x [200 x float]]*
+// CHECK:         %[[VAL_1711:.*]] = getelementptr inbounds i8, i8* %[[VAL_1712:.*]], i64 0
+// CHECK:         %[[VAL_1713:.*]] = bitcast i8* %[[VAL_1711]] to [100 x [200 x %[[VAL_1714:.*]]]]*
 // CHECK:         %[[VAL_1715:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1716:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1717:.*]] = mul nuw nsw i32 %[[VAL_1715]], 128
+// CHECK:         %[[VAL_1717:.*]] = mul nuw nsw i32 %[[VAL_1715]], 256
 // CHECK:         %[[VAL_1718:.*]] = add nuw nsw i32 %[[VAL_1717]], %[[VAL_1716]]
-// CHECK:         %[[VAL_1719:.*]] = icmp ult i32 %[[VAL_1718]], 163840
+// CHECK:         %[[VAL_1719:.*]] = icmp ult i32 %[[VAL_1718]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1719]])
 // CHECK:         %[[VAL_1720:.*]] = mul nuw nsw i32 %[[VAL_1718]], 4
 // CHECK:         %[[VAL_1721:.*]] = udiv i32 %[[VAL_1720]], 1
@@ -1893,50 +1893,50 @@
 // CHECK:       r30.in_bounds-after:                              ; preds = %[[VAL_1737]], %[[VAL_1739:.*]]
 // CHECK:         ret void
 // CHECK:       r30.in_bounds-true:                               ; preds = %[[VAL_1739]]
-// CHECK:         %[[VAL_1740:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1711]] to float*
+// CHECK:         %[[VAL_1740:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1707]] to float*
 // CHECK:         %[[VAL_1741:.*]] = getelementptr inbounds float, float* %[[VAL_1740]], i32 %[[VAL_1720]]
 // CHECK:         %[[VAL_1742:.*]] = load float, float* %[[VAL_1741]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1743:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1714]] to float*
+// CHECK:         %[[VAL_1743:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1710]] to float*
 // CHECK:         %[[VAL_1744:.*]] = getelementptr inbounds float, float* %[[VAL_1743]], i32 %[[VAL_1720]]
 // CHECK:         %[[VAL_1745:.*]] = load float, float* %[[VAL_1744]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1746:.*]] = insertvalue %[[VAL_1708]] zeroinitializer, float %[[VAL_1742]], 0
-// CHECK:         %[[VAL_1747:.*]] = insertvalue %[[VAL_1708]] %[[VAL_1746]], float %[[VAL_1745]], 1
-// CHECK:         %[[VAL_1748:.*]] = bitcast [100 x [200 x %[[VAL_1708]]]]* %[[VAL_1707]] to %[[VAL_1708]]*
-// CHECK:         %[[VAL_1749:.*]] = getelementptr inbounds %[[VAL_1708]], %[[VAL_1708]]* %[[VAL_1748]], i32 %[[VAL_1720]]
-// CHECK:         store %[[VAL_1708]] %[[VAL_1747]], %[[VAL_1708]]* %[[VAL_1749]], align 1
-// CHECK:         %[[VAL_1750:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1711]] to float*
+// CHECK:         %[[VAL_1746:.*]] = insertvalue %[[VAL_1714]] zeroinitializer, float %[[VAL_1742]], 0
+// CHECK:         %[[VAL_1747:.*]] = insertvalue %[[VAL_1714]] %[[VAL_1746]], float %[[VAL_1745]], 1
+// CHECK:         %[[VAL_1748:.*]] = bitcast [100 x [200 x %[[VAL_1714]]]]* %[[VAL_1713]] to %[[VAL_1714]]*
+// CHECK:         %[[VAL_1749:.*]] = getelementptr inbounds %[[VAL_1714]], %[[VAL_1714]]* %[[VAL_1748]], i32 %[[VAL_1720]]
+// CHECK:         store %[[VAL_1714]] %[[VAL_1747]], %[[VAL_1714]]* %[[VAL_1749]], align 1
+// CHECK:         %[[VAL_1750:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1707]] to float*
 // CHECK:         %[[VAL_1751:.*]] = getelementptr inbounds float, float* %[[VAL_1750]], i32 %[[VAL_1724]]
 // CHECK:         %[[VAL_1752:.*]] = load float, float* %[[VAL_1751]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1753:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1714]] to float*
+// CHECK:         %[[VAL_1753:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1710]] to float*
 // CHECK:         %[[VAL_1754:.*]] = getelementptr inbounds float, float* %[[VAL_1753]], i32 %[[VAL_1724]]
 // CHECK:         %[[VAL_1755:.*]] = load float, float* %[[VAL_1754]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1756:.*]] = insertvalue %[[VAL_1708]] zeroinitializer, float %[[VAL_1752]], 0
-// CHECK:         %[[VAL_1757:.*]] = insertvalue %[[VAL_1708]] %[[VAL_1756]], float %[[VAL_1755]], 1
-// CHECK:         %[[VAL_1758:.*]] = bitcast [100 x [200 x %[[VAL_1708]]]]* %[[VAL_1707]] to %[[VAL_1708]]*
-// CHECK:         %[[VAL_1759:.*]] = getelementptr inbounds %[[VAL_1708]], %[[VAL_1708]]* %[[VAL_1758]], i32 %[[VAL_1724]]
-// CHECK:         store %[[VAL_1708]] %[[VAL_1757]], %[[VAL_1708]]* %[[VAL_1759]], align 1
-// CHECK:         %[[VAL_1760:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1711]] to float*
+// CHECK:         %[[VAL_1756:.*]] = insertvalue %[[VAL_1714]] zeroinitializer, float %[[VAL_1752]], 0
+// CHECK:         %[[VAL_1757:.*]] = insertvalue %[[VAL_1714]] %[[VAL_1756]], float %[[VAL_1755]], 1
+// CHECK:         %[[VAL_1758:.*]] = bitcast [100 x [200 x %[[VAL_1714]]]]* %[[VAL_1713]] to %[[VAL_1714]]*
+// CHECK:         %[[VAL_1759:.*]] = getelementptr inbounds %[[VAL_1714]], %[[VAL_1714]]* %[[VAL_1758]], i32 %[[VAL_1724]]
+// CHECK:         store %[[VAL_1714]] %[[VAL_1757]], %[[VAL_1714]]* %[[VAL_1759]], align 1
+// CHECK:         %[[VAL_1760:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1707]] to float*
 // CHECK:         %[[VAL_1761:.*]] = getelementptr inbounds float, float* %[[VAL_1760]], i32 %[[VAL_1728]]
 // CHECK:         %[[VAL_1762:.*]] = load float, float* %[[VAL_1761]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1763:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1714]] to float*
+// CHECK:         %[[VAL_1763:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1710]] to float*
 // CHECK:         %[[VAL_1764:.*]] = getelementptr inbounds float, float* %[[VAL_1763]], i32 %[[VAL_1728]]
 // CHECK:         %[[VAL_1765:.*]] = load float, float* %[[VAL_1764]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1766:.*]] = insertvalue %[[VAL_1708]] zeroinitializer, float %[[VAL_1762]], 0
-// CHECK:         %[[VAL_1767:.*]] = insertvalue %[[VAL_1708]] %[[VAL_1766]], float %[[VAL_1765]], 1
-// CHECK:         %[[VAL_1768:.*]] = bitcast [100 x [200 x %[[VAL_1708]]]]* %[[VAL_1707]] to %[[VAL_1708]]*
-// CHECK:         %[[VAL_1769:.*]] = getelementptr inbounds %[[VAL_1708]], %[[VAL_1708]]* %[[VAL_1768]], i32 %[[VAL_1728]]
-// CHECK:         store %[[VAL_1708]] %[[VAL_1767]], %[[VAL_1708]]* %[[VAL_1769]], align 1
-// CHECK:         %[[VAL_1770:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1711]] to float*
+// CHECK:         %[[VAL_1766:.*]] = insertvalue %[[VAL_1714]] zeroinitializer, float %[[VAL_1762]], 0
+// CHECK:         %[[VAL_1767:.*]] = insertvalue %[[VAL_1714]] %[[VAL_1766]], float %[[VAL_1765]], 1
+// CHECK:         %[[VAL_1768:.*]] = bitcast [100 x [200 x %[[VAL_1714]]]]* %[[VAL_1713]] to %[[VAL_1714]]*
+// CHECK:         %[[VAL_1769:.*]] = getelementptr inbounds %[[VAL_1714]], %[[VAL_1714]]* %[[VAL_1768]], i32 %[[VAL_1728]]
+// CHECK:         store %[[VAL_1714]] %[[VAL_1767]], %[[VAL_1714]]* %[[VAL_1769]], align 1
+// CHECK:         %[[VAL_1770:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1707]] to float*
 // CHECK:         %[[VAL_1771:.*]] = getelementptr inbounds float, float* %[[VAL_1770]], i32 %[[VAL_1732]]
 // CHECK:         %[[VAL_1772:.*]] = load float, float* %[[VAL_1771]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1773:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1714]] to float*
+// CHECK:         %[[VAL_1773:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1710]] to float*
 // CHECK:         %[[VAL_1774:.*]] = getelementptr inbounds float, float* %[[VAL_1773]], i32 %[[VAL_1732]]
 // CHECK:         %[[VAL_1775:.*]] = load float, float* %[[VAL_1774]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1776:.*]] = insertvalue %[[VAL_1708]] zeroinitializer, float %[[VAL_1772]], 0
-// CHECK:         %[[VAL_1777:.*]] = insertvalue %[[VAL_1708]] %[[VAL_1776]], float %[[VAL_1775]], 1
-// CHECK:         %[[VAL_1778:.*]] = bitcast [100 x [200 x %[[VAL_1708]]]]* %[[VAL_1707]] to %[[VAL_1708]]*
-// CHECK:         %[[VAL_1779:.*]] = getelementptr inbounds %[[VAL_1708]], %[[VAL_1708]]* %[[VAL_1778]], i32 %[[VAL_1732]]
-// CHECK:         store %[[VAL_1708]] %[[VAL_1777]], %[[VAL_1708]]* %[[VAL_1779]], align 1
+// CHECK:         %[[VAL_1776:.*]] = insertvalue %[[VAL_1714]] zeroinitializer, float %[[VAL_1772]], 0
+// CHECK:         %[[VAL_1777:.*]] = insertvalue %[[VAL_1714]] %[[VAL_1776]], float %[[VAL_1775]], 1
+// CHECK:         %[[VAL_1778:.*]] = bitcast [100 x [200 x %[[VAL_1714]]]]* %[[VAL_1713]] to %[[VAL_1714]]*
+// CHECK:         %[[VAL_1779:.*]] = getelementptr inbounds %[[VAL_1714]], %[[VAL_1714]]* %[[VAL_1778]], i32 %[[VAL_1732]]
+// CHECK:         store %[[VAL_1714]] %[[VAL_1777]], %[[VAL_1714]]* %[[VAL_1779]], align 1
 // CHECK:         br label %[[VAL_1738]]
 // CHECK:       entry:
 // CHECK:         %[[VAL_1780:.*]] = getelementptr inbounds i8, i8* %[[VAL_1781:.*]], i64 0
@@ -1947,9 +1947,9 @@
 // CHECK:         %[[VAL_1788:.*]] = bitcast i8* %[[VAL_1786]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1789:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1790:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1791:.*]] = mul nuw nsw i32 %[[VAL_1789]], 128
+// CHECK:         %[[VAL_1791:.*]] = mul nuw nsw i32 %[[VAL_1789]], 256
 // CHECK:         %[[VAL_1792:.*]] = add nuw nsw i32 %[[VAL_1791]], %[[VAL_1790]]
-// CHECK:         %[[VAL_1793:.*]] = icmp ult i32 %[[VAL_1792]], 163840
+// CHECK:         %[[VAL_1793:.*]] = icmp ult i32 %[[VAL_1792]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1793]])
 // CHECK:         %[[VAL_1794:.*]] = mul nuw nsw i32 %[[VAL_1792]], 4
 // CHECK:         %[[VAL_1795:.*]] = udiv i32 %[[VAL_1794]], 1
@@ -1972,44 +1972,44 @@
 // CHECK:       r31.in_bounds-after:                              ; preds = %[[VAL_1811]], %[[VAL_1813:.*]]
 // CHECK:         ret void
 // CHECK:       r31.in_bounds-true:                               ; preds = %[[VAL_1813]]
-// CHECK:         %[[VAL_1814:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1785]] to float*
+// CHECK:         %[[VAL_1814:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1782]] to float*
 // CHECK:         %[[VAL_1815:.*]] = getelementptr inbounds float, float* %[[VAL_1814]], i32 %[[VAL_1794]]
 // CHECK:         %[[VAL_1816:.*]] = load float, float* %[[VAL_1815]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1817:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1788]] to float*
+// CHECK:         %[[VAL_1817:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1785]] to float*
 // CHECK:         %[[VAL_1818:.*]] = getelementptr inbounds float, float* %[[VAL_1817]], i32 %[[VAL_1794]]
 // CHECK:         %[[VAL_1819:.*]] = load float, float* %[[VAL_1818]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1820:.*]] = fdiv float %[[VAL_1816]], %[[VAL_1819]]
-// CHECK:         %[[VAL_1821:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1782]] to float*
+// CHECK:         %[[VAL_1821:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1788]] to float*
 // CHECK:         %[[VAL_1822:.*]] = getelementptr inbounds float, float* %[[VAL_1821]], i32 %[[VAL_1794]]
 // CHECK:         store float %[[VAL_1820]], float* %[[VAL_1822]], align 4
-// CHECK:         %[[VAL_1823:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1785]] to float*
+// CHECK:         %[[VAL_1823:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1782]] to float*
 // CHECK:         %[[VAL_1824:.*]] = getelementptr inbounds float, float* %[[VAL_1823]], i32 %[[VAL_1798]]
 // CHECK:         %[[VAL_1825:.*]] = load float, float* %[[VAL_1824]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1826:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1788]] to float*
+// CHECK:         %[[VAL_1826:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1785]] to float*
 // CHECK:         %[[VAL_1827:.*]] = getelementptr inbounds float, float* %[[VAL_1826]], i32 %[[VAL_1798]]
 // CHECK:         %[[VAL_1828:.*]] = load float, float* %[[VAL_1827]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1829:.*]] = fdiv float %[[VAL_1825]], %[[VAL_1828]]
-// CHECK:         %[[VAL_1830:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1782]] to float*
+// CHECK:         %[[VAL_1830:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1788]] to float*
 // CHECK:         %[[VAL_1831:.*]] = getelementptr inbounds float, float* %[[VAL_1830]], i32 %[[VAL_1798]]
 // CHECK:         store float %[[VAL_1829]], float* %[[VAL_1831]], align 4
-// CHECK:         %[[VAL_1832:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1785]] to float*
+// CHECK:         %[[VAL_1832:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1782]] to float*
 // CHECK:         %[[VAL_1833:.*]] = getelementptr inbounds float, float* %[[VAL_1832]], i32 %[[VAL_1802]]
 // CHECK:         %[[VAL_1834:.*]] = load float, float* %[[VAL_1833]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1835:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1788]] to float*
+// CHECK:         %[[VAL_1835:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1785]] to float*
 // CHECK:         %[[VAL_1836:.*]] = getelementptr inbounds float, float* %[[VAL_1835]], i32 %[[VAL_1802]]
 // CHECK:         %[[VAL_1837:.*]] = load float, float* %[[VAL_1836]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1838:.*]] = fdiv float %[[VAL_1834]], %[[VAL_1837]]
-// CHECK:         %[[VAL_1839:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1782]] to float*
+// CHECK:         %[[VAL_1839:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1788]] to float*
 // CHECK:         %[[VAL_1840:.*]] = getelementptr inbounds float, float* %[[VAL_1839]], i32 %[[VAL_1802]]
 // CHECK:         store float %[[VAL_1838]], float* %[[VAL_1840]], align 4
-// CHECK:         %[[VAL_1841:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1785]] to float*
+// CHECK:         %[[VAL_1841:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1782]] to float*
 // CHECK:         %[[VAL_1842:.*]] = getelementptr inbounds float, float* %[[VAL_1841]], i32 %[[VAL_1806]]
 // CHECK:         %[[VAL_1843:.*]] = load float, float* %[[VAL_1842]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1844:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1788]] to float*
+// CHECK:         %[[VAL_1844:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1785]] to float*
 // CHECK:         %[[VAL_1845:.*]] = getelementptr inbounds float, float* %[[VAL_1844]], i32 %[[VAL_1806]]
 // CHECK:         %[[VAL_1846:.*]] = load float, float* %[[VAL_1845]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1847:.*]] = fdiv float %[[VAL_1843]], %[[VAL_1846]]
-// CHECK:         %[[VAL_1848:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1782]] to float*
+// CHECK:         %[[VAL_1848:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1788]] to float*
 // CHECK:         %[[VAL_1849:.*]] = getelementptr inbounds float, float* %[[VAL_1848]], i32 %[[VAL_1806]]
 // CHECK:         store float %[[VAL_1847]], float* %[[VAL_1849]], align 4
 // CHECK:         br label %[[VAL_1812]]
@@ -2022,9 +2022,9 @@
 // CHECK:         %[[VAL_1858:.*]] = bitcast i8* %[[VAL_1856]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1859:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1860:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1861:.*]] = mul nuw nsw i32 %[[VAL_1859]], 128
+// CHECK:         %[[VAL_1861:.*]] = mul nuw nsw i32 %[[VAL_1859]], 256
 // CHECK:         %[[VAL_1862:.*]] = add nuw nsw i32 %[[VAL_1861]], %[[VAL_1860]]
-// CHECK:         %[[VAL_1863:.*]] = icmp ult i32 %[[VAL_1862]], 163840
+// CHECK:         %[[VAL_1863:.*]] = icmp ult i32 %[[VAL_1862]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1863]])
 // CHECK:         %[[VAL_1864:.*]] = mul nuw nsw i32 %[[VAL_1862]], 4
 // CHECK:         %[[VAL_1865:.*]] = udiv i32 %[[VAL_1864]], 1
@@ -2047,44 +2047,44 @@
 // CHECK:       r32.in_bounds-after:                              ; preds = %[[VAL_1881]], %[[VAL_1883:.*]]
 // CHECK:         ret void
 // CHECK:       r32.in_bounds-true:                               ; preds = %[[VAL_1883]]
-// CHECK:         %[[VAL_1884:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1855]] to float*
+// CHECK:         %[[VAL_1884:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1852]] to float*
 // CHECK:         %[[VAL_1885:.*]] = getelementptr inbounds float, float* %[[VAL_1884]], i32 %[[VAL_1864]]
 // CHECK:         %[[VAL_1886:.*]] = load float, float* %[[VAL_1885]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1887:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1858]] to float*
+// CHECK:         %[[VAL_1887:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1855]] to float*
 // CHECK:         %[[VAL_1888:.*]] = getelementptr inbounds float, float* %[[VAL_1887]], i32 %[[VAL_1864]]
 // CHECK:         %[[VAL_1889:.*]] = load float, float* %[[VAL_1888]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1890:.*]] = call float @llvm.maxnum.f32(float %[[VAL_1886]], float %[[VAL_1889]])
-// CHECK:         %[[VAL_1891:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1852]] to float*
+// CHECK:         %[[VAL_1891:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1858]] to float*
 // CHECK:         %[[VAL_1892:.*]] = getelementptr inbounds float, float* %[[VAL_1891]], i32 %[[VAL_1864]]
 // CHECK:         store float %[[VAL_1890]], float* %[[VAL_1892]], align 4
-// CHECK:         %[[VAL_1893:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1855]] to float*
+// CHECK:         %[[VAL_1893:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1852]] to float*
 // CHECK:         %[[VAL_1894:.*]] = getelementptr inbounds float, float* %[[VAL_1893]], i32 %[[VAL_1868]]
 // CHECK:         %[[VAL_1895:.*]] = load float, float* %[[VAL_1894]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1896:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1858]] to float*
+// CHECK:         %[[VAL_1896:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1855]] to float*
 // CHECK:         %[[VAL_1897:.*]] = getelementptr inbounds float, float* %[[VAL_1896]], i32 %[[VAL_1868]]
 // CHECK:         %[[VAL_1898:.*]] = load float, float* %[[VAL_1897]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1899:.*]] = call float @llvm.maxnum.f32(float %[[VAL_1895]], float %[[VAL_1898]])
-// CHECK:         %[[VAL_1900:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1852]] to float*
+// CHECK:         %[[VAL_1900:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1858]] to float*
 // CHECK:         %[[VAL_1901:.*]] = getelementptr inbounds float, float* %[[VAL_1900]], i32 %[[VAL_1868]]
 // CHECK:         store float %[[VAL_1899]], float* %[[VAL_1901]], align 4
-// CHECK:         %[[VAL_1902:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1855]] to float*
+// CHECK:         %[[VAL_1902:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1852]] to float*
 // CHECK:         %[[VAL_1903:.*]] = getelementptr inbounds float, float* %[[VAL_1902]], i32 %[[VAL_1872]]
 // CHECK:         %[[VAL_1904:.*]] = load float, float* %[[VAL_1903]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1905:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1858]] to float*
+// CHECK:         %[[VAL_1905:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1855]] to float*
 // CHECK:         %[[VAL_1906:.*]] = getelementptr inbounds float, float* %[[VAL_1905]], i32 %[[VAL_1872]]
 // CHECK:         %[[VAL_1907:.*]] = load float, float* %[[VAL_1906]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1908:.*]] = call float @llvm.maxnum.f32(float %[[VAL_1904]], float %[[VAL_1907]])
-// CHECK:         %[[VAL_1909:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1852]] to float*
+// CHECK:         %[[VAL_1909:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1858]] to float*
 // CHECK:         %[[VAL_1910:.*]] = getelementptr inbounds float, float* %[[VAL_1909]], i32 %[[VAL_1872]]
 // CHECK:         store float %[[VAL_1908]], float* %[[VAL_1910]], align 4
-// CHECK:         %[[VAL_1911:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1855]] to float*
+// CHECK:         %[[VAL_1911:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1852]] to float*
 // CHECK:         %[[VAL_1912:.*]] = getelementptr inbounds float, float* %[[VAL_1911]], i32 %[[VAL_1876]]
 // CHECK:         %[[VAL_1913:.*]] = load float, float* %[[VAL_1912]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1914:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1858]] to float*
+// CHECK:         %[[VAL_1914:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1855]] to float*
 // CHECK:         %[[VAL_1915:.*]] = getelementptr inbounds float, float* %[[VAL_1914]], i32 %[[VAL_1876]]
 // CHECK:         %[[VAL_1916:.*]] = load float, float* %[[VAL_1915]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1917:.*]] = call float @llvm.maxnum.f32(float %[[VAL_1913]], float %[[VAL_1916]])
-// CHECK:         %[[VAL_1918:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1852]] to float*
+// CHECK:         %[[VAL_1918:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1858]] to float*
 // CHECK:         %[[VAL_1919:.*]] = getelementptr inbounds float, float* %[[VAL_1918]], i32 %[[VAL_1876]]
 // CHECK:         store float %[[VAL_1917]], float* %[[VAL_1919]], align 4
 // CHECK:         br label %[[VAL_1882]]
@@ -2097,9 +2097,9 @@
 // CHECK:         %[[VAL_1928:.*]] = bitcast i8* %[[VAL_1926]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1929:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_1930:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_1931:.*]] = mul nuw nsw i32 %[[VAL_1929]], 128
+// CHECK:         %[[VAL_1931:.*]] = mul nuw nsw i32 %[[VAL_1929]], 256
 // CHECK:         %[[VAL_1932:.*]] = add nuw nsw i32 %[[VAL_1931]], %[[VAL_1930]]
-// CHECK:         %[[VAL_1933:.*]] = icmp ult i32 %[[VAL_1932]], 163840
+// CHECK:         %[[VAL_1933:.*]] = icmp ult i32 %[[VAL_1932]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_1933]])
 // CHECK:         %[[VAL_1934:.*]] = mul nuw nsw i32 %[[VAL_1932]], 4
 // CHECK:         %[[VAL_1935:.*]] = udiv i32 %[[VAL_1934]], 1
@@ -2122,44 +2122,44 @@
 // CHECK:       r33.in_bounds-after:                              ; preds = %[[VAL_1951]], %[[VAL_1953:.*]]
 // CHECK:         ret void
 // CHECK:       r33.in_bounds-true:                               ; preds = %[[VAL_1953]]
-// CHECK:         %[[VAL_1954:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1925]] to float*
+// CHECK:         %[[VAL_1954:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1922]] to float*
 // CHECK:         %[[VAL_1955:.*]] = getelementptr inbounds float, float* %[[VAL_1954]], i32 %[[VAL_1934]]
 // CHECK:         %[[VAL_1956:.*]] = load float, float* %[[VAL_1955]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1957:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1928]] to float*
+// CHECK:         %[[VAL_1957:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1925]] to float*
 // CHECK:         %[[VAL_1958:.*]] = getelementptr inbounds float, float* %[[VAL_1957]], i32 %[[VAL_1934]]
 // CHECK:         %[[VAL_1959:.*]] = load float, float* %[[VAL_1958]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1960:.*]] = call float @llvm.minnum.f32(float %[[VAL_1956]], float %[[VAL_1959]])
-// CHECK:         %[[VAL_1961:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1922]] to float*
+// CHECK:         %[[VAL_1961:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1928]] to float*
 // CHECK:         %[[VAL_1962:.*]] = getelementptr inbounds float, float* %[[VAL_1961]], i32 %[[VAL_1934]]
 // CHECK:         store float %[[VAL_1960]], float* %[[VAL_1962]], align 4
-// CHECK:         %[[VAL_1963:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1925]] to float*
+// CHECK:         %[[VAL_1963:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1922]] to float*
 // CHECK:         %[[VAL_1964:.*]] = getelementptr inbounds float, float* %[[VAL_1963]], i32 %[[VAL_1938]]
 // CHECK:         %[[VAL_1965:.*]] = load float, float* %[[VAL_1964]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1966:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1928]] to float*
+// CHECK:         %[[VAL_1966:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1925]] to float*
 // CHECK:         %[[VAL_1967:.*]] = getelementptr inbounds float, float* %[[VAL_1966]], i32 %[[VAL_1938]]
 // CHECK:         %[[VAL_1968:.*]] = load float, float* %[[VAL_1967]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1969:.*]] = call float @llvm.minnum.f32(float %[[VAL_1965]], float %[[VAL_1968]])
-// CHECK:         %[[VAL_1970:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1922]] to float*
+// CHECK:         %[[VAL_1970:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1928]] to float*
 // CHECK:         %[[VAL_1971:.*]] = getelementptr inbounds float, float* %[[VAL_1970]], i32 %[[VAL_1938]]
 // CHECK:         store float %[[VAL_1969]], float* %[[VAL_1971]], align 4
-// CHECK:         %[[VAL_1972:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1925]] to float*
+// CHECK:         %[[VAL_1972:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1922]] to float*
 // CHECK:         %[[VAL_1973:.*]] = getelementptr inbounds float, float* %[[VAL_1972]], i32 %[[VAL_1942]]
 // CHECK:         %[[VAL_1974:.*]] = load float, float* %[[VAL_1973]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1975:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1928]] to float*
+// CHECK:         %[[VAL_1975:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1925]] to float*
 // CHECK:         %[[VAL_1976:.*]] = getelementptr inbounds float, float* %[[VAL_1975]], i32 %[[VAL_1942]]
 // CHECK:         %[[VAL_1977:.*]] = load float, float* %[[VAL_1976]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1978:.*]] = call float @llvm.minnum.f32(float %[[VAL_1974]], float %[[VAL_1977]])
-// CHECK:         %[[VAL_1979:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1922]] to float*
+// CHECK:         %[[VAL_1979:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1928]] to float*
 // CHECK:         %[[VAL_1980:.*]] = getelementptr inbounds float, float* %[[VAL_1979]], i32 %[[VAL_1942]]
 // CHECK:         store float %[[VAL_1978]], float* %[[VAL_1980]], align 4
-// CHECK:         %[[VAL_1981:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1925]] to float*
+// CHECK:         %[[VAL_1981:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1922]] to float*
 // CHECK:         %[[VAL_1982:.*]] = getelementptr inbounds float, float* %[[VAL_1981]], i32 %[[VAL_1946]]
 // CHECK:         %[[VAL_1983:.*]] = load float, float* %[[VAL_1982]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_1984:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1928]] to float*
+// CHECK:         %[[VAL_1984:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1925]] to float*
 // CHECK:         %[[VAL_1985:.*]] = getelementptr inbounds float, float* %[[VAL_1984]], i32 %[[VAL_1946]]
 // CHECK:         %[[VAL_1986:.*]] = load float, float* %[[VAL_1985]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_1987:.*]] = call float @llvm.minnum.f32(float %[[VAL_1983]], float %[[VAL_1986]])
-// CHECK:         %[[VAL_1988:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1922]] to float*
+// CHECK:         %[[VAL_1988:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1928]] to float*
 // CHECK:         %[[VAL_1989:.*]] = getelementptr inbounds float, float* %[[VAL_1988]], i32 %[[VAL_1946]]
 // CHECK:         store float %[[VAL_1987]], float* %[[VAL_1989]], align 4
 // CHECK:         br label %[[VAL_1952]]
@@ -2172,9 +2172,9 @@
 // CHECK:         %[[VAL_1998:.*]] = bitcast i8* %[[VAL_1996]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_1999:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2000:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2001:.*]] = mul nuw nsw i32 %[[VAL_1999]], 128
+// CHECK:         %[[VAL_2001:.*]] = mul nuw nsw i32 %[[VAL_1999]], 256
 // CHECK:         %[[VAL_2002:.*]] = add nuw nsw i32 %[[VAL_2001]], %[[VAL_2000]]
-// CHECK:         %[[VAL_2003:.*]] = icmp ult i32 %[[VAL_2002]], 163840
+// CHECK:         %[[VAL_2003:.*]] = icmp ult i32 %[[VAL_2002]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2003]])
 // CHECK:         %[[VAL_2004:.*]] = mul nuw nsw i32 %[[VAL_2002]], 4
 // CHECK:         %[[VAL_2005:.*]] = udiv i32 %[[VAL_2004]], 1
@@ -2197,44 +2197,44 @@
 // CHECK:       r34.in_bounds-after:                              ; preds = %[[VAL_2021]], %[[VAL_2023:.*]]
 // CHECK:         ret void
 // CHECK:       r34.in_bounds-true:                               ; preds = %[[VAL_2023]]
-// CHECK:         %[[VAL_2024:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1995]] to float*
+// CHECK:         %[[VAL_2024:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1992]] to float*
 // CHECK:         %[[VAL_2025:.*]] = getelementptr inbounds float, float* %[[VAL_2024]], i32 %[[VAL_2004]]
 // CHECK:         %[[VAL_2026:.*]] = load float, float* %[[VAL_2025]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2027:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1998]] to float*
+// CHECK:         %[[VAL_2027:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1995]] to float*
 // CHECK:         %[[VAL_2028:.*]] = getelementptr inbounds float, float* %[[VAL_2027]], i32 %[[VAL_2004]]
 // CHECK:         %[[VAL_2029:.*]] = load float, float* %[[VAL_2028]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2030:.*]] = fmul float %[[VAL_2026]], %[[VAL_2029]]
-// CHECK:         %[[VAL_2031:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1992]] to float*
+// CHECK:         %[[VAL_2031:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1998]] to float*
 // CHECK:         %[[VAL_2032:.*]] = getelementptr inbounds float, float* %[[VAL_2031]], i32 %[[VAL_2004]]
 // CHECK:         store float %[[VAL_2030]], float* %[[VAL_2032]], align 4
-// CHECK:         %[[VAL_2033:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1995]] to float*
+// CHECK:         %[[VAL_2033:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1992]] to float*
 // CHECK:         %[[VAL_2034:.*]] = getelementptr inbounds float, float* %[[VAL_2033]], i32 %[[VAL_2008]]
 // CHECK:         %[[VAL_2035:.*]] = load float, float* %[[VAL_2034]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2036:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1998]] to float*
+// CHECK:         %[[VAL_2036:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1995]] to float*
 // CHECK:         %[[VAL_2037:.*]] = getelementptr inbounds float, float* %[[VAL_2036]], i32 %[[VAL_2008]]
 // CHECK:         %[[VAL_2038:.*]] = load float, float* %[[VAL_2037]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2039:.*]] = fmul float %[[VAL_2035]], %[[VAL_2038]]
-// CHECK:         %[[VAL_2040:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1992]] to float*
+// CHECK:         %[[VAL_2040:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1998]] to float*
 // CHECK:         %[[VAL_2041:.*]] = getelementptr inbounds float, float* %[[VAL_2040]], i32 %[[VAL_2008]]
 // CHECK:         store float %[[VAL_2039]], float* %[[VAL_2041]], align 4
-// CHECK:         %[[VAL_2042:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1995]] to float*
+// CHECK:         %[[VAL_2042:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1992]] to float*
 // CHECK:         %[[VAL_2043:.*]] = getelementptr inbounds float, float* %[[VAL_2042]], i32 %[[VAL_2012]]
 // CHECK:         %[[VAL_2044:.*]] = load float, float* %[[VAL_2043]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2045:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1998]] to float*
+// CHECK:         %[[VAL_2045:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1995]] to float*
 // CHECK:         %[[VAL_2046:.*]] = getelementptr inbounds float, float* %[[VAL_2045]], i32 %[[VAL_2012]]
 // CHECK:         %[[VAL_2047:.*]] = load float, float* %[[VAL_2046]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2048:.*]] = fmul float %[[VAL_2044]], %[[VAL_2047]]
-// CHECK:         %[[VAL_2049:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1992]] to float*
+// CHECK:         %[[VAL_2049:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1998]] to float*
 // CHECK:         %[[VAL_2050:.*]] = getelementptr inbounds float, float* %[[VAL_2049]], i32 %[[VAL_2012]]
 // CHECK:         store float %[[VAL_2048]], float* %[[VAL_2050]], align 4
-// CHECK:         %[[VAL_2051:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1995]] to float*
+// CHECK:         %[[VAL_2051:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1992]] to float*
 // CHECK:         %[[VAL_2052:.*]] = getelementptr inbounds float, float* %[[VAL_2051]], i32 %[[VAL_2016]]
 // CHECK:         %[[VAL_2053:.*]] = load float, float* %[[VAL_2052]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2054:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1998]] to float*
+// CHECK:         %[[VAL_2054:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1995]] to float*
 // CHECK:         %[[VAL_2055:.*]] = getelementptr inbounds float, float* %[[VAL_2054]], i32 %[[VAL_2016]]
 // CHECK:         %[[VAL_2056:.*]] = load float, float* %[[VAL_2055]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2057:.*]] = fmul float %[[VAL_2053]], %[[VAL_2056]]
-// CHECK:         %[[VAL_2058:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1992]] to float*
+// CHECK:         %[[VAL_2058:.*]] = bitcast [100 x [200 x float]]* %[[VAL_1998]] to float*
 // CHECK:         %[[VAL_2059:.*]] = getelementptr inbounds float, float* %[[VAL_2058]], i32 %[[VAL_2016]]
 // CHECK:         store float %[[VAL_2057]], float* %[[VAL_2059]], align 4
 // CHECK:         br label %[[VAL_2022]]
@@ -2246,10 +2246,10 @@
 // CHECK:         %[[VAL_2066:.*]] = getelementptr inbounds i8, i8* %[[VAL_2067:.*]], i64 0
 // CHECK:         %[[VAL_2068:.*]] = bitcast i8* %[[VAL_2066]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_2069:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
-// CHECK:         %[[VAL_2070:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2071:.*]] = mul nuw nsw i32 %[[VAL_2069]], 128
+// CHECK:         %[[VAL_2070:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !93
+// CHECK:         %[[VAL_2071:.*]] = mul nuw nsw i32 %[[VAL_2069]], 1024
 // CHECK:         %[[VAL_2072:.*]] = add nuw nsw i32 %[[VAL_2071]], %[[VAL_2070]]
-// CHECK:         %[[VAL_2073:.*]] = icmp ult i32 %[[VAL_2072]], 163840
+// CHECK:         %[[VAL_2073:.*]] = icmp ult i32 %[[VAL_2072]], 20480
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2073]])
 // CHECK:         %[[VAL_2074:.*]] = udiv i32 %[[VAL_2072]], 1
 // CHECK:         %[[VAL_2075:.*]] = urem i32 %[[VAL_2074]], 200
@@ -2259,14 +2259,14 @@
 // CHECK:       r35.in_bounds-after:                              ; preds = %[[VAL_2078]], %[[VAL_2080:.*]]
 // CHECK:         ret void
 // CHECK:       r35.in_bounds-true:                               ; preds = %[[VAL_2080]]
-// CHECK:         %[[VAL_2081:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2065]] to float*
+// CHECK:         %[[VAL_2081:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2062]] to float*
 // CHECK:         %[[VAL_2082:.*]] = getelementptr inbounds float, float* %[[VAL_2081]], i32 %[[VAL_2072]]
 // CHECK:         %[[VAL_2083:.*]] = load float, float* %[[VAL_2082]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2084:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2068]] to float*
+// CHECK:         %[[VAL_2084:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2065]] to float*
 // CHECK:         %[[VAL_2085:.*]] = getelementptr inbounds float, float* %[[VAL_2084]], i32 %[[VAL_2072]]
 // CHECK:         %[[VAL_2086:.*]] = load float, float* %[[VAL_2085]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2087:.*]] = call float @__nv_powf(float %[[VAL_2083]], float %[[VAL_2086]])
-// CHECK:         %[[VAL_2088:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2062]] to float*
+// CHECK:         %[[VAL_2088:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2068]] to float*
 // CHECK:         %[[VAL_2089:.*]] = getelementptr inbounds float, float* %[[VAL_2088]], i32 %[[VAL_2072]]
 // CHECK:         store float %[[VAL_2087]], float* %[[VAL_2089]], align 4
 // CHECK:         br label %[[VAL_2079]]
@@ -2279,9 +2279,9 @@
 // CHECK:         %[[VAL_2098:.*]] = bitcast i8* %[[VAL_2096]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_2099:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2100:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2101:.*]] = mul nuw nsw i32 %[[VAL_2099]], 128
+// CHECK:         %[[VAL_2101:.*]] = mul nuw nsw i32 %[[VAL_2099]], 256
 // CHECK:         %[[VAL_2102:.*]] = add nuw nsw i32 %[[VAL_2101]], %[[VAL_2100]]
-// CHECK:         %[[VAL_2103:.*]] = icmp ult i32 %[[VAL_2102]], 163840
+// CHECK:         %[[VAL_2103:.*]] = icmp ult i32 %[[VAL_2102]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2103]])
 // CHECK:         %[[VAL_2104:.*]] = mul nuw nsw i32 %[[VAL_2102]], 4
 // CHECK:         %[[VAL_2105:.*]] = udiv i32 %[[VAL_2104]], 1
@@ -2304,44 +2304,44 @@
 // CHECK:       r36.in_bounds-after:                              ; preds = %[[VAL_2121]], %[[VAL_2123:.*]]
 // CHECK:         ret void
 // CHECK:       r36.in_bounds-true:                               ; preds = %[[VAL_2123]]
-// CHECK:         %[[VAL_2124:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2095]] to float*
+// CHECK:         %[[VAL_2124:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2092]] to float*
 // CHECK:         %[[VAL_2125:.*]] = getelementptr inbounds float, float* %[[VAL_2124]], i32 %[[VAL_2104]]
 // CHECK:         %[[VAL_2126:.*]] = load float, float* %[[VAL_2125]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2127:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2098]] to float*
+// CHECK:         %[[VAL_2127:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2095]] to float*
 // CHECK:         %[[VAL_2128:.*]] = getelementptr inbounds float, float* %[[VAL_2127]], i32 %[[VAL_2104]]
 // CHECK:         %[[VAL_2129:.*]] = load float, float* %[[VAL_2128]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2130:.*]] = call float @__nv_fmodf(float %[[VAL_2126]], float %[[VAL_2129]])
-// CHECK:         %[[VAL_2131:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2092]] to float*
+// CHECK:         %[[VAL_2131:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2098]] to float*
 // CHECK:         %[[VAL_2132:.*]] = getelementptr inbounds float, float* %[[VAL_2131]], i32 %[[VAL_2104]]
 // CHECK:         store float %[[VAL_2130]], float* %[[VAL_2132]], align 4
-// CHECK:         %[[VAL_2133:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2095]] to float*
+// CHECK:         %[[VAL_2133:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2092]] to float*
 // CHECK:         %[[VAL_2134:.*]] = getelementptr inbounds float, float* %[[VAL_2133]], i32 %[[VAL_2108]]
 // CHECK:         %[[VAL_2135:.*]] = load float, float* %[[VAL_2134]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2136:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2098]] to float*
+// CHECK:         %[[VAL_2136:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2095]] to float*
 // CHECK:         %[[VAL_2137:.*]] = getelementptr inbounds float, float* %[[VAL_2136]], i32 %[[VAL_2108]]
 // CHECK:         %[[VAL_2138:.*]] = load float, float* %[[VAL_2137]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2139:.*]] = call float @__nv_fmodf(float %[[VAL_2135]], float %[[VAL_2138]])
-// CHECK:         %[[VAL_2140:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2092]] to float*
+// CHECK:         %[[VAL_2140:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2098]] to float*
 // CHECK:         %[[VAL_2141:.*]] = getelementptr inbounds float, float* %[[VAL_2140]], i32 %[[VAL_2108]]
 // CHECK:         store float %[[VAL_2139]], float* %[[VAL_2141]], align 4
-// CHECK:         %[[VAL_2142:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2095]] to float*
+// CHECK:         %[[VAL_2142:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2092]] to float*
 // CHECK:         %[[VAL_2143:.*]] = getelementptr inbounds float, float* %[[VAL_2142]], i32 %[[VAL_2112]]
 // CHECK:         %[[VAL_2144:.*]] = load float, float* %[[VAL_2143]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2145:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2098]] to float*
+// CHECK:         %[[VAL_2145:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2095]] to float*
 // CHECK:         %[[VAL_2146:.*]] = getelementptr inbounds float, float* %[[VAL_2145]], i32 %[[VAL_2112]]
 // CHECK:         %[[VAL_2147:.*]] = load float, float* %[[VAL_2146]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2148:.*]] = call float @__nv_fmodf(float %[[VAL_2144]], float %[[VAL_2147]])
-// CHECK:         %[[VAL_2149:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2092]] to float*
+// CHECK:         %[[VAL_2149:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2098]] to float*
 // CHECK:         %[[VAL_2150:.*]] = getelementptr inbounds float, float* %[[VAL_2149]], i32 %[[VAL_2112]]
 // CHECK:         store float %[[VAL_2148]], float* %[[VAL_2150]], align 4
-// CHECK:         %[[VAL_2151:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2095]] to float*
+// CHECK:         %[[VAL_2151:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2092]] to float*
 // CHECK:         %[[VAL_2152:.*]] = getelementptr inbounds float, float* %[[VAL_2151]], i32 %[[VAL_2116]]
 // CHECK:         %[[VAL_2153:.*]] = load float, float* %[[VAL_2152]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2154:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2098]] to float*
+// CHECK:         %[[VAL_2154:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2095]] to float*
 // CHECK:         %[[VAL_2155:.*]] = getelementptr inbounds float, float* %[[VAL_2154]], i32 %[[VAL_2116]]
 // CHECK:         %[[VAL_2156:.*]] = load float, float* %[[VAL_2155]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2157:.*]] = call float @__nv_fmodf(float %[[VAL_2153]], float %[[VAL_2156]])
-// CHECK:         %[[VAL_2158:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2092]] to float*
+// CHECK:         %[[VAL_2158:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2098]] to float*
 // CHECK:         %[[VAL_2159:.*]] = getelementptr inbounds float, float* %[[VAL_2158]], i32 %[[VAL_2116]]
 // CHECK:         store float %[[VAL_2157]], float* %[[VAL_2159]], align 4
 // CHECK:         br label %[[VAL_2122]]
@@ -2354,9 +2354,9 @@
 // CHECK:         %[[VAL_2168:.*]] = bitcast i8* %[[VAL_2166]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_2169:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2170:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2171:.*]] = mul nuw nsw i32 %[[VAL_2169]], 128
+// CHECK:         %[[VAL_2171:.*]] = mul nuw nsw i32 %[[VAL_2169]], 256
 // CHECK:         %[[VAL_2172:.*]] = add nuw nsw i32 %[[VAL_2171]], %[[VAL_2170]]
-// CHECK:         %[[VAL_2173:.*]] = icmp ult i32 %[[VAL_2172]], 163840
+// CHECK:         %[[VAL_2173:.*]] = icmp ult i32 %[[VAL_2172]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2173]])
 // CHECK:         %[[VAL_2174:.*]] = mul nuw nsw i32 %[[VAL_2172]], 4
 // CHECK:         %[[VAL_2175:.*]] = udiv i32 %[[VAL_2174]], 1
@@ -2379,44 +2379,44 @@
 // CHECK:       r37.in_bounds-after:                              ; preds = %[[VAL_2191]], %[[VAL_2193:.*]]
 // CHECK:         ret void
 // CHECK:       r37.in_bounds-true:                               ; preds = %[[VAL_2193]]
-// CHECK:         %[[VAL_2194:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2165]] to float*
+// CHECK:         %[[VAL_2194:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2162]] to float*
 // CHECK:         %[[VAL_2195:.*]] = getelementptr inbounds float, float* %[[VAL_2194]], i32 %[[VAL_2174]]
 // CHECK:         %[[VAL_2196:.*]] = load float, float* %[[VAL_2195]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2197:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2168]] to float*
+// CHECK:         %[[VAL_2197:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2165]] to float*
 // CHECK:         %[[VAL_2198:.*]] = getelementptr inbounds float, float* %[[VAL_2197]], i32 %[[VAL_2174]]
 // CHECK:         %[[VAL_2199:.*]] = load float, float* %[[VAL_2198]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2200:.*]] = fsub float %[[VAL_2196]], %[[VAL_2199]]
-// CHECK:         %[[VAL_2201:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2162]] to float*
+// CHECK:         %[[VAL_2201:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2168]] to float*
 // CHECK:         %[[VAL_2202:.*]] = getelementptr inbounds float, float* %[[VAL_2201]], i32 %[[VAL_2174]]
 // CHECK:         store float %[[VAL_2200]], float* %[[VAL_2202]], align 4
-// CHECK:         %[[VAL_2203:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2165]] to float*
+// CHECK:         %[[VAL_2203:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2162]] to float*
 // CHECK:         %[[VAL_2204:.*]] = getelementptr inbounds float, float* %[[VAL_2203]], i32 %[[VAL_2178]]
 // CHECK:         %[[VAL_2205:.*]] = load float, float* %[[VAL_2204]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2206:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2168]] to float*
+// CHECK:         %[[VAL_2206:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2165]] to float*
 // CHECK:         %[[VAL_2207:.*]] = getelementptr inbounds float, float* %[[VAL_2206]], i32 %[[VAL_2178]]
 // CHECK:         %[[VAL_2208:.*]] = load float, float* %[[VAL_2207]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2209:.*]] = fsub float %[[VAL_2205]], %[[VAL_2208]]
-// CHECK:         %[[VAL_2210:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2162]] to float*
+// CHECK:         %[[VAL_2210:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2168]] to float*
 // CHECK:         %[[VAL_2211:.*]] = getelementptr inbounds float, float* %[[VAL_2210]], i32 %[[VAL_2178]]
 // CHECK:         store float %[[VAL_2209]], float* %[[VAL_2211]], align 4
-// CHECK:         %[[VAL_2212:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2165]] to float*
+// CHECK:         %[[VAL_2212:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2162]] to float*
 // CHECK:         %[[VAL_2213:.*]] = getelementptr inbounds float, float* %[[VAL_2212]], i32 %[[VAL_2182]]
 // CHECK:         %[[VAL_2214:.*]] = load float, float* %[[VAL_2213]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2215:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2168]] to float*
+// CHECK:         %[[VAL_2215:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2165]] to float*
 // CHECK:         %[[VAL_2216:.*]] = getelementptr inbounds float, float* %[[VAL_2215]], i32 %[[VAL_2182]]
 // CHECK:         %[[VAL_2217:.*]] = load float, float* %[[VAL_2216]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2218:.*]] = fsub float %[[VAL_2214]], %[[VAL_2217]]
-// CHECK:         %[[VAL_2219:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2162]] to float*
+// CHECK:         %[[VAL_2219:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2168]] to float*
 // CHECK:         %[[VAL_2220:.*]] = getelementptr inbounds float, float* %[[VAL_2219]], i32 %[[VAL_2182]]
 // CHECK:         store float %[[VAL_2218]], float* %[[VAL_2220]], align 4
-// CHECK:         %[[VAL_2221:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2165]] to float*
+// CHECK:         %[[VAL_2221:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2162]] to float*
 // CHECK:         %[[VAL_2222:.*]] = getelementptr inbounds float, float* %[[VAL_2221]], i32 %[[VAL_2186]]
 // CHECK:         %[[VAL_2223:.*]] = load float, float* %[[VAL_2222]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2224:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2168]] to float*
+// CHECK:         %[[VAL_2224:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2165]] to float*
 // CHECK:         %[[VAL_2225:.*]] = getelementptr inbounds float, float* %[[VAL_2224]], i32 %[[VAL_2186]]
 // CHECK:         %[[VAL_2226:.*]] = load float, float* %[[VAL_2225]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2227:.*]] = fsub float %[[VAL_2223]], %[[VAL_2226]]
-// CHECK:         %[[VAL_2228:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2162]] to float*
+// CHECK:         %[[VAL_2228:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2168]] to float*
 // CHECK:         %[[VAL_2229:.*]] = getelementptr inbounds float, float* %[[VAL_2228]], i32 %[[VAL_2186]]
 // CHECK:         store float %[[VAL_2227]], float* %[[VAL_2229]], align 4
 // CHECK:         br label %[[VAL_2192]]
@@ -2429,9 +2429,9 @@
 // CHECK:         %[[VAL_2238:.*]] = bitcast i8* %[[VAL_2236]] to [100 x [200 x i8]]*
 // CHECK:         %[[VAL_2239:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2240:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2241:.*]] = mul nuw nsw i32 %[[VAL_2239]], 128
+// CHECK:         %[[VAL_2241:.*]] = mul nuw nsw i32 %[[VAL_2239]], 256
 // CHECK:         %[[VAL_2242:.*]] = add nuw nsw i32 %[[VAL_2241]], %[[VAL_2240]]
-// CHECK:         %[[VAL_2243:.*]] = icmp ult i32 %[[VAL_2242]], 163840
+// CHECK:         %[[VAL_2243:.*]] = icmp ult i32 %[[VAL_2242]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2243]])
 // CHECK:         %[[VAL_2244:.*]] = mul nuw nsw i32 %[[VAL_2242]], 4
 // CHECK:         %[[VAL_2245:.*]] = udiv i32 %[[VAL_2244]], 1
@@ -2454,44 +2454,44 @@
 // CHECK:       r38.in_bounds-after:                              ; preds = %[[VAL_2261]], %[[VAL_2263:.*]]
 // CHECK:         ret void
 // CHECK:       r38.in_bounds-true:                               ; preds = %[[VAL_2263]]
-// CHECK:         %[[VAL_2264:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2235]] to i8*
+// CHECK:         %[[VAL_2264:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2232]] to i8*
 // CHECK:         %[[VAL_2265:.*]] = getelementptr inbounds i8, i8* %[[VAL_2264]], i32 %[[VAL_2244]]
 // CHECK:         %[[VAL_2266:.*]] = load i8, i8* %[[VAL_2265]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2267:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2238]] to i8*
+// CHECK:         %[[VAL_2267:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2235]] to i8*
 // CHECK:         %[[VAL_2268:.*]] = getelementptr inbounds i8, i8* %[[VAL_2267]], i32 %[[VAL_2244]]
 // CHECK:         %[[VAL_2269:.*]] = load i8, i8* %[[VAL_2268]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2270:.*]] = and i8 %[[VAL_2266]], %[[VAL_2269]]
-// CHECK:         %[[VAL_2271:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2232]] to i8*
+// CHECK:         %[[VAL_2271:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2238]] to i8*
 // CHECK:         %[[VAL_2272:.*]] = getelementptr inbounds i8, i8* %[[VAL_2271]], i32 %[[VAL_2244]]
 // CHECK:         store i8 %[[VAL_2270]], i8* %[[VAL_2272]], align 1
-// CHECK:         %[[VAL_2273:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2235]] to i8*
+// CHECK:         %[[VAL_2273:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2232]] to i8*
 // CHECK:         %[[VAL_2274:.*]] = getelementptr inbounds i8, i8* %[[VAL_2273]], i32 %[[VAL_2248]]
 // CHECK:         %[[VAL_2275:.*]] = load i8, i8* %[[VAL_2274]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2276:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2238]] to i8*
+// CHECK:         %[[VAL_2276:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2235]] to i8*
 // CHECK:         %[[VAL_2277:.*]] = getelementptr inbounds i8, i8* %[[VAL_2276]], i32 %[[VAL_2248]]
 // CHECK:         %[[VAL_2278:.*]] = load i8, i8* %[[VAL_2277]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2279:.*]] = and i8 %[[VAL_2275]], %[[VAL_2278]]
-// CHECK:         %[[VAL_2280:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2232]] to i8*
+// CHECK:         %[[VAL_2280:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2238]] to i8*
 // CHECK:         %[[VAL_2281:.*]] = getelementptr inbounds i8, i8* %[[VAL_2280]], i32 %[[VAL_2248]]
 // CHECK:         store i8 %[[VAL_2279]], i8* %[[VAL_2281]], align 1
-// CHECK:         %[[VAL_2282:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2235]] to i8*
+// CHECK:         %[[VAL_2282:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2232]] to i8*
 // CHECK:         %[[VAL_2283:.*]] = getelementptr inbounds i8, i8* %[[VAL_2282]], i32 %[[VAL_2252]]
 // CHECK:         %[[VAL_2284:.*]] = load i8, i8* %[[VAL_2283]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2285:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2238]] to i8*
+// CHECK:         %[[VAL_2285:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2235]] to i8*
 // CHECK:         %[[VAL_2286:.*]] = getelementptr inbounds i8, i8* %[[VAL_2285]], i32 %[[VAL_2252]]
 // CHECK:         %[[VAL_2287:.*]] = load i8, i8* %[[VAL_2286]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2288:.*]] = and i8 %[[VAL_2284]], %[[VAL_2287]]
-// CHECK:         %[[VAL_2289:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2232]] to i8*
+// CHECK:         %[[VAL_2289:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2238]] to i8*
 // CHECK:         %[[VAL_2290:.*]] = getelementptr inbounds i8, i8* %[[VAL_2289]], i32 %[[VAL_2252]]
 // CHECK:         store i8 %[[VAL_2288]], i8* %[[VAL_2290]], align 1
-// CHECK:         %[[VAL_2291:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2235]] to i8*
+// CHECK:         %[[VAL_2291:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2232]] to i8*
 // CHECK:         %[[VAL_2292:.*]] = getelementptr inbounds i8, i8* %[[VAL_2291]], i32 %[[VAL_2256]]
 // CHECK:         %[[VAL_2293:.*]] = load i8, i8* %[[VAL_2292]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2294:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2238]] to i8*
+// CHECK:         %[[VAL_2294:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2235]] to i8*
 // CHECK:         %[[VAL_2295:.*]] = getelementptr inbounds i8, i8* %[[VAL_2294]], i32 %[[VAL_2256]]
 // CHECK:         %[[VAL_2296:.*]] = load i8, i8* %[[VAL_2295]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2297:.*]] = and i8 %[[VAL_2293]], %[[VAL_2296]]
-// CHECK:         %[[VAL_2298:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2232]] to i8*
+// CHECK:         %[[VAL_2298:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2238]] to i8*
 // CHECK:         %[[VAL_2299:.*]] = getelementptr inbounds i8, i8* %[[VAL_2298]], i32 %[[VAL_2256]]
 // CHECK:         store i8 %[[VAL_2297]], i8* %[[VAL_2299]], align 1
 // CHECK:         br label %[[VAL_2262]]
@@ -2504,9 +2504,9 @@
 // CHECK:         %[[VAL_2308:.*]] = bitcast i8* %[[VAL_2306]] to [100 x [200 x i8]]*
 // CHECK:         %[[VAL_2309:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2310:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2311:.*]] = mul nuw nsw i32 %[[VAL_2309]], 128
+// CHECK:         %[[VAL_2311:.*]] = mul nuw nsw i32 %[[VAL_2309]], 256
 // CHECK:         %[[VAL_2312:.*]] = add nuw nsw i32 %[[VAL_2311]], %[[VAL_2310]]
-// CHECK:         %[[VAL_2313:.*]] = icmp ult i32 %[[VAL_2312]], 163840
+// CHECK:         %[[VAL_2313:.*]] = icmp ult i32 %[[VAL_2312]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2313]])
 // CHECK:         %[[VAL_2314:.*]] = mul nuw nsw i32 %[[VAL_2312]], 4
 // CHECK:         %[[VAL_2315:.*]] = udiv i32 %[[VAL_2314]], 1
@@ -2529,44 +2529,44 @@
 // CHECK:       r39.in_bounds-after:                              ; preds = %[[VAL_2331]], %[[VAL_2333:.*]]
 // CHECK:         ret void
 // CHECK:       r39.in_bounds-true:                               ; preds = %[[VAL_2333]]
-// CHECK:         %[[VAL_2334:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2305]] to i8*
+// CHECK:         %[[VAL_2334:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2302]] to i8*
 // CHECK:         %[[VAL_2335:.*]] = getelementptr inbounds i8, i8* %[[VAL_2334]], i32 %[[VAL_2314]]
 // CHECK:         %[[VAL_2336:.*]] = load i8, i8* %[[VAL_2335]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2337:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2308]] to i8*
+// CHECK:         %[[VAL_2337:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2305]] to i8*
 // CHECK:         %[[VAL_2338:.*]] = getelementptr inbounds i8, i8* %[[VAL_2337]], i32 %[[VAL_2314]]
 // CHECK:         %[[VAL_2339:.*]] = load i8, i8* %[[VAL_2338]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2340:.*]] = or i8 %[[VAL_2336]], %[[VAL_2339]]
-// CHECK:         %[[VAL_2341:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2302]] to i8*
+// CHECK:         %[[VAL_2341:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2308]] to i8*
 // CHECK:         %[[VAL_2342:.*]] = getelementptr inbounds i8, i8* %[[VAL_2341]], i32 %[[VAL_2314]]
 // CHECK:         store i8 %[[VAL_2340]], i8* %[[VAL_2342]], align 1
-// CHECK:         %[[VAL_2343:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2305]] to i8*
+// CHECK:         %[[VAL_2343:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2302]] to i8*
 // CHECK:         %[[VAL_2344:.*]] = getelementptr inbounds i8, i8* %[[VAL_2343]], i32 %[[VAL_2318]]
 // CHECK:         %[[VAL_2345:.*]] = load i8, i8* %[[VAL_2344]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2346:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2308]] to i8*
+// CHECK:         %[[VAL_2346:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2305]] to i8*
 // CHECK:         %[[VAL_2347:.*]] = getelementptr inbounds i8, i8* %[[VAL_2346]], i32 %[[VAL_2318]]
 // CHECK:         %[[VAL_2348:.*]] = load i8, i8* %[[VAL_2347]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2349:.*]] = or i8 %[[VAL_2345]], %[[VAL_2348]]
-// CHECK:         %[[VAL_2350:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2302]] to i8*
+// CHECK:         %[[VAL_2350:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2308]] to i8*
 // CHECK:         %[[VAL_2351:.*]] = getelementptr inbounds i8, i8* %[[VAL_2350]], i32 %[[VAL_2318]]
 // CHECK:         store i8 %[[VAL_2349]], i8* %[[VAL_2351]], align 1
-// CHECK:         %[[VAL_2352:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2305]] to i8*
+// CHECK:         %[[VAL_2352:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2302]] to i8*
 // CHECK:         %[[VAL_2353:.*]] = getelementptr inbounds i8, i8* %[[VAL_2352]], i32 %[[VAL_2322]]
 // CHECK:         %[[VAL_2354:.*]] = load i8, i8* %[[VAL_2353]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2355:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2308]] to i8*
+// CHECK:         %[[VAL_2355:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2305]] to i8*
 // CHECK:         %[[VAL_2356:.*]] = getelementptr inbounds i8, i8* %[[VAL_2355]], i32 %[[VAL_2322]]
 // CHECK:         %[[VAL_2357:.*]] = load i8, i8* %[[VAL_2356]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2358:.*]] = or i8 %[[VAL_2354]], %[[VAL_2357]]
-// CHECK:         %[[VAL_2359:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2302]] to i8*
+// CHECK:         %[[VAL_2359:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2308]] to i8*
 // CHECK:         %[[VAL_2360:.*]] = getelementptr inbounds i8, i8* %[[VAL_2359]], i32 %[[VAL_2322]]
 // CHECK:         store i8 %[[VAL_2358]], i8* %[[VAL_2360]], align 1
-// CHECK:         %[[VAL_2361:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2305]] to i8*
+// CHECK:         %[[VAL_2361:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2302]] to i8*
 // CHECK:         %[[VAL_2362:.*]] = getelementptr inbounds i8, i8* %[[VAL_2361]], i32 %[[VAL_2326]]
 // CHECK:         %[[VAL_2363:.*]] = load i8, i8* %[[VAL_2362]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2364:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2308]] to i8*
+// CHECK:         %[[VAL_2364:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2305]] to i8*
 // CHECK:         %[[VAL_2365:.*]] = getelementptr inbounds i8, i8* %[[VAL_2364]], i32 %[[VAL_2326]]
 // CHECK:         %[[VAL_2366:.*]] = load i8, i8* %[[VAL_2365]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2367:.*]] = or i8 %[[VAL_2363]], %[[VAL_2366]]
-// CHECK:         %[[VAL_2368:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2302]] to i8*
+// CHECK:         %[[VAL_2368:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2308]] to i8*
 // CHECK:         %[[VAL_2369:.*]] = getelementptr inbounds i8, i8* %[[VAL_2368]], i32 %[[VAL_2326]]
 // CHECK:         store i8 %[[VAL_2367]], i8* %[[VAL_2369]], align 1
 // CHECK:         br label %[[VAL_2332]]
@@ -2579,9 +2579,9 @@
 // CHECK:         %[[VAL_2378:.*]] = bitcast i8* %[[VAL_2376]] to [100 x [200 x i8]]*
 // CHECK:         %[[VAL_2379:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2380:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2381:.*]] = mul nuw nsw i32 %[[VAL_2379]], 128
+// CHECK:         %[[VAL_2381:.*]] = mul nuw nsw i32 %[[VAL_2379]], 256
 // CHECK:         %[[VAL_2382:.*]] = add nuw nsw i32 %[[VAL_2381]], %[[VAL_2380]]
-// CHECK:         %[[VAL_2383:.*]] = icmp ult i32 %[[VAL_2382]], 163840
+// CHECK:         %[[VAL_2383:.*]] = icmp ult i32 %[[VAL_2382]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2383]])
 // CHECK:         %[[VAL_2384:.*]] = mul nuw nsw i32 %[[VAL_2382]], 4
 // CHECK:         %[[VAL_2385:.*]] = udiv i32 %[[VAL_2384]], 1
@@ -2604,44 +2604,44 @@
 // CHECK:       r40.in_bounds-after:                              ; preds = %[[VAL_2401]], %[[VAL_2403:.*]]
 // CHECK:         ret void
 // CHECK:       r40.in_bounds-true:                               ; preds = %[[VAL_2403]]
-// CHECK:         %[[VAL_2404:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2375]] to i8*
+// CHECK:         %[[VAL_2404:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2372]] to i8*
 // CHECK:         %[[VAL_2405:.*]] = getelementptr inbounds i8, i8* %[[VAL_2404]], i32 %[[VAL_2384]]
 // CHECK:         %[[VAL_2406:.*]] = load i8, i8* %[[VAL_2405]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2407:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2378]] to i8*
+// CHECK:         %[[VAL_2407:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2375]] to i8*
 // CHECK:         %[[VAL_2408:.*]] = getelementptr inbounds i8, i8* %[[VAL_2407]], i32 %[[VAL_2384]]
 // CHECK:         %[[VAL_2409:.*]] = load i8, i8* %[[VAL_2408]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2410:.*]] = xor i8 %[[VAL_2406]], %[[VAL_2409]]
-// CHECK:         %[[VAL_2411:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2372]] to i8*
+// CHECK:         %[[VAL_2411:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2378]] to i8*
 // CHECK:         %[[VAL_2412:.*]] = getelementptr inbounds i8, i8* %[[VAL_2411]], i32 %[[VAL_2384]]
 // CHECK:         store i8 %[[VAL_2410]], i8* %[[VAL_2412]], align 1
-// CHECK:         %[[VAL_2413:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2375]] to i8*
+// CHECK:         %[[VAL_2413:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2372]] to i8*
 // CHECK:         %[[VAL_2414:.*]] = getelementptr inbounds i8, i8* %[[VAL_2413]], i32 %[[VAL_2388]]
 // CHECK:         %[[VAL_2415:.*]] = load i8, i8* %[[VAL_2414]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2416:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2378]] to i8*
+// CHECK:         %[[VAL_2416:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2375]] to i8*
 // CHECK:         %[[VAL_2417:.*]] = getelementptr inbounds i8, i8* %[[VAL_2416]], i32 %[[VAL_2388]]
 // CHECK:         %[[VAL_2418:.*]] = load i8, i8* %[[VAL_2417]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2419:.*]] = xor i8 %[[VAL_2415]], %[[VAL_2418]]
-// CHECK:         %[[VAL_2420:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2372]] to i8*
+// CHECK:         %[[VAL_2420:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2378]] to i8*
 // CHECK:         %[[VAL_2421:.*]] = getelementptr inbounds i8, i8* %[[VAL_2420]], i32 %[[VAL_2388]]
 // CHECK:         store i8 %[[VAL_2419]], i8* %[[VAL_2421]], align 1
-// CHECK:         %[[VAL_2422:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2375]] to i8*
+// CHECK:         %[[VAL_2422:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2372]] to i8*
 // CHECK:         %[[VAL_2423:.*]] = getelementptr inbounds i8, i8* %[[VAL_2422]], i32 %[[VAL_2392]]
 // CHECK:         %[[VAL_2424:.*]] = load i8, i8* %[[VAL_2423]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2425:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2378]] to i8*
+// CHECK:         %[[VAL_2425:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2375]] to i8*
 // CHECK:         %[[VAL_2426:.*]] = getelementptr inbounds i8, i8* %[[VAL_2425]], i32 %[[VAL_2392]]
 // CHECK:         %[[VAL_2427:.*]] = load i8, i8* %[[VAL_2426]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2428:.*]] = xor i8 %[[VAL_2424]], %[[VAL_2427]]
-// CHECK:         %[[VAL_2429:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2372]] to i8*
+// CHECK:         %[[VAL_2429:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2378]] to i8*
 // CHECK:         %[[VAL_2430:.*]] = getelementptr inbounds i8, i8* %[[VAL_2429]], i32 %[[VAL_2392]]
 // CHECK:         store i8 %[[VAL_2428]], i8* %[[VAL_2430]], align 1
-// CHECK:         %[[VAL_2431:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2375]] to i8*
+// CHECK:         %[[VAL_2431:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2372]] to i8*
 // CHECK:         %[[VAL_2432:.*]] = getelementptr inbounds i8, i8* %[[VAL_2431]], i32 %[[VAL_2396]]
 // CHECK:         %[[VAL_2433:.*]] = load i8, i8* %[[VAL_2432]], align 1, !invariant.load !92
-// CHECK:         %[[VAL_2434:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2378]] to i8*
+// CHECK:         %[[VAL_2434:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2375]] to i8*
 // CHECK:         %[[VAL_2435:.*]] = getelementptr inbounds i8, i8* %[[VAL_2434]], i32 %[[VAL_2396]]
 // CHECK:         %[[VAL_2436:.*]] = load i8, i8* %[[VAL_2435]], align 1, !invariant.load !92
 // CHECK:         %[[VAL_2437:.*]] = xor i8 %[[VAL_2433]], %[[VAL_2436]]
-// CHECK:         %[[VAL_2438:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2372]] to i8*
+// CHECK:         %[[VAL_2438:.*]] = bitcast [100 x [200 x i8]]* %[[VAL_2378]] to i8*
 // CHECK:         %[[VAL_2439:.*]] = getelementptr inbounds i8, i8* %[[VAL_2438]], i32 %[[VAL_2396]]
 // CHECK:         store i8 %[[VAL_2437]], i8* %[[VAL_2439]], align 1
 // CHECK:         br label %[[VAL_2402]]
@@ -2654,9 +2654,9 @@
 // CHECK:         %[[VAL_2448:.*]] = bitcast i8* %[[VAL_2446]] to [100 x [200 x i32]]*
 // CHECK:         %[[VAL_2449:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2450:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2451:.*]] = mul nuw nsw i32 %[[VAL_2449]], 128
+// CHECK:         %[[VAL_2451:.*]] = mul nuw nsw i32 %[[VAL_2449]], 256
 // CHECK:         %[[VAL_2452:.*]] = add nuw nsw i32 %[[VAL_2451]], %[[VAL_2450]]
-// CHECK:         %[[VAL_2453:.*]] = icmp ult i32 %[[VAL_2452]], 163840
+// CHECK:         %[[VAL_2453:.*]] = icmp ult i32 %[[VAL_2452]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2453]])
 // CHECK:         %[[VAL_2454:.*]] = mul nuw nsw i32 %[[VAL_2452]], 4
 // CHECK:         %[[VAL_2455:.*]] = udiv i32 %[[VAL_2454]], 1
@@ -2679,52 +2679,52 @@
 // CHECK:       r41.in_bounds-after:                              ; preds = %[[VAL_2471]], %[[VAL_2473:.*]]
 // CHECK:         ret void
 // CHECK:       r41.in_bounds-true:                               ; preds = %[[VAL_2473]]
-// CHECK:         %[[VAL_2474:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2445]] to i32*
+// CHECK:         %[[VAL_2474:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2442]] to i32*
 // CHECK:         %[[VAL_2475:.*]] = getelementptr inbounds i32, i32* %[[VAL_2474]], i32 %[[VAL_2454]]
 // CHECK:         %[[VAL_2476:.*]] = load i32, i32* %[[VAL_2475]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2477:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2448]] to i32*
+// CHECK:         %[[VAL_2477:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2445]] to i32*
 // CHECK:         %[[VAL_2478:.*]] = getelementptr inbounds i32, i32* %[[VAL_2477]], i32 %[[VAL_2454]]
 // CHECK:         %[[VAL_2479:.*]] = load i32, i32* %[[VAL_2478]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2480:.*]] = shl i32 %[[VAL_2476]], %[[VAL_2479]]
 // CHECK:         %[[VAL_2481:.*]] = icmp ult i32 %[[VAL_2479]], 32
 // CHECK:         %[[VAL_2482:.*]] = select i1 %[[VAL_2481]], i32 %[[VAL_2480]], i32 0
-// CHECK:         %[[VAL_2483:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2442]] to i32*
+// CHECK:         %[[VAL_2483:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2448]] to i32*
 // CHECK:         %[[VAL_2484:.*]] = getelementptr inbounds i32, i32* %[[VAL_2483]], i32 %[[VAL_2454]]
 // CHECK:         store i32 %[[VAL_2482]], i32* %[[VAL_2484]], align 4
-// CHECK:         %[[VAL_2485:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2445]] to i32*
+// CHECK:         %[[VAL_2485:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2442]] to i32*
 // CHECK:         %[[VAL_2486:.*]] = getelementptr inbounds i32, i32* %[[VAL_2485]], i32 %[[VAL_2458]]
 // CHECK:         %[[VAL_2487:.*]] = load i32, i32* %[[VAL_2486]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2488:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2448]] to i32*
+// CHECK:         %[[VAL_2488:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2445]] to i32*
 // CHECK:         %[[VAL_2489:.*]] = getelementptr inbounds i32, i32* %[[VAL_2488]], i32 %[[VAL_2458]]
 // CHECK:         %[[VAL_2490:.*]] = load i32, i32* %[[VAL_2489]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2491:.*]] = shl i32 %[[VAL_2487]], %[[VAL_2490]]
 // CHECK:         %[[VAL_2492:.*]] = icmp ult i32 %[[VAL_2490]], 32
 // CHECK:         %[[VAL_2493:.*]] = select i1 %[[VAL_2492]], i32 %[[VAL_2491]], i32 0
-// CHECK:         %[[VAL_2494:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2442]] to i32*
+// CHECK:         %[[VAL_2494:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2448]] to i32*
 // CHECK:         %[[VAL_2495:.*]] = getelementptr inbounds i32, i32* %[[VAL_2494]], i32 %[[VAL_2458]]
 // CHECK:         store i32 %[[VAL_2493]], i32* %[[VAL_2495]], align 4
-// CHECK:         %[[VAL_2496:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2445]] to i32*
+// CHECK:         %[[VAL_2496:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2442]] to i32*
 // CHECK:         %[[VAL_2497:.*]] = getelementptr inbounds i32, i32* %[[VAL_2496]], i32 %[[VAL_2462]]
 // CHECK:         %[[VAL_2498:.*]] = load i32, i32* %[[VAL_2497]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2499:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2448]] to i32*
+// CHECK:         %[[VAL_2499:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2445]] to i32*
 // CHECK:         %[[VAL_2500:.*]] = getelementptr inbounds i32, i32* %[[VAL_2499]], i32 %[[VAL_2462]]
 // CHECK:         %[[VAL_2501:.*]] = load i32, i32* %[[VAL_2500]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2502:.*]] = shl i32 %[[VAL_2498]], %[[VAL_2501]]
 // CHECK:         %[[VAL_2503:.*]] = icmp ult i32 %[[VAL_2501]], 32
 // CHECK:         %[[VAL_2504:.*]] = select i1 %[[VAL_2503]], i32 %[[VAL_2502]], i32 0
-// CHECK:         %[[VAL_2505:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2442]] to i32*
+// CHECK:         %[[VAL_2505:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2448]] to i32*
 // CHECK:         %[[VAL_2506:.*]] = getelementptr inbounds i32, i32* %[[VAL_2505]], i32 %[[VAL_2462]]
 // CHECK:         store i32 %[[VAL_2504]], i32* %[[VAL_2506]], align 4
-// CHECK:         %[[VAL_2507:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2445]] to i32*
+// CHECK:         %[[VAL_2507:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2442]] to i32*
 // CHECK:         %[[VAL_2508:.*]] = getelementptr inbounds i32, i32* %[[VAL_2507]], i32 %[[VAL_2466]]
 // CHECK:         %[[VAL_2509:.*]] = load i32, i32* %[[VAL_2508]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2510:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2448]] to i32*
+// CHECK:         %[[VAL_2510:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2445]] to i32*
 // CHECK:         %[[VAL_2511:.*]] = getelementptr inbounds i32, i32* %[[VAL_2510]], i32 %[[VAL_2466]]
 // CHECK:         %[[VAL_2512:.*]] = load i32, i32* %[[VAL_2511]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2513:.*]] = shl i32 %[[VAL_2509]], %[[VAL_2512]]
 // CHECK:         %[[VAL_2514:.*]] = icmp ult i32 %[[VAL_2512]], 32
 // CHECK:         %[[VAL_2515:.*]] = select i1 %[[VAL_2514]], i32 %[[VAL_2513]], i32 0
-// CHECK:         %[[VAL_2516:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2442]] to i32*
+// CHECK:         %[[VAL_2516:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2448]] to i32*
 // CHECK:         %[[VAL_2517:.*]] = getelementptr inbounds i32, i32* %[[VAL_2516]], i32 %[[VAL_2466]]
 // CHECK:         store i32 %[[VAL_2515]], i32* %[[VAL_2517]], align 4
 // CHECK:         br label %[[VAL_2472]]
@@ -2737,9 +2737,9 @@
 // CHECK:         %[[VAL_2526:.*]] = bitcast i8* %[[VAL_2524]] to [100 x [200 x i32]]*
 // CHECK:         %[[VAL_2527:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2528:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2529:.*]] = mul nuw nsw i32 %[[VAL_2527]], 128
+// CHECK:         %[[VAL_2529:.*]] = mul nuw nsw i32 %[[VAL_2527]], 256
 // CHECK:         %[[VAL_2530:.*]] = add nuw nsw i32 %[[VAL_2529]], %[[VAL_2528]]
-// CHECK:         %[[VAL_2531:.*]] = icmp ult i32 %[[VAL_2530]], 163840
+// CHECK:         %[[VAL_2531:.*]] = icmp ult i32 %[[VAL_2530]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2531]])
 // CHECK:         %[[VAL_2532:.*]] = mul nuw nsw i32 %[[VAL_2530]], 4
 // CHECK:         %[[VAL_2533:.*]] = udiv i32 %[[VAL_2532]], 1
@@ -2762,10 +2762,10 @@
 // CHECK:       r42.in_bounds-after:                              ; preds = %[[VAL_2549]], %[[VAL_2551:.*]]
 // CHECK:         ret void
 // CHECK:       r42.in_bounds-true:                               ; preds = %[[VAL_2551]]
-// CHECK:         %[[VAL_2552:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2523]] to i32*
+// CHECK:         %[[VAL_2552:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2520]] to i32*
 // CHECK:         %[[VAL_2553:.*]] = getelementptr inbounds i32, i32* %[[VAL_2552]], i32 %[[VAL_2532]]
 // CHECK:         %[[VAL_2554:.*]] = load i32, i32* %[[VAL_2553]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2555:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2526]] to i32*
+// CHECK:         %[[VAL_2555:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2523]] to i32*
 // CHECK:         %[[VAL_2556:.*]] = getelementptr inbounds i32, i32* %[[VAL_2555]], i32 %[[VAL_2532]]
 // CHECK:         %[[VAL_2557:.*]] = load i32, i32* %[[VAL_2556]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2558:.*]] = ashr i32 %[[VAL_2554]], %[[VAL_2557]]
@@ -2773,13 +2773,13 @@
 // CHECK:         %[[VAL_2560:.*]] = select i1 %[[VAL_2559]], i32 -1, i32 0
 // CHECK:         %[[VAL_2561:.*]] = icmp ult i32 %[[VAL_2557]], 32
 // CHECK:         %[[VAL_2562:.*]] = select i1 %[[VAL_2561]], i32 %[[VAL_2558]], i32 %[[VAL_2560]]
-// CHECK:         %[[VAL_2563:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2520]] to i32*
+// CHECK:         %[[VAL_2563:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2526]] to i32*
 // CHECK:         %[[VAL_2564:.*]] = getelementptr inbounds i32, i32* %[[VAL_2563]], i32 %[[VAL_2532]]
 // CHECK:         store i32 %[[VAL_2562]], i32* %[[VAL_2564]], align 4
-// CHECK:         %[[VAL_2565:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2523]] to i32*
+// CHECK:         %[[VAL_2565:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2520]] to i32*
 // CHECK:         %[[VAL_2566:.*]] = getelementptr inbounds i32, i32* %[[VAL_2565]], i32 %[[VAL_2536]]
 // CHECK:         %[[VAL_2567:.*]] = load i32, i32* %[[VAL_2566]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2568:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2526]] to i32*
+// CHECK:         %[[VAL_2568:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2523]] to i32*
 // CHECK:         %[[VAL_2569:.*]] = getelementptr inbounds i32, i32* %[[VAL_2568]], i32 %[[VAL_2536]]
 // CHECK:         %[[VAL_2570:.*]] = load i32, i32* %[[VAL_2569]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2571:.*]] = ashr i32 %[[VAL_2567]], %[[VAL_2570]]
@@ -2787,13 +2787,13 @@
 // CHECK:         %[[VAL_2573:.*]] = select i1 %[[VAL_2572]], i32 -1, i32 0
 // CHECK:         %[[VAL_2574:.*]] = icmp ult i32 %[[VAL_2570]], 32
 // CHECK:         %[[VAL_2575:.*]] = select i1 %[[VAL_2574]], i32 %[[VAL_2571]], i32 %[[VAL_2573]]
-// CHECK:         %[[VAL_2576:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2520]] to i32*
+// CHECK:         %[[VAL_2576:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2526]] to i32*
 // CHECK:         %[[VAL_2577:.*]] = getelementptr inbounds i32, i32* %[[VAL_2576]], i32 %[[VAL_2536]]
 // CHECK:         store i32 %[[VAL_2575]], i32* %[[VAL_2577]], align 4
-// CHECK:         %[[VAL_2578:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2523]] to i32*
+// CHECK:         %[[VAL_2578:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2520]] to i32*
 // CHECK:         %[[VAL_2579:.*]] = getelementptr inbounds i32, i32* %[[VAL_2578]], i32 %[[VAL_2540]]
 // CHECK:         %[[VAL_2580:.*]] = load i32, i32* %[[VAL_2579]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2581:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2526]] to i32*
+// CHECK:         %[[VAL_2581:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2523]] to i32*
 // CHECK:         %[[VAL_2582:.*]] = getelementptr inbounds i32, i32* %[[VAL_2581]], i32 %[[VAL_2540]]
 // CHECK:         %[[VAL_2583:.*]] = load i32, i32* %[[VAL_2582]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2584:.*]] = ashr i32 %[[VAL_2580]], %[[VAL_2583]]
@@ -2801,13 +2801,13 @@
 // CHECK:         %[[VAL_2586:.*]] = select i1 %[[VAL_2585]], i32 -1, i32 0
 // CHECK:         %[[VAL_2587:.*]] = icmp ult i32 %[[VAL_2583]], 32
 // CHECK:         %[[VAL_2588:.*]] = select i1 %[[VAL_2587]], i32 %[[VAL_2584]], i32 %[[VAL_2586]]
-// CHECK:         %[[VAL_2589:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2520]] to i32*
+// CHECK:         %[[VAL_2589:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2526]] to i32*
 // CHECK:         %[[VAL_2590:.*]] = getelementptr inbounds i32, i32* %[[VAL_2589]], i32 %[[VAL_2540]]
 // CHECK:         store i32 %[[VAL_2588]], i32* %[[VAL_2590]], align 4
-// CHECK:         %[[VAL_2591:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2523]] to i32*
+// CHECK:         %[[VAL_2591:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2520]] to i32*
 // CHECK:         %[[VAL_2592:.*]] = getelementptr inbounds i32, i32* %[[VAL_2591]], i32 %[[VAL_2544]]
 // CHECK:         %[[VAL_2593:.*]] = load i32, i32* %[[VAL_2592]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2594:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2526]] to i32*
+// CHECK:         %[[VAL_2594:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2523]] to i32*
 // CHECK:         %[[VAL_2595:.*]] = getelementptr inbounds i32, i32* %[[VAL_2594]], i32 %[[VAL_2544]]
 // CHECK:         %[[VAL_2596:.*]] = load i32, i32* %[[VAL_2595]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2597:.*]] = ashr i32 %[[VAL_2593]], %[[VAL_2596]]
@@ -2815,7 +2815,7 @@
 // CHECK:         %[[VAL_2599:.*]] = select i1 %[[VAL_2598]], i32 -1, i32 0
 // CHECK:         %[[VAL_2600:.*]] = icmp ult i32 %[[VAL_2596]], 32
 // CHECK:         %[[VAL_2601:.*]] = select i1 %[[VAL_2600]], i32 %[[VAL_2597]], i32 %[[VAL_2599]]
-// CHECK:         %[[VAL_2602:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2520]] to i32*
+// CHECK:         %[[VAL_2602:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2526]] to i32*
 // CHECK:         %[[VAL_2603:.*]] = getelementptr inbounds i32, i32* %[[VAL_2602]], i32 %[[VAL_2544]]
 // CHECK:         store i32 %[[VAL_2601]], i32* %[[VAL_2603]], align 4
 // CHECK:         br label %[[VAL_2550]]
@@ -2828,9 +2828,9 @@
 // CHECK:         %[[VAL_2612:.*]] = bitcast i8* %[[VAL_2610]] to [100 x [200 x i32]]*
 // CHECK:         %[[VAL_2613:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2614:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2615:.*]] = mul nuw nsw i32 %[[VAL_2613]], 128
+// CHECK:         %[[VAL_2615:.*]] = mul nuw nsw i32 %[[VAL_2613]], 256
 // CHECK:         %[[VAL_2616:.*]] = add nuw nsw i32 %[[VAL_2615]], %[[VAL_2614]]
-// CHECK:         %[[VAL_2617:.*]] = icmp ult i32 %[[VAL_2616]], 163840
+// CHECK:         %[[VAL_2617:.*]] = icmp ult i32 %[[VAL_2616]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2617]])
 // CHECK:         %[[VAL_2618:.*]] = mul nuw nsw i32 %[[VAL_2616]], 4
 // CHECK:         %[[VAL_2619:.*]] = udiv i32 %[[VAL_2618]], 1
@@ -2853,52 +2853,52 @@
 // CHECK:       r43.in_bounds-after:                              ; preds = %[[VAL_2635]], %[[VAL_2637:.*]]
 // CHECK:         ret void
 // CHECK:       r43.in_bounds-true:                               ; preds = %[[VAL_2637]]
-// CHECK:         %[[VAL_2638:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2609]] to i32*
+// CHECK:         %[[VAL_2638:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2606]] to i32*
 // CHECK:         %[[VAL_2639:.*]] = getelementptr inbounds i32, i32* %[[VAL_2638]], i32 %[[VAL_2618]]
 // CHECK:         %[[VAL_2640:.*]] = load i32, i32* %[[VAL_2639]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2641:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2612]] to i32*
+// CHECK:         %[[VAL_2641:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2609]] to i32*
 // CHECK:         %[[VAL_2642:.*]] = getelementptr inbounds i32, i32* %[[VAL_2641]], i32 %[[VAL_2618]]
 // CHECK:         %[[VAL_2643:.*]] = load i32, i32* %[[VAL_2642]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2644:.*]] = lshr i32 %[[VAL_2640]], %[[VAL_2643]]
 // CHECK:         %[[VAL_2645:.*]] = icmp ult i32 %[[VAL_2643]], 32
 // CHECK:         %[[VAL_2646:.*]] = select i1 %[[VAL_2645]], i32 %[[VAL_2644]], i32 0
-// CHECK:         %[[VAL_2647:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2606]] to i32*
+// CHECK:         %[[VAL_2647:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2612]] to i32*
 // CHECK:         %[[VAL_2648:.*]] = getelementptr inbounds i32, i32* %[[VAL_2647]], i32 %[[VAL_2618]]
 // CHECK:         store i32 %[[VAL_2646]], i32* %[[VAL_2648]], align 4
-// CHECK:         %[[VAL_2649:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2609]] to i32*
+// CHECK:         %[[VAL_2649:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2606]] to i32*
 // CHECK:         %[[VAL_2650:.*]] = getelementptr inbounds i32, i32* %[[VAL_2649]], i32 %[[VAL_2622]]
 // CHECK:         %[[VAL_2651:.*]] = load i32, i32* %[[VAL_2650]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2652:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2612]] to i32*
+// CHECK:         %[[VAL_2652:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2609]] to i32*
 // CHECK:         %[[VAL_2653:.*]] = getelementptr inbounds i32, i32* %[[VAL_2652]], i32 %[[VAL_2622]]
 // CHECK:         %[[VAL_2654:.*]] = load i32, i32* %[[VAL_2653]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2655:.*]] = lshr i32 %[[VAL_2651]], %[[VAL_2654]]
 // CHECK:         %[[VAL_2656:.*]] = icmp ult i32 %[[VAL_2654]], 32
 // CHECK:         %[[VAL_2657:.*]] = select i1 %[[VAL_2656]], i32 %[[VAL_2655]], i32 0
-// CHECK:         %[[VAL_2658:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2606]] to i32*
+// CHECK:         %[[VAL_2658:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2612]] to i32*
 // CHECK:         %[[VAL_2659:.*]] = getelementptr inbounds i32, i32* %[[VAL_2658]], i32 %[[VAL_2622]]
 // CHECK:         store i32 %[[VAL_2657]], i32* %[[VAL_2659]], align 4
-// CHECK:         %[[VAL_2660:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2609]] to i32*
+// CHECK:         %[[VAL_2660:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2606]] to i32*
 // CHECK:         %[[VAL_2661:.*]] = getelementptr inbounds i32, i32* %[[VAL_2660]], i32 %[[VAL_2626]]
 // CHECK:         %[[VAL_2662:.*]] = load i32, i32* %[[VAL_2661]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2663:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2612]] to i32*
+// CHECK:         %[[VAL_2663:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2609]] to i32*
 // CHECK:         %[[VAL_2664:.*]] = getelementptr inbounds i32, i32* %[[VAL_2663]], i32 %[[VAL_2626]]
 // CHECK:         %[[VAL_2665:.*]] = load i32, i32* %[[VAL_2664]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2666:.*]] = lshr i32 %[[VAL_2662]], %[[VAL_2665]]
 // CHECK:         %[[VAL_2667:.*]] = icmp ult i32 %[[VAL_2665]], 32
 // CHECK:         %[[VAL_2668:.*]] = select i1 %[[VAL_2667]], i32 %[[VAL_2666]], i32 0
-// CHECK:         %[[VAL_2669:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2606]] to i32*
+// CHECK:         %[[VAL_2669:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2612]] to i32*
 // CHECK:         %[[VAL_2670:.*]] = getelementptr inbounds i32, i32* %[[VAL_2669]], i32 %[[VAL_2626]]
 // CHECK:         store i32 %[[VAL_2668]], i32* %[[VAL_2670]], align 4
-// CHECK:         %[[VAL_2671:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2609]] to i32*
+// CHECK:         %[[VAL_2671:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2606]] to i32*
 // CHECK:         %[[VAL_2672:.*]] = getelementptr inbounds i32, i32* %[[VAL_2671]], i32 %[[VAL_2630]]
 // CHECK:         %[[VAL_2673:.*]] = load i32, i32* %[[VAL_2672]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2674:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2612]] to i32*
+// CHECK:         %[[VAL_2674:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2609]] to i32*
 // CHECK:         %[[VAL_2675:.*]] = getelementptr inbounds i32, i32* %[[VAL_2674]], i32 %[[VAL_2630]]
 // CHECK:         %[[VAL_2676:.*]] = load i32, i32* %[[VAL_2675]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2677:.*]] = lshr i32 %[[VAL_2673]], %[[VAL_2676]]
 // CHECK:         %[[VAL_2678:.*]] = icmp ult i32 %[[VAL_2676]], 32
 // CHECK:         %[[VAL_2679:.*]] = select i1 %[[VAL_2678]], i32 %[[VAL_2677]], i32 0
-// CHECK:         %[[VAL_2680:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2606]] to i32*
+// CHECK:         %[[VAL_2680:.*]] = bitcast [100 x [200 x i32]]* %[[VAL_2612]] to i32*
 // CHECK:         %[[VAL_2681:.*]] = getelementptr inbounds i32, i32* %[[VAL_2680]], i32 %[[VAL_2630]]
 // CHECK:         store i32 %[[VAL_2679]], i32* %[[VAL_2681]], align 4
 // CHECK:         br label %[[VAL_2636]]
@@ -2911,8 +2911,8 @@
 // CHECK:         %[[VAL_2690:.*]] = bitcast i8* %[[VAL_2688]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_2691:.*]] = getelementptr inbounds i8, i8* %[[VAL_2692:.*]], i64 0
 // CHECK:         %[[VAL_2693:.*]] = bitcast i8* %[[VAL_2691]] to [100 x [200 x float]]*
-// CHECK:         %[[VAL_2694:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
-// CHECK:         %[[VAL_2695:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
+// CHECK:         %[[VAL_2694:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !94
+// CHECK:         %[[VAL_2695:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !95
 // CHECK:         %[[VAL_2696:.*]] = mul nuw nsw i32 %[[VAL_2694]], 128
 // CHECK:         %[[VAL_2697:.*]] = add nuw nsw i32 %[[VAL_2696]], %[[VAL_2695]]
 // CHECK:         %[[VAL_2698:.*]] = icmp ult i32 %[[VAL_2697]], 163840
@@ -3006,9 +3006,9 @@
 // CHECK:         %[[VAL_2782:.*]] = bitcast i8* %[[VAL_2780]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_2783:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2784:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2785:.*]] = mul nuw nsw i32 %[[VAL_2783]], 128
+// CHECK:         %[[VAL_2785:.*]] = mul nuw nsw i32 %[[VAL_2783]], 256
 // CHECK:         %[[VAL_2786:.*]] = add nuw nsw i32 %[[VAL_2785]], %[[VAL_2784]]
-// CHECK:         %[[VAL_2787:.*]] = icmp ult i32 %[[VAL_2786]], 163840
+// CHECK:         %[[VAL_2787:.*]] = icmp ult i32 %[[VAL_2786]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2787]])
 // CHECK:         %[[VAL_2788:.*]] = mul nuw nsw i32 %[[VAL_2786]], 4
 // CHECK:         %[[VAL_2789:.*]] = udiv i32 %[[VAL_2788]], 1
@@ -3031,68 +3031,68 @@
 // CHECK:       r45.in_bounds-after:                              ; preds = %[[VAL_2805]], %[[VAL_2807:.*]]
 // CHECK:         ret void
 // CHECK:       r45.in_bounds-true:                               ; preds = %[[VAL_2807]]
-// CHECK:         %[[VAL_2808:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2776]] to float*
+// CHECK:         %[[VAL_2808:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2773]] to float*
 // CHECK:         %[[VAL_2809:.*]] = getelementptr inbounds float, float* %[[VAL_2808]], i32 %[[VAL_2788]]
 // CHECK:         %[[VAL_2810:.*]] = load float, float* %[[VAL_2809]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2811:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2779]] to float*
+// CHECK:         %[[VAL_2811:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2776]] to float*
 // CHECK:         %[[VAL_2812:.*]] = getelementptr inbounds float, float* %[[VAL_2811]], i32 %[[VAL_2788]]
 // CHECK:         %[[VAL_2813:.*]] = load float, float* %[[VAL_2812]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2814:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2782]] to float*
+// CHECK:         %[[VAL_2814:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2779]] to float*
 // CHECK:         %[[VAL_2815:.*]] = getelementptr inbounds float, float* %[[VAL_2814]], i32 %[[VAL_2788]]
 // CHECK:         %[[VAL_2816:.*]] = load float, float* %[[VAL_2815]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2817:.*]] = fcmp uge float %[[VAL_2810]], %[[VAL_2813]]
 // CHECK:         %[[VAL_2818:.*]] = select i1 %[[VAL_2817]], float %[[VAL_2810]], float %[[VAL_2813]]
 // CHECK:         %[[VAL_2819:.*]] = fcmp ule float %[[VAL_2816]], %[[VAL_2818]]
 // CHECK:         %[[VAL_2820:.*]] = select i1 %[[VAL_2819]], float %[[VAL_2816]], float %[[VAL_2818]]
-// CHECK:         %[[VAL_2821:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2773]] to float*
+// CHECK:         %[[VAL_2821:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2782]] to float*
 // CHECK:         %[[VAL_2822:.*]] = getelementptr inbounds float, float* %[[VAL_2821]], i32 %[[VAL_2788]]
 // CHECK:         store float %[[VAL_2820]], float* %[[VAL_2822]], align 4
-// CHECK:         %[[VAL_2823:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2776]] to float*
+// CHECK:         %[[VAL_2823:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2773]] to float*
 // CHECK:         %[[VAL_2824:.*]] = getelementptr inbounds float, float* %[[VAL_2823]], i32 %[[VAL_2792]]
 // CHECK:         %[[VAL_2825:.*]] = load float, float* %[[VAL_2824]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2826:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2779]] to float*
+// CHECK:         %[[VAL_2826:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2776]] to float*
 // CHECK:         %[[VAL_2827:.*]] = getelementptr inbounds float, float* %[[VAL_2826]], i32 %[[VAL_2792]]
 // CHECK:         %[[VAL_2828:.*]] = load float, float* %[[VAL_2827]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2829:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2782]] to float*
+// CHECK:         %[[VAL_2829:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2779]] to float*
 // CHECK:         %[[VAL_2830:.*]] = getelementptr inbounds float, float* %[[VAL_2829]], i32 %[[VAL_2792]]
 // CHECK:         %[[VAL_2831:.*]] = load float, float* %[[VAL_2830]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2832:.*]] = fcmp uge float %[[VAL_2825]], %[[VAL_2828]]
 // CHECK:         %[[VAL_2833:.*]] = select i1 %[[VAL_2832]], float %[[VAL_2825]], float %[[VAL_2828]]
 // CHECK:         %[[VAL_2834:.*]] = fcmp ule float %[[VAL_2831]], %[[VAL_2833]]
 // CHECK:         %[[VAL_2835:.*]] = select i1 %[[VAL_2834]], float %[[VAL_2831]], float %[[VAL_2833]]
-// CHECK:         %[[VAL_2836:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2773]] to float*
+// CHECK:         %[[VAL_2836:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2782]] to float*
 // CHECK:         %[[VAL_2837:.*]] = getelementptr inbounds float, float* %[[VAL_2836]], i32 %[[VAL_2792]]
 // CHECK:         store float %[[VAL_2835]], float* %[[VAL_2837]], align 4
-// CHECK:         %[[VAL_2838:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2776]] to float*
+// CHECK:         %[[VAL_2838:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2773]] to float*
 // CHECK:         %[[VAL_2839:.*]] = getelementptr inbounds float, float* %[[VAL_2838]], i32 %[[VAL_2796]]
 // CHECK:         %[[VAL_2840:.*]] = load float, float* %[[VAL_2839]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2841:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2779]] to float*
+// CHECK:         %[[VAL_2841:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2776]] to float*
 // CHECK:         %[[VAL_2842:.*]] = getelementptr inbounds float, float* %[[VAL_2841]], i32 %[[VAL_2796]]
 // CHECK:         %[[VAL_2843:.*]] = load float, float* %[[VAL_2842]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2844:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2782]] to float*
+// CHECK:         %[[VAL_2844:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2779]] to float*
 // CHECK:         %[[VAL_2845:.*]] = getelementptr inbounds float, float* %[[VAL_2844]], i32 %[[VAL_2796]]
 // CHECK:         %[[VAL_2846:.*]] = load float, float* %[[VAL_2845]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2847:.*]] = fcmp uge float %[[VAL_2840]], %[[VAL_2843]]
 // CHECK:         %[[VAL_2848:.*]] = select i1 %[[VAL_2847]], float %[[VAL_2840]], float %[[VAL_2843]]
 // CHECK:         %[[VAL_2849:.*]] = fcmp ule float %[[VAL_2846]], %[[VAL_2848]]
 // CHECK:         %[[VAL_2850:.*]] = select i1 %[[VAL_2849]], float %[[VAL_2846]], float %[[VAL_2848]]
-// CHECK:         %[[VAL_2851:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2773]] to float*
+// CHECK:         %[[VAL_2851:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2782]] to float*
 // CHECK:         %[[VAL_2852:.*]] = getelementptr inbounds float, float* %[[VAL_2851]], i32 %[[VAL_2796]]
 // CHECK:         store float %[[VAL_2850]], float* %[[VAL_2852]], align 4
-// CHECK:         %[[VAL_2853:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2776]] to float*
+// CHECK:         %[[VAL_2853:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2773]] to float*
 // CHECK:         %[[VAL_2854:.*]] = getelementptr inbounds float, float* %[[VAL_2853]], i32 %[[VAL_2800]]
 // CHECK:         %[[VAL_2855:.*]] = load float, float* %[[VAL_2854]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2856:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2779]] to float*
+// CHECK:         %[[VAL_2856:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2776]] to float*
 // CHECK:         %[[VAL_2857:.*]] = getelementptr inbounds float, float* %[[VAL_2856]], i32 %[[VAL_2800]]
 // CHECK:         %[[VAL_2858:.*]] = load float, float* %[[VAL_2857]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2859:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2782]] to float*
+// CHECK:         %[[VAL_2859:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2779]] to float*
 // CHECK:         %[[VAL_2860:.*]] = getelementptr inbounds float, float* %[[VAL_2859]], i32 %[[VAL_2800]]
 // CHECK:         %[[VAL_2861:.*]] = load float, float* %[[VAL_2860]], align 4, !invariant.load !92
 // CHECK:         %[[VAL_2862:.*]] = fcmp uge float %[[VAL_2855]], %[[VAL_2858]]
 // CHECK:         %[[VAL_2863:.*]] = select i1 %[[VAL_2862]], float %[[VAL_2855]], float %[[VAL_2858]]
 // CHECK:         %[[VAL_2864:.*]] = fcmp ule float %[[VAL_2861]], %[[VAL_2863]]
 // CHECK:         %[[VAL_2865:.*]] = select i1 %[[VAL_2864]], float %[[VAL_2861]], float %[[VAL_2863]]
-// CHECK:         %[[VAL_2866:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2773]] to float*
+// CHECK:         %[[VAL_2866:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2782]] to float*
 // CHECK:         %[[VAL_2867:.*]] = getelementptr inbounds float, float* %[[VAL_2866]], i32 %[[VAL_2800]]
 // CHECK:         store float %[[VAL_2865]], float* %[[VAL_2867]], align 4
 // CHECK:         br label %[[VAL_2806]]
@@ -3117,9 +3117,9 @@
 // CHECK:         %[[VAL_2888:.*]] = bitcast i8* %[[VAL_2886]] to [100 x [200 x float]]*
 // CHECK:         %[[VAL_2889:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !90
 // CHECK:         %[[VAL_2890:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !91
-// CHECK:         %[[VAL_2891:.*]] = mul nuw nsw i32 %[[VAL_2889]], 128
+// CHECK:         %[[VAL_2891:.*]] = mul nuw nsw i32 %[[VAL_2889]], 256
 // CHECK:         %[[VAL_2892:.*]] = add nuw nsw i32 %[[VAL_2891]], %[[VAL_2890]]
-// CHECK:         %[[VAL_2893:.*]] = icmp ult i32 %[[VAL_2892]], 163840
+// CHECK:         %[[VAL_2893:.*]] = icmp ult i32 %[[VAL_2892]], 5120
 // CHECK:         call void @llvm.assume(i1 %[[VAL_2893]])
 // CHECK:         %[[VAL_2894:.*]] = mul nuw nsw i32 %[[VAL_2892]], 4
 // CHECK:         %[[VAL_2895:.*]] = udiv i32 %[[VAL_2894]], 1
@@ -3142,56 +3142,56 @@
 // CHECK:       r46.in_bounds-after:                              ; preds = %[[VAL_2911]], %[[VAL_2913:.*]]
 // CHECK:         ret void
 // CHECK:       r46.in_bounds-true:                               ; preds = %[[VAL_2913]]
-// CHECK:         %[[VAL_2914:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2885]] to float*
+// CHECK:         %[[VAL_2914:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2882]] to float*
 // CHECK:         %[[VAL_2915:.*]] = getelementptr inbounds float, float* %[[VAL_2914]], i32 %[[VAL_2894]]
 // CHECK:         %[[VAL_2916:.*]] = load float, float* %[[VAL_2915]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2917:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2888]] to float*
+// CHECK:         %[[VAL_2917:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2885]] to float*
 // CHECK:         %[[VAL_2918:.*]] = getelementptr inbounds float, float* %[[VAL_2917]], i32 %[[VAL_2894]]
 // CHECK:         %[[VAL_2919:.*]] = load float, float* %[[VAL_2918]], align 4, !invariant.load !92
 // CHECK:         store float %[[VAL_2916]], float* %[[VAL_2878]], align 4
 // CHECK:         store float %[[VAL_2919]], float* %[[VAL_2877]], align 4
-// CHECK:         call void @add_F32(float* %[[VAL_2878]], float* %[[VAL_2877]], float* %[[VAL_2879]])
+// CHECK:         call void @region_1_3(float* %[[VAL_2878]], float* %[[VAL_2877]], float* %[[VAL_2879]])
 // CHECK:         %[[VAL_2920:.*]] = load float, float* %[[VAL_2879]], align 4
-// CHECK:         %[[VAL_2921:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2882]] to float*
+// CHECK:         %[[VAL_2921:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2888]] to float*
 // CHECK:         %[[VAL_2922:.*]] = getelementptr inbounds float, float* %[[VAL_2921]], i32 %[[VAL_2894]]
 // CHECK:         store float %[[VAL_2920]], float* %[[VAL_2922]], align 4
-// CHECK:         %[[VAL_2923:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2885]] to float*
+// CHECK:         %[[VAL_2923:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2882]] to float*
 // CHECK:         %[[VAL_2924:.*]] = getelementptr inbounds float, float* %[[VAL_2923]], i32 %[[VAL_2898]]
 // CHECK:         %[[VAL_2925:.*]] = load float, float* %[[VAL_2924]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2926:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2888]] to float*
+// CHECK:         %[[VAL_2926:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2885]] to float*
 // CHECK:         %[[VAL_2927:.*]] = getelementptr inbounds float, float* %[[VAL_2926]], i32 %[[VAL_2898]]
 // CHECK:         %[[VAL_2928:.*]] = load float, float* %[[VAL_2927]], align 4, !invariant.load !92
 // CHECK:         store float %[[VAL_2925]], float* %[[VAL_2875]], align 4
 // CHECK:         store float %[[VAL_2928]], float* %[[VAL_2874]], align 4
-// CHECK:         call void @add_F32(float* %[[VAL_2875]], float* %[[VAL_2874]], float* %[[VAL_2876]])
+// CHECK:         call void @region_1_3(float* %[[VAL_2875]], float* %[[VAL_2874]], float* %[[VAL_2876]])
 // CHECK:         %[[VAL_2929:.*]] = load float, float* %[[VAL_2876]], align 4
-// CHECK:         %[[VAL_2930:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2882]] to float*
+// CHECK:         %[[VAL_2930:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2888]] to float*
 // CHECK:         %[[VAL_2931:.*]] = getelementptr inbounds float, float* %[[VAL_2930]], i32 %[[VAL_2898]]
 // CHECK:         store float %[[VAL_2929]], float* %[[VAL_2931]], align 4
-// CHECK:         %[[VAL_2932:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2885]] to float*
+// CHECK:         %[[VAL_2932:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2882]] to float*
 // CHECK:         %[[VAL_2933:.*]] = getelementptr inbounds float, float* %[[VAL_2932]], i32 %[[VAL_2902]]
 // CHECK:         %[[VAL_2934:.*]] = load float, float* %[[VAL_2933]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2935:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2888]] to float*
+// CHECK:         %[[VAL_2935:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2885]] to float*
 // CHECK:         %[[VAL_2936:.*]] = getelementptr inbounds float, float* %[[VAL_2935]], i32 %[[VAL_2902]]
 // CHECK:         %[[VAL_2937:.*]] = load float, float* %[[VAL_2936]], align 4, !invariant.load !92
 // CHECK:         store float %[[VAL_2934]], float* %[[VAL_2872]], align 4
 // CHECK:         store float %[[VAL_2937]], float* %[[VAL_2871]], align 4
-// CHECK:         call void @add_F32(float* %[[VAL_2872]], float* %[[VAL_2871]], float* %[[VAL_2873]])
+// CHECK:         call void @region_1_3(float* %[[VAL_2872]], float* %[[VAL_2871]], float* %[[VAL_2873]])
 // CHECK:         %[[VAL_2938:.*]] = load float, float* %[[VAL_2873]], align 4
-// CHECK:         %[[VAL_2939:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2882]] to float*
+// CHECK:         %[[VAL_2939:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2888]] to float*
 // CHECK:         %[[VAL_2940:.*]] = getelementptr inbounds float, float* %[[VAL_2939]], i32 %[[VAL_2902]]
 // CHECK:         store float %[[VAL_2938]], float* %[[VAL_2940]], align 4
-// CHECK:         %[[VAL_2941:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2885]] to float*
+// CHECK:         %[[VAL_2941:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2882]] to float*
 // CHECK:         %[[VAL_2942:.*]] = getelementptr inbounds float, float* %[[VAL_2941]], i32 %[[VAL_2906]]
 // CHECK:         %[[VAL_2943:.*]] = load float, float* %[[VAL_2942]], align 4, !invariant.load !92
-// CHECK:         %[[VAL_2944:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2888]] to float*
+// CHECK:         %[[VAL_2944:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2885]] to float*
 // CHECK:         %[[VAL_2945:.*]] = getelementptr inbounds float, float* %[[VAL_2944]], i32 %[[VAL_2906]]
 // CHECK:         %[[VAL_2946:.*]] = load float, float* %[[VAL_2945]], align 4, !invariant.load !92
 // CHECK:         store float %[[VAL_2943]], float* %[[VAL_2869]], align 4
 // CHECK:         store float %[[VAL_2946]], float* %[[VAL_2868]], align 4
-// CHECK:         call void @add_F32(float* %[[VAL_2869]], float* %[[VAL_2868]], float* %[[VAL_2870]])
+// CHECK:         call void @region_1_3(float* %[[VAL_2869]], float* %[[VAL_2868]], float* %[[VAL_2870]])
 // CHECK:         %[[VAL_2947:.*]] = load float, float* %[[VAL_2870]], align 4
-// CHECK:         %[[VAL_2948:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2882]] to float*
+// CHECK:         %[[VAL_2948:.*]] = bitcast [100 x [200 x float]]* %[[VAL_2888]] to float*
 // CHECK:         %[[VAL_2949:.*]] = getelementptr inbounds float, float* %[[VAL_2948]], i32 %[[VAL_2906]]
 // CHECK:         store float %[[VAL_2947]], float* %[[VAL_2949]], align 4
 // CHECK:         br label %[[VAL_2912]]
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
index bc832b4..9581673 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
@@ -46,12 +46,9 @@
             backend().default_stream_executor()->GetAllocator()));
     GpuExecutable* gpu_executable =
         static_cast<GpuExecutable*>(executable.get());
-    std::shared_ptr<const BufferAssignment> buffer_assignment =
-        gpu_executable->GetBufferAssignment();
-    CHECK_EQ(buffer_assignment->Allocations().size(),
-             expected_number_of_allocations)
-        << "Unexpected buffer assignment. Was:\n"
-        << buffer_assignment->ToString();
+    absl::Span<const BufferAllocation> allocations =
+        gpu_executable->GetAllocations();
+    CHECK_EQ(allocations.size(), expected_number_of_allocations);
   }
 };
 
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc
index b32da9c..db4b2ff 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/thunk.cc
@@ -50,6 +50,8 @@
       return "kCudnnBatchNormForwardTraining";
     case Thunk::kCustomCall:
       return "kCustomCall";
+    case Thunk::kNcclAllGather:
+      return "kNcclAllGather";
     case Thunk::kNcclAllReduce:
       return "kNcclAllReduce";
     case Thunk::kNcclAllToAll:
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index dc6febf..ed79f1c 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -59,6 +59,7 @@
     kKernel,
     kMemset32BitValue,
     kMemzero,
+    kNcclAllGather,
     kNcclAllReduce,
     kNcclAllToAll,
     kOutfeed,
diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc
index d401f1d..49174dd 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc
@@ -33,6 +33,10 @@
 
 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
+#endif
+
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
+    (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
 #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
 #endif
 
@@ -215,8 +219,8 @@
     CHECK(feature_index->IsConstant());
     int64 feature_index_value = feature_index->literal().Get<int64>({});
 
-    CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3);
-    CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0),
+    CHECK(custom_call->shape().IsArray());
+    CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape(),
                                            custom_call->operand(0)->shape()));
     CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call);
     CudnnBatchNormConfig config = GetCudnnBatchNormConfig(
@@ -258,8 +262,7 @@
             /*offset=*/GetAllocationSlice(*custom_call->operand(2)),
             /*output_data=*/output_data,
             /*output_mean=*/output_mean,
-            /*output_inv_stddev=*/output_inv_stddev,
-            /*output_tuple=*/GetAllocationSlice(*custom_call)));
+            /*output_inv_stddev=*/output_inv_stddev));
     return Status::OK();
   }
 
@@ -295,88 +298,13 @@
         /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
         /*output_grad_data=*/output_grad_data,
         /*output_grad_scale=*/output_grad_scale,
-        /*output_grad_offset=*/output_grad_offset,
-        /*output_tuple=*/GetAllocationSlice(*custom_call)));
+        /*output_grad_offset=*/output_grad_offset));
     return Status::OK();
   }
 
-  if (IsCustomCallToDnnConvolution(*custom_call)) {
-    std::vector<BufferAllocation::Slice> operand_slices;
-    operand_slices.reserve(custom_call->operand_count());
-    for (const auto* operand : custom_call->operands()) {
-      operand_slices.push_back(GetAllocationSlice(*operand));
-    }
-    auto conv_result_slice = GetAllocationSlice(*custom_call, {0});
-    auto scratch_slice = GetAllocationSlice(*custom_call, {1});
 
-    // Assert that the tuple slice is not used by anyone directly. That is, all
-    // users of the tuple output are get-tuple-element. Also assert that the
-    // second element of the tuple (the scratch buffer) is not used by anyone.
-    for (const HloInstruction* user : custom_call->users()) {
-      TF_RET_CHECK(user->opcode() == HloOpcode::kGetTupleElement &&
-                   user->tuple_index() == 0);
-    }
-
-    TF_ASSIGN_OR_RETURN(
-        GpuConvConfig config,
-        GetGpuConvConfig(Cast<HloCustomCallInstruction>(custom_call)));
-    AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
-        context_->GetThunkInfo(custom_call), std::move(config),
-        std::move(operand_slices), conv_result_slice, scratch_slice));
-    return Status::OK();
-  }
-
-  if (IsCublasGemm(*custom_call)) {
-    AddThunkToThunkSequence(BuildGemmThunk(custom_call));
-    return Status::OK();
-  }
-
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
-  if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) {
-    TF_ASSIGN_OR_RETURN(CholeskyOptions options,
-                        custom_call->backend_config<CholeskyOptions>());
-
-    const Shape& shape = custom_call->operand(0)->shape();
-    int ndim = shape.dimensions_size();
-    CHECK_GE(ndim, 2);
-    int64 n = shape.dimensions(ndim - 1);
-
-    const auto& dims = shape.dimensions();
-    int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1},
-                                       [](int64 a, int64 b) { return a * b; });
-
-    auto operand_buffer = GetAllocationSlice(*custom_call->operand(0));
-
-    auto a_buffer = GetAllocationSlice(*custom_call, {0});
-    auto workspace_buffer = GetAllocationSlice(*custom_call, {1});
-    auto info_buffer = GetAllocationSlice(*custom_call, {2});
-
-    std::vector<std::unique_ptr<Thunk>> thunks;
-
-    if (operand_buffer != a_buffer) {
-      thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
-          context_->GetThunkInfo(custom_call),
-          /*source_address=*/operand_buffer,
-          /*destination_buffer=*/a_buffer,
-          /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
-    }
-
-    thunks.push_back(absl::make_unique<CholeskyThunk>(
-        context_->GetThunkInfo(custom_call), options, a_buffer,
-        workspace_buffer, info_buffer,
-        custom_call->operand(0)->shape().element_type(), batch_size, n));
-
-    // Elide the sequential thunk if there's no copy.
-    if (thunks.size() == 1) {
-      AddThunkToThunkSequence(std::move(thunks[0]));
-    } else {
-      AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
-          context_->GetThunkInfo(custom_call), std::move(thunks)));
-    }
-
-    return Status::OK();
-  }
-
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
+    (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
   if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(
           custom_call->custom_call_target(), std::string(platform_name()))) {
     auto get_slices_for_instr = [&](const HloInstruction* instr) {
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 6323d09..b3de1f0 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -918,7 +918,11 @@
   // function, and that they would be correlated to the same TF op. This might
   // not always be correct since HLO optimizations can cross TF op boundaries.
   // But still this seems to be better than nothing.
-  if (new_instruction->metadata().op_name().empty()) {
+  bool overwrite_dummy_name =
+      absl::StartsWith(new_instruction->metadata().op_name(), "DUMMY") &&
+      !old_instruction->metadata().op_name().empty() &&
+      !absl::StartsWith(old_instruction->metadata().op_name(), "DUMMY");
+  if (new_instruction->metadata().op_name().empty() || overwrite_dummy_name) {
     new_instruction->set_metadata(old_instruction->metadata());
   }
   if (new_instruction->frontend_attributes().map().empty()) {
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 939c713..4ed89c4 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -396,8 +396,11 @@
   for (const auto& dimension : window.dimensions()) {
     window_element_count *= dimension.size();
   }
+
   const int64 output_element_count =
-      ShapeUtil::ElementsIn(reduce_window->shape());
+      ShapeUtil::ElementsIn(reduce_window->shape().IsArray()
+                                ? reduce_window->shape()
+                                : reduce_window->shape().tuple_shapes(0));
   const int64 reduction_count =
       (window_element_count - 1) * output_element_count;
   for (const auto& property : sub_properties) {
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index f101e38..ce94dc0 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -183,6 +183,12 @@
     return GetProperty(key, per_second_rates_);
   }
 
+  // Return the key that is used to index into Properties for the specified
+  // input/output at the shape index.
+  static std::string GetOperandBytesAccessedKey(int64 operand_num,
+                                                ShapeIndex index = {});
+  static std::string GetOutputBytesAccessedKey(ShapeIndex index = {});
+
  protected:
   typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties;
 
@@ -229,12 +235,6 @@
   void SetOutputBytesAccessed(float value);
   void SetOutputBytesAccessed(ShapeIndex index, float value);
 
-  // Return the key that is used to index into Properties for the specified
-  // input/output at the shape index.
-  static std::string GetOperandBytesAccessedKey(int64 operand_num,
-                                                ShapeIndex index = {});
-  static std::string GetOutputBytesAccessedKey(ShapeIndex index = {});
-
   // Function which computes the size of the top-level of a given shape (not
   // including nested elements, if any). If null then bytes_accessed methods
   // return an error.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 8f2b9a6..748eb40 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -449,6 +449,41 @@
   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 4);
 }
 
+TEST_F(HloCostAnalysisTest, ReduceWindowVariadic) {
+  XlaBuilder builder("reduce_window_variadic");
+  auto elem_shape = ShapeUtil::MakeShape(F32, {});
+  auto p2 = Parameter(&builder, 0, elem_shape, "x0");
+  auto p3 = Parameter(&builder, 1, elem_shape, "x1");
+  auto p4 = Parameter(&builder, 2, elem_shape, "y0");
+  auto p5 = Parameter(&builder, 3, elem_shape, "y1");
+  absl::InlinedVector<XlaOp, 2> compute_vec = {Min(p2, p4), Min(p3, p5)};
+  Tuple(&builder, compute_vec);
+  TF_ASSERT_OK_AND_ASSIGN(auto compute_tuple, builder.Build());
+  auto input1 =
+      Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input1");
+  auto input2 =
+      Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {10, 20}), "input2");
+  auto init = ConstantR0<float>(&builder, 0);
+  ReduceWindow({input1, input2}, {init, init}, compute_tuple, {4, 5}, {4, 5},
+               Padding::kValid);
+
+  // Run HLO cost analysis.
+  auto hlo_module = BuildHloGraph(&builder);
+  HloCostAnalysis analysis(ShapeSize);
+  ASSERT_IS_OK(
+      hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+  // Each of [2x4] output elements are generated from reducing [4x5] elements.
+  EXPECT_EQ(analysis.flop_count(), 2 * 4 * 2 * (4 * 5 - 1));
+
+  EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 * 2 + 2 * 3));
+
+  HloInstruction* root = hlo_module->entry_computation()->root_instruction();
+  EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 10 * 20);
+  EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
+  EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 4);
+}
+
 TEST_F(HloCostAnalysisTest, SelectAndScatter) {
   XlaBuilder builder("select_and_scatter");
   auto operand =
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 8d33664..49c0608 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -2024,15 +2024,11 @@
     return false;
   }
 
-  // Two AllReduces are Identical if they have the same channel_id.
-  // Their operands don't have to be Identical.
-  if (!IsCrossModuleAllReduce()) {
-    // Use an explicit loop rather than ContainerEquals, because copying
-    // around std::functions may be too expensive in some cases.
-    for (size_t i = 0; i < operands().size(); ++i) {
-      if (!eq_operands(operand(i), other.operand(i))) {
-        return false;
-      }
+  // Use an explicit loop rather than ContainerEquals, because copying around
+  // std::functions may be too expensive in some cases.
+  for (size_t i = 0; i < operands().size(); ++i) {
+    if (!eq_operands(operand(i), other.operand(i))) {
+      return false;
     }
   }
 
@@ -2542,7 +2538,7 @@
   if (print_ids) {
     return name;
   } else {
-    auto dot_position = name.find_first_of(".");
+    auto dot_position = name.find_first_of('.');
     return name.substr(0, dot_position);
   }
 }
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index c6f6919..98c29af 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1601,8 +1601,19 @@
   const PrecisionConfig& precision_config() const;
   PrecisionConfig* mutable_precision_config();
 
-  // Sets the debug metadata for this instruction.
-  void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
+  // Sets the debug metadata for this instruction, excluding creation_pass_id,
+  // which should never be copied anywhere.
+  void set_metadata(const OpMetadata& metadata) {
+    int64 creation_pass_id = metadata_.creation_pass_id();
+    metadata_ = metadata;
+    metadata_.set_creation_pass_id(creation_pass_id);
+  }
+  void set_creation_pass_id(int64 pass_id) {
+    metadata_.set_creation_pass_id(pass_id);
+  }
+  void set_metadata_op_name(const std::string& name) {
+    metadata_.set_op_name(name);
+  }
   const OpMetadata& metadata() const { return metadata_; }
 
   // Set/get the computation containing this instruction. set_parent should only
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e43f68f..e203e63 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2285,9 +2285,13 @@
 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
     HloCloneContext* context) const {
-  CHECK_EQ(new_operands.size(), 2);
+  CHECK_EQ(new_operands.size() % 2, 0);
+  int64 num_operands = new_operands.size() / 2;
   return absl::make_unique<HloReduceWindowInstruction>(
-      shape, new_operands[0], new_operands[1], window(), to_apply());
+      shape, absl::MakeSpan(new_operands).subspan(0, num_operands),
+      absl::MakeSpan(new_operands)
+          .subspan(num_operands, new_operands.size() / 2),
+      window(), to_apply());
 }
 
 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 961e945..e96f6d7 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -595,6 +595,22 @@
   return n;
 }
 
+std::vector<HloComputation*> HloModule::MakeComputationPostOrder(
+    const absl::flat_hash_set<HloComputation*>& allow_list) const {
+  std::vector<HloComputation*> filtered_post_order(allow_list.size());
+  auto post_order = this->MakeComputationPostOrder();
+
+  int filtered_idx = 0;
+  for (auto& computation : post_order) {
+    if (allow_list.contains(computation)) {
+      filtered_post_order[filtered_idx] = computation;
+      filtered_idx += 1;
+    }
+  }
+
+  return filtered_post_order;
+}
+
 std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
   // First determine all root computations by building a set of nonroot
   // computations (computations which are called by an instruction in the
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 1e10e24..ed379e4 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -203,6 +203,11 @@
   // computation B, then A will appear after B in the sort.
   std::vector<HloComputation*> MakeComputationPostOrder() const;
 
+  // Same as MakeComputationPostOrder() but only returns the computations
+  // that are also found in the passed in allowList
+  std::vector<HloComputation*> MakeComputationPostOrder(
+      const absl::flat_hash_set<HloComputation*>& allow_list) const;
+
   // Same as MakeComputationPostOrder() but sorting the computations by their
   // contents. The order is longer post order.
   std::vector<HloComputation*> MakeComputationSorted() const;
diff --git a/tensorflow/compiler/xla/service/hlo_module_metadata.h b/tensorflow/compiler/xla/service/hlo_module_metadata.h
index 434e3bb..fcb7871 100644
--- a/tensorflow/compiler/xla/service/hlo_module_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_metadata.h
@@ -61,6 +61,12 @@
     module_metadata_.add_partitioned_module_ids(id);
   }
 
+  StatusOr<int64> current_pass_id() {
+    TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata,
+                        GetCurrentHloPassMetadata());
+    return pass_metadata->pass_id();
+  }
+
   // Setters for the current HloPassMetadata.
   Status set_current_pass_name(const std::string& pass_name) {
     return MutateCurrentHloPassMetadata(
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 4c9fa9a..4a48f30 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -394,6 +394,86 @@
   EXPECT_EQ(root->to_apply(), new_comp);
 }
 
+TEST_F(HloModuleTest, OneComputationAllAllowed) {
+  // Create a module with a single computation and
+  // ensure it is available when placed in the allow-list
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputation(CreateConstantComputation());
+
+  absl::flat_hash_set<HloComputation*> allowList = {computation};
+  EXPECT_THAT(module->MakeComputationPostOrder(allowList),
+              ::testing::ElementsAre(computation));
+}
+
+TEST_F(HloModuleTest, OneComputationAllFiltered) {
+  // Create a module with a single computation.
+  auto module = CreateNewVerifiedModule();
+  module->AddEntryComputation(CreateConstantComputation());
+
+  absl::flat_hash_set<HloComputation*> allowList = {};
+  module->MakeComputationPostOrder(allowList);
+  EXPECT_THAT(module->MakeComputationPostOrder(allowList),
+              ::testing::IsEmpty());
+}
+
+TEST_F(HloModuleTest, DiamondComputationsPostOrderAllAllowed) {
+  // Create a module with a diamond call graph of computations.
+  auto module = CreateNewVerifiedModule();
+  auto computation1 =
+      module->AddEmbeddedComputation(CreateConstantComputation());
+  auto computation2 =
+      module->AddEmbeddedComputation(CreateCallComputation({computation1}));
+  auto computation3 =
+      module->AddEmbeddedComputation(CreateCallComputation({computation1}));
+  auto computation4 = module->AddEntryComputation(
+      CreateCallComputation({computation2, computation3}));
+
+  absl::flat_hash_set<HloComputation*> allowList = {computation1, computation2,
+                                                    computation3, computation4};
+  auto post_order = module->MakeComputationPostOrder(allowList);
+  EXPECT_THAT(post_order,
+              ::testing::UnorderedElementsAre(computation1, computation2,
+                                              computation3, computation4));
+  EXPECT_EQ(post_order.back(), computation4);
+  EXPECT_EQ(post_order.front(), computation1);
+}
+
+TEST_F(HloModuleTest, DiamondComputationsPostOrderMiddleFiltered) {
+  // Create a module with a diamond call graph of computations.
+  auto module = CreateNewVerifiedModule();
+  auto computation1 =
+      module->AddEmbeddedComputation(CreateConstantComputation());
+  auto computation2 =
+      module->AddEmbeddedComputation(CreateCallComputation({computation1}));
+  auto computation3 =
+      module->AddEmbeddedComputation(CreateCallComputation({computation1}));
+  auto computation4 = module->AddEntryComputation(
+      CreateCallComputation({computation2, computation3}));
+
+  absl::flat_hash_set<HloComputation*> allowList = {computation1, computation4};
+  auto post_order = module->MakeComputationPostOrder(allowList);
+  EXPECT_THAT(post_order,
+              ::testing::UnorderedElementsAre(computation1, computation4));
+}
+
+TEST_F(HloModuleTest, DiamondComputationsPostOrderAllFiltered) {
+  // Create a module with a diamond call graph of computations.
+  auto module = CreateNewVerifiedModule();
+  auto computation1 =
+      module->AddEmbeddedComputation(CreateConstantComputation());
+  auto computation2 =
+      module->AddEmbeddedComputation(CreateCallComputation({computation1}));
+  auto computation3 =
+      module->AddEmbeddedComputation(CreateCallComputation({computation1}));
+  module->AddEntryComputation(
+      CreateCallComputation({computation2, computation3}));
+
+  absl::flat_hash_set<HloComputation*> allowList = {};
+  auto post_order = module->MakeComputationPostOrder(allowList);
+  EXPECT_THAT(module->MakeComputationPostOrder(allowList),
+              ::testing::IsEmpty());
+}
+
 }  // namespace
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_util.cc b/tensorflow/compiler/xla/service/hlo_module_util.cc
new file mode 100644
index 0000000..106c50c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_util.cc
@@ -0,0 +1,131 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_util.h"
+
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+namespace {
+
+Status ValidateResultShape(const Shape& client_shape,
+                           const Shape& result_shape) {
+  TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
+  if (!ShapeUtil::Compatible(client_shape, result_shape)) {
+    return InvalidArgument(
+        "Shape used to set computation result layout %s is not compatible "
+        "with result shape %s",
+        ShapeUtil::HumanStringWithLayout(client_shape),
+        ShapeUtil::HumanString(result_shape));
+  }
+  return Status::OK();
+}
+}  // namespace
+
+StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
+    const ProgramShape& program_shape,
+    absl::Span<const Shape* const> argument_shapes,
+    const ExecutionOptions* execution_options, int default_num_replicas,
+    absl::optional<int> num_threads, const AotCompilationOptions* aot_options) {
+  auto config = absl::make_unique<HloModuleConfig>(program_shape);
+  ComputationLayout* computation_layout =
+      config->mutable_entry_computation_layout();
+  const int64 argument_shapes_size = argument_shapes.size();
+  if (program_shape.parameters_size() != argument_shapes_size) {
+    return InvalidArgument("computation takes %d parameters, but %u given",
+                           program_shape.parameters_size(),
+                           argument_shapes.size());
+  }
+  for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
+    // Verify that shape of arguments matches the shape of the arguments in the
+    // ProgramShape.
+    if (!ShapeUtil::Compatible(*argument_shapes[i],
+                               program_shape.parameters(i))) {
+      return InvalidArgument(
+          "Argument does not match shape of computation parameter %d: want "
+          "%s, got %s",
+          i, ShapeUtil::HumanString(program_shape.parameters(i)),
+          ShapeUtil::HumanString(*argument_shapes[i]));
+    }
+    TF_RETURN_IF_ERROR(
+        computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
+            *argument_shapes[i]));
+  }
+  if (execution_options != nullptr &&
+      execution_options->has_shape_with_output_layout()) {
+    const Shape shape_with_output_layout(
+        execution_options->shape_with_output_layout());
+    TF_RETURN_IF_ERROR(
+        ValidateResultShape(shape_with_output_layout, program_shape.result()));
+    TF_RETURN_IF_ERROR(
+        computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+            shape_with_output_layout));
+  } else {
+    // If the result layout is not set, then choose the default.
+    computation_layout->mutable_result_layout()->SetToDefaultLayout();
+  }
+
+  if (execution_options != nullptr) {
+    if (execution_options->num_replicas() > 0) {
+      config->set_replica_count(execution_options->num_replicas());
+    } else {
+      config->set_replica_count(default_num_replicas);
+    }
+    if (execution_options->num_partitions() > 0) {
+      config->set_num_partitions(execution_options->num_partitions());
+    }
+    config->set_use_spmd_partitioning(
+        execution_options->use_spmd_partitioning());
+    config->set_deduplicate_hlo(execution_options->deduplicate_hlo());
+    config->set_seed(execution_options->seed());
+    config->set_launch_id(execution_options->launch_id());
+    config->set_debug_options(execution_options->debug_options());
+  } else {
+    config->set_replica_count(default_num_replicas);
+    config->set_debug_options(GetDebugOptionsFromFlags());
+  }
+
+  if (num_threads.has_value()) {
+    config->set_intra_op_parallelism_threads(*num_threads);
+  }
+
+  if (execution_options != nullptr &&
+      execution_options->has_device_assignment()) {
+    TF_ASSIGN_OR_RETURN(
+        auto device_assignment,
+        DeviceAssignment::Deserialize(execution_options->device_assignment()));
+    config->set_static_device_assignment(*device_assignment);
+  }
+  config->set_alias_passthrough_params(
+      execution_options->alias_passthrough_params());
+
+  if (aot_options != nullptr &&
+      aot_options->fusion_config_collection() != FusionConfigCollection::kOff) {
+    config->set_fusion_config_collection(
+        aot_options->fusion_config_collection());
+    *config->mutable_fusion_config() = aot_options->fusion_config();
+  }
+
+  return std::move(config);
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_util.h b/tensorflow/compiler/xla/service/hlo_module_util.h
new file mode 100644
index 0000000..93d11ea
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_util.h
@@ -0,0 +1,44 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_UTIL_H_
+
+#include <memory>
+
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+// Creates an HloModuleConfig for a given program shape and arguments.
+// If execution_options does not set num_replicas, default_num_replicas is used.
+// num_threads is optional; if not given, intra_op_parallelism_threads not set.
+// aot_options is optional; if not given a default is used.
+StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
+    const ProgramShape& program_shape,
+    absl::Span<const Shape* const> argument_shapes,
+    const ExecutionOptions* execution_options, int default_num_replicas,
+    absl::optional<int> num_threads = absl::nullopt,
+    const AotCompilationOptions* aot_options = nullptr);
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 558a502..675b60b 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1366,16 +1366,25 @@
       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
                            &reduce_computation};
-      if (!ParseOperands(&operands, /*expected_size=*/2) ||
-          !ParseAttributes(attrs)) {
+      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
         return false;
       }
       if (!window) {
         window.emplace();
       }
+      if (operands.size() % 2) {
+        auto loc = lexer_.GetLoc();
+        return Error(loc, StrCat("expects an even number of operands, but has ",
+                                 operands.size(), " operands"));
+      }
       instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
-          shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
-          *reduce_computation));
+          shape, /*operands=*/
+          absl::Span<HloInstruction* const>(operands).subspan(
+              0, operands.size() / 2),
+          /*init_values=*/
+          absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
+                                                              2),
+          *window, *reduce_computation));
       break;
     }
     case HloOpcode::kConvolution: {
@@ -3585,7 +3594,7 @@
 }
 
 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
-// The string looks like "dim_labels=0bf_0io->0bf".
+// Thestring looks like "dim_labels=0bf_0io->0bf".
 bool HloParserImpl::ParseConvolutionDimensionNumbers(
     ConvolutionDimensionNumbers* dnums) {
   if (lexer_.GetKind() != TokKind::kDimLabels) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index dc94e30..27b0de5 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -439,6 +439,29 @@
 
 )"
 },
+// reduce window on scalar
+{
+"ReduceWindowVariadic",
+R"(HloModule reduce_window_variadic
+
+%add_F32.v3 (lhs1: f32[], lhs2: f32[], rhs1: f32[], rhs2: f32[]) -> (f32[], f32[]) {
+  %lhs1 = f32[] parameter(0)
+  %rhs1 = f32[] parameter(2)
+  %add1 = f32[] add(f32[] %lhs1, f32[] %rhs1)
+  %lhs2 = f32[] parameter(1)
+  %rhs2 = f32[] parameter(3)
+  %add2 = f32[] add(f32[] %lhs2, f32[] %rhs2)
+  ROOT %tuple1 = (f32[], f32[]) tuple(f32[] %add1, f32[] %add2)
+}
+
+ENTRY %R4UnitWindowScalar () -> (f32[], f32[]) {
+  %constant = f32[] constant(42)
+  %constant.1 = f32[] constant(1)
+  ROOT %reduce-window = (f32[], f32[]) reduce-window(f32[] %constant, f32[] %constant, f32[] %constant.1, f32[] %constant.1), to_apply=%add_F32.v3
+}
+
+)"
+},
 // convolution
 {
 "Convolution",
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 41f907f..6f25cb2 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -67,7 +67,7 @@
   Status status =
       AttemptRecordPassEndMetadata(module, pass_name, module_changed);
   if (!status.ok()) {
-    LOG(FATAL) << status.error_message();
+    LOG(FATAL) << status;
   }
 }
 
@@ -91,7 +91,30 @@
   Status status =
       AttemptRecordPassEndMetadata(module_group, pass_name, module_changed);
   if (!status.ok()) {
-    LOG(FATAL) << status.error_message();
+    LOG(FATAL) << status;
+  }
+}
+
+void SetInstructionMetadata(HloModule& module) {
+  StatusOr<int64> pass_id = module.metadata()->current_pass_id();
+  if (!pass_id.ok()) {
+    LOG(FATAL) << pass_id.status();
+  }
+  for (xla::HloComputation* computation : module.computations()) {
+    for (xla::HloInstruction* instruction : computation->instructions()) {
+      if (instruction->metadata().creation_pass_id() == 0) {
+        instruction->set_creation_pass_id(*pass_id);
+      }
+      if (instruction->metadata().op_name().empty()) {
+        instruction->set_metadata_op_name(absl::StrCat("DUMMY_", *pass_id));
+      }
+    }
+  }
+}
+
+void SetInstructionMetadata(HloModuleGroup& module_group) {
+  for (HloModule* module : module_group.modules()) {
+    SetInstructionMetadata(*module);
   }
 }
 
@@ -127,6 +150,7 @@
   TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart));
 
   RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name);
+  SetInstructionMetadata(*hlo);
   MaybeDumpHloAndSaveFilenames(*hlo,
                                /*after_pass_name=*/kPipelineStart,
                                /*before_pass_name=*/passes.empty()
@@ -147,6 +171,7 @@
     }
     RecordPassStartMetadata(*hlo, pass_name, pipeline_name);
     TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
+    SetInstructionMetadata(*hlo);
     MaybeDumpHloAndSaveFilenames(*hlo,
                                  /*after_pass_name=*/pass_name,
                                  /*before_pass_name=*/i + 1 >= passes.size()
@@ -216,7 +241,7 @@
            name(), before_pass_name, after_pass_name, module)) {
     Status status = module.metadata()->add_current_pass_dump_filename(filename);
     if (!status.ok()) {
-      LOG(FATAL) << status.error_message();
+      LOG(FATAL) << status;
     }
   }
 }
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 84e4fe6..80d2cd3 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1086,6 +1086,7 @@
     case HloOpcode::kRecv:
     case HloOpcode::kRecvDone:
     case HloOpcode::kReducePrecision:
+    case HloOpcode::kReduceWindow:
     case HloOpcode::kTupleSelect:
     case HloOpcode::kSend:
     case HloOpcode::kSendDone:
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index 3f3e74d..8b0a046 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -95,7 +95,7 @@
 
 StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
     std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
-    se::DeviceMemoryAllocator* /*device_allocator*/) {
+    const CompileOptions& /*options*/) {
   VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
   TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
   return std::move(hlo_module);
@@ -103,7 +103,7 @@
 
 StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
     std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* /*device_allocator*/) {
+    const CompileOptions& /*options*/) {
   TF_RET_CHECK(stream_exec != nullptr);
 
   VLOG(1) << "Run backend " << hlo_module->name();
@@ -128,7 +128,7 @@
 StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
     std::unique_ptr<HloModuleGroup> module_group,
     std::vector<std::vector<se::StreamExecutor*>> stream_exec,
-    se::DeviceMemoryAllocator* device_allocator) {
+    const CompileOptions& options) {
   if (module_group->empty()) {
     return std::vector<std::unique_ptr<Executable>>();
   }
@@ -141,12 +141,10 @@
         "Unexpected number of StreamExecutor's.");
   }
   auto hlo_modules = module_group->ConsumeModules();
-  TF_ASSIGN_OR_RETURN(auto module,
-                      RunHloPasses(std::move(hlo_modules[0]), stream_exec[0][0],
-                                   device_allocator));
-  TF_ASSIGN_OR_RETURN(
-      auto executable,
-      RunBackend(std::move(module), stream_exec[0][0], device_allocator));
+  TF_ASSIGN_OR_RETURN(auto module, RunHloPasses(std::move(hlo_modules[0]),
+                                                stream_exec[0][0], options));
+  TF_ASSIGN_OR_RETURN(auto executable, RunBackend(std::move(module),
+                                                  stream_exec[0][0], options));
   std::vector<std::unique_ptr<Executable>> ret;
   ret.push_back(std::move(executable));
   return std::move(ret);
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h
index 824594d..2136bc9 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.h
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.h
@@ -45,14 +45,14 @@
 
   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
       std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
+      const CompileOptions& options) override;
   StatusOr<std::unique_ptr<Executable>> RunBackend(
       std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
+      const CompileOptions& options) override;
   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
       std::unique_ptr<HloModuleGroup> module_group,
       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
+      const CompileOptions& options) override;
 
   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc
index aa759b2..8603cd5 100644
--- a/tensorflow/compiler/xla/service/llvm_compiler.cc
+++ b/tensorflow/compiler/xla/service/llvm_compiler.cc
@@ -24,7 +24,7 @@
 StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
     std::unique_ptr<HloModuleGroup> module_group,
     std::vector<std::vector<se::StreamExecutor*>> stream_execs,
-    se::DeviceMemoryAllocator* device_allocator) {
+    const CompileOptions& options) {
   // Tensorflow tries to enable the following behaviors in all its threads:
   //
   //  - Denormals are zero (DAZ): roughly, operations treat denormal floats as
@@ -48,10 +48,10 @@
 
     TF_ASSIGN_OR_RETURN(modules[i],
                         RunHloPasses(std::move(modules[i]), stream_execs[i][0],
-                                     device_allocator));
+                                     options.device_allocator));
     TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
                         RunBackend(std::move(modules[i]), stream_execs[i][0],
-                                   device_allocator));
+                                   options.device_allocator));
     result.push_back(std::move(executable));
   }
 
diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h
index bddda50..7f0c617 100644
--- a/tensorflow/compiler/xla/service/llvm_compiler.h
+++ b/tensorflow/compiler/xla/service/llvm_compiler.h
@@ -66,13 +66,14 @@
   //       std::unique_ptr<HloModule> module,
   //       se::StreamExecutor* stream_exec,
   //       se::DeviceMemoryAllocator* device_allocator)
+  using Compiler::Compile;
   using Compiler::RunBackend;
   using Compiler::RunHloPasses;
 
   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
       std::unique_ptr<HloModuleGroup> module_group,
       std::vector<std::vector<se::StreamExecutor*>> stream_execs,
-      se::DeviceMemoryAllocator* device_allocator) override;
+      const CompileOptions& options) override;
 
  protected:
   ModuleHook user_pre_optimization_hook_;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
index 3312163..3313954 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
@@ -43,7 +43,11 @@
 
 string SanitizeConstantName(const HloInstruction& instr) {
   CHECK_EQ(instr.opcode(), HloOpcode::kConstant);
-  string instr_name = instr.name();
+  return SanitizeConstantName(instr.name());
+}
+
+string SanitizeConstantName(absl::string_view name) {
+  std::string instr_name(name);
   for (char& c : instr_name) {
     // Having a hyphen or a dot in a global variable name can crash the LLVM PTX
     // backend.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h
index 2e2d3bf..2702a61 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h
@@ -24,6 +24,7 @@
 // name of the corresponding constant buffer. In particular, it replaces . and
 // - with _.
 string SanitizeConstantName(const HloInstruction& instr);
+string SanitizeConstantName(absl::string_view name);
 
 string ConstantHloToGlobalName(const HloInstruction& instr);
 
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 0eff81c..e187675 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -32,6 +32,7 @@
 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_module_util.h"
 #include "tensorflow/compiler/xla/service/platform_util.h"
 #include "tensorflow/compiler/xla/shape_layout.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -94,36 +95,6 @@
   return absl::nullopt;
 }
 
-ExecutionOptions CreateExecutionOptions(
-    const ExecutableBuildOptions& build_options,
-    const ProgramShape* program_shape) {
-  ExecutionOptions execution_options = CreateDefaultExecutionOptions();
-  if (build_options.has_debug_options()) {
-    *execution_options.mutable_debug_options() = build_options.debug_options();
-  }
-  if (build_options.result_layout() != nullptr) {
-    *execution_options.mutable_shape_with_output_layout() =
-        build_options.result_layout()->ToProto();
-  } else {
-    Shape result_shape(program_shape->result());
-    LayoutUtil::SetToDefaultLayout(&result_shape);
-    *execution_options.mutable_shape_with_output_layout() =
-        result_shape.ToProto();
-  }
-  execution_options.set_num_replicas(build_options.num_replicas());
-  execution_options.set_num_partitions(build_options.num_partitions());
-  execution_options.set_use_spmd_partitioning(
-      build_options.use_spmd_partitioning());
-  execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo());
-  if (build_options.has_device_assignment()) {
-    TF_CHECK_OK(build_options.device_assignment().Serialize(
-        execution_options.mutable_device_assignment()));
-  }
-  execution_options.set_alias_passthrough_params(
-      build_options.alias_passthrough_params());
-  return execution_options;
-}
-
 }  // namespace
 
 StatusOr<std::vector<std::unique_ptr<Executable>>>
@@ -190,11 +161,12 @@
   // single partition computations are built using `BuildExecutables`, fix it,
   // and remove this special case (provided the performance if similar).
   if (build_options.num_partitions() == 1) {
-    TF_ASSIGN_OR_RETURN(
-        std::unique_ptr<Executable> executable,
-        BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
-                        executor, build_options.device_allocator(),
-                        build_options.run_backend_only()));
+    TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+                        BuildExecutable(proto, std::move(module_config),
+                                        execute_backend_.get(), executor,
+                                        {build_options.device_allocator(),
+                                         build_options.compile_thread_pool()},
+                                        build_options.run_backend_only()));
     std::vector<std::unique_ptr<Executable>> executables;
     executables.push_back(std::move(executable));
     return executables;
@@ -206,10 +178,12 @@
     std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
                                                executor);
 
-    return BuildExecutables({&proto}, std::move(module_configs),
-                            execute_backend_.get(), {executors},
-                            build_options.device_allocator(),
-                            build_options.run_backend_only());
+    return BuildExecutables(
+        /*module_protos=*/{&proto}, std::move(module_configs),
+        execute_backend_.get(), {executors},
+        Compiler::CompileOptions{build_options.device_allocator(),
+                                 build_options.compile_thread_pool()},
+        build_options.run_backend_only());
   }
 }
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index c86abfd..473d153 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -893,6 +893,9 @@
 bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
     const AllocationValue& value, const HloUse& use) const {
   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
+  if (!options_.is_use_allowed_in_alternate_mem_fn(use)) {
+    return false;
+  }
   if (use.instruction->opcode() == HloOpcode::kWhile) {
     HloComputation* while_body = use.instruction->while_body();
 
@@ -2001,14 +2004,26 @@
   }
 
   if (required_assignment_at_start) {
-    if (!allocation_sequence->empty() &&
-        required_assignment_at_start->memory_space == MemorySpace::kAlternate) {
-      const auto& prev_allocation = allocation_sequence->back();
-      CHECK(prev_allocation->memory_space() ==
-            required_assignment_at_start->memory_space);
-      CHECK_EQ(GetAliasedOffset(*prev_allocation),
-               required_assignment_at_start->offset);
-      prev_allocation->Extend(request.start_time);
+    if (!allocation_sequence->empty()) {
+      // We shouldn't have a situation where the required assignment at start is
+      // at alternate memory space and we have existing allocations in the
+      // allocation sequence. The only time we'll have required assignment at
+      // start to be in the alternate memory space is in called computations
+      // (e.g., while body) and we shouldn't have any allocations in the
+      // allocation sequence so far.
+      CHECK(required_assignment_at_start->memory_space ==
+            MemorySpace::kDefault);
+      // Find the previous allocation in default memory (might not be the very
+      // last one) and extend its lifetime to include the start time of this
+      // segment.
+      auto prev_allocation_in_default_mem_it = std::find_if(
+          allocation_sequence->rbegin(), allocation_sequence->rend(),
+          [&](const auto& allocation) {
+            return allocation->memory_space() == MemorySpace::kDefault &&
+                   allocation->defining_position() == defining_position;
+          });
+      CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
+      (*prev_allocation_in_default_mem_it)->Extend(request.start_time);
     } else {
       absl::optional<Chunk> aliased_chunk = absl::nullopt;
       if (required_assignment_at_start->memory_space ==
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 341bf7e..7bffcc2 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -400,6 +400,8 @@
       GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare;
   using IsAllowedInAlternateMemoryFunction =
       std::function<bool(const HloValue&)>;
+  using IsUseAllowedInAlternateMemoryFunction =
+      std::function<bool(const HloUse&)>;
 
   // MemorySpaceAssignment uses a notion of a slow and large default memory
   // space and a fast and small alternate memory space.
@@ -434,6 +436,11 @@
     // the opcode) to be placed on the alternate memory.
     IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_fn;
 
+    // This function can be used to prevent certain HloUses (e.g., based on
+    // the opcode) to be placed on the alternate memory.
+    IsUseAllowedInAlternateMemoryFunction is_use_allowed_in_alternate_mem_fn =
+        [](const HloUse&) { return true; };
+
     // Specifies the upper bound for number of outstanding prefetches and
     // evictions, -1 for unlimited.
     int64 max_outstanding_prefetches = -1;
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index b2b8ebc..b314b5a 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -4152,6 +4152,142 @@
             find_schedule_index(cos->operand(0)));
 }
 
+TEST_P(MemorySpaceAssignmentTest, DisallowedUseBug) {
+  // When we have a disallowed use (in this case tanh), we aren't allowed to
+  // allocate this use in alternate memory. However, if we have another use
+  // after this on the same buffer (o), this use may refer to "a" instead of the
+  // evicted value, which is illegal because "a" will be allocated in the
+  // alternate memory space.
+  absl::string_view hlo_string = R"(
+  HloModule bug, is_scheduled=true
+
+  ENTRY Entry {
+    param0 = f32[8,3] parameter(0)
+    param1 = f32[2,4] parameter(1)
+    a = f32[8,3] cosine(param0)
+    b = f32[2,4] negate(param1)
+    d = f32[8,3] negate(a)
+    c = f32[2,4] negate(b)
+    e = f32[2,4] negate(c)
+    f = f32[8,3] tanh(a)
+    g = f32[2,4] negate(e)
+    h = f32[2,4] negate(g)
+    i = f32[2,4] negate(h)
+    j = f32[2,4] negate(i)
+    k = f32[2,4] negate(j)
+    l = f32[2,4] negate(k)
+    m = f32[2,4] negate(l)
+    n = f32[2,4] sine(m)
+    o = f32[8,3] negate(a)
+    p = f32[2,4] negate(n)
+    q = f32[8,3] add(o, f)
+    r = f32[8,3] add(q, d)
+    ROOT tuple = (f32[2,4], f32[8,3]) tuple(p, r)
+  }
+  )";
+
+  MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
+      [](const MemorySpaceAssignment::BufferInterval& a,
+         const MemorySpaceAssignment::BufferInterval& b) {
+        auto get_opcode_priority = [](const HloOpcode& opcode) {
+          switch (opcode) {
+            case HloOpcode::kSin:
+              return 0;
+            case HloOpcode::kCos:
+              return 1;
+            case HloOpcode::kTanh:
+              return 2;
+            default:
+              return 3;
+          }
+        };
+
+        return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
+               get_opcode_priority(b.buffer->defining_instruction()->opcode());
+      };
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
+  MemorySpaceAssignment::Options options;
+  options.max_size_in_bytes = 128;
+  options.alignment_in_bytes = 8;
+  options.verify = true;
+  options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
+    return use.instruction->opcode() != HloOpcode::kTanh;
+  };
+  AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
+                    buffer_interval_compare, &prefetch_interval_picker,
+                    options);
+}
+
+TEST_P(MemorySpaceAssignmentTest, DisallowedUseBugInWhile) {
+  // Test for situations where we disallow a use (tanh in this case) in the
+  // alternate memory space and there is a subsequent use that also requires the
+  // buffer to be in the default memory space. In this case, the allocation in
+  // the default memory space might not be the very last one, so we need to
+  // search the allocation sequence and find the one in the default memory
+  // space.
+  absl::string_view hlo_string = R"(
+  HloModule module, is_scheduled=true
+
+  while_cond {
+    p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+    ROOT gte = pred[] get-tuple-element(p0), index=3
+  }
+
+  while_body {
+    p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+    gte0 = f32[3]{0} get-tuple-element(p0), index=0
+    gte1 = f32[3]{0} get-tuple-element(p0), index=1
+    gte2 = f32[3]{0} get-tuple-element(p0), index=2
+    gte3 = pred[] get-tuple-element(p0), index=3
+    add = f32[3]{0} add(gte0, gte0)
+    negate0 = f32[3]{0} negate(add)
+    negate1 = f32[3]{0} negate(negate0)
+    negate2 = f32[3]{0} negate(negate1)
+    negate3 = f32[3]{0} negate(negate2)
+    negate4 = f32[3]{0} negate(negate3)
+    negate5 = f32[3]{0} negate(negate4)
+    negate6 = f32[3]{0} negate(negate5)
+    negate7 = f32[3]{0} negate(negate6)
+    negate8 = f32[3]{0} negate(negate7)
+    negate9 = f32[3]{0} negate(negate8)
+    negate10 = f32[3]{0} negate(negate9)
+    negate11 = f32[3]{0} negate(negate10)
+    negate12 = f32[3]{0} negate(negate11)
+    negate13 = f32[3]{0} negate(negate12)
+    negate14 = f32[3]{0} negate(negate13)
+    negate15 = f32[3]{0} negate(gte2)
+    tanh = f32[3]{0} tanh(gte2)
+    ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(negate14, tanh, gte2, gte3)
+  }
+
+  ENTRY entry {
+    p0 = f32[3]{0} parameter(0)
+    p1 = pred[] parameter(1)
+    copy0 = f32[3]{0} copy(p0)
+    copy1 = f32[3]{0} copy(p0)
+    tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
+    while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
+    ROOT gte = f32[3]{0} get-tuple-element(while), index=2
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  MemorySpaceAssignment::Options options;
+  options.max_size_in_bytes = 128;
+  options.alignment_in_bytes = 8;
+  options.verify = true;
+  options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
+    return use.instruction->opcode() != HloOpcode::kTanh;
+  };
+  AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
+                    /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
+                    options);
+}
+
 TEST_P(MemorySpaceAssignmentTest, BitcastRoot) {
   // Tests against a bug where the root of entry computation is a bitcast
   // instruction and it ends up getting an allocation in the alternate memory.
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
index 4eaed3a..4b2c04d 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD
+++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
@@ -9,10 +9,6 @@
 # buildifier: disable=same-origin-load
 load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
 load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
-load(
-    "//tensorflow/core/platform/default:cuda_build_defs.bzl",
-    "if_cuda_is_configured",
-)
 load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
 
 package(
@@ -34,139 +30,6 @@
     ]),
 )
 
-cc_library(
-    name = "failover_compiler",
-    srcs = ["failover_compiler.cc"],
-    hdrs = ["failover_compiler.h"],
-    deps = [
-        "//tensorflow/compiler/xla/service:compiler",
-        "//tensorflow/core:lib",
-    ],
-)
-
-cc_library(
-    name = "emission_context",
-    srcs = ["emission_context.cc"],
-    hdrs = ["emission_context.h"],
-    deps = [
-        "//tensorflow/compiler/mlir/hlo",
-        "//tensorflow/compiler/mlir/hlo:lhlo",
-        "//tensorflow/compiler/xla/service:hlo",
-        "@com_google_absl//absl/strings",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:StandardOps",
-    ],
-)
-
-cc_library(
-    name = "inject_errors_pass",
-    srcs = ["inject_errors_pass.cc"],
-    hdrs = ["inject_errors_pass.h"],
-    deps = [
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:StandardOps",
-    ],
-)
-
-cc_library(
-    name = "mlir_compiler",
-    srcs = ["mlir_compiler.cc"],
-    hdrs = ["mlir_compiler.h"],
-    deps = [
-        ":emission_context",
-        "//tensorflow/compiler/xla/service:compiler",
-        "//tensorflow/compiler/xla/service/gpu:target_constants",
-        "//tensorflow/core/platform:stream_executor_no_cuda",
-        "@llvm-project//llvm:Core",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:LLVMDialect",
-    ],
-)
-
-cc_library(
-    name = "mlir_compiler_impl",
-    srcs = if_cuda_is_configured(["mlir_compiler_impl.cc"]),
-    deps = if_cuda_is_configured([
-        ":mlir_compiler",
-        ":failover_compiler",
-        ":emission_context",
-        ":kernel_lowering",
-        ":lhlo_dialect_emitter",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@llvm-project//llvm:Core",
-        "@llvm-project//mlir:GPUDialect",
-        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:LLVMDialect",
-        "@llvm-project//mlir:LLVMTransforms",
-        "@llvm-project//mlir:StandardOps",
-        "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:TargetNVVMIR",
-        "//tensorflow/compiler/xla:util",
-        "//tensorflow/compiler/xla/service:buffer_assignment",
-        "//tensorflow/compiler/xla/service:dump",
-        "//tensorflow/compiler/xla/service:hlo",
-        "//tensorflow/compiler/xla/service/gpu:gpu_constants",
-        "//tensorflow/compiler/xla/service/gpu:gpu_executable",
-        "//tensorflow/compiler/xla/service/gpu:gpu_hlo_schedule",
-        "//tensorflow/compiler/xla/service/gpu:gpu_types",
-        "//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
-        "//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl",
-        "//tensorflow/compiler/xla/service/gpu:launch_dimensions",
-        "//tensorflow/compiler/xla/service/gpu:stream_assignment",
-        "//tensorflow/compiler/xla/service/gpu:stream_executor_util",
-        "//tensorflow/compiler/xla/service/gpu:target_constants",
-        "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
-        "//tensorflow/core/platform:cuda_libdevice_path",
-        "//tensorflow/core:lib",
-        "//tensorflow/stream_executor/gpu:asm_compiler",
-    ]),
-    alwayslink = True,  # Contains compiler registration
-)
-
-cc_library(
-    name = "hlo_dialect_emitter",
-    srcs = ["hlo_dialect_emitter.cc"],
-    hdrs = ["hlo_dialect_emitter.h"],
-    deps = [
-        ":emission_context",
-        "//tensorflow/compiler/mlir/hlo",
-        "//tensorflow/compiler/mlir/xla:hlo_utils",
-        "//tensorflow/compiler/xla:comparison_util",
-        "//tensorflow/compiler/xla:status",
-        "//tensorflow/compiler/xla/service:hlo",
-        "@com_google_absl//absl/types:span",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:StandardOps",
-    ],
-)
-
-cc_library(
-    name = "lhlo_dialect_emitter",
-    srcs = ["lhlo_dialect_emitter.cc"],
-    hdrs = ["lhlo_dialect_emitter.h"],
-    deps = [
-        ":emission_context",
-        ":hlo_dialect_emitter",
-        "//tensorflow/compiler/mlir/hlo:lhlo",
-        "//tensorflow/compiler/mlir/xla:hlo_utils",
-        "//tensorflow/compiler/xla:status",
-        "//tensorflow/compiler/xla:status_macros",
-        "//tensorflow/compiler/xla/service:buffer_assignment",
-        "//tensorflow/compiler/xla/service:hlo",
-        "//tensorflow/compiler/xla/service/gpu:thunk",
-        "//tensorflow/compiler/xla/service/gpu:thunk_emitter",
-        "//tensorflow/core:lib",
-        "//tensorflow/stream_executor:stream_executor_headers",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@llvm-project//llvm:Core",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:LLVMDialect",
-        "@llvm-project//mlir:StandardOps",
-    ],
-)
-
 gentbl(
     name = "passes_inc_gen",
     compatible_with = get_compatible_with_cloud(),
@@ -238,51 +101,6 @@
     ],
 )
 
-cc_library(
-    name = "xla_gpu_opt_lib",
-    testonly = True,
-    srcs = ["xla_gpu_opt.cc"],
-    hdrs = ["xla_gpu_opt.h"],
-    tags = ["no_pip"],
-    deps = [
-        ":failover_compiler",
-        ":inject_errors_pass",
-        ":mlir_compiler",
-        "//tensorflow/compiler/xla:debug_options_flags",
-        "//tensorflow/compiler/xla:shape_util",
-        "//tensorflow/compiler/xla:status",
-        "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla/service:backend",
-        "//tensorflow/compiler/xla/service:hlo_module_config",
-        "//tensorflow/compiler/xla/tests:verified_hlo_module",
-        "//tensorflow/core:lib",
-        "//tensorflow/stream_executor/lib",
-        "@com_google_absl//absl/strings",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Pass",
-    ],
-)
-
-tf_cc_binary(
-    name = "xla-gpu-opt",
-    testonly = True,
-    srcs = ["xla_gpu_opt_main.cc"],
-    tags = ["no_pip"],
-    deps = [
-        ":mlir_compiler",
-        ":xla_gpu_opt_lib",
-        "//tensorflow/compiler/mlir:init_mlir",
-        "//tensorflow/compiler/xla:status",
-        "//tensorflow/compiler/xla/service:gpu_plugin_mlir",
-        "//tensorflow/core:lib",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:SideEffects",
-        "@llvm-project//mlir:Support",
-    ],
-)
-
 tf_cc_binary(
     name = "xla-mlir-gpu-opt",
     srcs = ["xla_mlir_gpu_opt.cc"],
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc
deleted file mode 100644
index 06c7ebd..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc
+++ /dev/null
@@ -1,137 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
-
-#include "absl/strings/substitute.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
-#include "mlir/IR/Location.h"  // from @llvm-project
-#include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-
-namespace xla {
-namespace mlir_gpu {
-
-EmissionContext::EmissionContext(std::unique_ptr<HloModule> module)
-    : module_(std::move(module)), context_() {
-  context_.loadDialect<mlir::mhlo::MhloDialect, mlir::lmhlo::LmhloDialect,
-                       mlir::StandardOpsDialect>();
-  error_handler_ = [](const ErrorMap& instructions_with_error,
-                      HloModule* module) {
-    std::set<const HloComputation*> computations_with_error;
-    for (const auto& err : instructions_with_error) {
-      computations_with_error.insert(err.first->parent());
-    }
-
-    LOG(ERROR) << module->ToString(
-        HloPrintOptions()
-            .set_print_instruction(
-                [&instructions_with_error](const HloInstruction* instr) {
-                  return instructions_with_error.count(instr);
-                })
-            .set_format_instruction(
-                // Returns the string representation of `instr` in the following
-                // format.
-                //
-                // ROOT? instr_name
-                //   FAILED: err_0
-                //   FAILED: err_1
-                //   ...
-                [&instructions_with_error](const HloInstruction* instr,
-                                           const string& instr_name, int indent,
-                                           bool is_root) {
-                  const string tab(2 * indent, ' ');
-                  if (!instructions_with_error.count(instr)) {
-                    return absl::StrCat(tab, is_root ? "ROOT " : "",
-                                        instr_name);
-                  }
-                  static constexpr char kStartBold[] = "\033[1m";
-                  static constexpr char kStartRed[] = "\033[31m";
-                  static constexpr char kBackToNormal[] = "\033[0m";
-
-                  string result =
-                      absl::StrCat(tab, kStartBold, is_root ? "ROOT " : "",
-                                   instr_name, kBackToNormal);
-
-                  for (const string& err : instructions_with_error.at(instr)) {
-                    absl::SubstituteAndAppend(
-                        &result, "\n$0  $1$2FAILED:$3 $4$5$6", tab, kStartBold,
-                        kStartRed, kBackToNormal, kStartBold, err,
-                        kBackToNormal);
-                  }
-                  return result;
-                })
-            .set_print_computation(
-                [&computations_with_error](const HloComputation* comp) {
-                  return computations_with_error.find(comp) !=
-                         computations_with_error.end();
-                }));
-  };
-  registerDiagnosticHandler();
-}
-
-EmissionContext::EmissionContext(
-    std::unique_ptr<HloModule> module,
-    std::function<void(const ErrorMap&, HloModule*)> callback)
-    : module_(std::move(module)), context_(), error_handler_(callback) {
-  registerDiagnosticHandler();
-}
-
-EmissionContext::~EmissionContext() { callErrorHandlerCallback(); }
-
-mlir::Location EmissionContext::getLocation(const HloInstruction* instr) {
-  return mlir::OpaqueLoc::get<const HloInstruction*>(instr, &context_);
-}
-
-void EmissionContext::addError(const HloInstruction* hlo_instruction,
-                               const string& str) {
-  instructions_with_error_[hlo_instruction].push_back(str);
-}
-
-void EmissionContext::setErrorHandler(
-    std::function<void(const ErrorMap&, HloModule*)> callback) {
-  error_handler_ = callback;
-}
-
-std::unique_ptr<HloModule> EmissionContext::releaseHloModule() {
-  callErrorHandlerCallback();
-  return std::move(module_);
-}
-
-HloModule* EmissionContext::getHloModule() const { return module_.get(); }
-
-mlir::MLIRContext* EmissionContext::getContext() { return &context_; }
-
-void EmissionContext::registerDiagnosticHandler() {
-  context_.getDiagEngine().registerHandler([&](mlir::Diagnostic& diag) {
-    const HloInstruction* hloInstruction =
-        mlir::OpaqueLoc::getUnderlyingLocationOrNull<const HloInstruction*>(
-            diag.getLocation());
-    assert(hloInstruction);
-    addError(hloInstruction, diag.str());
-    return mlir::success();
-  });
-}
-
-void EmissionContext::callErrorHandlerCallback() {
-  if (module_.get() && !instructions_with_error_.empty()) {
-    error_handler_(instructions_with_error_, module_.get());
-  }
-}
-
-}  // namespace mlir_gpu
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h
deleted file mode 100644
index 9550914..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h
+++ /dev/null
@@ -1,89 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EMISSION_CONTEXT_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EMISSION_CONTEXT_H_
-
-#include <memory>
-
-#include "mlir/IR/Diagnostics.h"  // from @llvm-project
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
-
-namespace xla {
-namespace mlir_gpu {
-
-// Registers a diagnostic handler and collects all the errors as a map from
-// HloInstruction* to a vector of string representations of all the errors that
-// occurred at that hlo instruction. Also, it takes a function that handles
-// those errors at the point when the instance gets destroyed or
-// `releaseHloModule()` is called.
-//
-// EmissionContext uses an RAII pattern, it owns its hlo module and mlir
-// context.
-class EmissionContext {
- public:
-  using ErrorMap =
-      std::unordered_map<const HloInstruction*, std::vector<std::string>>;
-
-  // Gets an hlo module and sets the default error handler which writes to the
-  // ERROR log and is executed when the instance gets destroyed or
-  // `releaseHloModule()` is called.
-  explicit EmissionContext(std::unique_ptr<HloModule> module);
-
-  // Gets an hlo module and an error handler function which is executed when the
-  // instance gets destroyed or `releaseHloModule()` is called.
-  EmissionContext(std::unique_ptr<HloModule> module,
-                  std::function<void(const ErrorMap&, HloModule*)> callback);
-
-  // Handles all the errors according to the error handler function before
-  // getting destroyed.
-  ~EmissionContext();
-
-  // Returns a location constructed from `instr` that then is used by
-  // the diagnostic handler to collect the errors.
-  mlir::Location getLocation(const HloInstruction* instr);
-
-  // Adds an error message associated with provided hlo instruction.
-  void addError(const HloInstruction* hlo_instruction, const string& str);
-
-  // Sets a function that handles the errors at the point when the instance
-  // gets destroyed or `releaseHloModule()` is called.
-  void setErrorHandler(
-      std::function<void(const ErrorMap&, HloModule*)> callback);
-
-  // Releases hlo module and handles all the errors according to the error
-  // handler function.
-  std::unique_ptr<HloModule> releaseHloModule();
-
-  HloModule* getHloModule() const;
-
-  mlir::MLIRContext* getContext();
-
- private:
-  void registerDiagnosticHandler();
-  void callErrorHandlerCallback();
-
-  std::unique_ptr<HloModule> module_;
-  ErrorMap instructions_with_error_;
-  mlir::MLIRContext context_;
-  std::function<void(const ErrorMap&, HloModule*)> error_handler_;
-};
-
-}  // namespace mlir_gpu
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EMISSION_CONTEXT_H_
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD
index 74eef71..f0197c7 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD
+++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD
@@ -63,7 +63,6 @@
     srcs = ["conv_emitter_test.cc"],
     tags = [
         "no_oss",  # TODO(b/148143101): Test should pass in OSS.
-        "no_rocm",
     ],
     deps = [
         ":conv_emitter",
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc
index 1bac9a1..f2b0371 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc
@@ -35,7 +35,7 @@
 #include "mlir/IR/AffineExpr.h"  // from @llvm-project
 #include "mlir/IR/AffineMap.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/Transforms/LoopUtils.h"  // from @llvm-project
 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc
deleted file mode 100644
index f712679..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc
+++ /dev/null
@@ -1,127 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
-
-#include <memory>
-
-#include "tensorflow/core/lib/core/errors.h"
-
-namespace xla {
-
-template <typename T>
-bool IsUnimplemented(StatusOr<T>& result) {
-  return result.status().code() == tensorflow::error::Code::UNIMPLEMENTED;
-}
-
-StatusOr<std::unique_ptr<HloModule>> FailoverCompiler::RunHloPasses(
-    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* device_allocator) {
-  auto result =
-      primary_->RunHloPasses(module->Clone(), stream_exec, device_allocator);
-  if (IsUnimplemented(result)) {
-    VLOG(2) << "RunHloPasses resulted in " << result.status()
-            << ", falling back to secondary backend";
-    return secondary_->RunHloPasses(std::move(module), stream_exec,
-                                    device_allocator);
-  }
-  return result;
-}
-
-StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend(
-    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* device_allocator) {
-  auto result =
-      primary_->RunBackend(module->Clone(), stream_exec, device_allocator);
-  if (IsUnimplemented(result)) {
-    VLOG(2) << "RunBackend resulted in " << result.status()
-            << ", falling back to secondary backend";
-    return secondary_->RunBackend(std::move(module), stream_exec,
-                                  device_allocator);
-  }
-  return result;
-}
-
-StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
-    std::unique_ptr<HloModuleGroup> module_group,
-    std::vector<std::vector<se::StreamExecutor*>> stream_execs,
-    se::DeviceMemoryAllocator* device_allocator) {
-  std::vector<std::unique_ptr<Executable>> result;
-  std::vector<std::unique_ptr<HloModule>> modules =
-      module_group->ConsumeModules();
-  for (size_t i = 0; i < modules.size(); i++) {
-    if (stream_execs[i].size() != 1) {
-      // This is not supported by GPU compiler anyway.
-      return Unimplemented(
-          "Model partitioning not implemented for the failover compiler!");
-    }
-    auto executable = [stream_execs, device_allocator, i,
-                       this](std::unique_ptr<HloModule> module)
-        -> StatusOr<std::unique_ptr<Executable>> {
-      TF_ASSIGN_OR_RETURN(
-          auto processed_module,
-          primary_->RunHloPasses(std::move(module), stream_execs[i][0],
-                                 device_allocator));
-      TF_ASSIGN_OR_RETURN(
-          auto result,
-          primary_->RunBackend(std::move(processed_module), stream_execs[i][0],
-                               device_allocator));
-      return result;
-    }(modules[i]->Clone());
-
-    if (IsUnimplemented(executable)) {
-      VLOG(2) << "Compile resulted in " << executable.status()
-              << ", falling back to secondary backend";
-      TF_ASSIGN_OR_RETURN(
-          modules[i],
-          secondary_->RunHloPasses(std::move(modules[i]), stream_execs[i][0],
-                                   device_allocator));
-      TF_ASSIGN_OR_RETURN(
-          executable,
-          secondary_->RunBackend(std::move(modules[i]), stream_execs[i][0],
-                                 device_allocator));
-    }
-
-    if (!executable.ok()) {
-      return executable.status();
-    }
-
-    result.push_back(std::move(executable.ValueOrDie()));
-  }
-
-  return {std::move(result)};
-}
-
-StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-FailoverCompiler::CompileAheadOfTime(
-    std::unique_ptr<HloModuleGroup> module_group,
-    const AotCompilationOptions& options) {
-  // This is not supported by GPU compiler anyway.
-  return Unimplemented(
-      "CompileAheadOfTime not implemented in failover compiler!");
-}
-
-HloCostAnalysis::ShapeSizeFunction FailoverCompiler::ShapeSizeBytesFunction()
-    const {
-  auto prim_fun = primary_->ShapeSizeBytesFunction();
-  auto second_fun = secondary_->ShapeSizeBytesFunction();
-  return [prim_fun, second_fun](const Shape& shape) -> int64 {
-    int64 primary = prim_fun(shape);
-    assert(primary == second_fun(shape));
-    return primary;
-  };
-}
-
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h
deleted file mode 100644
index 05badaa..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h
+++ /dev/null
@@ -1,81 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_FAILOVER_COMPILER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_FAILOVER_COMPILER_H_
-
-#include <memory>
-
-#include "tensorflow/compiler/xla/service/compiler.h"
-
-namespace xla {
-
-// FailoverCompiler implements a compiler that fails over between a primary
-// and secondary compiler.
-//
-// For all methods, first the primary compiler is invoked. If that compiler's
-// implementation of the method fails with an unimplemented error, the
-// secondary's compiler method is invoked. In all other cases, the result of
-// the primary compiler's method is returned.
-//
-// The primary compiler is invoked on a clone of the supplied HloModule. This
-// ensures that partial updates to the module by one compiler to not leak into
-// the other compiler.
-//
-// The FailoverCompiler is used to layer a partial compiler implementation on
-// top of a full implementation.
-class FailoverCompiler final : public Compiler {
- public:
-  FailoverCompiler(std::unique_ptr<Compiler> primary,
-                   std::unique_ptr<Compiler> secondary)
-      : primary_(std::move(primary)), secondary_(std::move(secondary)) {
-    // Both compilers should serve the same platform id.
-    assert(primary_->PlatformId() == secondary_->PlatformId());
-  }
-
-  se::Platform::Id PlatformId() const override {
-    return primary_->PlatformId();
-  }
-
-  StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
-      std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
-
-  StatusOr<std::unique_ptr<Executable>> RunBackend(
-      std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
-
-  StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
-      std::unique_ptr<HloModuleGroup> module_group,
-      std::vector<std::vector<se::StreamExecutor*>> stream_execs,
-      se::DeviceMemoryAllocator* device_allocator) override;
-
-  StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-  CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
-                     const AotCompilationOptions& options) override;
-
-  HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
-
-  Compiler* GetPrimary() const { return primary_.get(); }
-  Compiler* GetSecondary() const { return secondary_.get(); }
-
- private:
-  std::unique_ptr<Compiler> primary_;
-  std::unique_ptr<Compiler> secondary_;
-};
-
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_FAILOVER_COMPILER_H_
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
deleted file mode 100644
index 4b06cda..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
+++ /dev/null
@@ -1,276 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h"
-
-#include <utility>
-
-#include "llvm/ADT/STLExtras.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
-#include "mlir/IR/Attributes.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
-#include "mlir/IR/Types.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
-#include "tensorflow/compiler/xla/comparison_util.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_instructions.h"
-
-namespace xla {
-namespace mlir_gpu {
-namespace {
-
-using ::mlir::ArrayRef;
-using ::mlir::Attribute;
-using ::mlir::Identifier;
-using ::mlir::Location;
-using ::mlir::NamedAttribute;
-using ::mlir::OpBuilder;
-using ::mlir::RankedTensorType;
-using ::mlir::Type;
-using ::mlir::Value;
-
-namespace hlo = ::mlir::mhlo;
-
-// TODO(b/137624192) Use tablegen for this.
-StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
-                             Location loc, ArrayRef<Type> rets,
-                             ArrayRef<Value> args,
-                             ArrayRef<std::pair<Identifier, Attribute>> attrs) {
-  switch (opcode) {
-    case HloOpcode::kAbs:
-      return {func_builder.create<hlo::AbsOp>(loc, rets, args, attrs)};
-    case HloOpcode::kAdd:
-      return {func_builder.create<hlo::AddOp>(loc, rets, args, attrs)};
-    case HloOpcode::kAnd:
-      return {func_builder.create<hlo::AndOp>(loc, rets, args, attrs)};
-    case HloOpcode::kCeil:
-      return {func_builder.create<hlo::CeilOp>(loc, rets, args, attrs)};
-    case HloOpcode::kComplex:
-      return {func_builder.create<hlo::ComplexOp>(loc, rets, args, attrs)};
-    case HloOpcode::kCopy:
-      return {func_builder.create<hlo::CopyOp>(loc, rets, args, attrs)};
-    case HloOpcode::kCos:
-      return {func_builder.create<hlo::CosOp>(loc, rets, args, attrs)};
-    case HloOpcode::kDivide:
-      return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
-    case HloOpcode::kExp:
-      return {func_builder.create<hlo::ExpOp>(loc, rets, args, attrs)};
-    case HloOpcode::kImag:
-      return {func_builder.create<hlo::ImagOp>(loc, rets, args, attrs)};
-    case HloOpcode::kLog:
-      return {func_builder.create<hlo::LogOp>(loc, rets, args, attrs)};
-    case HloOpcode::kMaximum:
-      return {func_builder.create<hlo::MaxOp>(loc, rets, args, attrs)};
-    case HloOpcode::kMinimum:
-      return {func_builder.create<hlo::MinOp>(loc, rets, args, attrs)};
-    case HloOpcode::kMultiply:
-      return {func_builder.create<hlo::MulOp>(loc, rets, args, attrs)};
-    case HloOpcode::kNegate:
-      return {func_builder.create<hlo::NegOp>(loc, rets, args, attrs)};
-    case HloOpcode::kReal:
-      return {func_builder.create<hlo::RealOp>(loc, rets, args, attrs)};
-    case HloOpcode::kRemainder:
-      return {func_builder.create<hlo::RemOp>(loc, rets, args, attrs)};
-    case HloOpcode::kRsqrt:
-      return {func_builder.create<hlo::RsqrtOp>(loc, rets, args, attrs)};
-    case HloOpcode::kSelect:
-      return {func_builder.create<hlo::SelectOp>(loc, rets, args, attrs)};
-    case HloOpcode::kSign:
-      return {func_builder.create<hlo::SignOp>(loc, rets, args, attrs)};
-    case HloOpcode::kSqrt:
-      return {func_builder.create<hlo::SqrtOp>(loc, rets, args, attrs)};
-    case HloOpcode::kSubtract:
-      return {func_builder.create<hlo::SubOp>(loc, rets, args, attrs)};
-    case HloOpcode::kTanh:
-      return {func_builder.create<hlo::TanhOp>(loc, rets, args, attrs)};
-    default:
-      return tensorflow::errors::Internal(absl::StrCat(
-          "HLO Opcode ", HloOpcodeString(opcode), " is not supported."));
-  }
-}
-
-}  // namespace
-
-mlir::Location HloDialectEmitter::getLocation(
-    const HloInstruction* instr) const {
-  return emission_context_->getLocation(instr);
-}
-
-StatusOr<Value> HloDialectEmitter::EmitComputation(
-    const HloComputation& computation) {
-  const auto root = computation.root_instruction();
-  TF_RETURN_IF_ERROR(root->Accept(this));
-  return instruction_to_values_[root];
-}
-
-Status HloDialectEmitter::DefaultAction(HloInstruction* instr) {
-  TF_ASSIGN_OR_RETURN(auto res_type, ConvertTensorShapeToType<RankedTensorType>(
-                                         instr->shape(), builder_));
-  llvm::SmallVector<Value, 4> arguments;
-  arguments.reserve(instr->operand_count());
-  for (auto operand : instr->operands()) {
-    arguments.push_back(instruction_to_values_[operand]);
-  }
-  TF_ASSIGN_OR_RETURN(
-      auto inserted, InsertMlirOp(instr->opcode(), builder_, getLocation(instr),
-                                  res_type, arguments, llvm::None));
-  instruction_to_values_[instr] = inserted;
-  return Status::OK();
-}
-
-Status HloDialectEmitter::HandleBroadcast(HloInstruction* instr) {
-  mlir::DenseIntElementsAttr broadcast_dim =
-      CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_);
-  TF_ASSIGN_OR_RETURN(Type res_type, ConvertTensorShapeToType<RankedTensorType>(
-                                         instr->shape(), builder_));
-
-  instruction_to_values_[instr] = builder_.create<hlo::BroadcastInDimOp>(
-      getLocation(instr), llvm::makeArrayRef(res_type),
-      instruction_to_values_[instr->operand(0)], broadcast_dim);
-  return Status::OK();
-}
-
-Status HloDialectEmitter::HandleConcatenate(HloInstruction* instr) {
-  int64 concatenate_dim = instr->concatenate_dimension();
-  TF_ASSIGN_OR_RETURN(Type res_type, ConvertTensorShapeToType<RankedTensorType>(
-                                         instr->shape(), builder_));
-
-  llvm::SmallVector<Value, 4> arguments;
-  arguments.reserve(instr->operand_count());
-  for (auto operand : instr->operands()) {
-    arguments.push_back(instruction_to_values_[operand]);
-  }
-
-  instruction_to_values_[instr] = builder_.create<hlo::ConcatenateOp>(
-      getLocation(instr), llvm::makeArrayRef(res_type), arguments,
-      builder_.getI64IntegerAttr(concatenate_dim));
-  return Status::OK();
-}
-
-Status HloDialectEmitter::HandleParameter(HloInstruction* instr) {
-  auto argValue = arguments_[instr->parameter_number()];
-  instruction_to_values_[instr] = argValue;
-  return Status::OK();
-}
-
-Status HloDialectEmitter::HandleConstant(HloInstruction* instr) {
-  auto shape = instr->shape();
-  if (!shape.IsArray() || shape.rank() != 0) {
-    return Unimplemented("non-scalar constants are not supported yet");
-  }
-  TF_ASSIGN_OR_RETURN(auto type, ConvertTensorShapeToType<RankedTensorType>(
-                                     instr->shape(), builder_));
-
-  TF_ASSIGN_OR_RETURN(auto value, CreateDenseElementsAttrFromLiteral(
-                                      instr->literal(), builder_));
-
-  auto const_value =
-      builder_.create<hlo::ConstOp>(getLocation(instr), type, value);
-  instruction_to_values_[instr] = const_value;
-  return Status::OK();
-}
-
-Status HloDialectEmitter::HandleGather(HloInstruction* instr) {
-  HloGatherInstruction* gather = static_cast<HloGatherInstruction*>(instr);
-  mlir::mhlo::GatherDimensionNumbers dimension_numbers =
-      xla::CreateGatherDimensionNumbers(gather->gather_dimension_numbers(),
-                                        builder_);
-  mlir::DenseIntElementsAttr slice_sizes = CreateDenseIntElementsAttrFromVector(
-      llvm::SmallVector<int64, 4>{gather->gather_slice_sizes().begin(),
-                                  gather->gather_slice_sizes().end()},
-      builder_);
-  mlir::BoolAttr indices_are_sorted =
-      builder_.getBoolAttr(gather->indices_are_sorted());
-
-  TF_ASSIGN_OR_RETURN(Type res_type, ConvertTensorShapeToType<RankedTensorType>(
-                                         instr->shape(), builder_));
-
-  instruction_to_values_[instr] = builder_.create<hlo::GatherOp>(
-      getLocation(instr), res_type, instruction_to_values_[instr->operand(0)],
-      instruction_to_values_[instr->operand(1)], dimension_numbers, slice_sizes,
-      indices_are_sorted);
-  return Status::OK();
-}
-
-Status HloDialectEmitter::HandleReduce(HloInstruction* instr) {
-  llvm::SmallVector<Value, 4> operands;
-  for (auto operand : instr->operands()) {
-    operands.push_back(instruction_to_values_.at(operand));
-  }
-  const unsigned num_inputs = operands.size() / 2;
-  TF_ASSIGN_OR_RETURN(
-      const auto return_type,
-      ConvertTensorShapeToType<RankedTensorType>(instr->shape(), builder_));
-  const auto dimensions_attr =
-      CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_);
-  auto reduceOp = builder_.create<hlo::ReduceOp>(
-      getLocation(instr), return_type,
-      llvm::makeArrayRef(operands).take_front(num_inputs),
-      llvm::makeArrayRef(operands).take_back(num_inputs), dimensions_attr);
-  {
-    auto computation = instr->to_apply();
-    auto block = new mlir::Block();
-    llvm::SmallVector<Value, 4> arguments;
-    arguments.reserve(computation->num_parameters());
-    for (auto parameter : computation->parameter_instructions()) {
-      TF_ASSIGN_OR_RETURN(auto param_type,
-                          ConvertTensorShapeToType<RankedTensorType>(
-                              parameter->shape(), builder_));
-      arguments.push_back(block->addArgument(param_type));
-    }
-    reduceOp.body().push_back(block);
-    HloDialectEmitter emitter(emission_context_, &reduceOp.body(), arguments);
-    TF_ASSIGN_OR_RETURN(auto result, emitter.EmitComputation(*computation));
-    OpBuilder body_builder = OpBuilder::atBlockEnd(block);
-    body_builder.setInsertionPointToEnd(block);
-    body_builder.create<hlo::ReturnOp>(getLocation(instr),
-                                       ArrayRef<Value>{result});
-  }
-  // TODO(b/137624192) Add support for multiple results.
-  instruction_to_values_[instr] = reduceOp.getResult(0);
-  return Status::OK();
-}
-
-Status HloDialectEmitter::HandleCompare(HloInstruction* instr) {
-  TF_ASSIGN_OR_RETURN(Type res_type, ConvertTensorShapeToType<RankedTensorType>(
-                                         instr->shape(), builder_));
-  auto comparison_direction_attr = builder_.getNamedAttr(
-      "comparison_direction",
-      builder_.getStringAttr(
-          ComparisonDirectionToString(instr->comparison_direction())));
-  llvm::SmallVector<Value, 4> arguments;
-  arguments.reserve(instr->operand_count());
-  for (auto operand : instr->operands()) {
-    arguments.push_back(instruction_to_values_[operand]);
-  }
-  instruction_to_values_[instr] = builder_.create<hlo::CompareOp>(
-      getLocation(instr), llvm::makeArrayRef(res_type), arguments,
-      comparison_direction_attr);
-  return Status::OK();
-}
-
-Status HloDialectEmitter::HandleIota(HloInstruction* instr) {
-  mlir::IntegerAttr iota_dim = builder_.getI64IntegerAttr(
-      static_cast<HloIotaInstruction*>(instr)->iota_dimension());
-  TF_ASSIGN_OR_RETURN(Type res_type, ConvertTensorShapeToType<RankedTensorType>(
-                                         instr->shape(), builder_));
-  instruction_to_values_[instr] =
-      builder_.create<hlo::IotaOp>(getLocation(instr), res_type, iota_dim);
-  return Status::OK();
-}
-
-}  // namespace mlir_gpu
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
deleted file mode 100644
index 1ec3cf4..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
+++ /dev/null
@@ -1,76 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_HLO_DIALECT_EMITTER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_HLO_DIALECT_EMITTER_H_
-
-#include <memory>
-
-#include "absl/types/span.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
-#include "tensorflow/compiler/xla/status.h"
-
-namespace xla {
-namespace mlir_gpu {
-
-class HloDialectEmitter : public DfsHloVisitorWithDefault {
- public:
-  HloDialectEmitter(xla::mlir_gpu::EmissionContext* emission_context,
-                    ::mlir::Region* region,
-                    llvm::ArrayRef<::mlir::Value> arguments)
-      : emission_context_(emission_context),
-        builder_(region),
-        arguments_(arguments) {}
-
-  HloDialectEmitter(xla::mlir_gpu::EmissionContext* emission_context,
-                    ::mlir::OpBuilder builder,
-                    llvm::ArrayRef<::mlir::Value> arguments)
-      : emission_context_(emission_context),
-        builder_(builder),
-        arguments_(arguments) {}
-
-  StatusOr<mlir::Value> EmitComputation(const HloComputation& computation);
-
-  Status DefaultAction(HloInstruction* instr) override;
-  Status HandleBroadcast(HloInstruction* instr) override;
-  Status HandleCompare(HloInstruction* instr) override;
-  Status HandleConcatenate(HloInstruction* instr) override;
-  Status HandleConstant(HloInstruction* instr) override;
-  Status HandleGather(HloInstruction* instr) override;
-  Status HandleIota(HloInstruction* instr) override;
-  Status HandleParameter(HloInstruction* instr) override;
-  Status HandleReduce(HloInstruction* instr) override;
-
- private:
-  mlir::Location getLocation(const HloInstruction* instr) const;
-
-  xla::mlir_gpu::EmissionContext* emission_context_;
-  ::mlir::OpBuilder builder_;
-  llvm::ArrayRef<::mlir::Value> arguments_;
-  absl::flat_hash_map<const xla::HloInstruction*, ::mlir::Value>
-      instruction_to_values_;
-};
-
-}  // namespace mlir_gpu
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_HLO_DIALECT_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc
deleted file mode 100644
index 7445ab5..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc
+++ /dev/null
@@ -1,41 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h"
-
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-
-namespace mlir {
-namespace {
-
-struct InjectErrorsForTestingPass
-    : public PassWrapper<InjectErrorsForTestingPass, FunctionPass> {
-  void runOnFunction() override {
-    getFunction().getBody().walk([&](Operation *op) {
-      op->emitError() << "failed for testing: " << op->getName();
-    });
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<OperationPass<FuncOp>> createInjectErrorsForTestingPass() {
-  return std::make_unique<InjectErrorsForTestingPass>();
-}
-
-static PassRegistration<InjectErrorsForTestingPass> pass(
-    "inject-errors", "Emits errors from all operations");
-
-}  // namespace mlir
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h
deleted file mode 100644
index 9f0612c..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h
+++ /dev/null
@@ -1,29 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_INJECT_ERRORS_PASS_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_INJECT_ERRORS_PASS_H_
-
-#include "mlir/Pass/Pass.h"  // from @llvm-project
-
-namespace mlir {
-
-// Returns a function pass that emits errors from all operations inside the
-// function.
-std::unique_ptr<OperationPass<FuncOp>> createInjectErrorsForTestingPass();
-
-}  // namespace mlir
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_INJECT_ERRORS_PASS_H_
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index b738c40..22404cd 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -162,18 +162,11 @@
     ::mlir::LLVMTypeConverter converter(m.getContext());
     ::mlir::populateStdToLLVMConversionPatterns(converter, patterns);
     // TODO(b/145824979) Remove linalg once sliceop is in std.
-    ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns,
-                                                   &getContext());
+    ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns);
     ::mlir::populateGpuToNVVMConversionPatterns(converter, patterns);
     ::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
     ::mlir::ConversionTarget target(getContext());
-    target.addIllegalDialect<::mlir::gpu::GPUDialect>();
-    target.addIllegalOp<::mlir::LLVM::ExpOp>();
-    target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
-    target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
-    // TODO(csigg): Remove once we support replacing non-root ops.
-    target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
-                      ::mlir::gpu::YieldOp>();
+    ::mlir::configureGpuToNVVMConversionLegality(target);
     if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
       signalPassFailure();
     }
@@ -229,23 +222,12 @@
     ::mlir::LLVMTypeConverter converter(m.getContext());
     ::mlir::populateStdToLLVMConversionPatterns(converter, patterns);
     // TODO(b/145824979) Remove linalg once sliceop is in std.
-    ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns,
-                                                   &getContext());
+    ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns);
     ::mlir::populateGpuToROCDLConversionPatterns(converter, patterns);
     ::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
 
     ::mlir::ConversionTarget target(getContext());
-    target.addIllegalDialect<::mlir::gpu::GPUDialect>();
-    target.addIllegalOp<mlir::LLVM::CosOp, mlir::LLVM::ExpOp,
-                        mlir::LLVM::FAbsOp, mlir::LLVM::FCeilOp,
-                        mlir::LLVM::LogOp, mlir::LLVM::Log10Op,
-                        mlir::LLVM::Log2Op, mlir::LLVM::SinOp>();
-    target.addIllegalOp<mlir::FuncOp>();
-    target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
-    target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
-    // TODO(csigg): Remove once we support replacing non-root ops.
-    target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
-                      ::mlir::gpu::YieldOp>();
+    ::mlir::configureGpuToROCDLConversionLegality(target);
     if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
       signalPassFailure();
     }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
deleted file mode 100644
index 6b2e53e..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
+++ /dev/null
@@ -1,504 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h"
-
-#include <utility>
-
-#include "llvm/IR/DataLayout.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
-#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
-#include "mlir/IR/Attributes.h"  // from @llvm-project
-#include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/Identifier.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
-#include "mlir/IR/Types.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
-#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
-#include "tensorflow/compiler/xla/service/gpu/thunk.h"
-#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_instructions.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/lib/core/errors.h"
-
-namespace xla {
-namespace mlir_gpu {
-namespace {
-
-using ::mlir::ArrayRef;
-using ::mlir::Attribute;
-using ::mlir::Builder;
-using ::mlir::DenseIntElementsAttr;
-using ::mlir::FuncOp;
-using ::mlir::Identifier;
-using ::mlir::Location;
-using ::mlir::MemRefType;
-using ::mlir::ModuleOp;
-using ::mlir::OpBuilder;
-using ::mlir::Type;
-using ::mlir::Value;
-using ::mlir::LLVM::LLVMDialect;
-using ::xla::gpu::Thunk;
-using ::xla::gpu::ThunkEmitter;
-using ::xla::gpu::ThunkSequence;
-
-namespace lhlo = ::mlir::lmhlo;
-
-// TODO(b/137624192) Use tablegen for this.
-Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
-                    ArrayRef<Type> rets, ArrayRef<Value> args,
-                    ArrayRef<std::pair<Identifier, Attribute>> attrs) {
-  switch (opcode) {
-    case HloOpcode::kAbs:
-      func_builder.create<lhlo::AbsOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kAdd:
-      func_builder.create<lhlo::AddOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kAnd:
-      func_builder.create<lhlo::AndOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kCeil:
-      func_builder.create<lhlo::CeilOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kComplex:
-      func_builder.create<lhlo::ComplexOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kCopy:
-      func_builder.create<lhlo::CopyOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kCos:
-      func_builder.create<lhlo::CosOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kDivide:
-      func_builder.create<lhlo::DivOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kExp:
-      func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kImag:
-      func_builder.create<lhlo::ImagOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kLog:
-      func_builder.create<lhlo::LogOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kMaximum:
-      func_builder.create<lhlo::MaxOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kMinimum:
-      func_builder.create<lhlo::MinOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kMultiply:
-      func_builder.create<lhlo::MulOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kNegate:
-      func_builder.create<lhlo::NegOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kReal:
-      func_builder.create<lhlo::RealOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kRemainder:
-      func_builder.create<lhlo::RemOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kRsqrt:
-      func_builder.create<lhlo::RsqrtOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kSelect:
-      func_builder.create<lhlo::SelectOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kSign:
-      func_builder.create<lhlo::SignOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kSqrt:
-      func_builder.create<lhlo::SqrtOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kSubtract:
-      func_builder.create<lhlo::SubOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kTanh:
-      func_builder.create<lhlo::TanhOp>(loc, rets, args, attrs);
-      break;
-    default:
-      return tensorflow::errors::Internal(absl::StrCat(
-          "LHLO opcode ", HloOpcodeString(opcode), " is not supported."));
-  }
-  return Status::OK();
-}
-
-StatusOr<llvm::SmallVector<Type, 4>> GetInstructionArgTypes(
-    const HloInstruction& instruction, Builder builder) {
-  llvm::SmallVector<Type, 4> arg_types;
-  for (auto operand : instruction.operands()) {
-    TF_ASSIGN_OR_RETURN(auto operand_type, ConvertShapeToType<MemRefType>(
-                                               operand->shape(), builder));
-    arg_types.push_back(operand_type);
-  }
-  TF_ASSIGN_OR_RETURN(auto operand_type, ConvertShapeToType<MemRefType>(
-                                             instruction.shape(), builder));
-  arg_types.push_back(operand_type);
-  return arg_types;
-}
-
-// Converts HloComputation into a block with HLO dialect ops. The block gets
-// memref arguments corresponding to HloComputation arguments and results.
-Status SpliceHloComputation(OpBuilder builder, mlir::Location loc,
-                            const HloComputation& hlo_computation,
-                            xla::mlir_gpu::EmissionContext* emission_context) {
-  auto block = builder.getInsertionBlock();
-  builder.setInsertionPoint(block->getTerminator());
-  llvm::SmallVector<Value, 4> arg_values;
-  // First map parameters to memrefs on the operation.
-  for (auto param : hlo_computation.parameter_instructions()) {
-    TF_ASSIGN_OR_RETURN(
-        auto arg_type, ConvertShapeToType<MemRefType>(param->shape(), builder));
-    auto block_arg = block->addArgument(arg_type);
-    arg_values.push_back(builder.create<::mlir::TensorLoadOp>(loc, block_arg));
-  }
-  HloDialectEmitter hlo_emitter(emission_context, builder, arg_values);
-
-  TF_ASSIGN_OR_RETURN(auto result,
-                      hlo_emitter.EmitComputation(hlo_computation));
-
-  // Now add a block arg and store for the result.
-  builder.setInsertionPoint(block->getTerminator());
-  TF_ASSIGN_OR_RETURN(
-      auto result_type,
-      ConvertShapeToType<MemRefType>(
-          hlo_computation.root_instruction()->shape(), builder));
-  auto block_arg = block->addArgument(result_type);
-  builder.create<::mlir::TensorStoreOp>(loc, result, block_arg);
-
-  return Status::OK();
-}
-
-}  // namespace
-
-mlir::Location LhloDialectEmitter::getLocation(
-    const HloInstruction* instr) const {
-  return emission_context_->getLocation(instr);
-}
-
-LhloDialectEmitter::LhloDialectEmitter(
-    xla::mlir_gpu::EmissionContext* emission_context,
-    const BufferAssignment& assignment, const se::Platform* platform,
-    ModuleOp mlir_module)
-    : emission_context_(emission_context),
-      mlir_module_(mlir_module),
-      builder_(mlir_module_.getContext()),
-      buffer_assignment_(assignment),
-      platform_(platform) {
-  llvm::DataLayout data_layout("");
-  if (auto data_layout_attr = mlir_module.getAttrOfType<mlir::StringAttr>(
-          mlir::LLVM::LLVMDialect::getDataLayoutAttrName())) {
-    data_layout.reset(data_layout_attr.getValue());
-  }
-
-  pointer_size_ = data_layout.getPointerSize();
-}
-
-void LhloDialectEmitter::AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) {
-  thunk_sequence_.push_back(std::move(thunk));
-}
-
-StatusOr<BufferAllocation::Slice> LhloDialectEmitter::MaybeGetAllocationSlice(
-    const HloInstruction& hlo, const ShapeIndex& index) const {
-  return buffer_assignment_.GetUniqueSlice(&hlo, index);
-}
-
-int64 LhloDialectEmitter::ByteSizeOf(const Shape& shape) const {
-  return ShapeUtil::ByteSizeOf(shape, pointer_size_);
-}
-
-absl::string_view LhloDialectEmitter::platform_name() const {
-  return platform_->Name();
-}
-
-StatusOr<FuncOp> LhloDialectEmitter::CreateFunction(
-    const HloInstruction& instr) {
-  TF_ASSIGN_OR_RETURN(auto args, GetInstructionArgTypes(instr, builder_));
-  auto function_type = builder_.getFunctionType(args, {});
-  auto function =
-      FuncOp::create(getLocation(&instr), instr.name(), function_type);
-  mlir_module_.push_back(function);
-  function.addEntryBlock();
-  OpBuilder op_builder(function.getBody());
-  op_builder.create<::mlir::ReturnOp>(getLocation(&instr));
-  instruction_to_mlir_func_[&instr] = function;
-  return function;
-}
-
-Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  OpBuilder func_builder(function.getBody());
-  llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
-                                         function.args_end()};
-  TF_RETURN_IF_ERROR(InsertMlirOp(instr->opcode(), func_builder,
-                                  getLocation(instr), ArrayRef<Type>{},
-                                  arg_values, llvm::None));
-  return Status::OK();
-}
-
-Status LhloDialectEmitter::HandleBroadcast(HloInstruction* instr) {
-  DenseIntElementsAttr broadcast_dim =
-      CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_);
-
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  OpBuilder func_builder(function.getBody());
-  func_builder.create<lhlo::BroadcastInDimOp>(
-      getLocation(instr), function.getArgument(0), function.getArgument(1),
-      broadcast_dim);
-  return Status::OK();
-}
-
-Status LhloDialectEmitter::HandleConcatenate(HloInstruction* instr) {
-  mlir::IntegerAttr concatenate_dim = builder_.getI64IntegerAttr(
-      static_cast<HloConcatenateInstruction*>(instr)->concatenate_dimension());
-
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  OpBuilder func_builder(function.getBody());
-  func_builder.create<lhlo::ConcatenateOp>(
-      getLocation(instr), function.getArguments().drop_back(),
-      function.getArguments().back(), concatenate_dim);
-  return Status::OK();
-}
-
-Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) {
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  OpBuilder func_builder(function.getBody());
-  auto fusion_op = func_builder.create<lhlo::FusionOp>(getLocation(instr));
-
-  // Load the HLO argument tensors from the corresponding buffers. The last
-  // argument is for the result, so no need to load it.
-  OpBuilder body_builder(fusion_op.region());
-  llvm::SmallVector<Value, 4> arg_values;
-  for (int i = 0, e = function.getNumArguments() - 1; i < e; ++i) {
-    arg_values.push_back(body_builder.create<::mlir::TensorLoadOp>(
-        getLocation(instr), function.getArgument(i)));
-  }
-  HloDialectEmitter hlo_emitter(emission_context_, body_builder, arg_values);
-
-  TF_ASSIGN_OR_RETURN(
-      auto result,
-      hlo_emitter.EmitComputation(*instr->fused_instructions_computation()));
-
-  // Insert the write-back from the HLO computation to the result argument
-  // buffer.
-  body_builder.setInsertionPoint(fusion_op.region().back().getTerminator());
-  Value result_memref = function.getArguments().back();
-  body_builder.create<::mlir::TensorStoreOp>(getLocation(instr), result,
-                                             result_memref);
-
-  return Status::OK();
-}
-
-Status LhloDialectEmitter::HandleGather(HloInstruction* instr) {
-  HloGatherInstruction* gather = static_cast<HloGatherInstruction*>(instr);
-  mlir::mhlo::GatherDimensionNumbers dim_numbers =
-      xla::CreateGatherDimensionNumbers(gather->gather_dimension_numbers(),
-                                        builder_);
-  mlir::DenseIntElementsAttr slice_sizes = CreateDenseIntElementsAttrFromVector(
-      llvm::SmallVector<int64, 4>{gather->gather_slice_sizes().begin(),
-                                  gather->gather_slice_sizes().end()},
-      builder_);
-
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  OpBuilder func_builder(function.getBody());
-
-  func_builder.create<lhlo::GatherOp>(
-      getLocation(instr), function.getArgument(0), function.getArgument(1),
-      dim_numbers, slice_sizes, function.getArgument(2));
-
-  return Status::OK();
-}
-
-Status LhloDialectEmitter::HandleReduce(HloInstruction* instr) {
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
-                                         function.args_end()};
-  OpBuilder builder(function.getBody());
-  auto loc = getLocation(instr);
-  int input_count = instr->operand_count() / 3;
-  auto inputs = llvm::makeArrayRef(arg_values).slice(input_count);
-  auto init_values =
-      llvm::makeArrayRef(arg_values).slice(input_count, input_count);
-  auto results =
-      llvm::makeArrayRef(arg_values).slice(2 * input_count, input_count);
-  auto dimensions_attr =
-      CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_);
-  auto reduce_op = builder.create<lhlo::ReduceOp>(loc, inputs, init_values,
-                                                  results, dimensions_attr);
-  builder.createBlock(&reduce_op.body());
-  OpBuilder::atBlockEnd(&reduce_op.body().front())
-      .create<lhlo::TerminatorOp>(getLocation(instr));
-  return SpliceHloComputation(OpBuilder{&reduce_op.body()}, loc,
-                              *instr->to_apply(), emission_context_);
-}
-
-Status LhloDialectEmitter::HandleReduceWindow(HloInstruction* instr) {
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
-                                         function.args_end()};
-  OpBuilder builder(function.getBody());
-  auto loc = getLocation(instr);
-
-  // Collect attribute values.
-  llvm::SmallVector<int64, 2> window_dimensions, window_strides, base_dilations,
-      window_dilations;
-  llvm::SmallVector<int64, 4> padding;
-  int64 rank = instr->window().dimensions_size();
-  window_dimensions.reserve(rank);
-  window_strides.reserve(rank);
-  base_dilations.reserve(rank);
-  window_dilations.reserve(rank);
-  padding.reserve(2 * rank);
-  for (const auto& window : instr->window().dimensions()) {
-    window_dimensions.push_back(window.size());
-    window_strides.push_back(window.stride());
-    base_dilations.push_back(window.base_dilation());
-    window_dilations.push_back(window.window_dilation());
-    padding.push_back(window.padding_low());
-    padding.push_back(window.padding_high());
-  }
-
-  auto reduce_window_op = builder.create<lhlo::ReduceWindowOp>(
-      loc, /*operand=*/arg_values[0], /*init_value=*/arg_values[1],
-      /*out=*/arg_values[2],
-      CreateDenseIntElementsAttrFromVector(window_dimensions, builder),
-      CreateDenseIntElementsAttrFromVector(window_strides, builder),
-      CreateDenseIntElementsAttrFromVector(base_dilations, builder),
-      CreateDenseIntElementsAttrFromVector(window_dilations, builder),
-      CreateDenseIntElementsAttrFromVector(padding, builder, {rank, 2}));
-  reduce_window_op.ensureTerminator(reduce_window_op.body(), builder, loc);
-  return SpliceHloComputation(OpBuilder{&reduce_window_op.body()}, loc,
-                              *instr->to_apply(), emission_context_);
-}
-
-Status LhloDialectEmitter::HandleSelectAndScatter(HloInstruction* instr) {
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
-                                         function.args_end()};
-  OpBuilder builder(function.getBody());
-  auto loc = getLocation(instr);
-
-  // Collect attribute values.
-  llvm::SmallVector<int64, 2> window_dimensions, window_strides, padding;
-  int64 rank = instr->window().dimensions_size();
-  window_dimensions.reserve(rank);
-  window_strides.reserve(rank);
-  padding.reserve(2 * rank);
-  for (const auto& window : instr->window().dimensions()) {
-    window_dimensions.push_back(window.size());
-    window_strides.push_back(window.stride());
-    padding.push_back(window.padding_low());
-    padding.push_back(window.padding_high());
-  }
-
-  auto select_scatter_op = builder.create<lhlo::SelectAndScatterOp>(
-      loc, /*operand=*/arg_values[0], /*source=*/arg_values[1],
-      /*init_value=*/arg_values[2],
-      /*out=*/arg_values[3],
-      CreateDenseIntElementsAttrFromVector(window_dimensions, builder),
-      CreateDenseIntElementsAttrFromVector(window_strides, builder),
-      CreateDenseIntElementsAttrFromVector(padding, builder, {rank, 2}));
-
-  // Convert `select` computation.
-  builder.createBlock(&select_scatter_op.select());
-  OpBuilder select_builder{&select_scatter_op.select()};
-  select_builder.create<lhlo::TerminatorOp>(loc);
-  TF_RETURN_IF_ERROR(SpliceHloComputation(select_builder, loc, *instr->select(),
-                                          emission_context_));
-
-  // Convert `scatter` computation.
-  builder.createBlock(&select_scatter_op.scatter());
-  OpBuilder scatter_builder{&select_scatter_op.scatter()};
-  scatter_builder.create<lhlo::TerminatorOp>(loc);
-  TF_RETURN_IF_ERROR(SpliceHloComputation(
-      scatter_builder, loc, *instr->scatter(), emission_context_));
-
-  return Status::OK();
-}
-
-Status LhloDialectEmitter::HandleCustomCall(HloInstruction* instr) {
-  return ThunkEmitter(this).HandleCustomCall(instr);
-}
-
-Status LhloDialectEmitter::HandleParameter(HloInstruction* instr) {
-  return Status::OK();
-}
-
-Status LhloDialectEmitter::HandleCompare(HloInstruction* instr) {
-  auto comparison_direction_attr = builder_.getNamedAttr(
-      "comparison_direction",
-      builder_.getStringAttr(
-          ComparisonDirectionToString(instr->comparison_direction())));
-
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  OpBuilder func_builder(function.getBody());
-  llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
-                                         function.args_end()};
-  func_builder.create<lhlo::CompareOp>(getLocation(instr), llvm::None,
-                                       arg_values, comparison_direction_attr);
-  return Status::OK();
-}
-
-Status LhloDialectEmitter::HandleConstant(HloInstruction* instr) {
-  auto shape = instr->shape();
-  if (!shape.IsArray() || shape.rank() != 0) {
-    return Unimplemented("non-scalar constants are not supported yet");
-  }
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  OpBuilder func_builder(function.getBody());
-
-  TF_ASSIGN_OR_RETURN(auto value, CreateDenseElementsAttrFromLiteral(
-                                      instr->literal(), func_builder));
-  func_builder.create<lhlo::ConstOp>(getLocation(instr), value,
-                                     *function.args_begin());
-  return Status::OK();
-}
-
-Status LhloDialectEmitter::HandleIota(HloInstruction* instr) {
-  mlir::IntegerAttr iota_dim = builder_.getI64IntegerAttr(
-      static_cast<HloIotaInstruction*>(instr)->iota_dimension());
-
-  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
-  OpBuilder func_builder(function.getBody());
-  func_builder.create<lhlo::IotaOp>(getLocation(instr), iota_dim,
-                                    function.getArgument(0));
-  return Status::OK();
-}
-
-Status LhloDialectEmitter::HandleTuple(HloInstruction* instr) {
-  // For the root node of the entry computation we can elide writing the tuple
-  // buffer. We can always figure out the contents of the tuples from buffer
-  // assignment because we insert copies to ensure non-ambiguous output buffers.
-  // GpuExecutable never reads the tuple buffer.
-  if (instr ==
-      instr->parent()->parent()->entry_computation()->root_instruction()) {
-    return Status::OK();
-  }
-  return Unimplemented("handling of typles not yet implemented");
-}
-
-Status LhloDialectEmitter::FinishVisit(HloInstruction* root) {
-  return Status::OK();
-}
-
-}  // namespace mlir_gpu
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h
deleted file mode 100644
index 91cc224..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h
+++ /dev/null
@@ -1,111 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_LHLO_DIALECT_EMITTER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_LHLO_DIALECT_EMITTER_H_
-
-#include <memory>
-#include <utility>
-
-#include "absl/container/flat_hash_map.h"
-#include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "tensorflow/compiler/xla/service/buffer_assignment.h"
-#include "tensorflow/compiler/xla/service/gpu/thunk.h"
-#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
-#include "tensorflow/compiler/xla/status.h"
-
-namespace xla {
-namespace mlir_gpu {
-
-// Implementation for the translation of HLO instructions to a ThunkSequence
-// via MLIR using the LHLO dialect.
-// Implements the DfsHloVisitor interface, emits LHLO computations as MLIR IR
-// functions and transforms them into gpu::Thunk.
-class LhloDialectEmitter : public DfsHloVisitorWithDefault,
-                           private gpu::ThunkEmitter::EmissionContext {
- public:
-  LhloDialectEmitter(xla::mlir_gpu::EmissionContext* emission_context,
-                     const BufferAssignment& assignment,
-                     const se::Platform* platform,
-                     ::mlir::ModuleOp mlir_module);
-  ~LhloDialectEmitter() override = default;
-
-  // The following methods implement the DfsHloVisitor interface.
-  //
-  // Default action which emits code for most operations. Operations which are
-  // special in some way are handled explicitly in HandleFoo methods.
-  Status DefaultAction(HloInstruction* instr) override;
-  Status HandleBroadcast(HloInstruction* instr) override;
-  Status HandleCompare(HloInstruction* instr) override;
-  Status HandleConcatenate(HloInstruction* instr) override;
-  Status HandleConstant(HloInstruction* instr) override;
-  Status HandleCustomCall(HloInstruction* instr) override;
-  Status HandleFusion(HloInstruction* instr) override;
-  Status HandleGather(HloInstruction* instr) override;
-  Status HandleIota(HloInstruction* instr) override;
-  Status HandleParameter(HloInstruction* instr) override;
-  Status HandleReduce(HloInstruction* instr) override;
-  Status HandleReduceWindow(HloInstruction* instr) override;
-  Status HandleSelectAndScatter(HloInstruction* instr) override;
-  Status HandleTuple(HloInstruction* instr) override;
-
-  Status FinishVisit(HloInstruction* root) override;
-
-  // Transfers the ownship of thunk_sequence_ out.
-  gpu::ThunkSequence ConsumeThunkSequence() {
-    gpu::ThunkSequence result;
-    std::swap(result, thunk_sequence_);
-    return result;
-  }
-
-  const absl::flat_hash_map<const xla::HloInstruction*, ::mlir::FuncOp>&
-  InstructionToFunctionMap() const {
-    return instruction_to_mlir_func_;
-  }
-
- private:
-  StatusOr<::mlir::FuncOp> CreateFunction(const HloInstruction& instr);
-  // Interface required by ThunkEmitter
-  void AddThunkToThunkSequence(std::unique_ptr<gpu::Thunk> thunk) override;
-  StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice(
-      const HloInstruction& hlo, const ShapeIndex& index) const override;
-  int64 ByteSizeOf(const Shape& shape) const override;
-  absl::string_view platform_name() const override;
-
-  mlir::Location getLocation(const HloInstruction* instr) const;
-
-  xla::mlir_gpu::EmissionContext* emission_context_;
-  ::mlir::ModuleOp mlir_module_;
-  ::mlir::Builder builder_;
-  absl::flat_hash_map<const xla::HloInstruction*, ::mlir::FuncOp>
-      instruction_to_mlir_func_;
-  const BufferAssignment& buffer_assignment_;
-  const se::Platform* platform_;
-  // Cached pointer size extracted from the mlir module.
-  unsigned pointer_size_;
-  // The thunk sequence this IrEmitter generates for the input computation.
-  gpu::ThunkSequence thunk_sequence_;
-
-  TF_DISALLOW_COPY_AND_ASSIGN(LhloDialectEmitter);
-};
-
-}  // namespace mlir_gpu
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_LHLO_DIALECT_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
deleted file mode 100644
index 26c9e15..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
+++ /dev/null
@@ -1,50 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
-
-#include <memory>
-
-#include "llvm/IR/Module.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
-#include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-
-namespace xla {
-namespace mlir_gpu {
-
-MlirCompiler::MlirCompiler() : data_layout_("") {}
-
-se::Platform::Id MlirCompiler::PlatformId() const {
-  return stream_executor::cuda::kCudaPlatformId;
-}
-
-void MlirCompiler::SetModuleHook(IRHook module_hook) {
-  module_hook_ = module_hook;
-}
-
-void MlirCompiler::RemoveModuleHook() {
-  module_hook_ = {nullptr, IRHook::LoweringStage::LHLO};
-}
-
-void MlirCompiler::SetErrorHandler(ErrorHandler error_handler) {
-  error_handler_ = error_handler;
-}
-
-void MlirCompiler::RemoveErrorHandler() { error_handler_ = nullptr; }
-
-}  // namespace mlir_gpu
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h
deleted file mode 100644
index 6c361f4..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h
+++ /dev/null
@@ -1,70 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
-
-#include "llvm/IR/DataLayout.h"
-#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "tensorflow/compiler/xla/service/compiler.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
-
-namespace xla {
-namespace mlir_gpu {
-
-// A Compiler implementation that converts XLAs IR to a matching MLIR dialect,
-// performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for
-// generation of a thunk suitable for XLAs runtime. MlirCompilerImpl contains
-// the implementation.
-class MlirCompiler : public Compiler {
-  using ErrorHandler =
-      std::function<void(const EmissionContext::ErrorMap&, HloModule*)>;
-
- public:
-  MlirCompiler();
-
-  se::Platform::Id PlatformId() const override;
-
-  struct IRHook {
-    enum class LoweringStage { LHLO, GPU, LLVM, KERNEL };
-
-    Status invoke(LoweringStage stage_, mlir::ModuleOp module) {
-      if (callback && stage == stage_) {
-        return callback(module);
-      }
-      return Status::OK();
-    }
-
-    std::function<Status(mlir::ModuleOp)> callback;
-    LoweringStage stage;
-  };
-
-  void SetModuleHook(IRHook module_hook);
-  void RemoveModuleHook();
-  void SetErrorHandler(ErrorHandler error_handler);
-  void RemoveErrorHandler();
-
- protected:
-  ::mlir::MLIRContext context_;
-  llvm::DataLayout data_layout_;
-  IRHook module_hook_;
-  ErrorHandler error_handler_;
-};
-
-}  // namespace mlir_gpu
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc
deleted file mode 100644
index 79525a3..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc
+++ /dev/null
@@ -1,622 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "llvm/IR/LLVMContext.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"  // from @llvm-project
-#include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
-#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
-#include "mlir/IR/Attributes.h"  // from @llvm-project
-#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/IR/Location.h"  // from @llvm-project
-#include "mlir/IR/OperationSupport.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
-#include "mlir/IR/Value.h"  // from @llvm-project
-#include "mlir/Support/LLVM.h"  // from @llvm-project
-#include "mlir/Target/NVVMIR.h"  // from @llvm-project
-#include "tensorflow/compiler/xla/service/buffer_assignment.h"
-#include "tensorflow/compiler/xla/service/dump.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
-#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
-#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
-#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
-#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
-#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
-#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
-#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
-#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
-#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/platform/cuda_libdevice_path.h"
-#include "tensorflow/stream_executor/gpu/asm_compiler.h"
-
-namespace xla {
-namespace mlir_gpu {
-namespace {
-
-using ::mlir::BlockArgument;
-using ::mlir::dyn_cast;
-using ::mlir::FuncOp;
-using ::mlir::ModuleOp;
-using ::mlir::OwningModuleRef;
-using ::mlir::UnknownLoc;
-using ::mlir::Value;
-using ::mlir::gpu::LaunchFuncOp;
-using ::mlir::LLVM::LLVMDialect;
-using ::mlir::LLVM::LLVMFuncOp;
-using ::mlir::LLVM::LLVMType;
-using ::xla::gpu::GpuExecutable;
-using ::xla::gpu::GpuHloSchedule;
-using ::xla::gpu::GpuVersion;
-using ::xla::gpu::StreamAssignment;
-using ::xla::gpu::ThunkSchedule;
-
-// A Compiler implementation that converts XLAs IR to a matching MLIR dialect,
-// performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for
-// generation of a thunk suitable for XLAs runtime.
-class MlirCompilerImpl : public MlirCompiler {
- public:
-  StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
-      std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
-
-  StatusOr<std::unique_ptr<Executable>> RunBackend(
-      std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
-
-  StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
-      std::unique_ptr<HloModuleGroup> module_group,
-      std::vector<std::vector<se::StreamExecutor*>> stream_execs,
-      se::DeviceMemoryAllocator* device_allocator) override;
-
-  StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-  CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
-                     const AotCompilationOptions& options) override;
-
-  HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
-    int64 pointer_size = data_layout_.getPointerSize();
-    return [pointer_size](const Shape& shape) {
-      return ShapeUtil::ByteSizeOf(shape, pointer_size);
-    };
-  }
-};
-
-// TODO(b/137624192) Share with NVPTX compiler
-static std::vector<std::string> CandidateCudaRoots(
-    const HloModuleConfig& config) {
-  return tensorflow::CandidateCudaRoots(
-      config.debug_options().xla_gpu_cuda_data_dir());
-}
-
-void PrintCantFindCudaMessage(absl::string_view msg,
-                              const HloModuleConfig& hlo_module_config) {
-  LOG(WARNING) << msg;
-  LOG(WARNING) << "Searched for CUDA in the following directories:";
-
-  for (const auto& dir : CandidateCudaRoots(hlo_module_config)) {
-    LOG(WARNING) << "  " << dir;
-  }
-  LOG(WARNING)
-      << "You can choose the search directory by setting xla_gpu_cuda_data_dir "
-         "in HloModule's DebugOptions.  For most apps, setting the environment "
-         "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.";
-}
-
-// Returns the directory containing nvvm libdevice files.
-std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
-  for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) {
-    const std::string libdevice_dir =
-        tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice");
-    VLOG(2) << "Looking for libdevice at " << libdevice_dir;
-    if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) {
-      VLOG(2) << "Found libdevice dir " << libdevice_dir;
-      return libdevice_dir;
-    }
-  }
-  PrintCantFindCudaMessage(
-      "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may "
-      "result in compilation or runtime failures, if the program we try to run "
-      "uses routines from libdevice.",
-      hlo_module_config);
-
-  // GetCudaRootCandidates always includes ".", but if everything fails, we
-  // return it anyway.  Better than returning the empty string.
-  return ".";
-}
-
-StatusOr<std::unique_ptr<HloModule>> MlirCompilerImpl::RunHloPasses(
-    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* device_allocator) {
-  // Until we find a reason to do something different, run the same passes
-  // that the normal GPU backend runs.
-  gpu::NVPTXCompiler xla_compiler;
-  TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec,
-                                                    device_allocator));
-  TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get()));
-
-  return std::move(module);
-}
-
-// TODO(b/137624192): Move this to custom call handling and share.
-absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
-                                        const HloInstruction* operand,
-                                        const ShapeIndex& user_index) {
-  if (user->opcode() == HloOpcode::kCustomCall) {
-    // Share the bias buffer with the parent instruction.
-    if (user->custom_call_target() == xla::gpu::kGemmCallTarget) {
-      if (user->operand_count() == 3 && user->operand(2) == operand) {
-        return true;
-      }
-    }
-    // The operand of cholesky can be shared with the first output.
-    if (user->custom_call_target() == xla::gpu::kCusolverCholeskyCallTarget) {
-      return user_index.size() == 1 && user_index[0] == 0;
-    }
-  }
-  return absl::nullopt;
-}
-
-// TODO(b/137624192): Share this with nvptx backend.
-GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) {
-  int cc_major, cc_minor;
-  const auto& device_description = stream_exec->GetDeviceDescription();
-  if (!device_description.cuda_compute_capability(&cc_major, &cc_minor)) {
-    LOG(WARNING)
-        << "Couldn't get compute capability for device; assuming sm_20.";
-    cc_major = 2;
-    cc_minor = 0;
-  }
-  return std::make_pair(cc_major, cc_minor);
-}
-
-// Return the constant launch bound along the "x" dimension in "dim" if all the
-// other dimensions are 1.  Return nullopt otherwise or when any of the bounds
-// is not constant.
-static absl::optional<int64> getLaunchBound(const mlir::gpu::KernelDim3& dim) {
-  auto get_constant = [](mlir::Operation* op,
-                         mlir::StringRef name) -> absl::optional<int64> {
-    if (auto constant = llvm::dyn_cast_or_null<mlir::ConstantOp>(op)) {
-      return constant.value().cast<mlir::IntegerAttr>().getInt();
-    }
-    op->emitError() << "bound " << name << " is not constant";
-    return absl::nullopt;
-  };
-  auto y_op = dim.y.getDefiningOp();
-  auto dim_y = get_constant(y_op, "y");
-  if (!dim_y.has_value() || dim_y.value() != 1) {
-    y_op->emitError() << "bound 'y' is not constant 1";
-    return absl::nullopt;
-  }
-  auto z_op = dim.z.getDefiningOp();
-  auto dim_z = get_constant(z_op, "z");
-  if (!dim_z.has_value() || dim_z.value() != 1) {
-    z_op->emitError() << "bound 'z' is not constant 1";
-    return absl::nullopt;
-  }
-  return get_constant(dim.x.getDefiningOp(), "x");
-}
-
-// Indexes of a range of arguments in a GPU function. This is used to keep the
-// range of arguments that correspond to a lowered kernel argument of
-// (previously) memref type.
-struct LaunchFuncArgument {
-  int kernel_argument_begin;
-  int kernel_argument_size;
-};
-
-using OperandToValueMap =
-    absl::flat_hash_map<const HloInstruction*, std::vector<LaunchFuncArgument>>;
-
-static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
-    OperandToValueMap* operand_to_value_map, const HloInstruction* instr,
-    LaunchFuncOp launchOp, LLVMFuncOp kernel) {
-  auto operands = instr->operands();
-  std::vector<const HloInstruction*> ordered_operands;
-  bool has_failed = false;
-  // A memref will expand into multiple kernel operands, accumulate their number
-  // in order to find them later.
-  int cur_operand_position = 0;
-
-  for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
-       ++kernel_index) {
-    auto launchop_operand =
-        launchOp.getKernelOperand(kernel_index).dyn_cast<BlockArgument>();
-    if (!launchop_operand) {
-      launchOp.emitError("argument to kernel is not a function input");
-      has_failed = true;
-      continue;
-    }
-    auto memref_type =
-        launchop_operand.getType().dyn_cast<::mlir::MemRefType>();
-    if (!memref_type) {
-      launchOp.emitError("only memref-typed arguments are supported");
-      has_failed = true;
-      break;
-    }
-    // host_index is the argument position to the surrounding function that
-    // contains the launch. This index corresponds to HLO operand indices
-    // by construction.
-    auto host_index = launchop_operand.getArgNumber();
-    // The trailing argument to the outer function are the results.
-    auto operand =
-        (host_index < operands.size()) ? operands[host_index] : instr;
-    if (!operand_to_value_map->count(operand)) {
-      ordered_operands.push_back(operand);
-    }
-    // Associate the HLO operand with the argument values of the kernel
-    // function.
-    int num_unpacked =
-        mlir::MemRefDescriptor::getNumUnpackedValues(memref_type);
-    (*operand_to_value_map)[operand].push_back(
-        {cur_operand_position, num_unpacked});
-    cur_operand_position += num_unpacked;
-  }
-  if (has_failed) {
-    return InternalError("Mapping operands to kernel arguments has failed.");
-  }
-  return ordered_operands;
-}
-
-Status InsertBufferLoadPreduleIntoKernel(
-    LLVMFuncOp kernel, const OperandToValueMap& operand_to_value_map,
-    const std::vector<const HloInstruction*>& ordered_operands,
-    BufferAssignment* assignment,
-    const std::vector<const BufferAllocation*>& buffers) {
-  mlir::OpBuilder builder(kernel.getBody());
-  auto* context = kernel.getContext();
-  auto offset_type = LLVMType::getInt64Ty(context);
-  auto ptr_type = LLVMType::getInt8PtrTy(context);
-  auto void_type = LLVMType::getVoidTy(context);
-  auto loc = kernel.getLoc();
-
-  auto num_original_args = kernel.getNumArguments();
-  std::vector<LLVMType> new_arg_types(buffers.size(), ptr_type);
-  kernel.setAttr(kernel.getTypeAttrName(),
-                 mlir::TypeAttr::get(LLVMType::getFunctionTy(
-                     void_type, new_arg_types, /*isVarArg=*/false)));
-  std::vector<Value> original_args(kernel.args_begin(), kernel.args_end());
-
-  std::vector<mlir::Type> as_mlir_types(new_arg_types.begin(),
-                                        new_arg_types.end());
-  auto new_args = kernel.front().addArguments(as_mlir_types);
-  std::vector<Value> buffer_args(new_args.begin(), new_args.end());
-
-  for (auto operand : ordered_operands) {
-    TF_ASSIGN_OR_RETURN(auto slice,
-                        assignment->GetUniqueTopLevelSlice(operand));
-    auto buffer = std::find(buffers.begin(), buffers.end(), slice.allocation());
-    auto index = buffer - buffers.begin();
-    auto offset = builder.create<mlir::LLVM::ConstantOp>(
-        loc, offset_type, builder.getI64IntegerAttr(slice.offset()));
-    auto ptr = buffer_args[index];
-
-    // Replace uses of function arguments pertaining to memref descriptors with
-    // values derived from HLO buffers. The instructions inserting these values
-    // into memref descriptors were already introduced during the lowering phase
-    // as per MLIR calling convention.
-    for (auto arg : operand_to_value_map.at(operand)) {
-      mlir::MemRefDescriptorView original(
-          mlir::ValueRange(original_args)
-              .slice(arg.kernel_argument_begin, arg.kernel_argument_size));
-
-      // Allocated and aligned pointers are the same.
-      auto casted = builder.create<mlir::LLVM::BitcastOp>(
-          loc, original.alignedPtr().getType().cast<LLVMType>(),
-          mlir::ValueRange(ptr));
-      original.alignedPtr().replaceAllUsesWith(casted);
-      original.allocatedPtr().replaceAllUsesWith(casted);
-
-      // Use the offset of the HLO buffer instead of the one expected in the
-      // function call.
-      original.offset().replaceAllUsesWith(offset);
-
-      // Fill the shape.
-      auto shape = operand->shape();
-      // Unless the operand is a scalar pointer, also fill shape and strides.
-      if (shape.dimensions().empty()) {
-        continue;
-      }
-
-      // TODO(b/137624192) Pass in the descriptor to allow for dynamic shapes.
-      assert(shape.IsArray() && shape.is_static());
-      for (auto extent : llvm::enumerate(shape.dimensions())) {
-        auto shape = builder.create<mlir::LLVM::ConstantOp>(
-            loc, original.size(extent.index()).getType(),
-            builder.getI64IntegerAttr(extent.value()));
-        original.size(extent.index()).replaceAllUsesWith(shape);
-      }
-      // Finally, fill the strides.
-      // TODO(b/137624192): Take assigned layout into account.
-      uint64_t accumulator = 0;
-      for (int64_t idx = shape.rank() - 1; idx >= 0; --idx) {
-        if (accumulator == 0) {
-          accumulator = 1;
-        } else {
-          accumulator *= shape.dimensions(idx + 1);
-        }
-        auto stride = builder.create<mlir::LLVM::ConstantOp>(
-            loc, original.stride(idx).getType(),
-            builder.getI64IntegerAttr(accumulator));
-        original.stride(idx).replaceAllUsesWith(stride);
-      }
-    }
-  }
-
-  // Now we can remove the original arguments, as they should have no more
-  // users.
-  for (int i = 0; i < num_original_args; ++i) {
-    kernel.front().eraseArgument(0);
-  }
-
-  return Status::OK();
-}
-
-StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk(
-    FuncOp func, const HloInstruction* const instr, ModuleOp kernel_module,
-    BufferAssignment* assignment) {
-  // Find the single LaunchFuncOp and compute a mapping from operands of
-  // the hlo instruction to the corresponding values of the kernel
-  // function in the target module;
-  LaunchFuncOp launchOp;
-  auto walkResult = func.walk([&launchOp](LaunchFuncOp op) {
-    if (launchOp) {
-      op.emitError("multiple kernels for single top-level HLO");
-      return mlir::WalkResult::interrupt();
-    }
-    launchOp = op;
-    return mlir::WalkResult::advance();
-  });
-  if (walkResult.wasInterrupted()) {
-    return InternalError("Multiple kernels for single top-level HLO");
-  }
-  if (!launchOp) {
-    // If there was no launchOp, then no kernel was generated, so the lowering
-    // from the LHLO ops to the GPU dialect is not implemented yet.
-    return Unimplemented("No kernel was generated.");
-  }
-
-  auto kernel =
-      kernel_module.lookupSymbol<LLVMFuncOp>(launchOp.getKernelName());
-
-  // Store the assignment of operands to block arguments. Note that an operand
-  // might be used in multiple argument positions, hence the vector.
-  OperandToValueMap operand_to_value_map;
-  TF_ASSIGN_OR_RETURN(
-      auto ordered_operands,
-      ComputeOperandToValueMap(&operand_to_value_map, instr, launchOp, kernel));
-
-  // Get the required buffers to support the inputs. Use a set and vector here
-  // to keep the order fixed. This is mostly useful for testing.
-  std::unordered_set<const BufferAllocation*> buffers_needed;
-  std::vector<const BufferAllocation*> buffers;
-  // TODO(b/137624192) Add support for tuples.
-  for (auto operand : ordered_operands) {
-    TF_ASSIGN_OR_RETURN(auto buffer,
-                        assignment->GetUniqueTopLevelSlice(operand));
-    if (buffers_needed.insert(buffer.allocation()).second) {
-      buffers.push_back(buffer.allocation());
-    }
-  }
-
-  // TODO(b/137624192) Add support for temp buffer.
-  // TODO(b/137624192) Add support for constant buffers.
-
-  // Change the signature to match what the XLA runtime expects from the
-  // kernel.
-  TF_RETURN_IF_ERROR(InsertBufferLoadPreduleIntoKernel(
-      kernel, operand_to_value_map, ordered_operands, assignment, buffers));
-
-  // Finally, create the thunk and set the launch dimensions.
-  gpu::Thunk::ThunkInfo info;
-  auto thunk = absl::make_unique<gpu::KernelThunk>(info, buffers,
-                                                   kernel.getName().str());
-
-  // Set launch bounds.
-  mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues();
-  mlir::gpu::KernelDim3 grid = launchOp.getGridSizeOperandValues();
-  absl::optional<int64> num_threads = getLaunchBound(block);
-  absl::optional<int64> num_blocks = getLaunchBound(grid);
-  if (!num_threads || !num_blocks) {
-    return Unimplemented("Unsupported launch bounds");
-  }
-  thunk->SetLaunchDimensions(gpu::LaunchDimensions(*num_blocks, *num_threads));
-  return std::move(thunk);
-}
-
-StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
-    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* device_allocator) {
-  // Determine the HLO schedule, which is an ordering of HLO instructions. This
-  // is used by buffer assignment to enable buffer reuse, and the same ordering
-  // must also be used to determine the thunk launch schedule.
-  std::unique_ptr<StreamAssignment> stream_assignment =
-      xla::gpu::AssignStreams(*module);
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<GpuHloSchedule> hlo_schedule,
-                      GpuHloSchedule::Build(*module, *stream_assignment,
-                                            data_layout_.getPointerSize()));
-
-  // Run buffer analysis on the HLO graph. This analysis figures out which
-  // temporary buffers are required to run the computation.
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferAssignment> buffer_assignment,
-                      BufferAssigner::Run(
-                          module.get(), hlo_schedule->ConsumeHloOrdering(),
-                          BufferSizeBytesFunction(),
-                          /*color_alignment=*/
-                          [](LogicalBuffer::Color) {
-                            return xla::gpu::kXlaAllocatedBufferAlignBytes;
-                          },
-                          /*allocate_buffers_for_constants=*/true,
-                          /*colorer=*/BufferAssigner::DefaultColorer(),
-                          /*must_not_live_out=*/{}, &CanShareBufferHint));
-  DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
-
-  EmissionContext emission_context(std::move(module));
-  if (error_handler_) {
-    emission_context.setErrorHandler(error_handler_);
-  }
-
-  OwningModuleRef mlir_module =
-      ModuleOp::create(UnknownLoc::get(emission_context.getContext()));
-  LhloDialectEmitter lhlo_emitter(&emission_context, *buffer_assignment,
-                                  stream_exec->platform(), *mlir_module);
-
-  absl::flat_hash_map<const HloInstruction*, std::unique_ptr<gpu::Thunk>>
-      hlo_to_thunk;
-  for (HloInstruction* instruction : hlo_schedule->ThunkLaunchOrder()) {
-    TF_RETURN_IF_ERROR(instruction->Visit(&lhlo_emitter));
-    gpu::ThunkSequence thunks = lhlo_emitter.ConsumeThunkSequence();
-    TF_RET_CHECK(thunks.size() <= 1) << instruction->ToString();
-    if (!thunks.empty()) {
-      auto thunk = std::move(thunks.front());
-      hlo_to_thunk[instruction] = std::move(thunk);
-    }
-  }
-
-  TF_RETURN_IF_ERROR(
-      module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module));
-
-  TF_RETURN_IF_ERROR(LowerLHLOToGPU(*mlir_module));
-
-  TF_RETURN_IF_ERROR(
-      module_hook_.invoke(IRHook::LoweringStage::GPU, *mlir_module));
-
-  TF_RETURN_IF_ERROR(LowerKernelBodiesToNVVM(*mlir_module));
-
-  TF_RETURN_IF_ERROR(
-      module_hook_.invoke(IRHook::LoweringStage::LLVM, *mlir_module));
-
-  TF_ASSIGN_OR_RETURN(OwningModuleRef kernel_module,
-                      ExtractKernelModule(*mlir_module));
-
-  for (auto entry : lhlo_emitter.InstructionToFunctionMap()) {
-    TF_ASSIGN_OR_RETURN(
-        auto thunk,
-        TransformKernelToXlaThunk(entry.second, entry.first, *kernel_module,
-                                  buffer_assignment.get()));
-    hlo_to_thunk[entry.first] = std::move(thunk);
-  }
-
-  absl::flat_hash_map<const gpu::Thunk*, const HloInstruction*> thunk_to_hlo;
-  gpu::ThunkSequence thunk_sequence;
-  {
-    for (HloInstruction* hlo : hlo_schedule->ThunkLaunchOrder()) {
-      auto it = hlo_to_thunk.find(hlo);
-      if (it != hlo_to_thunk.end()) {
-        const HloInstruction* hlo = it->first;
-        auto& thunk = it->second;
-        thunk_to_hlo[thunk.get()] = hlo;
-        thunk_sequence.push_back(std::move(thunk));
-      }
-    }
-  }
-
-  TF_RETURN_IF_ERROR(
-      module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module));
-
-  // Translate to LLVM IR in a fresh context. The module is further translated
-  // to textual PTX and a CUBIN blob so there is no need for the context to live
-  // longer than this function.
-  llvm::LLVMContext llvmContext;
-  auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext);
-
-  if (!llvmModule) {
-    return InternalError("Translation to LLVM failed");
-  }
-
-  llvmModule->setModuleIdentifier(emission_context.getHloModule()->name());
-  // TODO(herhut): Why is this needed and does not come from the template?
-  llvmModule->setDataLayout(gpu::nvptx::kDataLayout);
-
-  const auto& config = emission_context.getHloModule()->config();
-  TF_ASSIGN_OR_RETURN(
-      auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(),
-                                              GetGpuVersion(stream_exec),
-                                              config, GetLibdeviceDir(config)));
-  // Allow to fallback to the driver compilation when ptxas isn't able to
-  // compile.
-  StatusOr<std::vector<uint8>> maybe_cubin =
-      se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(),
-                        gpu::PtxOptsFromConfig(config));
-  std::vector<uint8> cubin;
-  if (maybe_cubin.ok()) {
-    cubin = std::move(maybe_cubin).ValueOrDie();
-  } else if (maybe_cubin.status().code() ==
-             tensorflow::error::Code::UNIMPLEMENTED) {
-    xla::gpu::WarnIfBadDriverJITVersion();
-  } else {
-    return maybe_cubin.status();
-  }
-
-  auto thunk_schedule = absl::make_unique<ThunkSchedule>(
-      std::make_unique<gpu::ThunkSequence>(std::move(thunk_sequence)),
-      std::move(stream_assignment), std::move(thunk_to_hlo));
-
-  if (DumpingEnabledForHloModule(*emission_context.getHloModule())) {
-    DumpToFileInDirOrStdout(*emission_context.getHloModule(), "",
-                            "thunk_schedule", thunk_schedule->ToString());
-  }
-
-  // TODO(b/137624192): Add profiling support.
-  return {absl::make_unique<GpuExecutable>(
-      ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule),
-      emission_context.releaseHloModule(), std::move(buffer_assignment),
-      nullptr, nullptr, std::vector<GpuExecutable::ConstantInfo>())};
-}
-
-StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
-    std::unique_ptr<HloModuleGroup> module_group,
-    std::vector<std::vector<se::StreamExecutor*>> stream_execs,
-    se::DeviceMemoryAllocator* device_allocator) {
-  return Unimplemented("Not yet implemented in MLIR compiler");
-}
-
-StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-MlirCompilerImpl::CompileAheadOfTime(
-    std::unique_ptr<HloModuleGroup> /*module_group*/,
-    const AotCompilationOptions& /*options*/) {
-  return Unimplemented("Not yet implemented in MLIR compiler");
-}
-
-}  // namespace
-}  // namespace mlir_gpu
-}  // namespace xla
-
-static bool InitModule() {
-  xla::Compiler::RegisterCompilerFactory(
-      stream_executor::cuda::kCudaPlatformId, []() {
-        return absl::make_unique<xla::FailoverCompiler>(
-            absl::make_unique<xla::mlir_gpu::MlirCompilerImpl>(),
-            absl::make_unique<xla::gpu::NVPTXCompiler>());
-      });
-  return true;
-}
-static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.cc b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc
index 84751bc..2cabd92 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/passes.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc
@@ -257,8 +257,8 @@
       auto new_kernel = kernel_builder.create<mlir::gpu::GPUFuncOp>(
           kernel.getLoc(), kernel.getName(),
           kernel_builder.getFunctionType(operand_types, {}));
-      new_kernel.setAttr(mlir::gpu::GPUDialect::getKernelFuncAttrName(),
-                         kernel_builder.getUnitAttr());
+      new_kernel->setAttr(mlir::gpu::GPUDialect::getKernelFuncAttrName(),
+                          kernel_builder.getUnitAttr());
 
       // Create a map from old kernel argument to new one.
       mlir::BlockAndValueMapping old_kernel_to_new;
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD
deleted file mode 100644
index 9bd5e33..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD
+++ /dev/null
@@ -1,48 +0,0 @@
-load("//tensorflow:tensorflow.bzl", "filegroup")
-load(
-    "//tensorflow/core/platform:build_config_root.bzl",
-    "tf_cuda_tests_tags",
-    "tf_exec_properties",
-)
-load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
-
-package(
-    default_visibility = [":friends"],
-    licenses = ["notice"],  # Apache 2.0
-)
-
-package_group(
-    name = "friends",
-    includes = [
-        "//tensorflow/compiler/xla:friends",
-    ],
-)
-
-glob_lit_tests(
-    data = [
-        ":test_utilities",
-        "@llvm-project//mlir:run_lit.sh",
-    ],
-    default_tags = tf_cuda_tests_tags() + [
-        "no_pip",
-        "config-cuda-only",
-        "no_rocm",
-    ],
-    driver = "//tensorflow/compiler/mlir:run_lit.sh",
-    exclude = [
-        # TODO(b/137624192): Reenable once we can fuse reductions.
-        "fused_reduce.hlo",
-    ],
-    exec_properties = tf_exec_properties({"tags": tf_cuda_tests_tags()}),
-    test_file_exts = ["hlo"],
-)
-
-# Bundle together all of the test utilities that are used by tests.
-filegroup(
-    name = "test_utilities",
-    testonly = True,
-    data = [
-        "//tensorflow/compiler/xla/service/mlir_gpu:xla-gpu-opt",
-        "@llvm-project//llvm:FileCheck",
-    ],
-)
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo
deleted file mode 100644
index ba29b0a..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Abs
-ENTRY %Abs (val: f32[2,2]) -> f32[2,2] {
-  %val = f32[2,2]{1,0} parameter(0)
-  ROOT %abs = f32[2,2]{1,0} abs(f32[2,2]{1,0} %val)
-}
-
-//  CHECK: func @abs(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
-//  CHECK:   "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
-//  CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo
deleted file mode 100644
index 37c163e..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo
+++ /dev/null
@@ -1,12 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Add
-
-ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  %y = f32[2,2]{1,0} parameter(1)
-  ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
-}
-
-// CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
-// CHECK:   "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
-// CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo
deleted file mode 100644
index 8d7930e..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo
+++ /dev/null
@@ -1,63 +0,0 @@
-// RUN: xla-gpu-opt -lowering-stage=KERNEL %s | FileCheck %s
-HloModule Add
-
-ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  %y = f32[2,2]{1,0} parameter(1)
-  ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
-}
-
-//  CHECK: func @add_kernel(%[[ARG0:.*]]: [[TYPE:!llvm\..*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]
-
-//
-//   Check that relevant sizes and strides are emitted.
-//
-//  CHECK: %[[CAST0:.*]] = llvm.bitcast %[[ARG0:.*]] : !llvm.ptr<i8> to !llvm.ptr<float>
-//  CHECK: %[[SIZE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
-//  CHECK: %[[SIZE01:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
-//  CHECK: %[[STRIDE01:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
-//  CHECK: %[[STRIDE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
-
-//  CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG1:.*]] : !llvm.ptr<i8> to !llvm.ptr<float>
-//  CHECK: %[[SIZE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
-//  CHECK: %[[SIZE11:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
-//  CHECK: %[[STRIDE11:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
-//  CHECK: %[[STRIDE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
-
-//  CHECK: %[[CAST2:.*]] = llvm.bitcast %[[ARG2:.*]] : !llvm.ptr<i8> to !llvm.ptr<float>
-//  CHECK: %[[SIZE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
-//  CHECK: %[[SIZE21:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
-//  CHECK: %[[STRIDE21:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
-//  CHECK: %[[STRIDE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
-
-//
-//   Check that the emitted sizes and strides, as well the pointers to HLO buffers,
-//   are inserted into the memref descriptors.
-//
-//  CHECK: %[[DESC0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC01:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC0]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC02:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC01]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC03:.*]] = llvm.insertvalue %{{.*}}, %[[DESC02]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC04:.*]] = llvm.insertvalue %[[SIZE00]], %[[DESC03]][3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC05:.*]] = llvm.insertvalue %[[STRIDE00]], %[[DESC04]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC06:.*]] = llvm.insertvalue %[[SIZE01]], %[[DESC05]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE01]], %[[DESC06]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-
-//  CHECK: %[[DESC1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC1]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC12:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC11]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC13:.*]] = llvm.insertvalue %{{.*}}, %[[DESC12]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC14:.*]] = llvm.insertvalue %[[SIZE10]], %[[DESC13]][3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC15:.*]] = llvm.insertvalue %[[STRIDE10]], %[[DESC14]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC16:.*]] = llvm.insertvalue %[[SIZE11]], %[[DESC15]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE11]], %[[DESC16]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-
-//  CHECK: %[[DESC2:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC21:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC2]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC22:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC21]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC23:.*]] = llvm.insertvalue %{{.*}}, %[[DESC22]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC24:.*]] = llvm.insertvalue %[[SIZE20]], %[[DESC23]][3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC25:.*]] = llvm.insertvalue %[[STRIDE20]], %[[DESC24]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %[[DESC26:.*]] = llvm.insertvalue %[[SIZE21]], %[[DESC25]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE21]], %[[DESC26]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo
deleted file mode 100644
index db39919..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo
+++ /dev/null
@@ -1,22 +0,0 @@
-// RUN: xla-gpu-opt -lowering-stage=GPU %s | FileCheck %s
-HloModule Add
-
-ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  %y = f32[2,2]{1,0} parameter(1)
-  ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
-}
-
-// CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
-// CHECK: gpu.launch_func
-// CHECK-SAME: blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args
-// CHECK-SAME: (%[[ARG0]] : [[TYPE]], %[[ARG1]] : [[TYPE]], %[[ARG2]] : [[TYPE]])
-// CHECK: }
-// CHECK: func @add_kernel(%[[ARG0]]: [[TYPE]], %[[ARG1]]: [[TYPE]], %[[ARG2]]: [[TYPE]]
-// CHECK-DAG: subview %[[ARG0]]{{\[}}[[INDEX:.*]]]
-// CHECK-DAG: subview %[[ARG1]]{{\[}}[[INDEX]]]
-// CHECK-DAG: subview %[[ARG2]]{{\[}}[[INDEX]]]
-// CHECK: %[[VAL1:.*]] = load %{{.*\[}}[[INDEX:.*]]]
-// CHECK: %[[VAL2:.*]] = load %{{.*\[}}[[INDEX]]]
-// CHECK: %[[RES:.*]] = addf %[[VAL1]], %[[VAL2]]
-// CHECK: store %[[RES]], %{{.*\[}}[[INDEX]]]
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo
deleted file mode 100644
index 2603b92..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo
+++ /dev/null
@@ -1,22 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule AddMultiply
-
-ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  %y = f32[2,2]{1,0} parameter(1)
-  %z = f32[2,2]{1,0} parameter(2)
-  %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
-  ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z)
-}
-
-//  CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]])
-//  CHECK: "lmhlo.fusion"() ( {
-//  CHECK:   %[[REF0:.*]] = tensor_load %[[ARG0]] : [[TYPE]]
-//  CHECK:   %[[REF1:.*]] = tensor_load %[[ARG1]] : [[TYPE]]
-//  CHECK:   %[[REF2:.*]] = tensor_load %[[ARG2]] : [[TYPE]]
-//  CHECK:   %[[ADD:.*]] = mhlo.add %[[REF1]], %[[REF2]]
-//  CHECK:   %[[MUL:.*]] = mhlo.multiply %[[ADD]], %[[REF0]]
-//  CHECK:   tensor_store %[[MUL]], %[[RESULT]]
-//  CHECK:   "lmhlo.terminator"()
-//  CHECK-NEXT: }
-
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo
deleted file mode 100644
index 645175f..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: xla-gpu-opt -lowering-stage=GPU %s | FileCheck %s
-HloModule AddMultiply
-
-ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  %y = f32[2,2]{1,0} parameter(1)
-  %z = f32[2,2]{1,0} parameter(2)
-  %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
-  ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z)
-}
-
-//  CHECK: func @fusion_kernel(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]])
-//  CHECK-DAG: subview %[[ARG0]]{{\[}}[[INDEX:.*]]]
-//  CHECK-DAG: subview %[[ARG1]]{{\[}}[[INDEX]]]
-//  CHECK-DAG: subview %[[ARG2]]{{\[}}[[INDEX]]]
-//  CHECK-DAG: subview %[[RESULT]]{{\[}}[[INDEX]]]
-//  CHECK:   %[[V0:.*]] = load %{{.*\[}}[[CSTIDX:.*]]]
-//  CHECK:   %[[V1:.*]] = load %{{.*\[}}[[CSTIDX:.*]]]
-//  CHECK:   %[[ADD:.*]] = addf %[[V0]], %[[V1]]
-//  CHECK:   %[[V2:.*]] = load %{{.*\[}}[[CSTIDX:.*]]]
-//  CHECK:   %[[MUL:.*]] = mulf %[[ADD]], %[[V2]]
-//  CHECK:   store %[[MUL]], %{{.*\[}}[[CSTIDX:.*]]]
-//  CHECK: return
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo
deleted file mode 100644
index a57f427..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo
+++ /dev/null
@@ -1,24 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule AddReduce
-
-%add (x: f32[], y: f32[]) -> f32[] {
-  %x = f32[] parameter(0)
-  %y = f32[] parameter(1)
-  ROOT %add = f32[] add(f32[] %x, f32[] %y)
-}
-
-ENTRY %AddReduce (x: f32[100,10], c: f32[]) -> f32[100] {
-  %x = f32[100,10]{1,0} parameter(0)
-  %c = f32[] parameter(1)
-  ROOT %reduce = f32[100]{0} reduce(f32[100,10]{1,0} %x, f32[] %c), dimensions={1}, to_apply=%add
-}
-
-//  CHECK: func @reduce(%[[ARG:.*]]: [[ARGT:.*]], %[[CST:.*]]: memref<f32>, %[[RES:.*]]: [[REST:.*]]) {
-//  CHECK:   "lmhlo.reduce"(%[[ARG]], %[[CST]], %[[RES]]) ( {
-//  CHECK:   ^bb0(%[[FARG0:.*]]: memref<f32>, %[[FARG1:.*]]: memref<f32>, %[[FRES:.*]]: memref<f32>):
-//  CHECK:      %[[LHS:.*]] = tensor_load %[[FARG0]] : memref<f32>
-//  CHECK:      %[[RHS:.*]] = tensor_load %[[FARG1]] : memref<f32>
-//  CHECK:      %[[RES:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<f32>
-//  CHECK:      tensor_store %[[RES]], %[[FRES]] : memref<f32>
-//  CHECK:     "lmhlo.terminator"() : () -> ()
-//  CHECK-NEXT: }) {dimensions = dense<1> : tensor<1xi64>} : ([[ARGT]], memref<f32>, [[REST]]) -> ()
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo
deleted file mode 100644
index 366545c..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Broadcast
-
-ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] {
-  %x = f32[10]{0} parameter(0)
-  ROOT %broadcast = f32[10, 5]{1,0} broadcast(f32[10]{0} %x), dimensions={0}
-}
-
-//  CHECK: func @broadcast(%[[IN:.*]]: [[IN_T:.*]],  %[[OUT:.*]]: [[OUT_T:.*]]) {
-//  CHECK:   "lmhlo.broadcast_in_dim"(%[[IN]], %[[OUT]])
-//  CHECK:   {broadcast_dimensions = dense<0> : tensor<1xi64>}
-//  CHECK:   : ([[IN_T]], [[OUT_T]]) -> ()
-//  CHECK: }
-
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo
deleted file mode 100644
index 6bbddb6..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: xla-gpu-opt -verify-errors %s | FileCheck %s
-HloModule Add
-
-ENTRY %Add (x: f32[2,2,2], y: f32[2,2,2]) -> f32[2,2,2] {
-  %x = f32[2,2,2]{2,1,0} parameter(0)
-  %y = f32[2,2,2]{2,1,0} parameter(1)
-  ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y)
-}
-
-// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: lmhlo.add; failed for testing: std.return]
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo
deleted file mode 100644
index f45fa1a..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Ceil
-ENTRY %Ceil (val: f32[2,2]) -> f32[2,2] {
-  %val = f32[2,2]{1,0} parameter(0)
-  ROOT %ceil = f32[2,2]{1,0} ceil(f32[2,2]{1,0} %val)
-}
-
-//  CHECK: func @ceil(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
-//  CHECK:   "lmhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
-//  CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo
deleted file mode 100644
index 2a34f49..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo
+++ /dev/null
@@ -1,13 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Compare
-
-ENTRY %Compare (x: f32[2,2], y: f32[2,2]) -> pred[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  %y = f32[2,2]{1,0} parameter(1)
-  ROOT %compare = pred[2,2]{1,0} compare(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y), direction=EQ
-}
-
-// CHECK: func @compare(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[PRED:.*]]: [[PRED_TYPE:.*]]) {
-// CHECK:   "lmhlo.compare"(%[[ARG0]], %[[ARG1]], %[[PRED]])
-// CHECK: {comparison_direction = "EQ"} : ([[TYPE]], [[TYPE]], [[PRED_TYPE]]) -> ()
-// CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo
deleted file mode 100644
index 99a4872..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo
+++ /dev/null
@@ -1,12 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Complex
-
-ENTRY %Complex (real: f32[2,2]{0,1}, imag: f32[2,2]{0,1}) -> c64[2,2] {
-  %real = f32[2,2]{0,1} parameter(0)
-  %imag = f32[2,2]{0,1} parameter(1)
-  ROOT %compl = c64[2,2]{0,1} complex(%real, %imag)
-}
-
-// CHECK: func @complex(%[[REAL:.*]]: [[BUF_F32:.*]], %[[IMAG:.*]]: [[BUF_F32]], %[[OUT:.*]]: [[BUF_C64:.*]]) {
-// CHECK:   "lmhlo.complex"(%[[REAL]], %[[IMAG]], %[[OUT]]) : ([[BUF_F32]], [[BUF_F32]], [[BUF_C64]]) -> ()
-// CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo
deleted file mode 100644
index 06f2918..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo
+++ /dev/null
@@ -1,13 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Concatenate
-
-ENTRY %Concatenate (x: f32[2,3], y: f32[2,2]) -> f32[2,5] {
-  %x = f32[2,3]{1,0} parameter(0)
-  %y = f32[2,2]{1,0} parameter(1)
-  ROOT %concatenate = f32[2,5]{1,0} concatenate(f32[2,3]{1,0} %x, f32[2,2]{1,0} %y), dimensions={1}
-}
-
-// CHECK: func @concatenate(%[[ARG0:.*]]: [[TYPE0:.*]], %[[ARG1:.*]]: [[TYPE1:.*]], %[[RESULT:.*]]: [[RTYPE:.*]]) {
-// CHECK:   "lmhlo.concatenate"(%[[ARG0]], %[[ARG1]], %[[RESULT]])
-// CHECK:   {dimension = 1 : i64} : ([[TYPE0]], [[TYPE1]], [[RTYPE]]) -> ()
-// CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo
deleted file mode 100644
index e0745c4..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo
+++ /dev/null
@@ -1,12 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Const
-
-ENTRY %Const () -> s32[100] {
-  %const.0 = s32[] constant(10)
-  ROOT %broadcast.0 = s32[100]{0} broadcast(s32[] %const.0), dimensions={}
-}
-
-// CHECK: func @constant(%[[ARG0:.*]]: memref<i32>)
-// CHECK:   "lmhlo.constant"(%[[ARG0]]) {value = dense<10> : tensor<i32>}
-// CHECK: func @broadcast(%[[ARG1:.*]]: memref<i32>, %[[ARG2:.*]]: memref<100xi32>)
-// CHECK:   "lmhlo.broadcast_in_dim"(%[[ARG1]], %[[ARG2]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo
deleted file mode 100644
index b4058da..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Copy
-
-ENTRY %Copy (x: f32[2,4]) -> f32[2,4] {
-  %x = f32[2,4] parameter(0)
-  ROOT %copy = f32[2,4] copy(f32[2,4] %x)
-}
-
-// CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, %[[RESULT:.*]]: memref<2x4xf32>) {
-// CHECK:   "lmhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> ()
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo
deleted file mode 100644
index 8656b4e..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo
+++ /dev/null
@@ -1,13 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule CopyTranspose
-
-ENTRY %CopyTranspose (x: f32[2,4]) -> f32[2,4]{0,1} {
-  %x = f32[2,4] parameter(0)
-  ROOT %copy = f32[2,4]{0,1} copy(f32[2,4] %x)
-}
-
-// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 2)>
-// CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>,
-// CHECK-SAME:       %[[RESULT:.*]]: memref<2x4xf32, #[[MAP0]]>) 
-// CHECK:   "lmhlo.copy"(%[[OPERAND]], %[[RESULT]])
-// CHECK-SAME: : (memref<2x4xf32>, memref<2x4xf32, #[[MAP0]]>)
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo
deleted file mode 100644
index 8a00a56..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Cos
-ENTRY %Cos (val: f32[2,2]) -> f32[2,2] {
-  %val = f32[2,2]{1,0} parameter(0)
-  ROOT %cos = f32[2,2]{1,0} cosine(f32[2,2]{1,0} %val)
-}
-
-//  CHECK: func @cosine(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
-//  CHECK:   "lmhlo.cosine"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
-//  CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo
deleted file mode 100644
index 42cc605..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo
+++ /dev/null
@@ -1,12 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Exp
-
-ENTRY %Exp (x: f32[2,2]) -> f32[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  ROOT %exp = f32[2,2]{1,0} exponential(f32[2,2]{1,0} %x)
-}
-
-// CHECK: func @exponential(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
-// CHECK:   "lmhlo.exponential"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
-// CHECK: }
-
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo
deleted file mode 100644
index f74cdef..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule FusedReduce
-
-%add (x: f32[], y: f32[]) -> f32[] {
-  %x = f32[] parameter(0)
-  %y = f32[] parameter(1)
-  ROOT %add = f32[] add(f32[] %x, f32[] %y)
-}
-
-%fused_computation (param: f32[100,10]) -> f32[10] {
-  %param = f32[100,10] parameter(0)
-  %constant = f32[] constant(0)
-  ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant),
-      dimensions={0}, to_apply=%add
-}
-
-ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] {
-  %x = f32[100,10] parameter(0)
-  ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput,
-      calls=%fused_computation
-}
-
-//  CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]])
-//  CHECK: "lmhlo.fusion"() ( {
-//  CHECK:   %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]]
-//  CHECK:   %[[CT0:.*]] = mhlo.constant dense<0.000000e+00>
-//  CHECK:   %[[RED:.*]] = "mhlo.reduce"(%0, %1) ( {
-//  CHECK:     ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]])
-//  CHECK:       %[[ADD:.*]] = mhlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]]
-//  CHECK:       "mhlo.return"(%[[ADD]])
-//  CHECK:     })
-//  CHECK:   tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]]
-//  CHECK:   "lmhlo.terminator"()
-//  CHECK-NEXT: })
-
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo
deleted file mode 100644
index 6a4f020..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo
+++ /dev/null
@@ -1,22 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Gather
-
-ENTRY %Gather (x: f32[100,10], y: s64[4,6]) -> f32[4,6,10] {
-  %x = f32[100,10] parameter(0)
-  %y = s64[4,6] parameter(1)
-  ROOT %gather = f32[4,6,10]{2,1,0} gather(f32[100,10]{1,0} %x, s64[4,6]{1,0} %y),
-      collapsed_slice_dims={0}, index_vector_dim=2, offset_dims={2},
-      slice_sizes={1,10}, start_index_map={0}
-}
-
-// CHECK: func @gather(%[[ARG0:.*]]: [[TYPE0:.*]], %[[ARG1:.*]]: [[TYPE1:.*]],
-// CHECK-SAME:         %[[RESULT:.*]]: [[RTYPE:.*]]) {
-// CHECK-NEXT: "lmhlo.gather"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) {
-// CHECK-SAME:   dimension_numbers = {
-// CHECK-SAME:     collapsed_slice_dims = dense<0> : tensor<1xi64>,
-// CHECK-SAME:     index_vector_dim = 2 : i64,
-// CHECK-SAME:     offset_dims = dense<2> : tensor<1xi64>,
-// CHECK-SAME:     start_index_map = dense<0> : tensor<1xi64>
-// CHECK-SAME:   },
-// CHECK-SAME:   slice_sizes = dense<[1, 10]> : tensor<2xi64>
-// CHECK-SAME: } : ([[TYPE0]], [[TYPE1]], [[RTYPE]]) -> ()
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo
deleted file mode 100644
index 50ff557..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Imag
-
-ENTRY %Imag (x: c64[2,2]{0,1}) -> f32[2,2] {
-  %x = c64[2,2]{0,1} parameter(0)
-  ROOT %imag = f32[2,2]{0,1} imag(%x)
-}
-
-// CHECK: func @imag(%[[IN:.*]]: [[BUF_C64:.*]], %[[OUT:.*]]: [[BUF_F32:.*]]) {
-// CHECK:   "lmhlo.imag"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> ()
-// CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo
deleted file mode 100644
index 1755e4b..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Iota
-
- ENTRY %Iota() -> s64[10, 5] {
-  ROOT %iota = s64[10, 5]{1,0} iota(), iota_dimension=0
-}
-
-//  CHECK: func @iota(%[[OUT:.*]]: [[OUT_T:.*]]) {
-//  CHECK:   "lmhlo.iota"(%[[OUT]])
-//  CHECK:   {iota_dimension = 0 : i64} : ([[OUT_T]]) -> ()
-//  CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_subtract.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_subtract.hlo
deleted file mode 100644
index 6c019db..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_subtract.hlo
+++ /dev/null
@@ -1,16 +0,0 @@
-// RUN: xla-gpu-opt -lowering-stage=GPU %s | FileCheck %s
-HloModule AddSubtract
-
-ENTRY %AddSubtract (x: s32[2,2], y: s32[2,2]) -> s32[2,2] {
-  %x = s32[2,2]{1,0} parameter(0)
-  %y = s32[2,2]{1,0} parameter(1)
-
-  %add = s32[2,2]{1,0} add(s32[2,2]{1,0} %x, s32[2,2]{1,0} %y)
-  %iota = s32[2, 2]{1,0} iota(), iota_dimension=0
-
-  ROOT %sub = s32[2,2]{1,0} subtract(s32[2,2]{1,0} %add, s32[2,2]{1,0} %iota)
-}
-
-//  CHECK-NOT:  store
-//  CHECK:      [[RESULT:%.*]] = subi %{{.*}}, %{{.*}}
-//  CHECK:      store [[RESULT]]
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo
deleted file mode 100644
index 5f11564..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Log
-
-ENTRY %Log (x: f32[2,2]) -> f32[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  ROOT %log = f32[2,2]{1,0} log(f32[2,2]{1,0} %x)
-}
-
-// CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
-// CHECK:   "lmhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
-// CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo
deleted file mode 100644
index 30557f1..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Neg
-ENTRY %Neg (val: f32[2,2]) -> f32[2,2] {
-  %val = f32[2,2]{1,0} parameter(0)
-  ROOT %neg = f32[2,2]{1,0} negate(f32[2,2]{1,0} %val)
-}
-
-//  CHECK: func @negate(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
-//  CHECK:   "lmhlo.negate"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
-//  CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo
deleted file mode 100644
index 559a4db..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Real
-
-ENTRY %Real (x: c64[2,2]{0,1}) -> f32[2,2] {
-  %x = c64[2,2]{0,1} parameter(0)
-  ROOT %real = f32[2,2]{0,1} real(%x)
-}
-
-// CHECK: func @real(%[[IN:.*]]: [[BUF_C64:.*]], %[[OUT:.*]]: [[BUF_F32:.*]]) {
-// CHECK:   "lmhlo.real"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> ()
-// CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo
deleted file mode 100644
index 4c23a98..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule ReduceWindow
-
-%max (x: f32[], y: f32[]) -> f32[] {
-  %x = f32[] parameter(0)
-  %y = f32[] parameter(1)
-  ROOT %max = f32[] maximum(f32[] %x, f32[] %y)
-}
-
-ENTRY %ReduceWindow (x: f32[128,64,112,112], y: f32[]) -> f32[128,64,56,56] {
-  %x = f32[128,64,112,112] parameter(0)
-  %y = f32[] parameter(1)
-  ROOT %reduce-window = f32[128,64,56,56] reduce-window(
-    f32[128,64,112,112] %x,
-    f32[] %y
-  ),
-  window={size=1x1x3x3 stride=1x1x2x2 pad=0_0x0_0x0_1x0_1}, to_apply=%max
-}
-
-// CHECK: func @"reduce-window"(
-// CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[CST:%.*]]: memref<f32>, [[RES:%.*]]: [[REST:.*]]) {
-// CHECK: "lmhlo.reduce_window"([[LHS:%.*]], [[RHS:%.*]], [[OUT:%.*]]) ( {
-// CHECK:   ^bb0([[LHS:%.*]]: memref<f32>, [[RHS:%.*]]: memref<f32>, [[OUT:%.*]]: memref<f32>):
-// CHECK:     [[LHS_TENSOR:%.*]] = tensor_load [[LHS]]
-// CHECK:     [[RHS_TENSOR:%.*]] = tensor_load [[RHS]]
-// CHECK:     [[OUT_TENSOR:%.*]] = mhlo.maximum [[LHS_TENSOR]], [[RHS_TENSOR]]
-// CHECK:     tensor_store [[OUT_TENSOR]], [[OUT]]
-// CHECK:     "lmhlo.terminator"() : () -> ()
-// CHECK:   }) {
-// CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>
-// CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]>
-// CHECK-SAME: window_dilations = dense<1> : tensor<4xi64>
-// CHECK-SAME: window_dimensions = dense<[1, 1, 3, 3]>
-// CHECK-SAME: window_strides = dense<[1, 1, 2, 2]>
-// CHECK: } : ([[ARGT]], memref<f32>, [[REST]]) -> ()
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo
deleted file mode 100644
index 6d3afb0..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Rem
-ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  %y = f32[2,2]{1,0} parameter(1)
-  ROOT %rem = f32[2,2]{1,0} remainder(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
-}
-
-//  CHECK: func @remainder(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
-//  CHECK:   "lmhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
-//  CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo
deleted file mode 100644
index 11d18e8..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Rsqrt
-
-ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  ROOT %rsqrt = f32[2,2]{1,0} rsqrt(f32[2,2]{1,0} %x)
-}
-
-//  CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
-//  CHECK:   "lmhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
-//  CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo
deleted file mode 100644
index bf25c69..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Select
-
-ENTRY %Select (p: pred[2,2], x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
-  %p = pred[2,2]{1,0} parameter(0)
-  %x = f32[2,2]{1,0} parameter(1)
-  %y = f32[2,2]{1,0} parameter(2)
-  ROOT %select = f32[2,2]{1,0} select(pred[2,2]{1,0} %p, f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
-}
-
-// CHECK: func @select(%[[PRED:.*]]: [[PRED_TYPE:.*]], %[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
-// CHECK:   "lmhlo.select"(%[[PRED]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[PRED_TYPE]], [[TYPE]], [[TYPE]], [[TYPE]]) -> ()
-// CHECK: }
-
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo
deleted file mode 100644
index 46d2985..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo
+++ /dev/null
@@ -1,54 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule SelectAndScatter
-
-%ge (x: f32[], y: f32[]) -> pred[] {
-  %x = f32[] parameter(0)
-  %y = f32[] parameter(1)
-  ROOT %compare = pred[] compare(f32[] %x, f32[] %y), direction=GE
-}
-
-%add (x: f32[], y: f32[]) -> f32[] {
-  %x = f32[] parameter(0)
-  %y = f32[] parameter(1)
-  ROOT %add = f32[] add(f32[] %x, f32[] %y)
-}
-
-ENTRY %SelectAndScatter (x: f32[128,64,112,112],
-                         y: f32[128,64,56,56],
-                         z: f32[]) -> f32[128,64,112,112] {
-  %x = f32[128,64,112,112] parameter(0)
-  %y = f32[128,64,56,56] parameter(1)
-  %z = f32[] parameter(2)
-  ROOT %result = f32[128,64,112,112] select-and-scatter(
-    f32[128,64,112,112] %x,
-    f32[128,64,56,56] %y,
-    f32[] %z),
-  window={size=1x1x3x3 stride=1x1x2x2 pad=0_0x0_0x0_1x0_1},
-  select=%ge,
-  scatter=%add
-}
-
-// CHECK: func @"select-and-scatter"(
-// CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[SRC:%.*]]: [[SRCT:.*]], [[CST:%.*]]: memref<f32>, [[RES:%.*]]: [[REST:.*]]) {
-// CHECK: "lmhlo.select_and_scatter"([[ARG]], [[SRC]], [[CST]], [[RES]]) ( {
-// CHECK:   ^bb0([[LHS:%.*]]: memref<f32>, [[RHS:%.*]]: memref<f32>,
-// CHECK-SAME:   [[OUT:%.*]]: memref<i1>):
-// CHECK:     [[LHS_TENSOR:%.*]] = tensor_load [[LHS]]
-// CHECK:     [[RHS_TENSOR:%.*]] = tensor_load [[RHS]]
-// CHECK:     [[OUT_TENSOR:%.*]] = "mhlo.compare"
-// CHECK-SAME:    ([[LHS_TENSOR]], [[RHS_TENSOR]]) {comparison_direction = "GE"}
-// CHECK:     tensor_store [[OUT_TENSOR]], [[OUT]]
-// CHECK:     lmhlo.terminator
-// CHECK:   },  {
-// CHECK:   ^bb0([[LHS_:%.*]]: memref<f32>, [[RHS_:%.*]]: memref<f32>,
-// CHECK-SAME:   [[OUT_:%.*]]: memref<f32>):
-// CHECK:     [[LHS_TENSOR_:%.*]] = tensor_load [[LHS_]]
-// CHECK:     [[RHS_TENSOR_:%.*]] = tensor_load [[RHS_]]
-// CHECK:     [[OUT_TENSOR_:%.*]] = mhlo.add [[LHS_TENSOR_]], [[RHS_TENSOR_]]
-// CHECK:     tensor_store [[OUT_TENSOR_]], [[OUT_]]
-// CHECK:     lmhlo.terminator
-// CHECK:   }) {
-// CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]>
-// CHECK-SAME: window_dimensions = dense<[1, 1, 3, 3]>
-// CHECK-SAME: window_strides = dense<[1, 1, 2, 2]>
-// CHECK-SAME: } : ([[ARGT]], [[SRCT]], memref<f32>, [[REST]]) -> ()
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo
deleted file mode 100644
index 6acadb8..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Sign
-ENTRY %Sign (val: f32[2,2]) -> f32[2,2] {
-  %val = f32[2,2]{1,0} parameter(0)
-  ROOT %sign = f32[2,2]{1,0} sign(f32[2,2]{1,0} %val)
-}
-
-//  CHECK: func @sign(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
-//  CHECK:   "lmhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
-//  CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo
deleted file mode 100644
index 4e47229..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo
+++ /dev/null
@@ -1,12 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Sqrt
-
-ENTRY %Sqrt (x: f32[2,2]) -> f32[2,2] {
-  %x = f32[2,2]{1,0} parameter(0)
-  ROOT %sqrt = f32[2,2]{1,0} sqrt(f32[2,2]{1,0} %x)
-}
-
-// CHECK: func @sqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
-// CHECK:   "lmhlo.sqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
-// CHECK: }
-
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo
deleted file mode 100644
index 681c18a..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: xla-gpu-opt %s | FileCheck %s
-HloModule Tanh
-ENTRY %Tanh (val: f32[2,2]) -> f32[2,2] {
-  %val = f32[2,2]{1,0} parameter(0)
-  ROOT %tanh = f32[2,2]{1,0} tanh(f32[2,2]{1,0} %val)
-}
-
-//  CHECK: func @tanh(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
-//  CHECK:   "lmhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
-//  CHECK: }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.cc b/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.cc
deleted file mode 100644
index 775901b..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.cc
+++ /dev/null
@@ -1,167 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.h"
-
-#include <memory>
-#include <string>
-
-#include "absl/strings/str_join.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
-#include "mlir/Pass/PassManager.h"  // from @llvm-project
-#include "tensorflow/compiler/xla/debug_options_flags.h"
-#include "tensorflow/compiler/xla/service/hlo_module_config.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status.h"
-#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/stream_executor/lib/statusor.h"
-
-namespace xla {
-namespace mlir_gpu {
-
-Status XlaGpuOpt::CompileIr(std::unique_ptr<HloModule> hlo_module,
-                            const MlirCompiler::IRHook& ir_hook) {
-  MlirCompiler* compiler = GetMLIRCompiler();
-  compiler->SetModuleHook(ir_hook);
-  TF_ASSIGN_OR_RETURN(hlo_module, backend_->compiler()->RunHloPasses(
-                                      std::move(hlo_module),
-                                      backend_->default_stream_executor(),
-                                      /*device_allocator=*/nullptr));
-  Status status = backend_->compiler()
-                      ->RunBackend(std::move(hlo_module),
-                                   backend_->default_stream_executor(),
-                                   /*device_allocator=*/nullptr)
-                      .status();
-  compiler->RemoveModuleHook();
-  return status;
-}
-
-StatusOr<std::string> XlaGpuOpt::CompileIr(
-    std::unique_ptr<HloModule> hlo_module,
-    MlirCompiler::IRHook::LoweringStage printing_stage) {
-  std::string ir;
-  TF_RETURN_IF_ERROR(CompileIr(
-      std::move(hlo_module), {[&ir](mlir::ModuleOp module) -> Status {
-                                std::string buffer_string;
-                                llvm::raw_string_ostream ostream(buffer_string);
-                                module.print(ostream);
-                                ostream.flush();
-                                ir = buffer_string;
-                                return Status::OK();
-                              },
-                              printing_stage}));
-  return ir;
-}
-
-Status XlaGpuOpt::CompileAndOutputIr(std::unique_ptr<HloModule> hlo_module,
-                                     llvm::raw_ostream& os,
-                                     LoweringStage printing_stage) {
-  TF_ASSIGN_OR_RETURN(std::string ir,
-                      CompileIr(std::move(hlo_module), printing_stage));
-  os << ir;
-  return Status::OK();
-}
-
-Status XlaGpuOpt::CompileAndOutputIr(const std::string& hlo_text,
-                                     llvm::raw_ostream& os,
-                                     LoweringStage printing_stage) {
-  TF_ASSIGN_OR_RETURN(auto module, GetVerifiedHloModule(hlo_text));
-  return CompileAndOutputIr(std::move(module), os, printing_stage);
-}
-
-MlirCompiler::IRHook XlaGpuOpt::GetIRHookBreakingLoweringStage(
-    LoweringStage breaking_stage) {
-  return {[](mlir::ModuleOp module) -> Status {
-            mlir::PassManager pm(module.getContext());
-            pm.addNestedPass<::mlir::FuncOp>(
-                ::mlir::createInjectErrorsForTestingPass());
-            if (failed(pm.run(module))) {
-              return InternalError("InjectErrorsForTestingPass failed.");
-            }
-            return Status::OK();
-          },
-          breaking_stage};
-}
-
-StatusOr<string> XlaGpuOpt::CompileAndInjectErrors(
-    std::unique_ptr<HloModule> hlo_module, LoweringStage breaking_stage) {
-  std::string errors;
-  auto error_handler = [&errors](const EmissionContext::ErrorMap& error_map,
-                                 HloModule* hlo_module) {
-    errors = "ERRORS FOUND: ";
-    for (auto& err : error_map) {
-      errors += "[" + err.first->ToString() + ": " +
-                absl::StrJoin(err.second, "; ") + "]";
-    }
-  };
-
-  MlirCompiler* compiler = GetMLIRCompiler();
-  compiler->SetModuleHook(GetIRHookBreakingLoweringStage(breaking_stage));
-  compiler->SetErrorHandler(error_handler);
-  TF_ASSIGN_OR_RETURN(
-      hlo_module, compiler->RunHloPasses(std::move(hlo_module),
-                                         backend_->default_stream_executor(),
-                                         /*device_allocator=*/nullptr));
-  Status status = compiler
-                      ->RunBackend(std::move(hlo_module),
-                                   backend_->default_stream_executor(),
-                                   /*device_allocator=*/nullptr)
-                      .status();
-  compiler->RemoveModuleHook();
-  compiler->RemoveErrorHandler();
-  if (status.ok()) {
-    return errors;
-  }
-  return status;
-}
-
-Status XlaGpuOpt::CompileAndExpectErrors(const std::string& hlo_text,
-                                         llvm::raw_ostream& os,
-                                         LoweringStage breaking_stage) {
-  TF_ASSIGN_OR_RETURN(auto module, GetVerifiedHloModule(hlo_text));
-  TF_ASSIGN_OR_RETURN(
-      std::string errors,
-      CompileAndInjectErrors(std::move(module), breaking_stage));
-  os << errors;
-  return Status::OK();
-}
-
-StatusOr<std::unique_ptr<VerifiedHloModule>> XlaGpuOpt::GetVerifiedHloModule(
-    const std::string& hlo_text) {
-  HloModuleConfig config;
-  auto debug_options = GetDebugOptionsFromFlags();
-  debug_options.add_xla_disable_hlo_passes("constant_folding");
-  config.set_debug_options(debug_options);
-  auto module = absl::make_unique<VerifiedHloModule>(
-      "Module", config, /*verifier_layout_sensitive=*/true,
-      /*allow_mixed_precision_in_hlo_verifier=*/false,
-      /*shape_size_function=*/ShapeUtil::ByteSizeOfElements);
-  TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
-  return std::move(module);
-}
-
-MlirCompiler* XlaGpuOpt::GetMLIRCompiler() {
-  // TODO(b/137624192): Remove failover once no longer in place.
-  auto* failover = static_cast<FailoverCompiler*>(backend_->compiler());
-  return static_cast<MlirCompiler*>(failover->GetPrimary());
-}
-
-}  // namespace mlir_gpu
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.h b/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.h
deleted file mode 100644
index 6a46f92..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.h
+++ /dev/null
@@ -1,76 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_XLA_GPU_OPT_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_XLA_GPU_OPT_H_
-
-#include <memory>
-#include <string>
-
-#include "llvm/Support/raw_ostream.h"
-#include "tensorflow/compiler/xla/service/backend.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
-#include "tensorflow/compiler/xla/status.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
-
-namespace xla {
-namespace mlir_gpu {
-
-// Prints the IR created by the MLIR GPU backend at a certain lowering stage.
-class XlaGpuOpt {
- public:
-  using LoweringStage = MlirCompiler::IRHook::LoweringStage;
-  XlaGpuOpt() {
-    backend_ = std::move(Backend::CreateDefaultBackend().ValueOrDie());
-  }
-
-  // Compiles the HLO module given in 'hlo_text' to a GpuExecutable and prints
-  // the IR at the lowering stage 'printing_stage' to the 'os' stream.
-  //
-  // This function invokes the JIT compiler.
-  Status CompileAndOutputIr(const std::string& hlo_text, llvm::raw_ostream& os,
-                            LoweringStage printing_stage = LoweringStage::LHLO);
-
-  // Adds the InjectErrorsForTestingPass to MLIRCompiler on the provided
-  // lowering stage 'breaking_stage', parses and compiles `hlo_text`, and prints
-  // the resulting errors to the 'os' stream.
-  Status CompileAndExpectErrors(const std::string& hlo_text,
-                                llvm::raw_ostream& os,
-                                LoweringStage breaking_stage);
-
- private:
-  std::unique_ptr<Backend> backend_;
-  StatusOr<std::unique_ptr<VerifiedHloModule>> GetVerifiedHloModule(
-      const std::string& hlo_text_filename);
-
-  Status CompileAndOutputIr(std::unique_ptr<HloModule> hlo_module,
-                            llvm::raw_ostream& os,
-                            LoweringStage printing_stage);
-  Status CompileIr(std::unique_ptr<HloModule> hlo_module,
-                   const MlirCompiler::IRHook& ir_hook);
-  StatusOr<std::string> CompileIr(std::unique_ptr<HloModule> hlo_module,
-                                  LoweringStage printing_stage);
-  MlirCompiler::IRHook GetIRHookBreakingLoweringStage(
-      LoweringStage breaking_stage);
-  StatusOr<std::string> CompileAndInjectErrors(
-      std::unique_ptr<HloModule> hlo_module, LoweringStage breaking_stage);
-  MlirCompiler* GetMLIRCompiler();
-};
-
-}  // namespace mlir_gpu
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_XLA_GPU_OPT_H_
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt_main.cc b/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt_main.cc
deleted file mode 100644
index f60eea6..0000000
--- a/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt_main.cc
+++ /dev/null
@@ -1,90 +0,0 @@
-/* Copyright 2020 Google Inc. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <string>
-
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/ToolOutputFile.h"
-#include "mlir/Pass/PassManager.h"  // from @llvm-project
-#include "mlir/Support/FileUtilities.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/init_mlir.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
-#include "tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.h"
-#include "tensorflow/compiler/xla/status.h"
-#include "tensorflow/core/platform/logging.h"
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<std::string> input_filename(llvm::cl::Positional,
-                                                 llvm::cl::desc("<input file>"),
-                                                 llvm::cl::init("-"));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<std::string> output_filename(
-    "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
-    llvm::cl::init("-"));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> verify_errors(
-    "verify-errors",
-    llvm::cl::desc("Whether we expect errors which should be verified"),
-    llvm::cl::init(false));
-
-static llvm::cl::opt<xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage>
-    // NOLINTNEXTLINE
-    lowering_stage(
-        "lowering-stage",
-        llvm::cl::desc(
-            "The lowering stage up to which the compiler will be run"),
-        llvm::cl::values(
-            clEnumValN(xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage::LHLO,
-                       "LHLO", "LHLO"),
-            clEnumValN(xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage::GPU,
-                       "GPU", "GPU"),
-            clEnumValN(xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage::LLVM,
-                       "LLVM", "LLVM"),
-            clEnumValN(
-                xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage::KERNEL,
-                "KERNEL", "Kernel")),
-        llvm::cl::init(
-            xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage::LHLO));
-
-int main(int argc, char **argv) {
-  tensorflow::InitMlir y(&argc, &argv);
-  mlir::registerPassManagerCLOptions();
-
-  llvm::cl::ParseCommandLineOptions(argc, argv,
-                                    "XLA GPU modular optimizer driver\n");
-
-  // Set up the input file.
-  std::string error_message;
-  auto file = mlir::openInputFile(input_filename, &error_message);
-  QCHECK(file) << error_message;
-
-  auto output = mlir::openOutputFile(output_filename, &error_message);
-  QCHECK(output) << error_message;
-
-  xla::mlir_gpu::XlaGpuOpt opt;
-  xla::Status status =
-      verify_errors ? opt.CompileAndExpectErrors(file->getBuffer().str(),
-                                                 output->os(), lowering_stage)
-                    : opt.CompileAndOutputIr(file->getBuffer().str(),
-                                             output->os(), lowering_stage);
-  if (!status.ok()) {
-    LOG(ERROR) << status.error_message();
-    return 1;
-  }
-  output->keep();
-  return 0;
-}
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index a6d23c1..c5d4d04 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -38,6 +38,7 @@
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_module_util.h"
 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
 #include "tensorflow/compiler/xla/service/platform_util.h"
 #include "tensorflow/compiler/xla/service/source_map_util.h"
@@ -256,88 +257,16 @@
     absl::Span<const Shape* const> argument_shapes,
     const ExecutionOptions* execution_options,
     const AotCompilationOptions* aot_options) {
-  auto config = absl::make_unique<HloModuleConfig>(program_shape);
-  ComputationLayout* computation_layout =
-      config->mutable_entry_computation_layout();
-  const int64 argument_shapes_size = argument_shapes.size();
-  if (program_shape.parameters_size() != argument_shapes_size) {
-    return InvalidArgument("computation takes %d parameters, but %u given",
-                           program_shape.parameters_size(),
-                           argument_shapes.size());
-  }
-  for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
-    // Verify that shape of arguments matches the shape of the arguments in the
-    // ProgramShape.
-    if (!ShapeUtil::Compatible(*argument_shapes[i],
-                               program_shape.parameters(i))) {
-      return InvalidArgument(
-          "Argument does not match shape of computation parameter %d: want "
-          "%s, got %s",
-          i, ShapeUtil::HumanString(program_shape.parameters(i)),
-          ShapeUtil::HumanString(*argument_shapes[i]));
-    }
-    TF_RETURN_IF_ERROR(
-        computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
-            *argument_shapes[i]));
-  }
-  if (execution_options != nullptr &&
-      execution_options->has_shape_with_output_layout()) {
-    const Shape shape_with_output_layout(
-        execution_options->shape_with_output_layout());
-    TF_RETURN_IF_ERROR(
-        ValidateResultShape(shape_with_output_layout, program_shape.result()));
-    TF_RETURN_IF_ERROR(
-        computation_layout->mutable_result_layout()->CopyLayoutFromShape(
-            shape_with_output_layout));
-  } else {
-    // If the result layout is not set, then choose the default.
-    computation_layout->mutable_result_layout()->SetToDefaultLayout();
-  }
-
-  if (execution_options != nullptr) {
-    if (execution_options->num_replicas() > 0) {
-      config->set_replica_count(execution_options->num_replicas());
-    } else {
-      config->set_replica_count(options_.number_of_replicas());
-    }
-    if (execution_options->num_partitions() > 0) {
-      config->set_num_partitions(execution_options->num_partitions());
-    }
-    config->set_use_spmd_partitioning(
-        execution_options->use_spmd_partitioning());
-    config->set_deduplicate_hlo(execution_options->deduplicate_hlo());
-    config->set_seed(execution_options->seed());
-    config->set_launch_id(execution_options->launch_id());
-    config->set_debug_options(execution_options->debug_options());
-  } else {
-    config->set_replica_count(options_.number_of_replicas());
-    config->set_debug_options(GetDebugOptionsFromFlags());
-  }
-
+  int default_num_replicas = options_.number_of_replicas();
+  absl::optional<int> num_threads;
   if (execute_backend_ != nullptr &&
       execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
-    config->set_intra_op_parallelism_threads(
-        execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
+    num_threads = execute_backend_->eigen_intra_op_thread_pool()->NumThreads();
   }
 
-  if (execution_options != nullptr &&
-      execution_options->has_device_assignment()) {
-    TF_ASSIGN_OR_RETURN(
-        auto device_assignment,
-        DeviceAssignment::Deserialize(execution_options->device_assignment()));
-    config->set_static_device_assignment(*device_assignment);
-  }
-  config->set_alias_passthrough_params(
-      execution_options->alias_passthrough_params());
-
-  if (aot_options != nullptr &&
-      aot_options->fusion_config_collection() != FusionConfigCollection::kOff) {
-    config->set_fusion_config_collection(
-        aot_options->fusion_config_collection());
-    *config->mutable_fusion_config() = aot_options->fusion_config();
-  }
-
-  return std::move(config);
+  return xla::CreateModuleConfig(program_shape, argument_shapes,
+                                 execution_options, default_num_replicas,
+                                 num_threads, aot_options);
 }
 
 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
@@ -357,7 +286,7 @@
     const std::vector<const HloModuleProto*>& module_protos,
     std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
     Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
-    se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) {
+    const Compiler::CompileOptions& options, bool run_backend_only) {
   VLOG(1) << StrFormat("BuildExecutable on service %p", this);
 
   // Dump computation proto state if flag is set.
@@ -387,17 +316,15 @@
 
   std::vector<std::unique_ptr<Executable>> executables;
   if (!run_backend_only) {
-    TF_ASSIGN_OR_RETURN(
-        executables,
-        backend->compiler()->Compile(std::move(module_group),
-                                     std::move(executors), device_allocator));
+    TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
+                                         std::move(module_group),
+                                         std::move(executors), options));
   } else {
     auto modules = module_group->ConsumeModules();
     for (std::unique_ptr<HloModule>& module : modules) {
-      TF_ASSIGN_OR_RETURN(
-          std::unique_ptr<Executable> executable,
-          backend->compiler()->RunBackend(std::move(module), executors[0][0],
-                                          device_allocator));
+      TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+                          backend->compiler()->RunBackend(
+                              std::move(module), executors[0][0], options));
       executables.push_back(std::move(executable));
     }
   }
@@ -710,7 +637,7 @@
   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
                       BuildExecutables(module_protos, std::move(module_configs),
                                        execute_backend_.get(), all_executors,
-                                       /*device_allocator=*/nullptr));
+                                       {/*device_allocator=*/nullptr}));
   std::vector<Executable*> executable_ptrs;
   executable_ptrs.reserve(executables.size());
   for (const auto& executable : executables) {
@@ -744,20 +671,39 @@
   // basically the same thing.
   ExecutionProfile profile;
   std::vector<GlobalDataHandle> outputs;
+  Status execution_status = Status::OK();
+
   if (executable_ptrs.size() == 1) {
-    TF_ASSIGN_OR_RETURN(
-        auto output,
-        ExecuteAndRegisterResult(executable_ptrs[0], all_arguments[0],
-                                 execute_backend_.get(), device_handles[0],
-                                 computation_names[0], &profile));
-    outputs.push_back(std::move(output));
+    StatusOr<GlobalDataHandle> output_or_status = ExecuteAndRegisterResult(
+        executable_ptrs[0], all_arguments[0], execute_backend_.get(),
+        device_handles[0], computation_names[0], &profile);
+    if (output_or_status.ok()) {
+      outputs.push_back(std::move(output_or_status).ValueOrDie());
+    } else {
+      execution_status = output_or_status.status();
+    }
   } else {
-    TF_ASSIGN_OR_RETURN(
-        outputs, ExecuteParallelAndRegisterResult(
-                     executable_ptrs, all_arguments, execute_backend_.get(),
-                     device_handles, computation_names, &profile));
+    StatusOr<std::vector<GlobalDataHandle>> outputs_or_status =
+        ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
+                                         execute_backend_.get(), device_handles,
+                                         computation_names, &profile);
+    if (outputs_or_status.ok()) {
+      outputs = std::move(outputs_or_status).ValueOrDie();
+    } else {
+      execution_status = outputs_or_status.status();
+    }
   }
 
+  if (!execution_status.ok()) {
+    // Execution failed so we don't have the results.  Dump the HLO snapshot
+    // with just the program arguments.
+    for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
+      DumpHloSnapshotIfEnabled(executable_ptrs[i]->module(), snapshots[i]);
+    }
+  }
+
+  TF_RETURN_IF_ERROR(execution_status);
+
   for (const GlobalDataHandle& output : outputs) {
     ExecuteResponse response;
     *response.mutable_output() = output;
@@ -810,7 +756,7 @@
 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
     const HloModuleProto& module_proto,
     std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
-    se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator,
+    se::StreamExecutor* executor, const Compiler::CompileOptions& options,
     bool run_backend_only) {
   VLOG(1) << StrFormat(
       "BuildExecutable on service %p with serialized module proto: %s", this,
@@ -822,14 +768,13 @@
   DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
 
   if (!run_backend_only) {
-    TF_ASSIGN_OR_RETURN(
-        module, backend->compiler()->RunHloPasses(std::move(module), executor,
-                                                  device_allocator));
+    TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
+                                    std::move(module), executor, options));
   }
 
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
-                      backend->compiler()->RunBackend(
-                          std::move(module), executor, device_allocator));
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<Executable> executable,
+      backend->compiler()->RunBackend(std::move(module), executor, options));
 
   const auto& debug_opts = module_config->debug_options();
   if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
@@ -875,7 +820,7 @@
       BuildExecutable(arg->computation(), std::move(module_config),
                       execute_backend_.get(),
                       execute_backend_->default_stream_executor(),
-                      /*device_allocator=*/nullptr));
+                      {/*device_allocator=*/nullptr}));
 
   *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
 
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 712ccc4..02288bb 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -235,8 +235,7 @@
   StatusOr<std::unique_ptr<Executable>> BuildExecutable(
       const HloModuleProto& module_proto,
       std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
-      se::StreamExecutor* executor,
-      se::DeviceMemoryAllocator* device_allocator = nullptr,
+      se::StreamExecutor* executor, const Compiler::CompileOptions& options,
       bool run_backend_only = false);
 
   // Same as BuildExecutable() above, but builds a list of Executables for the
@@ -245,8 +244,7 @@
       const std::vector<const HloModuleProto*>& module_protos,
       std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
       Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
-      se::DeviceMemoryAllocator* device_allocator,
-      bool run_backend_only = false);
+      const Compiler::CompileOptions& options, bool run_backend_only = false);
 
   // Runs the given executable with the given arguments and register the result
   // in the allocation tracker. The handle of the result from the tracker is
diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc
index 4dff4dc..af5471e 100644
--- a/tensorflow/compiler/xla/service/sharding_propagation.cc
+++ b/tensorflow/compiler/xla/service/sharding_propagation.cc
@@ -102,27 +102,8 @@
   }
 }
 
-// Returns a sharding where each tuple element is chosen as the more specific
-// one of the corresponding elements in a and b. Requires a an b to have the
-// same tuple nesting.
-HloSharding MergeForMoreSpecificSharding(const HloSharding& a,
-                                         const HloSharding& b) {
-  if (a.IsTuple()) {
-    HloSharding result = a;
-    CHECK(b.IsTuple());
-    CHECK_EQ(a.tuple_elements().size(), b.tuple_elements().size());
-    for (int64 i = 0; i < result.tuple_elements().size(); ++i) {
-      result.tuple_elements()[i] = MergeForMoreSpecificSharding(
-          a.tuple_elements()[i], b.tuple_elements()[i]);
-    }
-    return result;
-  }
-  return IsShardingMoreSpecific(a, b) ? a : b;
-}
-
 // Tries to refine `to_merge` by combining with `old`. Returns if the final
-// `to_merge` is more specific than `old`. May combine partial sharding in
-// addition to MergeForMoreSpecificSharding().
+// `to_merge` is more specific than `old`.
 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
                    bool may_combine_partial_sharding) {
   if (old.IsTuple()) {
@@ -685,12 +666,13 @@
     return false;
   }
   // Propagate manual sharding. Avoid tuple shaped HLOs that group independent
-  // together. Reduce and Sort can be tuples but the elements are correlated, so
-  // we propagate manual sharding through them.
+  // together. Reduce, ReduceWindow, and Sort can be tuples but the elements
+  // are correlated, so we propagate manual sharding through them.
   if (!instruction->has_sharding() &&
       (instruction->shape().IsArray() ||
        instruction->opcode() == HloOpcode::kReduce ||
-       instruction->opcode() == HloOpcode::kSort) &&
+       instruction->opcode() == HloOpcode::kSort ||
+       instruction->opcode() == HloOpcode::kReduceWindow) &&
       absl::c_any_of(instruction->operands(), [](const HloInstruction* op) {
         return op->has_sharding() && op->sharding().IsManual();
       })) {
@@ -887,6 +869,10 @@
                                              may_combine_partial_sharding);
     }
     case HloOpcode::kReduceWindow: {
+      if (instruction->shape().IsTuple()) {
+        // TODO (b/73062247) variadic reduce window is not yet supported here.
+        return false;
+      }
       const HloInstruction* lhs = instruction->operand(0);
       if (!IsSpatiallyPartitioned(lhs)) {
         return false;
@@ -1093,8 +1079,8 @@
       }
       auto sharding = instruction->operand(0)->sharding();
       if (instruction->has_sharding()) {
-        sharding =
-            MergeForMoreSpecificSharding(sharding, instruction->sharding());
+        MergeSharding(instruction->sharding(), &sharding,
+                      may_combine_partial_sharding);
       }
       return MaybeImproveInstructionSharding(std::move(sharding), instruction,
                                              may_combine_partial_sharding);
@@ -1311,6 +1297,10 @@
       return user.sharding();
     }
     case HloOpcode::kReduceWindow: {
+      if (user.shape().IsTuple()) {
+        return user.sharding().GetSubSharding(
+            user.shape(), {user.operand_index(&instruction)});
+      }
       if (&instruction != user.operand(0)) {
         return absl::nullopt;
       }
@@ -1320,6 +1310,12 @@
       return hlo_sharding_util::ReshapeSharding(
           user.shape(), instruction.shape(), user.sharding());
     }
+    case HloOpcode::kPad: {
+      if (&instruction != user.operand(0)) {
+        return absl::nullopt;
+      }
+      return user.sharding();
+    }
     case HloOpcode::kSlice: {
       return user.sharding();
     }
@@ -1673,8 +1669,10 @@
   // If instruction is a while, or the root or a parameter of a while body,
   // then propagate its sharding to the while instruction, to its body root,
   // and to its condition parameter.
-  std::function<void(HloInstruction*)> maybe_computation_propagation =
-      [&](HloInstruction* instruction) {
+  std::function<void(HloInstruction*, absl::flat_hash_set<HloInstruction*>*)>
+      maybe_computation_propagation = [&](HloInstruction* instruction,
+                                          absl::flat_hash_set<HloInstruction*>*
+                                              changed) {
         auto propagate_to_instruction = [&](HloInstruction* search_inst) {
           auto related_instructions = get_related_instructions(search_inst);
           if (absl::c_count(related_instructions, instruction)) {
@@ -1683,7 +1681,8 @@
                   inst->sharding() != instruction->sharding()) {
                 VLOG(2) << "Add computation sharding: " << inst->name();
                 inst->set_sharding(instruction->sharding());
-                maybe_computation_propagation(inst);
+                changed->insert(inst);
+                maybe_computation_propagation(inst, changed);
               }
             }
           }
@@ -1785,6 +1784,14 @@
         for (const HloInstruction* instruction : instructions) {
           already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
         }
+        auto clear_cache = [&](HloInstruction* hlo) {
+          for (auto operand : hlo->operands()) {
+            already_inferred_from_users.erase(operand);
+          }
+          for (auto user : hlo->users()) {
+            already_inferred_from_operands.erase(user);
+          }
+        };
         // First iterate the HLO graph in post order taking shardings from
         // operands.
         for (HloInstruction* instruction : instructions) {
@@ -1799,12 +1806,11 @@
             any_changed = true;
             VLOG(2) << "Add sharding (forward-pass): "
                     << instruction->ToString();
-            maybe_computation_propagation(instruction);
-            for (auto operand : instruction->operands()) {
-              already_inferred_from_users.erase(operand);
-            }
-            for (auto user : instruction->users()) {
-              already_inferred_from_operands.erase(user);
+            absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
+            maybe_computation_propagation(instruction, &changed_in_comp_prop);
+            clear_cache(instruction);
+            for (auto hlo : changed_in_comp_prop) {
+              clear_cache(hlo);
             }
             changed_last_iter = true;
           }
@@ -1823,12 +1829,11 @@
             ++inferred_from_user_counter;
             any_changed = true;
             VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
-            maybe_computation_propagation(*it);
-            for (auto operand : (*it)->operands()) {
-              already_inferred_from_users.erase(operand);
-            }
-            for (auto user : (*it)->users()) {
-              already_inferred_from_operands.erase(user);
+            absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
+            maybe_computation_propagation(*it, &changed_in_comp_prop);
+            clear_cache(*it);
+            for (auto hlo : changed_in_comp_prop) {
+              clear_cache(hlo);
             }
             changed_last_iter = true;
           }
diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc
index ec83f99..1645e01 100644
--- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc
@@ -514,6 +514,26 @@
               op::Sharding("{devices=[2,2]0,1,2,3}"));
 }
 
+TEST_F(ShardingPropagationTest, PadBackwardPass) {
+  const char* const hlo_string = R"(
+HloModule module
+ENTRY %pad {
+  %input = f32[11,17]{1,0} parameter(0)
+  %copy = f32[11,17]{1,0} copy(%input)
+  %pad_value = f32[] parameter(1)
+  %pad = f32[27,51]{1,0} pad(%copy, %pad_value), padding=2_4_1x1_1_2,
+    sharding={devices=[2,2]0,1,2,3}
+  ROOT %result = f32[27,51]{1,0} copy(%pad)
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          ShardingPropagation().Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "copy"),
+              op::Sharding("{devices=[2,2]0,1,2,3}"));
+}
+
 TEST_F(ShardingPropagationTest, PartialReplicatedPadForwardPass) {
   const char* const hlo_string = R"(
 HloModule module
@@ -856,40 +876,41 @@
 HloModule module
 
 %cond {
-  %vars.cond = (u32[], f32[10]{0}) parameter(0)
-  %count.cond = u32[] get-tuple-element((u32[], f32[10]{0}) %vars.cond), index=0
+  %vars.cond = (u32[], f32[10,10]) parameter(0)
+  %count.cond = u32[] get-tuple-element((u32[], f32[10,10]) %vars.cond), index=0
   %limit = u32[] constant(10)
   ROOT %lt = pred[] compare(u32[] %count.cond, u32[] %limit), direction=LT
 }
 
 %body {
-  %vars = (u32[], f32[10]{0}) parameter(0)
+  %vars = (u32[], f32[10,10]) parameter(0)
   %count = u32[] get-tuple-element(%vars), index=0
-  %acc = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %vars), index=1
+  %acc = f32[10,10] get-tuple-element((u32[], f32[10,10]) %vars), index=1
 
   %one = u32[] constant(1)
   %count.1 = u32[] add(u32[] %count, u32[] %one), sharding={replicated}
-  %acc.1 = f32[10]{0} add(f32[10]{0} %acc, f32[10]{0} %acc)
-  ROOT %tuple = (u32[], f32[10]{0}) tuple(u32[] %count.1, f32[10]{0} %acc.1)
+  %acc.1 = f32[10,10] add(f32[10,10] %acc, f32[10,10] %acc)
+  ROOT %tuple = (u32[], f32[10,10]) tuple(u32[] %count.1, f32[10,10] %acc.1)
 }
 
 ENTRY %entry {
-  %p0 = f32[10]{0} parameter(0)
-  %p0.copy = f32[10]{0} copy(f32[10]{0} %p0)
-  %p1 = f32[10]{0} parameter(1)
+  %p0 = f32[10,10] parameter(0)
+  %p0.copy = f32[10,10] copy(f32[10,10] %p0)
+  %p1 = f32[10,10] parameter(1)
   %zero = u32[] constant(0)
-  %init = (u32[], f32[10]{0}) tuple(u32[] %zero, f32[10]{0} %p0.copy)
-  %while = (u32[], f32[10]{0}) while((u32[], f32[10]{0}) %init),
+  %init = (u32[], f32[10,10]) tuple(u32[] %zero, f32[10,10] %p0.copy)
+  %while = (u32[], f32[10,10]) while((u32[], f32[10,10]) %init),
     body=%body, condition=%cond
-  %res = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %while), index=1
-  %prev = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %init), index=1
-  %res.1 = f32[10]{0} multiply(f32[10]{0} %res, %prev)
-  ROOT %res_tuple = (f32[10]{0}) tuple(f32[10]{0} %res.1)
+  %res = f32[10,10] get-tuple-element((u32[], f32[10,10]) %while), index=1
+  %prev = f32[10,10] get-tuple-element((u32[], f32[10,10]) %init), index=1
+  %res.1 = f32[10,10] multiply(f32[10,10] %res, %prev)
+  ROOT %res_tuple = (f32[10,10]) tuple(f32[10,10] %res.1)
 })";
 
   auto while_is_sharded = [this](HloModule* module,
                                  const HloSharding& sharding) {
-    TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingPropagation().Run(module));
+    TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                            ShardingPropagation(/*is_spmd=*/true).Run(module));
     EXPECT_TRUE(changed);
     auto while_instr = FindInstruction(module, "while");
     EXPECT_NE(nullptr, while_instr);
@@ -911,7 +932,7 @@
     auto body_root = FindInstruction(module.get(), "tuple");
     EXPECT_NE(nullptr, body_root);
     auto sharding =
-        ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie();
+        ParseSharding("{{replicated}, {devices=[2,1]0,1}}").ConsumeValueOrDie();
     body_root->set_sharding(sharding);
     while_is_sharded(module.get(), sharding);
   }
@@ -921,11 +942,30 @@
                             ParseAndReturnVerifiedModule(hlo_string));
     auto acc_1 = FindInstruction(module.get(), "acc.1");
     EXPECT_NE(nullptr, acc_1);
-    acc_1->set_sharding(ParseSharding("{devices=[2]0,1}").ConsumeValueOrDie());
+    acc_1->set_sharding(
+        ParseSharding("{devices=[2,1]0,1}").ConsumeValueOrDie());
 
-    while_is_sharded(
-        module.get(),
-        ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie());
+    while_is_sharded(module.get(),
+                     ParseSharding("{{replicated}, {devices=[2,1]0,1}}")
+                         .ConsumeValueOrDie());
+  }
+  {
+    // Merge partial sharding from operand and body.
+    TF_ASSERT_OK_AND_ASSIGN(auto module,
+                            ParseAndReturnVerifiedModule(hlo_string));
+    auto acc_1 = FindInstruction(module.get(), "acc.1");
+    EXPECT_NE(nullptr, acc_1);
+    acc_1->set_sharding(
+        ParseSharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")
+            .ConsumeValueOrDie());
+    auto p0 = FindInstruction(module.get(), "p0");
+    p0->set_sharding(
+        ParseSharding("{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}")
+            .ConsumeValueOrDie());
+
+    while_is_sharded(module.get(),
+                     ParseSharding("{{replicated}, {devices=[2,2]0,1,2,3}}")
+                         .ConsumeValueOrDie());
   }
 }
 
diff --git a/tensorflow/compiler/xla/service/slow_operation_alarm.cc b/tensorflow/compiler/xla/service/slow_operation_alarm.cc
index 2ce66b2..13f6ac3 100644
--- a/tensorflow/compiler/xla/service/slow_operation_alarm.cc
+++ b/tensorflow/compiler/xla/service/slow_operation_alarm.cc
@@ -106,12 +106,19 @@
 
 SlowOperationAlarm::~SlowOperationAlarm() { UnscheduleAlarm(this); }
 
-std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm() {
+std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm(
+    absl::string_view msg) {
   // Pass a counter to these alarms so they only log once every power-of-two
   // occurrences.
   static auto* counter = new std::atomic<int64>(0);
 
   const char* separator = "\n********************************";
+
+  std::string msg_suffix;
+  if (!msg.empty()) {
+    msg_suffix = absl::StrCat("\n", msg);
+  }
+
 #if NDEBUG
   return absl::make_unique<SlowOperationAlarm>(
       absl::Duration(absl::Minutes(2)),
@@ -119,7 +126,7 @@
           separator,
           "\nVery slow compile?  If you want to file a bug, run with envvar "
           "XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.",
-          separator),
+          msg_suffix, separator),
       counter);
 #else
   return absl::make_unique<SlowOperationAlarm>(
@@ -128,7 +135,7 @@
           separator,
           "\nSlow compile?  XLA was built without compiler optimizations, "
           "which can be slow.  Try rebuilding with -c opt.",
-          separator),
+          msg_suffix, separator),
       counter);
 #endif
 }
diff --git a/tensorflow/compiler/xla/service/slow_operation_alarm.h b/tensorflow/compiler/xla/service/slow_operation_alarm.h
index 20099bb..bd84591 100644
--- a/tensorflow/compiler/xla/service/slow_operation_alarm.h
+++ b/tensorflow/compiler/xla/service/slow_operation_alarm.h
@@ -22,6 +22,7 @@
 #include <tuple>
 
 #include "absl/base/attributes.h"
+#include "absl/strings/string_view.h"
 #include "absl/time/time.h"
 #include "tensorflow/compiler/xla/types.h"
 
@@ -64,7 +65,8 @@
 // In opt builds, recommends filing a bug.
 //
 // This is throttled to once-every-power-of-two occurrences, globally.
-ABSL_MUST_USE_RESULT std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm();
+ABSL_MUST_USE_RESULT std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm(
+    absl::string_view msg = "");
 
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc
index e33888c..8f7cc1a 100644
--- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc
+++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc
@@ -182,6 +182,10 @@
     return permute_dims[id];
   }
 
+  int64 ReverseDimLookUp(absl::Span<const int64> permute_dims, int64 id) {
+    return std::distance(permute_dims.begin(), absl::c_find(permute_dims, id));
+  }
+
   HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
       HloInstruction* instr, int64 depth);
 
@@ -215,9 +219,10 @@
   // Limit on batch size to apply this technique on.
   int64 limit_on_batch_size_;
 
-  // We choose the new batch size to be a constant so that space-to-batch
-  // propagation through several convolutional layers is consistent.
-  static constexpr int64 kNewBatchSize = 8;
+  // We choose the new batch size to be kNumSplits times that of the old batch
+  // so that space-to-batch propagation through several convolutional layers is
+  // consistent.
+  static constexpr int64 kNumSplits = 8;
 
   // Depth for searching reduce window
   static constexpr int64 kReduceWindowSearchDepth = 10;
@@ -301,17 +306,12 @@
   if (old_batch_size > limit_on_batch_size_) {
     return false;
   }
-  // We currently only cater to evenly divisible cases.
-  if (kNewBatchSize % old_batch_size != 0) {
-    return false;
-  }
 
   VLOG(1) << "spatial size " << c.spatial_size;
 
-  const int64 num_splits = kNewBatchSize / old_batch_size;
   // If the ratio is not within the 2X range, we can't Halo Pad from the next
   // split.
-  if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) {
+  if (c.halo_size > CeilOfRatio(c.spatial_size, kNumSplits)) {
     return false;
   }
   VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
@@ -323,6 +323,24 @@
     int64 activations_batch_dim, int64 old_batch_size, int64 low_padding,
     int64 high_padding, int64 halo_size, int64 original_split_dim_size,
     HloInstruction* pad_val) {
+  const int64 original_batch_size =
+      activations->shape().dimensions(activations_batch_dim) / kNumSplits;
+
+  if (original_batch_size > 1) {
+    std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
+                                      activations->shape().dimensions().end());
+    new_dimensions[activations_batch_dim] = kNumSplits;
+    new_dimensions.insert(new_dimensions.begin() + activations_batch_dim,
+                          original_batch_size);
+
+    // Reshape the output of the new conv into the old convolutions shape.
+    TF_ASSIGN_OR_RETURN(activations,
+                        MakeReshapeHlo(new_dimensions, activations));
+
+    spatial_dimension_to_split++;
+    activations_batch_dim++;
+  }
+
   const int64 rank = activations->shape().rank();
   const int64 spatial_split_size =
       activations->shape().dimensions(spatial_dimension_to_split);
@@ -415,6 +433,21 @@
     TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({activations, halo_region},
                                                    spatial_dimension_to_split));
   }
+
+  if (original_batch_size > 1) {
+    std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
+                                      activations->shape().dimensions().end());
+    new_dimensions[activations_batch_dim] = original_batch_size * kNumSplits;
+    new_dimensions.erase(new_dimensions.begin() + activations_batch_dim - 1);
+
+    // Reshape the output of the new conv into the old convolutions shape.
+    TF_ASSIGN_OR_RETURN(activations,
+                        MakeReshapeHlo(new_dimensions, activations));
+
+    spatial_dimension_to_split++;
+    activations_batch_dim++;
+  }
+
   VLOG(1) << "HaloDuplicated activations " << activations->ToString();
   return activations;
 }
@@ -424,17 +457,20 @@
     HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
     int64& spatial_dimension_to_split, int64& activations_batch_dim,
     bool is_backprop) {
-  std::vector<int64> transpose_dims;
-  ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
-  if (spatial_dimension_to_split != activations_batch_dim + 1) {
+  std::vector<int64> transpose_dims(activations->shape().rank());
+  if (spatial_dimension_to_split == activations_batch_dim + 1) {
+    absl::c_iota(transpose_dims, 0);
+  } else {
+    ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
     int64 pushed_counter = 0;
     int64 new_batch_dim, new_spatial_dim;
+    int64 dim_counter = 0;
     for (int i = 0; i < activations->shape().rank(); ++i) {
       if (i == activations_batch_dim) {
         continue;
       }
       if (i == spatial_dimension_to_split) {
-        transpose_dims.push_back(activations_batch_dim);
+        transpose_dims[dim_counter++] = activations_batch_dim;
         new_batch_dim = pushed_counter;
         pushed_counter++;
         new_spatial_dim = pushed_counter;
@@ -452,7 +488,7 @@
           }
         }
       }
-      transpose_dims.push_back(i);
+      transpose_dims[dim_counter++] = i;
       pushed_counter++;
     }
 
@@ -460,14 +496,14 @@
     spatial_dimension_to_split = new_spatial_dim;
     TF_ASSIGN_OR_RETURN(activations,
                         MakeTransposeHlo(activations, transpose_dims));
-  }
 
-  if (is_backprop) {
-    new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
-  } else {
-    new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
+    if (is_backprop) {
+      new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
+    } else {
+      new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
+    }
+    dim_numbers = new_dim_numbers;
   }
-  dim_numbers = new_dim_numbers;
 
   return SpaceNextToBatchDetails{activations, transpose_dims};
 }
@@ -586,12 +622,23 @@
     VLOG(1) << "Checking if conv is supported for propagation "
             << consumer->ToString();
     if (IsConvSuitableForSpaceToBatch(consumer)) {
-      for (int64 i = 0; i < consumer->operand_count(); ++i) {
-        auto old_producer = consumer->mutable_operand(i);
-        if (i == 0 && !old_to_new_instrs_.contains(old_producer)) {
-          return false;
-        }
+      if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
+        return false;
       }
+      auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
+      // Make sure that the space dimension is the same across the producer
+      // and consumer.
+      if (consumer->convolution_dimension_numbers().input_spatial_dimensions(
+              get_chosen_spatial_dim(consumer)) != dim_map_val_op_0.second) {
+        return false;
+      }
+      // Make sure that the batch dimension is the same across the producer
+      // and consumer.
+      if (consumer->convolution_dimension_numbers().input_batch_dimension() !=
+          dim_map_val_op_0.first) {
+        return false;
+      }
+
       return true;
     }
 
@@ -611,13 +658,35 @@
     VLOG(2) << "Checking for backprop filter conv operands "
             << consumer->operand_count();
 
-    if (!old_to_new_instrs_.contains(consumer->mutable_operand(1))) {
+    auto activations = consumer->mutable_operand(0);
+    auto kernel = consumer->mutable_operand(1);
+
+    if (!old_to_new_instrs_.contains(kernel)) {
       VLOG(2) << "Backprop filter conv not ready for propagation because of "
                  "kernel is not space-to-batched";
       return false;
     }
 
-    if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
+    if (!old_to_new_instrs_.contains(activations)) {
+      const int64 lhs_batch = activations->shape().dimensions(
+          consumer->convolution_dimension_numbers().input_feature_dimension());
+      auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
+      const int64 old_batch_dim = dim_map_val_op_1.first;
+      auto second_operand = old_to_new_instrs_[kernel];
+      auto permute_dims_second_operand =
+          instr_to_dim_permute_map_[second_operand];
+      const int64 new_batch_dim =
+          DimLookUp(permute_dims_second_operand, old_batch_dim);
+      const int64 rhs_batch = second_operand->shape().dimensions(new_batch_dim);
+
+      // Because we want to convert activations into a space-to-batched version
+      // only for backprop filter convolutions, we want to make sure that the
+      // batch dimensions (feature dimensions, technically) are same sized.
+      // Since RHS is already space-to-batched, we need to account for it too.
+      if (rhs_batch != kNumSplits * lhs_batch) {
+        return false;
+      }
+
       // If activations have not been propagated through, we can do
       // space-to-batch on them provided kernel has been propagated.
       VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, "
@@ -625,10 +694,10 @@
       return true;
     }
 
-    auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
-    auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
-    auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
-    auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
+    auto first_operand = old_to_new_instrs_[activations];
+    auto dim_map_val_op_0 = instr_to_dim_map_[activations];
+    auto second_operand = old_to_new_instrs_[kernel];
+    auto dim_map_val_op_1 = instr_to_dim_map_[kernel];
 
     auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
     auto permute_dims_second_operand =
@@ -853,6 +922,11 @@
            absl::c_linear_search(reduce_dims, space_dim);
   }
 
+  if (consumer->opcode() == HloOpcode::kReduceWindow &&
+      consumer->shape().IsTuple()) {
+    // TODO (b/73062247) variadic reduce window is not yet supported.
+    return false;
+  }
   if (consumer->opcode() == HloOpcode::kReduceWindow ||
       consumer->opcode() == HloOpcode::kSelectAndScatter) {
     auto first_operand = consumer->mutable_operand(0);
@@ -912,12 +986,14 @@
     auto dim_map_val = instr_to_dim_map_[producer];
     auto new_consumer = computation->AddInstruction(consumer->Clone());
 
+    bool is_pivot_producer_modified = false;
     // For elementwise binary ops, both of whose operands have been space-to-
     // batched, if their new spatial sizes don't match, choose the bigger one
     // as the producer.
     if (consumer->IsElementwiseBinary() &&
         old_to_new_instrs_.contains(consumer->mutable_operand(0)) &&
         old_to_new_instrs_.contains(consumer->mutable_operand(1))) {
+      is_pivot_producer_modified = true;
       if (old_to_new_instrs_[consumer->mutable_operand(0)]
               ->shape()
               .dimensions() > old_to_new_instrs_[consumer->mutable_operand(1)]
@@ -969,7 +1045,8 @@
               pivot_new_instr->shape().dimensions(space_dim) * batch_size /
               old_batch_size;
 
-          CHECK_GT(pivot_space_size, new_dimensions[space_dim]);
+          CHECK(pivot_space_size > new_dimensions[space_dim] ||
+                !is_pivot_producer_modified);
 
           PaddingConfig padding_config =
               MakeNoPaddingConfig(reshape->shape().dimensions_size());
@@ -1119,7 +1196,7 @@
 
     Window new_win;
     for (int64 i = 0; i < consumer->window().dimensions().size(); ++i) {
-      auto dim = DimLookUp(permute_dims, i);
+      auto dim = ReverseDimLookUp(permute_dims, i);
       new_win.add_dimensions();
       new_win.mutable_dimensions(i)->set_stride(
           consumer->window().dimensions(dim).stride());
@@ -1339,7 +1416,9 @@
   const int64 new_space_size = new_shape.dimensions(new_space_dim);
   const int64 old_batch_size = old_shape.dimensions(old_batch_dim);
   const int64 old_space_size = old_shape.dimensions(old_space_dim);
-  CHECK_EQ(new_batch_size % old_batch_size, 0);
+  CHECK_EQ(new_batch_size % old_batch_size, 0)
+      << " New batch size " << new_batch_size << " old batch size "
+      << old_batch_size;
   const int64 num_splits = new_batch_size / old_batch_size;
   // Build a constant PRED to decide which elements in the split dimension
   // are from halo.
@@ -1394,8 +1473,10 @@
   CHECK(old_to_new_instrs_.contains(old_instr));
   auto new_instr = old_to_new_instrs_[old_instr];
   VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim "
-          << old_space_dim << " new_instr " << new_instr->ToString()
-          << " permute dims " << instr_to_dim_permute_map_.count(new_instr);
+          << old_space_dim << " old_instr " << old_instr->ToString()
+          << "\n new_instr " << new_instr->ToString() << " permute dims "
+          << instr_to_dim_permute_map_.count(new_instr) << " old_batch_size "
+          << old_batch_size;
   CHECK(instr_to_dim_permute_map_.contains(new_instr));
   auto permute_dims = instr_to_dim_permute_map_[new_instr];
   const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
@@ -1565,6 +1646,7 @@
                           c.spatial_dimension_to_split, activations_batch_dim));
   activations_new = retval.instr;
   std::vector<int64> trans_dims = retval.transpose_dims;
+  CHECK(!trans_dims.empty());
   auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
       LiteralUtil::Zero(activations_new->shape().element_type())));
 
@@ -1578,8 +1660,7 @@
 
   VLOG(1) << "spatial size " << c.spatial_size;
 
-  const int64 num_splits = kNewBatchSize / old_batch_size;
-
+  const int64 num_splits = kNumSplits;
   const int64 output_offsets = convolution->shape().dimensions(
       permuted_conv_dims_numbers.output_spatial_dimensions(
           get_chosen_spatial_dim(convolution)));
@@ -1614,6 +1695,8 @@
         activations_new->shape().dimensions().end());
     const int64 reshaped_space_size =
         new_space_size * new_batch_size / old_batch_size;
+    VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
+            << new_batch_size << " old_batch_size " << old_batch_size;
     new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
     new_dimensions[activations_batch_dim] = old_batch_size;
 
@@ -1621,10 +1704,12 @@
     TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
                         MakeReshapeHlo(new_dimensions, activations_new));
 
+    VLOG(3) << "First reshape done";
     PaddingConfig padding_config =
         MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
     padding_config.mutable_dimensions(c.spatial_dimension_to_split)
-        ->set_edge_padding_high(spatial_split_size * new_batch_size -
+        ->set_edge_padding_high(spatial_split_size * new_batch_size /
+                                    old_batch_size -
                                 reshaped_space_size);
     padding_config.mutable_dimensions(c.spatial_dimension_to_split)
         ->set_edge_padding_low(0);
@@ -1647,6 +1732,8 @@
         reshaped_activations,
         MakeReshapeHlo(reshape_back_dims, reshaped_activations));
 
+    VLOG(3) << "Second reshape done";
+
     TF_ASSIGN_OR_RETURN(
         activations_new,
         HaloDuplicateWithSlice(
@@ -1664,6 +1751,7 @@
     // additional space available, and adjust the required slice size (and
     // thereby the halo size).
     if (spatial_split_size < new_space_size) {
+      VLOG(3) << "Decreasing the spatial size while propagating";
       const int64 additional_space_present = spatial_split_size % c.stride;
       spatial_split_size = new_space_size;
       slice_size =
@@ -1758,6 +1846,7 @@
 
   activations = retval.instr;
   std::vector<int64> transpose_dims = retval.transpose_dims;
+  CHECK(!transpose_dims.empty());
   // Because we are splitting the spatial dimension, if convolution needed
   // padding in the spatial dimension, we materialize it.
   if (high_padding || low_padding) {
@@ -1774,7 +1863,9 @@
                         MakePadHlo(activations, padding, padding_config));
   }
   VLOG(1) << "Initial padded activations shape "
-          << activations->shape().ToString();
+          << activations->shape().ToString() << " old_batch_size "
+          << old_batch_size << " activations_batch_dim "
+          << activations_batch_dim;
 
   // Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16]
   // and 4 splits were needed, we first create [4, 4]. Next, to deal with halo
@@ -1829,7 +1920,10 @@
   CHECK(old_to_new_instrs_.contains(kernel_old));
   auto kernel_new = old_to_new_instrs_[kernel_old];
 
+  auto permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
+
   HloInstruction* activations_new = nullptr;
+  bool activations_locally_space_to_batched = false;
   // If activations were no space-to-batched, we space-to-batch them below.
   if (!old_to_new_instrs_.contains(activations_old)) {
     VLOG(1) << "Space-to-batching activations to enable space-to-depth";
@@ -1838,28 +1932,34 @@
     instr_to_dim_map_[activations_old] =
         std::make_pair(prev_feature_dim, prev_batch_dim);
 
-    int64 activations_batch_dim = original_conv_dims.input_feature_dimension();
-    const int64 old_batch_size =
-        activations_old->shape().dimensions(activations_batch_dim);
-    const int64 num_splits = kNewBatchSize / old_batch_size;
+    const int64 new_kernel_space_dim =
+        DimLookUp(permute_dims_kernel, kernel_space_dim);
+
     const int64 new_kernel_split_dim_size =
-        kernel_new->shape().dimensions(kernel_space_dim);
+        kernel_new->shape().dimensions(new_kernel_space_dim);
     const int64 needed_spatial_size = rhs_dilation * new_kernel_split_dim_size;
     const int64 pad_size =
-        needed_spatial_size * num_splits - old_split_dim_size;
+        needed_spatial_size * kNumSplits - old_split_dim_size;
     ConvolutionDimensionNumbers tmp_dim_numbers;
     tmp_dim_numbers = original_conv_dims;
     TF_ASSIGN_OR_RETURN(
         auto retval,
         SplitSpace(activations_old, tmp_dim_numbers, old_space_dim,
-                   activations_batch_dim,
+                   old_batch_dim,
                    /*high_padding=*/pad_size, /*low_padding=*/0,
-                   needed_spatial_size, num_splits, /*is_backprop=*/true));
+                   needed_spatial_size, kNumSplits, /*is_backprop=*/true));
 
     old_to_new_instrs_[activations_old] = retval.first;
-    instr_to_dim_permute_map_[retval.first] = retval.second;
 
-    VLOG(3) << "Edited conv dims " << original_conv_dims.DebugString();
+    std::vector<int64> reversed_transpose_dims(retval.second.size());
+    for (int64 i = 0; i < retval.second.size(); ++i) {
+      reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
+    }
+    instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
+
+    VLOG(3) << "New Activations " << retval.first->ToString();
+
+    activations_locally_space_to_batched = true;
   }
 
   CHECK(old_to_new_instrs_.contains(activations_old));
@@ -1884,7 +1984,7 @@
         i, DimLookUp(permute_dims,
                      original_conv_dims.input_spatial_dimensions(i)));
     permuted_conv_dims_numbers.set_kernel_spatial_dimensions(
-        i, DimLookUp(permute_dims,
+        i, DimLookUp(permute_dims_kernel,
                      original_conv_dims.kernel_spatial_dimensions(i)));
   }
 
@@ -1905,10 +2005,11 @@
       previous_spatial_dim_count, previous_chosen_spatial_dim_in_output);
 
   const int64 kernel_input_feature_dim = DimLookUp(
-      permute_dims, original_conv_dims.kernel_input_feature_dimension());
+      permute_dims_kernel, original_conv_dims.kernel_input_feature_dimension());
 
-  const int64 kernel_output_feature_dim = DimLookUp(
-      permute_dims, original_conv_dims.kernel_output_feature_dimension());
+  const int64 kernel_output_feature_dim =
+      DimLookUp(permute_dims_kernel,
+                original_conv_dims.kernel_output_feature_dimension());
 
   permuted_conv_dims_numbers.set_kernel_input_feature_dimension(
       kernel_input_feature_dim);
@@ -1931,7 +2032,8 @@
 
   VLOG(1) << "Propagating on conv activations_batch_dim "
           << activations_batch_dim << " spatial_dimension_to_split "
-          << spatial_dimension_to_split << " old_batch_size " << old_batch_size;
+          << spatial_dimension_to_split << " old_batch_size " << old_batch_size
+          << " new_split_dim_size " << new_split_dim_size;
 
   TF_ASSIGN_OR_RETURN(
       auto retval,
@@ -1939,6 +2041,7 @@
                             spatial_dimension_to_split, activations_batch_dim,
                             /*is_backprop=*/true));
   std::vector<int64> transpose_dims = retval.transpose_dims;
+  CHECK(!transpose_dims.empty());
   activations_new = retval.instr;
 
   VLOG(1) << "Activations_new post BringSpaceNextToBatch "
@@ -1949,13 +2052,15 @@
   auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
       LiteralUtil::Zero(activations_new->shape().element_type())));
 
-  // Select activations correctly by masking additional space.
-  TF_ASSIGN_OR_RETURN(
-      activations_new,
-      SelectValidPortion(activations_new, activations_old, select_val,
-                         activations_batch_dim, spatial_dimension_to_split,
-                         old_batch_dim, old_space_dim));
-
+  if (!activations_locally_space_to_batched) {
+    // Select activations correctly by masking additional space.
+    TF_ASSIGN_OR_RETURN(
+        activations_new,
+        SelectValidPortion(activations_new, activations_old, select_val,
+                           activations_batch_dim, spatial_dimension_to_split,
+                           old_batch_dim, old_space_dim));
+  }
+  VLOG(3) << "Selecting the valid kernel area";
   // Select kernel correctly by masking additional space.
   TF_ASSIGN_OR_RETURN(
       kernel_new,
@@ -2238,7 +2343,6 @@
 
   VLOG(1) << "spatial size " << c.spatial_size;
 
-  const int64 num_splits = kNewBatchSize / old_batch_size;
   auto original_conv = convolution;
 
   const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions(
@@ -2246,13 +2350,13 @@
   const int64 output_offsets =
       convolution->shape().dimensions(output_spatial_dim);
   const int64 output_offsets_per_split =
-      CeilOfRatio(output_offsets, num_splits);
+      CeilOfRatio(output_offsets, kNumSplits);
 
   int64 spatial_split_size =
       CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
   // Keep increasing the split size so that overall size isn't smaller than the
   // original spatial dimension.
-  while (spatial_split_size * num_splits - c.spatial_size < 0) {
+  while (spatial_split_size * kNumSplits - c.spatial_size < 0) {
     spatial_split_size += c.stride;
   }
 
@@ -2276,12 +2380,12 @@
   const int64 slice_size = spatial_split_size + c.halo_size;
 
   // Pad spatial dim.
-  const int64 pad_size = spatial_split_size * num_splits - c.spatial_size;
+  const int64 pad_size = spatial_split_size * kNumSplits - c.spatial_size;
 
   VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
           << c.stride << " slice_size " << slice_size;
   VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
-          << " num_splits " << num_splits << " kernel_spatial_dim_size "
+          << " num_splits " << kNumSplits << " kernel_spatial_dim_size "
           << c.kernel_spatial_dim_size;
   int64 spatial_dimension_to_split = c.spatial_dimension_to_split;
   TF_ASSIGN_OR_RETURN(
@@ -2292,7 +2396,7 @@
                  /*low_padding=*/c.base_dilation_factor == 1
                      ? c.inherent_low_padding
                      : 0,
-                 spatial_split_size, num_splits));
+                 spatial_split_size, kNumSplits));
   HloInstruction* batch_increased_reshape = retval.first;
   convolution->SetupDerivedInstruction(batch_increased_reshape);
 
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
index 3bca043..b36c646 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
@@ -417,6 +417,10 @@
     if (try_reshard.has_value()) {
       return try_reshard.value();
     }
+    try_reshard = ReshardPartialReplicateWithAllToAll(target);
+    if (try_reshard.has_value()) {
+      return try_reshard.value();
+    }
   }
 
   if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
@@ -424,6 +428,10 @@
     if (try_reshard.has_value()) {
       return try_reshard.value();
     }
+    try_reshard = ReshardPartialReplicateWithAllToAll(target);
+    if (try_reshard.has_value()) {
+      return try_reshard.value();
+    }
   }
 
   // If not replicated yet, first replicate and then reshard to use one of the
@@ -1216,6 +1224,92 @@
       .ReshardWithAllToAll(target, remaining_source_target_dims);
 }
 
+absl::optional<PartitionedHlo>
+PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
+  bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
+  const auto& partial_replicate_sharding =
+      source_is_partial_replicate ? sharding() : target;
+  // If neither the source nor the target is partial replicate, return null.
+  if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
+    return absl::nullopt;
+  }
+  const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
+  // If both source and target are partial replicate, should be supported in
+  // Reshard with AllToAll already.
+  if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
+    return absl::nullopt;
+  }
+
+  // Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
+  // to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
+  // the last tile dim will be replicate first before all-to-all.
+  // Or resharding from
+  // sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
+  // to sharding={devices=[2,3]0,1,2,3,4,5}, where
+  // the last tile dim will be sharded after all-to-all.
+  const int num_replicas =
+      partial_replicate_sharding.tile_assignment().dimensions().back();
+  if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
+       partial_replicate_sharding.tile_assignment().num_dimensions()) ||
+      (partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
+    return absl::nullopt;
+  }
+  int to_replicate_dim = -1;
+  for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
+       --i) {
+    if (tile_sharding.tile_assignment().dim(i) > 1 &&
+        (to_replicate_dim == -1)) {
+      if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
+        return absl::nullopt;
+      }
+      to_replicate_dim = i;
+    }
+
+    if (tile_sharding.tile_assignment().dim(i) !=
+        partial_replicate_sharding.tile_assignment().dim(i + 1)) {
+      return absl::nullopt;
+    }
+  }
+
+  if (to_replicate_dim == -1) {
+    return absl::nullopt;
+  }
+
+  // Check if core assignments for source and the target are the same.
+  auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
+  reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
+  if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
+    return absl::nullopt;
+  }
+
+  auto tmp_tile_assignment = tile_sharding.tile_assignment();
+  auto tmp_tile_assignment_dimensions =
+      tile_sharding.tile_assignment().dimensions();
+  tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
+  tmp_tile_assignment_dimensions.push_back(num_replicas);
+  tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
+  auto tmp_partial_replicate_sharding =
+      HloSharding::PartialTile(tmp_tile_assignment);
+
+  if (source_is_partial_replicate) {
+    if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
+            sharding(), tmp_partial_replicate_sharding)) {
+      auto partitioned_hlo =
+          ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
+      return partitioned_hlo.Reshard(target);
+    }
+  } else {
+    auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
+
+    if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
+            partitioned_hlo.sharding(), target)) {
+      return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
+    }
+  }
+
+  return absl::nullopt;
+}
+
 PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
     const HloSharding& target) const {
   CHECK(CanReshardWithCollectivePermute(sharding(), target))
@@ -1422,9 +1516,25 @@
   // temp_output_shape is the output shape where the concatenate dimension
   // is changed to the full (and padded to shard count) dimension size.
   auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding);
+  auto last_operand_padded_shape =
+      MakePartitionedShape(hlo->operands().back()->shape(), sharding);
+  // If the last operand has more padding than the temp_output padding, needs to
+  // add extra padding to avoid dynamic update slice out of bound.
+  int last_operand_padding =
+      last_operand_padded_shape.dimensions(dimension) *
+          sharding.tile_assignment().dim(dimension) -
+      hlo->operands().back()->shape().dimensions(dimension);
+  int temp_output_padding = temp_output_shape.dimensions(dimension) *
+                                sharding.tile_assignment().dim(dimension) -
+                            hlo->shape().dimensions(dimension);
+  int padding_for_last_operand =
+      last_operand_padding < temp_output_padding
+          ? 0
+          : last_operand_padding - temp_output_padding;
   temp_output_shape.set_dimensions(
       dimension, temp_output_shape.dimensions(dimension) *
-                     sharding.tile_assignment().dim(dimension));
+                         sharding.tile_assignment().dim(dimension) +
+                     padding_for_last_operand);
   auto temp_output = CreateZero(temp_output_shape, &b_);
 
   // Offset of each operand along the concatenate dimension.
@@ -3399,6 +3509,10 @@
 }
 
 Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
+  // TODO(b/73062247) Variadic reduce window not yet supported in partitioner.
+  if (hlo->shape().IsTuple()) {
+    return DefaultAction(hlo);
+  }
   auto& operand = GetPartitionedHlo(hlo->operand(0));
   if (hlo->sharding().IsTileMaximal()) {
     return DefaultAction(hlo);
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
index d5a2efd..d77fd7e 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
@@ -338,6 +338,10 @@
   absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice(
       const HloSharding& target);
 
+  // Helper function to reshard from partial replicate using AllToAll.
+  absl::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll(
+      const HloSharding& target);
+
   // SPMD instruction.
   HloInstruction* hlo_;
 
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
index cac7694..52bd709 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
@@ -5542,6 +5542,55 @@
   EXPECT_THAT(root, partially_replicated);
 }
 
+TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshardUnevenPartition) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[8,8] parameter(0),
+    sharding={devices=[2,3]0,1,2,3,4,5}
+  ROOT %copy0 = f32[8,8] copy(%param0),
+    sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/6));
+  VLOG(1) << module->ToString();
+  auto tiled = AllOf(op::Shape("f32[4,3]"), op::Parameter(0));
+  auto partially_replicated = AllOf(
+      op::Shape("f32[8,4]"),
+      op::Copy(op::Reshape(
+          op::Transpose(op::AllToAll(op::Reshape(op::Slice(op::AllReduce(
+              op::DynamicUpdateSlice(op::Broadcast(), tiled, _, _)))))))));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, partially_replicated);
+}
+
+TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshardUnevenPartition) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[8,8] parameter(0),
+    sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
+  ROOT %copy0 = f32[8,8] copy(%param0),
+    sharding={devices=[2,3]0,1,2,3,4,5}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/6));
+  VLOG(1) << module->ToString();
+  auto partial_replicated = AllOf(op::Shape("f32[8,4]"), op::Parameter(0));
+  auto tiled = AllOf(
+      op::Shape("f32[4,3]"),
+      op::Copy(op::DynamicSlice(op::Pad(op::Reshape(op::Transpose(op::AllToAll(
+                                            op::Reshape(partial_replicated)))),
+                                        _),
+                                _, _)));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, tiled);
+}
+
 TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
   const char* const hlo_string = R"(
 HloModule module
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 7080e44..dffdeff 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -367,6 +367,7 @@
         "conv_depthwise_test.cc",
     ],
     shard_count = 50,
+    tags = ["no_rocm"],  # ROCm 3.9 regression
     deps = [
         ":conv_depthwise_common",
         ":test_macros_header",
@@ -388,6 +389,7 @@
     timeout = "long",
     srcs = ["conv_depthwise_backprop_filter_test.cc"],
     shard_count = 40,
+    tags = ["no_rocm"],  # ROCm 3.9 regression
     deps = [
         ":test_macros_header",
         "//tensorflow/compiler/xla:execution_options_util",
@@ -412,6 +414,7 @@
         "cpu",
     ],
     shard_count = 50,
+    tags = ["no_rocm"],  # ROCm 3.9 regression
     deps = [
         ":client_library_test_base",
         ":hlo_test_base",
@@ -626,7 +629,6 @@
     name = "conditional_test",
     srcs = ["conditional_test.cc"],
     shard_count = 2,
-    tags = ["no_rocm"],
     deps = [
         ":test_macros_header",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
@@ -665,7 +667,6 @@
     name = "scalar_computations_test",
     srcs = ["scalar_computations_test.cc"],
     shard_count = 32,
-    tags = ["no_rocm"],
     deps = [
         ":test_macros_header",
         "//tensorflow/compiler/xla:literal",
@@ -923,7 +924,7 @@
     srcs = ["dot_operation_test.cc"],
     shard_count = 20,
     tags = [
-        "no_rocm",
+        "no_rocm",  # ROCm 3.9 regression
         "optonly",
     ],
     deps = [
@@ -957,7 +958,7 @@
     backends = ["gpu"],
     shard_count = 20,
     tags = [
-        "no_rocm",
+        "no_rocm",  # ROCm 3.9 regression
         "optonly",
         # TODO(b/151340488): Timed out on 2020-03-12.
         "nozapfhahn",
@@ -1024,7 +1025,7 @@
     },
     shard_count = 20,
     tags = [
-        "no_rocm",
+        "no_rocm",  # ROCm 3.9 regression
         "optonly",
     ],
     deps = [
@@ -1252,6 +1253,7 @@
         "cpu": ["nomsan"],
     },
     shard_count = 30,
+    tags = ["no_rocm"],  # ROCm 3.9 regression
     deps = [
         ":test_macros_header",
         "//tensorflow/compiler/xla:array3d",
@@ -1276,6 +1278,7 @@
     timeout = "long",
     srcs = ["convolution_dimension_numbers_test.cc"],
     shard_count = 20,
+    tags = ["no_rocm"],  # ROCm 3.9 regression
     deps = [
         ":test_macros_header",
         "//tensorflow/compiler/xla:array4d",
@@ -1514,7 +1517,6 @@
     srcs = ["reduce_test.cc"],
     shard_count = 31,
     tags = [
-        "no_rocm",
         "optonly",
     ],
     deps = [
@@ -1594,7 +1596,6 @@
     timeout = "long",
     srcs = ["select_and_scatter_test.cc"],
     tags = [
-        "no_rocm",
         "nozapfhahn",
         "optonly",
     ],
@@ -2098,7 +2099,10 @@
     name = "dynamism_inference_test",
     srcs = ["dynamism_inference_test.cc"],
     deps = [
+        ":literal_test_util",
         ":test_macros_header",
+        ":test_utils",
+        ":xla_internal_test_main",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
@@ -2109,10 +2113,8 @@
         "//tensorflow/compiler/xla/client:global_data",
         "//tensorflow/compiler/xla/client:xla_builder",
         "//tensorflow/compiler/xla/client:xla_computation",
+        "//tensorflow/compiler/xla/client/lib:arithmetic",
         "//tensorflow/compiler/xla/client/lib:prng",
-        "//tensorflow/compiler/xla/tests:literal_test_util",
-        "//tensorflow/compiler/xla/tests:test_utils",
-        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "@com_google_absl//absl/strings",
@@ -2320,6 +2322,7 @@
     name = "multioutput_fusion_test",
     srcs = ["multioutput_fusion_test.cc"],
     backends = ["gpu"],
+    tags = ["no_rocm"],  # ROCm 3.9 regression
     deps = [
         ":test_macros_header",
         "//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc
index 97c0f33..adeb83d 100644
--- a/tensorflow/compiler/xla/tests/collective_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc
@@ -738,6 +738,33 @@
                                          results[3]);
 }
 
+XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather)) {
+  const char* const kModuleStr = R"(
+  HloModule test
+  ENTRY test_computation {
+    id = u32[] replica-id()
+    id2 = u32[1, 2] broadcast(id), dimensions={}
+    a0 = u32[1, 2] constant({{10, 15}})
+    a1 = u32[1, 2] add(id2, a0)
+    allgather = u32[4, 2] all-gather(a1), dimensions={0}
+    ROOT out = u32[8] reshape(allgather)
+  }
+  )";
+  const int64 kNumReplicas = 4;
+  auto config = GetModuleConfigForTest(kNumReplicas);
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(kModuleStr, config));
+
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
+                          ExecuteReplicated(std::move(module), {}, kNumReplicas,
+                                            /*use_threads=*/true));
+  ASSERT_EQ(results.size(), kNumReplicas);
+  for (const Literal& result : results) {
+    LiteralTestUtil::ExpectR1Equal<uint32>({10, 15, 11, 16, 12, 17, 13, 18},
+                                           result);
+  }
+}
+
 XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
   std::string hlo_string = R"(
     HloModule test
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 72f2708..690b657 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -1210,7 +1210,7 @@
           .ValueOrDie(),
       &builder);
   auto config = std::get<2>(GetParam());
-  if (config.find(",") == config.npos) {
+  if (config.find(',') == config.npos) {
     Einsum(x, config);
   } else {
     Einsum(x, y, config);
diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
index 96ba73a..892fdb8 100644
--- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
@@ -20,6 +20,7 @@
 #include "absl/strings/match.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
 #include "tensorflow/compiler/xla/client/lib/prng.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -162,6 +163,21 @@
   }
 }
 
+TEST_F(DynamismInferenceTest, ReduceUsedTwice) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    auto c = ConstantR0<int32>(&b, 42);
+    auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
+    auto zero = ConstantR0<int32>(&b, 0);
+    XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
+    auto reduce = Reduce(p, zero, add_s32, {0});
+    auto pred = Eq(c, reduce);
+    auto result = Select(pred, reduce, c);
+    EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true);
+  }
+}
+
 TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
@@ -317,5 +333,27 @@
   }
 }
 
+TEST_F(DynamismInferenceTest, InferThroughPad) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    // Test the analysis on a gather.
+    auto operand1 = ConstantR1<int32>(&b, {1, 2});
+    auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
+    PaddingConfig padding_config;
+    padding_config.add_dimensions()->set_edge_padding_high(1);
+    // After pad the value is [constant, constant, parameter].
+    auto pad = Pad(operand1, parameter, padding_config);
+    ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
+    // Everything is constant, result is also contant.
+    EXPECT_FALSE(
+        ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({0}));
+    EXPECT_FALSE(
+        ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({1}));
+    EXPECT_TRUE(
+        ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({2}));
+  }
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
index 1868159..9b397dc 100644
--- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
+++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
@@ -22,6 +22,9 @@
 namespace xla {
 namespace {
 
+using ::testing::StartsWith;
+using ::testing::StrEq;
+
 class HloMetadataTest : public LocalClientTestBase {
  protected:
   HloMetadataTest() {
@@ -79,9 +82,8 @@
                          ->module()
                          .entry_computation()
                          ->root_instruction();
-  // We expect these to be empty (no metadata set).
-  EXPECT_EQ("", instruction->metadata().op_type());
-  EXPECT_EQ("", instruction->metadata().op_name());
+  EXPECT_THAT(instruction->metadata().op_type(), StrEq(""));
+  EXPECT_THAT(instruction->metadata().op_name(), StartsWith("DUMMY"));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index 0ddb01f..49e7560 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -57,7 +57,8 @@
 
   StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
       const HloModule* hlo_module, llvm::Module* llvm_module,
-      GpuVersion gpu_version, se::StreamExecutor* stream_exec) {
+      GpuVersion gpu_version, se::StreamExecutor* stream_exec,
+      bool relocatable) {
     if (user_post_optimization_hook_) {
       user_post_optimization_hook_(*llvm_module);
     }
diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
index c2dc912..b4d8d3c 100644
--- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
+++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
@@ -25,6 +25,20 @@
 
 namespace xla {
 
+namespace {
+
+void RemoveDummyMetadataNames(HloModule* module) {
+  for (xla::HloComputation* computation : module->computations()) {
+    for (xla::HloInstruction* instruction : computation->instructions()) {
+      if (absl::StartsWith(instruction->metadata().op_name(), "DUMMY")) {
+        instruction->set_metadata_op_name("");
+      }
+    }
+  }
+}
+
+}  // namespace
+
 void LlvmIrGenTestBase::SetIrHook(bool match_optimized_ir) {
   auto llvm_compiler = GetLLVMCompiler();
   using std::placeholders::_1;
@@ -56,7 +70,7 @@
 
   StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
   TF_ASSERT_OK(filecheck_result.status());
-  EXPECT_TRUE(filecheck_result.ValueOrDie());
+  EXPECT_TRUE(filecheck_result.ValueOrDie()) << "Full IR: " << ir_;
 }
 
 void LlvmIrGenTestBase::CompileAndVerifyIr(const string& hlo_text,
@@ -80,7 +94,7 @@
 
   StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
   ASSERT_TRUE(filecheck_result.ok());
-  EXPECT_TRUE(filecheck_result.ValueOrDie());
+  EXPECT_TRUE(filecheck_result.ValueOrDie()) << "Full IR: " << ir_;
 }
 
 void LlvmIrGenTestBase::MatchOptimizedHlo(absl::string_view hlo,
@@ -88,6 +102,7 @@
                                           bool print_operand_shape) {
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
                           GetOptimizedModule(hlo));
+  RemoveDummyMetadataNames(optimized_module.get());
   HloPrintOptions print_opts;
   print_opts.set_print_operand_shape(print_operand_shape);
   StatusOr<bool> filecheck_result =
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 7e5b699..d86ebfa 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -1704,5 +1704,114 @@
   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
 }
 
+XLA_TEST_F(HloTestBase,
+           DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport))) {
+  const char* const hlo_string = R"(
+HloModule module
+
+sum {
+  a0 = f32[] parameter(0)
+  a1 = f32[] parameter(1) 
+  b0 = f32[] parameter(2)
+  b1 = f32[] parameter(3)
+  add0 = f32[] add(a0, b0)
+  add1 = f32[] add(a1, b1)
+  ROOT sum2 = (f32[], f32[]) tuple(add0, add1)
+}
+
+ENTRY entry {
+  constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
+  constant.1 = f32[] constant(0)
+  constant.2 = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
+  constant.3 = f32[] constant(0)
+  reduce-window = (f32[2,2]{1,0}, f32[2,2]{1,0}) 
+    reduce-window(constant, constant.2, constant.1, constant.3),
+    window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
+  ROOT copy = (f32[2,2]{1,0}, f32[2,2]{1,0}) copy(reduce-window)
+})";
+  EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
+}
+
+XLA_TEST_F(HloTestBase,
+           DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport2))) {
+  const char* const hlo_string = R"(
+HloModule module
+
+sum {
+  a0 = f32[] parameter(0)
+  a1 = s32[] parameter(1) 
+  b0 = f32[] parameter(2)
+  b1 = s32[] parameter(3)
+  add0 = f32[] add(a0, b0)
+  add1 = s32[] add(a1, b1)
+  ROOT sum2 = (f32[], s32[]) tuple(add0, add1)
+}
+
+ENTRY entry {
+  constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
+  constant.1 = f32[] constant(0)
+  constant.2 = s32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
+  constant.3 = s32[] constant(0)
+  ROOT reduce-window = (f32[2,2]{1,0}, s32[2,2]{1,0}) 
+    reduce-window(constant, constant.2, constant.1, constant.3),
+    window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
+})";
+  EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
+}
+
+XLA_TEST_F(HloTestBase,
+           DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport3))) {
+  const char* const hlo_string = R"(
+HloModule module
+
+sum {
+  a0 = f32[] parameter(0)
+  a1 = bf16[] parameter(1) 
+  b0 = f32[] parameter(2)
+  b1 = bf16[] parameter(3)
+  add0 = f32[] add(a0, b0)
+  add1 = bf16[] add(a1, b1)
+  ROOT sum2 = (f32[], bf16[]) tuple(add0, add1)
+}
+
+ENTRY entry {
+  constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
+  constant.1 = f32[] constant(0)
+  constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
+  constant.3 = bf16[] constant(0)
+  ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0}) 
+    reduce-window(constant, constant.2, constant.1, constant.3),
+    window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
+})";
+  EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
+}
+
+XLA_TEST_F(HloTestBase,
+           DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport4))) {
+  const char* const hlo_string = R"(
+HloModule module
+
+sum {
+  a0 = f32[] parameter(0)
+  a1 = bf16[] parameter(1) 
+  b0 = f32[] parameter(2)
+  b1 = bf16[] parameter(3)
+  add0 = f32[] add(a0, b0)
+  add1 = bf16[] multiply(a1, b1)
+  ROOT sum2 = (f32[], bf16[]) tuple(add0, add1)
+}
+
+ENTRY entry {
+  constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
+  constant.1 = f32[] constant(0)
+  constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
+  constant.3 = bf16[] constant(1)
+  ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0}) 
+    reduce-window(constant, constant.2, constant.1, constant.3),
+    window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
+})";
+  EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index aa02deb..4b64ef4 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -94,7 +94,6 @@
     ],
 )
 
-# To run with MLIR GPU plugin enabled, pass --define=with_mlir_gpu_support=true.
 tf_cc_binary(
     name = "replay_computation_gpu",
     tags = ["gpu"],
@@ -328,7 +327,6 @@
     ],
 )
 
-# To run with MLIR GPU plugin enabled, pass --define=with_mlir_gpu_support=true.
 tf_cc_binary(
     name = "run_hlo_module",
     testonly = True,
diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.cc b/tensorflow/compiler/xla/tools/run_hlo_module.cc
index be9b23e..015a381 100644
--- a/tensorflow/compiler/xla/tools/run_hlo_module.cc
+++ b/tensorflow/compiler/xla/tools/run_hlo_module.cc
@@ -117,7 +117,8 @@
     const RunHloModuleOptions& options,
     std::function<Status(const HloModule&,
                          const ::stream_executor::Platform::Id&, HloModule*)>
-        reference_module_modifier_hook) {
+        reference_module_modifier_hook,
+    std::function<void(HloModuleConfig*)> config_modifier_hook) {
   se::Platform* test_platform =
       xla::PlatformUtil::GetPlatform(test_platform_name).ValueOrDie();
   se::Platform* reference_platform =
@@ -125,11 +126,15 @@
           ? nullptr
           : xla::PlatformUtil::GetPlatform(reference_platform_name)
                 .ValueOrDie();
-  auto config_modifier = [](HloModuleConfig* config) { config->set_seed(42); };
+  if (!config_modifier_hook) {
+    config_modifier_hook = [](HloModuleConfig* config) {
+      config->set_seed(42);
+    };
+  }
 
   std::unique_ptr<HloModule> test_module =
       LoadModuleFromFile(hlo_filename, hlo_module_loader_details::Config(),
-                         options.input_format, config_modifier)
+                         options.input_format, config_modifier_hook)
           .ValueOrDie();
   const HloModuleProto test_module_proto = test_module->ToProto();
 
@@ -148,10 +153,10 @@
   if (reference_platform != nullptr) {
     // PrepareReferenceModule needs to know the *test* platform, in order to
     // properly match the test platform's numerics.
-    reference_module =
-        PrepareReferenceModule(*test_module, test_platform->id(),
-                               config_modifier, reference_module_modifier_hook)
-            .ConsumeValueOrDie();
+    reference_module = PrepareReferenceModule(*test_module, test_platform->id(),
+                                              config_modifier_hook,
+                                              reference_module_modifier_hook)
+                           .ConsumeValueOrDie();
   }
 
   Literal test_result = ExecuteOnPlatform(
diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.h b/tensorflow/compiler/xla/tools/run_hlo_module.h
index 57f81cc..8dc720c 100644
--- a/tensorflow/compiler/xla/tools/run_hlo_module.h
+++ b/tensorflow/compiler/xla/tools/run_hlo_module.h
@@ -68,7 +68,8 @@
     const RunHloModuleOptions& options,
     std::function<Status(const HloModule&,
                          const ::stream_executor::Platform::Id&, HloModule*)>
-        reference_module_modifier_hook = {});
+        reference_module_modifier_hook = {},
+    std::function<void(HloModuleConfig*)> config_modifier_hook = {});
 
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 13897c4..44a5bf4 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -32,7 +32,6 @@
 #include "absl/strings/str_join.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
-#include "llvm/ADT/SmallVector.h"
 #include "tensorflow/compiler/xla/status.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -579,11 +578,6 @@
 // range that is available in F32s (out of a total of 11 exponent bits in F64s).
 std::pair<float, float> SplitF64ToF32(double x);
 
-template <typename T>
-std::vector<T> ToStdVector(const llvm::SmallVectorImpl<T>& v) {
-  return std::vector<T>(v.begin(), v.end());
-}
-
 // MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its
 // destructor. The easiest way to use MakeCleanup is with a lambda argument,
 // capturing the return value in an 'auto' local variable. Most users will not
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 9b4b386..eb67010 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -302,7 +302,16 @@
   // Enable detailed logging into vlog.
   bool xla_detailed_logging = 143;
 
-  // Next id: 146
+  // Overrides normal multi-threaded compilation settting to use this many
+  // threads. Setting to 0 (the default value) means no enforcement.
+  int32 xla_gpu_force_compilation_parallelism = 147;
+
+  // Guarantees run-to-run determinism. At present, the HLO ops Scatter and
+  // SelectAndScatter do not have deterministic XLA:GPU implementations.
+  // Compilation errors out if these ops are encountered.
+  bool xla_gpu_deterministic_ops = 148;
+
+  // Next id: 149
 
   // Extra options to pass to the compilation backend (e.g. LLVM); specific
   // interpretation of these values is left to the backend.
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 11b39be..844be33 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -271,6 +271,10 @@
   //
   // This name is often unique within a computation. Note: some frameworks
   // add auto-generated names if the user does not provide one.
+  //
+  // A dummy name may be assigned if op_name is empty in order to keep track of
+  // where op_name first became empty. Dummy names begin with "DUMMY_" and may
+  // include the current HloPassMetadata.pass_id.
   string op_name = 2;
   // Indicate a file and line that this op is associated to in a user's program.
   //
@@ -279,6 +283,11 @@
   int32 source_line = 4;
 
   repeated ProfileType profile_type = 5;
+
+  // HloPassMetadata.pass_id of the pass that created this HLO instruction
+  // object. Should never be copied between HLO instructions. Zero if unset and
+  // -1 if the instruction was created before HLO passes began.
+  int64 creation_pass_id = 6;
 }
 
 // Profile data from the execution of a computation.
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5833dd6..a10c88a 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1246,7 +1246,7 @@
             "-lpthread",
         ],
     }),
-    visibility = ["//tensorflow/python:__pkg__"],
+    visibility = ["//tensorflow/python:__subpackages__"],
     deps = tf_additional_lib_deps() + [
         "@com_google_absl//absl/meta:type_traits",
         "@com_google_absl//absl/strings",
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecvV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecvV2.pbtxt
new file mode 100644
index 0000000..7114888
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecvV2.pbtxt
@@ -0,0 +1,5 @@
+op {
+  graph_op_name: "CollectiveBcastRecvV2"
+  summary: "Receives a tensor value broadcast from another device."
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSendV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSendV2.pbtxt
new file mode 100644
index 0000000..8d0cced
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSendV2.pbtxt
@@ -0,0 +1,5 @@
+op {
+  graph_op_name: "CollectiveBcastSendV2"
+  summary: "Broadcasts a tensor value to one or more other devices."
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt
index 3654286..d5fde05 100644
--- a/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt
@@ -20,9 +20,8 @@
         "A `Tensor` of type T. An alias of `x`. The content "
         "of `y` is undefined if there are duplicates in `i`."
   }
-  summary: <<END
-    Adds v into specified rows of x.
-
+  summary: "Adds v into specified rows of x."
+  description: <<END
     Computes y = x; y[i, :] += v; return y.
 END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscBinaryArithmetic.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscBinaryArithmetic.pbtxt
new file mode 100644
index 0000000..eed8919
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscBinaryArithmetic.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscBinaryArithmetic"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscBinaryComparison.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscBinaryComparison.pbtxt
new file mode 100644
index 0000000..7ff2224
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscBinaryComparison.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscBinaryComparison"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscBitcast.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscBitcast.pbtxt
new file mode 100644
index 0000000..65e2496
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscBitcast.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscBitcast"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscCast.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscCast.pbtxt
new file mode 100644
index 0000000..c9a3de8
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscCast.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscCast"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscCholesky.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscCholesky.pbtxt
new file mode 100644
index 0000000..7950053
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscCholesky.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscCholesky"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscCondition.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscCondition.pbtxt
new file mode 100644
index 0000000..feb75f4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscCondition.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscCondition"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscFft.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscFft.pbtxt
new file mode 100644
index 0000000..8b38f7f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscFft.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscFft"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscGather.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscGather.pbtxt
new file mode 100644
index 0000000..28f4b03
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscGather.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscGather"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscIsFinite.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscIsFinite.pbtxt
new file mode 100644
index 0000000..f5334c2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscIsFinite.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscIsFinite"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscLogicalAnd.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscLogicalAnd.pbtxt
new file mode 100644
index 0000000..62683f8
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscLogicalAnd.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscLogicalAnd"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscLogicalNot.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscLogicalNot.pbtxt
new file mode 100644
index 0000000..50f7df4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscLogicalNot.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscLogicalNot"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscLogicalOr.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscLogicalOr.pbtxt
new file mode 100644
index 0000000..804e337
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscLogicalOr.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscLogicalOr"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscRandomUniform.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscRandomUniform.pbtxt
new file mode 100644
index 0000000..da9671b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscRandomUniform.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscRandomUniform"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscReduce.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscReduce.pbtxt
new file mode 100644
index 0000000..6f909fe
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscReduce.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscReduce"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscReverse.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscReverse.pbtxt
new file mode 100644
index 0000000..94f4f3e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscReverse.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscReverse"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscScatter.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscScatter.pbtxt
new file mode 100644
index 0000000..fa29741
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscScatter.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscScatter"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscSort.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscSort.pbtxt
new file mode 100644
index 0000000..bdf9def
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscSort.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscSort"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscSqueeze.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscSqueeze.pbtxt
new file mode 100644
index 0000000..02424dd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscSqueeze.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscSqueeze"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscTranspose.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscTranspose.pbtxt
new file mode 100644
index 0000000..7b19e43
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscTranspose.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscTranspose"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscTriangularSolve.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscTriangularSolve.pbtxt
new file mode 100644
index 0000000..d6f1e0f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscTriangularSolve.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscTriangularSolve"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscUnary.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscUnary.pbtxt
new file mode 100644
index 0000000..293e4c8
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscUnary.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscUnary"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RiscWhile.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscWhile.pbtxt
new file mode 100644
index 0000000..cc2e6ee
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RiscWhile.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RiscWhile"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TopKUnique.pbtxt b/tensorflow/core/api_def/base_api/api_def_TopKUnique.pbtxt
index 5c9ccba..c594746 100644
--- a/tensorflow/core/api_def/base_api/api_def_TopKUnique.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_TopKUnique.pbtxt
@@ -1,8 +1,8 @@
 op {
   graph_op_name: "TopKUnique"
-  summary: "Returns the TopK unique values in the array in sorted order. The"
+  summary: "Returns the TopK unique values in the array in sorted order."
   description: <<END
-running time is proportional to the product of K and the input
+The running time is proportional to the product of K and the input
 size. Sorting the whole array is more efficient for sufficiently large
 values of K. The median-of-medians algorithm is probably faster, but
 difficult to implement efficiently in XLA. If there are fewer than K
diff --git a/tensorflow/core/api_def/base_api/api_def_TopKWithUnique.pbtxt b/tensorflow/core/api_def/base_api/api_def_TopKWithUnique.pbtxt
index ac73bad..8c90a08 100644
--- a/tensorflow/core/api_def/base_api/api_def_TopKWithUnique.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_TopKWithUnique.pbtxt
@@ -1,10 +1,11 @@
 op {
   graph_op_name: "TopKWithUnique"
-  summary: "Returns the TopK values in the array in sorted order. This is a combination"
+  summary: "Returns the TopK values in the array in sorted order."
   description: <<END
-of MakeUnique and TopKUnique. The returned top-K will have its lower bits
-replaced by iota, thus it will be close to the original value but not exactly
-the same. The running time is proportional to the product of K and the input
-size. NaNs are never returned. Subnormal numbers are flushed to zero.
+This is a combination of MakeUnique and TopKUnique. The returned top-K will
+have its lower bits replaced by iota, thus it will be close to the original
+value but not exactly the same. The running time is proportional to the product
+of K and the input size. NaNs are never returned. Subnormal numbers are flushed
+to zero.
 END
 }
diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index aec0e16..5d68621 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -2127,7 +2127,6 @@
     size = "small",
     srcs = ["process_function_library_runtime_test.cc"],
     linkstatic = tf_kernel_tests_linkstatic(),
-    tags = ["no_rocm"],
     deps = [
         ":core_cpu",
         ":core_cpu_internal",
@@ -2513,13 +2512,10 @@
     ],
 )
 
-tf_cuda_cc_test(
+tf_cc_test(
     name = "lower_if_op_test",
     size = "small",
     srcs = ["lower_if_op_test.cc"],
-    tags = tf_cuda_tests_tags() + [
-        "no_cuda_asan",  # TODO(b/171575050): re-enable once fixed.
-    ],
     deps = [
         ":core_cpu",
         ":core_cpu_internal",
diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc
index 38368c4..2ad2267 100644
--- a/tensorflow/core/common_runtime/constant_folding_test.cc
+++ b/tensorflow/core/common_runtime/constant_folding_test.cc
@@ -631,13 +631,6 @@
   }
 }
 
-// Disabling the following test on the ROCm platform because it relies on the
-// "topK" operator being supported on the ROCm platform (which is currently not
-// the case)
-// TODO(rocm) :
-// re-enable this test once support for "topK" operator is available on ROCm
-
-#ifndef TENSORFLOW_USE_ROCM
 TEST_F(ConstantFoldingTest, NoReplacePartialOutput) {
   Graph g(OpRegistry::Global());
   {
@@ -662,7 +655,6 @@
       &g, &was_mutated));
   EXPECT_FALSE(was_mutated);
 }
-#endif  // TENSORFLOW_USE_ROCM
 
 namespace {
 
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index 728aacb..cde26d7 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -77,6 +77,7 @@
     deps = [
         ":eager_executor",
         ":kernel_and_device",
+        ":custom_device",
         "@com_google_absl//absl/container:flat_hash_map",
         "//tensorflow/c:tf_tensor_internal",
         "//tensorflow/c/eager:immediate_execution_context",
@@ -111,6 +112,39 @@
 )
 
 tf_cuda_library(
+    name = "custom_device",
+    srcs = ["custom_device.cc"],
+    hdrs = ["custom_device.h"],
+    visibility = ["//tensorflow:internal"],
+    deps = select({
+        "//tensorflow:android": [
+            "//tensorflow/core:portable_tensorflow_lib_lite",
+        ],
+        "//conditions:default": [
+            "//tensorflow/core:framework",
+            "//tensorflow/c/eager:immediate_execution_context",
+            "//tensorflow/c/eager:immediate_execution_tensor_handle",
+            "//tensorflow/core/lib/core:status",
+        ],
+    }),
+)
+
+tf_cc_test(
+    name = "custom_device_test",
+    srcs = ["custom_device_test.cc"],
+    deps = [
+        ":context",
+        ":core",
+        ":custom_device",
+        ":tensor_handle",
+        "//tensorflow/core:core_cpu_base",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
+
+tf_cuda_library(
     name = "context_distributed_manager",
     srcs = [
         "context_distributed_manager.cc",
@@ -358,6 +392,7 @@
     "//tensorflow/core:protos_all_cc",
     "//tensorflow/core/profiler/lib:annotated_traceme",
     "//tensorflow/core/profiler/lib:traceme",
+    "//tensorflow/core/grappler:grappler_item",
     "//tensorflow/core/grappler/optimizers:meta_optimizer",
 ]
 
@@ -665,6 +700,7 @@
     srcs = [
         "attr_builder.h",
         "context.h",
+        "custom_device.h",
         "eager_executor.h",
         "eager_operation.h",
         "kernel_and_device.h",
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index f9802db..813676a 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -705,10 +705,10 @@
   return Status::OK();
 }
 
-Status EagerContext::AddFunctionDefWithDebugInfo(
-    const FunctionDef& fdef, const Graph* graph_with_debug_info) {
+Status EagerContext::AddFunctionDefWithStackTraces(
+    const FunctionDef& fdef, const StackTracesMap& stack_traces) {
   return AddFunctionDef(fdef, FunctionDefLibrary(),
-                        /* add_to_local_only=*/false, graph_with_debug_info);
+                        /* add_to_local_only=*/false, stack_traces);
 }
 
 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
@@ -719,7 +719,7 @@
 Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
                                     const FunctionDefLibrary& library,
                                     const bool add_to_local_only,
-                                    const Graph* graph_with_debug_info) {
+                                    const StackTracesMap& stack_traces) {
   bool is_first_ref = false;
   {
     mutex_lock l(cache_mu_);
@@ -753,8 +753,7 @@
     is_first_ref = registered_function->RefCountIsOne();
   }
   if (is_first_ref) {
-    TF_RETURN_IF_ERROR(
-        func_lib_def_.AddFunctionDef(fdef, graph_with_debug_info));
+    TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
     TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
     if (!add_to_local_only) {
       return MaybeRegisterFunctionRemotely(fdef);
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 9e0d9fc..ef99dfa 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -30,6 +30,7 @@
 #include "tensorflow/core/common_runtime/composite_device.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/eager/custom_device.h"
 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
 #include "tensorflow/core/common_runtime/function.h"
@@ -81,26 +82,6 @@
 class TensorHandle;
 class EagerOperation;
 
-class CustomDevice {
- public:
-  virtual ~CustomDevice() {}
-  virtual const string& name() = 0;
-  virtual Status CopyTensorToDevice(TensorHandle* tensor,
-                                    TensorHandle** result) = 0;
-
-  virtual Status CopyTensorFromDevice(TensorHandle* tensor,
-                                      const string& target_device_name,
-                                      TensorHandle** result) = 0;
-
-  virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
-                         int* num_retvals) = 0;
-};
-
-// Custom devices do many of the same things as physical Devices, but have a
-// much more restricted interface. We pass around ambiguous pointers since
-// TensorHandles may be placed either on custom or physical devices.
-using VariantDevice = absl::variant<Device*, CustomDevice*>;
-
 class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
  public:
   static constexpr uint64 kInvalidContextId = 0;
@@ -234,8 +215,8 @@
   // entry to the KernelAndDevice cache for it if it's not exist.
   Status AddFunctionDef(const FunctionDef& fdef) override;
 
-  Status AddFunctionDefWithDebugInfo(
-      const FunctionDef& fdef, const Graph* graph_with_debug_info) override;
+  Status AddFunctionDefWithStackTraces(
+      const FunctionDef& fdef, const StackTracesMap& stack_traces) override;
 
   // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
   // it to the local FunctionLibraryDefinition as well, but no need to add it
@@ -244,7 +225,7 @@
   Status AddFunctionDef(const FunctionDef& fdef,
                         const FunctionDefLibrary& library,
                         const bool add_to_local_only = false,
-                        const Graph* graph_with_debug_info = nullptr);
+                        const StackTracesMap& stack_traces = {});
 
   const FunctionDef* GetFunctionDef(const string& function_name);
 
diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc
index 47864c7..d80952a 100644
--- a/tensorflow/core/common_runtime/eager/core.cc
+++ b/tensorflow/core/common_runtime/eager/core.cc
@@ -39,11 +39,15 @@
 // TODO(b/152902651): This should not depend on EagerContext. This can be
 // resolved by storing ctx->HostCPU() in the TensorHandle class.
 AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
+  *status = WaitUnknownDevice();
+  if (!status->ok()) {
+    return nullptr;
+  }
   if (VariantDeviceIsCustom(device())) {
     auto* custom_device = absl::get<CustomDevice*>(device());
     TensorHandle* copy;
-    *status = custom_device->CopyTensorFromDevice(
-        this, "/job:localhost/replica:0/task:0/device:CPU:0", &copy);
+    *status = custom_device->CopyTensorFromDevice(this, ctx_->HostCPU()->name(),
+                                                  &copy);
     if (status->ok()) {
       auto result = copy->Resolve(status);
       copy->Unref();
diff --git a/tensorflow/core/common_runtime/eager/custom_device.cc b/tensorflow/core/common_runtime/eager/custom_device.cc
new file mode 100644
index 0000000..8055e03
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/custom_device.cc
@@ -0,0 +1,85 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/eager/custom_device.h"
+
+namespace tensorflow {
+
+Status CustomDeviceTensorHandle::Shape(PartialTensorShape* shape) const {
+  int num_dims;
+  TF_RETURN_IF_ERROR(NumDims(&num_dims));
+  std::vector<int64> dims(num_dims);
+  for (int i = 0; i < num_dims; ++i) {
+    TF_RETURN_IF_ERROR(Dim(i, &dims[i]));
+  }
+  return PartialTensorShape::MakePartialShape(dims.data(), num_dims, shape);
+}
+
+Status CustomDeviceTensorHandle::NumElements(int64* num_elements) const {
+  *num_elements = 1;
+  int num_dims;
+  TF_RETURN_IF_ERROR(NumDims(&num_dims));
+  for (int i = 0; i < num_dims; ++i) {
+    int64 dim;
+    TF_RETURN_IF_ERROR(Dim(i, &dim));
+    *num_elements *= dim;
+  }
+  return Status::OK();
+}
+
+const char* CustomDeviceTensorHandle::DeviceType(Status* status) const {
+  const DeviceNameUtils::ParsedName* parsed = ParsedName(status);
+  if (!status->ok()) {
+    return "";
+  }
+  return parsed->type.c_str();
+}
+
+int CustomDeviceTensorHandle::DeviceId(Status* status) const {
+  const DeviceNameUtils::ParsedName* parsed = ParsedName(status);
+  if (!status->ok()) {
+    return 0;
+  }
+  return parsed->id;
+}
+
+AbstractTensorInterface* CustomDeviceTensorHandle::Resolve(Status* status) {
+  core::RefCountPtr<ImmediateExecutionTensorHandle> copied_off(
+      context_->CopyTensorHandleToDevice(
+          this,
+          DeviceNameUtils::ParsedNameToString(context_->HostCPUParsedName())
+              .c_str(),
+          status));
+  if (!status->ok()) {
+    return nullptr;
+  }
+  return copied_off->Resolve(status);
+}
+
+const DeviceNameUtils::ParsedName* CustomDeviceTensorHandle::ParsedName(
+    Status* status) const {
+  if (!parsed_name_.has_value()) {
+    DeviceNameUtils::ParsedName parsed_name;
+    if (!DeviceNameUtils::ParseFullOrLocalName(device_->name(), &parsed_name)) {
+      *status = errors::InvalidArgument(
+          absl::StrCat("Invalid custom device name ", device_->name()));
+      return nullptr;
+    }
+    parsed_name_.emplace(std::move(parsed_name));
+  }
+  return &*parsed_name_;
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/custom_device.h b/tensorflow/core/common_runtime/eager/custom_device.h
new file mode 100644
index 0000000..e3168b6
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/custom_device.h
@@ -0,0 +1,107 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
+
+#include <string>
+
+#include "tensorflow/c/eager/immediate_execution_context.h"
+#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+class TensorHandle;
+class EagerOperation;
+
+// Custom devices intercept the execution of operations (the `Execute` method),
+// typically implemented with one or more of the custom device's own executions.
+class CustomDevice {
+ public:
+  virtual ~CustomDevice() {}
+  virtual const string& name() = 0;
+  virtual Status CopyTensorToDevice(TensorHandle* tensor,
+                                    TensorHandle** result) = 0;
+
+  virtual Status CopyTensorFromDevice(TensorHandle* tensor,
+                                      const string& target_device_name,
+                                      TensorHandle** result) = 0;
+
+  virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
+                         int* num_retvals) = 0;
+};
+
+// Custom devices do many of the same things as physical Devices, but have a
+// much more restricted interface. We pass around ambiguous pointers since
+// operations may be placed either on custom or physical devices.
+using VariantDevice = absl::variant<Device*, CustomDevice*>;
+
+// A tensor handle produced by a custom device. Generally they can only be
+// consumed by executing an operation on the same custom device that produced it
+// originally, or by attempting to copy the handle off the custom device.
+//
+// TODO(allenl): Currently custom devices are tied to the eager C API. They
+// should be renamed op handlers and subclass AbstractTensorHandle instead so
+// they are eager/graph agnostic.
+class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle {
+ public:
+  CustomDeviceTensorHandle(ImmediateExecutionContext* context,
+                           CustomDevice* device, tensorflow::DataType dtype)
+      : ImmediateExecutionTensorHandle(kCustomDevice),
+        context_(context),
+        device_(device),
+        dtype_(dtype) {}
+
+  tensorflow::DataType DataType() const override { return dtype_; }
+  Status Shape(PartialTensorShape* shape) const override;
+  Status NumElements(int64* num_elements) const override;
+
+  const char* DeviceName(Status* status) const override {
+    return device_->name().c_str();
+  }
+  const char* BackingDeviceName(Status* status) const override {
+    return device_->name().c_str();
+  }
+  CustomDevice* device() const { return device_; }
+  const char* DeviceType(Status* status) const override;
+  int DeviceId(Status* status) const override;
+
+  AbstractTensorInterface* Resolve(Status* status) override;
+
+  ImmediateExecutionTensorHandle* Copy() override {
+    Ref();
+    return this;
+  }
+  void Release() override { Unref(); }
+
+  // For LLVM style RTTI.
+  static bool classof(const AbstractTensorHandle* ptr) {
+    return ptr->getKind() == kCustomDevice;
+  }
+
+ protected:
+  const DeviceNameUtils::ParsedName* ParsedName(Status* status) const;
+
+  ImmediateExecutionContext* const context_;
+  CustomDevice* const device_;
+  const tensorflow::DataType dtype_;
+
+  mutable absl::optional<DeviceNameUtils::ParsedName> parsed_name_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/eager/custom_device_test.cc b/tensorflow/core/common_runtime/eager/custom_device_test.cc
new file mode 100644
index 0000000..388ad81
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/custom_device_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/eager/custom_device.h"
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/device_factory.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+class TestCustomDevice : public CustomDevice {
+ public:
+  explicit TestCustomDevice(std::string name) : name_(name) {}
+  const std::string& name() override { return name_; }
+  Status CopyTensorToDevice(TensorHandle* tensor,
+                            TensorHandle** result) override {
+    tensor->Ref();
+    *result = tensor;
+    return Status::OK();
+  }
+  Status CopyTensorFromDevice(TensorHandle* tensor,
+                              const std::string& target_device_name,
+                              TensorHandle** result) override {
+    tensor->Ref();
+    *result = tensor;
+    return Status::OK();
+  }
+  Status Execute(const EagerOperation* op, TensorHandle** retvals,
+                 int* num_retvals) override {
+    return errors::Unimplemented("Not implemented");
+  }
+
+ private:
+  std::string name_;
+};
+
+class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle {
+ public:
+  TestCustomDeviceTensorHandle(ImmediateExecutionContext* context,
+                               TestCustomDevice* device,
+                               tensorflow::DataType dtype)
+      : CustomDeviceTensorHandle(context, device, dtype) {}
+
+  Status NumDims(int* num_dims) const override {
+    *num_dims = 1;
+    return Status::OK();
+  }
+  Status Dim(int dim_index, int64* dim) const override {
+    if (dim_index == 0) {
+      *dim = 3;
+      return Status::OK();
+    } else {
+      return errors::Internal("Dim out of bounds");
+    }
+  }
+};
+
+TEST(CustomDevice, TestTensorHandle) {
+  StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
+      "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
+  core::RefCountPtr<EagerContext> ctx(new EagerContext(
+      SessionOptions(),
+      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
+      false, &device_mgr, false, nullptr, nullptr));
+  std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15";
+  TestCustomDevice device(device_name);
+  core::RefCountPtr<TestCustomDeviceTensorHandle> tensor(
+      new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT));
+  Status s;
+  std::string device_type = tensor->DeviceType(&s);
+  ASSERT_TRUE(s.ok()) << s.error_message();
+  EXPECT_EQ("CUSTOM", device_type);
+  int device_index = tensor->DeviceId(&s);
+  ASSERT_TRUE(s.ok()) << s.error_message();
+  EXPECT_EQ(15, device_index);
+  int64 num_elements = 0;
+  s = tensor->NumElements(&num_elements);
+  ASSERT_TRUE(s.ok()) << s.error_message();
+  EXPECT_EQ(3, num_elements);
+  EXPECT_EQ("TensorHandle(shape=[3], dtype=DT_FLOAT)", tensor->DebugString());
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 572615b..56a0617 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -920,11 +920,11 @@
         }
       }
       auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle();
-      // For a multi-device function, a remote RunComponentFunction request is
-      // not sent through StreamingEnqueueAsync. It could arrive at a remote
-      // worker before a remote execution request which produces an input of the
-      // component function. So we wait until the remote input is ready before
-      // serializing it.
+      // For a remote component function, a function execution request and an
+      // input generation request may come from different workers. We need to
+      // guarantee that the input generation request is processed before the
+      // function execution request, so wait until the remote input is ready
+      // before sending it to the multi-device function device.
       const bool wait_until_ready = op->is_function();
       TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
           input, wait_until_ready, input_handle, input_device,
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 3b73f01..79b9179 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -44,6 +44,7 @@
 #include "tensorflow/core/public/version.h"
 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
 #if !defined(IS_MOBILE_PLATFORM)
+#include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
 #endif  // !IS_MOBILE_PLATFORM
 
@@ -189,20 +190,8 @@
           "Failed to parse config_proto attribute as tensorflow::ConfigProto "
           "proto.");
     }
-    grappler::GrapplerItem::OptimizationOptions optimization_options;
-
-    // Tensorflow 2.0 in eager mode with automatic control dependencies will
-    // prune all nodes that are not in the transitive fanin of the fetch nodes.
-    // However because the function will be executed via FunctionLibraryRuntime,
-    // and current function implementation does not prune stateful and dataset
-    // ops, we rely on Grappler to do the correct graph pruning.
-    optimization_options.allow_pruning_stateful_and_dataset_ops = true;
-
-    optimization_options.is_eager_mode = true;
-
-    // All the nested function calls will be executed and optimized via
-    // PartitionedCallOp, there is no need to optimize functions now.
-    optimization_options.optimize_function_library = false;
+    grappler::GrapplerItem::OptimizationOptions optimization_options =
+        grappler::CreateOptOptionsForEager();
 
     options.optimize_graph_fn = std::bind(
         grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
@@ -215,9 +204,10 @@
 
   // In Eager mode we always inline all functions into the top-level
   // function body graph, to get a single executable graph, that could be
-  // optimized across function boundaries (e.g. prune unused inputs and outputs
-  // in a function call chain). This is required to mimic graph mode execution,
-  // with aggressive pruning of nodes not in the transitive fanin of fetches.
+  // optimized across function boundaries (e.g. prune unused inputs and
+  // outputs in a function call chain). This is required to mimic graph mode
+  // execution, with aggressive pruning of nodes not in the transitive fanin
+  // of fetches.
   options.config_proto.mutable_graph_options()
       ->mutable_optimizer_options()
       ->set_do_function_inlining(true);
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 03c23f3..fb33cb8 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -73,6 +73,7 @@
 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
 
 namespace tensorflow {
+
 namespace {
 
 // 1-D, 0 element tensor.
@@ -155,7 +156,7 @@
     KernelStats() = default;
 
     void Initialize(const GraphView& gview) {
-      is_expensive_ = absl::make_unique<std::atomic<bool>[]>(gview.num_nodes());
+      is_expensive_.resize(gview.num_nodes());
       cost_estimates_ =
           absl::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
       for (int32 i = 0; i < gview.num_nodes(); ++i) {
@@ -176,28 +177,26 @@
               kOpIsExpensiveThresholdCycles);
     }
 
+    // Returns the value of kernel->IsExpensive().
+    bool HasExpensiveMarker(const NodeItem& node) const {
+      return is_expensive_[node.node_id];
+    }
+
     // Updates the dynamic cost estimate, which is used to determine whether the
     // given node is expensive. The new cost estimate is a weighted average of
-    // the old cost estimate and the latest cost.
-    //
-    // NOTE: We currently only expect updates to the cost estimate when
-    // `is_expensive_[node.node_id]` is true (or at least, it *was* true, when
-    // we started to execute the kernel. As a result, we expect that a kernel
-    // can only ever transition from "expensive" to "inexpensive", but not vice
-    // versa.
+    // the old cost estimate and the latest cost. We only update cost estimates
+    // for kernels for which IsExpensive() return true.
     void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
       // N.B. Updates to `cost_estimate` are atomic but unlocked.  Simultaneous
       // updates may result in one or more updates being ignored.  This does not
       // affect correctness but may slow down the update frequency.
       std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
-      uint64 new_estimate = (kCostDecay - 1) *
-                                cost_estimate.load(std::memory_order_relaxed) /
-                                kCostDecay +
-                            (elapsed_cycles / kCostDecay);
+      auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
+
+      uint64 new_estimate =
+          ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
+
       cost_estimate.store(new_estimate, std::memory_order_relaxed);
-      if (new_estimate < kOpIsExpensiveThresholdCycles) {
-        is_expensive_[node.node_id].store(false, std::memory_order_relaxed);
-      }
     }
 
    private:
@@ -205,10 +204,11 @@
     // determine whether an operation should be place in a threadpool.
     // Operations start out "expensive".
     static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
-    static constexpr uint64 kOpIsExpensiveThresholdCycles = 5000;
+    static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
     static constexpr uint64 kCostDecay = 10;
 
-    std::unique_ptr<std::atomic<bool>[]> is_expensive_;
+    std::vector<bool> is_expensive_;
+    // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
     std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
   };
 
@@ -567,15 +567,19 @@
         },
         profiler::GetTFTraceMeLevel(is_expensive));
     device->Compute(op_kernel, &ctx);
-  } else {
-    // In the common case, avoid creating any tracing objects.
-    if (is_expensive) {
-      KernelTimer timer;
-      device->Compute(op_kernel, &ctx);
+  } else if (kernel_stats_->HasExpensiveMarker(item)) {
+    KernelTimer timer;
+    device->Compute(op_kernel, &ctx);
+    // For expensive kernels, always update the cost estimate. For inexpensive
+    // kernels, update the cost estimate with ~1/16 probability. This assumes
+    // that the last 4 bits of the CPU cycle count is uniformly distributed.
+    constexpr int kKernelExecutionTrackingInvocationSkipCount = 16;
+    if (is_expensive ||
+        timer.start_cycles % kKernelExecutionTrackingInvocationSkipCount == 0) {
       kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles());
-    } else {
-      device->Compute(op_kernel, &ctx);
     }
+  } else {
+    device->Compute(op_kernel, &ctx);
   }
   nodestats::SetOpEnd(stats);
   if (outputs->size() < item.num_outputs) outputs->resize(item.num_outputs);
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index d590ae0..d8ea85f 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -33,6 +33,15 @@
 
 class StepStatsCollector;
 
+// If this is called, we will sample execution cost for "inexpensive" kernels
+// and switch them to "expensive" when the estimated cost exceeds expensive-ness
+// threshold.
+// This is a temporary flag for validating the performance impact of
+// this feature. For simplicity, a global flag is used and once the flag
+// is turned on, it cannot be turned off. We will remove this flag once this
+// feature is validated.
+void EnableAlwaysTrackKernelExecutionCost();
+
 // Executor runs a graph computation.
 // Example:
 //   Graph* graph = ...;
diff --git a/tensorflow/core/common_runtime/function_def_utils.cc b/tensorflow/core/common_runtime/function_def_utils.cc
index d5ada59..b9b679e 100644
--- a/tensorflow/core/common_runtime/function_def_utils.cc
+++ b/tensorflow/core/common_runtime/function_def_utils.cc
@@ -35,14 +35,32 @@
   InstantiationResult result;
   TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
 
-  std::unique_ptr<Graph> graph(new Graph(lib_def));
-  graph->SetConstructionContext(ConstructionContext::kFunctionDef);
+  auto graph = absl::make_unique<Graph>(lib_def);
+
+  auto construction_context_iter = fdef.attr().find("_construction_context");
+  if (construction_context_iter != fdef.attr().end()) {
+    if (construction_context_iter->second.s() == "kEagerRuntime") {
+      graph->SetConstructionContext(ConstructionContext::kEagerRuntime);
+    } else {
+      DCHECK(false) << "Unknown _construction_context attribute: "
+                    << construction_context_iter->second.s();
+    }
+  }
 
   GraphConstructorOptions opts;
   opts.allow_internal_ops = true;
   opts.expect_device_spec = false;
   TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
 
+  const StackTracesMap& stack_traces =
+      lib_def->GetStackTraces(fdef.signature().name());
+  for (Node* n : graph->nodes()) {
+    auto it = stack_traces.find(n->name());
+    if (n && it != stack_traces.end()) {
+      n->SetStackTrace(it->second);
+    }
+  }
+
   // Call BuildControlFlowInfo to validate that this function body has
   // well-formed control flow.
   std::vector<ControlFlowInfo> dummy;
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index c52b4c5..f031ec5 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -393,6 +393,37 @@
   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
 }
 
+TEST_F(FunctionLibraryRuntimeTest, InstantiationStackTraceCopying) {
+  class DummyStackTrace : public AbstractStackTrace {
+    absl::Span<StackFrame const> ToFrames() const override { return {}; }
+
+    std::string ToString(const TracePrintingOptions& opts) const override {
+      return "DummyStackTrace";
+    }
+
+    StackFrame LastUserFrame() const override { return StackFrame{}; }
+  };
+
+  FunctionDef func = test::function::XTimesTwo();
+  Init({});
+
+  StackTracesMap stack_traces;
+  stack_traces["two"] = std::make_shared<DummyStackTrace>();
+
+  TF_CHECK_OK(lib_def_->AddFunctionDef(func, stack_traces));
+
+  FunctionLibraryRuntime::Handle handle;
+  TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {}, &handle));
+
+  const FunctionBody* func_body = flr0_->GetFunctionBody(handle);
+  for (const Node* node : func_body->graph->nodes()) {
+    if (node->name() == "two") {
+      EXPECT_EQ(node->GetStackTrace()->ToString({}), "DummyStackTrace");
+    }
+  }
+  TF_CHECK_OK(flr0_->ReleaseHandle(handle));
+}
+
 TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_MultiDeviceBacked) {
   Init({test::function::XTimesTwo()});
   auto x = test::AsTensor<float>({1, 2, 3, 4});
@@ -1258,7 +1289,7 @@
     auto g = absl::make_unique<Graph>(OpRegistry::Global());
     TF_ASSERT_OK(construct_graph(&g));
 
-    const string merged_device = "/job:call/replica:0/task:1/device:CPU:*";
+    const string merged_device = "/job:body/replica:0/task:1/device:CPU:*";
 
     ExpandInlineFunctions(flr0_, g.get(), opts);
     GraphDef expected = expected_graph({/*a*/ arg_device,                //
diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD
index 5ca8ee2..fbd1421 100644
--- a/tensorflow/core/common_runtime/gpu/BUILD
+++ b/tensorflow/core/common_runtime/gpu/BUILD
@@ -114,6 +114,7 @@
         "gpu_managed_allocator.h",
         "gpu_process_state.h",
         "gpu_util.h",
+        "gpu_virtual_mem_allocator.h",
         "//tensorflow/core/common_runtime:gpu_runtime_headers",
         "//tensorflow/core/common_runtime/device:device_runtime_headers",
     ],
@@ -137,6 +138,7 @@
     cuda_deps = [
         "@local_config_cuda//cuda:cudnn_header",
         "//tensorflow/stream_executor/cuda:cuda_platform",
+        ":gpu_virtual_mem_allocator",
     ],
     deps = [
         ":gpu_bfc_allocator",
@@ -187,6 +189,7 @@
     features = ["parse_headers"],
     visibility = ["//visibility:public"],
     deps = [
+        ":gpu_virtual_mem_allocator",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
@@ -196,6 +199,31 @@
 )
 
 tf_cuda_library(
+    name = "gpu_virtual_mem_allocator",
+    srcs = [
+        "gpu_virtual_mem_allocator.cc",
+    ],
+    hdrs = [
+        "gpu_virtual_mem_allocator.h",
+    ],
+    copts = tf_copts(),
+    features = ["parse_headers"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":gpu_id",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core/framework:allocator",
+        "//tensorflow/core/platform:stream_executor",
+        "//tensorflow/stream_executor:platform",
+        "//tensorflow/stream_executor:stream_executor_headers",
+        "//tensorflow/stream_executor/lib",
+    ],
+)
+
+tf_cuda_library(
     name = "gpu_init",
     hdrs = [
         "gpu_init.h",
@@ -403,3 +431,21 @@
         "//tensorflow/stream_executor:platform",
     ],
 )
+
+tf_cc_test(
+    name = "gpu_virtual_mem_allocator_test",
+    size = "small",
+    srcs = ["gpu_virtual_mem_allocator_test.cc"],
+    linkstatic = tf_kernel_tests_linkstatic(),
+    tags = tf_cuda_tests_tags(),
+    deps = [
+        ":gpu_init",
+        ":gpu_virtual_mem_allocator",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/framework:allocator",
+        "//tensorflow/core/platform:stream_executor",
+        "//tensorflow/stream_executor/lib",
+    ],
+)
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 6d4ddce..50647b8 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -907,24 +907,25 @@
   return Status::OK();
 }
 
-int64 MinSystemMemory(int64 available_memory) {
+int64 MinSystemMemory(int64 available_memory, int cc_major) {
   // We use the following heuristic for now:
   //
   // If the available_memory is < 2GiB, we allocate 225MiB to system memory.
-  // Otherwise, allocate max(300MiB, kMinSystemMemoryFraction *
-  // available_memory) to system memory.
-  //
-  // In the future we could be more sophisticated by using a table of devices.
+  // Otherwise, depending on the capability version assign
+  //  500MiB (for cuda_compute_capability <= 6.x) or
+  // 1050MiB (for cuda_compute_capability <= 7.x) or
+  // 1536MiB (for cuda_compute_capability >= 8.x)
   int64 min_system_memory;
-  constexpr float kMinSystemMemoryFraction = 0.06;
   if (available_memory < (1LL << 31)) {
-    // 225MiB
     min_system_memory = 225 * 1024 * 1024;
   } else {
-    // max(300 MiB, kMinSystemMemoryFraction * available_memory)
-    min_system_memory = std::max(
-        int64{314572800},
-        static_cast<int64>(available_memory * kMinSystemMemoryFraction));
+    if (cc_major <= 6) {
+      min_system_memory = 500 * 1024 * 1024;
+    } else if (cc_major <= 7) {
+      min_system_memory = 1050 * 1024 * 1024;
+    } else {
+      min_system_memory = 1536 * 1024 * 1024;
+    }
   }
 #if defined(__GNUC__) && defined(__OPTIMIZE__)
 // Do nothing
@@ -967,13 +968,13 @@
   int64 allocated_memory = 0;
   const double per_process_gpu_memory_fraction =
       gpu_options.per_process_gpu_memory_fraction();
+  int cc_major = 0, cc_minor = 0;
+  if (!se->GetDeviceDescription().cuda_compute_capability(&cc_major,
+                                                          &cc_minor)) {
+    return errors::Internal("Failed to get compute capability for device.");
+  }
   if (per_process_gpu_memory_fraction > 1.0 ||
       gpu_options.experimental().use_unified_memory()) {
-    int cc_major = 0, cc_minor = 0;
-    if (!se->GetDeviceDescription().cuda_compute_capability(&cc_major,
-                                                            &cc_minor)) {
-      return errors::Internal("Failed to get compute capability for device.");
-    }
     if (cc_major < 6) {
       return errors::Internal(
           "Unified memory on GPUs with compute capability lower than 6.0 "
@@ -983,13 +984,45 @@
 
   if (per_process_gpu_memory_fraction == 0) {
     allocated_memory = available_memory;
-    const int64 min_system_memory = MinSystemMemory(available_memory);
+    const int64 min_system_memory = MinSystemMemory(available_memory, cc_major);
     if (min_system_memory < allocated_memory) {
       allocated_memory -= min_system_memory;
     }
   } else {
     allocated_memory = total_memory * per_process_gpu_memory_fraction;
   }
+
+  // Override the excluded memory when TF_DEVICE_MIN_SYS_MEMORY_IN_MB is set.
+  const char* force_device_reserved_bytes =
+      std::getenv("TF_DEVICE_MIN_SYS_MEMORY_IN_MB");
+  if (force_device_reserved_bytes != nullptr &&
+      strcmp(force_device_reserved_bytes, "") != 0) {
+    int32 reserved_mb;
+    if (!strings::safe_strto32(force_device_reserved_bytes, &reserved_mb) ||
+        reserved_mb < 0) {
+      LOG(WARNING) << "The requested reserved device memory "
+                   << force_device_reserved_bytes
+                   << " is invalid. The request will be ignored.";
+    } else {
+      // Convert MBytes to Bytes.
+      size_t allowable_reserved_memory = reserved_mb * 1024 * 1024;
+      // TF_DEVICE_MIN_SYS_MEMORY_IN_MB overrides
+      // per_process_gpu_memory_fraction.
+      if (allowable_reserved_memory <= available_memory) {
+        allocated_memory = available_memory - allowable_reserved_memory;
+        VLOG(1) << "Setting the GPU reserved bytes to "
+                << strings::HumanReadableNumBytes(allocated_memory)
+                << " MBytes";
+      } else {
+        LOG(WARNING) << "The requested reserved device memory "
+                     << strings::HumanReadableNumBytes(
+                            allowable_reserved_memory)
+                     << " is larger than the available memory of "
+                     << strings::HumanReadableNumBytes(available_memory)
+                     << ". The request is ignored.";
+      }
+    }
+  }
   *memory_limit = allocated_memory;
   return Status::OK();
 }
diff --git a/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.cc
new file mode 100644
index 0000000..4e0f976
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.cc
@@ -0,0 +1,186 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h"
+
+#include "tensorflow/core/lib/strings/numbers.h"
+
+#if CUDA_VERSION >= 10020
+
+namespace tensorflow {
+namespace {
+
+using ::stream_executor::gpu::GpuContext;
+using ::stream_executor::gpu::GpuDevicePtr;
+using ::stream_executor::gpu::GpuDriver;
+
+// Rounds value up to the specified power of two alignment.
+size_t AlignUp(size_t value, size_t alignment) {
+  DCHECK_EQ(alignment & (alignment - 1), 0)
+      << "Alignment must be a power of two; alignment=" << alignment;
+  return (value + alignment - 1) & ~(alignment - 1);
+}
+
+}  // namespace
+
+/* static */ stream_executor::port::StatusOr<
+    std::unique_ptr<GpuVirtualMemAllocator>>
+GpuVirtualMemAllocator::Create(const std::vector<Visitor>& alloc_visitors,
+                               const std::vector<Visitor>& free_visitors,
+                               GpuContext& gpu_context, PlatformGpuId gpu_id,
+                               size_t virtual_address_space_size,
+                               const std::vector<PlatformGpuId>& peer_gpu_ids) {
+  std::vector<int> access_gpu_ordinals;
+  access_gpu_ordinals.reserve(peer_gpu_ids.size() + 1);
+  access_gpu_ordinals.push_back(gpu_id.value());
+  for (const auto& peer_id : peer_gpu_ids) {
+    access_gpu_ordinals.push_back(peer_id.value());
+  }
+
+  // Find the min granularity for all devices that have access to this memory;
+  // that is, the maximum min granularity among all devices.
+  size_t max_granularity = 1;
+  for (const int device_ordinal : access_gpu_ordinals) {
+    TF_ASSIGN_OR_RETURN(size_t granularity,
+                        GpuDriver::GetMinAllocationGranularity(device_ordinal));
+    max_granularity = std::max(max_granularity, granularity);
+  }
+
+  // Create the virtual memory reservation. Must be aligned to system page size,
+  // and larger than the CUDA min granularity. Empirically, the granularity
+  // check is sufficient as the granularity is some multiple of the page size.
+  // TODO(imintz): Create OS agnostic page size utility for completeness.
+  TF_ASSIGN_OR_RETURN(
+      GpuDriver::VmemSpan vmem,
+      GpuDriver::ReserveVirtualMemory(
+          &gpu_context, AlignUp(virtual_address_space_size, max_granularity)));
+  VLOG(1) << "Reserved GPU virtual memory at " << vmem.base << " of size "
+          << strings::HumanReadableNumBytes(vmem.size_bytes) << " bytes";
+
+  return std::unique_ptr<GpuVirtualMemAllocator>(new GpuVirtualMemAllocator(
+      alloc_visitors, free_visitors, gpu_context, gpu_id,
+      std::move(access_gpu_ordinals), vmem, max_granularity));
+}
+
+GpuVirtualMemAllocator::GpuVirtualMemAllocator(
+    const std::vector<Visitor>& alloc_visitors,
+    const std::vector<Visitor>& free_visitors, GpuContext& gpu_context,
+    PlatformGpuId gpu_id, const std::vector<int> access_gpu_ordinals,
+    GpuDriver::VmemSpan vmem, size_t granularity)
+    : SubAllocator(alloc_visitors, free_visitors),
+      gpu_context_(gpu_context),
+      gpu_id_(gpu_id),
+      access_gpu_ordinals_(access_gpu_ordinals),
+      vmem_(vmem),
+      granularity_(granularity) {}
+
+GpuVirtualMemAllocator::~GpuVirtualMemAllocator() {
+  for (const auto mapping : mappings_) {
+    GpuDriver::UnmapMemory(&gpu_context_, mapping.va, mapping.physical.bytes);
+    GpuDriver::ReleaseMemoryHandle(&gpu_context_, std::move(mapping.physical));
+  }
+  GpuDriver::FreeVirtualMemory(&gpu_context_, vmem_);
+}
+
+void* GpuVirtualMemAllocator::Alloc(size_t alignment, size_t num_bytes,
+                                    size_t* bytes_received) {
+  if (num_bytes == 0) return nullptr;
+  size_t padded_bytes = (num_bytes + granularity_ - 1) & ~(granularity_ - 1);
+
+  GpuDevicePtr next_va = vmem_.base + next_alloc_offset_;
+
+  // TODO(imintz): Attempt to extend the vmem allocation by reserving additional
+  // virtual memory at the specific address at the end of the initial vmem
+  // reservation.
+  if (next_va + padded_bytes > vmem_.base + vmem_.size_bytes) {
+    LOG(ERROR) << "OOM in GPU virtual memory allocator when attempting to "
+                  "allocate {request: "
+               << strings::HumanReadableNumBytes(num_bytes)
+               << ", aligned: " << padded_bytes << "} bytes.";
+    return nullptr;
+  }
+
+  // Create physical memory backing allocation.
+  auto maybe_handle =
+      GpuDriver::CreateMemoryHandle(&gpu_context_, padded_bytes);
+  if (!maybe_handle.ok()) {
+    LOG(ERROR) << maybe_handle.status();
+    return nullptr;
+  }
+  GpuDriver::GenericMemoryHandle handle = std::move(maybe_handle).ValueOrDie();
+
+  // Map VAs for this physical memory.
+  auto status = GpuDriver::MapMemory(&gpu_context_, next_va, handle,
+                                     access_gpu_ordinals_);
+  if (!status.ok()) {
+    LOG(ERROR) << status;
+    GpuDriver::ReleaseMemoryHandle(&gpu_context_, std::move(handle));
+    return nullptr;
+  }
+  next_alloc_offset_ += handle.bytes;
+  mappings_.push_back({next_va, std::move(handle)});
+  VisitAlloc(reinterpret_cast<void*>(next_va), gpu_id_.value(), padded_bytes);
+  *bytes_received = padded_bytes;
+  return reinterpret_cast<void*>(next_va);
+}
+
+void GpuVirtualMemAllocator::Free(void* ptr, size_t num_bytes) {
+  if (ptr == nullptr) return;
+
+  auto mapping_it =
+      std::lower_bound(mappings_.begin(), mappings_.end(), ptr,
+                       [](const Mapping& mapping, const void* ptr) {
+                         return reinterpret_cast<const void*>(mapping.va) < ptr;
+                       });
+  if (mapping_it == mappings_.end() ||
+      (reinterpret_cast<void*>(mapping_it->va) != ptr)) {
+    LOG(ERROR) << "Could not find GPU vmem mapping for address at "
+               << reinterpret_cast<uintptr_t>(ptr);
+    return;
+  }
+
+  int num_mappings_to_free = 0;
+  int total_bytes = 0;
+  for (auto it = mapping_it; it != mappings_.end() && total_bytes < num_bytes;
+       ++it) {
+    ++num_mappings_to_free;
+    total_bytes += it->physical.bytes;
+  }
+  if (total_bytes != num_bytes) {
+    LOG(ERROR) << "Invalid size requested for freeing GPU vmem mapping. Got "
+               << strings::HumanReadableNumBytes(num_bytes) << " but expected "
+               << strings::HumanReadableNumBytes(mapping_it->physical.bytes);
+    return;
+  }
+
+  VLOG(1) << "Freeing " << num_mappings_to_free << " mappings for a total of "
+          << total_bytes << " bytes";
+  for (auto it = mapping_it; it < mapping_it + num_mappings_to_free; ++it) {
+    GpuDriver::UnmapMemory(&gpu_context_, it->va, it->physical.bytes);
+    GpuDriver::ReleaseMemoryHandle(&gpu_context_, std::move(it->physical));
+  }
+
+  // Move back the next_alloc_offset_ if this free was at the end.
+  if (mapping_it + num_mappings_to_free == mappings_.end()) {
+    next_alloc_offset_ = mapping_it->va - vmem_.base;
+  }
+
+  mappings_.erase(mapping_it, mapping_it + num_mappings_to_free);
+  VisitFree(ptr, gpu_id_.value(), num_bytes);
+}
+
+}  // namespace tensorflow
+
+#endif
diff --git a/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h
new file mode 100644
index 0000000..86bb00f
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h
@@ -0,0 +1,113 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// CUDA virtual memory API is only available in CUDA versions greater than 10.2.
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_VMEM_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_VMEM_ALLOCATOR_H_
+
+#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/stream_executor/gpu/gpu_driver.h"
+#include "tensorflow/stream_executor/gpu/gpu_types.h"
+#endif
+
+#if CUDA_VERSION >= 10020
+
+namespace tensorflow {
+
+// GpuVirtualMemAllocator is a SubAllocator for use with BFCAllocator which
+// provides contiguous allocations with each call to Alloc. This is done by
+// reserving a large chunk of virtual addresses at construction and then mapping
+// physical memory pages to this virtual address range as requested.
+//
+// This class is not thread-safe.
+class GpuVirtualMemAllocator : public SubAllocator {
+ public:
+  static stream_executor::port::StatusOr<
+      std::unique_ptr<GpuVirtualMemAllocator>>
+  Create(const std::vector<Visitor>& alloc_visitors,
+         const std::vector<Visitor>& free_visitors,
+         stream_executor::gpu::GpuContext& gpu_context, PlatformGpuId gpu_id,
+         size_t virtual_address_space_size,
+         const std::vector<PlatformGpuId>& peer_gpu_ids);
+  ~GpuVirtualMemAllocator() override;
+
+  // Allocates memory at least as large as requested by num_bytes. Will be
+  // aligned to the min allocation granularity (typically 2MiB).
+  // alignment is ignored by this allocator.
+  void* Alloc(size_t alignment, size_t num_bytes,
+              size_t* bytes_received) override;
+
+  // Frees should only happen at the end of the contiguous memory allocations or
+  // else we introduce pointless fragmentation...But, this is supported. If the
+  // allocation happens at the end, then the next_alloc_offset_ is moved back,
+  // otherwise a hole is created.
+  //
+  // Holes are not re-used, all allocations continue to come at the end of the
+  // next_alloc_offset_. To accommodate this, the virtual_address_space_size
+  // should be much larger than the max physical size of the allocator.
+  //
+  // In practice, since the BFC allocator coalesces adjacent AllocationRegions,
+  // this free function should never be invoked.
+  void Free(void* ptr, size_t num_bytes) override;
+
+ private:
+  GpuVirtualMemAllocator(const std::vector<Visitor>& alloc_visitors,
+                         const std::vector<Visitor>& free_visitors,
+                         ::stream_executor::gpu::GpuContext& gpu_context,
+                         PlatformGpuId gpu_id,
+                         std::vector<int> access_gpu_ordinals,
+                         stream_executor::gpu::GpuDriver::VmemSpan vmem,
+                         size_t granularity);
+
+  stream_executor::gpu::GpuContext& gpu_context_;
+  PlatformGpuId gpu_id_;
+
+  // Peer access is configured at mmap time so the allocator must be aware of
+  // all gpus that may want to read the memory. This list also includes the
+  // above gpu_id_ to facilitate the invocation of the GpuDriver::MapMemory
+  // function.
+  const std::vector<int> access_gpu_ordinals_;
+
+  // The virtual memory span held by this allocator.
+  stream_executor::gpu::GpuDriver::VmemSpan vmem_;
+  // The next offset from the vmem base address that will be allocated. This
+  // corresponds to the size of physically pinned memory if holes haven't been
+  // created with "free".
+  size_t next_alloc_offset_ = 0;
+
+  // Smallest allocation as determined by CUDA.
+  const size_t granularity_;
+
+  struct Mapping {
+    stream_executor::gpu::GpuDevicePtr va;
+    stream_executor::gpu::GpuDriver::GenericMemoryHandle physical;
+  };
+  // List of mappings, sorted by va.
+  std::vector<Mapping> mappings_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(GpuVirtualMemAllocator);
+};
+
+}  // namespace tensorflow
+
+#endif  // CUDA_VERSION >= 10200
+
+#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_VMEM_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator_test.cc
new file mode 100644
index 0000000..7a71c55
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator_test.cc
@@ -0,0 +1,185 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h"
+
+#if CUDA_VERSION >= 10020
+
+#include "tensorflow/core/common_runtime/device/device_id_utils.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+using ::stream_executor::gpu::GpuContext;
+using ::stream_executor::gpu::GpuDevicePtr;
+using ::stream_executor::gpu::GpuDriver;
+
+// Empirically the min allocation granularity.
+constexpr size_t k2MiB{2 << 20};
+
+// Creates an allocator with 8 MiB of virtual address space.
+std::unique_ptr<GpuVirtualMemAllocator> CreateAllocator() {
+  PlatformGpuId gpu_id(0);
+  auto executor =
+      DeviceIdUtil::ExecutorForPlatformDeviceId(GPUMachineManager(), gpu_id)
+          .ValueOrDie();
+  GpuContext* gpu_context = reinterpret_cast<GpuContext*>(
+      executor->implementation()->GpuContextHack());
+  return GpuVirtualMemAllocator::Create(
+             {}, {}, *gpu_context, gpu_id,
+             /*virtual_address_space_size=*/4 * k2MiB, {})
+      .ValueOrDie();
+}
+
+TEST(GpuVirtualMemAllocatorTest, SimpleAlloc) {
+  PlatformGpuId gpu_id(0);
+  auto executor =
+      DeviceIdUtil::ExecutorForPlatformDeviceId(GPUMachineManager(), gpu_id)
+          .ValueOrDie();
+  GpuContext* gpu_context = reinterpret_cast<GpuContext*>(
+      executor->implementation()->GpuContextHack());
+  auto allocator = GpuVirtualMemAllocator::Create(
+                       {}, {}, *gpu_context, gpu_id,
+                       /*virtual_address_space_size=*/4 * k2MiB, {})
+                       .ValueOrDie();
+  size_t bytes_received;  // Ignored in this test.
+  void* gpu_block =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(gpu_block, nullptr);
+
+  constexpr size_t kBufSize{256};
+  void* host_mem[2] = {GpuDriver::HostAllocate(gpu_context, kBufSize),
+                       GpuDriver::HostAllocate(gpu_context, kBufSize)};
+  std::memset(host_mem[0], 'z', kBufSize);
+  std::memset(host_mem[1], 0, kBufSize);
+
+  GpuDevicePtr gpu_buf = reinterpret_cast<GpuDevicePtr>(gpu_block) + 2048;
+  ASSERT_TRUE(GpuDriver::SynchronousMemcpyH2D(gpu_context, gpu_buf, host_mem[0],
+                                              kBufSize)
+                  .ok());
+  ASSERT_TRUE(GpuDriver::SynchronousMemcpyD2H(gpu_context, host_mem[1], gpu_buf,
+                                              kBufSize)
+                  .ok());
+  for (int i = 0; i < kBufSize; ++i) {
+    ASSERT_EQ('z', reinterpret_cast<const char*>(host_mem[1])[i]);
+  }
+}
+
+TEST(GpuVirtualMemAllocatorTest, AllocPaddedUp) {
+  auto allocator = CreateAllocator();
+  size_t bytes_received;
+  void* gpu_block =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/256, &bytes_received);
+  ASSERT_NE(gpu_block, nullptr);
+  ASSERT_EQ(bytes_received, k2MiB);
+}
+
+TEST(GpuVirtualMemAllocatorTest, AllocsContiguous) {
+  auto allocator = CreateAllocator();
+  size_t bytes_received;  // Ignored in this test.
+  void* first_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(first_alloc, nullptr);
+  void* second_alloc = allocator->Alloc(
+      /*alignment=*/0, /*num_bytes=*/2 * k2MiB, &bytes_received);
+  ASSERT_NE(second_alloc, nullptr);
+
+  ASSERT_EQ(second_alloc, reinterpret_cast<const char*>(first_alloc) + k2MiB);
+
+  void* third_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(third_alloc, nullptr);
+
+  ASSERT_EQ(third_alloc,
+            reinterpret_cast<const char*>(second_alloc) + 2 * k2MiB);
+}
+
+TEST(GpuVirtualMemAllocator, OverAllocate) {
+  auto allocator = CreateAllocator();
+  size_t bytes_received;  // Ignored in this test.
+  void* first_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(first_alloc, nullptr);
+  void* over_alloc = allocator->Alloc(/*alignment=*/0, /*num_bytes=*/4 * k2MiB,
+                                      &bytes_received);
+  ASSERT_EQ(over_alloc, nullptr);
+}
+
+TEST(GpuVirtualMemAllocatorTest, FreeAtEnd) {
+  auto allocator = CreateAllocator();
+  size_t bytes_received;  // Ignored in this test.
+  void* first_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(first_alloc, nullptr);
+  void* second_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(second_alloc, nullptr);
+
+  allocator->Free(second_alloc, k2MiB);
+
+  void* re_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_EQ(re_alloc, second_alloc);
+}
+
+TEST(GpuVirtualMemAllocatorTest, FreeHole) {
+  auto allocator = CreateAllocator();
+  size_t bytes_received;  // Ignored in this test.
+  void* first_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(first_alloc, nullptr);
+  void* second_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(second_alloc, nullptr);
+
+  allocator->Free(first_alloc, k2MiB);
+
+  void* third_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(third_alloc, nullptr);
+
+  // Expect that allocation still happens at the end.
+  ASSERT_EQ(third_alloc, reinterpret_cast<const char*>(second_alloc) + k2MiB);
+}
+
+TEST(GpuVirtualMemAllocatorTest, FreeRange) {
+  auto allocator = CreateAllocator();
+  size_t bytes_received;  // Ignored in this test.
+  void* first_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(first_alloc, nullptr);
+  void* second_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(second_alloc, nullptr);
+  void* third_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(third_alloc, nullptr);
+
+  allocator->Free(first_alloc, 3 * k2MiB);
+
+  void* re_alloc =
+      allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
+  ASSERT_NE(re_alloc, nullptr);
+  ASSERT_EQ(re_alloc, first_alloc);
+}
+
+}  // namespace
+}  // namespace tensorflow
+
+#endif
diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc
index 92b0768..639739e 100644
--- a/tensorflow/core/common_runtime/graph_constructor.cc
+++ b/tensorflow/core/common_runtime/graph_constructor.cc
@@ -44,6 +44,7 @@
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
 #include "tensorflow/core/lib/strings/scanner.h"
 #include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/public/version.h"
@@ -1425,6 +1426,17 @@
 
 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
                                   int input_index) {
+  if (output_index >= src->num_outputs()) {
+    return errors::InvalidArgument(
+        "Output ", output_index, " of node ", src->name(),
+        " does not exist. Node only has ", src->num_outputs(), " outputs.");
+  }
+  if (input_index >= dst->num_inputs()) {
+    return errors::InvalidArgument(
+        "Input ", input_index, " of node ", dst->name(),
+        " does not exist. Node only has ", dst->num_inputs(), " inputs.");
+  }
+
   DataType src_out = src->output_type(output_index);
   DataType dst_in = dst->input_type(input_index);
   if (!TypesCompatible(dst_in, src_out)) {
diff --git a/tensorflow/core/common_runtime/inline_function_utils.cc b/tensorflow/core/common_runtime/inline_function_utils.cc
index fc2b846..2ca6f74 100644
--- a/tensorflow/core/common_runtime/inline_function_utils.cc
+++ b/tensorflow/core/common_runtime/inline_function_utils.cc
@@ -231,17 +231,19 @@
     if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device))
       return ndef.device();
 
-    if (caller_parsed_device_.has_job) {
+    // Nodes with explicit device placements in the function body have those
+    // respected, but otherwise the function's placement provides a default.
+    if (caller_parsed_device_.has_job && !ndef_parsed_device.has_job) {
       ndef_parsed_device.has_job = caller_parsed_device_.has_job;
       ndef_parsed_device.job = caller_parsed_device_.job;
     }
 
-    if (caller_parsed_device_.has_replica) {
+    if (caller_parsed_device_.has_replica && !ndef_parsed_device.has_replica) {
       ndef_parsed_device.has_replica = caller_parsed_device_.has_replica;
       ndef_parsed_device.replica = caller_parsed_device_.replica;
     }
 
-    if (caller_parsed_device_.has_task) {
+    if (caller_parsed_device_.has_task && !ndef_parsed_device.has_task) {
       ndef_parsed_device.has_task = caller_parsed_device_.has_task;
       ndef_parsed_device.task = caller_parsed_device_.task;
     }
@@ -616,6 +618,7 @@
     Node* clone = g->AddNode(ndef, &added_node);
     TF_CHECK_OK(added_node);
     node_map[n->id()] = clone;
+    clone->SetStackTrace(n->GetStackTrace());
 
     // If there is an input control node, and one of:
     // a) the node has no data or control inputs, or
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index 2a0e5d3..ff010ad 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -148,22 +148,13 @@
 Status CondBuilder::CreatePivotNodes() {
   // Construct the basic cond body (consisting of feeding in the predicate to
   // create pivot nodes).
-
-  // This is a special pivot switch node for lowering. We mark this with a
-  // special _PivotSwitch attr on it as later on in the graph partitioner we
-  // do some special placement for Switch nodes and its necessary to distinguish
-  // between a "normal" Switch node and one of these pivot switches. We would
-  // like to place this node on the CPU always as the pred_ will be on the CPU
-  // as well (either a CPU op output or a GPU op with HostMemory annotation).
-  // TODO(b/171321391): Fix this for NUMA cases.
   Node* switch_pred;
   TF_RETURN_IF_ERROR(
       SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch",
                                            graph_->op_registry(), &debug_info_)
                                    .Input(NodeOut(pred_))
                                    .Input(NodeOut(pred_))
-                                   .Attr("_PivotSwitch", true)
-                                   .Device("/CPU:0"),
+                                   .Device(if_op_->requested_device()),
                                graph_, &switch_pred));
   control_predecessor_ = switch_pred;
   TF_RETURN_IF_ERROR(
diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc
index b0304cf..cf7d354 100644
--- a/tensorflow/core/common_runtime/lower_if_op_test.cc
+++ b/tensorflow/core/common_runtime/lower_if_op_test.cc
@@ -147,115 +147,6 @@
   }
 }
 
-TEST(LowerIfOpTest, GPUPlacement) {
-  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
-
-  // Add test functions for then and else branch.
-  FunctionDefLibrary f_lib_proto;
-  *(f_lib_proto.add_function()) = test::function::XTimesTwo();
-  *(f_lib_proto.add_function()) = test::function::XTimesFour();
-
-  // Construct simple conditional that switches on `pred` and operates only on
-  // single input `A`.
-  Scope root = Scope::NewRootScope().ExitOnError();
-  TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
-  auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
-  auto x = ops::Placeholder(root.WithOpName("X"), DT_INT32);
-  auto y = ops::Placeholder(root.WithOpName("Y"), DT_INT32);
-  Node* pred;
-  TF_ASSERT_OK(NodeBuilder("greater", "Greater", &root.graph()->flib_def())
-                   .Input(x.node())
-                   .Input(y.node())
-                   .Device("/GPU:0")
-                   .Finalize(root.graph(), &pred));
-  Node* written_if;
-  std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
-  TF_ASSERT_OK(
-      NodeBuilder("if", "If", &root.graph()->flib_def())
-          .Input(pred)
-          .Input(inputs)
-          .Attr("then_branch", FuncAttr("XTimesTwo"))
-          .Attr("else_branch", FuncAttr("XTimesFour"))
-          .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
-          .Attr("Tout", {DT_INT32})
-          .Device("/GPU:0")
-          .Finalize(root.graph(), &written_if));
-  TF_ASSERT_OK(root.DoShapeInference(written_if));
-  TF_ASSERT_OK(root.ToGraph(graph.get()));
-
-  // The input graph has no switch or merge nodes.
-  int node_called_if_count = 0;
-  for (const auto* op : graph->op_nodes()) {
-    ASSERT_FALSE(op->IsSwitch());
-    ASSERT_FALSE(op->IsMerge());
-    if (op->name() == "if") {
-      ++node_called_if_count;
-    }
-  }
-  ASSERT_EQ(node_called_if_count, 1);
-
-  TF_ASSERT_OK(Rewrite(&graph));
-
-  // Verify the resultant graph has switch and merge nodes, and a node called
-  // `if` (but not If nodes).
-  int switch_count = 0;
-  int merge_count = 0;
-  node_called_if_count = 0;
-  for (const auto* op : graph->op_nodes()) {
-    if (op->IsSwitch()) {
-      ++switch_count;
-    }
-    if (op->IsMerge()) {
-      ++merge_count;
-    }
-    ASSERT_NE(op->type_string(), "If");
-    if (op->name() == "if") {
-      ++node_called_if_count;
-    }
-  }
-  // One switch for predicate and one for input (A).
-  ASSERT_EQ(switch_count, 2);
-  // One merge for the single output value of then and else, and one more merge
-  // to enforce then and else function call execution (`branch_executed` node).
-  ASSERT_EQ(merge_count, 2);
-  ASSERT_EQ(node_called_if_count, 1);
-
-  // Verify execution.
-  ClientSession session(root, SessionOptionsWithInlining());
-  {
-    RunMetadata metadata;
-    RunOptions options;
-    options.set_output_partition_graphs(true);
-    ClientSession::FeedType feeds;
-    feeds.emplace(Output(x.node()), Input::Initializer(5));
-    feeds.emplace(Output(y.node()), Input::Initializer(10));
-    feeds.emplace(Output(a.node()), Input::Initializer(10));
-    std::vector<Tensor> out_tensors;
-    TF_ASSERT_OK(session.Run(options, feeds, {Output(written_if)}, {},
-                             &out_tensors, &metadata));
-    GraphDef cpu_graph = metadata.partition_graphs(1);
-    int num_cpu_switch = 0;
-    for (const auto& node : cpu_graph.node()) {
-      if (node.op() == "Switch") {
-        ++num_cpu_switch;
-      }
-    }
-    EXPECT_EQ(num_cpu_switch, 2);
-    EXPECT_EQ(out_tensors.size(), 1);
-    EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
-  }
-  {
-    ClientSession::FeedType feeds;
-    feeds.emplace(Output(x.node()), Input::Initializer(10));
-    feeds.emplace(Output(y.node()), Input::Initializer(5));
-    feeds.emplace(Output(a.node()), Input::Initializer(10));
-    std::vector<Tensor> out_tensors;
-    TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
-    EXPECT_EQ(out_tensors.size(), 1);
-    EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
-  }
-}
-
 TEST(LowerIfOpTest, BranchFunctionsWithoutOutputs) {
   using ::tensorflow::test::function::GDef;
   using ::tensorflow::test::function::NDef;
diff --git a/tensorflow/core/common_runtime/partitioning_utils.cc b/tensorflow/core/common_runtime/partitioning_utils.cc
index 6fb7526..6cdc970 100644
--- a/tensorflow/core/common_runtime/partitioning_utils.cc
+++ b/tensorflow/core/common_runtime/partitioning_utils.cc
@@ -74,7 +74,7 @@
 }
 
 Status UpdateArgAndRetvalMetadata(
-    Graph* subgraph, const string& device_type,
+    Graph* graph, const string& device_type,
     std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
     std::vector<AllocatorAttributes>* arg_alloc_attrs,
     std::vector<AllocatorAttributes>* ret_alloc_attrs) {
@@ -84,7 +84,7 @@
 
   // Find the Arg and Retval nodes, along with their corresponding indices
   // in the original function.
-  for (Node* node : subgraph->op_nodes()) {
+  for (Node* node : graph->op_nodes()) {
     if (node->IsArg()) {
       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
       int index = static_cast<int>(attr_value->i());
@@ -124,31 +124,35 @@
     Node* arg = arg_nodes[i].first;
     arg->AddAttr("index", i);
     TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
-    AllocatorAttributes alloc_attr;
-    DataType type = attr_value->type();
-    MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
-                        device_type == "XLA_GPU")
-                           ? MTypeFromDTypeIntsOnDevice(type)
-                           : MTypeFromDType(type);
-    if (mtype == HOST_MEMORY) {
-      alloc_attr.set_on_host(true);
+    if (arg_alloc_attrs != nullptr) {
+      AllocatorAttributes alloc_attr;
+      DataType type = attr_value->type();
+      MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
+                          device_type == "XLA_GPU")
+                             ? MTypeFromDTypeIntsOnDevice(type)
+                             : MTypeFromDType(type);
+      if (mtype == HOST_MEMORY) {
+        alloc_attr.set_on_host(true);
+      }
+      arg_alloc_attrs->push_back(alloc_attr);
     }
-    arg_alloc_attrs->push_back(alloc_attr);
   }
   for (int i = 0; i < ret_nodes.size(); ++i) {
     Node* ret = ret_nodes[i].first;
     ret->AddAttr("index", i);
     TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
-    AllocatorAttributes alloc_attr;
-    DataType type = attr_value->type();
-    MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
-                        device_type == "XLA_GPU")
-                           ? MTypeFromDTypeIntsOnDevice(type)
-                           : MTypeFromDType(type);
-    if (mtype == HOST_MEMORY) {
-      alloc_attr.set_on_host(true);
+    if (ret_alloc_attrs) {
+      AllocatorAttributes alloc_attr;
+      DataType type = attr_value->type();
+      MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
+                          device_type == "XLA_GPU")
+                             ? MTypeFromDTypeIntsOnDevice(type)
+                             : MTypeFromDType(type);
+      if (mtype == HOST_MEMORY) {
+        alloc_attr.set_on_host(true);
+      }
+      ret_alloc_attrs->push_back(alloc_attr);
     }
-    ret_alloc_attrs->push_back(alloc_attr);
   }
 
   return Status::OK();
diff --git a/tensorflow/core/common_runtime/partitioning_utils.h b/tensorflow/core/common_runtime/partitioning_utils.h
index 1eb1742..32bc36b 100644
--- a/tensorflow/core/common_runtime/partitioning_utils.h
+++ b/tensorflow/core/common_runtime/partitioning_utils.h
@@ -34,31 +34,34 @@
     const DeviceSet& device_set, std::unique_ptr<Graph> graph,
     std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
 
-// Each subgraph produced by partitioning the function body contains a subset
-// of the original `Arg` and `Retval` nodes. This function performs
-// bookkeeping to track which `Arg` and `Retval` nodes were placed on a
-// particular device / subgraph.
+// This function performs bookkeeping to track which `Arg` and `Retval` nodes
+// were placed on a particular device / graph.
 //
 // More specifically, this function
-//  (1) rewrites the indices of the `Arg` and `Retval` nodes placed
-//      on a particular device.  When a function is partitioned, each
-//      partition `subgraph` gets a subset of the arguments and
-//      return values. The `index` attributes of these _Arg and _Retval
-//      nodes reflect the indices of these parameters in the original
-//      function. To convert `subgraph` to a function, we need to replace
-//      there original indices with 0, 1, 2, ... .
 //
-//      The argument and return value order in the partitioned function is
-//      determined by the argument and return value order in the original
-//      function. This stability is important because it enables us to treat
-//      a single-partition function as having the same signature as the
-//      subgraph.
+//  (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be
+//      consecutive.
+//
+//      These indices might not be consecutive after grappler's pruning
+//      optimization (e.g. removing redundant Args), or graph partitioning. In
+//      the latter case, the nodes in `graph` are placed on `device_type`, and
+//      each such graph partition gets a subset of the arguments and return
+//      values. The `index` attributes of these _Arg and _Retval nodes reflect
+//      the indices of these parameters in the original function. To convert
+//      `subgraph` to a function, we need to replace there original indices with
+//      0, 1, 2, ... .
+//
+//      The argument and return value order in `graph` is determined by the
+//      argument and return value order in the original function. This stability
+//      is important because it enables us to treat a single-partition function
+//      as having the same signature as the subgraph.
+//
 //  (2) records the subsets of `Arg` and `Retval` nodes assigned to the
 //      device in `*_indices`, and
 //  (3) records which `Arg` and `Retval` nodes live in host memory in
-//      `*_alloc_attrs`.
+//      `*_alloc_attrs`. If these vectors are NULL, do nothing here.
 Status UpdateArgAndRetvalMetadata(
-    Graph* subgraph, const string& device_type,
+    Graph* graph, const string& device_type,
     std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
     std::vector<AllocatorAttributes>* arg_alloc_attrs,
     std::vector<AllocatorAttributes>* ret_alloc_attrs);
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 50f3b52..60b8379 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -608,6 +608,8 @@
   return Status::OK();
 }
 
+}  // anonymous namespace
+
 Status GetGraphAndArgRets(
     const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
     const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
@@ -644,8 +646,6 @@
   return Status::OK();
 }
 
-}  // anonymous namespace
-
 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
     const string& function_name, AttrSlice attrs,
     const FunctionLibraryRuntime::InstantiateOptions& options,
@@ -1426,6 +1426,10 @@
                                       InternalArgs* comp_args) -> Status {
       // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
       for (const auto& it : comp_data.arg_indices) {
+        if (it.index >= args.size()) {
+          return errors::InvalidArgument(
+              "index ", it.index, " is out of range [0, ", args.size(), ")");
+        }
         if (it.sub_index >= 0) {
           const Tensor& t = args[it.index];
           if (t.dtype() != DT_RESOURCE) {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index 3cac56d..832451f 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -327,7 +327,7 @@
       for (const auto& node_stats : dev_stats.hardware_stats->node_stats()) {
         string node_name = node_stats.node_name();
         // Remove the part of op name (e.g. :Conv2D) in the end of a node name.
-        size_t pos = node_name.find_first_of(":");
+        size_t pos = node_name.find_first_of(':');
         if (pos != std::string::npos) {
           node_name = node_name.substr(0, pos);
         }
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index 19e7adf..c1cb858 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -504,6 +504,7 @@
         "//tensorflow/core/data:dataset_proto_cc",
         "//tensorflow/core/data:standalone",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/memory",
         tf_grpc_cc_dependency(),
     ],
diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc
index b7253e3..2f81d0d 100644
--- a/tensorflow/core/data/service/worker_impl.cc
+++ b/tensorflow/core/data/service/worker_impl.cc
@@ -191,8 +191,15 @@
     }
     auto it = tasks_.find(request->task_id());
     if (it == tasks_.end()) {
-      response->set_end_of_sequence(true);
-      return Status::OK();
+      if (finished_tasks_.contains(request->task_id())) {
+        VLOG(3) << "Task is already finished";
+        response->set_end_of_sequence(true);
+        return Status::OK();
+      } else {
+        // Perhaps the workers hasn't gotten the task from the dispatcher yet.
+        // Return Unavailable so that the client knows to continue retrying.
+        return errors::Unavailable("Task ", request->task_id(), " not found");
+      }
     }
     auto& task = it->second;
     TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
@@ -362,6 +369,7 @@
     VLOG(3) << "Deleting task " << task_id
             << " at the request of the dispatcher";
     tasks_.erase(task_id);
+    finished_tasks_.insert(task_id);
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h
index 47b883d..80eb5b7 100644
--- a/tensorflow/core/data/service/worker_impl.h
+++ b/tensorflow/core/data/service/worker_impl.h
@@ -16,6 +16,7 @@
 #define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/data_service.h"
 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
@@ -85,6 +86,8 @@
   mutex mu_;
   // Information about tasks, keyed by task ids.
   absl::flat_hash_map<int64, std::unique_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
+  // Ids of tasks that have finished.
+  absl::flat_hash_set<int64> finished_tasks_ TF_GUARDED_BY(mu_);
   // Completed tasks which haven't yet been communicated to the dispatcher.
   absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
   bool cancelled_ TF_GUARDED_BY(mu_) = false;
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index c4e99ad..52018b3 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -378,9 +378,9 @@
       // Determine the path (if any) in the grpc:// URL, and add it as a field
       // of the JSON string.
       const string address = url.substr(strlen(DebugIO::kFileURLScheme));
-      const string path = address.find("/") == string::npos
+      const string path = address.find('/') == string::npos
                               ? ""
-                              : address.substr(address.find("/"));
+                              : address.substr(address.find('/'));
       grpc_event.set_wall_time(event.wall_time());
       LogMessage* log_message_grpc = grpc_event.mutable_log_message();
       log_message_grpc->set_message(
diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
index f673d2c..1edecb5 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
@@ -333,8 +333,13 @@
       const bool serialize_resource_dtype_and_shape =
           (i == 0) && (h->dtype == DT_RESOURCE) &&
           (!ctx->OnSameTask(src_device, target_device));
+      // For a remote component function, a function execution request and an
+      // input generation request may come from different workers. We need to
+      // guarantee that the input generation request is processed before the
+      // function execution request, so wait until the underlying remote handles
+      // are ready before sending a packed handle to the function device.
       TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
-          h, /*wait_until_ready=*/false,
+          h, /*wait_until_ready=*/true,
           op->add_handles()->mutable_remote_handle(), src_device,
           absl::get<Device*>(h->DeviceOrHostCPU(*ctx))->name(),
           serialize_resource_dtype_and_shape));
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index 985b045..26d1ac9 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -53,7 +53,7 @@
   uint32 port;
   auto colon_index = host_port.find_last_of(':');
   if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
-      host_port.substr(0, colon_index).find("/") != string::npos) {
+      host_port.substr(0, colon_index).find('/') != string::npos) {
     return errors::InvalidArgument("Could not interpret \"", host_port,
                                    "\" as a host-port pair.");
   }
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index dd675d7..eec4af3 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -67,7 +67,6 @@
         "model.h",
         "node_def_builder.h",
         "numeric_op.h",
-        "numeric_op_base.h",
         "op_kernel.h",
         "op_requires.h",
         "op_segment.h",
@@ -203,7 +202,6 @@
         "node_def_util.h",
         "node_properties.h",
         "numeric_op.h",
-        "numeric_op_base.h",
         "numeric_types.h",
         "op.h",
         "op_def_builder.h",
@@ -304,7 +302,6 @@
         "kernel_shape_util.h",
         "log_memory.cc",
         "log_memory.h",
-        "numeric_op_base.h",
         "numeric_types.h",
         "op_requires.h",
         "ops_util.cc",
@@ -1248,6 +1245,7 @@
     visibility = [
         "//tensorflow/core:__pkg__",
         "//tensorflow/python:__pkg__",
+        "//tensorflow/python/util:__pkg__",
     ],
 )
 
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 0e23c82..fd3004a 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -1061,6 +1061,12 @@
     }
   }
 
+  // Returns whether work is currently being recorded, i.e. whether we are
+  // currently between a `RecordStart` and a `RecordStop`.
+  bool IsRecording(IteratorContext* ctx) {
+    return collect_resource_usage(ctx) && node_->is_recording();
+  }
+
  private:
   bool collect_resource_usage(IteratorContext* ctx) {
     auto model = ctx->model();
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 50d2b96..b37a9ca 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1174,13 +1174,13 @@
 
 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
     FunctionDefAndOpRegistration(const FunctionDef& fdef_in,
-                                 const Graph* graph_with_debug_info)
+                                 const StackTracesMap& stack_traces)
     : fdef(fdef_in),
       // Exact shape inference for functions is handled by ShapeRefiner.
       // Here we pass a dummy shape inference function for legacy code paths.
       op_registration_data(fdef.signature(), shape_inference::UnknownShape,
                            true /* is_function */),
-      graph_with_debug_info(graph_with_debug_info) {}
+      stack_traces(stack_traces) {}
 
 FunctionLibraryDefinition::FunctionLibraryDefinition(
     const FunctionLibraryDefinition& other)
@@ -1233,14 +1233,14 @@
 }
 
 Status FunctionLibraryDefinition::AddFunctionDef(
-    const FunctionDef& fdef, const Graph* graph_with_debug_info) {
+    const FunctionDef& fdef, const StackTracesMap& stack_traces) {
   mutex_lock l(mu_);
   bool added;
-  return AddFunctionDefHelper(fdef, graph_with_debug_info, &added);
+  return AddFunctionDefHelper(fdef, stack_traces, &added);
 }
 
 Status FunctionLibraryDefinition::AddFunctionDefHelper(
-    const FunctionDef& fdef, const Graph* graph_with_debug_info, bool* added) {
+    const FunctionDef& fdef, const StackTracesMap& stack_traces, bool* added) {
   *added = false;
   std::shared_ptr<FunctionDefAndOpRegistration>& entry =
       function_defs_[fdef.signature().name()];
@@ -1260,8 +1260,7 @@
         "Cannot add function '", fdef.signature().name(),
         "' because an op with the same name already exists.");
   }
-  entry = std::make_shared<FunctionDefAndOpRegistration>(fdef,
-                                                         graph_with_debug_info);
+  entry = std::make_shared<FunctionDefAndOpRegistration>(fdef, stack_traces);
   *added = true;
   return Status::OK();
 }
@@ -1403,7 +1402,7 @@
   Status s;
   bool added;
   for (const FunctionDef& fdef : lib_def.function()) {
-    s = AddFunctionDefHelper(fdef, /*graph_with_debug_info=*/nullptr, &added);
+    s = AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added);
     if (!s.ok()) {
       Remove(funcs, funcs_with_grads);
       return s;
@@ -1430,8 +1429,7 @@
   mutex_lock l(mu_);
   bool added;
   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
-  TF_RETURN_IF_ERROR(
-      AddFunctionDefHelper(fdef, /*graph_with_debug_info=*/nullptr, &added));
+  TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added));
   return Status::OK();
 }
 
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 3951caa..dca235b 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -339,15 +339,26 @@
 
     // Drop the common largest prefix of all filenames in stack frames.
     bool filter_common_prefix = false;
+
+    // Do not show internal frames.
+    bool drop_internal_frames = false;
   };
 
   virtual ~AbstractStackTrace() {}
 
   // The returned span is alive as long as the AbstractStackTrace is alive.
   virtual absl::Span<StackFrame const> ToFrames() const = 0;
+
+  // Returns the last stack frame from user code, attempting to ignore the
+  // framework code. Returns an empty frame if no such stack frame was found.
+  virtual StackFrame LastUserFrame() const = 0;
   virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
 };
 
+using StackTracesMap =
+    std::unordered_map<std::string,
+                       std::shared_ptr<tensorflow::AbstractStackTrace>>;
+
 // Helper to maintain a map between function names in a given
 // FunctionDefLibrary and function definitions.
 //
@@ -397,7 +408,7 @@
   // Associates `graph` with a function `func_name`. Lifetime assumption:
   // `graph` has to outlive all instantiated graphs.
   Status AddFunctionDef(const FunctionDef& fdef,
-                        const Graph* graph_with_debug_info = nullptr)
+                        const StackTracesMap& stack_traces = {})
       TF_LOCKS_EXCLUDED(mu_);
 
   // Adds gradient definition 'grad' to this function library.
@@ -509,10 +520,14 @@
 
   // Returns graph with debug stack traces for the given function, or `nullptr`
   // if none found.
-  const Graph* GetGraphWithDebugInfo(const std::string& func_name) const {
+  const StackTracesMap& GetStackTraces(const std::string& func_name) const {
     tf_shared_lock l(mu_);
     std::shared_ptr<FunctionDefAndOpRegistration> entry = FindHelper(func_name);
-    return entry ? entry->graph_with_debug_info : nullptr;
+    if (entry) {
+      return entry->stack_traces;
+    }
+    static const auto* empty_map = new StackTracesMap;
+    return *empty_map;
   }
 
  private:
@@ -520,12 +535,11 @@
 
   struct FunctionDefAndOpRegistration {
     explicit FunctionDefAndOpRegistration(
-        const FunctionDef& fdef_in,
-        const Graph* graph_with_debug_info = nullptr);
+        const FunctionDef& fdef_in, const StackTracesMap& stack_traces = {});
 
     const FunctionDef fdef;
     const OpRegistrationData op_registration_data;
-    const Graph* graph_with_debug_info;
+    const StackTracesMap stack_traces;
   };
 
   std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
@@ -539,7 +553,7 @@
   // Same as AddFunctionDef/AddGradientDef except these methods set
   // `added` to true if the `fdef`/`grad` were actually added to this.
   Status AddFunctionDefHelper(const FunctionDef& fdef,
-                              const Graph* graph_with_debug_info, bool* added)
+                              const StackTracesMap& stack_traces, bool* added)
       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
   Status AddGradientDefHelper(const GradientDef& grad, bool* added)
       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
diff --git a/tensorflow/core/framework/graph_to_functiondef.h b/tensorflow/core/framework/graph_to_functiondef.h
index 834bf50..83e56ca 100644
--- a/tensorflow/core/framework/graph_to_functiondef.h
+++ b/tensorflow/core/framework/graph_to_functiondef.h
@@ -60,6 +60,13 @@
                           const std::vector<std::string>& output_names,
                           FunctionDef* fdef);
 
+Status GetGraphAndArgRets(
+    const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
+    const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
+    std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
+    std::vector<string>* ret_node_names, DataTypeVector* ret_types,
+    std::vector<string>* control_ret_node_names);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
diff --git a/tensorflow/core/framework/log_memory.cc b/tensorflow/core/framework/log_memory.cc
index ecdc3c4..7198376 100644
--- a/tensorflow/core/framework/log_memory.cc
+++ b/tensorflow/core/framework/log_memory.cc
@@ -29,7 +29,7 @@
 template <typename T>
 void OutputToLog(const T& proto) {
   string type_name = proto.GetTypeName();
-  const size_t index = type_name.find_last_of(".");
+  const size_t index = type_name.find_last_of('.');
   if (index != string::npos) type_name = type_name.substr(index + 1);
   LOG(INFO) << LogMemory::kLogMemoryLabel << " " << type_name << " { "
             << proto.ShortDebugString() << " }";
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index cc98528..4014c40 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -37,6 +37,88 @@
 // Wrapper for the square function to reduce verbosity.
 inline double Square(double x) { return x * x; }
 
+// Collects "essential" parallelism parameters and buffer size parameters in the
+// tree rooted in the given node. Which parallelism parameters are essential is
+// determined by the relative processing time spent in the corresponding
+// transformation. The collected parameters are returned via maps that map node
+// names to their respective parameters.
+inline void CollectParameters(
+    std::shared_ptr<Node> node,
+    const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
+    absl::flat_hash_map<string, std::shared_ptr<Parameter>>*
+        parallelism_parameters,
+    absl::flat_hash_map<string, std::shared_ptr<Parameter>>*
+        buffer_size_parameters) {
+  // Parallelism parameter is considered to be essential if the corresponding
+  // transformations's processing time is greater than essential rate times the
+  // average transformation self processing time.
+  constexpr double kEssentialRate = 0.3L;
+
+  absl::flat_hash_map<string, double> processing_times;
+  double processing_time = node->TotalProcessingTime(&processing_times);
+  double uniform_share =
+      processing_time / static_cast<double>(processing_times.size());
+  for (auto& pair : parameters) {
+    if (pair.second->name == kParallelism &&
+        processing_times[pair.first] > kEssentialRate * uniform_share) {
+      parallelism_parameters->insert(pair);
+    } else if (pair.second->name == kBufferSize) {
+      buffer_size_parameters->insert(pair);
+    }
+  }
+}
+
+// Applies the gradient descent method once and updates the parameter values. If
+// the new value is out of the range, bound it within the range between the
+// minimal and maximum values.
+inline void UpdateParameterValues(
+    const absl::flat_hash_map<string, double>& gradients,
+    absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) {
+  // Gradient descent step size.
+  constexpr double kDescentStep = 0.1L;
+  double new_value;
+
+  double max_abs_derivative = 1.0;
+  for (auto& pair : *parameters) {
+    if (std::round(pair.second->value) != pair.second->max) {
+      auto* gradient = gtl::FindOrNull(gradients, pair.first);
+      if (gradient) {
+        max_abs_derivative = std::max(max_abs_derivative, std::abs(*gradient));
+      }
+    }
+  }
+  for (auto& pair : *parameters) {
+    auto* gradient = gtl::FindOrNull(gradients, pair.first);
+    if (gradient) {
+      new_value =
+          pair.second->value - kDescentStep * (*gradient) / max_abs_derivative;
+      // Projection on a feasible interval.
+      if (new_value > pair.second->max) {
+        pair.second->value = pair.second->max;
+      } else if (new_value < pair.second->min) {
+        pair.second->value = pair.second->min;
+      } else {
+        pair.second->value = new_value;
+      }
+    }
+  }
+}
+
+// Copies the parameter values (which are for optimization tuning) and updates
+// the state values (which are for the input pipeline to follow).
+inline void UpdateStateValues(
+    absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) {
+  VLOG(2) << "Number of tunable parameters: " << parameters->size();
+  for (auto& pair : *parameters) {
+    auto& parameter = pair.second;
+    VLOG(2) << "Setting tunable parameter " << pair.first << " to "
+            << parameter->value;
+    mutex_lock l(*parameter->state->mu);
+    parameter->state->value = parameter->value;
+    parameter->state->cond_var->notify_all();
+  }
+}
+
 // The first input of InterleaveMany corresponds to the input dataset whose
 // elements are used to create the (derived) input datasets whose elements are
 // interleaved as output.
@@ -1406,27 +1488,37 @@
   return parameters;
 }
 
-absl::flat_hash_map<string, std::shared_ptr<Parameter>>
-Model::CollectEssentialParallelism(
-    std::shared_ptr<Node> node,
-    const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters) {
-  // Parallelism parameter is considered to be essential if the corresponding
-  // transformations's processing time is greater than essential rate times the
-  // average transformation self processing time.
-  constexpr double kEssentialRate = 0.3L;
+bool Model::ShouldStop(
+    int64 cpu_budget, int64 ram_budget,
+    const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
+    const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
+        parallelism_parameters,
+    const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
+        buffer_size_parameters,
+    std::shared_ptr<Node> snapshot, bool* cpu_budget_reached) {
+  if (!(*cpu_budget_reached)) {
+    // If those essential transformations' parallelism reaches the CPU
+    // budget, we will only tune the buffer size parameters in future
+    // iterations.
+    int64 model_parallelism = 0;
+    for (auto& pair : parallelism_parameters) {
+      model_parallelism += std::round(pair.second->value);
+    }
+    *cpu_budget_reached = (model_parallelism > cpu_budget);
+  }
 
-  absl::flat_hash_map<string, double> processing_times;
-  double processing_time = node->TotalProcessingTime(&processing_times);
-  double uniform_share =
-      processing_time / static_cast<double>(processing_times.size());
-  absl::flat_hash_map<string, std::shared_ptr<Parameter>> essential_parameters;
-  for (auto& pair : parameters) {
-    if (pair.second->name == kParallelism &&
-        processing_times[pair.first] > kEssentialRate * uniform_share) {
-      essential_parameters.insert(pair);
+  bool all_max = true;
+  for (auto& pair :
+       (*cpu_budget_reached ? buffer_size_parameters : parameters)) {
+    if (std::round(pair.second->value) < pair.second->max) {
+      all_max = false;
+      break;
     }
   }
-  return essential_parameters;
+
+  // If all parameters have reached their maximum values or RAM budget is
+  // reached, we stop the iterations.
+  return all_max || TotalMaximumBufferedBytes(snapshot) > ram_budget;
 }
 
 void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget,
@@ -1438,12 +1530,16 @@
   }
   VLOG(2) << "Starting optimization of tunable parameters with GradientDescent";
   auto parameters = CollectTunableParameters(snapshot);
-  auto essential_parameters = CollectEssentialParallelism(snapshot, parameters);
+  // The maps of "essential" parallelism parameters and buffer size parameters.
+  absl::flat_hash_map<string, std::shared_ptr<Parameter>>
+      parallelism_parameters, buffer_size_parameters;
+  CollectParameters(snapshot, parameters, &parallelism_parameters,
+                    &buffer_size_parameters);
+
+  // Initialize the parameter values to minimal before tuning.
   for (auto& pair : parameters) {
     pair.second->value = pair.second->min;
   }
-  // Gradient descent step size.
-  constexpr double kDescentStep = 0.1L;
 
   // Optimization is stopped once the `OutputTime` improvement is smaller than
   // this value.
@@ -1454,53 +1550,33 @@
 
   double output_time = 0;
   double new_output_time;
-  double new_value;
-  for (int i = 0; i < kMaxIterations; ++i) {
+
+  // When the CPU budget is reached, the parallelism parameter values are fixed
+  // and we only increase the buffer size parameters.
+  bool cpu_budget_reached = false;
+
+  for (int i = 0;
+       i < kMaxIterations &&
+       !ShouldStop(cpu_budget, ram_budget, parameters, parallelism_parameters,
+                   buffer_size_parameters, snapshot, &cpu_budget_reached);
+       ++i) {
     absl::flat_hash_map<string, double> gradients;
     new_output_time = OutputTime(snapshot, model_input_time, &gradients);
-    int64 model_parallelism = 0;
-    for (auto& pair : essential_parameters) {
-      model_parallelism += std::round(pair.second->value);
-    }
-    // We terminate once the improvement of the output latency is too small or
-    // the essential transformations' parallelism reaches the CPU budget or the
-    // worst-case total buffer size exceeds the memory budget.
-    if (std::abs(output_time - new_output_time) < kOptimizationPrecision ||
-        model_parallelism > cpu_budget ||
-        TotalMaximumBufferedBytes(snapshot) > ram_budget) {
+    // We also terminate once the improvement of the output latency is too
+    // small.
+    if (std::abs(output_time - new_output_time) < kOptimizationPrecision) {
       break;
     }
-    double max_abs_derivative = 1.0;
-    for (auto& pair : parameters) {
-      if (pair.second->value != pair.second->max) {
-        max_abs_derivative =
-            std::max(max_abs_derivative, std::abs(gradients[pair.first]));
-      }
-    }
-    for (auto& pair : parameters) {
-      new_value = pair.second->value -
-                  kDescentStep * gradients[pair.first] / max_abs_derivative;
-      // Projection on a feasible interval.
-      if (new_value > pair.second->max) {
-        pair.second->value = pair.second->max;
-      } else if (new_value < pair.second->min) {
-        pair.second->value = pair.second->min;
-      } else {
-        pair.second->value = new_value;
-      }
-    }
+
+    UpdateParameterValues(
+        gradients, &(cpu_budget_reached ? buffer_size_parameters : parameters));
     output_time = new_output_time;
   }
-  VLOG(2) << "Number of tunable parameters: " << parameters.size();
+
   for (auto& pair : parameters) {
     pair.second->value = std::round(pair.second->value);
-    auto& parameter = pair.second;
-    VLOG(2) << "Setting tunable parameter " << pair.first << " to "
-            << parameter->value;
-    mutex_lock l(*parameter->state->mu);
-    parameter->state->value = parameter->value;
-    parameter->state->cond_var->notify_all();
   }
+  UpdateStateValues(&parameters);
 }
 
 void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget,
@@ -1517,6 +1593,7 @@
   // improvement is greater than this constant.
   constexpr double kBufferSizeMinDelta = 1.0L;
 
+  // Initialize the parameter values to minimal before tuning.
   for (auto& pair : parameters) {
     pair.second->value = pair.second->min;
   }
@@ -1560,15 +1637,7 @@
     }
     best_parameter->value++;
   }
-  VLOG(2) << "Number of tunable parameters: " << parameters.size();
-  for (auto& pair : parameters) {
-    auto& parameter = pair.second;
-    VLOG(2) << "Setting tunable parameter " << pair.first << " to "
-            << parameter->value;
-    mutex_lock l(*parameter->state->mu);
-    parameter->state->value = parameter->value;
-    parameter->state->cond_var->notify_all();
-  }
+  UpdateStateValues(&parameters);
 }
 
 double Model::OutputTime(std::shared_ptr<Node> node, double model_input_time,
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index 53365d2..2645d7b 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -296,6 +296,10 @@
     }
   }
 
+  // Returns whether work is currently being recorded, i.e. whether we are
+  // currently between a `record_start` and a `record_stop`.
+  bool is_recording() TF_LOCKS_EXCLUDED(mu_) { return work_start_ > 0; }
+
   // Removes an input.
   void remove_input(std::shared_ptr<Node> input) TF_LOCKS_EXCLUDED(mu_) {
     mutex_lock l(mu_);
@@ -644,16 +648,17 @@
   absl::flat_hash_map<string, std::shared_ptr<Parameter>>
   CollectTunableParameters(std::shared_ptr<Node> node);
 
-  // Collects "essential" parallelism parameters of transformations in the tree
-  // rooted in the given node. Which parameters are essential is determined by
-  // comparison the processing time spent in the corresponding transformation
-  // relative to other transformations. The collected parameters are returned
-  // as a mapping from a (unique) node name to a parallelism parameter.
-  absl::flat_hash_map<string, std::shared_ptr<Parameter>>
-  CollectEssentialParallelism(
-      std::shared_ptr<Node> node,
+  // Determines if we should stop the gradient descent optimization iterations
+  // based on number of increasable parameters, CPU budget, RAM budget and
+  // current resource usage.
+  bool ShouldStop(
+      int64 cpu_budget, int64 ram_budget,
+      const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
       const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
-          parameters);
+          parallelism_parameters,
+      const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
+          buffer_size_parameters,
+      std::shared_ptr<Node> snapshot, bool* cpu_budget_reached);
 
   // This optimization algorithm starts by setting all tunable parallelism
   // parameters to the minimum value. It then repeatedly identifies the
diff --git a/tensorflow/core/framework/model_test.cc b/tensorflow/core/framework/model_test.cc
index 97eb720..d210e81 100644
--- a/tensorflow/core/framework/model_test.cc
+++ b/tensorflow/core/framework/model_test.cc
@@ -949,6 +949,15 @@
 INSTANTIATE_TEST_SUITE_P(Test, OptimizeZeroRamBudgetTest,
                          ::testing::Values(0, 1));
 
+TEST(RecordTimeTest, RecordTimeTest) {
+  std::shared_ptr<Node> source = model::MakeSourceNode({});
+  EXPECT_FALSE(source->is_recording());
+  source->record_start(100);
+  EXPECT_TRUE(source->is_recording());
+  source->record_stop(200);
+  EXPECT_FALSE(source->is_recording());
+}
+
 }  // namespace
 }  // namespace model
 }  // namespace data
diff --git a/tensorflow/core/framework/numeric_op.h b/tensorflow/core/framework/numeric_op.h
index 9f8ceed..0167e21 100644
--- a/tensorflow/core/framework/numeric_op.h
+++ b/tensorflow/core/framework/numeric_op.h
@@ -12,22 +12,38 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+
 #ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_
 #define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_
 
-#include "tensorflow/core/framework/numeric_op_base.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 
 namespace tensorflow {
 
+// One input and one output, both the same type.
 template <class T>
-using UnaryOp = UnaryOpBase<T, OpKernel, OpKernelConstruction>;
+class UnaryOp : public OpKernel {
+ public:
+  explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) {
+    const DataType dt = DataTypeToEnum<T>::v();
+    OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt}));
+  }
+};
 
+// Two inputs and one output, all the same type.
 template <class T>
-using BinaryOp = BinaryOpBase<T, OpKernel, OpKernelConstruction>;
+class BinaryOp : public OpKernel {
+ public:
+  explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) {
+    const DataType dt = DataTypeToEnum<T>::v();
+    OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt}));
+  }
+};
 
 // For operations where the input and output are the same shape.
 //
diff --git a/tensorflow/core/framework/numeric_op_base.h b/tensorflow/core/framework/numeric_op_base.h
deleted file mode 100644
index be7d3bf..0000000
--- a/tensorflow/core/framework/numeric_op_base.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_BASE_H_
-#define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_BASE_H_
-
-#include "tensorflow/core/framework/op_requires.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/core/status.h"
-
-namespace tensorflow {
-
-// One input and one output, both the same type.
-template <class T, class OpKernelT, class OpKernelConstructionT>
-class UnaryOpBase : public OpKernelT {
- public:
-  explicit UnaryOpBase(OpKernelConstructionT* construction) :
-      OpKernelT(construction) {
-    const DataType dt = DataTypeToEnum<T>::v();
-    OP_REQUIRES_OK(construction, construction->MatchSignature({dt}, {dt}));
-  }
-};
-
-// Two inputs and one output, all the same type.
-template <class T, class OpKernelT, class OpKernelConstructionT>
-class BinaryOpBase : public OpKernelT {
- public:
-  explicit BinaryOpBase(OpKernelConstructionT* construction) :
-      OpKernelT(construction) {
-    const DataType dt = DataTypeToEnum<T>::v();
-    OP_REQUIRES_OK(construction, construction->MatchSignature({dt, dt}, {dt}));
-  }
-};
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_BASE_H_
diff --git a/tensorflow/core/framework/op_def.proto b/tensorflow/core/framework/op_def.proto
index ad109a3..756c8e4 100644
--- a/tensorflow/core/framework/op_def.proto
+++ b/tensorflow/core/framework/op_def.proto
@@ -8,6 +8,7 @@
 option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto";
 import "tensorflow/core/framework/attr_value.proto";
 import "tensorflow/core/framework/types.proto";
+import "tensorflow/core/framework/resource_handle.proto";
 
 // Defines an operation. A NodeDef in a GraphDef specifies an Op by
 // using the "op" field which should match the name of a OpDef.
@@ -42,6 +43,9 @@
     // type, type_attr, and number_attr may be specified.
     string type_list_attr = 6;
 
+    // The handle data for resource inputs.
+    repeated ResourceHandleProto.DtypeAndShape handle_data = 7;
+
     // For inputs: if true, the inputs are required to be refs.
     //   By default, inputs can be either refs or non-refs.
     // For outputs: if true, outputs are refs, otherwise they are not.
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 93f4eaf..2e7e7fb 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -454,6 +454,7 @@
     copy->MaybeCopyOnWrite();
     copy->props_->op_def = op_def;
   }
+  copy->SetStackTrace(node->GetStackTrace());
 
   return copy;
 }
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 5dc9a55..cc5c5b2 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -74,7 +74,7 @@
 enum class ConstructionContext {
   kNotTracked,     // Not tracked.
   kDirectSession,  // From `tensorflow::DirectSession`, TF1 session API.
-  kFunctionDef,    // From `FunctionDef`, @tf.function.
+  kEagerRuntime,   // Registered from TF2 eager runtime.
 };
 
 class Node {
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 7680bca..bf57e26 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -371,13 +371,6 @@
 void OptimizeControlFlowColocation(Graph* graph) {
   auto visit = [](Node* node) {
     if (IsSwitch(node)) {
-      // Pivot Switch nodes (which are also of type Switch) are already placed
-      // on the CPU and colocated with its inputs that are also already on the
-      // CPU (or might be placed on GPU but in host memory).
-      if (HasNodeAttr(node->def(), "_PivotSwitch")) {
-        DCHECK(node->requested_device().find("CPU") != string::npos);
-        return;
-      }
       for (const Edge* in_edge : node->in_edges()) {
         if (in_edge->dst_input() == 0) {
           // Colocate with the data input.
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 8bda3f6..951a78d 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -334,6 +334,7 @@
         "@com_google_absl//absl/strings",
         "//third_party/eigen3",
         "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler/clusters:utils",
     ] + tf_protos_grappler(),
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 086f1e9..0b28dc3 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -26,10 +26,12 @@
 #include "tensorflow/core/grappler/clusters/utils.h"
 #include "tensorflow/core/grappler/costs/op_context.h"
 #include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/platform/errors.h"
 
 namespace tensorflow {
 namespace grappler {
 
+// TODO(dyoon): update op to Predict method map for TF ops with V2 or V3 suffix.
 constexpr int kOpsPerMac = 2;
 constexpr char kGuaranteeConst[] = "GuaranteeConst";
 constexpr char kAddN[] = "AddN";
@@ -121,6 +123,7 @@
 constexpr char kAssignSubVariableOp[] = "AssignSubVariableOp";
 
 static const Costs::Duration kMinComputeTime(1);
+static const int64 kMinComputeOp = 1;
 
 namespace {
 
@@ -354,11 +357,12 @@
 OpLevelCostEstimator::OpLevelCostEstimator() {
   // Syntactic sugar to build and return a lambda that takes an OpInfo and
   // returns a cost.
-  typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context)
-      const;
-  auto wrap = [this](CostImpl impl) -> std::function<Costs(const OpContext&)> {
-    return [this, impl](const OpContext& op_context) {
-      return (this->*impl)(op_context);
+  typedef Status (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context,
+                                                   NodeCosts*) const;
+  auto wrap = [this](CostImpl impl)
+      -> std::function<Status(const OpContext&, NodeCosts*)> {
+    return [this, impl](const OpContext& op_context, NodeCosts* node_costs) {
+      return (this->*impl)(op_context, node_costs);
     };
   };
 
@@ -642,27 +646,72 @@
 }
 
 Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
+  Costs costs;
+  NodeCosts node_costs;
+  if (PredictNodeCosts(op_context, &node_costs).ok()) {
+    if (node_costs.has_costs) {
+      return node_costs.costs;
+    }
+    // Convert NodeCosts to Costs.
+    if (node_costs.minimum_cost_op) {
+      // Override to minimum cost; Note that some ops with minimum cost may have
+      // non-typical device (e.g., channel for _Send), which may fail with
+      // GetDeviceInfo(), called from PredictOpCountBasedCost(). Make sure we
+      // directly set minimum values to Costs here, not calling
+      // PredictOpCountBasedCost().
+      costs.compute_time = kMinComputeTime;
+      costs.execution_time = kMinComputeTime;
+      costs.memory_time = 0;
+      costs.intermediate_memory_time = 0;
+      costs.intermediate_memory_read_time = 0;
+      costs.intermediate_memory_write_time = 0;
+    } else {
+      // Convert NodeCosts to Costs.
+      costs = PredictOpCountBasedCost(
+          node_costs.num_compute_ops, node_costs.num_total_read_bytes(),
+          node_costs.num_total_write_bytes(), op_context.op_info);
+    }
+    VLOG(1) << "Operation " << op_context.op_info.op() << " takes "
+            << costs.execution_time.count() << " ns.";
+    // Copy additional stats from NodeCosts to Costs.
+    costs.max_memory = node_costs.max_memory;
+    costs.persistent_memory = node_costs.persistent_memory;
+    costs.temporary_memory = node_costs.temporary_memory;
+    costs.inaccurate = node_costs.inaccurate;
+    costs.num_ops_with_unknown_shapes =
+        node_costs.num_nodes_with_unknown_shapes;
+    costs.num_ops_total = node_costs.num_nodes;
+    return costs;
+  }
+  // Errors during node cost estimate.
+  LOG(WARNING) << "Error in PredictCost() for the op: "
+               << op_context.op_info.ShortDebugString();
+  costs = Costs::ZeroCosts(/*inaccurate=*/true);
+  costs.num_ops_with_unknown_shapes = node_costs.num_nodes_with_unknown_shapes;
+  return costs;
+}
+
+Status OpLevelCostEstimator::PredictNodeCosts(const OpContext& op_context,
+                                              NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
   auto it = device_cost_impl_.find(op_info.op());
   if (it != device_cost_impl_.end()) {
-    std::function<Costs(const OpContext&)> estimator = it->second;
-    Costs costs = estimator(op_context);
-    VLOG(1) << "Operation " << op_info.op() << " takes "
-            << costs.execution_time.count() << " ns.";
-    return costs;
+    std::function<Status(const OpContext&, NodeCosts*)> estimator = it->second;
+    return estimator(op_context, node_costs);
   }
 
   if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
-    return PredictVariable(op_context);
+    return PredictVariable(op_context, node_costs);
   }
 
   if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
-    return PredictCwiseOp(op_context);
+    return PredictCwiseOp(op_context, node_costs);
   }
 
   VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
 
-  return PredictCostOfAnUnknownOp(op_context);
+  node_costs->num_nodes_with_unknown_op_type = 1;
+  return PredictCostOfAnUnknownOp(op_context, node_costs);
 }
 
 DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
@@ -716,7 +765,8 @@
   return DeviceInfo(gflops, gb_per_sec);
 }
 
-Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context,
+                                            NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
   bool found_unknown_shapes = false;
   // For element-wise operations, op count is the element count of any input. We
@@ -736,30 +786,25 @@
   }
 
   int op_cost = 1;
-  bool is_known_elementwise_op = false;
   auto it = elementwise_ops_.find(op_info.op());
   if (it != elementwise_ops_.end()) {
     op_cost = it->second;
-    is_known_elementwise_op = true;
   } else {
-    LOG(WARNING) << "Not a cwise op: " << op_info.op();
+    return errors::InvalidArgument("Not a cwise op: ", op_info.op());
   }
 
-  Costs costs = PredictOpCountBasedCost(op_count * op_cost, op_info);
-  if (found_unknown_shapes || !is_known_elementwise_op) {
-    costs.inaccurate = true;
-  }
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  return costs;
+  return PredictDefaultNodeCosts(op_count * op_cost, op_context,
+                                 &found_unknown_shapes, node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictCostOfAnUnknownOp(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictCostOfAnUnknownOp(
+    const OpContext& op_context, NodeCosts* node_costs) const {
   // Don't assume the operation is cwise, return cost based on input/output size
   // and admit that it is inaccurate...
-  auto costs = PredictOpCountBasedCost(0, op_context.op_info);
-  costs.inaccurate = true;
-  return costs;
+  bool found_unknown_shapes = false;
+  node_costs->inaccurate = true;
+  return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
+                                 node_costs);
 }
 
 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
@@ -1509,6 +1554,17 @@
   return total_input_size;
 }
 
+std::vector<int64> OpLevelCostEstimator::CalculateInputTensorSize(
+    const OpInfo& op_info, bool* found_unknown_shapes) {
+  std::vector<int64> input_tensor_size;
+  input_tensor_size.reserve(op_info.inputs().size());
+  for (auto& input : op_info.inputs()) {
+    input_tensor_size.push_back(
+        CalculateTensorSize(input, found_unknown_shapes));
+  }
+  return input_tensor_size;
+}
+
 int64 OpLevelCostEstimator::CalculateLargestInputCount(
     const OpInfo& op_info, bool* found_unknown_shapes) {
   int64 largest_input_count = 0;
@@ -1527,7 +1583,7 @@
 int64 OpLevelCostEstimator::CalculateOutputSize(const OpInfo& op_info,
                                                 bool* found_unknown_shapes) {
   int64 total_output_size = 0;
-  // use float as default for calculations
+  // Use float as default for calculations.
   for (const auto& output : op_info.outputs()) {
     DataType dt = output.dtype();
     const auto& original_output_shape = output.shape();
@@ -1545,6 +1601,43 @@
   return total_output_size;
 }
 
+std::vector<int64> OpLevelCostEstimator::CalculateOutputTensorSize(
+    const OpInfo& op_info, bool* found_unknown_shapes) {
+  std::vector<int64> output_tensor_size;
+  output_tensor_size.reserve(op_info.outputs().size());
+  // Use float as default for calculations.
+  for (const auto& output : op_info.outputs()) {
+    DataType dt = output.dtype();
+    const auto& original_output_shape = output.shape();
+    int64 output_size = DataTypeSize(BaseType(dt));
+    int num_dims = std::max(1, original_output_shape.dim_size());
+    auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
+                                             found_unknown_shapes);
+    for (const auto& dim : output_shape.dim()) {
+      output_size *= dim.size();
+    }
+    output_tensor_size.push_back(output_size);
+  }
+  return output_tensor_size;
+}
+
+Status OpLevelCostEstimator::PredictDefaultNodeCosts(
+    const int64 num_compute_ops, const OpContext& op_context,
+    bool* found_unknown_shapes, NodeCosts* node_costs) {
+  const auto& op_info = op_context.op_info;
+  node_costs->num_compute_ops = num_compute_ops;
+  node_costs->num_input_bytes_accessed =
+      CalculateInputTensorSize(op_info, found_unknown_shapes);
+  node_costs->num_output_bytes_accessed =
+      CalculateOutputTensorSize(op_info, found_unknown_shapes);
+  node_costs->max_memory = node_costs->num_total_output_bytes();
+  if (*found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
+}
+
 bool HasZeroDim(const OpInfo& op_info) {
   for (int i = 0; i < op_info.inputs_size(); ++i) {
     const auto& input = op_info.inputs(i);
@@ -1560,62 +1653,54 @@
   return false;
 }
 
-Costs OpLevelCostEstimator::PredictConv2D(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictConv2D(const OpContext& op_context,
+                                           NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
   if (HasZeroDim(op_info)) {
-    Costs costs = Costs::ZeroCosts();
-    costs.inaccurate = true;
-    costs.num_ops_with_unknown_shapes = 1;
-    return costs;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+    return errors::InvalidArgument("Conv2D op includes zero dimension: ",
+                                   op_info.ShortDebugString());
   }
   bool found_unknown_shapes = false;
-  auto costs = PredictOpCountBasedCost(
-      CountConv2DOperations(op_info, &found_unknown_shapes), op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  return costs;
+  int64 num_compute_ops = CountConv2DOperations(op_info, &found_unknown_shapes);
+  return PredictDefaultNodeCosts(num_compute_ops, op_context,
+                                 &found_unknown_shapes, node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictConv2DBackpropInput(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictConv2DBackpropInput(
+    const OpContext& op_context, NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
   if (HasZeroDim(op_info)) {
-    Costs costs = Costs::ZeroCosts();
-    costs.inaccurate = true;
-    costs.num_ops_with_unknown_shapes = true;
-    return costs;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+    return errors::InvalidArgument(
+        "Conv2DBackpropInput op includes zero dimension",
+        op_info.ShortDebugString());
   }
   bool found_unknown_shapes = false;
-  auto costs =
-      PredictOpCountBasedCost(CountConv2DBackpropInputOperations(
-                                  op_info, nullptr, &found_unknown_shapes),
-                              op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  return costs;
+  int64 num_compute_ops = CountConv2DBackpropInputOperations(
+      op_info, nullptr, &found_unknown_shapes);
+  return PredictDefaultNodeCosts(num_compute_ops, op_context,
+                                 &found_unknown_shapes, node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictConv2DBackpropFilter(
+    const OpContext& op_context, NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
   if (HasZeroDim(op_info)) {
-    Costs costs = Costs::ZeroCosts();
-    costs.inaccurate = true;
-    costs.num_ops_with_unknown_shapes = true;
-    return costs;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+    return errors::InvalidArgument(
+        "Conv2DBackpropFilter op includes zero dimension",
+        op_info.ShortDebugString());
   }
   bool found_unknown_shapes = false;
-  auto costs =
-      PredictOpCountBasedCost(CountConv2DBackpropFilterOperations(
-                                  op_info, nullptr, &found_unknown_shapes),
-                              op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  return costs;
+  int64 num_compute_ops = CountConv2DBackpropFilterOperations(
+      op_info, nullptr, &found_unknown_shapes);
+  return PredictDefaultNodeCosts(num_compute_ops, op_context,
+                                 &found_unknown_shapes, node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
+    const OpContext& op_context, NodeCosts* node_costs) const {
   // FusedConv2DBiasActivation computes a fused kernel which implements:
   // 2D convolution, adds side input with separate scaling on convolution and
   // side inputs, then adds bias, and finally applies the ReLU activation
@@ -1639,18 +1724,16 @@
   std::string data_format = GetDataFormat(op_context.op_info);
   if (data_format != "NCHW" && data_format != "NHWC" &&
       data_format != "NCHW_VECT_C") {
-    LOG(WARNING) << "unsupported data format: " << data_format;
-    Costs cost = Costs::ZeroCosts();
-    cost.inaccurate = true;
-    return cost;
+    return errors::InvalidArgument(
+        "Unsupported data format (", data_format,
+        ") for op: ", op_context.op_info.ShortDebugString());
   }
   std::string filter_format = GetFilterFormat(op_context.op_info);
   if (filter_format != "HWIO" && filter_format != "OIHW" &&
       filter_format != "OIHW_VECT_I") {
-    LOG(WARNING) << "unsupported filter format: " << filter_format;
-    Costs cost = Costs::ZeroCosts();
-    cost.inaccurate = true;
-    return cost;
+    return errors::InvalidArgument(
+        "Unsupported filter format (", filter_format,
+        ") for op: ", op_context.op_info.ShortDebugString());
   }
 
   auto& conv_input = op_context.op_info.inputs(0);
@@ -1695,42 +1778,48 @@
   *op_context_with_output.op_info.mutable_outputs()->Add() = output;
 
   // Construct component operations and run the cost computation.
-  auto costs = PredictFusedOp(op_context_with_output, component_ops);
-  costs.inaccurate |= found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = costs.inaccurate;
-  return costs;
+  if (found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return PredictFusedOp(op_context_with_output, component_ops, node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictMatMul(const OpContext& op_context,
+                                           NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
   bool found_unknown_shapes = false;
-  auto costs = PredictOpCountBasedCost(
-      CountMatMulOperations(op_info, &found_unknown_shapes), op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  return costs;
+  int64 num_compute_ops = CountMatMulOperations(op_info, &found_unknown_shapes);
+  return PredictDefaultNodeCosts(num_compute_ops, op_context,
+                                 &found_unknown_shapes, node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictEinsum(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictEinsum(const OpContext& op_context,
+                                           NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
 
   auto it = op_info.attr().find("equation");
-  if (it == op_info.attr().end()) return Costs::ZeroCosts(/*inaccurate=*/true);
+  if (it == op_info.attr().end()) {
+    return errors::InvalidArgument("Einsum op doesn't have equation attr: ",
+                                   op_info.ShortDebugString());
+  }
+
   OpContext batch_matmul_op_context;
   bool found_unknown_shapes = false;
   bool success = GenerateBatchMatmulContextFromEinsum(
       op_context, &batch_matmul_op_context, &found_unknown_shapes);
-  if (!success) {
-    return PredictCostOfAnUnknownOp(op_context);
+  if (found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
   }
-  Costs costs = PredictCosts(batch_matmul_op_context);
-  costs.inaccurate = costs.inaccurate || found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  return costs;
+  if (!success) {
+    return PredictCostOfAnUnknownOp(op_context, node_costs);
+  }
+  return PredictNodeCosts(batch_matmul_op_context, node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
+    const OpContext& op_context, NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
   bool found_unknown_shapes = false;
   // input[0]: indices in sparse matrix a
@@ -1758,93 +1847,113 @@
       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
   int64 b_input_size =
       num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype()));
-  double input_size = a_indices_input_size + a_values_input_size +
-                      a_shape_input_size + b_input_size;
+  int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
 
-  double output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
-
-  auto costs =
-      PredictOpCountBasedCost(op_count, input_size, output_size, op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  costs.max_memory = output_size;
-
-  return costs;
+  node_costs->num_compute_ops = op_count;
+  node_costs->num_input_bytes_accessed = {a_indices_input_size,
+                                          a_values_input_size,
+                                          a_shape_input_size, b_input_size};
+  node_costs->num_output_bytes_accessed = {output_size};
+  if (found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictNoOp(const OpContext& op_context,
+                                         NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
   VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
-  return Costs::ZeroCosts();
+  // By default, NodeCosts is initialized to zero ops and bytes.
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictPureMemoryOp(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictPureMemoryOp(const OpContext& op_context,
+                                                 NodeCosts* node_costs) const {
   // Each output element is a copy of some element from input, with no required
   // computation, so just compute memory costs.
-  return PredictOpCountBasedCost(0, op_context.op_info);
+  bool found_unknown_shapes = false;
+  node_costs->num_nodes_with_pure_memory_op = 1;
+  return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
+                                 node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictIdentity(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictIdentity(const OpContext& op_context,
+                                             NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
-  VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
-  Costs result = Costs::ZeroCosts();
-  result.max_memory = CalculateOutputSize(op_info, &result.inaccurate);
-  result.num_ops_with_unknown_shapes = result.inaccurate;
-  // Assign the minimum amount of time we can represent to the identity op since
-  // it tends to be really cheap.
-  result.compute_time = kMinComputeTime;
-  result.execution_time = result.compute_time;
-  return result;
+  VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Identity";
+  node_costs->minimum_cost_op = true;
+  node_costs->num_compute_ops = kMinComputeOp;
+  // Identity op internally pass input tensor buffer's pointer to the output
+  // tensor buffer; no actual memory operation.
+  node_costs->num_input_bytes_accessed = {0};
+  node_costs->num_output_bytes_accessed = {0};
+  bool inaccurate = false;
+  node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
+  if (inaccurate) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictVariable(const OpContext& op_context,
+                                             NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
-  VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
-  Costs result = Costs::ZeroCosts();
-  result.persistent_memory = CalculateOutputSize(op_info, &result.inaccurate);
-  result.num_ops_with_unknown_shapes = result.inaccurate;
-
-  result.compute_time = kMinComputeTime;
-  result.execution_time = result.compute_time;
-  return result;
+  VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Variable";
+  node_costs->minimum_cost_op = true;
+  node_costs->num_compute_ops = kMinComputeOp;
+  // Variables are persistent ops; initialized before step; hence, no memory
+  // cost.
+  node_costs->num_input_bytes_accessed = {0};
+  node_costs->num_output_bytes_accessed = {0};
+  bool inaccurate = false;
+  node_costs->persistent_memory = CalculateOutputSize(op_info, &inaccurate);
+  if (inaccurate) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictBatchMatMul(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictBatchMatMul(const OpContext& op_context,
+                                                NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
   bool found_unknown_shapes = false;
-  Costs costs = PredictOpCountBasedCost(
-      CountBatchMatMulOperations(op_info, &found_unknown_shapes), op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  return costs;
+  int64 num_compute_ops =
+      CountBatchMatMulOperations(op_info, &found_unknown_shapes);
+  return PredictDefaultNodeCosts(num_compute_ops, op_context,
+                                 &found_unknown_shapes, node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictMetadata(const OpContext& op_context,
+                                             NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
-  Costs costs = Costs::ZeroCosts();
-  costs.max_memory = CalculateOutputSize(op_info, &costs.inaccurate);
-  costs.num_ops_with_unknown_shapes = costs.inaccurate;
-  // Metadata operations are so cheap we assume they take the minimum amount of
-  // time we can represent (1 ns).
-  costs.compute_time = kMinComputeTime;
-  costs.execution_time = costs.compute_time;
-
-  return costs;
+  node_costs->minimum_cost_op = true;
+  node_costs->num_compute_ops = kMinComputeOp;
+  node_costs->num_input_bytes_accessed = {0};
+  node_costs->num_output_bytes_accessed = {0};
+  bool inaccurate = false;
+  node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
+  if (inaccurate) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictGatherOrSlice(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictGatherOrSlice(const OpContext& op_context,
+                                                  NodeCosts* node_costs) const {
   // Gather & Slice ops can have a very large input, but only access a small
   // part of it. For these op the size of the output determines the memory cost.
   const auto& op_info = op_context.op_info;
 
   const int inputs_needed = op_info.op() == "Slice" ? 3 : 2;
   if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) {
-    Costs costs = Costs::ZeroCosts();
-    costs.inaccurate = true;
-    return costs;
+    return errors::InvalidArgument(
+        op_info.op(),
+        " Op doesn't have valid input / output: ", op_info.ShortDebugString());
   }
 
   bool unknown_shapes = false;
@@ -1853,10 +1962,19 @@
   // For roofline estimate we assume each copy has a unit cost.
   const int64 op_count =
       CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
+  node_costs->num_compute_ops = op_count;
 
-  const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
-  double input_size = output_size;
-  int begin_input_index = 1, end_input_index;
+  const int64 output_size = CalculateOutputSize(op_info, &unknown_shapes);
+  node_costs->num_output_bytes_accessed = {output_size};
+
+  node_costs->num_input_bytes_accessed.reserve(op_info.inputs().size());
+  int64 input_size = output_size;
+  // Note that input(0) byte accessed is not equal to input(0) tensor size.
+  // It's equal to the output size; though, input access is indexed gather or
+  // slice (ignore duplicate indices).
+  node_costs->num_input_bytes_accessed.push_back(input_size);
+  int begin_input_index = 1;
+  int end_input_index;
   if (op_info.op() == "Slice") {
     // Slice: 'input' (omitted), 'begin', 'size'
     end_input_index = 3;
@@ -1868,20 +1986,18 @@
     end_input_index = 2;
   }
   for (int i = begin_input_index; i < end_input_index; ++i) {
-    input_size +=
-        CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes);
+    node_costs->num_input_bytes_accessed.push_back(
+        CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes));
   }
-
-  Costs costs =
-      PredictOpCountBasedCost(op_count, input_size, output_size, op_info);
-  costs.inaccurate = unknown_shapes;
-  costs.num_ops_with_unknown_shapes = unknown_shapes;
-  costs.max_memory = output_size;
-
-  return costs;
+  if (unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictScatter(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictScatter(const OpContext& op_context,
+                                            NodeCosts* node_costs) const {
   // Scatter ops sparsely access a reference input and output tensor.
   const auto& op_info = op_context.op_info;
   bool found_unknown_shapes = false;
@@ -1904,6 +2020,7 @@
     num_elems_in_ref_per_index *= ref_tensor_shape.dim(i).size();
   }
   const int64 op_count = num_indices * num_elems_in_ref_per_index;
+  node_costs->num_compute_ops = op_count;
 
   // Sparsely access ref so input size depends on the number of operations
   int64 ref_input_size =
@@ -1912,44 +2029,50 @@
       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
   int64 updates_input_size =
       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
-
-  double total_input_size =
-      ref_input_size + indices_input_size + updates_input_size;
+  node_costs->num_input_bytes_accessed = {ref_input_size, indices_input_size,
+                                          updates_input_size};
 
   // Sparsely access ref so output size depends on the number of operations
-  double total_output_size =
+  int64 output_size =
       op_count * DataTypeSize(BaseType(op_info.outputs(0).dtype()));
+  node_costs->num_output_bytes_accessed = {output_size};
 
-  auto costs = PredictOpCountBasedCost(op_count, total_input_size,
-                                       total_output_size, op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-
-  return costs;
+  if (found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictFusedOp(
+Status OpLevelCostEstimator::PredictFusedOp(
     const OpContext& op_context,
-    const std::vector<OpContext>& fused_op_contexts) const {
-  // Note that PredictOpCountBasedCost will get the correct memory_time from
+    const std::vector<OpContext>& fused_op_contexts,
+    NodeCosts* node_costs) const {
+  // Note that PredictDefaultNodeCosts will get the correct memory costs from
   // the node's inputs and outputs; but we don't want to have to re-implement
   // the logic for computing the operation count of each of our component
   // operations here; so we simply add the compute times of each component
-  // operation, then update the execution time.
-  Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info);
+  // operation, then update the cost.
+  bool found_unknown_shapes = false;
+  Status s =
+      PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes, node_costs);
 
-  fused_cost.compute_time = 0;
-  fused_cost.inaccurate = false;
   for (auto& fused_op : fused_op_contexts) {
-    auto op_cost = PredictCosts(fused_op);
-
-    fused_cost.compute_time += op_cost.compute_time;
-    fused_cost.inaccurate |= op_cost.inaccurate;
-    fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time;
+    NodeCosts fused_node_costs;
+    s.Update(PredictNodeCosts(fused_op, &fused_node_costs));
+    node_costs->num_compute_ops += fused_node_costs.num_compute_ops;
+    node_costs->inaccurate |= fused_node_costs.inaccurate;
+    // Set, not increment. Note that we are predicting the cost of one fused
+    // node, not a function node composed of many nodes.
+    node_costs->num_nodes_with_unknown_shapes |=
+        fused_node_costs.num_nodes_with_unknown_shapes;
+    node_costs->num_nodes_with_unknown_op_type |=
+        fused_node_costs.num_nodes_with_unknown_op_type;
+    node_costs->num_nodes_with_pure_memory_op |=
+        fused_node_costs.num_nodes_with_pure_memory_op;
   }
 
-  CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &fused_cost);
-  return fused_cost;
+  return Status::OK();
 }
 
 /* static */
@@ -2040,7 +2163,8 @@
   return conv_dims;
 }
 
-Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context,
+                                            NodeCosts* node_costs) const {
   bool found_unknown_shapes = false;
   const auto& op_info = op_context.op_info;
   // x: op_info.inputs(0)
@@ -2050,38 +2174,41 @@
   // or 1 copy per output (kx * k1 = 1).
   int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
   int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops;
+  node_costs->num_compute_ops = ops;
 
-  double total_input_size = 0;
+  int64 input_size = 0;
   if (dims.ky >= dims.sy) {
-    total_input_size =
-        CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+    input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
   } else {  // dims.ky < dims.sy
     // Vertical stride is larger than vertical kernel; assuming row-major
     // format, skip unnecessary rows (or read every kx rows per sy rows, as the
     // others are not used for output).
     const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
-    total_input_size =
-        data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
+    input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
   }
-  const double total_output_size =
-      CalculateOutputSize(op_info, &found_unknown_shapes);
-
-  Costs costs = PredictOpCountBasedCost(ops, total_input_size,
-                                        total_output_size, op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  costs.max_memory = total_output_size;
-  return costs;
+  node_costs->num_input_bytes_accessed = {input_size};
+  const int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
+  node_costs->num_output_bytes_accessed = {output_size};
+  node_costs->max_memory = output_size;
+  if (found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictMaxPoolGrad(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictMaxPoolGrad(const OpContext& op_context,
+                                                NodeCosts* node_costs) const {
   bool found_unknown_shapes = false;
   const auto& op_info = op_context.op_info;
   // x: op_info.inputs(0)
   // y: op_info.inputs(1)
   // y_grad: op_info.inputs(2)
-  if (op_info.inputs_size() < 3) return Costs::ZeroCosts(/*inaccurate=*/true);
+  if (op_info.inputs_size() < 3) {
+    return errors::InvalidArgument("MaxPoolGrad op has invalid inputs: ",
+                                   op_info.ShortDebugString());
+  }
+
   ConvolutionDimensions dims = OpDimensionsFromInputs(
       op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
 
@@ -2099,48 +2226,62 @@
     ops = dims.batch * dims.iz *
           (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2);
   }
+  node_costs->num_compute_ops = ops;
 
   // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run
   // MaxPool internally.
-  double total_input_size =
+  const int64 input0_size =
       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
-  total_input_size +=
+  const int64 input2_size =
       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
+  node_costs->num_input_bytes_accessed = {input0_size, 0, input2_size};
   // Write x_grad; size equal to x.
-  const double total_output_size =
+  const int64 output_size =
       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+  node_costs->num_output_bytes_accessed = {output_size};
+  node_costs->max_memory = output_size;
 
-  Costs costs = PredictOpCountBasedCost(ops, total_input_size,
-                                        total_output_size, op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  costs.max_memory = total_output_size;
-  return costs;
+  if (found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
 /* This predict function handles three types of tensorflow ops
  * AssignVariableOp/AssignAddVariableOp/AssignSubVariableOp, broadcasting
  * was not possible for these ops, therefore the input tensor's shapes is
  * enough to compute the cost */
-Costs OpLevelCostEstimator::PredictAssignVariableOps(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictAssignVariableOps(
+    const OpContext& op_context, NodeCosts* node_costs) const {
   bool found_unknown_shapes = false;
   const auto& op_info = op_context.op_info;
   /* First input of these ops are reference to the assignee. */
-  if (op_info.inputs_size() != 2) return Costs::ZeroCosts(true);
-  const double total_input_size =
-      CalculateInputSize(op_info, &found_unknown_shapes);
-  const double flops = op_info.op() == kAssignVariableOp
-                           ? 0.0
-                           : CalculateTensorElementCount(op_info.inputs(1),
-                                                         &found_unknown_shapes);
-  Costs costs = PredictOpCountBasedCost(flops, total_input_size, 0, op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  return costs;
+  if (op_info.inputs_size() != 2) {
+    return errors::InvalidArgument("AssignVariable op has invalid input: ",
+                                   op_info.ShortDebugString());
+  }
+
+  const int64 ops = op_info.op() == kAssignVariableOp
+                        ? 0
+                        : CalculateTensorElementCount(op_info.inputs(1),
+                                                      &found_unknown_shapes);
+  node_costs->num_compute_ops = ops;
+  const int64 input_size = CalculateInputSize(op_info, &found_unknown_shapes);
+  node_costs->num_input_bytes_accessed = {input_size};
+  // TODO(dyoon): check these ops' behavior whether it writes data;
+  // Op itself doesn't have output tensor, but it may modify the input (ref or
+  // resource). Maybe use node_costs->internal_write_bytes.
+  node_costs->num_output_bytes_accessed = {0};
+  if (found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context,
+                                            NodeCosts* node_costs) const {
   bool found_unknown_shapes = false;
   const auto& op_info = op_context.op_info;
   // x: op_info.inputs(0)
@@ -2149,32 +2290,33 @@
 
   // kx * ky - 1 additions and 1 multiplication per output.
   int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
+  node_costs->num_compute_ops = ops;
 
-  double total_input_size = 0;
+  int64 input_size;
   if (dims.ky >= dims.sy) {
-    total_input_size =
-        CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+    input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
   } else {  // dims.ky < dims.sy
     // vertical stride is larger than vertical kernel; assuming row-major
     // format, skip unnecessary rows (or read every kx rows per sy rows, as the
     // others are not used for output).
     const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
-    total_input_size =
-        data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
+    input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
   }
-  const double total_output_size =
-      CalculateOutputSize(op_info, &found_unknown_shapes);
+  node_costs->num_input_bytes_accessed = {input_size};
 
-  Costs costs = PredictOpCountBasedCost(ops, total_input_size,
-                                        total_output_size, op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  costs.max_memory = total_output_size;
-  return costs;
+  const int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
+  node_costs->num_output_bytes_accessed = {output_size};
+  node_costs->max_memory = output_size;
+
+  if (found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictAvgPoolGrad(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictAvgPoolGrad(const OpContext& op_context,
+                                                NodeCosts* node_costs) const {
   bool found_unknown_shapes = false;
   const auto& op_info = op_context.op_info;
   // x's shape: op_info.inputs(0)
@@ -2212,22 +2354,14 @@
     ops = dims.batch * dims.iz *
           (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1));
   }
-
-  const double total_input_size =
-      CalculateInputSize(op_info, &found_unknown_shapes);
-  const double total_output_size =
-      CalculateOutputSize(op_info, &found_unknown_shapes);
-
-  Costs costs = PredictOpCountBasedCost(ops, total_input_size,
-                                        total_output_size, op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  costs.max_memory = total_output_size;
-  return costs;
+  auto s = PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
+                                   node_costs);
+  node_costs->max_memory = node_costs->num_total_output_bytes();
+  return s;
 }
 
-Costs OpLevelCostEstimator::PredictFusedBatchNorm(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictFusedBatchNorm(
+    const OpContext& op_context, NodeCosts* node_costs) const {
   bool found_unknown_shapes = false;
   const auto& op_info = op_context.op_info;
   // x: op_info.inputs(0)
@@ -2247,34 +2381,37 @@
   } else {
     ops = dims.batch * dims.ix * dims.iy * dims.iz * 2;
   }
+  node_costs->num_compute_ops = ops;
 
-  const double size_nhwc =
+  const int64 size_nhwc =
       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
-  const double size_c =
+  const int64 size_c =
       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
-  double total_input_size = 0.0;
-  double total_internal_read_size = 0.0;
-  double total_output_size = 0.0;
   if (is_training) {
-    total_input_size = size_nhwc + size_c * 2;
-    total_output_size = size_nhwc + size_c * 4;
-    total_internal_read_size = size_nhwc;
+    node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c};
+    node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
+                                             size_c};
+    // FusedBatchNorm in training mode internally re-reads the input tensor:
+    // one for mean/variance, and the 2nd internal read forthe actual scaling.
+    // Assume small intermediate data such as mean / variance (size_c) can be
+    // cached on-chip.
+    node_costs->internal_read_bytes = size_nhwc;
   } else {
-    total_input_size = size_nhwc + size_c * 4;
-    total_output_size = size_nhwc;
+    node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
+                                            size_c};
+    node_costs->num_output_bytes_accessed = {size_nhwc};
   }
+  node_costs->max_memory = node_costs->num_total_output_bytes();
 
-  Costs costs =
-      PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size,
-                              total_output_size, op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  costs.max_memory = total_output_size;
-  return costs;
+  if (found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictFusedBatchNormGrad(
+    const OpContext& op_context, NodeCosts* node_costs) const {
   bool found_unknown_shapes = false;
   const auto& op_info = op_context.op_info;
   // y_backprop: op_info.inputs(0)
@@ -2289,25 +2426,29 @@
   const auto rsqrt_cost = Eigen::internal::functor_traits<
       Eigen::internal::scalar_rsqrt_op<float>>::Cost;
   ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost);
+  node_costs->num_compute_ops = ops;
 
-  const double size_nhwc =
+  const int64 size_nhwc =
       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
-  const double size_c =
+  const int64 size_c =
       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
-  double total_input_size = size_nhwc * 2 + size_c * 2;
-  double total_internal_read_size = size_nhwc;
-  double total_output_size = size_nhwc * 1 + size_c * 2;
+  // TODO(dyoon): fix missing memory cost for variance input (size_c) and
+  // yet another read of y_backprop (size_nhwc) internally.
+  node_costs->num_input_bytes_accessed = {size_nhwc, size_nhwc, size_c, size_c};
+  node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c};
+  // FusedBatchNormGrad has to read y_backprop internally.
+  node_costs->internal_read_bytes = size_nhwc;
+  node_costs->max_memory = node_costs->num_total_output_bytes();
 
-  Costs costs =
-      PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size,
-                              total_output_size, op_info);
-  costs.inaccurate = found_unknown_shapes;
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  costs.max_memory = total_output_size;
-  return costs;
+  if (found_unknown_shapes) {
+    node_costs->inaccurate = true;
+    node_costs->num_nodes_with_unknown_shapes = 1;
+  }
+  return Status::OK();
 }
 
-Costs OpLevelCostEstimator::PredictNaryOp(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictNaryOp(const OpContext& op_context,
+                                           NodeCosts* node_costs) const {
   const auto& op_info = op_context.op_info;
   bool found_unknown_shapes = false;
   // Calculate the largest known tensor size across all inputs and output.
@@ -2331,21 +2472,22 @@
 
   const auto sum_cost = Eigen::internal::functor_traits<
       Eigen::internal::scalar_sum_op<float>>::Cost;
-  Costs costs = PredictOpCountBasedCost(op_count * sum_cost, op_info);
-  if (found_unknown_shapes) {
-    costs.inaccurate = true;
-  }
-  costs.num_ops_with_unknown_shapes = found_unknown_shapes;
-  return costs;
+  return PredictDefaultNodeCosts(op_count * sum_cost, op_context,
+                                 &found_unknown_shapes, node_costs);
 }
 
 // softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
-Costs OpLevelCostEstimator::PredictSoftmax(const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictSoftmax(const OpContext& op_context,
+                                            NodeCosts* node_costs) const {
   bool found_unknown_shapes = false;
   const int64 logits_size = CalculateTensorElementCount(
       op_context.op_info.inputs(0), &found_unknown_shapes);
-  TensorShapeProto logits_shape = MaybeGetMinimumShape(
-      op_context.op_info.inputs(0).shape(), 2, &found_unknown_shapes);
+  // Softmax input rank should be >=1.
+  TensorShapeProto logits_shape = op_context.op_info.inputs(0).shape();
+  if (logits_shape.unknown_rank() || logits_shape.dim_size() == 0) {
+    return errors::InvalidArgument("Softmax op has invalid input: ",
+                                   op_context.op_info.ShortDebugString());
+  }
 
 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
 
@@ -2359,23 +2501,21 @@
       EIGEN_COST(scalar_inverse_op<float>) * logits_shape.dim(0).size();
 
 #undef EIGEN_COST
-
-  return PredictOpCountBasedCost(ops, op_context.op_info);
+  return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
+                                 node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictResizeBilinear(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictResizeBilinear(
+    const OpContext& op_context, NodeCosts* node_costs) const {
   bool found_unknown_shapes = false;
 
   if (op_context.op_info.outputs().empty() ||
       op_context.op_info.inputs().empty()) {
-    return Costs::ZeroCosts(/*inaccurate=*/true);
+    return errors::InvalidArgument(
+        "ResizeBilinear op has invalid input / output ",
+        op_context.op_info.ShortDebugString());
   }
 
-  const int64 input_size =
-      CalculateTensorSize(op_context.op_info.inputs(0), &found_unknown_shapes);
-  const int64 output_size =
-      CalculateTensorSize(op_context.op_info.outputs(0), &found_unknown_shapes);
   const int64 output_elements = CalculateTensorElementCount(
       op_context.op_info.outputs(0), &found_unknown_shapes);
 
@@ -2384,7 +2524,7 @@
   bool use_half_pixel_centers = false;
   if (half_pixel_centers == op_context.op_info.attr().end()) {
     LOG(WARNING) << "half_pixel_centers attr not set for ResizeBilinear.";
-    return PredictCostOfAnUnknownOp(op_context);
+    return PredictCostOfAnUnknownOp(op_context, node_costs);
   } else {
     use_half_pixel_centers = half_pixel_centers->second.b();
   }
@@ -2454,12 +2594,12 @@
   //   return top + (bottom - top) * y_lerp;
   ops += (add_cost * 3 + sub_cost_float * 3 + mul_cost * 3) * output_elements;
 
-  return PredictOpCountBasedCost(ops, input_size, output_size,
-                                 op_context.op_info);
+  return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
+                                 node_costs);
 }
 
-Costs OpLevelCostEstimator::PredictCropAndResize(
-    const OpContext& op_context) const {
+Status OpLevelCostEstimator::PredictCropAndResize(const OpContext& op_context,
+                                                  NodeCosts* node_costs) const {
   bool found_unknown_shapes = false;
 
   const auto method = op_context.op_info.attr().find("method");
@@ -2472,14 +2612,9 @@
   } else {
     LOG(WARNING) << "method attr in CropAndResize invalid; expected bilinear "
                     "or nearest.";
-    return PredictCostOfAnUnknownOp(op_context);
+    return PredictCostOfAnUnknownOp(op_context, node_costs);
   }
 
-  const int input_size =
-      CalculateTensorSize(op_context.op_info.inputs(0), &found_unknown_shapes);
-  const int output_size =
-      CalculateOutputSize(op_context.op_info, &found_unknown_shapes);
-
   const int64 num_boxes = op_context.op_info.inputs(1).shape().dim(0).size();
   const auto crop_shape = MaybeGetMinimumShape(
       op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
@@ -2529,9 +2664,8 @@
     // Ops for innermost loop across depth.
     ops += cast_to_float_cost * output_elements;
   }
-
-  return PredictOpCountBasedCost(ops, input_size, output_size,
-                                 op_context.op_info);
+  return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
+                                 node_costs);
 }
 
 }  // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 238e159..5438292 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -16,9 +16,12 @@
 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_
 #define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_
 
+#include <numeric>
+
 #include "tensorflow/core/grappler/costs/cost_estimator.h"
 #include "tensorflow/core/grappler/costs/op_context.h"
 #include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/util/padding.h"
 
 namespace tensorflow {
@@ -29,6 +32,62 @@
 TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
                                       int rank, bool* found_unknown_shapes);
 
+// Node costs; an intermediate structure used within op level cost estimator.
+struct NodeCosts {
+  // If this FLAG is true, override calculated compute time with a minimum
+  // value, instead of calculating it from num_compute_ops and compute ops/sec.
+  // For example, PredictIdentity, PredictVariable, PredictMetadata set this
+  // FLAG.
+  bool minimum_cost_op = false;
+
+  // Compute ops.
+  int64 num_compute_ops = 0;
+
+  // Memory bytes accessed; note that these may be different to the size of
+  // tensors.
+  std::vector<int64> num_input_bytes_accessed;   // ordered by input tensors.
+  std::vector<int64> num_output_bytes_accessed;  // ordered by output ports.
+  int64 internal_read_bytes = 0;
+  int64 internal_write_bytes = 0;
+
+  // Convenience functions.
+  int64 num_total_input_bytes() const {
+    return std::accumulate(num_input_bytes_accessed.begin(),
+                           num_input_bytes_accessed.end(), 0LL);
+  }
+  int64 num_total_read_bytes() const {
+    return num_total_input_bytes() + internal_read_bytes;
+  }
+  int64 num_total_output_bytes() const {
+    return std::accumulate(num_output_bytes_accessed.begin(),
+                           num_output_bytes_accessed.end(), 0LL);
+  }
+  int64 num_total_write_bytes() const {
+    return num_total_output_bytes() + internal_write_bytes;
+  }
+  int64 num_bytes_accessed() const {
+    return num_total_read_bytes() + num_total_write_bytes();
+  }
+
+  // Memory usage.
+  int64 max_memory = 0;
+  int64 persistent_memory = 0;
+  int64 temporary_memory = 0;
+
+  // Stats.
+  int64 num_nodes = 1;
+  int64 num_nodes_with_unknown_shapes = 0;
+  int64 num_nodes_with_unknown_op_type = 0;
+  int64 num_nodes_with_pure_memory_op = 0;
+  bool inaccurate = false;
+
+  // TODO(dyoon): this is added for compatibility; some old code is hard to
+  // migrate; hence, using these as a backup. Once we clean up, we'll delete
+  // these fields. New code should not use these.
+  bool has_costs = false;
+  Costs costs;
+};
+
 class OpLevelCostEstimator {
  public:
   OpLevelCostEstimator();
@@ -40,9 +99,7 @@
   virtual DeviceInfo GetDeviceInfo(const DeviceProperties& device) const;
 
  protected:
-  // Predict cost of an op for which no accurate estimator is defined.
-  Costs PredictCostOfAnUnknownOp(const OpContext& op_context) const;
-
+  // TODO(dyoon): Consider to remove PredictOpCountBasedCosts() with OpInfo.
   // Naive cost estimate based on the given operations count and total
   // input/output tensor sizes of the given op_info combined.
   Costs PredictOpCountBasedCost(double operations, const OpInfo& op_info) const;
@@ -54,6 +111,16 @@
                                 double output_io_bytes,
                                 const OpInfo& op_info) const;
 
+  // Top-level method cost function (PredictCosts calls this method to get
+  // NodeCosts, and then converts it to Costs). PredictNodeCosts() calls other
+  // Predict methods depending on op types.
+  Status PredictNodeCosts(const OpContext& op_context,
+                          NodeCosts* node_costs) const;
+
+  // Predict cost of an op for which no accurate estimator is defined.
+  Status PredictCostOfAnUnknownOp(const OpContext& op_context,
+                                  NodeCosts* node_costs) const;
+
   // This family of routines predicts the costs to
   // perform the specified TensorFlow Op on the
   // device represented by a subclass. The default
@@ -64,37 +131,64 @@
   // Implementation of costs other than
   // execution_time is optional, depending on the
   // device.
-  Costs PredictNaryOp(const OpContext& op_context) const;
-  Costs PredictConv2D(const OpContext& op_context) const;
-  Costs PredictCwiseOp(const OpContext& op_context) const;
-  Costs PredictConv2DBackpropInput(const OpContext& op_context) const;
-  Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
-  Costs PredictFusedConv2DBiasActivation(const OpContext& op_context) const;
-  Costs PredictMatMul(const OpContext& op_context) const;
-  Costs PredictSparseTensorDenseMatMul(const OpContext& op_context) const;
-  Costs PredictNoOp(const OpContext& op_context) const;
-  Costs PredictIdentity(const OpContext& op_context) const;
-  Costs PredictVariable(const OpContext& op_context) const;
-  Costs PredictBatchMatMul(const OpContext& op_context) const;
-  Costs PredictMetadata(const OpContext& op_context) const;
-  Costs PredictGatherOrSlice(const OpContext& op_context) const;
-  Costs PredictScatter(const OpContext& op_context) const;
-  Costs PredictMaxPool(const OpContext& op_context) const;
-  Costs PredictMaxPoolGrad(const OpContext& op_context) const;
-  Costs PredictAvgPool(const OpContext& op_context) const;
-  Costs PredictAvgPoolGrad(const OpContext& op_context) const;
-  Costs PredictFusedBatchNorm(const OpContext& op_context) const;
-  Costs PredictFusedBatchNormGrad(const OpContext& op_context) const;
-  Costs PredictEinsum(const OpContext& op_context) const;
-  Costs PredictAssignVariableOps(const OpContext& op_context) const;
-  Costs PredictPureMemoryOp(const OpContext& op_context) const;
-  Costs PredictSoftmax(const OpContext& op_context) const;
-  Costs PredictResizeBilinear(const OpContext& op_context) const;
-  Costs PredictCropAndResize(const OpContext& op_context) const;
+  Status PredictNaryOp(const OpContext& op_context,
+                       NodeCosts* node_costs) const;
+  Status PredictConv2D(const OpContext& op_context,
+                       NodeCosts* node_costs) const;
+  Status PredictCwiseOp(const OpContext& op_context,
+                        NodeCosts* node_costs) const;
+  Status PredictConv2DBackpropInput(const OpContext& op_context,
+                                    NodeCosts* node_costs) const;
+  Status PredictConv2DBackpropFilter(const OpContext& op_context,
+                                     NodeCosts* node_costs) const;
+  Status PredictFusedConv2DBiasActivation(const OpContext& op_context,
+                                          NodeCosts* node_costs) const;
+  Status PredictMatMul(const OpContext& op_context,
+                       NodeCosts* node_costs) const;
+  Status PredictSparseTensorDenseMatMul(const OpContext& op_context,
+                                        NodeCosts* node_costs) const;
+  Status PredictNoOp(const OpContext& op_context, NodeCosts* node_costs) const;
+  Status PredictIdentity(const OpContext& op_context,
+                         NodeCosts* node_costs) const;
+  Status PredictVariable(const OpContext& op_context,
+                         NodeCosts* node_costs) const;
+  Status PredictBatchMatMul(const OpContext& op_context,
+                            NodeCosts* node_costs) const;
+  Status PredictMetadata(const OpContext& op_context,
+                         NodeCosts* node_costs) const;
+  Status PredictGatherOrSlice(const OpContext& op_context,
+                              NodeCosts* node_costs) const;
+  Status PredictScatter(const OpContext& op_context,
+                        NodeCosts* node_costs) const;
+  Status PredictMaxPool(const OpContext& op_context,
+                        NodeCosts* node_costs) const;
+  Status PredictMaxPoolGrad(const OpContext& op_context,
+                            NodeCosts* node_costs) const;
+  Status PredictAvgPool(const OpContext& op_context,
+                        NodeCosts* node_costs) const;
+  Status PredictAvgPoolGrad(const OpContext& op_context,
+                            NodeCosts* node_costs) const;
+  Status PredictFusedBatchNorm(const OpContext& op_context,
+                               NodeCosts* node_costs) const;
+  Status PredictFusedBatchNormGrad(const OpContext& op_context,
+                                   NodeCosts* node_costs) const;
+  Status PredictEinsum(const OpContext& op_context,
+                       NodeCosts* node_costs) const;
+  Status PredictAssignVariableOps(const OpContext& op_context,
+                                  NodeCosts* node_costs) const;
+  Status PredictPureMemoryOp(const OpContext& op_context,
+                             NodeCosts* node_costs) const;
+  Status PredictSoftmax(const OpContext& op_context,
+                        NodeCosts* node_costs) const;
+  Status PredictResizeBilinear(const OpContext& op_context,
+                               NodeCosts* node_costs) const;
+  Status PredictCropAndResize(const OpContext& op_context,
+                              NodeCosts* node_costs) const;
 
   // Generic cost prediction method for fused operations.
-  Costs PredictFusedOp(const OpContext& op_context,
-                       const std::vector<OpContext>& fused_op_contexts) const;
+  Status PredictFusedOp(const OpContext& op_context,
+                        const std::vector<OpContext>& fused_op_contexts,
+                        NodeCosts* node_costs) const;
 
   // Utility function for safe division. Returns 0
   // if rhs is 0 or negative.
@@ -176,11 +270,19 @@
   static int64 CalculateInputSize(const OpInfo& op_info,
                                   bool* found_unknown_shapes);
 
+  // Same, but a vector format: one for each input.
+  static std::vector<int64> CalculateInputTensorSize(
+      const OpInfo& op_info, bool* found_unknown_shapes);
+
   // Calculate the total size in bytes of the all
   // the outputs of specified TensorFlow op.
   static int64 CalculateOutputSize(const OpInfo& op_info,
                                    bool* found_unknown_shapes);
 
+  // Same, but a vector format: one for each output.
+  static std::vector<int64> CalculateOutputTensorSize(
+      const OpInfo& op_info, bool* found_unknown_shapes);
+
   // For convolution and its grad ops.
   static ConvolutionDimensions ConvolutionDimensionsFromInputs(
       const TensorShapeProto& original_image_shape,
@@ -203,9 +305,16 @@
   static OpInfo::TensorProperties DescribeTensor(
       DataType type, const std::vector<int64>& dims);
 
+  // Helper method for building common case NodeCosts.
+  static Status PredictDefaultNodeCosts(const int64 num_compute_ops,
+                                        const OpContext& op_context,
+                                        bool* found_unknown_shapes,
+                                        NodeCosts* node_costs);
+
  protected:
   std::map<string, int> elementwise_ops_;
-  typedef std::function<Costs(const OpContext& op_context)> CostImpl;
+  typedef std::function<Status(const OpContext& op_context, NodeCosts*)>
+      CostImpl;
   std::map<string, CostImpl> device_cost_impl_;
   // If true, assume compute and memory overlap; hence, the op cost is max of
   // compute_time and memory_time, instead of sum of those two.
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index fb2e445..23373d3 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -894,8 +894,8 @@
       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false,
       "NCHW", "HWIO"));
   EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
-  EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
-  EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
+  EXPECT_EQ(Costs::Duration(355321037), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(356146382), cost.execution_time);
   EXPECT_EQ(cost.num_ops_total, 1);
   EXPECT_FALSE(cost.inaccurate);
   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
@@ -908,8 +908,8 @@
       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
       "NCHW", "HWIO"));
   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
-  EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
-  EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+  EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
   EXPECT_EQ(cost.num_ops_total, 1);
   EXPECT_FALSE(cost.inaccurate);
   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
@@ -922,8 +922,8 @@
       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
       "NCHW", "OIHW"));
   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
-  EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
-  EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+  EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
   EXPECT_EQ(cost.num_ops_total, 1);
   EXPECT_FALSE(cost.inaccurate);
   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
@@ -936,8 +936,8 @@
       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
       "NHWC", "HWIO"));
   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
-  EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
-  EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+  EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
   EXPECT_EQ(cost.num_ops_total, 1);
   EXPECT_FALSE(cost.inaccurate);
   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
@@ -950,8 +950,8 @@
       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
       "NHWC", "OIHW"));
   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
-  EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
-  EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+  EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
   EXPECT_EQ(cost.num_ops_total, 1);
   EXPECT_FALSE(cost.inaccurate);
   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
@@ -964,8 +964,8 @@
       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
       "NCHW_VECT_C", "OIHW"));
   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
-  EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
-  EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+  EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
   EXPECT_EQ(cost.num_ops_total, 1);
   EXPECT_FALSE(cost.inaccurate);
   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
@@ -978,8 +978,8 @@
       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
       "NCHW", "OIHW_VECT_I"));
   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
-  EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
-  EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+  EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
   EXPECT_EQ(cost.num_ops_total, 1);
   EXPECT_FALSE(cost.inaccurate);
   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
@@ -993,8 +993,8 @@
       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
       "NCHW_VECT_C", "OIHW_VECT_I"));
   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
-  EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
-  EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+  EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
   EXPECT_EQ(cost.num_ops_total, 1);
   EXPECT_FALSE(cost.inaccurate);
   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
@@ -2255,9 +2255,14 @@
   DescribeTensor4D(kNumBoxes, kOutputImageDim, kOutputImageDim, kChannelSize,
                    op_context.op_info.add_outputs());
 
+  // Note this is time [ns, default in Duration in Costs], not bytes;
+  // whereas memory bandwidth from SetCpuDevice() is 10GB/s.
   const int kExpectedMemoryTime =
-      (kImageDim * kImageDim + kNumBoxes * kOutputImageDim * kOutputImageDim) *
-      4;
+      (kImageDim * kImageDim * 4 +  // input image in float.
+       kNumBoxes * 4 * 8 / 10 +     // boxes (kNumBoxes x 4) in int64.
+       kNumBoxes * kOutputImageDim * kOutputImageDim * 4);  // output in float.
+  // Note that input image and output image has kChannelSize dim, which is 10,
+  // hence, no need to divide it by 10 (bandwidth).
 
   {
     // Cost of CropAndResize with bilinear interpolation.
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index 4b58456..8989245 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -31,6 +31,24 @@
 namespace tensorflow {
 namespace grappler {
 
+GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
+  GrapplerItem::OptimizationOptions optimization_options;
+  // Tensorflow 2.0 in eager mode with automatic control dependencies will
+  // prune all nodes that are not in the transitive fanin of the fetch nodes.
+  // However because the function will be executed via FunctionLibraryRuntime,
+  // and current function implementation does not prune stateful and dataset
+  // ops, we rely on Grappler to do the correct graph pruning.
+  optimization_options.allow_pruning_stateful_and_dataset_ops = true;
+
+  optimization_options.is_eager_mode = true;
+
+  // All the nested function calls will be executed and optimized via
+  // PartitionedCallOp, there is no need to optimize functions now.
+  optimization_options.optimize_function_library = false;
+
+  return optimization_options;
+}
+
 GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
   GrapplerItem item;
   item.id = id;
diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h
index 99d6d2c4..7a3900f 100644
--- a/tensorflow/core/grappler/grappler_item.h
+++ b/tensorflow/core/grappler/grappler_item.h
@@ -133,6 +133,8 @@
   OptimizationOptions optimization_options_;
 };
 
+GrapplerItem::OptimizationOptions CreateOptOptionsForEager();
+
 }  // end namespace grappler
 }  // end namespace tensorflow
 
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
index e09ea57..0cd0db3 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
@@ -429,11 +429,16 @@
     StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
     const SetInputFn& set_input, const SetOutputFn& set_output,
     const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
-  auto has_attrs = [](const FunctionDef& func) {
-    return !(func.attr_size() == 0 ||
-             (func.attr_size() == 1 && data::IsTFDataFunction(func)));
+  auto has_unknown_attrs = [](const FunctionDef& func) {
+    int known_attribute_size = 0;
+
+    if (data::IsTFDataFunction(func)) known_attribute_size += 1;
+    if (func.attr().contains("_construction_context"))
+      known_attribute_size += 1;
+
+    return func.attr_size() > known_attribute_size;
   };
-  if (has_attrs(first_function) || has_attrs(second_function)) {
+  if (has_unknown_attrs(first_function) || has_unknown_attrs(second_function)) {
     return nullptr;  // Functions with attributes are currently not supported.
   }
 
@@ -474,6 +479,28 @@
 
   set_nodes(first_function, setup_function, fused_function, library);
   (*fused_function->mutable_attr())[data::kTFDataFunction].set_b(true);
+
+  // Preserve `_construction_context` attribute in the fused function.
+  auto get_construction_context = [](const FunctionDef& func) {
+    auto iter = func.attr().find("_construction_context");
+    if (iter == func.attr().cend()) return std::string();
+    return iter->second.s();
+  };
+  std::string first_construction_context =
+      get_construction_context(first_function);
+  std::string second_construction_context =
+      get_construction_context(second_function);
+  if (first_construction_context != second_construction_context) {
+    LOG(ERROR) << "_construction_context attribute mismatch during fused "
+                  "function optimization pass. First function: "
+               << first_construction_context
+               << " Second function: " << first_construction_context;
+  }
+  if (!first_construction_context.empty()) {
+    (*fused_function->mutable_attr())["_construction_context"].set_s(
+        first_construction_context);
+  }
+
   return fused_function;
 }
 
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index ad9f3d5..96c5339 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -830,7 +830,8 @@
        // to run asynchronously to avoid deadlock.
        "CollectiveGather", "CollectiveGatherV2", "CollectiveReduce",
        "CollectiveReduceV2", "CollectiveBcastSend", "CollectiveBcastRecv",
-       "NcclAllReduce", "Send", "Recv",
+       "CollectiveBcastSendV2", "CollectiveBcastRecvV2", "NcclAllReduce",
+       "Send", "Recv",
 
        // Legacy random ops.
        // See details in tensorflow/python/framework/auto_control_deps.py.
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc
index f584b8d..8c26195 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc
@@ -20,7 +20,7 @@
 namespace grappler {
 
 const NodeScopeAndName ParseNodeScopeAndName(const string& node_name) {
-  auto pos = node_name.find_last_of("/");
+  auto pos = node_name.find_last_of('/');
   if (pos == string::npos) {
     return {"", node_name};
   } else {
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 6fcbe37..4009493 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -170,17 +170,22 @@
 #define MK_OPT(NAME, VALUE) \
   if (optimizer == NAME) return std::unique_ptr<GraphOptimizer>(VALUE)
 
-bool MetaOptimizer::IsSingleThreadedExecutor() const {
-  return config_proto_.experimental().executor_type() ==
-         "SINGLE_THREADED_EXECUTOR";
+bool MetaOptimizer::LowerControlFlow() const {
+  if (config_proto_.experimental().executor_type() ==
+      "SINGLE_THREADED_EXECUTOR")
+    return false;
+
+  if (config_proto_.experimental().use_tfrt()) return false;
+
+  return true;
 }
 
 std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
     const string& optimizer) const {
   MK_OPT("pruning", new ModelPruner());
-  MK_OPT("function", new FunctionOptimizer(
-                         cfg_.function_optimization(),
-                         /*lower_control_flow=*/!IsSingleThreadedExecutor()));
+  MK_OPT("function",
+         new FunctionOptimizer(cfg_.function_optimization(),
+                               /*lower_control_flow=*/LowerControlFlow()));
   MK_OPT("constfold",
          new ConstantFolding(
              cpu_device_,
@@ -235,7 +240,7 @@
   if (cfg_.function_optimization() != RewriterConfig::OFF) {
     optimizers->push_back(MakeUnique<FunctionOptimizer>(
         cfg_.function_optimization(),
-        /*lower_control_flow=*/!IsSingleThreadedExecutor()));
+        /*lower_control_flow=*/LowerControlFlow()));
   }
   if (cfg_.common_subgraph_elimination() != RewriterConfig::OFF &&
       cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index b21ea68..d3b489b 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -61,7 +61,8 @@
   std::unique_ptr<GraphOptimizer> MakeNewOptimizer(
       const string& optimizer) const;
 
-  bool IsSingleThreadedExecutor() const;
+  // When grappler should lower control flow to V1 switch/merge style nodes.
+  bool LowerControlFlow() const;
 
   // Initialize active optimizers from RewriterConfig toggles.
   Status InitializeOptimizers(
diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc
index f8574a4..e9270ff 100644
--- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc
+++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc
@@ -446,6 +446,84 @@
     }
   }
 }
+
+TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
+  using ::tensorflow::ops::Placeholder;
+
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+  auto input_shape = ops::Placeholder::Shape({4, 32});
+  auto input_shape_add = ops::Placeholder::Shape({4, 8});
+  auto filter_shape = ops::Placeholder::Shape({32, 8});
+  auto bias_shape = ops::Placeholder::Shape({8});
+
+  auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
+  auto input_add =
+      Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
+  auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
+  auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
+
+  auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
+  auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
+
+  auto fetch = s.WithOpName("fetch");
+  auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
+
+  ops::Identity(fetch, add);
+
+  auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
+      TensorShape(input_shape.shape_.dim_sizes()));
+  auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
+      TensorShape(input_shape_add.shape_.dim_sizes()));
+  auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
+      TensorShape(filter_shape.shape_.dim_sizes()));
+  auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
+      TensorShape(bias_shape.shape_.dim_sizes()));
+
+  GrapplerItem item;
+  item.fetch = {"fetch"};
+  item.feed = {{"input", input_tensor},
+               {"filter", filter_tensor},
+               {"bias", bias_tensor},
+               {"input_add", input_add_tensor}};
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  // Place all nodes on CPU.
+  for (int i = 0; i < item.graph.node_size(); ++i) {
+    item.graph.mutable_node(i)->set_device("/device:CPU:0");
+  }
+
+  Remapper optimizer(RewriterConfig::AGGRESSIVE);
+  GraphDef output;
+  TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
+
+  int found = 0;
+  for (const NodeDef& node : output.node()) {
+    auto fetch_node_name = "add";
+    if (node.name() == fetch_node_name) {
+      EXPECT_EQ("_FusedMatMul", node.op());
+      EXPECT_EQ("input", node.input(0));
+      EXPECT_EQ("filter", node.input(1));
+
+      EXPECT_EQ(2, node.attr().at("num_args").i());
+      EXPECT_EQ("bias", node.input(2));
+      EXPECT_EQ("input_add", node.input(3));
+
+      const auto fused_ops = node.attr().at("fused_ops").list().s();
+      EXPECT_EQ(2, fused_ops.size());
+      EXPECT_EQ("BiasAdd", fused_ops[0]);
+      EXPECT_EQ("Add", fused_ops[1]);
+      found++;
+    }
+  }
+  EXPECT_EQ(1, found);
+
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+  auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+  EXPECT_EQ(1, tensors_expected.size());
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
+}
 #endif  // ENABLE_MKLDNN_V1
 
 }  // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index b9bd643..d7705e9 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -1283,28 +1283,36 @@
   const NodeDef& contraction = graph->node(matched.contraction);
   const NodeDef& bias_add = graph->node(matched.bias_add);
 
-  // MKL version only support fusion for Conv2D
-  DCHECK(IsConv2D(contraction));
+  // MKL version only support fusion for Conv2D and MatMul
+  DCHECK(IsConv2D(contraction) || IsMatMul(contraction));
 
-  NodeDef fused_conv2d;
+  NodeDef contraction_node;
   const NodeDef& add = graph->node(matched.add);
-  fused_conv2d.set_name(add.name());
-  fused_conv2d.set_op(kFusedConv2D);
-  fused_conv2d.set_device(contraction.device());
-  fused_conv2d.add_input(contraction.input(0));  // 0: input
-  fused_conv2d.add_input(contraction.input(1));  // 1: filter
-  fused_conv2d.add_input(bias_add.input(1));     // 2: bias
+  contraction_node.set_name(add.name());
+  contraction_node.set_device(contraction.device());
+  contraction_node.add_input(
+      contraction.input(0));  // 0: input(conv) / a (matmul)
+  contraction_node.add_input(
+      contraction.input(1));  // 1: filter(conv) / b (matmul)
+  contraction_node.add_input(bias_add.input(1));  // 2: bias
 
-  // Add OP has two inputs, one is conv+bias pattern matched previously,
-  // the other input to add is fused here.
-  fused_conv2d.add_input(add.input(1 - matched.port_id));
+  // Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
+  // previously, the other input to add is fused here.
+  contraction_node.add_input(add.input(1 - matched.port_id));
 
-  CopyConv2DAttributes(contraction, &fused_conv2d);
-  SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add"}, 2);
+  if (IsConv2D(contraction)) {
+    contraction_node.set_op(kFusedConv2D);
+    CopyConv2DAttributes(contraction, &contraction_node);
+  } else if (IsMatMul(contraction)) {
+    contraction_node.set_op(kFusedMatMul);
+    CopyMatMulAttributes(contraction, &contraction_node);
+  }
+
+  SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2);
 
   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
   Status status;
-  mutation->AddNode(std::move(fused_conv2d), &status);
+  mutation->AddNode(std::move(contraction_node), &status);
   TF_RETURN_IF_ERROR(status);
   TF_RETURN_IF_ERROR(mutation->Apply());
 
@@ -1621,19 +1629,25 @@
 }
 
 #ifdef INTEL_MKL
-bool IsConv2DWithAdd(const RemapperContext& ctx, int node_index) {
+bool IsConv2DOrMatMul(const NodeDef& node) {
+  return IsConv2D(node) || IsMatMul(node);
+}
+
+bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
   const auto* node_view = ctx.graph_view.GetNode(node_index);
   const auto* node_def = node_view->node();
 
   // Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
+  //               MatMul + Add or MatMul + BiasAdd + Add fusion.
   auto is_supported_add_input = [](const auto* node_view) -> bool {
-    if (IsConv2D(*node_view->node())) return true;
+    // Currently only support Conv2D and MatMul
+    if (IsConv2DOrMatMul(*node_view->node())) return true;
     if (IsBiasAdd(*node_view->node())) {
       if (node_view->NumRegularFanins() < 2) return false;
       const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
       const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
-      return IsConv2D(*bias_add_fanin_0.node_view()->node()) ||
-             IsConv2D(*bias_add_fanin_1.node_view()->node());
+      return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
+             IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node());
     }
     return false;
   };
@@ -1739,7 +1753,7 @@
 
 #ifdef INTEL_MKL
   return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
-         IsConv2DWithAdd(ctx, node_index);
+         IsContractionWithAdd(ctx, node_index);
 #else
   return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
          is_batch_norm_fusion_candidate();
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
index 11f9589..cf601c9 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
@@ -48,13 +48,13 @@
 // matches op_name, i.e. it looks from the name like this node is
 // of that op type.
 bool HasOpName(const string& node_name, const string& op_name) {
-  size_t begin = node_name.rfind("/");
+  size_t begin = node_name.rfind('/');
   if (begin == string::npos) {
     begin = 0;
   } else {
     ++begin;
   }
-  size_t end = node_name.rfind("_");
+  size_t end = node_name.rfind('_');
   if (end != string::npos) {
     size_t p = end + 1;
     while (p < node_name.size()) {
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 91c703e..2ea2a8a 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -16,8 +16,8 @@
 )
 load(
     "//tensorflow/core/kernels/mlir_generated:build_defs.bzl",
+    "if_mlir_experimental_kernels_enabled",
     "if_mlir_generated_gpu_kernels_enabled",
-    "if_mlir_unranked_kernels_enabled",
 )
 
 # buildifier: disable=same-origin-load
@@ -632,6 +632,7 @@
     srcs = ["batch_kernels.cc"],
     deps = [
         ":ops_util_hdrs",
+        "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
@@ -3229,6 +3230,7 @@
     "//tensorflow/core:lib_internal",
     "//tensorflow/core:math_grad",
     "//tensorflow/core/framework:bounds_check",
+    "//tensorflow/core/framework:op_requires",
     "//third_party/eigen3",
 ]
 
@@ -3280,65 +3282,6 @@
     deps = MATH_DEPS,
 )
 
-cc_library(
-    name = "aggregate_ops_headers",
-    hdrs = [
-        "aggregate_ops.h",
-        "aggregate_ops_cpu.h",
-    ],
-    deps = select({
-        "//tensorflow:android": [
-            "//tensorflow/core:portable_tensorflow_lib_lite",  # TODO(annarev): exclude runtime srcs
-        ],
-        "//conditions:default": [
-            "//third_party/eigen3",
-            "//tensorflow/core:framework",
-        ],
-    }),
-)
-
-# TODO(annarev): conv_ops_3d_headers currently depends on android target build
-# from selected sources. We should switch to use granular dependencies instead.
-# Then, we can just depend on "conv3d".
-cc_library(
-    name = "conv_3d_mobile",
-    hdrs = [
-        "conv_3d.h",
-        "eigen_backward_cuboid_convolutions.h",
-        "eigen_convolution_helpers.h",
-        "eigen_cuboid_convolution.h",
-        "eigen_volume_patch.h",
-    ],
-    deps = [
-        ":eigen_spatial_convolutions-inl",
-    ] + select({
-        "//tensorflow:android": [
-            "//tensorflow/core:portable_tensorflow_lib_lite",  # TODO(annarev): exclude runtime srcs
-        ],
-        "//conditions:default": [
-            "//tensorflow/core:framework",
-        ],
-    }),
-)
-
-cc_library(
-    name = "conv_ops_3d_headers",
-    hdrs = [
-        "conv_ops_3d.h",
-    ],
-    deps = select({
-        "//tensorflow:android": [
-            ":conv_3d_mobile",
-            "//tensorflow/core:portable_tensorflow_lib_lite",  # TODO(annarev): exclude runtime srcs
-        ],
-        "//conditions:default": [
-            ":conv_3d",
-            "//third_party/eigen3",
-            "//tensorflow/core:framework",
-        ],
-    }),
-)
-
 tf_kernel_library(
     name = "argmax_op",
     prefix = "argmax_op",
@@ -3404,7 +3347,7 @@
 tf_kernel_library(
     name = "cwise_op",
     copts = if_mlir_generated_gpu_kernels_enabled(if_true = ["-DMLIR_GENERATED_GPU_KERNELS_ENABLED=1"]) +
-            if_mlir_unranked_kernels_enabled(if_true = ["-DMLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED=1"]),
+            if_mlir_experimental_kernels_enabled(if_true = ["-DMLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED=1"]),
     prefix = "cwise_op",
     deps = MATH_DEPS + if_mlir_generated_gpu_kernels_enabled(if_true = ["//tensorflow/core/kernels/mlir_generated:cwise_op"]),
 )
@@ -3810,6 +3753,7 @@
         "deep_conv2d.h",
         "gemm_functors.h",
         "winograd_transform.h",
+        "conv_ops_fused_impl.h",
     ] + select({
         ":xsmm_convolutions": ["xsmm_conv2d.h"],
         "//conditions:default": [],
@@ -3824,7 +3768,6 @@
     prefix = "conv_ops",
     deps = [
         ":conv_grad_shape_utils",
-        ":conv_ops_3d_headers",
         ":conv_2d",
         ":conv_3d",
         ":eigen_contraction_kernel",
@@ -5948,7 +5891,6 @@
         "conv_2d.h",
         "conv_3d.h",
         "conv_ops.h",
-        "conv_ops_3d.h",
         "conv_ops_gpu.h",
         "data_format_ops.h",
         "depthtospace_op.h",
@@ -5994,6 +5936,7 @@
         "spectrogram.h",
         "stateless_random_ops.h",
         "stateless_random_ops_v2.h",
+        "sparse_fill_empty_rows_op.h",
         "string_util.h",
         "string_to_hash_bucket_op.h",
         "string_to_hash_bucket_fast_op.h",
@@ -6262,6 +6205,7 @@
         "stateless_random_ops.cc",
         "stateless_random_ops_v2.cc",
         "string_join_op.cc",
+        "string_length_op.cc",
         "string_lower_op.cc",
         "string_util.cc",
         "string_split_op.cc",
@@ -6444,7 +6388,6 @@
         "stateful_random_ops_cpu_gpu.h",
         # Allows conv_3d ops for android but excluded from *_3d* rule above.
         "conv_3d.h",
-        "conv_ops_3d.h",
         "conv_ops_3d.cc",
         "conv_ops_gpu.h",
     ],
@@ -7523,6 +7466,7 @@
     "cwise_op_gpu_sigmoid.cu.cc",
     "cwise_op_gpu_sin.cu.cc",
     "cwise_op_gpu_sqrt.cu.cc",
+    "cwise_op_gpu_square.cu.cc",
     "cwise_op_gpu_squared_difference.cu.cc",
     "cwise_op_gpu_sub.cu.cc",
     "cwise_op_gpu_tanh.cu.cc",
diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc
index 3b6f89a..1cf2f54 100644
--- a/tensorflow/core/kernels/aggregate_ops.cc
+++ b/tensorflow/core/kernels/aggregate_ops.cc
@@ -19,21 +19,238 @@
 
 #include "tensorflow/core/kernels/aggregate_ops.h"
 
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_requires.h"
 #include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
 #include "tensorflow/core/kernels/aggregate_ops_cpu.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
 
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
 typedef Eigen::GpuDevice GPUDevice;
 
+template <typename Device, typename T>
+class AddNOp : public OpKernel {
+ public:
+  explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    if (!ctx->ValidateInputsAreSameShape(this)) return;
+
+    const Tensor& input0 = ctx->input(0);
+    const int num = ctx->num_inputs();
+
+    if (num == 1) {
+      ctx->set_output(0, input0);
+      return;
+    }
+
+    // Try to forward and accumulate the result in one of the input buffers.
+    int reused_input = -1;
+    gtl::InlinedVector<int, 8> input_indices(num);
+    std::iota(input_indices.begin(), input_indices.end(), 0);
+    Tensor* output = nullptr;
+    for (int input_idx = 0; input_idx < num; ++input_idx) {
+      if (ctx->forward_input_to_output_with_shape(input_idx, 0, input0.shape(),
+                                                  &output)) {
+        reused_input = input_idx;
+        break;
+      }
+    }
+    if (reused_input == -1) {
+      OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input0.shape(), &output));
+    } else if (reused_input > 0) {
+      // Move the forwarded buffer to the front so we don't double count
+      // anything if there are more than 8 inputs.
+      input_indices[0] = reused_input;
+      input_indices[reused_input] = 0;
+    }
+    auto To = output->flat<T>();
+
+#define I(IDX) ctx->input(input_indices[IDX]).template flat<T>()
+
+#if defined(__ANDROID_TYPES_SLIM__)
+    // On Android by default,we only support additions of two arguments, so we
+    // can reduce the number of template instantiations.
+    OP_REQUIRES(ctx, num == 2,
+                errors::InvalidArgument("Only additions of two arguments "
+                                        "supported. Num inputs: ",
+                                        num));
+    functor::Add2Functor<Device, T> functor2;
+    functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
+#else
+    static const int kWidth = 8;
+    int r = num % kWidth;
+
+    switch (r) {
+      case 2: {
+        functor::Add2Functor<Device, T> functor2;
+        functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
+        break;
+      }
+      case 3: {
+        functor::Add3Functor<Device, T> functor3;
+        functor3(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2));
+        break;
+      }
+      case 4: {
+        functor::Add4Functor<Device, T> functor4;
+        functor4(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+                 I(3));
+        break;
+      }
+      case 5: {
+        functor::Add5Functor<Device, T> functor5;
+        functor5(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+                 I(3), I(4));
+        break;
+      }
+      case 6: {
+        functor::Add6Functor<Device, T> functor6;
+        functor6(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+                 I(3), I(4), I(5));
+        break;
+      }
+      case 7: {
+        functor::Add7Functor<Device, T> functor7;
+        functor7(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+                 I(3), I(4), I(5), I(6));
+        break;
+      }
+      case 0: {
+        functor::Add8Functor<Device, T> functor8;
+        functor8(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+                 I(3), I(4), I(5), I(6), I(7));
+        r = 8;
+        break;
+      }
+      case 1: {
+        functor::Add9Functor<Device, T> functor9;
+        functor9(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+                 I(3), I(4), I(5), I(6), I(7), I(8));
+        r = 9;
+        break;
+      }
+    }
+
+    for (; r < num; r += kWidth) {
+      functor::Add8pFunctor<Device, T> functor8p;
+      functor8p(ctx->template eigen_device<Device>(), To, I(r), I(r + 1),
+                I(r + 2), I(r + 3), I(r + 4), I(r + 5), I(r + 6), I(r + 7));
+    }
+#endif  // defined(__ANDROID_TYPES_SLIM__)
+
+#undef I
+  }
+};
+
+template <typename Device>
+class AddNOp<Device, Variant> : public OpKernel {
+ public:
+  explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    if (!ctx->ValidateInputsAreSameShape(this)) return;
+
+    const Tensor& input0 = ctx->input(0);
+    const int num = ctx->num_inputs();
+
+    if (num == 1) {
+      ctx->set_output(0, input0);
+      return;
+    }
+
+    for (int i = 0; i < num; ++i) {
+      // Step 1: ensure unary variants.
+      OP_REQUIRES(
+          ctx, ctx->input(i).dims() == 0,
+          errors::InvalidArgument(
+              "AddN of non-scalar Tensor with dtype=DT_VARIANT is not "
+              "supported; inputs[",
+              i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
+    }
+
+    // Step 2: Sum input variants in a tree-like structure using
+    //   BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
+    //   For the output create a default-constructed variant object.
+    //
+    // Pairwise summation provides better numerical precision by
+    // reducing round-off error:
+    //
+    //   https://en.wikipedia.org/wiki/Pairwise_summation
+    //
+    // These two vectors are used to store and mark intermediate sums.
+    gtl::InlinedVector<bool, 4> temp_filled(num, false);
+    gtl::InlinedVector<Variant, 4> temp(num);
+
+    // Tree-based summation.
+    int skip = 1;
+    int n = num;
+    while (skip < n) {
+      int i = skip;
+      while (i < n) {
+        // TODO(ebrevdo, rmlarsen): Parallelize the pairwise summations in the
+        // inner loop if the variants are "large".
+
+        // x[i - skip] += x[i]
+        OP_REQUIRES_OK(ctx,
+                       AddVariantTo(ctx, i - skip, i, &temp, &temp_filled));
+        // We won't use this index again, recover its memory.
+        temp[i].clear();
+        i += 2 * skip;
+      }
+      if (i == n) {
+        // x[0] += x[i - skip]
+        OP_REQUIRES_OK(ctx,
+                       AddVariantTo(ctx, 0, i - skip, &temp, &temp_filled));
+        // We won't use this index again, recover its memory.
+        temp[i - skip].clear();
+        n -= skip;
+      }
+      skip *= 2;
+    }
+
+    Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
+    out.scalar<Variant>()() = std::move(temp[0]);
+    ctx->set_output(0, out);
+  }
+
+ private:
+  // AddVariantTo efficiently performs:
+  //    temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
+  // where array(ix) := (temp_filled[ix]
+  //                     ? temp[ix]
+  //                     : ctx->input(ix).scalar<Variant>()())
+  // This reduces (possibly expensive) copying of Variants from
+  // the inputs into temp at the lowest levels of the summation tree.
+  static inline Status AddVariantTo(OpKernelContext* ctx, const int lhs_ix,
+                                    const int rhs_ix,
+                                    gtl::InlinedVector<Variant, 4>* temp,
+                                    gtl::InlinedVector<bool, 4>* temp_filled) {
+    Variant tmp;
+    if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
+    const Variant& a = temp_filled->at(lhs_ix)
+                           ? tmp
+                           : ctx->input(lhs_ix).template scalar<Variant>()();
+    const Variant& b = temp_filled->at(rhs_ix)
+                           ? temp->at(rhs_ix)
+                           : ctx->input(rhs_ix).template scalar<Variant>()();
+    Variant* c = &temp->at(lhs_ix);
+    TF_RETURN_IF_ERROR(
+        BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
+    temp_filled->at(lhs_ix) = true;
+    return Status::OK();
+  }
+};
+
 #define REGISTER_ADDN(type, dev)                                   \
   REGISTER_KERNEL_BUILDER(                                         \
       Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
-      AddNOp<dev##Device, type, OpKernel, OpKernelConstruction,    \
-             OpKernelContext>)
+      AddNOp<dev##Device, type>)
 
 #define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU)
 
@@ -54,17 +271,15 @@
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
 // registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(
-    Name("AddN")
-        .Device(DEVICE_GPU)
-        .TypeConstraint<int32>("T")
-        .HostMemory("inputs")
-        .HostMemory("sum"),
-    AddNOp<CPUDevice, int32, OpKernel, OpKernelConstruction, OpKernelContext>);
+REGISTER_KERNEL_BUILDER(Name("AddN")
+                            .Device(DEVICE_GPU)
+                            .TypeConstraint<int32>("T")
+                            .HostMemory("inputs")
+                            .HostMemory("sum"),
+                        AddNOp<CPUDevice, int32>);
 
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
-
 #undef REGISTER_ADDN
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/aggregate_ops.h b/tensorflow/core/kernels/aggregate_ops.h
index f4351f3..38912ee 100644
--- a/tensorflow/core/kernels/aggregate_ops.h
+++ b/tensorflow/core/kernels/aggregate_ops.h
@@ -18,11 +18,8 @@
 
 #include <numeric>
 
-#include "tensorflow/core/framework/op_requires.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/variant.h"
-#include "tensorflow/core/framework/variant_op_registry.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
 
 namespace tensorflow {
 namespace functor {
@@ -223,226 +220,7 @@
     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9;
   }
 };
-
 }  // namespace functor
-
-template <typename Device, typename T, class OpKernelT,
-          class OpKernelConstructionT, class OpKernelContextT>
-class AddNOp : public OpKernelT {
- public:
-  explicit AddNOp(OpKernelConstructionT* context) : OpKernelT(context) {}
-
-  void Compute(OpKernelContextT* ctx) override {
-    if (!ctx->ValidateInputsAreSameShape(this)) return;
-
-    const Tensor& input0 = ctx->input(0);
-    const int num = ctx->num_inputs();
-
-    if (num == 1) {
-      ctx->set_output(0, input0);
-      return;
-    }
-
-    // Try to forward and accumulate the result in one of the input buffers.
-    int reused_input = -1;
-    gtl::InlinedVector<int, 8> input_indices(num);
-    std::iota(input_indices.begin(), input_indices.end(), 0);
-    Tensor* output = nullptr;
-    for (int input_idx = 0; input_idx < num; ++input_idx) {
-      if (ctx->forward_input_to_output_with_shape(input_idx, 0, input0.shape(),
-                                                  &output)) {
-        reused_input = input_idx;
-        break;
-      }
-    }
-    if (reused_input == -1) {
-      OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input0.shape(), &output));
-    } else if (reused_input > 0) {
-      // Move the forwarded buffer to the front so we don't double count
-      // anything if there are more than 8 inputs.
-      input_indices[0] = reused_input;
-      input_indices[reused_input] = 0;
-    }
-    auto To = output->flat<T>();
-
-#define I(IDX) ctx->input(input_indices[IDX]).template flat<T>()
-
-#if defined(__ANDROID_TYPES_SLIM__)
-    // On Android by default,we only support additions of two arguments, so we
-    // can reduce the number of template instantiations.
-    OP_REQUIRES(ctx, num == 2,
-                errors::InvalidArgument("Only additions of two arguments "
-                                        "supported. Num inputs: ",
-                                        num));
-    functor::Add2Functor<Device, T> functor2;
-    functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
-#else
-    static const int kWidth = 8;
-    int r = num % kWidth;
-
-    switch (r) {
-      case 2: {
-        functor::Add2Functor<Device, T> functor2;
-        functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
-        break;
-      }
-      case 3: {
-        functor::Add3Functor<Device, T> functor3;
-        functor3(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2));
-        break;
-      }
-      case 4: {
-        functor::Add4Functor<Device, T> functor4;
-        functor4(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
-                 I(3));
-        break;
-      }
-      case 5: {
-        functor::Add5Functor<Device, T> functor5;
-        functor5(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
-                 I(3), I(4));
-        break;
-      }
-      case 6: {
-        functor::Add6Functor<Device, T> functor6;
-        functor6(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
-                 I(3), I(4), I(5));
-        break;
-      }
-      case 7: {
-        functor::Add7Functor<Device, T> functor7;
-        functor7(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
-                 I(3), I(4), I(5), I(6));
-        break;
-      }
-      case 0: {
-        functor::Add8Functor<Device, T> functor8;
-        functor8(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
-                 I(3), I(4), I(5), I(6), I(7));
-        r = 8;
-        break;
-      }
-      case 1: {
-        functor::Add9Functor<Device, T> functor9;
-        functor9(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
-                 I(3), I(4), I(5), I(6), I(7), I(8));
-        r = 9;
-        break;
-      }
-    }
-
-    for (; r < num; r += kWidth) {
-      functor::Add8pFunctor<Device, T> functor8p;
-      functor8p(ctx->template eigen_device<Device>(), To, I(r), I(r + 1),
-                I(r + 2), I(r + 3), I(r + 4), I(r + 5), I(r + 6), I(r + 7));
-    }
-#endif  // defined(__ANDROID_TYPES_SLIM__)
-
-#undef I
-  }
-};
-
-template <typename Device, class OpKernelT, class OpKernelConstructionT,
-          class OpKernelContextT>
-class AddNOp<Device, Variant, OpKernelT, OpKernelConstructionT,
-             OpKernelContextT> : public OpKernelT {
- public:
-  explicit AddNOp(OpKernelConstructionT* context) : OpKernelT(context) {}
-
-  void Compute(OpKernelContextT* ctx) override {
-    if (!ctx->ValidateInputsAreSameShape(this)) return;
-
-    const Tensor& input0 = ctx->input(0);
-    const int num = ctx->num_inputs();
-
-    if (num == 1) {
-      ctx->set_output(0, input0);
-      return;
-    }
-
-    for (int i = 0; i < num; ++i) {
-      // Step 1: ensure unary variants.
-      OP_REQUIRES(
-          ctx, ctx->input(i).dims() == 0,
-          errors::InvalidArgument(
-              "AddN of non-scalar Tensor with dtype=DT_VARIANT is not "
-              "supported; inputs[",
-              i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
-    }
-
-    // Step 2: Sum input variants in a tree-like structure using
-    //   BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
-    //   For the output create a default-constructed variant object.
-    //
-    // Pairwise summation provides better numerical precision by
-    // reducing round-off error:
-    //
-    //   https://en.wikipedia.org/wiki/Pairwise_summation
-    //
-    // These two vectors are used to store and mark intermediate sums.
-    gtl::InlinedVector<bool, 4> temp_filled(num, false);
-    gtl::InlinedVector<Variant, 4> temp(num);
-
-    // Tree-based summation.
-    int skip = 1;
-    int n = num;
-    while (skip < n) {
-      int i = skip;
-      while (i < n) {
-        // TODO(ebrevdo, rmlarsen): Parallelize the pairwise summations in the
-        // inner loop if the variants are "large".
-
-        // x[i - skip] += x[i]
-        OP_REQUIRES_OK(ctx,
-                       AddVariantTo(ctx, i - skip, i, &temp, &temp_filled));
-        // We won't use this index again, recover its memory.
-        temp[i].clear();
-        i += 2 * skip;
-      }
-      if (i == n) {
-        // x[0] += x[i - skip]
-        OP_REQUIRES_OK(ctx,
-                       AddVariantTo(ctx, 0, i - skip, &temp, &temp_filled));
-        // We won't use this index again, recover its memory.
-        temp[i - skip].clear();
-        n -= skip;
-      }
-      skip *= 2;
-    }
-
-    Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
-    out.scalar<Variant>()() = std::move(temp[0]);
-    ctx->set_output(0, out);
-  }
-
- private:
-  // AddVariantTo efficiently performs:
-  //    temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
-  // where array(ix) := (temp_filled[ix]
-  //                     ? temp[ix]
-  //                     : ctx->input(ix).scalar<Variant>()())
-  // This reduces (possibly expensive) copying of Variants from
-  // the inputs into temp at the lowest levels of the summation tree.
-  static inline Status AddVariantTo(OpKernelContextT* ctx, const int lhs_ix,
-                                    const int rhs_ix,
-                                    gtl::InlinedVector<Variant, 4>* temp,
-                                    gtl::InlinedVector<bool, 4>* temp_filled) {
-    Variant tmp;
-    if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
-    const Variant& a = temp_filled->at(lhs_ix)
-                           ? tmp
-                           : ctx->input(lhs_ix).template scalar<Variant>()();
-    const Variant& b = temp_filled->at(rhs_ix)
-                           ? temp->at(rhs_ix)
-                           : ctx->input(rhs_ix).template scalar<Variant>()();
-    Variant* c = &temp->at(lhs_ix);
-    TF_RETURN_IF_ERROR(
-        BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
-    temp_filled->at(lhs_ix) = true;
-    return Status::OK();
-  }
-};
-
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index 5447fac..4cbd214 100644
--- a/tensorflow/core/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -14,6 +14,8 @@
 ==============================================================================*/
 
 #include "absl/strings/str_cat.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/device.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/resource_mgr.h"
@@ -187,12 +189,7 @@
                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
 
-    auto lib = c->function_library();
-    OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library"));
-    NameAttrList func;
-    OP_REQUIRES_OK(c, c->GetAttr("f", &func));
-    OP_REQUIRES_OK(
-        c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
+    OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
     if (num_batch_threads_ <= 0) {
       adaptive_batch_scheduler_options_ =
           absl::make_optional(AdaptiveBatchSchedulerOptions{
@@ -242,8 +239,11 @@
 
     std::function<Status(BatchResource**)> creator;
 
+    FunctionLibraryRuntime::Handle handle;
+    OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done);
+
     if (adaptive_batch_scheduler_options_ != absl::nullopt) {
-      creator = [this](BatchResource** r) {
+      creator = [this, handle](BatchResource** r) {
         serving::AdaptiveSharedBatchScheduler<
             serving::BatchResourceBase::BatchTask>::Options
             adaptive_shared_batch_scheduler_options;
@@ -274,16 +274,16 @@
         TF_RETURN_IF_ERROR(BatchResource::Create(
             adaptive_shared_batch_scheduler_options, max_batch_size_,
             batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
-            fhandle_, &new_resource));
+            handle, &new_resource));
         *r = new_resource.release();
         return Status::OK();
       };
     } else {
-      creator = [this](BatchResource** r) {
+      creator = [this, handle](BatchResource** r) {
         std::unique_ptr<BatchResource> new_resource;
         TF_RETURN_IF_ERROR(BatchResource::Create(
             num_batch_threads_, max_batch_size_, batch_timeout_micros_,
-            max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
+            max_enqueued_batches_, allowed_batch_sizes_, handle,
             enable_large_batch_splitting_, &new_resource));
         *r = new_resource.release();
         return Status::OK();
@@ -302,6 +302,75 @@
     // Assume br calls done, so nothing to do here.
   }
 
+  Status InstantiateFunction(OpKernelContext* c,
+                             FunctionLibraryRuntime::Handle* handle) const {
+    // TODO(b/173748062): Merge this instantiation logic with PartitionedCall.
+    FunctionLibraryRuntime* lib = c->function_library();
+    if (!lib) {
+      return errors::Internal("No function library");
+    }
+
+    FunctionLibraryRuntime::InstantiateOptions opts;
+    opts.target = lib->device() == nullptr ? "" : lib->device()->name();
+    opts.is_multi_device_function = true;
+
+    Device* cpu_device;
+    TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
+
+    const FunctionDef* fdef =
+        lib->GetFunctionLibraryDefinition()->Find(func_.name());
+    if (!fdef) {
+      return errors::NotFound("Failed to find definition for function \"",
+                              func_.name(), "\"");
+    }
+    OpInputList in_tensors;
+    TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors));
+    for (int i = 0; i < in_tensors.size(); i++) {
+      if (in_tensors[i].dtype() == DT_RESOURCE) {
+        return errors::InvalidArgument(
+            "BatchFunction cannot take resource inputs but input ", i,
+            " is a resource.");
+      } else {
+        // Currently, inputs are on CPU since they are concatenated on CPU
+        opts.input_devices.push_back(cpu_device->name());
+      }
+    }
+    OpInputList captured_tensors;
+    TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors));
+    for (const Tensor& t : captured_tensors) {
+      if (t.dtype() == DT_RESOURCE) {
+        const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0);
+        opts.input_devices.push_back(rhandle.device());
+      } else {
+        opts.input_devices.push_back(cpu_device->name());
+      }
+    }
+    const OpDef& signature = fdef->signature();
+    for (int i = 0; i < signature.output_arg_size(); i++) {
+      // Currently, outputs must be on CPU since they are split on CPU.
+      opts.output_devices.push_back(cpu_device->name());
+    }
+    if (opts.input_devices.size() != signature.input_arg_size()) {
+      return errors::InvalidArgument(
+          "Function takes ", signature.input_arg_size(), " argument(s) but ",
+          opts.input_devices.size(), " argument(s) were passed");
+    }
+    return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
+                            handle);
+  }
+
+  Status GetOrCreateFunctionHandle(OpKernelContext* c,
+                                   FunctionLibraryRuntime::Handle* handle) {
+    mutex_lock ml(mu_);
+    if (!fhandle_) {
+      TF_RETURN_IF_ERROR(InstantiateFunction(c, handle));
+      fhandle_ = *handle;
+    } else {
+      *handle = fhandle_.value();
+    }
+    return Status::OK();
+  }
+
   // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
   // and the last one must equal 'max_batch_size_'.
   Status ValidateAllowedBatchSizes() const {
@@ -337,9 +406,11 @@
   int32 batch_timeout_micros_;
   int32 max_enqueued_batches_;
   std::vector<int32> allowed_batch_sizes_;
-  FunctionLibraryRuntime::Handle fhandle_;
+  NameAttrList func_;
+  absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_);
   bool enable_large_batch_splitting_;
   bool has_attribute_enable_large_batch_splitting_;
+  mutex mu_;
 
   // Parameters for adaptive batch scheduler only.
   // Note 'num_batch_threads_' above is shared by two implementations of batch
@@ -355,6 +426,14 @@
 
 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
                         BatchFunctionKernel);
+// Currently all inputs and outputs are on the host.
+// TODO(b/173748277): Accept inputs/outputs on the device.
+REGISTER_KERNEL_BUILDER(Name("BatchFunction")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("in_tensors")
+                            .HostMemory("captured_tensors")
+                            .HostMemory("out_tensors"),
+                        BatchFunctionKernel);
 
 class BatchKernel : public AsyncOpKernel {
  public:
diff --git a/tensorflow/core/kernels/batching_util/concat_split_util.h b/tensorflow/core/kernels/batching_util/concat_split_util.h
index 77c4463..914c793 100644
--- a/tensorflow/core/kernels/batching_util/concat_split_util.h
+++ b/tensorflow/core/kernels/batching_util/concat_split_util.h
@@ -71,8 +71,10 @@
 
   TensorShape output_shape(input_shape);
   output_shape.set_dim(0, output_dim0);
-  TF_RETURN_IF_ERROR(
-      context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
+  AllocatorAttributes attr;
+  attr.set_on_host(true);
+  TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
+                                            output_shape, output, attr));
   if (output->NumElements() > 0) {
     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
@@ -167,8 +169,10 @@
     TensorShape output_shape = input.shape();
     output_shape.set_dim(0, size);
     Tensor output;
+    AllocatorAttributes attr;
+    attr.set_on_host(true);
     TF_RETURN_IF_ERROR(
-        context->allocate_temp(input.dtype(), output_shape, &output));
+        context->allocate_temp(input.dtype(), output_shape, &output, attr));
     auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
 
     Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 357ae15..f8f1875 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/framework/collective.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_util.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
@@ -742,5 +743,261 @@
                             .HostMemory("instance_key"),
                         CollectiveGatherV2OpKernel);
 
+class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
+ public:
+  explicit CollectiveBcastSendV2OpKernel(OpKernelConstruction* c)
+      : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
+    OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
+    OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
+    OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
+    const bool is_source = true;
+    name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
+  }
+
+ protected:
+  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+    CollectiveExecutor* col_exec = c->collective_executor();
+    OP_REQUIRES_ASYNC(
+        c, col_exec,
+        errors::Internal(
+            "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+            name_),
+        done);
+    const Tensor& input = c->input(0);
+    const Tensor& group_size = c->input(1);
+    const Tensor& group_key = c->input(2);
+    const Tensor& instance_key = c->input(3);
+    OP_REQUIRES_ASYNC(
+        c, group_size.dims() == 0,
+        errors::Internal("Unexpected dimensions on input group_size"), done);
+    OP_REQUIRES_ASYNC(
+        c, group_key.dims() == 0,
+        errors::Internal("Unexpected dimensions on input group_key"), done);
+    OP_REQUIRES_ASYNC(
+        c, instance_key.dims() == 0,
+        errors::Internal("Unexpected dimensions on input instance_key"), done);
+
+    auto col_params = new CollectiveParams();
+    col_params->name = name_;
+    col_params->group.device_type = device_type_;
+    col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
+    col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
+    col_params->instance.type = BROADCAST_COLLECTIVE;
+    col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
+    col_params->instance.data_type = data_type_;
+    col_params->instance.impl_details.communication_hint = communication_hint_;
+    col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
+    col_params->is_source = true;
+    // Add a default value for subdiv offsets, which is the same as the default
+    // value in the V1 op's attribute.
+    col_params->instance.impl_details.subdiv_offsets.push_back(0);
+    VLOG(1) << "CollectiveBcastSendV2 group_size "
+            << col_params->group.group_size << " group_key "
+            << col_params->group.group_key << " instance_key "
+            << col_params->instance.instance_key;
+
+    auto done_with_cleanup = [col_params, done = std::move(done)]() {
+      delete col_params;
+      done();
+    };
+
+    // Allocate the output tensor, trying to reuse the input.
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK_ASYNC(
+        c, c->forward_input_or_allocate_output({0}, 0, input.shape(), &output),
+        done_with_cleanup);
+    col_params->instance.shape = input.shape();
+
+    // Resolve the collective params.
+    // Schedule the `CompleteParamsAsync` call on a work queue that can handle
+    // blocking work because it's not guaranteed that this call cannot block.
+    c->collective_executor()->RunClosure([c,
+                                          done = std::move(done_with_cleanup),
+                                          col_params, col_exec]() {
+      VLOG(1) << "CollectiveBcastSendV2 CompleteParams for collective "
+              << col_params->name << " device " << c->device()->name()
+              << " group " << col_params->group.group_key << " instance "
+              << col_params->instance.instance_key;
+      col_exec->CompleteParamsAsync(
+          c->device()->attributes(), col_params, c->cancellation_manager(),
+          [c, done = std::move(done), col_params, col_exec](const Status& s) {
+            if (s.ok()) {
+              auto actual_done = [c, group_key = col_params->group.group_key,
+                                  instance_key =
+                                      col_params->instance.instance_key,
+                                  done = std::move(done)](const Status& s) {
+                VLOG(1) << "CollectiveBcastSendV2 ExecuteAsync done for "
+                           "collective "
+                        << c->op_kernel().name() << " device "
+                        << c->device()->name() << " group " << group_key
+                        << " instance " << instance_key << " status " << s;
+                OP_REQUIRES_OK_ASYNC(c, s, done);
+                done();
+              };
+              VLOG(1) << "CollectiveBcastSendV2 ExecuteAsync start for "
+                         "collective "
+                      << col_params->name << " device " << c->device()->name()
+                      << " group " << col_params->group.group_key
+                      << " instance " << col_params->instance.instance_key;
+              col_exec->ExecuteAsync(
+                  c, *col_params,
+                  CollectiveKey(c, col_params->group.group_key,
+                                col_params->instance.instance_key),
+                  actual_done);
+            } else {
+              c->SetStatus(s);
+              done();
+            }
+          });
+    });
+  }
+
+ private:
+  DeviceType device_type_;
+  DataType data_type_ = DT_INVALID;
+  string communication_hint_;
+  float timeout_seconds_ = 0;
+  string name_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2").Device(DEVICE_CPU),
+                        CollectiveBcastSendV2OpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("group_size")
+                            .HostMemory("group_key")
+                            .HostMemory("instance_key"),
+                        CollectiveBcastSendV2OpKernel);
+
+class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
+ public:
+  explicit CollectiveBcastRecvV2OpKernel(OpKernelConstruction* c)
+      : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
+    OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
+    OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
+    OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
+    const bool is_source = false;
+    name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
+  }
+
+ protected:
+  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+    CollectiveExecutor* col_exec = c->collective_executor();
+    OP_REQUIRES_ASYNC(
+        c, col_exec,
+        errors::Internal(
+            "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+            name_),
+        done);
+    const Tensor& group_size = c->input(0);
+    const Tensor& group_key = c->input(1);
+    const Tensor& instance_key = c->input(2);
+    const Tensor& shape_tensor = c->input(3);
+    OP_REQUIRES_ASYNC(
+        c, group_size.dims() == 0,
+        errors::Internal("Unexpected dimensions on input group_size"), done);
+    OP_REQUIRES_ASYNC(
+        c, group_key.dims() == 0,
+        errors::Internal("Unexpected dimensions on input group_key"), done);
+    OP_REQUIRES_ASYNC(
+        c, instance_key.dims() == 0,
+        errors::Internal("Unexpected dimensions on input instance_key"), done);
+
+    auto col_params = new CollectiveParams();
+    auto done_with_cleanup = [col_params, done = std::move(done)]() {
+      delete col_params;
+      done();
+    };
+
+    OP_REQUIRES_OK_ASYNC(
+        c, tensor::MakeShape(shape_tensor, &col_params->instance.shape),
+        done_with_cleanup);
+    col_params->name = name_;
+    col_params->group.device_type = device_type_;
+    col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
+    col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
+    col_params->instance.type = BROADCAST_COLLECTIVE;
+    col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
+    col_params->instance.data_type = data_type_;
+    col_params->instance.impl_details.communication_hint = communication_hint_;
+    col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
+    col_params->is_source = false;
+    // Add a default value for subdiv offsets, which is the same as the default
+    // value in the V1 op's attribute.
+    col_params->instance.impl_details.subdiv_offsets.push_back(0);
+    VLOG(1) << "CollectiveBcastRecvV2 group_size "
+            << col_params->group.group_size << " group_key "
+            << col_params->group.group_key << " instance_key "
+            << col_params->instance.instance_key;
+
+    // Allocate the output tensor.
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK_ASYNC(c,
+                         c->forward_input_or_allocate_output(
+                             {0}, 0, col_params->instance.shape, &output),
+                         done_with_cleanup);
+
+    // Resolve the collective params.
+    // Schedule the `CompleteParamsAsync` call on a work queue that can handle
+    // blocking work because it's not guaranteed that this call cannot block.
+    c->collective_executor()->RunClosure([c,
+                                          done = std::move(done_with_cleanup),
+                                          col_params, col_exec]() {
+      VLOG(1) << "CollectiveBcastRecvV2 CompleteParams for collective "
+              << col_params->name << " device " << c->device()->name()
+              << " group " << col_params->group.group_key << " instance "
+              << col_params->instance.instance_key;
+      col_exec->CompleteParamsAsync(
+          c->device()->attributes(), col_params, c->cancellation_manager(),
+          [c, done = std::move(done), col_params, col_exec](const Status& s) {
+            if (s.ok()) {
+              auto actual_done = [c, group_key = col_params->group.group_key,
+                                  instance_key =
+                                      col_params->instance.instance_key,
+                                  done = std::move(done)](const Status& s) {
+                VLOG(1) << "CollectiveBcastRecvV2 ExecuteAsync done for "
+                           "collective "
+                        << c->op_kernel().name() << " device "
+                        << c->device()->name() << " group " << group_key
+                        << " instance " << instance_key << " status " << s;
+                OP_REQUIRES_OK_ASYNC(c, s, done);
+                done();
+              };
+              VLOG(1) << "CollectiveBcastRecvV2 ExecuteAsync start for "
+                         "collective "
+                      << col_params->name << " device " << c->device()->name()
+                      << " group " << col_params->group.group_key
+                      << " instance " << col_params->instance.instance_key;
+              col_exec->ExecuteAsync(
+                  c, *col_params,
+                  CollectiveKey(c, col_params->group.group_key,
+                                col_params->instance.instance_key),
+                  actual_done);
+            } else {
+              c->SetStatus(s);
+              done();
+            }
+          });
+    });
+  }
+
+ private:
+  DeviceType device_type_;
+  DataType data_type_ = DT_INVALID;
+  string communication_hint_;
+  float timeout_seconds_ = 0;
+  string name_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2").Device(DEVICE_CPU),
+                        CollectiveBcastRecvV2OpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("group_size")
+                            .HostMemory("group_key")
+                            .HostMemory("instance_key")
+                            .HostMemory("shape"),
+                        CollectiveBcastRecvV2OpKernel);
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 4ca5f51..f6b30ce 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -16,8 +16,7 @@
 #define USE_EIGEN_TENSOR
 #define EIGEN_USE_THREADS
 
-#include "tensorflow/core/kernels/conv_ops_3d.h"
-
+#include "tensorflow/core/framework/kernel_shape_util.h"
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -51,11 +50,147 @@
 typedef Eigen::ThreadPoolDevice CPUDevice;
 typedef Eigen::GpuDevice GPUDevice;
 
+template <typename Device, typename T>
+struct LaunchConvOp;
+
+template <typename T>
+struct LaunchConvOp<CPUDevice, T> {
+  static void launch(OpKernelContext* context, bool cudnn_use_autotune,
+                     const Tensor& input, const Tensor& filter,
+                     const std::array<int64, 3>& dilations,
+                     const std::array<int64, 3>& strides, const Padding padding,
+                     TensorFormat data_format, Tensor* output) {
+    OP_REQUIRES(context, data_format == FORMAT_NHWC,
+                errors::InvalidArgument("CPU implementation of Conv3D "
+                                        "currently only supports the NHWC "
+                                        "tensor format."));
+    OP_REQUIRES(context,
+                dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1,
+                errors::InvalidArgument("CPU implementation of Conv3D "
+                                        "currently only supports dilated rates "
+                                        "of 1."));
+    functor::CuboidConvolution<CPUDevice, T>()(
+        context->eigen_device<CPUDevice>(), output->tensor<T, 5>(),
+        input.tensor<T, 5>(), filter.tensor<T, 5>(), strides[2], strides[1],
+        strides[0], BrainPadding2EigenPadding(padding));
+  }
+};
+
+template <typename Device, typename T>
+class Conv3DOp : public BinaryOp<T> {
+ public:
+  explicit Conv3DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
+    string data_format;
+    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+                errors::InvalidArgument("Invalid data format"));
+    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+    OP_REQUIRES(context, stride_.size() == 5,
+                errors::InvalidArgument("Sliding window strides field must "
+                                        "specify 5 dimensions"));
+    OP_REQUIRES(
+        context,
+        (GetTensorDim(stride_, data_format_, 'N') == 1 &&
+         GetTensorDim(stride_, data_format_, 'C') == 1),
+        errors::InvalidArgument("Current implementation does not yet support "
+                                "strides in the batch and depth dimensions."));
+    OP_REQUIRES(
+        context,
+        (GetTensorDim(stride_, data_format_, '0') > 0 &&
+         GetTensorDim(stride_, data_format_, '1') > 0 &&
+         GetTensorDim(stride_, data_format_, '2') > 0),
+        errors::InvalidArgument("Spatial strides should be larger than 0."));
+    OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+    OP_REQUIRES(context, dilation_.size() == 5,
+                errors::InvalidArgument("Dilation rates field must "
+                                        "specify 5 dimensions"));
+    OP_REQUIRES(context,
+                (GetTensorDim(dilation_, data_format_, 'N') == 1 &&
+                 GetTensorDim(dilation_, data_format_, 'C') == 1),
+                errors::InvalidArgument(
+                    "Current implementation does not yet support "
+                    "dilation rates in the batch and depth dimensions."));
+    OP_REQUIRES(
+        context,
+        (GetTensorDim(dilation_, data_format_, '0') > 0 &&
+         GetTensorDim(dilation_, data_format_, '1') > 0 &&
+         GetTensorDim(dilation_, data_format_, '2') > 0),
+        errors::InvalidArgument("Dilated rates should be larger than 0."));
+    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+    cudnn_use_autotune_ = CudnnUseAutotune();
+  }
+
+  void Compute(OpKernelContext* context) override {
+    // Input tensor is of the following dimensions:
+    // [ batch, in_z, in_y, in_x, in_channels ]
+    const Tensor& input = context->input(0);
+
+    // Input filter is of the following dimensions:
+    // [ filter_z, filter_y, filter_x, in_channels, out_channels]
+    const Tensor& filter = context->input(1);
+
+    // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be
+    // kept consistent between input/filter/output.
+    OP_REQUIRES(context, input.dims() == 5,
+                errors::InvalidArgument("input must be 5-dimensional"));
+    OP_REQUIRES(context, filter.dims() == 5,
+                errors::InvalidArgument("filter must be 5-dimensional"));
+
+    const int64 in_depth = GetTensorDim(input, data_format_, 'C');
+    const int64 in_batch = GetTensorDim(input, data_format_, 'N');
+
+    const int64 filter_depth = filter.dim_size(3);
+    const int64 out_depth = filter.dim_size(4);
+
+    OP_REQUIRES(context, in_depth % filter_depth == 0,
+                errors::InvalidArgument(
+                    "Input depth must be evenly divisible by filter depth: ",
+                    in_depth, " vs ", filter_depth));
+
+    // Dimension order for these arrays is: z, y, x.
+    std::array<int64, 3> input_size = {
+        {GetTensorDim(input, data_format_, '0'),
+         GetTensorDim(input, data_format_, '1'),
+         GetTensorDim(input, data_format_, '2')}};
+    std::array<int64, 3> filter_size = {
+        {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}};
+    std::array<int64, 3> dilations = {
+        {GetTensorDim(dilation_, data_format_, '0'),
+         GetTensorDim(dilation_, data_format_, '1'),
+         GetTensorDim(dilation_, data_format_, '2')}};
+    std::array<int64, 3> strides = {{GetTensorDim(stride_, data_format_, '0'),
+                                     GetTensorDim(stride_, data_format_, '1'),
+                                     GetTensorDim(stride_, data_format_, '2')}};
+    std::array<int64, 3> out, padding;
+
+    OP_REQUIRES_OK(
+        context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides,
+                                   padding_, &out, &padding));
+    TensorShape out_shape = ShapeFromFormat(
+        data_format_, in_batch, {{out[0], out[1], out[2]}}, out_depth);
+    Tensor* output;
+    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+    // Return early if nothing to do.
+    if (out_shape.num_elements() == 0) return;
+
+    LaunchConvOp<Device, T>::launch(context, cudnn_use_autotune_, input, filter,
+                                    dilations, strides, padding_, data_format_,
+                                    output);
+  }
+
+ private:
+  std::vector<int32> dilation_;
+  std::vector<int32> stride_;
+  Padding padding_;
+  TensorFormat data_format_;
+  bool cudnn_use_autotune_;
+};
+
 #define REGISTER_CPU_KERNEL(T)                                  \
   REGISTER_KERNEL_BUILDER(                                      \
       Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      Conv3DOp<CPUDevice, T, OpKernel, OpKernelConstruction,    \
-               OpKernelContext>);
+      Conv3DOp<CPUDevice, T>);
 TF_CALL_half(REGISTER_CPU_KERNEL);
 TF_CALL_float(REGISTER_CPU_KERNEL);
 TF_CALL_double(REGISTER_CPU_KERNEL);
@@ -73,7 +208,7 @@
 
 // TODO(mjanusz): Share logic with 2d implementation as much as possible.
 template <typename T>
-struct LaunchConvOp<GPUDevice, T, OpKernelContext> {
+struct LaunchConvOp<GPUDevice, T> {
   static void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
                      const Tensor& input_param, const Tensor& filter,
                      const std::array<int64, 3>& dilations,
@@ -548,16 +683,13 @@
 // Registration of the GPU implementations.
 REGISTER_KERNEL_BUILDER(
     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
-    Conv3DOp<GPUDevice, Eigen::half, OpKernel, OpKernelConstruction,
-             OpKernelContext>);
+    Conv3DOp<GPUDevice, Eigen::half>);
 REGISTER_KERNEL_BUILDER(
     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
-    Conv3DOp<GPUDevice, float, OpKernel, OpKernelConstruction,
-             OpKernelContext>);
+    Conv3DOp<GPUDevice, float>);
 REGISTER_KERNEL_BUILDER(
     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<double>("T"),
-    Conv3DOp<GPUDevice, double, OpKernel, OpKernelConstruction,
-             OpKernelContext>);
+    Conv3DOp<GPUDevice, double>);
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_3d.h b/tensorflow/core/kernels/conv_ops_3d.h
deleted file mode 100644
index 9dcdea5..0000000
--- a/tensorflow/core/kernels/conv_ops_3d.h
+++ /dev/null
@@ -1,187 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_3D_H_
-#define TENSORFLOW_CORE_KERNELS_CONV_OPS_3D_H_
-
-#include <vector>
-
-#define USE_EIGEN_TENSOR
-#define EIGEN_USE_THREADS
-
-#include "tensorflow/core/framework/numeric_op_base.h"
-#include "tensorflow/core/framework/kernel_shape_util.h"
-#include "tensorflow/core/framework/op_requires.h"
-#include "tensorflow/core/framework/ops_util.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/kernels/conv_3d.h"
-#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/util/padding.h"
-#include "tensorflow/core/util/tensor_format.h"
-#if GOOGLE_CUDA
-#include "tensorflow/core/util/use_cudnn.h"
-#endif
-
-namespace tensorflow {
-typedef Eigen::ThreadPoolDevice CPUDevice;
-
-template <typename Device, typename T, class OpKernelContextT>
-struct LaunchConvOp;
-
-template <typename T, class OpKernelContextT>
-struct LaunchConvOp<CPUDevice, T, OpKernelContextT> {
-  static void launch(OpKernelContextT* context, bool cudnn_use_autotune,
-                     const Tensor& input, const Tensor& filter,
-                     const std::array<int64, 3>& dilations,
-                     const std::array<int64, 3>& strides, const Padding padding,
-                     TensorFormat data_format, Tensor* output) {
-    OP_REQUIRES(context, data_format == FORMAT_NHWC,
-                errors::InvalidArgument("CPU implementation of Conv3D "
-                                        "currently only supports the NHWC "
-                                        "tensor format."));
-    OP_REQUIRES(context,
-                dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1,
-                errors::InvalidArgument("CPU implementation of Conv3D "
-                                        "currently only supports dilated rates "
-                                        "of 1."));
-    functor::CuboidConvolution<CPUDevice, T>()(
-        context->template eigen_device<CPUDevice>(), output->tensor<T, 5>(),
-        input.tensor<T, 5>(), filter.tensor<T, 5>(), strides[2], strides[1],
-        strides[0], BrainPadding2EigenPadding(padding));
-  }
-};
-
-template <typename Device, typename T, class OpKernelT,
-          class OpKernelConstructionT, class OpKernelContextT>
-class Conv3DOp : public BinaryOpBase<T, OpKernelT, OpKernelConstructionT> {
- public:
-  explicit Conv3DOp(OpKernelConstructionT* context) :
-      BinaryOpBase<T, OpKernelT, OpKernelConstructionT>(context) {
-    string data_format;
-    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
-    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
-                errors::InvalidArgument("Invalid data format"));
-    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
-    OP_REQUIRES(context, stride_.size() == 5,
-                errors::InvalidArgument("Sliding window strides field must "
-                                        "specify 5 dimensions"));
-    OP_REQUIRES(
-        context,
-        (GetTensorDim(stride_, data_format_, 'N') == 1 &&
-         GetTensorDim(stride_, data_format_, 'C') == 1),
-        errors::InvalidArgument("Current implementation does not yet support "
-                                "strides in the batch and depth dimensions."));
-    OP_REQUIRES(
-        context,
-        (GetTensorDim(stride_, data_format_, '0') > 0 &&
-         GetTensorDim(stride_, data_format_, '1') > 0 &&
-         GetTensorDim(stride_, data_format_, '2') > 0),
-        errors::InvalidArgument("Spatial strides should be larger than 0."));
-    OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
-    OP_REQUIRES(context, dilation_.size() == 5,
-                errors::InvalidArgument("Dilation rates field must "
-                                        "specify 5 dimensions"));
-    OP_REQUIRES(context,
-                (GetTensorDim(dilation_, data_format_, 'N') == 1 &&
-                 GetTensorDim(dilation_, data_format_, 'C') == 1),
-                errors::InvalidArgument(
-                    "Current implementation does not yet support "
-                    "dilation rates in the batch and depth dimensions."));
-    OP_REQUIRES(
-        context,
-        (GetTensorDim(dilation_, data_format_, '0') > 0 &&
-         GetTensorDim(dilation_, data_format_, '1') > 0 &&
-         GetTensorDim(dilation_, data_format_, '2') > 0),
-        errors::InvalidArgument("Dilated rates should be larger than 0."));
-    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
-#if GOOGLE_CUDA
-    cudnn_use_autotune_ = CudnnUseAutotune();
-#else
-    cudnn_use_autotune_ = false;
-#endif
-  }
-
-  void Compute(OpKernelContextT* context) override {
-    // Input tensor is of the following dimensions:
-    // [ batch, in_z, in_y, in_x, in_channels ]
-    const Tensor& input = context->input(0);
-
-    // Input filter is of the following dimensions:
-    // [ filter_z, filter_y, filter_x, in_channels, out_channels]
-    const Tensor& filter = context->input(1);
-
-    // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be
-    // kept consistent between input/filter/output.
-    OP_REQUIRES(context, input.dims() == 5,
-                errors::InvalidArgument("input must be 5-dimensional"));
-    OP_REQUIRES(context, filter.dims() == 5,
-                errors::InvalidArgument("filter must be 5-dimensional"));
-
-    const int64 in_depth = GetTensorDim(input, data_format_, 'C');
-    const int64 in_batch = GetTensorDim(input, data_format_, 'N');
-
-    const int64 filter_depth = filter.dim_size(3);
-    const int64 out_depth = filter.dim_size(4);
-
-    OP_REQUIRES(context, in_depth % filter_depth == 0,
-                errors::InvalidArgument(
-                    "Input depth must be evenly divisible by filter depth: ",
-                    in_depth, " vs ", filter_depth));
-
-    // Dimension order for these arrays is: z, y, x.
-    std::array<int64, 3> input_size = {
-        {GetTensorDim(input, data_format_, '0'),
-         GetTensorDim(input, data_format_, '1'),
-         GetTensorDim(input, data_format_, '2')}};
-    std::array<int64, 3> filter_size = {
-        {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}};
-    std::array<int64, 3> dilations = {
-        {GetTensorDim(dilation_, data_format_, '0'),
-         GetTensorDim(dilation_, data_format_, '1'),
-         GetTensorDim(dilation_, data_format_, '2')}};
-    std::array<int64, 3> strides = {{GetTensorDim(stride_, data_format_, '0'),
-                                     GetTensorDim(stride_, data_format_, '1'),
-                                     GetTensorDim(stride_, data_format_, '2')}};
-    std::array<int64, 3> out, padding;
-
-    OP_REQUIRES_OK(
-        context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides,
-                                   padding_, &out, &padding));
-    TensorShape out_shape = ShapeFromFormat(
-        data_format_, in_batch, {{out[0], out[1], out[2]}}, out_depth);
-    Tensor* output;
-    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
-
-    // Return early if nothing to do.
-    if (out_shape.num_elements() == 0) return;
-
-    LaunchConvOp<Device, T, OpKernelContextT>::launch(
-        context, cudnn_use_autotune_, input, filter,
-        dilations, strides, padding_, data_format_,
-        output);
-  }
-
- private:
-  std::vector<int32> dilation_;
-  std::vector<int32> stride_;
-  Padding padding_;
-  TensorFormat data_format_;
-  bool cudnn_use_autotune_;
-};
-
-}  // namespace tensorflow
-
-
-#endif  // TENSORFLOW_CORE_KERNELS_CONV_OPS_3D_H_
diff --git a/tensorflow/core/kernels/conv_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_ops_benchmark_test.cc
index 8408c8b..022bbce 100644
--- a/tensorflow/core/kernels/conv_ops_benchmark_test.cc
+++ b/tensorflow/core/kernels/conv_ops_benchmark_test.cc
@@ -325,7 +325,8 @@
         .Run(state);                                                    \
     BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D);                       \
   }                                                                     \
-  BENCHMARK(BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC));
+  BENCHMARK(BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC))           \
+      ->Arg(/*unused arg*/ 1);
 
 #define BM_Conv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL)           \
   static void BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH,       \
@@ -336,32 +337,35 @@
         .Run(state);                                                     \
     BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D);                        \
   }                                                                      \
-  BENCHMARK(BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, FC));
+  BENCHMARK(BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, FC))    \
+      ->Arg(/*unused arg*/ 1);
 
-#define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL)      \
-  static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH,  \
+#define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL)        \
+  static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH,    \
+                      FC)(::testing::benchmark::State & state) {             \
+    test::Benchmark(                                                         \
+        #type,                                                               \
+        Conv2DWithBiasAndActivation<float>(N, H, W, C, FW, FH, FC, "Relu")   \
+            .graph,                                                          \
+        /*old_benchmark_api=*/false)                                         \
+        .Run(state);                                                         \
+    BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D);                            \
+  }                                                                          \
+  BENCHMARK(BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC)) \
+      ->Arg(/*unused arg*/ 1);
+
+#define BM_FusedConv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL)        \
+  static void BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH,    \
                       FC)(::testing::benchmark::State & state) {           \
     test::Benchmark(                                                       \
         #type,                                                             \
-        Conv2DWithBiasAndActivation<float>(N, H, W, C, FW, FH, FC, "Relu") \
-            .graph,                                                        \
+        FusedConv2DWithBias<float>(N, H, W, C, FW, FH, FC, {"BiasAdd"}),   \
         /*old_benchmark_api=*/false)                                       \
         .Run(state);                                                       \
     BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D);                          \
   }                                                                        \
-  BENCHMARK(BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC));
-
-#define BM_FusedConv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL)      \
-  static void BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH,  \
-                      FC)(::testing::benchmark::State & state) {         \
-    test::Benchmark(                                                     \
-        #type,                                                           \
-        FusedConv2DWithBias<float>(N, H, W, C, FW, FH, FC, {"BiasAdd"}), \
-        /*old_benchmark_api=*/false)                                     \
-        .Run(state);                                                     \
-    BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D);                        \
-  }                                                                      \
-  BENCHMARK(BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, FC));
+  BENCHMARK(BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, FC)) \
+      ->Arg(/*unused arg*/ 1);
 
 #define BM_FusedConv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL)     \
   static void BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
@@ -374,7 +378,8 @@
     BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D);                              \
   }                                                                            \
   BENCHMARK(                                                                   \
-      BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC));
+      BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC))    \
+      ->Arg(/*unused arg*/ 1);
 
 #define BM_Conv2DWithBatchNorm(N, H, W, C, FW, FH, FC, type, LABEL)           \
   static void BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH,       \
@@ -385,7 +390,8 @@
         .Run(state);                                                          \
     BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D);                             \
   }                                                                           \
-  BENCHMARK(BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
+  BENCHMARK(BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC))    \
+      ->Arg(/*unused arg*/ 1);
 
 #define BM_Conv2DWithBatchNormAndRelu(N, H, W, C, FW, FH, FC, type, LABEL)     \
   static void BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, \
@@ -399,7 +405,8 @@
     BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D);                              \
   }                                                                            \
   BENCHMARK(                                                                   \
-      BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, FC));
+      BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, FC))    \
+      ->Arg(/*unused arg*/ 1);
 
 #define BM_FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC, type, LABEL)     \
   static void BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, \
@@ -411,7 +418,9 @@
         .Run(state);                                                         \
     BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D);                            \
   }                                                                          \
-  BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
+  BENCHMARK(                                                                 \
+      BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC))    \
+      ->Arg(/*unused arg*/ 1);
 
 #define BM_FusedConv2DWithBatchNormAndRelu(N, H, W, C, FW, FH, FC, type,      \
                                            LABEL)                             \
@@ -425,7 +434,8 @@
     BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D);                             \
   }                                                                           \
   BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNormAndRelu, type, N, H, W, C, FW, \
-                    FH, FC));
+                    FH, FC))                                                  \
+      ->Arg(/*unused arg*/ 1);
 
 // -------------------------------------------------------------------------- //
 // Pixel CNN convolutions.
@@ -584,7 +594,8 @@
         .Run(state);                                                          \
     BM_SET_INFO(N, H, W, C, type, "", Conv2D);                                \
   }                                                                           \
-  BENCHMARK(BM_LONG_NAME(BM_Conv2D, type, T, FORMAT, N, H, W, C, FW, FH, FC));
+  BENCHMARK(BM_LONG_NAME(BM_Conv2D, type, T, FORMAT, N, H, W, C, FW, FH, FC)) \
+      ->Arg(/*unused arg*/ 1);
 
 #if GOOGLE_CUDA
 using fp32 = float;
diff --git a/tensorflow/core/kernels/cwise_op_add_1.cc b/tensorflow/core/kernels/cwise_op_add_1.cc
index deabeb9..e5ef14a 100644
--- a/tensorflow/core/kernels/cwise_op_add_1.cc
+++ b/tensorflow/core/kernels/cwise_op_add_1.cc
@@ -25,7 +25,7 @@
 REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double);
 
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(BinaryOp, GPU, "AddV2", functor::add, float, Eigen::half, double);
 #endif
 
diff --git a/tensorflow/core/kernels/cwise_op_add_2.cc b/tensorflow/core/kernels/cwise_op_add_2.cc
index e300552..c8f36d2 100644
--- a/tensorflow/core/kernels/cwise_op_add_2.cc
+++ b/tensorflow/core/kernels/cwise_op_add_2.cc
@@ -33,7 +33,7 @@
           complex128);
 
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER5(BinaryOp, GPU, "AddV2", functor::add, uint8, uint32, int64, complex64,
           complex128);
 #else
diff --git a/tensorflow/core/kernels/cwise_op_bitwise_and.cc b/tensorflow/core/kernels/cwise_op_bitwise_and.cc
index 5e557e7..63d47d1 100644
--- a/tensorflow/core/kernels/cwise_op_bitwise_and.cc
+++ b/tensorflow/core/kernels/cwise_op_bitwise_and.cc
@@ -21,8 +21,16 @@
 
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER8(BinaryOp, GPU, "BitwiseAnd", functor::bitwise_and, int8, int16, int32,
           int64, uint8, uint16, uint32, uint64);
+#else
+// TODO(b/172804967): We do not generate unsigned kernels for GPU via mlir.
+REGISTER4(BinaryOp, GPU, "BitwiseAnd", functor::bitwise_and, uint8, uint16,
+          uint32, uint64);
+#endif  // !MLIR_GENERATED_GPU_KERNELS_ENABLED ||
+        // !MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_bitwise_or.cc b/tensorflow/core/kernels/cwise_op_bitwise_or.cc
index 3b371f9..7457845 100644
--- a/tensorflow/core/kernels/cwise_op_bitwise_or.cc
+++ b/tensorflow/core/kernels/cwise_op_bitwise_or.cc
@@ -21,8 +21,16 @@
 
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER8(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32,
           int64, uint8, uint16, uint32, uint64);
+#else
+// TODO(b/172804967): We do not generate unsigned kernels for GPU via mlir.
+REGISTER4(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, uint8, uint16,
+          uint32, uint64);
+#endif  // !MLIR_GENERATED_GPU_KERNELS_ENABLED ||
+        // !MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_bitwise_xor.cc b/tensorflow/core/kernels/cwise_op_bitwise_xor.cc
index bb3c727..0569bfe 100644
--- a/tensorflow/core/kernels/cwise_op_bitwise_xor.cc
+++ b/tensorflow/core/kernels/cwise_op_bitwise_xor.cc
@@ -21,8 +21,16 @@
 
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER8(BinaryOp, GPU, "BitwiseXor", functor::bitwise_xor, int8, int16, int32,
           int64, uint8, uint16, uint32, uint64);
+#else
+// TODO(b/172804967): We do not generate unsigned kernels for GPU via mlir.
+REGISTER4(BinaryOp, GPU, "BitwiseXor", functor::bitwise_xor, uint8, uint16,
+          uint32, uint64);
+#endif  // !MLIR_GENERATED_GPU_KERNELS_ENABLED ||
+        // !MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_ceil.cc b/tensorflow/core/kernels/cwise_op_ceil.cc
index b6748ea..e722b13 100644
--- a/tensorflow/core/kernels/cwise_op_ceil.cc
+++ b/tensorflow/core/kernels/cwise_op_ceil.cc
@@ -21,7 +21,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(UnaryOp, GPU, "Ceil", functor::ceil, float, Eigen::half, double);
 #endif
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_conj.cc b/tensorflow/core/kernels/cwise_op_conj.cc
index 5e137d6..8442da1 100644
--- a/tensorflow/core/kernels/cwise_op_conj.cc
+++ b/tensorflow/core/kernels/cwise_op_conj.cc
@@ -27,7 +27,7 @@
     Name("Conj").Device(DEVICE_GPU).TypeConstraint<Variant>("T"),
     UnaryVariantOp<GPUDevice, CONJ_VARIANT_UNARY_OP>);
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER_KERNEL_BUILDER(
     Name("Conj").Device(DEVICE_GPU).TypeConstraint<complex64>("T"),
     UnaryOp<GPUDevice, functor::conj<complex64>>);
diff --git a/tensorflow/core/kernels/cwise_op_cos.cc b/tensorflow/core/kernels/cwise_op_cos.cc
index 5bf127f..16697c1 100644
--- a/tensorflow/core/kernels/cwise_op_cos.cc
+++ b/tensorflow/core/kernels/cwise_op_cos.cc
@@ -21,7 +21,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(UnaryOp, GPU, "Cos", functor::cos, float, Eigen::half, double);
 #endif
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_equal_to_1.cc
index af72aa3..b0dfe94 100644
--- a/tensorflow/core/kernels/cwise_op_equal_to_1.cc
+++ b/tensorflow/core/kernels/cwise_op_equal_to_1.cc
@@ -27,8 +27,13 @@
     Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<double>("T"),
     ApproximateEqualOp<CPUDevice, double>);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER4(BinaryOp, GPU, "Equal", functor::equal_to, float, Eigen::half, double,
           uint8);
+#else
+REGISTER(BinaryOp, GPU, "Equal", functor::equal_to, uint8);
+#endif
 REGISTER_KERNEL_BUILDER(
     Name("ApproximateEqual").Device(DEVICE_GPU).TypeConstraint<float>("T"),
     ApproximateEqualOp<GPUDevice, float>);
@@ -48,5 +53,4 @@
                         BinaryOp<CPUDevice, functor::equal_to<int32>>);
 #endif
 
-
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_equal_to_2.cc b/tensorflow/core/kernels/cwise_op_equal_to_2.cc
index 8bf53d8..95eb576 100644
--- a/tensorflow/core/kernels/cwise_op_equal_to_2.cc
+++ b/tensorflow/core/kernels/cwise_op_equal_to_2.cc
@@ -25,8 +25,13 @@
 REGISTER6(BinaryOp, CPU, "Equal", functor::equal_to, int32, int64, complex64,
           complex128, tstring, bool);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER6(BinaryOp, GPU, "Equal", functor::equal_to, int8, int16, int64,
           complex64, complex128, bool);
+#else
+REGISTER2(BinaryOp, GPU, "Equal", functor::equal_to, complex64, complex128);
+#endif
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #endif  // !defined(__ANDROID_TYPES_SLIM__)
diff --git a/tensorflow/core/kernels/cwise_op_exp.cc b/tensorflow/core/kernels/cwise_op_exp.cc
index 24d098a..71833e6 100644
--- a/tensorflow/core/kernels/cwise_op_exp.cc
+++ b/tensorflow/core/kernels/cwise_op_exp.cc
@@ -21,7 +21,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER5(UnaryOp, GPU, "Exp", functor::exp, float, Eigen::half, double,
           complex64, complex128);
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_floor.cc b/tensorflow/core/kernels/cwise_op_floor.cc
index 57296f9..72fea66 100644
--- a/tensorflow/core/kernels/cwise_op_floor.cc
+++ b/tensorflow/core/kernels/cwise_op_floor.cc
@@ -21,7 +21,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(UnaryOp, GPU, "Floor", functor::floor, float, Eigen::half, double);
 #endif
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_floor_div.cc b/tensorflow/core/kernels/cwise_op_floor_div.cc
index a98eecd..4e12829 100644
--- a/tensorflow/core/kernels/cwise_op_floor_div.cc
+++ b/tensorflow/core/kernels/cwise_op_floor_div.cc
@@ -24,9 +24,12 @@
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER4(BinaryOp, GPU, "FloorDiv", functor::floor_div, uint8, uint16, int16,
           int64);
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(BinaryOp, GPU, "FloorDiv", functor::floor_div_real, float,
           Eigen::half, double);
 #endif
+#endif
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 // A special GPU kernel for int32.
diff --git a/tensorflow/core/kernels/cwise_op_gpu_left_shift.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_left_shift.cu.cc
index ac4db97..9305a63 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_left_shift.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_left_shift.cu.cc
@@ -19,8 +19,13 @@
 
 namespace tensorflow {
 namespace functor {
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 DEFINE_BINARY8(left_shift, int8, int16, int32, int64, uint8, uint16, uint32,
                uint64);
+#else
+DEFINE_BINARY4(left_shift, uint8, uint16, uint32, uint64)
+#endif
 }  // namespace functor
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/cwise_op_gpu_right_shift.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_right_shift.cu.cc
index 55d8a88..6c02a6f 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_right_shift.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_right_shift.cu.cc
@@ -19,8 +19,13 @@
 
 namespace tensorflow {
 namespace functor {
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 DEFINE_BINARY8(right_shift, int8, int16, int32, int64, uint8, uint16, uint32,
                uint64);
+#else
+DEFINE_BINARY4(right_shift, uint8, uint16, uint32, uint64);
+#endif
 }  // namespace functor
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/cwise_op_greater.cc b/tensorflow/core/kernels/cwise_op_greater.cc
index f9a2b8c..84fc462 100644
--- a/tensorflow/core/kernels/cwise_op_greater.cc
+++ b/tensorflow/core/kernels/cwise_op_greater.cc
@@ -19,8 +19,14 @@
 REGISTER9(BinaryOp, CPU, "Greater", functor::greater, float, Eigen::half,
           double, int32, int64, uint8, int8, int16, bfloat16);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER7(BinaryOp, GPU, "Greater", functor::greater, float, Eigen::half,
           double, int64, uint8, int8, int16);
+#else
+// TODO(b/172804967): We do not generate unsigned kernels for GPU via mlir.
+REGISTER(BinaryOp, GPU, "Greater", functor::greater, uint8);
+#endif
 
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_greater_equal.cc b/tensorflow/core/kernels/cwise_op_greater_equal.cc
index d33adc2..0bb34a5 100644
--- a/tensorflow/core/kernels/cwise_op_greater_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_greater_equal.cc
@@ -19,8 +19,14 @@
 REGISTER9(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float,
           Eigen::half, double, int32, int64, uint8, int8, int16, bfloat16);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER7(BinaryOp, GPU, "GreaterEqual", functor::greater_equal, float,
           Eigen::half, double, int64, uint8, int8, int16);
+#else
+// TODO(b/172804967): We do not generate unsigned kernels for GPU via mlir.
+REGISTER(BinaryOp, GPU, "GreaterEqual", functor::greater_equal, uint8);
+#endif
 
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_imag.cc b/tensorflow/core/kernels/cwise_op_imag.cc
index fb76ec2..9e34c09 100644
--- a/tensorflow/core/kernels/cwise_op_imag.cc
+++ b/tensorflow/core/kernels/cwise_op_imag.cc
@@ -28,7 +28,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER_COMPLEX(GPU, float, complex64);
 REGISTER_COMPLEX(GPU, double, complex128);
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_isinf.cc b/tensorflow/core/kernels/cwise_op_isinf.cc
index 5f5a0ac..6d4777b 100644
--- a/tensorflow/core/kernels/cwise_op_isinf.cc
+++ b/tensorflow/core/kernels/cwise_op_isinf.cc
@@ -21,7 +21,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(UnaryOp, GPU, "IsInf", functor::isinf, float, Eigen::half, double);
 #endif
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_left_shift.cc b/tensorflow/core/kernels/cwise_op_left_shift.cc
index ed65bea..f43c46b 100644
--- a/tensorflow/core/kernels/cwise_op_left_shift.cc
+++ b/tensorflow/core/kernels/cwise_op_left_shift.cc
@@ -21,8 +21,14 @@
 
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER8(BinaryOp, GPU, "LeftShift", functor::left_shift, int8, int16, int32,
           int64, uint8, uint16, uint32, uint64);
+#else
+REGISTER4(BinaryOp, GPU, "LeftShift", functor::left_shift, uint8, uint16,
+          uint32, uint64);
+#endif
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc
index 817f07a..8819328 100644
--- a/tensorflow/core/kernels/cwise_op_less.cc
+++ b/tensorflow/core/kernels/cwise_op_less.cc
@@ -21,8 +21,14 @@
 REGISTER4(BinaryOp, CPU, "Less", functor::less, int64, uint8, int8, int16);
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER7(BinaryOp, GPU, "Less", functor::less, float, Eigen::half, double,
           int64, uint8, int8, int16);
+#else
+// TODO(b/172804967): We do not generate unsigned kernels for GPU via mlir.
+REGISTER(BinaryOp, GPU, "Less", functor::less, uint8);
+#endif
 
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc
index 17b9915..669b96f 100644
--- a/tensorflow/core/kernels/cwise_op_less_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_less_equal.cc
@@ -22,8 +22,14 @@
           int16);
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER7(BinaryOp, GPU, "LessEqual", functor::less_equal, float, Eigen::half,
           double, int64, uint8, int8, int16);
+#else
+// TODO(b/172804967): We do not generate unsigned kernels for GPU via mlir.
+REGISTER(BinaryOp, GPU, "LessEqual", functor::less_equal, uint8);
+#endif
 
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_log.cc b/tensorflow/core/kernels/cwise_op_log.cc
index f0ece6c..6efce16 100644
--- a/tensorflow/core/kernels/cwise_op_log.cc
+++ b/tensorflow/core/kernels/cwise_op_log.cc
@@ -21,7 +21,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(UnaryOp, GPU, "Log", functor::log, float, Eigen::half, double);
 #endif
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_logical_and.cc b/tensorflow/core/kernels/cwise_op_logical_and.cc
index 32a67c5..8cd8517 100644
--- a/tensorflow/core/kernels/cwise_op_logical_and.cc
+++ b/tensorflow/core/kernels/cwise_op_logical_and.cc
@@ -19,7 +19,11 @@
 REGISTER_KERNEL_BUILDER(Name("LogicalAnd").Device(DEVICE_CPU),
                         BinaryOp<CPUDevice, functor::logical_and>);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER_KERNEL_BUILDER(Name("LogicalAnd").Device(DEVICE_GPU),
                         BinaryOp<GPUDevice, functor::logical_and>);
+#endif  // !MLIR_GENERATED_GPU_KERNELS_ENABLED ||
+        // !MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED
 #endif
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_logical_not.cc b/tensorflow/core/kernels/cwise_op_logical_not.cc
index cd7a5cc..d388765 100644
--- a/tensorflow/core/kernels/cwise_op_logical_not.cc
+++ b/tensorflow/core/kernels/cwise_op_logical_not.cc
@@ -20,7 +20,7 @@
                         UnaryOp<CPUDevice, functor::logical_not>);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER_KERNEL_BUILDER(Name("LogicalNot").Device(DEVICE_GPU),
                         UnaryOp<GPUDevice, functor::logical_not>);
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_logical_or.cc b/tensorflow/core/kernels/cwise_op_logical_or.cc
index 9476393..052e29a 100644
--- a/tensorflow/core/kernels/cwise_op_logical_or.cc
+++ b/tensorflow/core/kernels/cwise_op_logical_or.cc
@@ -19,7 +19,11 @@
 REGISTER_KERNEL_BUILDER(Name("LogicalOr").Device(DEVICE_CPU),
                         BinaryOp<CPUDevice, functor::logical_or>);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER_KERNEL_BUILDER(Name("LogicalOr").Device(DEVICE_GPU),
                         BinaryOp<GPUDevice, functor::logical_or>);
+#endif  // !MLIR_GENERATED_GPU_KERNELS_ENABLED ||
+        // !MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED
 #endif
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_mul_1.cc b/tensorflow/core/kernels/cwise_op_mul_1.cc
index 5660f43..730e6c9 100644
--- a/tensorflow/core/kernels/cwise_op_mul_1.cc
+++ b/tensorflow/core/kernels/cwise_op_mul_1.cc
@@ -30,8 +30,13 @@
 #endif  // __ANDROID_TYPES_SLIM__
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER4(BinaryOp, GPU, "Mul", functor::mul, Eigen::half, float, double,
           uint8);
+#else
+REGISTER(BinaryOp, GPU, "Mul", functor::mul, uint8);
+#endif
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
 // registration requires all int32 inputs and outputs to be in host memory.
diff --git a/tensorflow/core/kernels/cwise_op_mul_2.cc b/tensorflow/core/kernels/cwise_op_mul_2.cc
index c4a2f63..995012f 100644
--- a/tensorflow/core/kernels/cwise_op_mul_2.cc
+++ b/tensorflow/core/kernels/cwise_op_mul_2.cc
@@ -25,8 +25,13 @@
 REGISTER6(BinaryOp, CPU, "Mul", functor::mul, int8, uint16, int16, int64,
           complex64, complex128);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER6(BinaryOp, GPU, "Mul", functor::mul, int8, uint16, int16, int64,
           complex64, complex128);
+#else
+REGISTER3(BinaryOp, GPU, "Mul", functor::mul, uint16, complex64, complex128);
+#endif
 
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
diff --git a/tensorflow/core/kernels/cwise_op_neg_1.cc b/tensorflow/core/kernels/cwise_op_neg_1.cc
index fde5fae..9551d84 100644
--- a/tensorflow/core/kernels/cwise_op_neg_1.cc
+++ b/tensorflow/core/kernels/cwise_op_neg_1.cc
@@ -20,7 +20,10 @@
 
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(UnaryOp, GPU, "Neg", functor::neg, int8, int16, int64);
+#endif
 
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_neg_2.cc b/tensorflow/core/kernels/cwise_op_neg_2.cc
index 5ea78ad..a2857fa 100644
--- a/tensorflow/core/kernels/cwise_op_neg_2.cc
+++ b/tensorflow/core/kernels/cwise_op_neg_2.cc
@@ -20,7 +20,12 @@
           bfloat16, complex64, complex128);
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER6(UnaryOp, GPU, "Neg", functor::neg, Eigen::half, float, double,
           bfloat16, complex64, complex128);
+#else
+REGISTER3(UnaryOp, GPU, "Neg", functor::neg, bfloat16, complex64, complex128);
+#endif
 #endif
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
index 49edd3f..22f30eb 100644
--- a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
+++ b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
@@ -21,8 +21,13 @@
 REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32,
           uint64, qint8, qint16, quint8, quint16);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
           double, uint8);
+#else
+REGISTER(BinaryOp, GPU, "NotEqual", functor::not_equal_to, uint8);
+#endif
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
 // registration requires all int32 inputs and outputs to be in host memory.
diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_2.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
index 9b23960..dc0adc9 100644
--- a/tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
+++ b/tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
@@ -25,8 +25,14 @@
 REGISTER6(BinaryOp, CPU, "NotEqual", functor::not_equal_to, int32, int64,
           complex64, complex128, tstring, bool);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER6(BinaryOp, GPU, "NotEqual", functor::not_equal_to, int8, int16, int64,
           complex64, complex128, bool);
+#else
+REGISTER2(BinaryOp, GPU, "NotEqual", functor::not_equal_to, complex64,
+          complex128);
+#endif
 
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
diff --git a/tensorflow/core/kernels/cwise_op_real.cc b/tensorflow/core/kernels/cwise_op_real.cc
index cb84848..3513768 100644
--- a/tensorflow/core/kernels/cwise_op_real.cc
+++ b/tensorflow/core/kernels/cwise_op_real.cc
@@ -29,7 +29,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER_COMPLEX(GPU, float, complex64);
 REGISTER_COMPLEX(GPU, double, complex128);
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_right_shift.cc b/tensorflow/core/kernels/cwise_op_right_shift.cc
index 2bf819c..b72ec49 100644
--- a/tensorflow/core/kernels/cwise_op_right_shift.cc
+++ b/tensorflow/core/kernels/cwise_op_right_shift.cc
@@ -21,8 +21,14 @@
 
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER8(BinaryOp, GPU, "RightShift", functor::right_shift, int8, int16, int32,
           int64, uint8, uint16, uint32, uint64);
+#else
+REGISTER4(BinaryOp, GPU, "RightShift", functor::right_shift, uint8, uint16,
+          uint32, uint64);
+#endif
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc
index cb6c1ef..7d67388 100644
--- a/tensorflow/core/kernels/cwise_op_rsqrt.cc
+++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc
@@ -21,7 +21,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double);
 #endif
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_sign.cc b/tensorflow/core/kernels/cwise_op_sign.cc
index c4ef05e..7c1e9fb 100644
--- a/tensorflow/core/kernels/cwise_op_sign.cc
+++ b/tensorflow/core/kernels/cwise_op_sign.cc
@@ -20,9 +20,12 @@
           complex64, Eigen::half, bfloat16, complex128);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER6(UnaryOp, GPU, "Sign", functor::sign, float, Eigen::half, double,
           int64, complex64, complex128);
+#else
+REGISTER2(UnaryOp, GPU, "Sign", functor::sign, complex64, complex128);
+#endif
 
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -34,6 +37,5 @@
                             .TypeConstraint<int32>("T"),
                         UnaryOp<CPUDevice, functor::sign<int32>>);
 #endif
-#endif
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sin.cc b/tensorflow/core/kernels/cwise_op_sin.cc
index 3689f8b..b137e69 100644
--- a/tensorflow/core/kernels/cwise_op_sin.cc
+++ b/tensorflow/core/kernels/cwise_op_sin.cc
@@ -21,7 +21,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(UnaryOp, GPU, "Sin", functor::sin, float, Eigen::half, double);
 #endif
 #endif
diff --git a/tensorflow/core/kernels/cwise_op_sqrt.cc b/tensorflow/core/kernels/cwise_op_sqrt.cc
index 32c78e4..dfa0b8e 100644
--- a/tensorflow/core/kernels/cwise_op_sqrt.cc
+++ b/tensorflow/core/kernels/cwise_op_sqrt.cc
@@ -21,7 +21,7 @@
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
-    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER3(UnaryOp, GPU, "Sqrt", functor::sqrt, float, Eigen::half, double);
 #endif
 #endif
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc
index 12ffd55..d3c5e4b 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc
+++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc
@@ -117,11 +117,19 @@
   for (const auto& tensor : tensors) {
     TensorProto proto;
     tensor.AsProtoTensorContent(&proto);
-#if defined(PLATFORM_GOOGLE)
-    TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsCord()));
-#else   // PLATFORM_GOOGLE
+#if defined(TF_CORD_SUPPORT)
+    // Creating raw pointer here because std::move() in a releases in OSS TF
+    // will result in a smart pointer being moved upon function creation, which
+    // will result in proto_buffer == nullptr when WriteRecord happens.
+    auto proto_buffer = new std::string();
+    proto.SerializeToString(proto_buffer);
+    absl::Cord proto_serialized = absl::MakeCordFromExternal(
+        *proto_buffer,
+        [proto_buffer](absl::string_view) { delete proto_buffer; });
+    TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
+#else   // TF_CORD_SUPPORT
     TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString()));
-#endif  // PLATFORM_GOOGLE
+#endif  // TF_CORD_SUPPORT
   }
   return Status::OK();
 }
@@ -197,16 +205,16 @@
       TensorProto* t = record.add_tensor();
       tensor.AsProtoTensorContent(t);
     }
-#if defined(PLATFORM_GOOGLE)
-    return WriteRecord(record.SerializeAsCord());
-#else   // PLATFORM_GOOGLE
+#if defined(TF_CORD_SUPPORT)
+    auto record_buffer = new std::string();
+    record.SerializeToString(record_buffer);
+    absl::Cord record_serialized = absl::MakeCordFromExternal(
+        *record_buffer,
+        [record_buffer](absl::string_view) { delete record_buffer; });
+    return WriteRecord(record_serialized);
+#else   // TF_CORD_SUPPORT
     return WriteRecord(record.SerializeAsString());
-#endif  // PLATFORM_GOOGLE
-  }
-
-  if (compression_type_ != io::compression::kSnappy) {
-    return errors::InvalidArgument("Compression ", compression_type_,
-                                   " is not supported.");
+#endif  // TF_CORD_SUPPORT
   }
 
   std::vector<const TensorBuffer*> tensor_buffers;
@@ -258,11 +266,16 @@
   if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
     return errors::Internal("Failed to compress using snappy.");
   }
-#if defined(PLATFORM_GOOGLE)
-  absl::Cord metadata_serialized = metadata.SerializeAsCord();
-#else   // PLATFORM_GOOGLE
+
+#if defined(TF_CORD_SUPPORT)
+  auto metadata_buffer = new std::string();
+  metadata.SerializeToString(metadata_buffer);
+  absl::Cord metadata_serialized = absl::MakeCordFromExternal(
+      *metadata_buffer,
+      [metadata_buffer](absl::string_view) { delete metadata_buffer; });
+#else
   std::string metadata_serialized = metadata.SerializeAsString();
-#endif  // PLATFORM_GOOGLE
+#endif  // TF_CORD_SUPPORT
   TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
   TF_RETURN_IF_ERROR(WriteRecord(output));
   return Status::OK();
@@ -296,14 +309,14 @@
   return dest_->Append(data);
 }
 
-#if defined(PLATFORM_GOOGLE)
+#if defined(TF_CORD_SUPPORT)
 Status CustomWriter::WriteRecord(const absl::Cord& data) {
   char header[kHeaderSize];
   core::EncodeFixed64(header, data.size());
   TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
   return dest_->Append(data);
 }
-#endif  // PLATFORM_GOOGLE
+#endif  // TF_CORD_SUPPORT
 
 Status Reader::Create(Env* env, const std::string& filename,
                       const string& compression_type, int version,
@@ -722,19 +735,9 @@
       auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first);
       size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
       TensorProto tp;
-#if defined(PLATFORM_GOOGLE)
-      absl::string_view tensor_proto_view(tensor_proto_str.get(),
-                                          tensor_proto_size);
-      absl::Cord c = absl::MakeCordFromExternal(
-          tensor_proto_view, [s = std::move(tensor_proto_str)] {});
-      if (!tp.ParseFromCord(c)) {
-        return errors::Internal("Could not parse TensorProto");
-      }
-#else   // PLATFORM_GOOGLE
       if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
         return errors::Internal("Could not parse TensorProto");
       }
-#endif  // PLATFORM_GOOGLE
       Tensor t;
       if (!t.FromProto(tp)) {
         return errors::Internal("Could not parse Tensor");
@@ -824,7 +827,7 @@
   return input_stream_->ReadNBytes(length, record);
 }
 
-#if defined(PLATFORM_GOOGLE)
+#if defined(TF_CORD_SUPPORT)
 Status CustomReader::ReadRecord(absl::Cord* record) {
   tstring header;
   TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
@@ -832,15 +835,15 @@
   if (compression_type_ == io::compression::kNone) {
     return input_stream_->ReadNBytes(length, record);
   } else {
-    auto tmp_str = absl::make_unique<tstring>();
-    TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str.get()));
+    auto tmp_str = new tstring();
+    TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str));
     absl::string_view tmp_str_view(*tmp_str);
-    record->Append(
-        absl::MakeCordFromExternal(tmp_str_view, [s = std::move(tmp_str)] {}));
+    record->Append(absl::MakeCordFromExternal(
+        tmp_str_view, [tmp_str](absl::string_view) { delete tmp_str; }));
     return Status::OK();
   }
 }
-#endif
+#endif  // TF_CORD_SUPPORT
 
 Status WriteMetadataFile(Env* env, const string& dir,
                          const experimental::SnapshotMetadataRecord* metadata) {
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.h b/tensorflow/core/kernels/data/experimental/snapshot_util.h
index 5b22846..35bd1f5 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_util.h
+++ b/tensorflow/core/kernels/data/experimental/snapshot_util.h
@@ -146,9 +146,9 @@
  private:
   Status WriteRecord(const StringPiece& data);
 
-#if defined(PLATFORM_GOOGLE)
+#if defined(TF_CORD_SUPPORT)
   Status WriteRecord(const absl::Cord& data);
-#endif  // PLATFORM_GOOGLE
+#endif  // TF_CORD_SUPPORT
 
   std::unique_ptr<WritableFile> dest_;
   const std::string filename_;
@@ -265,7 +265,7 @@
 
   Status ReadRecord(tstring* record);
 
-#if defined(PLATFORM_GOOGLE)
+#if defined(TF_CORD_SUPPORT)
   Status ReadRecord(absl::Cord* record);
 #endif
 
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index e23759a..924f9a4 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -946,6 +946,17 @@
   return type_string() == "IteratorGetNextSync" ? nullptr : this;
 }
 
+void RecordElementSize(const std::vector<Tensor> element,
+                       profiler::TraceMe* traceme) {
+  traceme->AppendMetadata([&]() {
+    int64 element_size = 0;
+    for (const auto& component : element) {
+      element_size += component.TotalBytes();
+    }
+    return profiler::TraceMeEncode({{"element_size", element_size}});
+  });
+}
+
 Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
   profiler::TraceMe traceme(
       [&] {
@@ -968,6 +979,7 @@
   }
   TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
   TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
+  RecordElementSize(components, &traceme);
   for (int i = 0; i < components.size(); ++i) {
     ctx->set_output(i, components[i]);
   }
@@ -995,6 +1007,7 @@
   if (end_of_sequence) {
     return WriteOptionalNoneToOutput(ctx, 0);
   } else {
+    RecordElementSize(components, &traceme);
     for (int i = 0; i < components.size(); ++i) {
       if (components[i].dtype() != output_types_[i]) {
         return errors::InvalidArgument(
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 6c86222..5d85d14 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -85,7 +85,7 @@
     // clang-format off
     absl::flat_hash_map<string, uint64> live_experiments = {
         {"enable_gradient_descent", 0},
-        {"map_parallelization", 100}
+        {"map_parallelization", 0}
     };
     // clang-format on
     auto hash_func = [](const string& str) { return Hash64(str); };
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 42244ca..170e001 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -448,18 +448,20 @@
                   model_node());
             },
             std::move(input_element));
-        // `ctx->runner()` may execute its logic synchronously so we wrap it in
-        // `RecordStop` and `RecordStart` to prevent invalid nesting of
-        // `RecordStart` calls.
-        RecordStop(ctx.get());
         (*ctx->runner())(
             [this, ctx, fn = std::move(fn), done = std::move(done)]() {
-              RecordStart(ctx.get());
-              auto cleanup =
-                  gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
-              done(fn());
+              Status s;
+              // Check whether we are already recording to prevent invalid
+              // nesting of `RecordStart` calls.
+              if (IsRecording(ctx.get())) {
+                s = fn();
+              } else {
+                RecordStart(ctx.get());
+                s = fn();
+                RecordStop(ctx.get());
+              }
+              done(s);
             });
-        RecordStart(ctx.get());
       }
     }
 
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 6d6a259..8a04de6 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -299,14 +299,22 @@
 
     data::TraceMeMetadata GetTraceMeMetadata() const override {
       int64 limit = -1, size = -1;
+      data::TraceMeMetadata result;
       // NOTE: We only set the parallelism value if the lock can be acquired
       // right away to avoid introducing tracing overhead.
       if (mu_->try_lock()) {
         limit = buffer_limit();
         size = buffer_.size();
+        if (!buffer_.empty()) {
+          std::vector<std::string> shapes(buffer_.front().value.size());
+          for (const auto& component : buffer_.front().value) {
+            shapes.push_back(component.shape().DebugString());
+          }
+          result.push_back(std::make_pair("next_element_shapes",
+                                          absl::StrJoin(shapes, ",")));
+        }
         mu_->unlock();
       }
-      data::TraceMeMetadata result;
       result.push_back(std::make_pair(
           "buffer_limit",
           strings::Printf("%lld", static_cast<long long>(limit))));
diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc
index b52d4d6..00724b6 100644
--- a/tensorflow/core/kernels/data_format_ops.cc
+++ b/tensorflow/core/kernels/data_format_ops.cc
@@ -18,16 +18,52 @@
 #define EIGEN_USE_THREADS
 
 #include "tensorflow/core/kernels/data_format_ops.h"
+
+#include <map>
+
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/errors.h"
 
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
 typedef Eigen::GpuDevice GPUDevice;
 
+// Ensure that `src` and `dst` define a valid permutation.
+// Ops defined in this file assume that user specifies a permutation via two
+// string attributes. This check validates that these attributes properly define
+// it to prevent security vulnerabilities.
+static bool IsValidPermutation(const std::string& src, const std::string& dst) {
+  if (src.size() != dst.size()) {
+    return false;
+  }
+
+  std::map<char, bool> characters;
+
+  // Every character in `src` must be present only once
+  for (const auto c : src) {
+    if (characters[c]) {
+      return false;
+    }
+    characters[c] = true;
+  }
+
+  // Every character in `dst` must show up in `src` exactly once
+  for (const auto c : dst) {
+    if (!characters[c]) {
+      return false;
+    }
+    characters[c] = false;
+  }
+
+  // At this point, characters[] has been switched to true and false exactly
+  // once for all character in `src` (and `dst`) so we have a valid permutation
+  return true;
+}
+
 template <typename Device, typename T>
 class DataFormatDimMapOp : public OpKernel {
  public:
@@ -38,15 +74,19 @@
     string dst_format;
     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
     OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
-                errors::InvalidArgument(strings::StrCat(
-                    "Source format must of length 4 or 5, received "
+                errors::InvalidArgument(
+                    "Source format must be of length 4 or 5, received "
                     "src_format = ",
-                    src_format)));
+                    src_format));
+    OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
+                errors::InvalidArgument("Destination format must be of length "
+                                        "4 or 5, received dst_format = ",
+                                        dst_format));
     OP_REQUIRES(
-        context, dst_format.size() == 4 || dst_format.size() == 5,
-        errors::InvalidArgument(strings::StrCat(
-            "Destination format must of length 4 or 5, received dst_format = ",
-            dst_format)));
+        context, IsValidPermutation(src_format, dst_format),
+        errors::InvalidArgument(
+            "Destination and source format must determine a permutation, got ",
+            src_format, " and ", dst_format));
     dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
     for (int i = 0; i < src_format.size(); ++i) {
       for (int j = 0; j < dst_format.size(); ++j) {
@@ -78,8 +118,22 @@
       : OpKernel(context) {
     string src_format;
     OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
+    OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
+                errors::InvalidArgument(
+                    "Source format must be of length 4 or 5, received "
+                    "src_format = ",
+                    src_format));
     string dst_format;
     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
+    OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
+                errors::InvalidArgument("Destination format must be of length "
+                                        "4 or 5, received dst_format = ",
+                                        dst_format));
+    OP_REQUIRES(
+        context, IsValidPermutation(src_format, dst_format),
+        errors::InvalidArgument(
+            "Destination and source format must determine a permutation, got ",
+            src_format, " and ", dst_format));
     src_format_ = src_format;
     dst_format_ = dst_format;
   }
@@ -127,6 +181,10 @@
       };
       keep_only_spatial_dimensions(&src_format_str);
       keep_only_spatial_dimensions(&dst_format_str);
+      OP_REQUIRES(context,
+                  src_format_str.size() == 2 && dst_format_str.size() == 2,
+                  errors::InvalidArgument(
+                      "Format specifier must contain H and W for 2D case"));
     }
     ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx);
 
diff --git a/tensorflow/core/kernels/depthwise_conv_op.h b/tensorflow/core/kernels/depthwise_conv_op.h
index 094e2cf..568e8ab 100644
--- a/tensorflow/core/kernels/depthwise_conv_op.h
+++ b/tensorflow/core/kernels/depthwise_conv_op.h
@@ -193,27 +193,19 @@
                   const int64 padded_filter_inner_dim_size, const int64 out_r,
                   const int64 out_c, const T* input, T* input_buffer) {
     typedef typename Eigen::internal::packet_traits<T>::type Packet;
-    static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
+    static const int64 kPacketSize = Eigen::internal::packet_traits<T>::size;
 
+    const int64 kDepth = args.depth_multiplier;
     // Calculate vectorized and scalar (residual) lengths for 'in_depth'.
     const int64 input_vectorized_size =
         (args.in_depth / kPacketSize) * kPacketSize;
-    const int64 input_scalar_size = args.in_depth % kPacketSize;
-
-    // Calculate vectorized and scalar (residual) lengths for
-    // 'depth_multiplier'. This is used to efficiently replicate data for
-    // when 'depth_multiplier' > kPacketSize.
-    const int64 dm_vectorized_size =
-        (args.depth_multiplier / kPacketSize) * kPacketSize;
-    const int64 dm_scalar_size = args.depth_multiplier % kPacketSize;
+    const int64 input_scalar_size = args.in_depth - input_vectorized_size;
 
     // Calculate output padding length.
     const int64 output_scalar_size = args.out_depth % kPacketSize;
     const int64 output_pad_size =
         output_scalar_size > 0 ? kPacketSize - output_scalar_size : 0;
 
-    const int64 replicated_packet_size = kPacketSize * args.depth_multiplier;
-
     // Iterate through all rows x cols reading 'in_depth' from 'input' and
     // replicating by 'depth_multiplier' into 'input_buffer' (otherwise
     // zero-padding input buffer as needed).
@@ -221,60 +213,126 @@
     const int64 in_r_start = out_r * args.stride - args.pad_rows;
     const int64 in_c_start = out_c * args.stride - args.pad_cols;
 
-    for (int64 f_r = 0; f_r < args.filter_rows; ++f_r) {
-      const int64 in_r = in_r_start + f_r;
+    // TODO: add a ploaddup variant for depth == 2 if needed.
+    if (kDepth > 1 && kDepth <= kPacketSize) {
+      for (int64 f_r = 0; f_r < args.filter_rows; ++f_r) {
+        const int64 in_r = in_r_start + f_r;
 
-      for (int64 f_c = 0; f_c < args.filter_cols; ++f_c) {
-        const int64 in_c = in_c_start + f_c;
+        for (int64 f_c = 0; f_c < args.filter_cols; ++f_c) {
+          const int64 in_c = in_c_start + f_c;
 
-        if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 &&
-            in_c < args.in_cols) {
-          auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth;
-          // Copy vectorized portion of inner dimension.
-          for (int64 d = 0; d < input_vectorized_size; d += kPacketSize) {
-            auto v = Eigen::internal::ploadu<Packet>(in + d);
-            for (int dm = 0; dm < args.depth_multiplier; ++dm) {
-              Eigen::internal::pscatter<T, Packet>(in_buf + dm, v,
-                                                   args.depth_multiplier);
+          if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 &&
+              in_c < args.in_cols) {
+            const auto* in =
+                input + (in_r * args.in_cols + in_c) * args.in_depth;
+            int64 limit = args.in_depth;
+            // This will overwrite up to kPacketSize next elements,
+            // this is ok on all iterations except the last one, since
+            // we will write correct values on a next iteration.
+            if (f_c == args.filter_cols - 1) {
+              limit -= (kPacketSize - kDepth) / kDepth + 1;
+              if (limit < 0) {
+                limit = 0;
+              }
             }
-            in_buf += replicated_packet_size;
-          }
+            // Copy vectorized portion of inner dimension.
+            for (int64 d = 0; d < limit; d++) {
+              const auto p = Eigen::internal::pset1<Packet>(in[d]);
+              Eigen::internal::pstoreu<T>(in_buf, p);
+              in_buf += kDepth;
+            }
 
-          // Copy scalar portion of inner dimension.
-          for (int64 d = 0; d < input_scalar_size; ++d) {
-            T v = in[input_vectorized_size + d];
-            const int64 base = d * args.depth_multiplier;
-            if (dm_vectorized_size > 0) {
-              // Copy vectorized portion of replicated output.
-              // This branch is only taken if 'args.depth_multiplier' is
-              // vectorizable (i.e. args.depth_multiplier >= register width).
-              auto p = Eigen::internal::pset1<Packet>(v);
+            // Copy the scalar portion.
+            for (int64 d = limit; d < args.in_depth; d++) {
+              const auto value = in[d];
+              for (int64 dm = 0; dm < kDepth; dm++) {
+                in_buf[dm] = value;
+              }
+              in_buf += kDepth;
+            }
+
+            // Pad the remainder of the output to vector register boundary.
+            for (int64 d = 0; d < output_pad_size; ++d) {
+              in_buf[d] = static_cast<T>(0);
+            }
+            in_buf += output_pad_size;
+          } else {
+            // Zero pad.
+            memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size);
+            in_buf += padded_filter_inner_dim_size;
+          }
+        }
+      }
+    } else if (kDepth > kPacketSize) {
+      // Calculate vectorized and scalar (residual) lengths for
+      // 'depth_multiplier'. This is used to efficiently replicate data for
+      // when 'depth_multiplier' > kPacketSize.
+      const int64 dm_vectorized_size = (kDepth / kPacketSize) * kPacketSize;
+
+      for (int64 f_r = 0; f_r < args.filter_rows; ++f_r) {
+        const int64 in_r = in_r_start + f_r;
+
+        for (int64 f_c = 0; f_c < args.filter_cols; ++f_c) {
+          const int64 in_c = in_c_start + f_c;
+
+          if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 &&
+              in_c < args.in_cols) {
+            const auto* in =
+                input + (in_r * args.in_cols + in_c) * args.in_depth;
+            // Copy vectorized portion of inner dimension.
+            for (int64 d = 0; d < args.in_depth; d++) {
+              const auto p = Eigen::internal::pset1<Packet>(in[d]);
               for (int64 dm = 0; dm < dm_vectorized_size; dm += kPacketSize) {
-                Eigen::internal::pstoreu<T>(in_buf + base + dm, p);
+                Eigen::internal::pstoreu<T>(in_buf + dm, p);
               }
-              // Copy scalar portion of replicated output.
-              for (int64 dm = 0; dm < dm_scalar_size; ++dm) {
-                in_buf[base + dm_vectorized_size + dm] = v;
-              }
-            } else {
-              // Depth multiplier is less than one packet: scalar copy.
-              for (int dm = 0; dm < args.depth_multiplier; ++dm) {
-                in_buf[base + dm] = v;
-              }
+              // Overlapping store for the remainder.
+              Eigen::internal::pstoreu<T>(in_buf + kDepth - kPacketSize, p);
+              in_buf += kDepth;
             }
+            // Pad the remainder of the output to vector register boundary.
+            for (int64 d = 0; d < output_pad_size; ++d) {
+              in_buf[d] = static_cast<T>(0);
+            }
+            in_buf += output_pad_size;
+          } else {
+            // Zero pad.
+            memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size);
+            in_buf += padded_filter_inner_dim_size;
           }
-          in_buf += input_scalar_size * args.depth_multiplier;
+        }
+      }
+    } else if (kDepth == 1) {
+      for (int64 f_r = 0; f_r < args.filter_rows; ++f_r) {
+        const int64 in_r = in_r_start + f_r;
 
-          // Pad the remainder of the output to vector register boundary.
-          for (int64 d = 0; d < output_pad_size; ++d) {
-            in_buf[d] = static_cast<T>(0);
+        for (int64 f_c = 0; f_c < args.filter_cols; ++f_c) {
+          const int64 in_c = in_c_start + f_c;
+
+          if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 &&
+              in_c < args.in_cols) {
+            const auto* in =
+                input + (in_r * args.in_cols + in_c) * args.in_depth;
+            for (int64 d = 0; d < input_vectorized_size; d += kPacketSize) {
+              const auto p = Eigen::internal::ploadu<Packet>(in + d);
+              Eigen::internal::pstoreu<T>(in_buf, p);
+              in_buf += kPacketSize;
+            }
+            for (int64 d = 0; d < input_scalar_size; ++d) {
+              T v = in[input_vectorized_size + d];
+              in_buf[d] = v;
+            }
+            in_buf += input_scalar_size;
+
+            // Pad the remainder of the output to vector register boundary.
+            for (int64 d = 0; d < output_pad_size; ++d) {
+              in_buf[d] = static_cast<T>(0);
+            }
+            in_buf += output_pad_size;
+          } else {
+            // Zero pad.
+            memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size);
+            in_buf += padded_filter_inner_dim_size;
           }
-          in_buf += output_pad_size;
-
-        } else {
-          // Zero pad.
-          memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size);
-          in_buf += padded_filter_inner_dim_size;
         }
       }
     }
diff --git a/tensorflow/core/kernels/depthwise_conv_ops_test.cc b/tensorflow/core/kernels/depthwise_conv_ops_test.cc
index ba4b167..f47880a 100644
--- a/tensorflow/core/kernels/depthwise_conv_ops_test.cc
+++ b/tensorflow/core/kernels/depthwise_conv_ops_test.cc
@@ -102,7 +102,7 @@
   Run<Eigen::half>(Device::CPU);
 }
 
-#ifdef GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 TEST_F(DepthwiseConvOpTest, DepthwiseConvFloatGpu) { Run<float>(Device::GPU); }
 TEST_F(DepthwiseConvOpTest, DepthwiseConvDoubleGpu) {
   Run<double>(Device::GPU);
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index ac02d3b..6fd0b8a 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -364,8 +364,8 @@
 
     // Original input column and row after applying all non-standard strides and
     // dilations. Computed by padOrSkip{Row,Col}.
-    Index orig_c;
-    Index orig_r;
+    Index orig_c = 0;
+    Index orig_r = 0;
 
     for (StorageIndex col = 0; col < cols; ++col) {
       SubMapper lm = rhs.getLinearMapper(0, col);
diff --git a/tensorflow/core/kernels/gather_nd_op.h b/tensorflow/core/kernels/gather_nd_op.h
index 46414a3..5de12f0 100644
--- a/tensorflow/core/kernels/gather_nd_op.h
+++ b/tensorflow/core/kernels/gather_nd_op.h
@@ -161,7 +161,8 @@
           "indices", SliceDebugString(shape, bad_i), " = [",
           str_util::Join(
               gtl::ArraySlice<Index>(&indices_mat(bad_i, 0), indices_nd), ", "),
-          "] does not index into param shape ", params.shape().DebugString());
+          "] does not index into param shape ", params.shape().DebugString(),
+          ", node name: ", c->op_kernel().name());
     }
   }
   return Status::OK();
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index e9e6a93..e0b909a 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -88,29 +88,31 @@
       axis = params.dims() + axis;
     }
 
-    if (batch_dims_ != 0) {
-      OP_REQUIRES(
-          c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(),
-          errors::InvalidArgument("Expected batch_dims in the range [",
-                                  -indices.dims(), ", ", indices.dims(),
-                                  "], but got ", batch_dims_));
+    // Modify only a local copy of batch_dims_.
+    int32 batch_dims = batch_dims_;
+    if (batch_dims != 0) {
+      OP_REQUIRES(c,
+                  batch_dims >= -indices.dims() && batch_dims <= indices.dims(),
+                  errors::InvalidArgument("Expected batch_dims in the range [",
+                                          -indices.dims(), ", ", indices.dims(),
+                                          "], but got ", batch_dims));
 
-      if (batch_dims_ < 0) {
-        batch_dims_ = indices.dims() + batch_dims_;
+      if (batch_dims < 0) {
+        batch_dims = indices.dims() + batch_dims;
       }
 
-      if (!axis_is_set) axis = batch_dims_;
+      if (!axis_is_set) axis = batch_dims;
 
-      OP_REQUIRES(c, batch_dims_ < params.dims(),
-                  errors::InvalidArgument("batch_dims (", batch_dims_,
+      OP_REQUIRES(c, batch_dims < params.dims(),
+                  errors::InvalidArgument("batch_dims (", batch_dims,
                                           ") must be less than rank(params) (",
                                           params.dims(), ")."));
 
-      OP_REQUIRES(c, axis >= batch_dims_,
-                  errors::InvalidArgument("batch_dims (", batch_dims_,
+      OP_REQUIRES(c, axis >= batch_dims,
+                  errors::InvalidArgument("batch_dims (", batch_dims,
                                           ") must be less than or equal to ",
                                           "axis (", axis, ")."));
-      for (int i = 0; i < batch_dims_; ++i) {
+      for (int i = 0; i < batch_dims; ++i) {
         OP_REQUIRES(c, params.dim_size(i) == indices.dim_size(i),
                     errors::InvalidArgument(
                         "params.shape[", i, "]: ", params.dim_size(i),
@@ -136,15 +138,15 @@
     int64 outer_size = 1;
     int64 inner_size = 1;
 
-    for (int i = 0; i < batch_dims_; ++i) {
+    for (int i = 0; i < batch_dims; ++i) {
       result_shape.AddDim(params.dim_size(i));
       batch_size *= params.dim_size(i);
     }
-    for (int i = batch_dims_; i < axis; ++i) {
+    for (int i = batch_dims; i < axis; ++i) {
       result_shape.AddDim(params.dim_size(i));
       outer_size *= params.dim_size(i);
     }
-    for (int i = batch_dims_; i < indices.dims(); ++i) {
+    for (int i = batch_dims; i < indices.dims(); ++i) {
       result_shape.AddDim(indices.dim_size(i));
     }
     for (int i = axis + 1; i < params.dims(); ++i) {
@@ -159,7 +161,7 @@
 
     int64 bad_i = -1;
     auto indices_flat = indices.flat<Index>();
-    if (batch_dims_ > 0) {
+    if (batch_dims > 0) {
       auto params_flat = params.shaped<T, 4>(
           {batch_size, outer_size, gather_dim_size, inner_size});
       auto out_flat = out->shaped<T, 4>(
diff --git a/tensorflow/core/kernels/image/BUILD b/tensorflow/core/kernels/image/BUILD
index a94e98d..8a1ff9d 100644
--- a/tensorflow/core/kernels/image/BUILD
+++ b/tensorflow/core/kernels/image/BUILD
@@ -98,6 +98,7 @@
 # Public support libraries ----------------------------------------------------<
 cc_library(
     name = "image",
+    visibility = ["//visibility:public"],
     deps = [
         ":adjust_contrast_op",
         ":adjust_hue_op",
diff --git a/tensorflow/core/kernels/image/non_max_suppression_op.cc b/tensorflow/core/kernels/image/non_max_suppression_op.cc
index 6aecb7f..c357154 100644
--- a/tensorflow/core/kernels/image/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/image/non_max_suppression_op.cc
@@ -194,7 +194,7 @@
   T scale = static_cast<T>(0.0);
   bool is_soft_nms = soft_nms_sigma > static_cast<T>(0.0);
   if (is_soft_nms) {
-    scale = static_cast<T>(-1.0) / soft_nms_sigma;
+    scale = static_cast<T>(-0.5) / soft_nms_sigma;
   }
 
   auto suppress_weight = [similarity_threshold, scale,
@@ -323,12 +323,6 @@
     }
   }
 
-  // Copy class_boxes_data to a tensor
-  TensorShape boxesShape({num_boxes, 4});
-  Tensor boxes(DT_FLOAT, boxesShape);
-  std::copy_n(class_boxes_data.begin(), class_boxes_data.size(),
-              boxes.unaligned_flat<float>().data());
-
   // Do NMS, get the candidate indices of form vector<int>
   // Data structure for selection candidate in NMS.
   struct Candidate {
@@ -350,9 +344,10 @@
   Candidate next_candidate;
 
   std::sort(candidate_vector.begin(), candidate_vector.end(), cmp);
-  const Tensor const_boxes = boxes;
-  typename TTypes<float, 2>::ConstTensor boxes_data_t =
-      const_boxes.tensor<float, 2>();
+  // Move class_boxes_data to a tensor
+  Eigen::array<Eigen::DenseIndex, 2> boxesShape = {num_boxes, 4};
+  typename TTypes<float, 2>::ConstTensor boxes_data_t(class_boxes_data.data(),
+                                                      boxesShape);
   int candidate_idx = 0;
   float iou;
   while (selected.size() < size_per_class &&
diff --git a/tensorflow/core/kernels/image/non_max_suppression_op_test.cc b/tensorflow/core/kernels/image/non_max_suppression_op_test.cc
index b4a92a4..84754b3 100644
--- a/tensorflow/core/kernels/image/non_max_suppression_op_test.cc
+++ b/tensorflow/core/kernels/image/non_max_suppression_op_test.cc
@@ -694,7 +694,7 @@
   AddInputFromArray<int>(TensorShape({}), {6});
   AddInputFromArray<float>(TensorShape({}), {0.5f});
   AddInputFromArray<float>(TensorShape({}), {0.0f});
-  AddInputFromArray<float>(TensorShape({}), {1.0f});
+  AddInputFromArray<float>(TensorShape({}), {0.5f});
   TF_ASSERT_OK(RunOpKernel());
 
   Tensor expected(allocator(), DT_INT32, TensorShape({6}));
diff --git a/tensorflow/core/kernels/image/resize_bilinear_op_test.cc b/tensorflow/core/kernels/image/resize_bilinear_op_test.cc
index df00ca2..fe0d4d1 100644
--- a/tensorflow/core/kernels/image/resize_bilinear_op_test.cc
+++ b/tensorflow/core/kernels/image/resize_bilinear_op_test.cc
@@ -533,7 +533,7 @@
 INSTANTIATE_TEST_SUITE_P(ResizeBilinearOpAlignCornersTestCpu,
                          ResizeBilinearOpAlignCornersTest,
                          ::testing::Values(TestDevice::CPU));
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 // Instantiate tests for GPU.
 INSTANTIATE_TEST_SUITE_P(ResizeBilinearOpTestGpu, ResizeBilinearOpTest,
                          ::testing::Values(TestDevice::GPU));
@@ -543,7 +543,7 @@
 INSTANTIATE_TEST_SUITE_P(ResizeBilinearOpAlignCornersTestGpu,
                          ResizeBilinearOpAlignCornersTest,
                          ::testing::Values(TestDevice::GPU));
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 class ResizeBM : public ResizeBilinearOpTest {
  public:
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index 22e640f..b60d545 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -19,7 +19,9 @@
 
 #include "tensorflow/core/kernels/maxpooling_op.h"
 
+#include <type_traits>
 #include <vector>
+
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/framework/bounds_check.h"
@@ -56,7 +58,7 @@
 
 const int kInvalidMaxPoolingIndex = -1;
 
-template <typename Device, typename T>
+template <typename Device, typename T, typename Targmax>
 static void SpatialMaxPoolWithArgMaxHelper(
     OpKernelContext* context, Tensor* output, Tensor* output_arg_max,
     Tensor* input_backprop, const Tensor& tensor_in, const Tensor& out_backprop,
@@ -67,13 +69,17 @@
         errors::Internal(
             "SpatialMaxPoolWithArgMaxHelper requires include_batch_in_index "
             "to be True when input_backprop != nullptr"));
+    OP_REQUIRES(
+        context, (std::is_same<Targmax, int64>::value),
+        errors::Internal("SpatialMaxPoolWithArgMaxHelper requires Targmax "
+                         "to be int64 when input_backprop != nullptr"));
   }
 
   typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
       ConstEigenMatrixMap;
   typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
       EigenMatrixMap;
-  typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
+  typedef Eigen::Map<Eigen::Matrix<Targmax, Eigen::Dynamic, Eigen::Dynamic>>
       EigenIndexMatrixMap;
 
   ConstEigenMatrixMap in_mat(
@@ -83,7 +89,7 @@
       output->flat<T>().data(), params.depth,
       params.out_width * params.out_height * params.tensor_in_batch);
   EigenIndexMatrixMap out_arg_max_mat(
-      output_arg_max->flat<int64>().data(), params.depth,
+      output_arg_max->flat<Targmax>().data(), params.depth,
       params.out_width * params.out_height * params.tensor_in_batch);
 
   const DeviceBase::CpuWorkerThreads& worker_threads =
@@ -150,7 +156,8 @@
               for (int d = 0; d < depth; ++d) {
                 const T& input_ref = in_mat.coeffRef(d, in_index);
                 T& output_ref = out_mat.coeffRef(d, out_index);
-                int64& out_arg_max_ref = out_arg_max_mat.coeffRef(d, out_index);
+                Targmax& out_arg_max_ref =
+                    out_arg_max_mat.coeffRef(d, out_index);
                 if (output_ref < input_ref ||
                     out_arg_max_ref == kInvalidMaxPoolingIndex) {
                   output_ref = input_ref;
@@ -319,7 +326,7 @@
     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
                                 {0}, 0, output_shape, &output));
 
-    SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
+    SpatialMaxPoolWithArgMaxHelper<CPUDevice, T, int64>(
         context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in,
         out_backprop, params, true);
   }
@@ -900,22 +907,22 @@
   TensorFormat data_format_;
 };
 
-template <typename Device, typename T>
+template <typename Device, typename T, typename Targmax>
 struct LaunchMaxPoolingWithArgmax;
 
-template <typename T>
-struct LaunchMaxPoolingWithArgmax<CPUDevice, T> {
+template <typename T, typename Targmax>
+struct LaunchMaxPoolingWithArgmax<CPUDevice, T, Targmax> {
   static void launch(OpKernelContext* context, const PoolParameters& params,
                      const Tensor& input, Tensor* output, Tensor* argmax,
                      bool propagate_nans, bool include_batch_in_index) {
     Tensor unused;
-    SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(context, output, argmax,
-                                                 nullptr, input, unused, params,
-                                                 include_batch_in_index);
+    SpatialMaxPoolWithArgMaxHelper<CPUDevice, T, Targmax>(
+        context, output, argmax, /*input_backprop=*/nullptr, input, unused,
+        params, include_batch_in_index);
   }
 };
 
-template <typename Device, typename T>
+template <typename Device, typename T, typename Targmax>
 class MaxPoolingWithArgmaxOp : public OpKernel {
  public:
   explicit MaxPoolingWithArgmaxOp(OpKernelConstruction* context)
@@ -959,7 +966,7 @@
     Tensor* argmax = nullptr;
     OP_REQUIRES_OK(context, context->allocate_output(1, out_shape, &argmax));
 
-    LaunchMaxPoolingWithArgmax<Device, T>::launch(
+    LaunchMaxPoolingWithArgmax<Device, T, Targmax>::launch(
         context, params, tensor_in, output, argmax, propagate_nans_,
         include_batch_in_index_);
   }
@@ -1027,6 +1034,7 @@
   }
 };
 
+// TODO(b/175733711): Support int32 argmax type in MaxPoolGradWithArgmax op.
 template <typename Device, typename T>
 class MaxPoolingGradWithArgmaxOp : public OpKernel {
  public:
@@ -1363,7 +1371,7 @@
 };
 
 template <typename T>
-struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
+struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T, int64> {
   static void launch(OpKernelContext* context, const PoolParameters& params,
                      const Tensor& input, Tensor* output, Tensor* argmax,
                      bool propagate_nans, bool include_batch_in_index) {
@@ -1456,7 +1464,7 @@
                               .Device(DEVICE_##D)                        \
                               .TypeConstraint<int64>("Targmax")          \
                               .TypeConstraint<T>("T"),                   \
-                          MaxPoolingWithArgmaxOp<D##Device, T>);         \
+                          MaxPoolingWithArgmaxOp<D##Device, T, int64>);  \
   REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax")                  \
                               .Device(DEVICE_##D)                        \
                               .TypeConstraint<T>("T")                    \
@@ -1470,7 +1478,12 @@
       MaxPoolingOp<CPUDevice, T>);                                 \
   REGISTER_KERNEL_BUILDER(                                         \
       Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      MaxPoolingV2Op<CPUDevice, T>);
+      MaxPoolingV2Op<CPUDevice, T>);                               \
+  REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax")                \
+                              .Device(DEVICE_CPU)                  \
+                              .TypeConstraint<int32>("Targmax")    \
+                              .TypeConstraint<T>("T"),             \
+                          MaxPoolingWithArgmaxOp<CPUDevice, T, int32>);
 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS);
 #undef REGISTER_CPU_ONLY_POOL_KERNELS
 
diff --git a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc
index 7bd47e9..9bb2653 100644
--- a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc
+++ b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc
@@ -955,6 +955,71 @@
 // Testing fusion of MatMul and BiasAdd
 template <typename T>
 class MklFusedMatMulOpTest : public OpsTestBase {
+ private:
+  void RunMklFusedMatMulOp(const Tensor& input, const Tensor& weight,
+                           const std::vector<Tensor>& args,
+                           const std::vector<string>& fused_ops,
+                           Tensor* output) {
+    DataType dtype = DataTypeToEnum<T>::v();
+    const int num_args = args.size();
+    if (!NativeFormatEnabled()) {
+      TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
+                       .Input(FakeInput(dtype))
+                       .Input(FakeInput(dtype))
+                       .Input(FakeInput(num_args, dtype))
+                       .Input(FakeInput(DT_UINT8))
+                       .Input(FakeInput(DT_UINT8))
+                       .Input(FakeInput(num_args, DT_UINT8))
+                       .Attr("T", dtype)
+                       .Attr("transpose_a", false)
+                       .Attr("transpose_b", false)
+                       .Attr("num_args", num_args)
+                       .Attr("fused_ops", fused_ops)
+                       .Attr("epsilon", 0.0001)
+                       .Attr("_kernel", "MklLayoutDependentOp")
+                       .Finalize(node_def()));
+    } else {
+      TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
+                       .Input(FakeInput(dtype))
+                       .Input(FakeInput(dtype))
+                       .Input(FakeInput(num_args, dtype))
+                       .Attr("T", dtype)
+                       .Attr("transpose_a", false)
+                       .Attr("transpose_b", false)
+                       .Attr("num_args", num_args)
+                       .Attr("fused_ops", fused_ops)
+                       .Attr("epsilon", 0.0001)
+                       .Attr("_kernel", "MklNameChangeOp")
+                       .Finalize(node_def()));
+    }
+
+    TF_EXPECT_OK(InitOp());
+
+    AddInputFromArray<T>(input.shape(), input.flat<T>());
+    AddInputFromArray<T>(weight.shape(), weight.flat<T>());
+    for (const Tensor& arg : args)
+      AddInputFromArray<T>(arg.shape(), arg.flat<T>());
+    if (!NativeFormatEnabled()) {
+      // Add MKL meta input for input, filter and bias.
+      AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+      AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+      for (const Tensor& arg : args)
+        AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+    }
+
+    TF_ASSERT_OK(RunOpKernel());
+
+    const Tensor& output_tensor = *GetOutput(0);
+    if (!NativeFormatEnabled()) {
+      const Tensor& output_meta_tensor = *GetOutput(1);
+      CommonTestUtilities<T> test_util;
+      test_util.PerformConversion(dtype, output_tensor, output_meta_tensor,
+                                  output);
+    } else {
+      *output = output_tensor;
+    }
+  }
+
  protected:
   void VerifyFusedMatMul(const int kBatch, const int kInputChannel,
                          const int kOutputChannel,
@@ -1002,70 +1067,24 @@
             next_op = ops::Tanh(root.WithOpName(last_op), next_op);
           }
 
+          if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
+              fused_ops.end()) {
+            last_op = "with_add";
+            next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op);
+          }
+
           CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
         };
 
     const FusedGraphRunner run_fused =
         [this](const Tensor& input, const Tensor& weight, const Tensor& bias,
                const std::vector<string>& fused_ops, Tensor* output) {
-          DataType dtype = DataTypeToEnum<T>::v();
-          const int num_args = 1;
-
-          if (!NativeFormatEnabled()) {
-            TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
-                             .Input(FakeInput(dtype))
-                             .Input(FakeInput(dtype))
-                             .Input(FakeInput(num_args, dtype))
-                             .Input(FakeInput(DT_UINT8))
-                             .Input(FakeInput(DT_UINT8))
-                             .Input(FakeInput(num_args, DT_UINT8))
-                             .Attr("T", dtype)
-                             .Attr("transpose_a", false)
-                             .Attr("transpose_b", false)
-                             .Attr("num_args", num_args)
-                             .Attr("fused_ops", fused_ops)
-                             .Attr("epsilon", 0.0001)
-                             .Attr("_kernel", "MklLayoutDependentOp")
-                             .Finalize(node_def()));
-          } else {
-            TF_EXPECT_OK(
-                NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
-                    .Input(FakeInput(dtype))
-                    .Input(FakeInput(dtype))
-                    .Input(FakeInput(num_args, dtype))
-                    .Attr("T", dtype)
-                    .Attr("transpose_a", false)
-                    .Attr("transpose_b", false)
-                    .Attr("num_args", num_args)
-                    .Attr("fused_ops", fused_ops)
-                    .Attr("epsilon", 0.0001)
-                    .Attr("_kernel", "MklNameChangeOp")
-                    .Finalize(node_def()));
+          std::vector<Tensor> fused_input = {bias};
+          if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
+              fused_ops.end()) {
+            fused_input.push_back(input);
           }
-
-          TF_EXPECT_OK(InitOp());
-
-          AddInputFromArray<T>(input.shape(), input.flat<T>());
-          AddInputFromArray<T>(weight.shape(), weight.flat<T>());
-          AddInputFromArray<T>(bias.shape(), bias.flat<T>());
-          if (!NativeFormatEnabled()) {
-            // Add MKL meta input for input, filter and bias.
-            AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
-            AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
-            AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
-          }
-
-          TF_ASSERT_OK(RunOpKernel());
-
-          const Tensor& output_tensor = *GetOutput(0);
-          if (!NativeFormatEnabled()) {
-            const Tensor& output_meta_tensor = *GetOutput(1);
-            CommonTestUtilities<T> test_util;
-            test_util.PerformConversion(dtype, output_tensor,
-                                        output_meta_tensor, output);
-          } else {
-            *output = output_tensor;
-          }
+          RunMklFusedMatMulOp(input, weight, fused_input, fused_ops, output);
         };
 
     CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch,
@@ -1120,12 +1139,22 @@
                           {"BiasAdd", "Tanh"});
 }
 
+TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) {
+  const int batch = 3;
+  const int input_channel = 4;
+  const int output_channel = 4;
+
+  this->VerifyFusedMatMul(batch, input_channel, output_channel,
+                          {"BiasAdd", "Add"});
+}
+
 REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest,  //
                             WithBias,              //
                             WithBiasAndRelu,       //
                             WithBiasAndRelu6,      //
                             WithBiasAndElu,        //
-                            WithBiasAndTanh);
+                            WithBiasAndTanh,       //
+                            WithBiasAndAdd);
 
 using MklFusedMatMulDataTypes = ::testing::Types<float>;
 INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,
diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc
index 905abbf..246efac 100644
--- a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc
+++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc
@@ -45,6 +45,7 @@
         ctx, fused_ops_[0] == "BiasAdd",
         errors::InvalidArgument(
             "The 1st post-argument of MklFusedMatMul must be BiasAdd."));
+    if (fused_ops_.size() > 1 && fused_ops_[1] == "Add") fuse_add_ = true;
     OP_REQUIRES(
         ctx, transpose_a_ == false,
         errors::InvalidArgument("In[0] of MklMatMul can't be transposed."));
@@ -114,7 +115,8 @@
     //   2. var, keep the original format to avoid reordering.
     MklDnnMatMulFwdParams matmul_params(
         src_dims, weight_dims, bias_dims, dst_dims, src_format,
-        (this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format);
+        (this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format,
+        MEMORY_FORMAT::nc);
 
     // Extend the basic parameters for data types and fusions.
     ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
@@ -126,15 +128,70 @@
     std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> matmul_pd =
         matmul_prim->GetPrimitiveDesc();
 
-    if (src_mkl_shape.IsMklTensor()) {
-      this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims,
-                                 MKL_TENSOR_FORMAT_NC, &dst_tensor);
+    // The output shape of MatMul is same both for MKL and TF version.
+    // They are all NC format, no matter what's the format of input.
+    // And the shape of AddOp is also the same with output's shape.
+    auto dst_pd = matmul_pd->PRIMITIVE_DESC_DST;
+
+    MklDnnShape output_mkl_shape;
+    output_mkl_shape.SetMklTensor(false);
+
+    TensorShape output_tf_shape({batch, channel});
+
+    if (fuse_add_) {
+      const Tensor& add_tensor = MklGetInput(ctx, kInputIndex_Add);
+      MklDnnShape add_mkl_shape;
+      GetMklShape(ctx, kInputIndex_Add, &add_mkl_shape, native_format);
+
+      // For native format, we need not to set metadata.
+      if (native_format && ctx->forward_input_to_output_with_shape(
+                               kInputIndex_Add, kOutputIndex_Dst,
+                               output_tf_shape, &dst_tensor)) {
+        ;  // Need to do nothing for native format
+      } else if (!native_format && ForwardMklTensorInToOutWithMklShape(
+                                       ctx, kInputIndex_Add, kOutputIndex_Dst,
+                                       &dst_tensor, output_mkl_shape, false)) {
+        ;  // If it's not native format, need to forward and set meta first
+      } else {
+        // If forward is not successful, we should use reorder to copy add
+        // tensor to dst tensor
+        AllocateOutputSetMklShape(ctx, kOutputIndex_Dst, &dst_tensor,
+                                  output_tf_shape, output_mkl_shape,
+                                  native_format);
+        auto output_format_tag =
+            MklTensorFormatToMklDnnDataFormat(MKL_TENSOR_FORMAT_NC);
+        auto add_md =
+            add_mkl_shape.IsMklTensor()
+                ? add_mkl_shape.GetMklLayout()
+                : memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
+        auto dst_md =
+            memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
+
+        void* add_buf =
+            static_cast<void*>(const_cast<T*>(add_tensor.flat<T>().data()));
+        void* dst_buf = static_cast<void*>((dst_tensor)->flat<T>().data());
+
+        if (native_format) {
+          // We are simply deep copying the add_tensor to dst_tensor without
+          // changing memory layout, hence using same memory descriptor.
+          add_md = dst_md =
+              memory::desc({add_tensor.NumElements()}, MklDnnType<T>(),
+                           mkldnn::memory::format_tag::x);
+        }
+
+        auto fuse_add_src_ =
+            MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf);
+        auto fuse_add_dst_ =
+            MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf);
+        auto reorder_desc =
+            REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
+
+        CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_,
+                                this->cpu_engine_, ctx);
+      }
     } else {
-      TensorShape dst_tensor_shape({batch, channel});
-      MklDnnShape dst_mkl_shape;
-      dst_mkl_shape.SetMklTensor(false);
-      AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape,
-                                dst_mkl_shape, native_format);
+      AllocateOutputSetMklShape(ctx, 0, &dst_tensor, output_tf_shape,
+                                output_mkl_shape, native_format);
     }
 
     // if there's nothing to compute, just return.
@@ -228,6 +285,8 @@
         params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
       } else if (post_op == "Tanh") {
         params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}});
+      } else if (post_op == "Add") {
+        params.post_op_params.push_back({"sum", {1.0}});
       } else {
         OP_REQUIRES_OK(
             ctx, errors::InvalidArgument(
@@ -237,10 +296,13 @@
   }
 
  private:
+  bool fuse_add_ = false;
   bool transpose_a_;
   bool transpose_b_;
   std::vector<string> fused_ops_;
-};
+  const int kInputIndex_Add = 3;
+  const int kOutputIndex_Dst = 0;
+};  // namespace tensorflow
 
 // Register mkl kernels for supported operations and types.
 #define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type)                \
diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h
index d1e82bf..375047d 100644
--- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h
+++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h
@@ -48,8 +48,9 @@
   memory::dims weight_dims;
   memory::dims bias_dims;
   memory::dims dst_dims;
-  memory::format_tag src_format;
-  memory::format_tag weight_format;
+  MEMORY_FORMAT src_format;
+  MEMORY_FORMAT weight_format;
+  MEMORY_FORMAT dst_format;
   string dtypes = string("");
   struct PostOpParam {
     string name;
@@ -57,17 +58,18 @@
   };
   std::vector<PostOpParam> post_op_params;
 
-  MklDnnMatMulFwdParams(
-      memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims,
-      memory::dims dst_dims,
-      memory::format_tag src_format = memory::format_tag::any,
-      memory::format_tag weight_format = memory::format_tag::any)
+  MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
+                        memory::dims bias_dims, memory::dims dst_dims,
+                        MEMORY_FORMAT src_format = MEMORY_FORMAT::any,
+                        MEMORY_FORMAT weight_format = MEMORY_FORMAT::any,
+                        MEMORY_FORMAT dst_format = MEMORY_FORMAT::any)
       : src_dims(src_dims),
         weight_dims(weight_dims),
         bias_dims(bias_dims),
         dst_dims(dst_dims),
         src_format(src_format),
-        weight_format(weight_format) {}
+        weight_format(weight_format),
+        dst_format(dst_format) {}
 };
 
 // With quantization, input, weight, bias, and output can have different types.
@@ -184,7 +186,7 @@
 
     context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
                                            MklDnnType<Toutput>(),
-                                           memory::format_tag::any));
+                                           matmul_fwd_params.dst_format));
 
     context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
                                             MklDnnType<Tbias>(),
@@ -236,11 +238,17 @@
           std::vector<float> scales;
           scales.push_back(post_op_param.param[0]);
           post_ops_attr.set_output_scales(0, scales);
+        } else if (post_op_param.name == "sum") {
+          DCHECK_EQ(post_op_param.param.size(), 1);
+          float op_scale = post_op_param.param[0];
+          post_ops.append_sum(op_scale);
+
         } else {
           DCHECK((post_op_param.name == "relu") ||
                  (post_op_param.name == "relu6") ||
                  (post_op_param.name == "elu") ||
                  (post_op_param.name == "tanh") ||
+                 (post_op_param.name == "sum") ||
                  (post_op_param.name == "output_scale"));
         }
       }
@@ -340,6 +348,10 @@
         key_creator.AddAsKey(post_op_param.param[0]);
         key_creator.AddAsKey(post_op_param.param[1]);
         key_creator.AddAsKey(post_op_param.param[2]);
+      } else if (post_op_param.name == "sum") {
+        DCHECK_EQ(post_op_param.param.size(), 1);
+        key_creator.AddAsKey(post_op_param.name);
+        key_creator.AddAsKey(post_op_param.param[0]);
       } else if (post_op_param.name == "output_scale") {
         DCHECK_EQ(post_op_param.param.size(), 1);
         key_creator.AddAsKey(post_op_param.name);
diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD
index 047782f..fb32462 100644
--- a/tensorflow/core/kernels/mlir_generated/BUILD
+++ b/tensorflow/core/kernels/mlir_generated/BUILD
@@ -3,8 +3,8 @@
 load(
     "//tensorflow/core/kernels/mlir_generated:build_defs.bzl",
     "gen_kernel_library",
+    "if_mlir_experimental_kernels_enabled",
     "if_mlir_generated_gpu_kernels_enabled",
-    "if_mlir_unranked_kernels_enabled",
 )
 load(
     "//tensorflow:tensorflow.bzl",
@@ -27,60 +27,59 @@
 
 config_setting(
     name = "mlir_generated_gpu_kernels_disabled",
-    define_values = {
-        "tensorflow_enable_mlir_generated_gpu_kernels": "0",
-    },
+    define_values = {"tensorflow_enable_mlir_generated_gpu_kernels": "0"},
 )
 
 config_setting(
-    name = "mlir_use_unranked_kernels",
+    name = "mlir_experimental_kernels_enabled",
     define_values = {"enable_unranked_kernels": "1"},
 )
 
 filegroup(
-    name = "enabled_unary_unranked_kernel_srcs",
+    name = "enabled_unary_kernel_srcs",
     srcs = [
-        "unranked_op_gpu_abs.cc",
-        "unranked_op_gpu_tanh.cc",
+        "gpu_op_abs.cc",
+        "gpu_op_tanh.cc",
     ],
     compatible_with = get_compatible_with_cloud(),
 )
 
 filegroup(
-    name = "experimental_unary_unranked_kernel_srcs",
+    name = "experimental_unary_kernel_srcs",
     srcs = [
-        "unranked_op_gpu_ceil.cc",
-        "unranked_op_gpu_conj.cc",
-        "unranked_op_gpu_cos.cc",
-        "unranked_op_gpu_exp.cc",
-        "unranked_op_gpu_floor.cc",
-        "unranked_op_gpu_imag.cc",
-        "unranked_op_gpu_is_inf.cc",
-        "unranked_op_gpu_log.cc",
-        "unranked_op_gpu_logical_not.cc",
-        "unranked_op_gpu_real.cc",
-        "unranked_op_gpu_rsqrt.cc",
-        "unranked_op_gpu_sign.cc",
-        "unranked_op_gpu_sin.cc",
-        "unranked_op_gpu_sqrt.cc",
+        "gpu_op_ceil.cc",
+        "gpu_op_conj.cc",
+        "gpu_op_cos.cc",
+        "gpu_op_exp.cc",
+        "gpu_op_floor.cc",
+        "gpu_op_imag.cc",
+        "gpu_op_is_inf.cc",
+        "gpu_op_log.cc",
+        "gpu_op_logical_not.cc",
+        "gpu_op_neg.cc",
+        "gpu_op_real.cc",
+        "gpu_op_rsqrt.cc",
+        "gpu_op_sign.cc",
+        "gpu_op_sin.cc",
+        "gpu_op_sqrt.cc",
     ],
     compatible_with = get_compatible_with_cloud(),
 )
 
 filegroup(
-    name = "unary_unranked_kernel_srcs",
+    name = "unary_kernel_srcs",
     srcs = [
-        ":enabled_unary_unranked_kernel_srcs",
-    ] + if_mlir_unranked_kernels_enabled(
-        if_true = [":experimental_unary_unranked_kernel_srcs"],
+        ":enabled_unary_kernel_srcs",
+    ] + if_mlir_experimental_kernels_enabled(
+        if_true = [":experimental_unary_kernel_srcs"],
     ),
     compatible_with = get_compatible_with_cloud(),
 )
 
 cc_library(
-    name = "unranked_op_gpu_base",
-    srcs = ["unranked_op_gpu_base.cc"],
-    hdrs = ["unranked_op_gpu_base.h"],
+    name = "gpu_ops_base",
+    srcs = ["gpu_ops_base.cc"],
+    hdrs = ["gpu_ops_base.h"],
     compatible_with = get_compatible_with_cloud(),
     deps = [
         "//tensorflow/compiler/mlir/tools/kernel_gen:tf_framework_c_interface",
@@ -96,7 +95,7 @@
 
 tf_kernel_library(
     name = "cwise_unary_op",
-    srcs = [":unary_unranked_kernel_srcs"],
+    srcs = [":unary_kernel_srcs"],
     tags = [
         "manual",
     ],
@@ -106,36 +105,71 @@
         # make our BUILD target structure uglier. We already need to make
         # sure that those targets can be built, so it should not hurt to
         # link them in even if they are currently not needed yet.
-        ":abs_unranked_kernels",
-        ":ceil_unranked_kernels",
-        ":conj_unranked_kernels",
-        ":cos_unranked_kernels",
-        ":exp_unranked_kernels",
-        ":floor_unranked_kernels",
-        ":imag_unranked_kernels",
-        ":is_inf_unranked_kernels",
-        ":log_unranked_kernels",
-        ":logical_not_unranked_kernels",
-        ":real_unranked_kernels",
-        ":rsqrt_unranked_kernels",
-        ":sign_unranked_kernels",
-        ":sin_unranked_kernels",
-        ":sqrt_unranked_kernels",
-        ":tanh_unranked_kernels",
-        ":unranked_op_gpu_base",
+        ":abs_kernels",
+        ":ceil_kernels",
+        ":conj_kernels",
+        ":cos_kernels",
+        ":exp_kernels",
+        ":floor_kernels",
+        ":imag_kernels",
+        ":is_inf_kernels",
+        ":log_kernels",
+        ":logical_not_kernels",
+        ":neg_kernels",
+        ":real_kernels",
+        ":rsqrt_kernels",
+        ":sign_kernels",
+        ":sin_kernels",
+        ":sqrt_kernels",
+        ":tanh_kernels",
+        ":gpu_ops_base",
         "//third_party/eigen3",
     ],
 )
 
 tf_kernel_library(
     name = "cwise_binary_op",
-    srcs = ["unranked_gpu_add.cc"],
+    srcs = [
+        "gpu_op_add.cc",
+        "gpu_op_bitwise_and.cc",
+        "gpu_op_bitwise_or.cc",
+        "gpu_op_bitwise_xor.cc",
+        "gpu_op_equal.cc",
+        "gpu_op_floor_div.cc",
+        "gpu_op_greater.cc",
+        "gpu_op_greater_equal.cc",
+        "gpu_op_left_shift.cc",
+        "gpu_op_less.cc",
+        "gpu_op_less_equal.cc",
+        "gpu_op_logical_and.cc",
+        "gpu_op_logical_or.cc",
+        "gpu_op_mul.cc",
+        "gpu_op_not_equal.cc",
+        "gpu_op_right_shift.cc",
+    ],
     tags = [
         "manual",
     ],
     deps = [
-        ":add_v2_unranked_kernels",
-        ":unranked_op_gpu_base",
+        ":add_v2_kernels",
+        ":bitwise_and_kernels",
+        ":bitwise_or_kernels",
+        ":bitwise_xor_kernels",
+        ":equal_kernels",
+        ":floor_div_kernels",
+        ":gpu_ops_base",
+        ":greater_equal_kernels",
+        ":greater_kernels",
+        ":left_shift_kernels",
+        ":less_equal_kernels",
+        ":less_kernels",
+        ":logical_and_kernels",
+        ":logical_or_kernels",
+        ":maximum_kernels",
+        ":minimum_kernels",
+        ":mul_kernels",
+        ":not_equal_kernels",
+        ":right_shift_kernels",
         "//third_party/eigen3",
     ],
 )
@@ -150,7 +184,7 @@
     # but we want to avoid building them if they are not needed.
     deps = if_cuda_or_rocm([
         ":cwise_unary_op",
-    ]) + if_mlir_unranked_kernels_enabled(
+    ]) + if_mlir_experimental_kernels_enabled(
         [
             ":cwise_binary_op",
         ],
@@ -165,6 +199,7 @@
         "no_cuda_asan",  # TODO(b/171341759): re-enable.
     ],
     deps = [
+        ":gpu_ops_test_util",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:tensorflow",
@@ -180,13 +215,14 @@
 )
 
 tf_cuda_cc_test(
-    name = "gpu_add_test",
+    name = "gpu_binary_ops_test",
     size = "small",
-    srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_add_test.cc"]),
+    srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_binary_ops_test.cc"]),
     tags = tf_cuda_tests_tags() + [
         "no_cuda_asan",  # b/173033461
     ],
     deps = [
+        ":gpu_ops_test_util",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:tensorflow",
@@ -195,8 +231,32 @@
         "//tensorflow/core:testlib",
         "//tensorflow/core/common_runtime:device",
         "//tensorflow/core/common_runtime:device_factory",
+        "//tensorflow/core/framework:types_proto_cc",
         "//tensorflow/core/kernels:cwise_op",
         "//tensorflow/core/kernels:ops_testutil",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_library(
+    name = "gpu_ops_test_util",
+    testonly = 1,
+    srcs = [
+        "gpu_ops_test_util.cc",
+        "gpu_ops_test_util.h",
+    ],
+    hdrs = [
+        "gpu_ops_test_util.h",
+    ],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:tensorflow",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
     ],
 )
 
@@ -224,7 +284,6 @@
 
 gen_kernel_library(
     name = "abs",
-    generate_unranked = True,
     tile_size = "256",
     types = [
         "f16",
@@ -238,29 +297,25 @@
 
 gen_kernel_library(
     name = "conj",
-    generate_unranked = True,
     tile_size = "256",
     types = [
         "f32",
         "f64",
     ],
-    unroll_factors = "4",
+    unroll_factors = "2",
 )
 
 gen_kernel_library(
     name = "imag",
-    generate_unranked = True,
     tile_size = "256",
     types = [
         "f32",
         "f64",
     ],
-    unroll_factors = "1",
 )
 
 gen_kernel_library(
     name = "invert",
-    generate_unranked = True,
     tile_size = "256",
     types = [
         "i8",
@@ -273,8 +328,6 @@
 
 gen_kernel_library(
     name = "is_inf",
-    generate_ranked = False,
-    generate_unranked = True,
     tile_size = "256",
     types = [
         "f16",
@@ -286,26 +339,21 @@
 
 gen_kernel_library(
     name = "logical_not",
-    generate_unranked = True,
     tile_size = "256",
     types = ["i1"],
-    unroll_factors = "4",
 )
 
 gen_kernel_library(
     name = "real",
-    generate_unranked = True,
     tile_size = "256",
     types = [
         "f32",
         "f64",
     ],
-    unroll_factors = "1",
 )
 
 gen_kernel_library(
     name = "sign",
-    generate_unranked = True,
     tile_size = "256",
     types = [
         # TODO(b/162577610): Add bf16, c64 and c128.
@@ -320,8 +368,6 @@
 
 gen_kernel_library(
     name = "add_v2",
-    generate_ranked = False,
-    generate_unranked = True,
     tile_size = "256,1,1",
     types = [
         "f16",
@@ -329,25 +375,180 @@
         "f64",
         "i64",
     ],
-    # TODO(b/174543802): Enable once fusion heursitics is better.
+    # TODO(b/174543802): Enable once fusion heuristics is better.
     # unroll_factors = "4",
 )
 
 gen_kernel_library(
-    name = "equal",
-    generate_ranked = False,
-    generate_unranked = True,
+    name = "complex",
+    tile_size = "256,1,1",
+    types = [
+        "f32",
+        "f64",
+    ],
+    # TODO(b/174543802): Enable once fusion heuristics is better.
+    # unroll_factors = "2",
+)
+
+gen_kernel_library(
+    name = "mul",
     tile_size = "256,1,1",
     types = [
         "f16",
         "f32",
         "f64",
-        "i1",
         "i8",
         "i16",
-        "i32",
         "i64",
     ],
+    # TODO(b/174543802): Enable once fusion heuristics is better.
+    # unroll_factors = "4",
+)
+
+# Bitwise operations.
+[
+    gen_kernel_library(
+        name = name,
+        tile_size = "256,1,1",
+        types = [
+            "i8",
+            "i16",
+            "i32",
+            "i64",
+            # TODO(b/172804967): Enable once fixed.
+            # "ui8",
+            # "ui16",
+            # "ui32",
+            # "ui64",
+        ],
+        # TODO(b/174543802): Enable once fusion heursitics is better.
+        # unroll_factors = "4",
+    )
+    for name in [
+        "bitwise_and",
+        "bitwise_or",
+        "bitwise_xor",
+        "left_shift",
+        "right_shift",
+    ]
+]
+
+# Logical operations.
+[
+    gen_kernel_library(
+        name = name,
+        tile_size = "256,1,1",
+        types = [
+            "i1",
+        ],
+        # TODO(b/174543802): Enable once fusion heursitics is better.
+        # unroll_factors = "4",
+    )
+    for name in [
+        "logical_and",
+        "logical_or",
+    ]
+]
+
+[
+    gen_kernel_library(
+        name = name,
+        tile_size = "256,1,1",
+        types = [
+            "f16",
+            "f32",
+            "f64",
+            "i1",
+            "i8",
+            "i16",
+            "i32",
+            "i64",
+        ],
+        # TODO(b/174543802): Enable once fusion heuristics is better.
+        # unroll_factors = "4",
+    )
+    for name in [
+        "equal",
+        "not_equal",
+    ]
+]
+
+[
+    gen_kernel_library(
+        name = name,
+        tile_size = "256,1,1",
+        types = [
+            "f16",
+            "f32",
+            "f64",
+            "i8",
+            "i16",
+            "i32",
+            "i64",
+        ],
+        # TODO(b/174543802): Enable once fusion heuristics is better.
+        # unroll_factors = "4",
+    )
+    for name in [
+        "less",
+        "less_equal",
+        "greater",
+        "greater_equal",
+    ]
+]
+
+[
+    gen_kernel_library(
+        name = name,
+        tile_size = "256,1,1",
+        types = [
+            "f16",
+            "f32",
+            "f64",
+            "i16",
+            "i32",
+            "i64",
+        ],
+        # TODO(b/174543802): Enable once fusion heuristics is better.
+        # unroll_factors = "4",
+    )
+    for name in [
+        "maximum",
+        "minimum",
+    ]
+]
+
+# Kernels that support all floating-point and signed int types.
+[
+    gen_kernel_library(
+        name = name,
+        tile_size = "256",
+        types = [
+            "f16",
+            "f32",
+            "f64",
+            "i8",
+            "i16",
+            "i32",
+            "i64",
+        ],
+        unroll_factors = "4",
+    )
+    for name in [
+        "neg",
+    ]
+]
+
+gen_kernel_library(
+    name = "floor_div",
+    tile_size = "256",
+    # TODO(172804967): Enable for integer types also once unsigned integers are
+    # supported.
+    types = [
+        "f16",
+        "f32",
+        "f64",
+    ],
     # TODO(b/174543802): Enable once fusion heursitics is better.
     # unroll_factors = "4",
 )
@@ -356,7 +557,6 @@
 [
     gen_kernel_library(
         name = name,
-        generate_unranked = True,
         tile_size = "256",
         types = [
             "f16",
@@ -371,7 +571,6 @@
         "floor",
         "is_finite",
         "log",
-        "neg",
         "rsqrt",
         "sqrt",
         "tanh",
@@ -382,14 +581,12 @@
 [
     gen_kernel_library(
         name = name,
-        generate_unranked = True,
         tile_size = "256",
         types = [
             "f16",
             "f32",
             "f64",
         ],
-        unroll_factors = "1",
     )
     for name in [
         "cos",
diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl
index e20257d..7808d1a 100644
--- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl
+++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl
@@ -11,6 +11,7 @@
     "//tensorflow/stream_executor:build_defs.bzl",
     "if_gpu_is_configured",
 )
+load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")
 
 def if_mlir_generated_gpu_kernels_enabled(if_true, if_false = []):
     return select({
@@ -30,168 +31,6 @@
     fields = ["gpu_bins"],
 )
 
-def _gen_kernel_gpu_bin_impl(ctx):
-    name = ctx.attr.name
-    tile_sizes = ctx.attr.tile_size.replace("x", ",")
-    cmd_args = []
-    if ctx.attr.unroll_factors:
-        cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
-
-    if ctx.attr.extra_args:
-        cmd_args.extend(ctx.attr.extra_args)
-
-    gpu_bins = []
-    for arch in ctx.attr.gpu_archs:
-        # TODO(b/170283783): 'compute_' should generate both SASS and PTX.
-        arch = arch.replace("compute_", "sm_")
-        filename = "%s.%s.bin" % (name, arch)
-        gpu_bin = ctx.actions.declare_file(filename)
-        ctx.actions.run(
-            inputs = [ctx.file.mlir_op, ctx.file._tfso],
-            outputs = [gpu_bin],
-            executable = ctx.executable._tool,
-            arguments = cmd_args + [
-                "--tile_sizes=%s" % tile_sizes,
-                "--arch=%s" % arch,
-                "--input=%s" % ctx.file.mlir_op.path,
-                "--output=%s" % gpu_bin.path,
-            ],
-            mnemonic = "compile",
-        )
-        gpu_bins.append(gpu_bin)
-    return [GpuBinaryInfo(gpu_bins = gpu_bins)]
-
-_gen_kernel_gpu_bin_rule = rule(
-    attrs = {
-        "mlir_op": attr.label(mandatory = True, allow_single_file = True),
-        "tile_size": attr.string(mandatory = True),
-        "unroll_factors": attr.string(),
-        "gpu_archs": attr.string_list(mandatory = True),
-        "extra_args": attr.string_list(),
-        "_tfso": attr.label(
-            default = Label("//tensorflow:libtensorflow_framework.so.2"),
-            cfg = "host",
-            allow_single_file = True,
-        ),
-        "_tool": attr.label(
-            executable = True,
-            default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_gpu_binary"),
-            cfg = "host",
-        ),
-    },
-    output_to_genfiles = True,
-    implementation = _gen_kernel_gpu_bin_impl,
-)
-
-def _gen_kernel_image_hdr_impl_cuda(ctx):
-    images = []
-    for cubin in ctx.attr.input[GpuBinaryInfo].gpu_bins:
-        arch = cubin.path.split(".")[-2]
-        images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
-
-    # Generate fatbin file from all cubins.
-    fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name)
-    ctx.actions.run(
-        outputs = [fatbin],
-        inputs = ctx.attr.input[GpuBinaryInfo].gpu_bins,
-        executable = _lookup_file(ctx.attr._gpu_root, "bin/fatbinary"),
-        arguments = [
-            "--64",
-            "--cmdline=--compile-only",
-            "--link",
-            "--compress-all",
-            "--create=%s" % fatbin.path,
-        ] + images,
-        mnemonic = "fatbinary",
-    )
-
-    bin2c = _lookup_file(ctx.attr._gpu_root, "bin/bin2c")
-    ctx.actions.run_shell(
-        outputs = [ctx.outputs.out],
-        inputs = [fatbin],
-        tools = [bin2c],
-        command = "%s --static --const --type=char --name=%s %s 1> %s" %
-                  (bin2c.path, ctx.attr.symbol, fatbin.path, ctx.outputs.out.path),
-        mnemonic = "bin2c",
-    )
-
-def _gen_kernel_image_hdr_impl_rocm(ctx):
-    hsaco_files = []
-    hsaco_targets = []
-
-    # Add a dummy host target triple...clang-offload-bundler requires 1 and only 1 host target triple
-    hsaco_files.append("/dev/null")
-    hsaco_targets.append("host-x86_64-unknown-linux")
-
-    hsacos = ctx.attr.input[GpuBinaryInfo].gpu_bins
-    for hsaco in hsacos:
-        gfx_arch = hsaco.path.split(".")[-2]
-        hsaco_files.append(hsaco.path)
-        hsaco_targets.append("hip-amdgcn-amd-amdhsa-%s" % gfx_arch)
-
-    # Generate fatbin file from all hsacos.
-    fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name)
-    ctx.actions.run(
-        outputs = [fatbin],
-        inputs = hsacos,
-        executable = _lookup_file(ctx.attr._gpu_root, "bin/clang-offload-bundler"),
-        arguments = [
-            "--inputs=%s" % ",".join(hsaco_files),
-            "--targets=%s" % ",".join(hsaco_targets),
-            "--type=o",
-            "--outputs=%s" % fatbin.path,
-        ],
-        mnemonic = "fatbinary",
-    )
-
-    ctx.actions.run_shell(
-        outputs = [ctx.outputs.out],
-        inputs = [fatbin],
-        command = (
-            ("hex=`hexdump -v -e \'/1 \"0x%%02x, \"\' %s` && " +
-             "len=`echo $hex | wc -c` && " +
-             "echo 'static const unsigned char %s['$len' + 1] = {' > %s && " +
-             "echo $hex | cat >> %s && " +
-             "echo '};' >> %s") % (
-                fatbin.path,
-                ctx.attr.symbol,
-                ctx.outputs.out.path,
-                ctx.outputs.out.path,
-                ctx.outputs.out.path,
-            )
-        ),
-    )
-
-_gen_kernel_image_hdr_rule = rule(
-    implementation = _gen_kernel_image_hdr_impl_rocm if rocm_is_configured() else _gen_kernel_image_hdr_impl_cuda,
-    output_to_genfiles = True,
-    attrs = {
-        "input": attr.label(mandatory = True, providers = [GpuBinaryInfo]),
-        "out": attr.output(mandatory = True),
-        "symbol": attr.string(mandatory = True),
-        "_gpu_root": attr.label(
-            default = Label("@local_config_rocm//rocm:rocm_root") if rocm_is_configured() else Label("@local_config_cuda//cuda:cuda_root"),
-        ),
-    },
-)
-
-def _gen_kernel_image_hdr(name, mlir_op, gpu_archs, tile_size, unroll_factors = None, extra_args = []):
-    """Generates a C header with fatbin data from a Tensorflow op."""
-    _gen_kernel_gpu_bin_rule(
-        name = name + "_cubin",
-        mlir_op = mlir_op,
-        tile_size = tile_size,
-        unroll_factors = unroll_factors,
-        gpu_archs = gpu_archs,
-        extra_args = extra_args,
-    )
-    _gen_kernel_image_hdr_rule(
-        name = name,
-        input = ":" + name + "_cubin",
-        out = "%s.h" % name,
-        symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
-    )
-
 type_to_mlir = {
     "c64": "complex<f32>",
     "c128": "complex<f64>",
@@ -203,18 +42,12 @@
     if mlir_type in type_to_mlir:
         mlir_type = type_to_mlir[mlir_type]
 
-    # In order to generate a ranked kernel we change *xelem_type to ?xelem_type
-    # and remove element type from the entry function name.
-    convert_to_ranked = ""
-    if ctx.attr.unranked == False:
-        convert_to_ranked = "sed s/*x/?x/g | sed s/_elem_type//g |"
     cmd = ctx.actions.run_shell(
         inputs = [ctx.file.template],
         outputs = [ctx.outputs.out],
         command = (
-            ("cat %s | %s sed 's/_elem_type/_%s/g' | sed 's/elem_type/%s/g' > %s") % (
+            ("cat %s | sed 's/_elem_type/_%s/g' | sed 's/elem_type/%s/g' > %s") % (
                 ctx.file.template.path,
-                convert_to_ranked,
                 ctx.attr.type,
                 mlir_type,
                 ctx.outputs.out.path,
@@ -229,21 +62,100 @@
         "template": attr.label(mandatory = True, allow_single_file = True),
         "type": attr.string(mandatory = True),
         "out": attr.output(mandatory = True),
-        "unranked": attr.bool(mandatory = True),
     },
 )
 
-def _gen_mlir_op(name, type, unranked):
-    tmpl_name = name.replace("_unranked", "") if unranked else name
+def _gen_mlir_op(name, type):
     _gen_mlir_op_rule(
         name = "generate_{name}_{type}_mlir".format(name = name, type = type),
-        template = "op_definitions/{name}.mlir.tmpl".format(name = tmpl_name),
+        template = "op_definitions/{name}.mlir.tmpl".format(name = name),
         type = type,
         out = "{name}_{type}.mlir".format(name = name, type = type),
-        unranked = unranked,
     )
 
-def gen_ranked_kernel_library(name, types, tile_size, tags = [], unroll_factors = None, extra_args = []):
+################################################################################
+# Kernels build rules.
+################################################################################
+
+def if_mlir_experimental_kernels_enabled(if_true, if_false = []):
+    return select({
+        "//tensorflow/core/kernels/mlir_generated:mlir_experimental_kernels_enabled": if_true,
+        "//conditions:default": if_false,
+    })
+
+def _gen_kernel_fatbin_impl(ctx):
+    cc_toolchain = find_cpp_toolchain(ctx)
+    feature_configuration = cc_common.configure_features(
+        ctx = ctx,
+        cc_toolchain = cc_toolchain,
+        requested_features = ctx.features,
+        unsupported_features = ctx.disabled_features,
+    )
+    name = ctx.attr.name
+    cmd_args = []
+    if ctx.attr.unroll_factors:
+        cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
+    if ctx.attr.extra_args:
+        cmd_args.extend(ctx.attr.extra_args)
+    tile_sizes = ctx.attr.tile_size.replace("x", ",")
+    arch_flag = ",".join(ctx.attr.gpu_archs)
+    gpu_bin = ctx.outputs.kernel
+
+    # cc_binary seems not to bring its dependencies with it, so do that explicitly here.
+    ctx.actions.run(
+        inputs = [ctx.file.mlir_op, ctx.file._tfso],
+        outputs = [gpu_bin],
+        executable = ctx.executable._tool,
+        arguments = cmd_args + [
+            "--tile_sizes=%s" % tile_sizes,
+            "--arch=%s" % arch_flag,
+            "--input=%s" % ctx.file.mlir_op.path,
+            "--output=%s" % gpu_bin.path,
+            "--enable_ftz=%s" % (ctx.attr.data_type == "f32"),
+        ],
+        mnemonic = "compile",
+    )
+    compilation_outputs = cc_common.create_compilation_outputs(
+        # We always produce PIC object files, so use the same object files for both.
+        objects = depset([gpu_bin]),
+        pic_objects = depset([gpu_bin]),
+    )
+    (linking_context, linking_outputs) = cc_common.create_linking_context_from_compilation_outputs(
+        name = ctx.label.name,
+        actions = ctx.actions,
+        feature_configuration = feature_configuration,
+        cc_toolchain = cc_toolchain,
+        compilation_outputs = compilation_outputs,
+    )
+    return [CcInfo(linking_context = linking_context)]
+
+_gen_kernel_fatbin_rule = rule(
+    attrs = {
+        "mlir_op": attr.label(mandatory = True, allow_single_file = True),
+        "data_type": attr.string(mandatory = True),
+        "tile_size": attr.string(mandatory = True),
+        "unroll_factors": attr.string(),
+        "gpu_archs": attr.string_list(mandatory = True),
+        "extra_args": attr.string_list(),
+        # cc_binary seems not to bring its dependencies with it, so do that explicitly here.
+        "_tfso": attr.label(
+            default = Label("//tensorflow:libtensorflow_framework.so.2"),
+            cfg = "host",
+            allow_single_file = True,
+        ),
+        "_tool": attr.label(
+            executable = True,
+            default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel"),
+            cfg = "host",
+        ),
+        "_cc_toolchain": attr.label(default = "@bazel_tools//tools/cpp:current_cc_toolchain"),
+    },
+    fragments = ["cpp"],
+    outputs = {"kernel": "%{name}_kernel.o"},
+    implementation = _gen_kernel_fatbin_impl,
+)
+
+def gen_kernel_library(name, types, tile_size, tags = [], unroll_factors = None, extra_args = []):
     """ Generate a library with kernels for a specific tensorflow op.
 
     Args:
@@ -260,111 +172,16 @@
             _gen_mlir_op(
                 name = name,
                 type = type,
-                unranked = False,
             )
-            _gen_kernel_image_hdr(
-                name = "{name}_{type}_kernel".format(name = name, type = type),
-                mlir_op = "{name}_{type}.mlir".format(name = name, type = type),
-                gpu_archs = rocm_gpu_architectures() if rocm_is_configured() else cuda_gpu_architectures(),
-                tile_size = tile_size,
-                unroll_factors = unroll_factors,
-                extra_args = extra_args,
-            )
-
-    native.cc_library(
-        name = name + "_kernels",
-        hdrs = if_gpu_is_configured([":{name}_{type}_kernel".format(name = name, type = type) for type in types]),
-        tags = tags,
-    )
-
-################################################################################
-# Unranked kernels build rules.
-################################################################################
-
-def if_mlir_unranked_kernels_enabled(if_true, if_false = []):
-    return select({
-        "//tensorflow/core/kernels/mlir_generated:mlir_use_unranked_kernels": if_true,
-        "//conditions:default": if_false,
-    })
-
-def _gen_unranked_kernel_fatbin_impl(ctx):
-    name = ctx.attr.name
-    cmd_args = []
-    if ctx.attr.unroll_factors:
-        cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
-    if ctx.attr.extra_args:
-        cmd_args.extend(ctx.attr.extra_args)
-    tile_sizes = ctx.attr.tile_size.replace("x", ",")
-    arch_flag = ",".join(ctx.attr.gpu_archs)
-    gpu_bin = ctx.outputs.output
-    ctx.actions.run(
-        inputs = [ctx.file.mlir_op, ctx.file._tfso],
-        outputs = [gpu_bin],
-        executable = ctx.executable._tool,
-        arguments = cmd_args + [
-            "--tile_sizes=%s" % tile_sizes,
-            "--arch=%s" % arch_flag,
-            "--input=%s" % ctx.file.mlir_op.path,
-            "--output=%s" % gpu_bin.path,
-        ],
-        mnemonic = "compile",
-    )
-
-_gen_unranked_kernel_fatbin_rule = rule(
-    attrs = {
-        "mlir_op": attr.label(mandatory = True, allow_single_file = True),
-        "output": attr.output(mandatory = True, doc = "The generated file"),
-        "tile_size": attr.string(mandatory = True),
-        "unroll_factors": attr.string(),
-        "gpu_archs": attr.string_list(mandatory = True),
-        "extra_args": attr.string_list(),
-        "_tfso": attr.label(
-            default = Label("//tensorflow:libtensorflow_framework.so.2"),
-            cfg = "host",
-            allow_single_file = True,
-        ),
-        "_tool": attr.label(
-            executable = True,
-            default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel"),
-            cfg = "host",
-        ),
-    },
-    output_to_genfiles = True,
-    implementation = _gen_unranked_kernel_fatbin_impl,
-)
-
-def gen_unranked_kernel_library(name, types, tile_size, tags = [], unroll_factors = None, extra_args = []):
-    """ Generate a library with unranked kernels for a specific tensorflow op.
-
-    Args:
-      name: The name of the tensorflow op.
-      types: The types ("f16", "f32", "f64") for which a kernel should be generated.
-      tile_size: The tiling specification, e.g. "16x16".
-      unroll_factors: The unrolling specification, e.g. "4,4"
-      tags: The tags which should be added to the library.
-      extra_args: Extra arguments to pass to the generator tool.
-    """
-
-    if cuda_gpu_architectures() or rocm_gpu_architectures():
-        for type in types:
-            _gen_mlir_op(
-                name = name,
-                type = type,
-                unranked = True,
-            )
-            _gen_unranked_kernel_fatbin_rule(
+            _gen_kernel_fatbin_rule(
                 name = "{name}_{type}_kernel_generator".format(name = name, type = type),
                 mlir_op = "{name}_{type}.mlir".format(name = name, type = type),
-                output = "{name}_{type}.a".format(name = name, type = type),
+                data_type = type,
                 gpu_archs = rocm_gpu_architectures() if rocm_is_configured() else cuda_gpu_architectures(),
                 tile_size = tile_size,
                 unroll_factors = unroll_factors,
                 extra_args = extra_args,
             )
-            native.cc_import(
-                name = "{name}_{type}_kernel".format(name = name, type = type),
-                static_library = "{name}_{type}.a".format(name = name, type = type),
-            )
 
             # We have to use a sh_test instead of build_test because it doesn't properly find the dependent targets.
             native.sh_test(
@@ -385,27 +202,7 @@
     native.cc_library(
         name = name + "_kernels",
         compatible_with = get_compatible_with_cloud(),
-        deps = if_gpu_is_configured([":{name}_{type}_kernel".format(name = name, type = type) for type in types]),
+        deps = if_gpu_is_configured([":{name}_{type}_kernel_generator".format(name = name, type = type) for type in types]),
         linkstatic = 1,
         tags = tags,
     )
-
-def gen_kernel_library(name, types, tile_size, tags = [], unroll_factors = None, extra_args = [], generate_ranked = True, generate_unranked = False):
-    if (generate_ranked):
-        gen_ranked_kernel_library(
-            name = name,
-            types = types,
-            tile_size = tile_size,
-            tags = tags,
-            unroll_factors = unroll_factors,
-            extra_args = extra_args,
-        )
-    if (generate_unranked):
-        gen_unranked_kernel_library(
-            name = name + "_unranked",
-            types = types,
-            tile_size = tile_size,
-            tags = tags,
-            unroll_factors = unroll_factors,
-            extra_args = extra_args,
-        )
diff --git a/tensorflow/core/kernels/mlir_generated/build_test.sh b/tensorflow/core/kernels/mlir_generated/build_test.sh
index a0748a9..0fcb8a3 100755
--- a/tensorflow/core/kernels/mlir_generated/build_test.sh
+++ b/tensorflow/core/kernels/mlir_generated/build_test.sh
@@ -24,7 +24,7 @@
 INPUT="$2"
 
 # Do something
-${TF_TO_KERNEL} --input=${INPUT} --output=${OUTPUT_FILE} --unroll_factors=4 --tile_sizes=256 --arch=sm_70,compute_75  || die "Failed to generate kernel"
+${TF_TO_KERNEL} --input=${INPUT} --output=${OUTPUT_FILE} --unroll_factors=4 --tile_sizes=256 --arch=sm_70,compute_75 "${@:3}"  || die "Failed to generate kernel"
 
 # Check something
 [ -s ${OUTPUT_FILE} ] || die "output file was empty"
diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_abs.cc b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_abs.cc
deleted file mode 100644
index 948a7c0..0000000
--- a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_abs.cc
+++ /dev/null
@@ -1,40 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <string>
-#include <vector>
-
-#include "absl/types/span.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor_types.h"
-#include "tensorflow/core/kernels/mlir_generated/abs_f16_kernel.h"
-#include "tensorflow/core/kernels/mlir_generated/abs_f32_kernel.h"
-#include "tensorflow/core/kernels/mlir_generated/abs_f64_kernel.h"
-#include "tensorflow/core/kernels/mlir_generated/abs_i32_kernel.h"
-#include "tensorflow/core/kernels/mlir_generated/abs_i64_kernel.h"
-#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h"
-
-namespace tensorflow {
-namespace {
-GENERATE_OP_KERNEL_BASE(Abs);
-}  // namespace
-
-GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, F16, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, F32, float);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, F64, double);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, I32, int32);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, I64, int64);
-}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cc b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cc
deleted file mode 100644
index c5fbb15..0000000
--- a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cc
+++ /dev/null
@@ -1,129 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h"
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/types/span.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor_types.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/stream_executor.h"
-
-namespace tensorflow {
-namespace {
-Status CreateKernel(absl::string_view kernel_name, uint64_t num_args,
-                    absl::string_view ptx, absl::Span<const uint8_t> cubin_data,
-                    se::StreamExecutor* stream_exec,
-                    std::unique_ptr<se::KernelBase>& kernel_base) {
-  se::MultiKernelLoaderSpec loader_spec(num_args);
-
-  if (!cubin_data.empty()) {
-    loader_spec.AddCudaCubinInMemory(
-        reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
-  }
-
-  kernel_base.reset(new se::KernelBase(stream_exec));
-  return stream_exec->GetKernel(loader_spec, kernel_base.get());
-}
-
-struct LaunchConfig {
-  se::BlockDim blockDim;
-  se::ThreadDim threadDim;
-};
-
-LaunchConfig GetLaunchConfiguration(std::vector<uint64> tile_sizes,
-                                    std::vector<uint64> unrolling_factors,
-                                    std::vector<uint64> shape) {
-  LaunchConfig result;
-  // Ensure the vectors are length 3 and pad with ones.
-  tile_sizes.resize(3, 1);
-  unrolling_factors.resize(3, 1);
-  shape.resize(3, 1);
-  // The number of threads is given by the tiling size.
-  result.threadDim = se::ThreadDim(tile_sizes[0], tile_sizes[1], tile_sizes[2]);
-  // We know that the kernel was generated by mapping the three outer-most
-  // dimensions to x,y,z dimensions. So we only need to compute those.
-  std::vector<int> block_dims(3);
-  for (int i = 0; i < 3; ++i) {
-    // Compute the number of grids. We use ceildiv here as we have to allocate
-    // an extra thread/block if the division is not even. The kernel contains
-    // code to handle the boundaries.
-    uint64 number_of_threads = Eigen::divup(shape[i], unrolling_factors[i]);
-    int number_of_grids = Eigen::divup(number_of_threads, tile_sizes[i]);
-    block_dims[i] = number_of_grids;
-  }
-  result.blockDim = se::BlockDim(block_dims[0], block_dims[1], block_dims[2]);
-  return result;
-}
-}  // namespace
-
-void MlirGeneratedUnaryOp::Compute(OpKernelContext* ctx) {
-  auto* stream = ctx->op_device_context()->stream();
-  se::KernelBase* kernel;
-  {
-    absl::MutexLock l(&mu_);
-    if (!kernel_) {
-      OP_REQUIRES_OK(ctx, CreateKernel(name_, 10, "", cubin_data_,
-                                       stream->parent(), kernel_));
-    }
-    kernel = kernel_.get();
-  }
-
-  const Tensor& inp = ctx->input(0);
-  Tensor* out = nullptr;
-  OP_REQUIRES_OK(
-      ctx, ctx->forward_input_or_allocate_output({0}, 0, inp.shape(), &out));
-
-  if (inp.NumElements() == 0) {
-    return;
-  }
-
-  se::KernelArgsArray<10> args;
-
-  args.add_device_memory_argument(
-      stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
-  args.add_device_memory_argument(
-      stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
-  args.add_argument<int64_t>(0);
-  args.add_argument<int64_t>(inp.NumElements());
-  args.add_argument<int64_t>(1);
-
-  args.add_device_memory_argument(
-      stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
-  args.add_device_memory_argument(
-      stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
-  args.add_argument<int64_t>(0);
-  args.add_argument<int64_t>(inp.NumElements());
-  args.add_argument<int64_t>(1);
-
-  // This has to be aligned with the configuration that was used when building
-  // the kernels. See the corresponding build rules in the `BUILD` file.
-  LaunchConfig config = GetLaunchConfiguration(
-      {256}, {4}, {static_cast<uint64>(inp.NumElements())});
-  OP_REQUIRES_OK(ctx, stream->parent()->Launch(stream, config.threadDim,
-                                               config.blockDim, *kernel, args));
-}
-
-}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h
deleted file mode 100644
index 466bbea..0000000
--- a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h
+++ /dev/null
@@ -1,76 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_H_
-#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_H_
-
-#include <memory>
-#include <string>
-
-#include "absl/strings/ascii.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/types/span.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor_types.h"
-#include "tensorflow/core/platform/stream_executor.h"
-
-namespace tensorflow {
-class MlirGeneratedUnaryOp : public OpKernel {
- public:
-  MlirGeneratedUnaryOp(OpKernelConstruction* ctx, std::string name,
-                       absl::Span<const uint8_t> cubin_data)
-      : OpKernel(ctx), name_(name), cubin_data_(cubin_data) {}
-
-  void Compute(OpKernelContext* ctx) override;
-
- private:
-  std::string name_;
-  absl::Span<const uint8_t> cubin_data_;
-  std::unique_ptr<se::KernelBase> kernel_;
-  absl::Mutex mu_;
-};
-
-#define GENERATE_OP_KERNEL_BASE(kernel_name)                               \
-  class MlirGenerated##kernel_name##Op : public MlirGeneratedUnaryOp {     \
-   public:                                                                 \
-    MlirGenerated##kernel_name##Op(OpKernelConstruction* ctx,              \
-                                   absl::Span<const uint8_t> cubin_data)   \
-        : MlirGeneratedUnaryOp(ctx, #kernel_name "_kernel", cubin_data) {} \
-  };
-
-#define GENERATE_OP_KERNEL_FOR(kernel_name, data_type)    \
-  class MlirGenerated##kernel_name##data_type##Op         \
-      : public MlirGenerated##kernel_name##Op {           \
-   public:                                                \
-    explicit MlirGenerated##kernel_name##data_type##Op(   \
-        OpKernelConstruction* ctx)                        \
-        : MlirGenerated##kernel_name                      \
-          ##Op(ctx, k##kernel_name##data_type##Kernel) {} \
-  };
-
-#define GENERATE_AND_REGISTER_UNARY_KERNEL(kernel_name, data_type,    \
-                                           native_data_type)          \
-  namespace {                                                         \
-  GENERATE_OP_KERNEL_FOR(kernel_name, data_type)                      \
-  }                                                                   \
-  REGISTER_KERNEL_BUILDER(Name(#kernel_name)                          \
-                              .Device(DEVICE_GPU)                     \
-                              .TypeConstraint<native_data_type>("T"), \
-                          MlirGenerated##kernel_name##data_type##Op);
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_H_
diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_tanh.cc b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_tanh.cc
deleted file mode 100644
index a9cc066..0000000
--- a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_tanh.cc
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <string>
-#include <vector>
-
-#include "absl/types/span.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor_types.h"
-#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h"
-#include "tensorflow/core/kernels/mlir_generated/tanh_f16_kernel.h"
-#include "tensorflow/core/kernels/mlir_generated/tanh_f32_kernel.h"
-#include "tensorflow/core/kernels/mlir_generated/tanh_f64_kernel.h"
-
-namespace tensorflow {
-namespace {
-GENERATE_OP_KERNEL_BASE(Tanh);
-}  // namespace
-
-GENERATE_AND_REGISTER_UNARY_KERNEL(Tanh, F16, Eigen::half)
-GENERATE_AND_REGISTER_UNARY_KERNEL(Tanh, F32, float)
-GENERATE_AND_REGISTER_UNARY_KERNEL(Tanh, F64, double)
-}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc
deleted file mode 100644
index b518aff..0000000
--- a/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc
+++ /dev/null
@@ -1,270 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <cmath>
-#include <limits>
-#include <memory>
-#include <vector>
-
-#include "tensorflow/core/common_runtime/device.h"
-#include "tensorflow/core/common_runtime/device_factory.h"
-#include "tensorflow/core/framework/fake_input.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/ops_testutil.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace {
-
-class GpuAddTest : public OpsTestBase {
- protected:
-  void SetUp() override {
-    std::unique_ptr<tensorflow::Device> device_gpu(
-        tensorflow::DeviceFactory::NewDevice("GPU", {},
-                                             "/job:a/replica:0/task:0"));
-    SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
-  }
-
-  template <typename T, typename BaselineType = T>
-  void SetAddOp(std::vector<T> input_1, TensorShape shape_1,
-                std::vector<T> input_2, TensorShape shape_2) {
-    TF_ASSERT_OK(NodeDefBuilder("add_op", "AddV2")
-                     .Input(FakeInput(DataTypeToEnum<T>::v()))
-                     .Input(FakeInput(DataTypeToEnum<T>::v()))
-                     .Attr("T", DataTypeToEnum<T>::v())
-                     .Finalize(node_def()));
-
-    TF_ASSERT_OK(InitOp());
-    inputs_.clear();
-    AddInputFromArray<T>(shape_1, input_1);
-    AddInputFromArray<T>(shape_2, input_2);
-  }
-
-  template <typename T, typename BaselineType = T>
-  void RunAndCompareAddOp(std::vector<T> input_1, TensorShape shape_1,
-                          std::vector<T> input_2, TensorShape shape_2,
-                          std::vector<T> output, TensorShape output_shape) {
-    SetAddOp<T>(input_1, shape_1, input_2, shape_2);
-    TF_ASSERT_OK(RunOpKernel());
-    Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, output_shape);
-    test::FillValues<T>(&expected_tensor, output);
-    test::ExpectEqual(expected_tensor, *GetOutput(0));
-  }
-
-  template <typename T, typename BaselineType = T>
-  void TestBroadcastingExpandAddOp() {
-    auto input_1 = {static_cast<T>(10)};
-    auto input_2 = {static_cast<T>(1), static_cast<T>(2), static_cast<T>(3),
-                    static_cast<T>(4), static_cast<T>(5), static_cast<T>(6)};
-    std::vector<T> expected{
-        static_cast<T>(11), static_cast<T>(12), static_cast<T>(13),
-        static_cast<T>(14), static_cast<T>(15), static_cast<T>(16),
-    };
-    auto expected_shape = TensorShape({6});
-    RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape({1}), input_2,
-                                        TensorShape({6}), expected,
-                                        expected_shape);
-  }
-
-  template <typename T, typename BaselineType = T>
-  void TestBroadcastingInDimAddOp() {
-    auto input_1 = {static_cast<T>(10), static_cast<T>(20), static_cast<T>(30)};
-    auto input_2 = {static_cast<T>(1), static_cast<T>(2), static_cast<T>(3),
-                    static_cast<T>(4), static_cast<T>(5), static_cast<T>(6)};
-    std::vector<T> expected{
-        static_cast<T>(11), static_cast<T>(22), static_cast<T>(33),
-        static_cast<T>(14), static_cast<T>(25), static_cast<T>(36),
-    };
-    auto expected_shape = TensorShape({2, 3});
-    RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape({3}), input_2,
-                                        TensorShape({2, 3}), expected,
-                                        expected_shape);
-  }
-
-  template <typename T, typename BaselineType = T>
-  void TestBroadcastingAddOp() {
-    auto input_1 = {static_cast<T>(10), static_cast<T>(20)};
-    auto input_2 = {static_cast<T>(1), static_cast<T>(2), static_cast<T>(3)};
-    std::vector<T> expected{
-        static_cast<T>(11), static_cast<T>(12), static_cast<T>(13),
-        static_cast<T>(21), static_cast<T>(22), static_cast<T>(23),
-    };
-    auto expected_shape = TensorShape({2, 3});
-    RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape({2, 1}), input_2,
-                                        TensorShape({3}), expected,
-                                        expected_shape);
-  }
-
-  template <typename T, typename BaselineType = T>
-  void RunAddOp() {
-    auto input_1 = {
-        static_cast<T>(-std::numeric_limits<BaselineType>::infinity()),
-        static_cast<T>(-0.1),
-        static_cast<T>(-0.0),
-        static_cast<T>(0.0),
-        static_cast<T>(0.1),
-        static_cast<T>(std::numeric_limits<BaselineType>::infinity())};
-    auto input_2 = {
-        static_cast<T>(-std::numeric_limits<BaselineType>::infinity()),
-        static_cast<T>(-0.1),
-        static_cast<T>(-0.0),
-        static_cast<T>(0.0),
-        static_cast<T>(0.1),
-        static_cast<T>(std::numeric_limits<BaselineType>::infinity())};
-    std::vector<T> expected;
-    for (const T& inp : input_2) {
-      expected.push_back(static_cast<T>(static_cast<BaselineType>(inp) +
-                                        static_cast<BaselineType>(inp)));
-    }
-    RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape{2, 3}, input_2,
-                                        TensorShape{2, 3}, expected,
-                                        TensorShape{2, 3});
-  }
-
-  template <typename T, typename BaselineType = T>
-  void TestEqualShapesAddOp() {
-    auto input_1 = {
-        static_cast<T>(-std::numeric_limits<BaselineType>::infinity()),
-        static_cast<T>(-0.1),
-        static_cast<T>(-0.0),
-        static_cast<T>(0.0),
-        static_cast<T>(0.1),
-        static_cast<T>(std::numeric_limits<BaselineType>::infinity())};
-    auto input_2 = {
-        static_cast<T>(-std::numeric_limits<BaselineType>::infinity()),
-        static_cast<T>(-0.1),
-        static_cast<T>(-0.0),
-        static_cast<T>(0.0),
-        static_cast<T>(0.1),
-        static_cast<T>(std::numeric_limits<BaselineType>::infinity())};
-    std::vector<T> expected;
-    for (const T& inp : input_2) {
-      expected.push_back(static_cast<T>(static_cast<BaselineType>(inp) +
-                                        static_cast<BaselineType>(inp)));
-    }
-    RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape{2, 3}, input_2,
-                                        TensorShape{2, 3}, expected,
-                                        TensorShape{2, 3});
-  }
-
-  template <typename T, typename BaselineType = T>
-  void TestOneIsScalarAddOp() {
-    auto input_1 = static_cast<T>(42);
-    auto input_2 = {
-        static_cast<T>(-std::numeric_limits<BaselineType>::infinity()),
-        static_cast<T>(-0.1),
-        static_cast<T>(-0.0),
-        static_cast<T>(0.0),
-        static_cast<T>(0.1),
-        static_cast<T>(std::numeric_limits<BaselineType>::infinity())};
-    std::vector<T> expected;
-    for (const T& inp : input_2) {
-      expected.push_back(static_cast<T>(static_cast<BaselineType>(input_1) +
-                                        static_cast<BaselineType>(inp)));
-    }
-    RunAndCompareAddOp<T, BaselineType>({input_1}, TensorShape{}, input_2,
-                                        TensorShape{2, 3}, expected,
-                                        TensorShape{2, 3});
-  }
-
-  template <typename T, typename RT = T>
-  void TestIncompatibleShapes() {
-    auto input_1 = {static_cast<T>(-0.1), static_cast<T>(-0.0),
-                    static_cast<T>(0.0)};
-    auto input_2 = {static_cast<T>(-0.1), static_cast<T>(0.0)};
-
-    SetAddOp<T>(input_1, TensorShape{3}, input_2, TensorShape{2});
-    auto status = RunOpKernel();
-    EXPECT_FALSE(status.ok());
-    EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
-  }
-
-  template <typename T, typename BaselineType = T>
-  void TestEmptyShapeWithBroadcastingAddOp() {
-    TensorShape input_shape_a{2, 0, 1};
-    TensorShape input_shape_b{2, 0, 5};
-    TensorShape expected_shape{2, 0, 5};
-    std::vector<T> empty_input = {};
-    RunAndCompareAddOp<T, BaselineType>(empty_input, input_shape_a, empty_input,
-                                        input_shape_b, empty_input,
-                                        expected_shape);
-    RunAndCompareAddOp<T, BaselineType>(empty_input, input_shape_b, empty_input,
-                                        input_shape_a, empty_input,
-                                        expected_shape);
-  }
-};
-
-TEST_F(GpuAddTest, AddFloat) { RunAddOp<float>(); }
-TEST_F(GpuAddTest, AddDouble) { RunAddOp<double>(); }
-TEST_F(GpuAddTest, AddHalf) { RunAddOp<Eigen::half, float>(); }
-TEST_F(GpuAddTest, AddInt64) { RunAddOp<int64, int64>(); }
-
-TEST_F(GpuAddTest, AddEqShapesFloat) { TestEqualShapesAddOp<float>(); }
-TEST_F(GpuAddTest, AddEqShapesDouble) { TestEqualShapesAddOp<double>(); }
-TEST_F(GpuAddTest, AddEqShapesHalf) {
-  TestEqualShapesAddOp<Eigen::half, float>();
-}
-TEST_F(GpuAddTest, AddEqShapesInt64) { TestEqualShapesAddOp<int64>(); }
-
-TEST_F(GpuAddTest, AddScalarFloat) { TestOneIsScalarAddOp<float>(); }
-TEST_F(GpuAddTest, AddScalarDouble) { TestOneIsScalarAddOp<double>(); }
-TEST_F(GpuAddTest, AddScalarHalf) {
-  TestOneIsScalarAddOp<Eigen::half, float>();
-}
-TEST_F(GpuAddTest, AddScalarInt64) { TestOneIsScalarAddOp<int64>(); }
-
-TEST_F(GpuAddTest, BCastExpandAddFloat) {
-  TestBroadcastingExpandAddOp<float>();
-}
-TEST_F(GpuAddTest, BCastExpandAddDouble) {
-  TestBroadcastingExpandAddOp<double>();
-}
-TEST_F(GpuAddTest, BCastExpandAddHalf) {
-  TestBroadcastingExpandAddOp<Eigen::half, float>();
-}
-TEST_F(GpuAddTest, BCastExpandAddInt64) {
-  TestBroadcastingExpandAddOp<int64>();
-}
-
-TEST_F(GpuAddTest, BCastInDimAddFloat) { TestBroadcastingInDimAddOp<float>(); }
-TEST_F(GpuAddTest, BCastInDimAddDouble) {
-  TestBroadcastingInDimAddOp<double>();
-}
-TEST_F(GpuAddTest, BCastInDimAddHalf) {
-  TestBroadcastingInDimAddOp<Eigen::half, float>();
-}
-TEST_F(GpuAddTest, BCastInDimAddInt64) { TestBroadcastingInDimAddOp<int64>(); }
-
-TEST_F(GpuAddTest, BCastAddFloat) { TestBroadcastingAddOp<float>(); }
-TEST_F(GpuAddTest, BCastAddDouble) { TestBroadcastingAddOp<double>(); }
-TEST_F(GpuAddTest, BCastAddHalf) {
-  TestBroadcastingAddOp<Eigen::half, float>();
-}
-TEST_F(GpuAddTest, BCastAddInt64) { TestBroadcastingAddOp<int64>(); }
-
-TEST_F(GpuAddTest, IncompatibleShapes) { TestIncompatibleShapes<float>(); }
-
-TEST_F(GpuAddTest, EmptyShapeBCastAddFloat) {
-  TestEmptyShapeWithBroadcastingAddOp<float>();
-}
-TEST_F(GpuAddTest, EmptyShapeBCastAddDouble) {
-  TestEmptyShapeWithBroadcastingAddOp<double>();
-}
-
-// TEST_F(GpuAddTest, AddV2Half) { RunAddOp<Eigen::half, float>(); }
-}  // namespace
-}  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc
new file mode 100644
index 0000000..919dad7
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc
@@ -0,0 +1,621 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+#include <limits>
+#include <memory>
+#include <vector>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/string_view.h"
+#include "llvm/ADT/STLExtras.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+class GpuBinaryOpTest : public OpsTestBase {
+ protected:
+  void SetUp() override {
+    std::unique_ptr<tensorflow::Device> device_gpu(
+        tensorflow::DeviceFactory::NewDevice("GPU", {},
+                                             "/job:a/replica:0/task:0"));
+    SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
+  }
+
+  template <typename T, typename OutT>
+  void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape,
+                   const absl::InlinedVector<T, 10>& lhs_input,
+                   const TensorShape& rhs_shape,
+                   const absl::InlinedVector<T, 10>& rhs_input,
+                   bool use_constraint) {
+    auto builder = NodeDefBuilder("some_name", op_name)
+                       .Input(FakeInput(DataTypeToEnum<T>::v()))
+                       .Input(FakeInput(DataTypeToEnum<T>::v()));
+    if (use_constraint) {
+      builder.Attr("T", DataTypeToEnum<T>::v());
+    }
+    TF_ASSERT_OK(builder.Finalize(node_def()));
+
+    TF_ASSERT_OK(InitOp());
+    AddInputFromArray<T>(lhs_shape, lhs_input);
+    AddInputFromArray<T>(rhs_shape, rhs_input);
+  }
+
+  // Run fully specified tests.
+
+  template <typename T, typename OutT>
+  void RunAndExpectResult(const std::string& op_name,
+                          const TensorShape& lhs_shape,
+                          const absl::InlinedVector<T, 10>& lhs_input,
+                          const TensorShape& rhs_shape,
+                          const absl::InlinedVector<T, 10>& rhs_input,
+                          const TensorShape& expected_shape,
+                          const absl::InlinedVector<OutT, 10>& expected_output,
+                          bool use_constraint = true) {
+    SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
+                         use_constraint);
+    TF_ASSERT_OK(RunOpKernel());
+
+    // Compare output to expectation.
+    Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value,
+                           expected_shape);
+    test::FillValues<OutT>(&expected_tensor, expected_output);
+    test::ExpectEqual(expected_tensor, *GetOutput(0));
+  }
+
+  template <typename T, typename OutT>
+  void RunAndExpectInvalidArgument(const std::string& op_name,
+                                   const TensorShape& lhs_shape,
+                                   const absl::InlinedVector<T, 10>& lhs_input,
+                                   const TensorShape& rhs_shape,
+                                   const absl::InlinedVector<T, 10>& rhs_input,
+                                   bool use_constraint = true) {
+    SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
+                         use_constraint);
+    auto status = RunOpKernel();
+    EXPECT_FALSE(status.ok());
+    EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
+  }
+
+  // Run common test cases.
+
+  template <typename T, typename OutT>
+  void TestIncompatibleShapes(const std::string& op_name,
+                              const absl::InlinedVector<T, 10>& lhs_input,
+                              const absl::InlinedVector<T, 10>& rhs_input,
+                              bool use_constraint = true) {
+    // Prepare incompatibly shaped inputs.
+    TensorShape lhs_shape{3};
+    TensorShape rhs_shape{2};
+    auto repeated_lhs_input =
+        test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements());
+    auto repeated_rhs_input =
+        test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
+
+    RunAndExpectInvalidArgument<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
+                                         rhs_shape, repeated_rhs_input,
+                                         use_constraint);
+  }
+
+  template <typename T, typename BaselineT, typename OutT,
+            typename BaselineOutT>
+  void TestEqualShapes(const std::string& op_name, const TensorShape& shape,
+                       const absl::InlinedVector<T, 10>& lhs_input,
+                       const absl::InlinedVector<T, 10>& rhs_input,
+                       BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
+                       bool use_constraint = true) {
+    // Prepare inputs.
+    int input_size = shape.num_elements();
+    auto repeated_lhs_input =
+        test::RepeatInputToMatchShape(lhs_input, input_size);
+    auto repeated_rhs_input =
+        test::RepeatInputToMatchShape(rhs_input, input_size);
+
+    // Compute expected results.
+    absl::InlinedVector<OutT, 10> expected_output;
+    for (auto it_lhs = repeated_lhs_input.begin(),
+              it_rhs = repeated_rhs_input.begin(),
+              end = repeated_lhs_input.end();
+         it_lhs != end; ++it_lhs, ++it_rhs) {
+      auto lhs = static_cast<BaselineT>(*it_lhs);
+      auto rhs = static_cast<BaselineT>(*it_rhs);
+      auto result = static_cast<OutT>(baseline_callback(lhs, rhs));
+      expected_output.push_back(result);
+    }
+
+    RunAndExpectResult<T, OutT>(op_name, shape, repeated_lhs_input, shape,
+                                repeated_rhs_input, shape, expected_output,
+                                use_constraint);
+  }
+
+  template <typename T, typename BaselineT, typename OutT,
+            typename BaselineOutT>
+  void TestOneScalar(const std::string& op_name, T scalar_input,
+                     const TensorShape& other_shape,
+                     const absl::InlinedVector<T, 10>& other_input,
+                     BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
+                     bool use_constraint = true) {
+    // Prepare inputs.
+    TensorShape scalar_shape{};
+    auto repeated_other_input =
+        test::RepeatInputToMatchShape(other_input, other_shape.num_elements());
+
+    // Compute expected results.
+    absl::InlinedVector<OutT, 10> expected_output;
+    for (auto it = repeated_other_input.begin(),
+              end = repeated_other_input.end();
+         it != end; ++it) {
+      auto scalar = static_cast<BaselineT>(scalar_input);
+      auto other_value = static_cast<BaselineT>(*it);
+      auto result = static_cast<OutT>(baseline_callback(scalar, other_value));
+      expected_output.push_back(result);
+    }
+
+    auto scalar_input_vector = test::InputAsVector<T>({scalar_input});
+    RunAndExpectResult<T, OutT>(op_name, scalar_shape, scalar_input_vector,
+                                other_shape, repeated_other_input,
+                                /*expected_shape=*/other_shape, expected_output,
+                                use_constraint);
+  }
+
+  template <typename T, typename BaselineT, typename OutT,
+            typename BaselineOutT>
+  void TestBroadcastingExpand(const std::string& op_name,
+                              const absl::InlinedVector<T, 10>& lhs_input,
+                              const absl::InlinedVector<T, 10>& rhs_input,
+                              BaselineOutT (*baseline_callback)(BaselineT,
+                                                                BaselineT),
+                              bool use_constraint = true) {
+    // Prepare inputs.
+    TensorShape lhs_shape{1};
+    TensorShape rhs_shape{6};
+    auto repeated_lhs_input =
+        test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements());
+    auto repeated_rhs_input =
+        test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
+
+    // Compute expected results.
+    std::vector<int> lhs_indices = {0, 0, 0, 0, 0, 0};
+    std::vector<int> rhs_indices = {0, 1, 2, 3, 4, 5};
+    auto expected_output =
+        ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
+            lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input,
+            baseline_callback);
+
+    RunAndExpectResult<T, OutT>(
+        op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
+        /*expected_shape=*/rhs_shape, expected_output, use_constraint);
+  }
+
+  template <typename T, typename BaselineT, typename OutT,
+            typename BaselineOutT>
+  void TestBroadcastingInDim(const std::string& op_name,
+                             const absl::InlinedVector<T, 10>& lhs_input,
+                             const absl::InlinedVector<T, 10>& rhs_input,
+                             BaselineOutT (*baseline_callback)(BaselineT,
+                                                               BaselineT),
+                             bool use_constraint = true) {
+    // Prepare inputs.
+    TensorShape lhs_shape{3};
+    TensorShape rhs_shape{2, 3};
+    auto repeated_lhs_input =
+        test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements());
+    auto repeated_rhs_input =
+        test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
+
+    // Compute expected results.
+    std::vector<int> lhs_indices = {0, 1, 2, 0, 1, 2};
+    std::vector<int> rhs_indices = {0, 1, 2, 3, 4, 5};
+    auto expected_output =
+        ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
+            lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input,
+            baseline_callback);
+
+    RunAndExpectResult<T, OutT>(
+        op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
+        /*expected_shape=*/rhs_shape, expected_output, use_constraint);
+  }
+
+  template <typename T, typename BaselineT, typename OutT,
+            typename BaselineOutT>
+  void TestBroadcasting(const std::string& op_name,
+                        const absl::InlinedVector<T, 10>& lhs_input,
+                        const absl::InlinedVector<T, 10>& rhs_input,
+                        BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
+                        bool use_constraint = true) {
+    // Prepare inputs.
+    TensorShape lhs_shape{2, 1};
+    TensorShape rhs_shape{3};
+    auto repeated_lhs_input =
+        test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements());
+    auto repeated_rhs_input =
+        test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
+
+    // Compute expected results.
+    TensorShape expected_shape{2, 3};
+    std::vector<int> lhs_indices = {0, 0, 0, 1, 1, 1};
+    std::vector<int> rhs_indices = {0, 1, 2, 0, 1, 2};
+    auto expected_output =
+        ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
+            lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input,
+            baseline_callback);
+
+    RunAndExpectResult<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
+                                rhs_shape, repeated_rhs_input, expected_shape,
+                                expected_output, use_constraint);
+  }
+
+  template <typename T, typename BaselineT, typename OutT,
+            typename BaselineOutT>
+  void TestEmptyShapeBroadcasting(const std::string& op_name,
+                                  const absl::InlinedVector<T, 10>& lhs_input,
+                                  const absl::InlinedVector<T, 10>& rhs_input,
+                                  bool use_constraint = true) {
+    // Prepare inputs.
+    TensorShape lhs_shape{2, 0, 1};
+    TensorShape rhs_shape{2, 0, 5};
+    absl::InlinedVector<T, 10> empty_input = {};
+
+    // Define expected result.
+    TensorShape expected_shape{2, 0, 5};
+    absl::InlinedVector<OutT, 10> expected_output = {};
+
+    RunAndExpectResult<T, OutT>(op_name, lhs_shape, empty_input, rhs_shape,
+                                empty_input, expected_shape, expected_output,
+                                use_constraint);
+  }
+
+ private:
+  template <typename T, typename BaselineT, typename OutT,
+            typename BaselineOutT>
+  absl::InlinedVector<OutT, 10> ComputeExpectedOutput(
+      std::vector<int> lhs_indices, absl::InlinedVector<T, 10> lhs_input,
+      std::vector<int> rhs_indices, absl::InlinedVector<T, 10> rhs_input,
+      BaselineOutT (*baseline_callback)(BaselineT, BaselineT)) {
+    absl::InlinedVector<OutT, 10> expected_output;
+    for (int i = 0; i < lhs_indices.size(); i++) {
+      auto lhs = static_cast<BaselineT>(lhs_input[lhs_indices[i]]);
+      auto rhs = static_cast<BaselineT>(rhs_input[rhs_indices[i]]);
+      auto result = static_cast<OutT>(baseline_callback(lhs, rhs));
+      expected_output.push_back(result);
+    }
+    return expected_output;
+  }
+};
+
+// Macros to easily generate common test cases. For specific inputs, please
+// define your own test fixtures.
+
+#define GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, BaselineT, OutT,    \
+                                 BaselineOutT, baseline_callback,           \
+                                 use_constraint)                            \
+  TEST_F(GpuBinaryOpTest, op_name##EqShapes##test_name) {                   \
+    TestEqualShapes<T, BaselineT, OutT, BaselineOutT>(                      \
+        #op_name, /*shape=*/test::DefaultInputShape(),                      \
+        /*lhs_input=*/test::DefaultInput<T>(#op_name),                      \
+        /*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback,   \
+        use_constraint);                                                    \
+  }                                                                         \
+                                                                            \
+  TEST_F(GpuBinaryOpTest, op_name##OneScalar##test_name) {                  \
+    TestOneScalar<T, BaselineT, OutT, BaselineOutT>(                        \
+        #op_name, /*scalar_input=*/test::DefaultScalarInput<T>(),           \
+        /*other_shape=*/test::DefaultInputShape(),                          \
+        /*other_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
+        use_constraint);                                                    \
+  }                                                                         \
+                                                                            \
+  TEST_F(GpuBinaryOpTest, op_name##IncompatibleShapes##test_name) {         \
+    TestIncompatibleShapes<T, OutT>(                                        \
+        #op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name),            \
+        /*rhs_input=*/test::DefaultInput<T>(#op_name), use_constraint);     \
+  }                                                                         \
+                                                                            \
+  TEST_F(GpuBinaryOpTest, op_name##BroadcastingExpand##test_name) {         \
+    TestBroadcastingExpand<T, BaselineT, OutT, BaselineOutT>(               \
+        #op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name),            \
+        /*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback,   \
+        use_constraint);                                                    \
+  }                                                                         \
+                                                                            \
+  TEST_F(GpuBinaryOpTest, op_name##BroadcastingInDim##test_name) {          \
+    TestBroadcastingInDim<T, BaselineT, OutT, BaselineOutT>(                \
+        #op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name),            \
+        /*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback,   \
+        use_constraint);                                                    \
+  }                                                                         \
+                                                                            \
+  TEST_F(GpuBinaryOpTest, op_name##Broadcasting##test_name) {               \
+    TestBroadcasting<T, BaselineT, OutT, BaselineOutT>(                     \
+        #op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name),            \
+        /*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback,   \
+        use_constraint);                                                    \
+  }                                                                         \
+                                                                            \
+  TEST_F(GpuBinaryOpTest, op_name##EmptyShapeBroadcasting##test_name) {     \
+    TestEmptyShapeBroadcasting<T, BaselineT, OutT, BaselineOutT>(           \
+        #op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name),            \
+        /*rhs_input=*/test::DefaultInput<T>(#op_name), use_constraint);     \
+  }
+
+#define GENERATE_DEFAULT_TESTS(op_name, test_name, T, OutT, baseline_callback) \
+  GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT,               \
+                           baseline_callback, /*use_constraint=*/false)
+
+#define GENERATE_DEFAULT_TESTS_SAME_INPUT_AND_OUTPUT_TYPE( \
+    op_name, test_name, T, baseline_callback)              \
+  GENERATE_DEFAULT_TESTS(op_name, test_name, T, T, baseline_callback)
+
+/// Test `tf.AddV2`.
+
+template <typename T>
+T baseline_add(T lhs, T rhs) {
+  return lhs + rhs;
+}
+
+GENERATE_DEFAULT_TESTS(AddV2,
+                       /*test_name=*/Half, Eigen::half, Eigen::half,
+                       baseline_add)
+GENERATE_DEFAULT_TESTS(AddV2,
+                       /*test_name=*/Float, float, float, baseline_add)
+GENERATE_DEFAULT_TESTS(AddV2,
+                       /*test_name=*/Double, double, double, baseline_add)
+GENERATE_DEFAULT_TESTS(AddV2,
+                       /*test_name=*/Int64, int64, int64, baseline_add)
+
+/// Test `tf.BitwiseAnd`.
+
+template <typename T>
+T baseline_bitwise_and(T lhs, T rhs) {
+  return lhs & rhs;
+}
+
+GENERATE_DEFAULT_TESTS(BitwiseAnd,
+                       /*test_name=*/Int8, int8, int8, baseline_bitwise_and)
+GENERATE_DEFAULT_TESTS(BitwiseAnd,
+                       /*test_name=*/Int16, int16, int16, baseline_bitwise_and)
+GENERATE_DEFAULT_TESTS(BitwiseAnd,
+                       /*test_name=*/Int32, int32, int32, baseline_bitwise_and)
+GENERATE_DEFAULT_TESTS(BitwiseAnd,
+                       /*test_name=*/Int64, int64, int64, baseline_bitwise_and)
+
+/// Test `tf.BitwiseOr`.
+
+template <typename T>
+T baseline_bitwise_or(T lhs, T rhs) {
+  return lhs | rhs;
+}
+
+GENERATE_DEFAULT_TESTS(BitwiseOr,
+                       /*test_name=*/Int8, int8, int8, baseline_bitwise_or)
+GENERATE_DEFAULT_TESTS(BitwiseOr,
+                       /*test_name=*/Int16, int16, int16, baseline_bitwise_or)
+GENERATE_DEFAULT_TESTS(BitwiseOr,
+                       /*test_name=*/Int32, int32, int32, baseline_bitwise_or)
+GENERATE_DEFAULT_TESTS(BitwiseOr,
+                       /*test_name=*/Int64, int64, int64, baseline_bitwise_or)
+
+/// Test `tf.BitwiseXor`.
+
+template <typename T>
+T baseline_bitwise_xor(T lhs, T rhs) {
+  return lhs ^ rhs;
+}
+
+GENERATE_DEFAULT_TESTS(BitwiseXor,
+                       /*test_name=*/Int8, int8, int8, baseline_bitwise_xor)
+GENERATE_DEFAULT_TESTS(BitwiseXor,
+                       /*test_name=*/Int16, int16, int16, baseline_bitwise_xor)
+GENERATE_DEFAULT_TESTS(BitwiseXor,
+                       /*test_name=*/Int32, int32, int32, baseline_bitwise_xor)
+GENERATE_DEFAULT_TESTS(BitwiseXor,
+                       /*test_name=*/Int64, int64, int64, baseline_bitwise_xor)
+
+/// Test `tf.LeftShift`.
+
+template <typename T>
+T baseline_left_shift(T lhs, T rhs) {
+  return lhs << rhs;
+}
+
+GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
+                       baseline_left_shift)
+GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
+                       baseline_left_shift)
+GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
+                       baseline_left_shift)
+GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
+                       baseline_left_shift)
+
+/// Test `tf.RightShift`.
+
+template <typename T>
+T baseline_right_shift(T lhs, T rhs) {
+  return lhs >> rhs;
+}
+
+GENERATE_DEFAULT_TESTS(RightShift,
+                       /*test_name=*/Int8, int8, int8, baseline_right_shift)
+GENERATE_DEFAULT_TESTS(RightShift,
+                       /*test_name=*/Int16, int16, int16, baseline_right_shift)
+GENERATE_DEFAULT_TESTS(RightShift,
+                       /*test_name=*/Int32, int32, int32, baseline_right_shift)
+GENERATE_DEFAULT_TESTS(RightShift,
+                       /*test_name=*/Int64, int64, int64, baseline_right_shift)
+
+/// Test `tf.Equal`.
+
+template <typename T>
+bool baseline_equal(T lhs, T rhs) {
+  return lhs == rhs;
+}
+
+GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Half, Eigen::half, bool,
+                       baseline_equal)
+GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Float, float, bool, baseline_equal)
+GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Double, double, bool,
+                       baseline_equal)
+GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Bool, bool, bool, baseline_equal)
+GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int8, int8, bool, baseline_equal)
+GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int16, int16, bool, baseline_equal)
+GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int64, int64, bool, baseline_equal)
+
+/// Test `tf.NotEqual`.
+
+template <typename T>
+bool baseline_not_equal(T lhs, T rhs) {
+  return lhs != rhs;
+}
+
+GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
+                       baseline_not_equal)
+GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
+                       baseline_not_equal)
+GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
+                       baseline_not_equal)
+GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
+                       baseline_not_equal)
+GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
+                       baseline_not_equal)
+GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
+                       baseline_not_equal)
+GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
+                       baseline_not_equal)
+
+/// Test `tf.Greater`.
+
+template <typename T>
+bool baseline_greater(T lhs, T rhs) {
+  return lhs > rhs;
+}
+
+GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Half, Eigen::half, bool,
+                       baseline_greater)
+GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Float, float, bool,
+                       baseline_greater)
+GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Double, double, bool,
+                       baseline_greater)
+GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Int8, int8, bool,
+                       baseline_greater)
+GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Int16, int16, bool,
+                       baseline_greater)
+GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Int64, int64, bool,
+                       baseline_greater)
+
+/// Test `tf.GreaterEqual`.
+
+template <typename T>
+bool baseline_greater_equal(T lhs, T rhs) {
+  return lhs >= rhs;
+}
+
+GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Half, Eigen::half, bool,
+                       baseline_greater_equal)
+GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Float, float, bool,
+                       baseline_greater_equal)
+GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Double, double, bool,
+                       baseline_greater_equal)
+GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int8, int8, bool,
+                       baseline_greater_equal)
+GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int16, int16, bool,
+                       baseline_greater_equal)
+GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int64, int64, bool,
+                       baseline_greater_equal)
+
+/// Test `tf.Less`.
+
+template <typename T>
+bool baseline_less(T lhs, T rhs) {
+  return lhs < rhs;
+}
+
+GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Half, Eigen::half, bool,
+                       baseline_less)
+GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Float, float, bool, baseline_less)
+GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Double, double, bool, baseline_less)
+GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Int8, int8, bool, baseline_less)
+GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Int16, int16, bool, baseline_less)
+GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Int64, int64, bool, baseline_less)
+
+/// Test `tf.LessEqual`.
+
+template <typename T>
+bool baseline_less_equal(T lhs, T rhs) {
+  return lhs <= rhs;
+}
+
+GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Half, Eigen::half, bool,
+                       baseline_less_equal)
+GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Float, float, bool,
+                       baseline_less_equal)
+GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Double, double, bool,
+                       baseline_less_equal)
+GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Int8, int8, bool,
+                       baseline_less_equal)
+GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Int16, int16, bool,
+                       baseline_less_equal)
+GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Int64, int64, bool,
+                       baseline_less_equal)
+
+/// Test `tf.LogicalAnd`.
+
+bool baseline_logical_and(bool lhs, bool rhs) { return lhs && rhs; }
+
+GENERATE_DEFAULT_TESTS_2(LogicalAnd, /*test_name=*/Bool, /*T=*/bool,
+                         /*BaselineT=*/bool, /*OutT=*/bool,
+                         /*BaselineOutT=*/bool, baseline_logical_and,
+                         /*use_constraint=*/false)
+
+/// Test `tf.LogicalOr`.
+
+bool baseline_logical_or(bool lhs, bool rhs) { return lhs || rhs; }
+
+GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
+                         /*BaselineT=*/bool, /*OutT=*/bool,
+                         /*BaselineOutT=*/bool, baseline_logical_or,
+                         /*use_constraint=*/false)
+
+/// Test `tf.FloorDiv`.
+template <typename T>
+T baseline_floor_div(T lhs, T rhs) {
+  return std::floor(lhs / rhs);
+}
+
+template <>
+Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
+  return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
+}
+
+GENERATE_DEFAULT_TESTS(FloorDiv,
+                       /*test_name=*/Half, Eigen::half, Eigen::half,
+                       baseline_floor_div);
+GENERATE_DEFAULT_TESTS(FloorDiv,
+                       /*test_name=*/Float, float, float, baseline_floor_div);
+GENERATE_DEFAULT_TESTS(FloorDiv,
+                       /*test_name=*/Double, double, double,
+                       baseline_floor_div);
+
+}  // namespace
+}  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_abs.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc
similarity index 93%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_abs.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc
index 43eb7bb..615f7f0 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_abs.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_gpu_add.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_add.cc
similarity index 93%
rename from tensorflow/core/kernels/mlir_generated/unranked_gpu_add.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_add.cc
index decfd99..674808e 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_gpu_add.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_add.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc
new file mode 100644
index 0000000..43343f4
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i8, DT_INT8, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i16, DT_INT16, int16);
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i32, DT_INT32, int32);
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i64, DT_INT64, int64);
+
+// TODO(b/172804967): Enable once fixed.
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui8, DT_UINT8, uint8);
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui16, DT_UINT16, uint16);
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui32, DT_UINT32, uint32);
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui64, DT_UINT64, uint64);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc
new file mode 100644
index 0000000..364fb27
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i8, DT_INT8, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i16, DT_INT16, int16);
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i32, DT_INT32, int32);
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i64, DT_INT64, int64);
+
+// TODO(b/172804967): Enable once fixed.
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui8, DT_UINT8, uint8);
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui16, DT_UINT16, uint16);
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui32, DT_UINT32, uint32);
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui64, DT_UINT64, uint64);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc
new file mode 100644
index 0000000..5529067
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i8, DT_INT8, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i16, DT_INT16, int16);
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i32, DT_INT32, int32);
+GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i64, DT_INT64, int64);
+
+// TODO(b/172804967): Enable once fixed.
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui8, DT_UINT8, uint8);
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui16, DT_UINT16, uint16);
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui32, DT_UINT32, uint32);
+// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui64, DT_UINT64, uint64);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_ceil.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc
similarity index 92%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_ceil.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc
index 41800d0..3a3561f 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_ceil.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_conj.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc
similarity index 67%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_conj.cc
copy to tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc
index 52fe23f..d076e5c 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_conj.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc
@@ -16,13 +16,14 @@
 #include <complex>
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
-GENERATE_AND_REGISTER_UNARY_KERNEL(Conj, f32, DT_COMPLEX64,
-                                   std::complex<float>);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Conj, f64, DT_COMPLEX128,
-                                   std::complex<double>);
+GENERATE_UNARY_KERNEL2(Complex, f32, DT_COMPLEX64, std::complex<float>, float);
+REGISTER_COMPLEX_KERNEL(Complex, f32, std::complex<float>, float);
+GENERATE_UNARY_KERNEL2(Complex, f64, DT_COMPLEX128, std::complex<double>,
+                       double);
+REGISTER_COMPLEX_KERNEL(Complex, f64, std::complex<double>, double);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_conj.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc
similarity index 93%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_conj.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc
index 52fe23f..44fc983 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_conj.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc
@@ -16,7 +16,7 @@
 #include <complex>
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_cos.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc
similarity index 92%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_cos.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc
index ca3832b..8fd6f2b 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_cos.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc
new file mode 100644
index 0000000..c826599
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <complex>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, f16, DT_BOOL, bool, Eigen::half);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, f32, DT_BOOL, bool, float);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, f64, DT_BOOL, bool, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i1, DT_BOOL, bool, bool);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i8, DT_BOOL, bool, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i16, DT_BOOL, bool, int16);
+// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i64, DT_BOOL, bool, int64);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_exp.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc
similarity index 92%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_exp.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc
index b14b5ce..d456363 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_exp.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_floor.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc
similarity index 92%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_floor.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc
index faf7616..90d8306 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_floor.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc
similarity index 73%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
copy to tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc
index 2dd4a8d..b5371b94 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc
@@ -12,14 +12,13 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f16, DT_HALF, Eigen::half);
+GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f32, DT_FLOAT, float);
+GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f64, DT_DOUBLE, double);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc
new file mode 100644
index 0000000..993f623
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <complex>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, f16, DT_BOOL, bool, Eigen::half);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, f32, DT_BOOL, bool, float);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, f64, DT_BOOL, bool, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, i8, DT_BOOL, bool, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, i16, DT_BOOL, bool, int16);
+// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, i64, DT_BOOL, bool, int64);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc
new file mode 100644
index 0000000..aa66a96
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <complex>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, f16, DT_BOOL, bool,
+                                     Eigen::half);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, f32, DT_BOOL, bool, float);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, f64, DT_BOOL, bool, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, i8, DT_BOOL, bool, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, i16, DT_BOOL, bool, int16);
+// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
+GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, i64, DT_BOOL, bool, int64);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_imag.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc
similarity index 93%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_imag.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc
index dec53c9..8382f0d 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_imag.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc
@@ -16,7 +16,7 @@
 #include <complex>
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_is_inf.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc
similarity index 93%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_is_inf.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc
index 240d319..41c122a 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_is_inf.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc
@@ -16,7 +16,7 @@
 #include <complex>
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc
similarity index 69%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
copy to tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc
index 2dd4a8d..3034cef 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc
@@ -12,14 +12,14 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i8, DT_INT8, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i16, DT_INT16, int16);
+GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i32, DT_INT32, int32);
+GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i64, DT_INT64, int64);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc
new file mode 100644
index 0000000..c3fc8a2
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <complex>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, f16, DT_BOOL, bool, Eigen::half);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, f32, DT_BOOL, bool, float);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, f64, DT_BOOL, bool, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, i8, DT_BOOL, bool, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, i16, DT_BOOL, bool, int16);
+// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
+GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, i64, DT_BOOL, bool, int64);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc
new file mode 100644
index 0000000..8f0e18a
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <complex>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, f16, DT_BOOL, bool,
+                                     Eigen::half);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, f32, DT_BOOL, bool, float);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, f64, DT_BOOL, bool, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, i8, DT_BOOL, bool, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, i16, DT_BOOL, bool, int16);
+// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
+GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, i64, DT_BOOL, bool, int64);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_log.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_log.cc
similarity index 92%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_log.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_log.cc
index afd941b..83e7576 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_log.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_log.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_logical_not.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc
similarity index 72%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_logical_not.cc
copy to tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc
index c60b799..4dcd779 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_logical_not.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc
@@ -12,14 +12,14 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
-GENERATE_UNARY_KERNEL(LogicalNot, i1, DT_BOOL, bool);
-// LogicalNot does not have a "T" attribute because it only works with type
+GENERATE_BINARY_KERNEL(LogicalAnd, i1, DT_BOOL, bool);
+// LogicalAnd does not have a "T" attribute because it only works with type
 // bool. So we need to register it without TypeConstraint<bool>("T").
-REGISTER_KERNEL_NO_TYPE_CONSTRAINT(LogicalNot, i1);
+REGISTER_KERNEL_NO_TYPE_CONSTRAINT(LogicalAnd, i1);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_logical_not.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc
similarity index 92%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_logical_not.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc
index c60b799..944e466 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_logical_not.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_logical_not.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc
similarity index 72%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_logical_not.cc
copy to tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc
index c60b799..681a623 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_logical_not.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc
@@ -12,14 +12,14 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
-GENERATE_UNARY_KERNEL(LogicalNot, i1, DT_BOOL, bool);
-// LogicalNot does not have a "T" attribute because it only works with type
+GENERATE_BINARY_KERNEL(LogicalOr, i1, DT_BOOL, bool);
+// LogicalOr does not have a "T" attribute because it only works with type
 // bool. So we need to register it without TypeConstraint<bool>("T").
-REGISTER_KERNEL_NO_TYPE_CONSTRAINT(LogicalNot, i1);
+REGISTER_KERNEL_NO_TYPE_CONSTRAINT(LogicalOr, i1);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc
similarity index 60%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
copy to tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc
index 2dd4a8d..28d875d 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc
@@ -14,12 +14,16 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, f16, DT_HALF, Eigen::half);
+GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, f32, DT_FLOAT, float);
+GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, i8, DT_INT8, int8);
+// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
+GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, i16, DT_INT16, int16);
+GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, i64, DT_INT64, int64);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc
similarity index 60%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
copy to tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc
index 2dd4a8d..a487346 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc
@@ -14,12 +14,16 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, f16, DT_HALF, Eigen::half);
+GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, f32, DT_FLOAT, float);
+GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, i8, DT_INT8, int8);
+GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, i16, DT_INT16, int16);
+// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
+GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, i64, DT_INT64, int64);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc
new file mode 100644
index 0000000..009062c
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <complex>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, f16, DT_BOOL, bool, Eigen::half);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, f32, DT_BOOL, bool, float);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, f64, DT_BOOL, bool, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i1, DT_BOOL, bool, bool);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i8, DT_BOOL, bool, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i16, DT_BOOL, bool, int16);
+// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
+GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i64, DT_BOOL, bool, int64);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_real.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc
similarity index 93%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_real.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_real.cc
index 8567060..668a983 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_real.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc
@@ -16,7 +16,7 @@
 #include <complex>
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc
similarity index 69%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
copy to tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc
index 2dd4a8d..2e49ba7 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc
@@ -12,14 +12,14 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i8, DT_INT8, int8);
+GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i16, DT_INT16, int16);
+GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i32, DT_INT32, int32);
+GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i64, DT_INT64, int64);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_rsqrt.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc
similarity index 92%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_rsqrt.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc
index f89e106..770eb3b 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_rsqrt.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc
similarity index 88%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc
index e49d640..e95369f 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc
@@ -14,14 +14,14 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
 GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f16, DT_HALF, Eigen::half);
 GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f32, DT_FLOAT, float);
 GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, i32, DT_INT32, int32);
+// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
 GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, i64, DT_INT64, int64);
 // TODO(b/162577610): Register the kernel for complex types and bfloat.
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc
similarity index 92%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc
index 2dd4a8d..42ad713 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sqrt.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc
similarity index 92%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sqrt.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc
index 9b77735..60bdd9f 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sqrt.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_tanh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc
similarity index 92%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_tanh.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc
index 5a703b9..7aad99a 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_tanh.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.cc b/tensorflow/core/kernels/mlir_generated/gpu_ops_base.cc
similarity index 95%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.cc
rename to tensorflow/core/kernels/mlir_generated/gpu_ops_base.cc
index bc4a36f..d4960fb 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_ops_base.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
 
 #include "tensorflow/core/framework/allocation_description.pb.h"
 #include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h b/tensorflow/core/kernels/mlir_generated/gpu_ops_base.h
similarity index 76%
rename from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h
rename to tensorflow/core/kernels/mlir_generated/gpu_ops_base.h
index 81d94e0..b9578cd 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h
+++ b/tensorflow/core/kernels/mlir_generated/gpu_ops_base.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_UNRANKED_OP_GPU_ABS_H_
-#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_UNRANKED_OP_GPU_ABS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_BASE_H_
+#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_BASE_H_
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
@@ -25,6 +25,18 @@
 
 namespace tensorflow {
 
+// A type-erased version of the UnrankedMemRefType to allow it to be used
+// as the return type of an extern "C" function on windows.
+struct UntypedUnrankedMemRefType {
+  int64_t rank;
+  void* descriptor;
+};
+
+template <typename ElemType>
+UnrankedMemRefType<ElemType> ConvertToTyped(UntypedUnrankedMemRefType desc) {
+  return {desc.rank, desc.descriptor};
+}
+
 // Returns a pointer to an allocated MlirTensorBuffer that takes ownership of
 // pre-allocated memory.
 TensorBuffer* GetMlirTensorBuffer(const void* ptr, size_t size,
@@ -157,25 +169,37 @@
   GENERATE_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type)         \
   REGISTER_KERNEL(tf_op, mlir_type, data_type)
 
-#define GENERATE_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type)     \
-  extern "C" ::UnrankedMemRefType<data_type> MLIR_FUNCTION(tf_op, mlir_type)( \
-      tensorflow::OpKernelContext * ctx,                                      \
-      const ::UnrankedMemRefType<data_type>* arg1,                            \
-      const ::UnrankedMemRefType<data_type>* arg2);                           \
-                                                                              \
-  namespace {                                                                 \
-  class MlirUnranked##tf_op##mlir_type##Op                                    \
-      : public MlirUnrankedOp<tf_data_type, data_type,                        \
-                              MlirUnranked##tf_op##mlir_type##Op> {           \
-   public:                                                                    \
-    using MlirUnrankedOp::MlirUnrankedOp;                                     \
-                                                                              \
-    static ::UnrankedMemRefType<data_type> Invoke(                            \
-        OpKernelContext* ctx,                                                 \
-        llvm::ArrayRef<::UnrankedMemRefType<data_type>> args) {               \
-      return MLIR_FUNCTION(tf_op, mlir_type)(ctx, &args[0], &args[1]);        \
-    }                                                                         \
-  };                                                                          \
+#define GENERATE_AND_REGISTER_BINARY_KERNEL2(                               \
+    tf_op, mlir_type, tf_data_type, result_data_type, input_data_type)      \
+  GENERATE_BINARY_KERNEL2(tf_op, mlir_type, tf_data_type, result_data_type, \
+                          input_data_type)                                  \
+  REGISTER_KERNEL(tf_op, mlir_type, input_data_type)
+
+#define GENERATE_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
+  GENERATE_BINARY_KERNEL2(tf_op, mlir_type, tf_data_type, data_type, data_type)
+
+#define GENERATE_BINARY_KERNEL2(tf_op, mlir_type, tf_data_type,         \
+                                result_data_type, input_data_type)      \
+  extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type)( \
+      tensorflow::OpKernelContext * ctx,                                \
+      const ::UnrankedMemRefType<input_data_type>* arg1,                \
+      const ::UnrankedMemRefType<input_data_type>* arg2);               \
+                                                                        \
+  namespace {                                                           \
+  class MlirUnranked##tf_op##mlir_type##Op                              \
+      : public MlirUnrankedOp<tf_data_type, result_data_type,           \
+                              MlirUnranked##tf_op##mlir_type##Op,       \
+                              input_data_type> {                        \
+   public:                                                              \
+    using MlirUnrankedOp::MlirUnrankedOp;                               \
+                                                                        \
+    static ::UnrankedMemRefType<result_data_type> Invoke(               \
+        OpKernelContext* ctx,                                           \
+        llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) {   \
+      return ConvertToTyped<result_data_type>(                          \
+          MLIR_FUNCTION(tf_op, mlir_type)(ctx, &args[0], &args[1]));    \
+    }                                                                   \
+  };                                                                    \
   }
 
 #define GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, \
@@ -186,28 +210,29 @@
 #define GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
   GENERATE_UNARY_KERNEL2(tf_op, mlir_type, tf_data_type, data_type, data_type)
 
-#define GENERATE_UNARY_KERNEL2(tf_op, mlir_type, tf_data_type, data_type,     \
-                               input_data_type)                               \
-  extern "C" ::UnrankedMemRefType<data_type> MLIR_FUNCTION(tf_op, mlir_type)( \
-      tensorflow::OpKernelContext * ctx,                                      \
-      const ::UnrankedMemRefType<input_data_type>* arg);                      \
-                                                                              \
-  namespace {                                                                 \
-  class MlirUnranked##tf_op##mlir_type##Op                                    \
-      : public MlirUnrankedOp<tf_data_type, data_type,                        \
-                              MlirUnranked##tf_op##mlir_type##Op,             \
-                              input_data_type> {                              \
-   public:                                                                    \
-    using MlirUnrankedOp::MlirUnrankedOp;                                     \
-                                                                              \
-    static ::UnrankedMemRefType<data_type> Invoke(                            \
-        OpKernelContext* ctx,                                                 \
-        llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) {         \
-      return MLIR_FUNCTION(tf_op, mlir_type)(ctx, &args[0]);                  \
-    }                                                                         \
-  };                                                                          \
+#define GENERATE_UNARY_KERNEL2(tf_op, mlir_type, tf_data_type,          \
+                               result_data_type, input_data_type)       \
+  extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type)( \
+      tensorflow::OpKernelContext * ctx,                                \
+      const ::UnrankedMemRefType<input_data_type>* arg);                \
+                                                                        \
+  namespace {                                                           \
+  class MlirUnranked##tf_op##mlir_type##Op                              \
+      : public MlirUnrankedOp<tf_data_type, result_data_type,           \
+                              MlirUnranked##tf_op##mlir_type##Op,       \
+                              input_data_type> {                        \
+   public:                                                              \
+    using MlirUnrankedOp::MlirUnrankedOp;                               \
+                                                                        \
+    static ::UnrankedMemRefType<result_data_type> Invoke(               \
+        OpKernelContext* ctx,                                           \
+        llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) {   \
+      return ConvertToTyped<result_data_type>(                          \
+          MLIR_FUNCTION(tf_op, mlir_type)(ctx, &args[0]));              \
+    }                                                                   \
+  };                                                                    \
   }
 
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_UNRANKED_OP_GPU_ABS_H_
+#endif  // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_BASE_H_
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.cc
similarity index 68%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
copy to tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.cc
index 2dd4a8d..39b23c7 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.cc
@@ -13,13 +13,12 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h"
 
 namespace tensorflow {
+namespace test {
 
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
+TensorShape DefaultInputShape() { return TensorShape{3, 4}; }
 
+}  // namespace test
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h b/tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h
new file mode 100644
index 0000000..00d288b
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h
@@ -0,0 +1,159 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_TEST_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_TEST_UTIL_H_
+
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/string_view.h"
+#include "llvm/ADT/STLExtras.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace test {
+
+/// Helper functions to create or derive inputs of the right type and size.
+
+template <typename T, typename LiteralT>
+absl::InlinedVector<T, 10> InputAsVector(
+    std::initializer_list<LiteralT> input) {
+  absl::InlinedVector<T, 10> result;
+  result.reserve(input.size());
+  for (const LiteralT& value : input) {
+    result.push_back(static_cast<T>(value));
+  }
+  return result;
+}
+
+template <typename T>
+absl::InlinedVector<T, 10> RepeatInputToMatchShape(
+    absl::InlinedVector<T, 10> input, int size) {
+  absl::InlinedVector<T, 10> result;
+  for (int i = 0; i < size; i++) {
+    auto value = input[i % input.size()];
+    result.push_back(value);
+  }
+  return result;
+}
+
+/// Helper functions to get default input shapes.
+
+TensorShape DefaultInputShape();
+
+/// Helper functions to get default input data.
+
+template <typename T,
+          std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
+                           bool> = true>
+T DefaultScalarInput() {
+  return static_cast<T>(3);
+}
+
+template <typename T, std::enable_if_t<
+                          llvm::is_one_of<T, Eigen::half, float, double>::value,
+                          bool> = true>
+T DefaultScalarInput() {
+  return static_cast<T>(2.0);
+}
+
+template <typename T,
+          std::enable_if_t<llvm::is_one_of<T, bool>::value, bool> = true>
+T DefaultScalarInput() {
+  return static_cast<T>(true);
+}
+
+template <typename T,
+          std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
+                           bool> = true>
+absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name = "") {
+  // Only generate values less than the bitwidth of the data type.
+  if (op_name == "LeftShift" || op_name == "RightShift") {
+    auto max_shift = sizeof(T) * 8 - 1;
+    absl::InlinedVector<T, 10> v(max_shift);
+    for (auto i = 0; i < max_shift; ++i) v.push_back(i);
+    return v;
+  }
+  return InputAsVector<T, int>({-18, -9, -1, 0, 0, 1, 1, 2, 3, 5, 7, 9, 9, 18});
+}
+
+template <typename T, std::enable_if_t<
+                          llvm::is_one_of<T, Eigen::half, float, double>::value,
+                          bool> = true>
+absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name = "") {
+  if (op_name == "FloorDiv")
+    return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.1, 0.1, 1e-6, 0.1,
+                                     0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
+  return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1,
+                                   0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
+}
+
+template <typename T,
+          std::enable_if_t<llvm::is_one_of<T, bool>::value, bool> = true>
+absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name = "") {
+  return InputAsVector<T, bool>({true, false, true, true, false});
+}
+
+/// Helper functions to get more specific input data.
+
+template <typename T, std::enable_if_t<
+                          llvm::is_one_of<T, Eigen::half, float, double>::value,
+                          bool> = true>
+absl::InlinedVector<std::complex<T>, 10> DefaultComplexInput() {
+  auto input = test::DefaultInput<T>();
+  absl::InlinedVector<std::complex<T>, 10> complex_input;
+  for (T value : input) {
+    complex_input.emplace_back(value, -value);
+  }
+  return complex_input;
+}
+
+template <typename T, std::enable_if_t<
+                          llvm::is_one_of<T, Eigen::half, float, double>::value,
+                          bool> = true>
+absl::InlinedVector<T, 10> NearZeroAndExtremeInput() {
+  return InputAsVector<T, double>({-std::numeric_limits<double>::infinity(),
+                                   -0.1, -0.0, 0.0, 0.1,
+                                   std::numeric_limits<float>::infinity()});
+}
+
+template <typename T,
+          std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
+                           bool> = true>
+absl::InlinedVector<T, 10> NearZeroAndExtremeInput() {
+  return InputAsVector<T, T>({std::numeric_limits<T>::min(),
+                              std::numeric_limits<T>::min() + 1, -1, 0, 1,
+                              std::numeric_limits<T>::max()});
+}
+
+template <typename T, std::enable_if_t<
+                          llvm::is_one_of<T, Eigen::half, float, double>::value,
+                          bool> = true>
+absl::InlinedVector<T, 10> DefaultInputGreaterThanZero() {
+  return test::InputAsVector<T, double>({18.0, 9.0, 1e-6, 1.0, 0.1, 1e-6, 0.1,
+                                         0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
+}
+
+template <typename T, std::enable_if_t<
+                          llvm::is_one_of<T, Eigen::half, float, double>::value,
+                          bool> = true>
+absl::InlinedVector<T, 10> DefaultInputGreaterOrEqualToZero() {
+  return test::InputAsVector<T, double>({18.0, 9.0, 1e-6, 0.0, 0.1, 1e-6, 0.1,
+                                         0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
+}
+
+}  // namespace test
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_TEST_UTIL_H_
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
index 40d58a6..a3308b5 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
@@ -28,6 +28,7 @@
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h"
 #include "tensorflow/core/kernels/ops_testutil.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
@@ -44,23 +45,15 @@
     SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
   }
 
-  // 'T' is the input type, 'RT' is the input type for the callback function,
-  // 'OutT' is the output type, 'ROutT' is the output type for the callback
-  // function. In most cases it is enough to just provide the input type,
-  // because all the types are the same.
-  template <typename T, typename RT = T, typename OutT = T, typename ROutT = RT>
-  void Run(std::vector<int64> input_shape, std::vector<T> input,
-           const std::string op_name, ROutT (*expected_callback)(RT),
-           bool expect_equal = true, bool add_tout = false,
-           bool expect_buffer_reuse = true) {
-    assert(std::accumulate(input_shape.begin(), input_shape.end(), 1,
-                           std::multiplies<int64>()) == input.size() &&
-           "Expected input length to equal to shape's number of elements.");
-
-    TensorShape shape(input_shape);
+  template <typename T, typename OutT>
+  void SetOpKernel(const std::string& op_name, const TensorShape& shape,
+                   const absl::InlinedVector<T, 10>& input, bool add_t,
+                   bool add_tout) {
     NodeDefBuilder builder("some_name", op_name);
-    builder.Input(FakeInput(DataTypeToEnum<T>::v()))
-        .Attr("T", DataTypeToEnum<T>::v());
+    builder.Input(FakeInput(DataTypeToEnum<T>::v()));
+    if (add_t) {
+      builder.Attr("T", DataTypeToEnum<T>::v());
+    }
     if (add_tout) {
       builder.Attr("Tout", DataTypeToEnum<OutT>::v());
     }
@@ -68,6 +61,15 @@
 
     TF_ASSERT_OK(InitOp());
     AddInputFromArray<T>(shape, input);
+  }
+
+  template <typename T, typename OutT>
+  void RunAndExpectResult(const std::string& op_name, const TensorShape& shape,
+                          const absl::InlinedVector<T, 10>& input,
+                          const absl::InlinedVector<OutT, 10>& expected_output,
+                          bool add_t, bool add_tout, bool expect_buffer_reuse,
+                          bool expect_equal) {
+    SetOpKernel<T, OutT>(op_name, shape, input, add_t, add_tout);
     TF_ASSERT_OK(RunOpKernel());
 
     // Assert buffer reuse if expected.
@@ -79,13 +81,7 @@
 
     // Assert expected results.
     Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, shape);
-    absl::InlinedVector<OutT, 14> expected;
-    expected.reserve(input.size());
-    for (const T& inp : input) {
-      expected.push_back(
-          static_cast<OutT>(expected_callback(static_cast<RT>(inp))));
-    }
-    test::FillValues<OutT>(&expected_tensor, expected);
+    test::FillValues<OutT>(&expected_tensor, expected_output);
     if (expect_equal) {
       test::ExpectEqual(expected_tensor, *GetOutput(0));
     } else {
@@ -93,240 +89,225 @@
     }
   }
 
-  // Some helper functions to get default input values.
+  template <typename T, typename BaselineT, typename OutT,
+            typename BaselineOutT>
+  void Test(const std::string op_name, const TensorShape& shape,
+            absl::InlinedVector<T, 10> input,
+            BaselineOutT (*baseline_callback)(BaselineT),
+            bool expect_equal = true, bool add_tout = false,
+            bool expect_buffer_reuse = true, bool add_t = true) {
+    // Prepare inputs and compute expected results.
+    auto repeated_input =
+        test::RepeatInputToMatchShape(input, shape.num_elements());
+    absl::InlinedVector<OutT, 10> expected_output =
+        ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
+            repeated_input, baseline_callback);
 
-  std::vector<int64> DefaultInputShape() { return std::vector<int64>{2, 7}; }
-
-  template <typename T>
-  std::vector<T> DefaultInput() {
-    return InputAsVector<T>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3,
-                             0.5, 0.7, 0.9, 9.0, 18.0});
-  }
-
-  template <typename T>
-  std::vector<std::complex<T>> DefaultComplexInput() {
-    auto input = DefaultInput<T>();
-    std::vector<std::complex<T>> complex_input;
-    for (T value : input) {
-      complex_input.emplace_back(value, -value);
-    }
-    return complex_input;
-  }
-
-  template <typename T>
-  std::vector<T> DefaultInputGreaterThanZero() {
-    return InputAsVector<T>({18.0, 9.0, 1e-6, 1.0, 0.1, 1e-6, 0.1, 0.2, 0.3,
-                             0.5, 0.7, 0.9, 9.0, 18.0});
-  }
-
-  template <typename T>
-  std::vector<T> DefaultInputGreaterOrEqualToZero() {
-    return InputAsVector<T>({18.0, 9.0, 1e-6, 0.0, 0.1, 1e-6, 0.1, 0.2, 0.3,
-                             0.5, 0.7, 0.9, 9.0, 18.0});
+    RunAndExpectResult<T, OutT>(op_name, shape, repeated_input, expected_output,
+                                add_t, add_tout, expect_buffer_reuse,
+                                expect_equal);
   }
 
  private:
-  template <typename T>
-  std::vector<T> InputAsVector(std::initializer_list<double> input) {
-    std::vector<T> result;
-    result.reserve(input.size());
-    for (const auto& value : input) {
-      result.push_back(static_cast<T>(value));
+  template <typename T, typename BaselineT, typename OutT,
+            typename BaselineOutT>
+  absl::InlinedVector<OutT, 10> ComputeExpectedOutput(
+      absl::InlinedVector<T, 10> input,
+      BaselineOutT (*baseline_callback)(BaselineT)) {
+    absl::InlinedVector<OutT, 10> expected_output;
+    for (int i = 0; i < input.size(); i++) {
+      auto arg = static_cast<BaselineT>(input[i]);
+      auto result = static_cast<OutT>(baseline_callback(arg));
+      expected_output.push_back(result);
     }
-    return result;
+    return expected_output;
   }
 };
 
 /// Test `tf.Abs`.
 
 TEST_F(GpuUnaryOpTest, AbsFloat) {
-  Run<float>(
-      /*input_shape=*/{2, 3},
-      /*input=*/
-      {-std::numeric_limits<float>::infinity(), -0.1f, -0.0f, 0.0f, 0.1f,
-       std::numeric_limits<float>::infinity()},
-      /*op_name=*/"Abs",
-      /*expected_callback=*/std::abs,
+  Test<float, float, float, float>(
+      /*op_name=*/"Abs", test::DefaultInputShape(),
+      test::NearZeroAndExtremeInput<float>(),
+      /*baseline_callback=*/std::abs,
       /*expect_equal=*/true);
 }
 
 TEST_F(GpuUnaryOpTest, AbsDouble) {
-  Run<double>(
-      /*input_shape=*/{2, 3},
-      /*input=*/
-      {-std::numeric_limits<double>::infinity(), -0.1, -0.0, 0.0, 0.1,
-       std::numeric_limits<double>::infinity()},
-      /*op_name=*/"Abs",
-      /*expected_callback=*/std::abs,
+  Test<double, double, double, double>(
+      /*op_name=*/"Abs", test::DefaultInputShape(),
+      test::NearZeroAndExtremeInput<double>(),
+      /*baseline_callback=*/std::abs,
       /*expect_equal=*/true);
 }
 
 TEST_F(GpuUnaryOpTest, AbsHalf) {
-  Run<Eigen::half, float>(
-      /*input_shape=*/{2, 3},
-      /*input=*/
-      {static_cast<Eigen::half>(-std::numeric_limits<double>::infinity()),
-       static_cast<Eigen::half>(-0.1), static_cast<Eigen::half>(-0.0),
-       static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(0.1),
-       static_cast<Eigen::half>(std::numeric_limits<double>::infinity())},
-      /*op_name=*/"Abs",
-      /*expected_callback=*/std::abs,
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Abs", test::DefaultInputShape(),
+      test::NearZeroAndExtremeInput<Eigen::half>(),
+      /*baseline_callback=*/std::abs,
       /*expect_equal=*/true);
 }
 
 TEST_F(GpuUnaryOpTest, AbsInt32) {
-  Run<int32>(
-      /*input_shape=*/{2, 3},
-      /*input=*/
-      {std::numeric_limits<int32>::min(), std::numeric_limits<int32>::min() + 1,
-       -1, 0, 1, std::numeric_limits<int32>::max()},
-      /*op_name=*/"Abs",
-      /*expected_callback=*/std::abs,
+  Test<int32, int32, int32, int32>(
+      /*op_name=*/"Abs", test::DefaultInputShape(),
+      test::NearZeroAndExtremeInput<int32>(),
+      /*baseline_callback=*/std::abs,
       /*expect_equal=*/true);
 }
 
 TEST_F(GpuUnaryOpTest, AbsInt64) {
-  Run<int64>(
-      /*input_shape=*/{2, 3},
-      /*input=*/
-      {std::numeric_limits<int64>::min(), std::numeric_limits<int64>::min() + 1,
-       -1, 0, 1, std::numeric_limits<int64>::max()},
-      /*op_name=*/"Abs",
-      /*expected_callback=*/std::abs,
+  Test<int64, int64, int64, int64>(
+      /*op_name=*/"Abs", test::DefaultInputShape(),
+      test::NearZeroAndExtremeInput<int64>(),
+      /*baseline_callback=*/std::abs,
       /*expect_equal=*/true);
 }
 
 /// Test `tf.Ceil`.
 
 TEST_F(GpuUnaryOpTest, CeilFloat) {
-  Run<float>(DefaultInputShape(), DefaultInput<float>(),
-             /*op_name=*/"Ceil",
-             /*expected_callback=*/std::ceil,
-             /*expect_equal=*/true);
+  Test<float, float, float, float>(
+      /*op_name=*/"Ceil", test::DefaultInputShape(),
+      test::DefaultInput<float>("Ceil"),
+      /*baseline_callback=*/std::ceil,
+      /*expect_equal=*/true);
 }
 
 TEST_F(GpuUnaryOpTest, CeilDouble) {
-  Run<double>(DefaultInputShape(), DefaultInput<double>(),
-              /*op_name=*/"Ceil",
-              /*expected_callback=*/std::ceil,
-              /*expect_equal=*/true);
+  Test<double, double, double, double>(
+      /*op_name=*/"Ceil", test::DefaultInputShape(),
+      test::DefaultInput<double>(),
+      /*baseline_callback=*/std::ceil,
+      /*expect_equal=*/true);
 }
 
 TEST_F(GpuUnaryOpTest, CeilHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
-                          /*op_name=*/"Ceil",
-                          /*expected_callback=*/std::ceil,
-                          /*expect_equal=*/true);
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Ceil", test::DefaultInputShape(),
+      test::DefaultInput<Eigen::half>(),
+      /*baseline_callback=*/std::ceil,
+      /*expect_equal=*/true);
 }
 
 /// Test `tf.Conj`.
 
 TEST_F(GpuUnaryOpTest, ConjFloat) {
-  Run<std::complex<float>, const std::complex<float>&, std::complex<float>,
-      std::complex<float>>(DefaultInputShape(), DefaultComplexInput<float>(),
-                           /*op_name=*/"Conj",
-                           /*expected_callback=*/std::conj,
-                           /*expect_equal=*/false,
-                           /*add_tout=*/false,
-                           /*expect_buffer_reuse=*/false);
-}
-
-TEST_F(GpuUnaryOpTest, ConjDouble) {
-  Run<std::complex<double>, const std::complex<double>&, std::complex<double>,
-      std::complex<double>>(DefaultInputShape(), DefaultComplexInput<double>(),
-                            /*op_name=*/"Conj",
-                            /*expected_callback=*/std::conj,
+  Test<std::complex<float>, const std::complex<float>&, std::complex<float>,
+       std::complex<float>>(/*op_name=*/"Conj", test::DefaultInputShape(),
+                            test::DefaultComplexInput<float>(),
+                            /*baseline_callback=*/std::conj,
                             /*expect_equal=*/false,
                             /*add_tout=*/false,
                             /*expect_buffer_reuse=*/false);
 }
 
+TEST_F(GpuUnaryOpTest, ConjDouble) {
+  Test<std::complex<double>, const std::complex<double>&, std::complex<double>,
+       std::complex<double>>(
+      /*op_name=*/"Conj", test::DefaultInputShape(),
+      test::DefaultComplexInput<double>(),
+      /*baseline_callback=*/std::conj,
+      /*expect_equal=*/false,
+      /*add_tout=*/false,
+      /*expect_buffer_reuse=*/false);
+}
+
 /// Test `tf.Cos`.
 
 TEST_F(GpuUnaryOpTest, CosFloat) {
-  Run<float>(DefaultInputShape(), DefaultInput<float>(),
-             /*op_name=*/"Cos",
-             /*expected_callback=*/std::cos,
-             /*expect_equal=*/false);
+  Test<float, float, float, float>(
+      /*op_name=*/"Cos", test::DefaultInputShape(), test::DefaultInput<float>(),
+      /*baseline_callback=*/std::cos,
+      /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, CosDouble) {
-  Run<double>(DefaultInputShape(), DefaultInput<double>(),
-              /*op_name=*/"Cos",
-              /*expected_callback=*/std::cos,
-              /*expect_equal=*/false);
+  Test<double, double, double, double>(/*op_name=*/"Cos",
+                                       test::DefaultInputShape(),
+                                       test::DefaultInput<double>(),
+                                       /*baseline_callback=*/std::cos,
+                                       /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, CosHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
-                          /*op_name=*/"Cos",
-                          /*expected_callback=*/std::cos,
-                          /*expect_equal=*/false);
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Cos", test::DefaultInputShape(),
+      test::DefaultInput<Eigen::half>(),
+      /*baseline_callback=*/std::cos,
+      /*expect_equal=*/false);
 }
 
 /// Test `tf.Exp`.
 
 TEST_F(GpuUnaryOpTest, ExpFloat) {
-  Run<float>(DefaultInputShape(), DefaultInput<float>(),
-             /*op_name=*/"Exp",
-             /*expected_callback=*/std::exp,
-             /*expect_equal=*/false);
+  Test<float, float, float, float>(/*op_name=*/"Exp", test::DefaultInputShape(),
+                                   test::DefaultInput<float>(),
+                                   /*baseline_callback=*/std::exp,
+                                   /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, ExpDouble) {
-  Run<double>(DefaultInputShape(), DefaultInput<double>(),
-              /*op_name=*/"Exp",
-              /*expected_callback=*/std::exp,
-              /*expect_equal=*/false);
+  Test<double, double, double, double>(/*op_name=*/"Exp",
+                                       test::DefaultInputShape(),
+                                       test::DefaultInput<double>(),
+                                       /*baseline_callback=*/std::exp,
+                                       /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, ExpHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
-                          /*op_name=*/"Exp",
-                          /*expected_callback=*/std::exp,
-                          /*expect_equal=*/false);
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Exp", test::DefaultInputShape(),
+      test::DefaultInput<Eigen::half>(),
+      /*baseline_callback=*/std::exp,
+      /*expect_equal=*/false);
 }
 
 /// Test `tf.Floor`.
 
 TEST_F(GpuUnaryOpTest, FloorFloat) {
-  Run<float>(DefaultInputShape(), DefaultInput<float>(),
-             /*op_name=*/"Floor",
-             /*expected_callback=*/std::floor,
-             /*expect_equal=*/true);
+  Test<float, float, float, float>(/*op_name=*/"Floor",
+                                   test::DefaultInputShape(),
+                                   test::DefaultInput<float>(),
+                                   /*baseline_callback=*/std::floor,
+                                   /*expect_equal=*/true);
 }
 
 TEST_F(GpuUnaryOpTest, FloorDouble) {
-  Run<double>(DefaultInputShape(), DefaultInput<double>(),
-              /*op_name=*/"Floor",
-              /*expected_callback=*/std::floor,
-              /*expect_equal=*/true);
+  Test<double, double, double, double>(/*op_name=*/"Floor",
+                                       test::DefaultInputShape(),
+                                       test::DefaultInput<double>(),
+                                       /*baseline_callback=*/std::floor,
+                                       /*expect_equal=*/true);
 }
 
 TEST_F(GpuUnaryOpTest, FloorHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
-                          /*op_name=*/"Floor",
-                          /*expected_callback=*/std::floor,
-                          /*expect_equal=*/true);
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Floor", test::DefaultInputShape(),
+      test::DefaultInput<Eigen::half>(),
+      /*baseline_callback=*/std::floor,
+      /*expect_equal=*/true);
 }
 
 /// Test `tf.Imag`.
 
 TEST_F(GpuUnaryOpTest, ImagFloat) {
-  Run<std::complex<float>, const std::complex<float>&, float, float>(
-      DefaultInputShape(), DefaultComplexInput<float>(),
-      /*op_name=*/"Imag",
-      /*expected_callback=*/std::imag,
+  Test<std::complex<float>, const std::complex<float>&, float, float>(
+      /*op_name=*/"Imag", test::DefaultInputShape(),
+      test::DefaultComplexInput<float>(),
+      /*baseline_callback=*/std::imag,
       /*expect_equal=*/false,
       /*add_tout=*/true,
       /*expect_buffer_reuse=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, ImagDouble) {
-  Run<std::complex<double>, const std::complex<double>&, double, double>(
-      DefaultInputShape(), DefaultComplexInput<double>(),
-      /*op_name=*/"Imag",
-      /*expected_callback=*/std::imag,
+  Test<std::complex<double>, const std::complex<double>&, double, double>(
+      /*op_name=*/"Imag", test::DefaultInputShape(),
+      test::DefaultComplexInput<double>(),
+      /*baseline_callback=*/std::imag,
       /*expect_equal=*/false,
       /*add_tout=*/true,
       /*expect_buffer_reuse=*/false);
@@ -335,103 +316,140 @@
 /// Test `tf.IsInf`.
 
 // TODO(b/162575339): The tests currently still fails with CUDA_ILLEGAL_ADDRESS
-// when run with unranked kernels.
+// when Test with unranked kernels.
 TEST_F(GpuUnaryOpTest, DISABLED_IsInfFloat) {
-  Run<float, float, bool, bool>(DefaultInputShape(), DefaultInput<float>(),
-                                /*op_name=*/"IsInf",
-                                /*expected_callback=*/std::isinf,
-                                /*expect_equal=*/true);
+  Test<float, float, bool, bool>(/*op_name=*/"IsInf", test::DefaultInputShape(),
+                                 test::DefaultInput<float>(),
+                                 /*baseline_callback=*/std::isinf,
+                                 /*expect_equal=*/true);
 }
 
 TEST_F(GpuUnaryOpTest, DISABLED_IsInfDouble) {
   // Workaround for gcc bug, it would fail with "unresolved overloaded function
   // type" if passing std::isinf with type double. So we use type float for
   // comparing expected values.
-  Run<double, float, bool, bool>(DefaultInputShape(), DefaultInput<double>(),
-                                 /*op_name=*/"IsInf",
-                                 /*expected_callback=*/std::isinf,
-                                 /*expect_equal=*/true);
+  Test<double, float, bool, bool>(/*op_name=*/"IsInf",
+                                  test::DefaultInputShape(),
+                                  test::DefaultInput<double>(),
+                                  /*baseline_callback=*/std::isinf,
+                                  /*expect_equal=*/true);
 }
 
 TEST_F(GpuUnaryOpTest, DISABLED_IsInfHalf) {
-  Run<Eigen::half, float, bool, bool>(DefaultInputShape(),
-                                      DefaultInput<Eigen::half>(),
-                                      /*op_name=*/"IsInf",
-                                      /*expected_callback=*/std::isinf,
-                                      /*expect_equal=*/true);
+  Test<Eigen::half, float, bool, bool>(/*op_name=*/"IsInf",
+                                       test::DefaultInputShape(),
+                                       test::DefaultInput<Eigen::half>(),
+                                       /*baseline_callback=*/std::isinf,
+                                       /*expect_equal=*/true);
 }
 
 /// Test `tf.Log`.
 
 TEST_F(GpuUnaryOpTest, LogFloat) {
-  Run<float>(DefaultInputShape(), DefaultInputGreaterThanZero<float>(),
-             /*op_name=*/"Log",
-             /*expected_callback=*/std::log,
-             /*expect_equal=*/false);
+  Test<float, float, float, float>(/*op_name=*/"Log", test::DefaultInputShape(),
+                                   test::DefaultInputGreaterThanZero<float>(),
+                                   /*baseline_callback=*/std::log,
+                                   /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, LogDouble) {
-  Run<double>(DefaultInputShape(), DefaultInputGreaterThanZero<double>(),
-              /*op_name=*/"Log",
-              /*expected_callback=*/std::log,
-              /*expect_equal=*/false);
+  Test<double, double, double, double>(
+      /*op_name=*/"Log", test::DefaultInputShape(),
+      test::DefaultInputGreaterThanZero<double>(),
+      /*baseline_callback=*/std::log,
+      /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, LogHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(),
-                          /*input=*/
-                          DefaultInputGreaterThanZero<Eigen::half>(),
-                          /*op_name=*/"Log",
-                          /*expected_callback=*/std::log,
-                          /*expect_equal=*/false);
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Log", test::DefaultInputShape(),
+      test::DefaultInputGreaterThanZero<Eigen::half>(),
+      /*baseline_callback=*/std::log,
+      /*expect_equal=*/false);
+}
+
+/// Test `tf.LogicalNot`
+
+TEST_F(GpuUnaryOpTest, LogicalNot) {
+  Test<bool, bool, bool, bool>(
+      /*op_name=*/"LogicalNot", test::DefaultInputShape(),
+      test::DefaultInput<bool>(),
+      /*baseline_callback=*/[](bool v) { return !v; },
+      /*expect_equal=*/true,
+      /*add_tout=*/false,
+      /*expect_buffer_reuse=*/true,
+      /*add_t=*/false);
 }
 
 /// Test `tf.Neg`.
 
 /// Reference implementation.
 template <typename T>
-T expected_neg(T x) {
+T baseline_neg(T x) {
   return -x;
 }
 
 TEST_F(GpuUnaryOpTest, NegFloat) {
-  Run<float>(DefaultInputShape(), DefaultInput<float>(),
-             /*op_name=*/"Neg",
-             /*expected_callback=*/expected_neg,
-             /*expect_equal=*/false);
+  Test<float, float, float, float>(
+      /*op_name=*/"Neg", test::DefaultInputShape(), test::DefaultInput<float>(),
+      /*baseline_callback=*/baseline_neg,
+      /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, NegDouble) {
-  Run<double>(DefaultInputShape(), DefaultInput<double>(),
-              /*op_name=*/"Neg",
-              /*expected_callback=*/expected_neg,
-              /*expect_equal=*/false);
+  Test<double, double, double, double>(
+      /*op_name=*/"Neg", test::DefaultInputShape(),
+      test::DefaultInput<double>(),
+      /*baseline_callback=*/baseline_neg,
+      /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, NegHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
-                          /*op_name=*/"Neg",
-                          /*expected_callback=*/expected_neg,
-                          /*expect_equal=*/false);
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Neg", test::DefaultInputShape(),
+      test::DefaultInput<Eigen::half>(),
+      /*baseline_callback=*/baseline_neg,
+      /*expect_equal=*/false);
+}
+
+TEST_F(GpuUnaryOpTest, NegInt8) {
+  Test<int8, int8, int8, int8>(
+      /*op_name=*/"Neg", test::DefaultInputShape(), test::DefaultInput<int8>(),
+      /*baseline_callback=*/baseline_neg,
+      /*expect_equal=*/true);
+}
+
+TEST_F(GpuUnaryOpTest, NegInt16) {
+  Test<int16, int16, int16, int16>(/*op_name=*/"Neg", test::DefaultInputShape(),
+                                   test::DefaultInput<int16>(),
+                                   /*baseline_callback=*/baseline_neg,
+                                   /*expect_equal=*/true);
+}
+
+TEST_F(GpuUnaryOpTest, NegInt64) {
+  Test<int64, int64, int64, int64>(/*op_name=*/"Neg", test::DefaultInputShape(),
+                                   test::DefaultInput<int64>(),
+                                   /*baseline_callback=*/baseline_neg,
+                                   /*expect_equal=*/true);
 }
 
 /// Test `tf.Real`.
 
 TEST_F(GpuUnaryOpTest, RealFloat) {
-  Run<std::complex<float>, const std::complex<float>&, float, float>(
-      DefaultInputShape(), DefaultComplexInput<float>(),
-      /*op_name=*/"Real",
-      /*expected_callback=*/std::real,
+  Test<std::complex<float>, const std::complex<float>&, float, float>(
+      /*op_name=*/"Real", test::DefaultInputShape(),
+      test::DefaultComplexInput<float>(),
+      /*baseline_callback=*/std::real,
       /*expect_equal=*/false,
       /*add_tout=*/true,
       /*expect_buffer_reuse=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, RealDouble) {
-  Run<std::complex<double>, const std::complex<double>&, double, double>(
-      DefaultInputShape(), DefaultComplexInput<double>(),
-      /*op_name=*/"Real",
-      /*expected_callback=*/std::real,
+  Test<std::complex<double>, const std::complex<double>&, double, double>(
+      /*op_name=*/"Real", test::DefaultInputShape(),
+      test::DefaultComplexInput<double>(),
+      /*baseline_callback=*/std::real,
       /*expect_equal=*/false,
       /*add_tout=*/true,
       /*expect_buffer_reuse=*/false);
@@ -441,134 +459,153 @@
 
 /// Reference implementation.
 template <typename T>
-T expected_rsqrt(T x) {
+T baseline_rsqrt(T x) {
   return 1.0 / std::sqrt(x);
 }
 
 TEST_F(GpuUnaryOpTest, RsqrtFloat) {
-  Run<float>(DefaultInputShape(), DefaultInputGreaterThanZero<float>(),
-             /*op_name=*/"Rsqrt",
-             /*expected_callback=*/expected_rsqrt,
-             /*expect_equal=*/false);
+  Test<float, float, float, float>(/*op_name=*/"Rsqrt",
+                                   test::DefaultInputShape(),
+                                   test::DefaultInputGreaterThanZero<float>(),
+                                   /*baseline_callback=*/baseline_rsqrt,
+                                   /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, RsqrtDouble) {
-  Run<double>(DefaultInputShape(), DefaultInputGreaterThanZero<double>(),
-              /*op_name=*/"Rsqrt",
-              /*expected_callback=*/expected_rsqrt,
-              /*expect_equal=*/false);
+  Test<double, double, double, double>(
+      /*op_name=*/"Rsqrt", test::DefaultInputShape(),
+      test::DefaultInputGreaterThanZero<double>(),
+      /*baseline_callback=*/baseline_rsqrt,
+      /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, RsqrtHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(),
-                          /*input=*/
-                          DefaultInputGreaterThanZero<Eigen::half>(),
-                          /*op_name=*/"Rsqrt",
-                          /*expected_callback=*/expected_rsqrt,
-                          /*expect_equal=*/false);
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Rsqrt", test::DefaultInputShape(),
+      test::DefaultInputGreaterThanZero<Eigen::half>(),
+      /*baseline_callback=*/baseline_rsqrt,
+      /*expect_equal=*/false);
 }
 
 /// Test `tf.Sign`.
 
 // Reference implementation
 template <typename T>
-T expected_sign(T x) {
+T baseline_sign(T x) {
   if (x == 0) return 0;
   if (x < 0) return -1;
   return 1;
 }
 
-// TODO(b/162577610): Enable these tests when our generated kernels handle 0.0
-// and -0.0 correctly.
-TEST_F(GpuUnaryOpTest, DISABLED_SignFloat) {
-  Run<float>(DefaultInputShape(), DefaultInput<float>(),
-             /*op_name=*/"Sign",
-             /*expected_callback=*/expected_sign,
-             /*expect_equal=*/true);
+TEST_F(GpuUnaryOpTest, SignFloat) {
+  Test<float, float, float, float>(/*op_name=*/"Sign",
+                                   test::DefaultInputShape(),
+                                   test::DefaultInput<float>(),
+                                   /*baseline_callback=*/baseline_sign,
+                                   /*expect_equal=*/true);
 }
 
-TEST_F(GpuUnaryOpTest, DISABLED_SignDouble) {
-  Run<double>(DefaultInputShape(), DefaultInput<double>(),
-              /*op_name=*/"Sign",
-              /*expected_callback=*/expected_sign,
-              /*expect_equal=*/true);
+TEST_F(GpuUnaryOpTest, SignDouble) {
+  Test<double, double, double, double>(/*op_name=*/"Sign",
+                                       test::DefaultInputShape(),
+                                       test::DefaultInput<double>(),
+                                       /*baseline_callback=*/baseline_sign,
+                                       /*expect_equal=*/true);
 }
 
-TEST_F(GpuUnaryOpTest, DISABLED_SignHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
-                          /*op_name=*/"Sign",
-                          /*expected_callback=*/expected_sign,
-                          /*expect_equal=*/true);
+TEST_F(GpuUnaryOpTest, SignHalf) {
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Sign", test::DefaultInputShape(),
+      test::DefaultInput<Eigen::half>(),
+      /*expected_callback=*/baseline_sign,
+      // TODO(b/162577610): We should actually use true
+      // here. This requires returning 0.0 for input -0.0.
+      /*expect_equal=*/false);
+}
+
+TEST_F(GpuUnaryOpTest, SignInt64) {
+  Test<int64, int64, int64, int64>(
+      /*op_name=*/"Sign", test::DefaultInputShape(),
+      test::DefaultInput<int64>(),
+      /*expected_callback=*/baseline_sign,
+      /*expect_equal=*/true);
 }
 
 /// Test `tf.Sin`.
 
 TEST_F(GpuUnaryOpTest, SinFloat) {
-  Run<float>(DefaultInputShape(), DefaultInput<float>(),
-             /*op_name=*/"Sin",
-             /*expected_callback=*/std::sin,
-             /*expect_equal=*/false);
+  Test<float, float, float, float>(/*op_name=*/"Sin", test::DefaultInputShape(),
+                                   test::DefaultInput<float>(),
+                                   /*baseline_callback=*/std::sin,
+                                   /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, SinDouble) {
-  Run<double>(DefaultInputShape(), DefaultInput<double>(),
-              /*op_name=*/"Sin",
-              /*expected_callback=*/std::sin,
-              /*expect_equal=*/false);
+  Test<double, double, double, double>(/*op_name=*/"Sin",
+                                       test::DefaultInputShape(),
+                                       test::DefaultInput<double>(),
+                                       /*baseline_callback=*/std::sin,
+                                       /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, SinHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
-                          /*op_name=*/"Sin",
-                          /*expected_callback=*/std::sin,
-                          /*expect_equal=*/false);
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Sin", test::DefaultInputShape(),
+      test::DefaultInput<Eigen::half>(),
+      /*baseline_callback=*/std::sin,
+      /*expect_equal=*/false);
 }
 
 /// Test `tf.Sqrt`.
 
 TEST_F(GpuUnaryOpTest, SqrtFloat) {
-  Run<float>(DefaultInputShape(), DefaultInputGreaterOrEqualToZero<float>(),
-             /*op_name=*/"Sqrt",
-             /*expected_callback=*/std::sqrt,
-             /*expect_equal=*/false);
+  Test<float, float, float, float>(
+      /*op_name=*/"Sqrt", test::DefaultInputShape(),
+      test::DefaultInputGreaterOrEqualToZero<float>(),
+      /*baseline_callback=*/std::sqrt,
+      /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, SqrtDouble) {
-  Run<double>(DefaultInputShape(), DefaultInputGreaterOrEqualToZero<double>(),
-              /*op_name=*/"Sqrt",
-              /*expected_callback=*/std::sqrt,
-              /*expect_equal=*/false);
+  Test<double, double, double, double>(
+      /*op_name=*/"Sqrt", test::DefaultInputShape(),
+      test::DefaultInputGreaterOrEqualToZero<double>(),
+      /*baseline_callback=*/std::sqrt,
+      /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, SqrtHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(),
-                          DefaultInputGreaterOrEqualToZero<Eigen::half>(),
-                          /*op_name=*/"Sqrt",
-                          /*expected_callback=*/std::sqrt,
-                          /*expect_equal=*/false);
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Sqrt", test::DefaultInputShape(),
+      test::DefaultInputGreaterOrEqualToZero<Eigen::half>(),
+      /*baseline_callback=*/std::sqrt,
+      /*expect_equal=*/false);
 }
 
 /// Test `tf.Tanh`.
 
 TEST_F(GpuUnaryOpTest, TanhFloat) {
-  Run<float>(DefaultInputShape(), DefaultInput<float>(),
-             /*op_name=*/"Tanh",
-             /*expected_callback=*/std::tanh,
-             /*expect_equal=*/false);
+  Test<float, float, float, float>(/*op_name=*/"Tanh",
+                                   test::DefaultInputShape(),
+                                   test::DefaultInput<float>(),
+                                   /*baseline_callback=*/std::tanh,
+                                   /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, TanhDouble) {
-  Run<double>(DefaultInputShape(), DefaultInput<double>(),
-              /*op_name=*/"Tanh",
-              /*expected_callback=*/std::tanh,
-              /*expect_equal=*/false);
+  Test<double, double, double, double>(/*op_name=*/"Tanh",
+                                       test::DefaultInputShape(),
+                                       test::DefaultInput<double>(),
+                                       /*baseline_callback=*/std::tanh,
+                                       /*expect_equal=*/false);
 }
 
 TEST_F(GpuUnaryOpTest, TanhHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
-                          /*op_name=*/"Tanh",
-                          /*expected_callback=*/std::tanh,
-                          /*expect_equal=*/false);
+  Test<Eigen::half, float, Eigen::half, float>(
+      /*op_name=*/"Tanh", test::DefaultInputShape(),
+      test::DefaultInput<Eigen::half>(),
+      /*baseline_callback=*/std::tanh,
+      /*expect_equal=*/false);
 }
 
 }  // namespace
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_and.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_and.mlir.tmpl
new file mode 100644
index 0000000..3d12a78
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_and.mlir.tmpl
@@ -0,0 +1,6 @@
+func @BitwiseAnd_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.BitwiseAnd"(%arg0, %arg1)
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_or.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_or.mlir.tmpl
new file mode 100644
index 0000000..d3f4bf7
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_or.mlir.tmpl
@@ -0,0 +1,6 @@
+func @BitwiseOr_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.BitwiseOr"(%arg0, %arg1)
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_xor.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_xor.mlir.tmpl
new file mode 100644
index 0000000..8aaf9e3
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_xor.mlir.tmpl
@@ -0,0 +1,6 @@
+func @BitwiseXor_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.BitwiseXor"(%arg0, %arg1)
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/complex.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/complex.mlir.tmpl
new file mode 100644
index 0000000..3431893
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/complex.mlir.tmpl
@@ -0,0 +1,6 @@
+func @Complex_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xcomplex<elem_type>> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Complex"(%arg0, %arg1) {}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xcomplex<elem_type>>
+  return %0 : tensor<*xcomplex<elem_type>>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div.mlir.tmpl
new file mode 100644
index 0000000..be4841d
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div.mlir.tmpl
@@ -0,0 +1,6 @@
+func @FloorDiv_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.FloorDiv"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/greater.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/greater.mlir.tmpl
new file mode 100644
index 0000000..47010ee
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/greater.mlir.tmpl
@@ -0,0 +1,6 @@
+func @Greater_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Greater"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1>
+  return %0 : tensor<*xi1>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/greater_equal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/greater_equal.mlir.tmpl
new file mode 100644
index 0000000..63c0ce9
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/greater_equal.mlir.tmpl
@@ -0,0 +1,6 @@
+func @GreaterEqual_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.GreaterEqual"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1>
+  return %0 : tensor<*xi1>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/left_shift.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/left_shift.mlir.tmpl
new file mode 100644
index 0000000..459378e
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/left_shift.mlir.tmpl
@@ -0,0 +1,6 @@
+func @LeftShift_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.LeftShift"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/less.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/less.mlir.tmpl
new file mode 100644
index 0000000..59496dc
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/less.mlir.tmpl
@@ -0,0 +1,6 @@
+func @Less_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Less"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1>
+  return %0 : tensor<*xi1>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/less_equal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/less_equal.mlir.tmpl
new file mode 100644
index 0000000..245f27a
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/less_equal.mlir.tmpl
@@ -0,0 +1,6 @@
+func @LessEqual_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.LessEqual"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1>
+  return %0 : tensor<*xi1>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/logical_and.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/logical_and.mlir.tmpl
new file mode 100644
index 0000000..8e9d86c
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/logical_and.mlir.tmpl
@@ -0,0 +1,6 @@
+func @LogicalAnd_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.LogicalAnd"(%arg0, %arg1)
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/logical_or.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/logical_or.mlir.tmpl
new file mode 100644
index 0000000..9888890
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/logical_or.mlir.tmpl
@@ -0,0 +1,6 @@
+func @LogicalOr_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.LogicalOr"(%arg0, %arg1)
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/maximum.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/maximum.mlir.tmpl
new file mode 100644
index 0000000..c917b9a
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/maximum.mlir.tmpl
@@ -0,0 +1,6 @@
+func @Maximum_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Maximum"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/minimum.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/minimum.mlir.tmpl
new file mode 100644
index 0000000..6d8987b
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/minimum.mlir.tmpl
@@ -0,0 +1,6 @@
+func @Minimum_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Minimum"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/mul.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/mul.mlir.tmpl
new file mode 100644
index 0000000..c1903c2
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/mul.mlir.tmpl
@@ -0,0 +1,6 @@
+func @Mul_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Mul"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/not_equal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/not_equal.mlir.tmpl
new file mode 100644
index 0000000..8efef8b
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/not_equal.mlir.tmpl
@@ -0,0 +1,6 @@
+func @NotEqual_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.NotEqual"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1>
+  return %0 : tensor<*xi1>
+}
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift.mlir.tmpl
new file mode 100644
index 0000000..5f30b81
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift.mlir.tmpl
@@ -0,0 +1,6 @@
+func @RightShift_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.RightShift"(%arg0, %arg1) {T = elem_type, device = ""}
+    : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}
diff --git a/tensorflow/core/kernels/risc/experimental/BUILD b/tensorflow/core/kernels/risc/experimental/BUILD
index 86e5e7c..4d9b3bd 100644
--- a/tensorflow/core/kernels/risc/experimental/BUILD
+++ b/tensorflow/core/kernels/risc/experimental/BUILD
@@ -18,6 +18,36 @@
 )
 
 tf_kernel_library(
+    name = "risc_binary_arithmetic_op",
+    srcs = ["risc_binary_arithmetic_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_binary_comparison_op",
+    srcs = ["risc_binary_comparison_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_bitcast_op",
+    srcs = ["risc_bitcast_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
     name = "risc_broadcast_op",
     srcs = ["risc_broadcast_op.cc"],
     deps = [
@@ -28,6 +58,26 @@
 )
 
 tf_kernel_library(
+    name = "risc_cast_op",
+    srcs = ["risc_cast_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_cholesky_op",
+    srcs = ["risc_cholesky_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
     name = "risc_concat_op",
     srcs = ["risc_concat_op.cc"],
     deps = [
@@ -38,6 +88,16 @@
 )
 
 tf_kernel_library(
+    name = "risc_condition_op",
+    srcs = ["risc_condition_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
     name = "risc_conv_op",
     srcs = ["risc_conv_op.cc"],
     deps = [
@@ -58,6 +118,66 @@
 )
 
 tf_kernel_library(
+    name = "risc_fft_op",
+    srcs = ["risc_fft_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_gather_op",
+    srcs = ["risc_gather_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_is_finite_op",
+    srcs = ["risc_is_finite_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_logical_and_op",
+    srcs = ["risc_logical_and_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_logical_not_op",
+    srcs = ["risc_logical_not_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_logical_or_op",
+    srcs = ["risc_logical_or_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
     name = "risc_max_op",
     srcs = ["risc_max_op.cc"],
     deps = [
@@ -88,6 +208,26 @@
 )
 
 tf_kernel_library(
+    name = "risc_random_uniform_op",
+    srcs = ["risc_random_uniform_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_reduce_op",
+    srcs = ["risc_reduce_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
     name = "risc_reshape_op",
     srcs = ["risc_reshape_op.cc"],
     deps = [
@@ -98,6 +238,26 @@
 )
 
 tf_kernel_library(
+    name = "risc_reverse_op",
+    srcs = ["risc_reverse_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_scatter_op",
+    srcs = ["risc_scatter_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
     name = "risc_shape_op",
     srcs = ["risc_shape_op.cc"],
     deps = [
@@ -118,17 +278,99 @@
 )
 
 tf_kernel_library(
+    name = "risc_sort_op",
+    srcs = ["risc_sort_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_squeeze_op",
+    srcs = ["risc_squeeze_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_transpose_op",
+    srcs = ["risc_transpose_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_triangular_solve_op",
+    srcs = ["risc_triangular_solve_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_unary_op",
+    srcs = ["risc_unary_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "risc_while_op",
+    srcs = ["risc_while_op.cc"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
     name = "experimental",
     deps = [
         ":risc_add_op",
+        ":risc_binary_arithmetic_op",
+        ":risc_binary_comparison_op",
+        ":risc_bitcast_op",
         ":risc_broadcast_op",
+        ":risc_cast_op",
+        ":risc_cholesky_op",
+        ":risc_condition_op",
         ":risc_conv_op",
         ":risc_dot_op",
+        ":risc_fft_op",
+        ":risc_gather_op",
+        ":risc_is_finite_op",
+        ":risc_logical_and_op",
+        ":risc_logical_not_op",
+        ":risc_logical_or_op",
         ":risc_max_op",
         ":risc_pad_op",
         ":risc_pool_op",
+        ":risc_random_uniform_op",
+        ":risc_reduce_op",
         ":risc_reshape_op",
+        ":risc_reverse_op",
+        ":risc_scatter_op",
         ":risc_shape_op",
         ":risc_slice_op",
+        ":risc_sort_op",
+        ":risc_squeeze_op",
+        ":risc_transpose_op",
+        ":risc_triangular_solve_op",
+        ":risc_unary_op",
+        ":risc_while_op",
     ],
 )
diff --git a/tensorflow/core/kernels/risc/experimental/risc_binary_arithmetic_op.cc b/tensorflow/core/kernels/risc/experimental/risc_binary_arithmetic_op.cc
new file mode 100644
index 0000000..59da954
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_binary_arithmetic_op.cc
@@ -0,0 +1,48 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+template <typename T>
+class RiscBinaryArithmeticOp : public OpKernel {
+ public:
+  explicit RiscBinaryArithmeticOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscBinaryArithmetic op.
+  }
+};
+
+#define REGISTER_CPU(T)                                                       \
+  REGISTER_KERNEL_BUILDER(                                                    \
+      Name("RiscBinaryArithmetic").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      RiscBinaryArithmeticOp<T>);
+
+REGISTER_CPU(bfloat16);
+REGISTER_CPU(Eigen::half);
+REGISTER_CPU(float);
+REGISTER_CPU(double);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_binary_comparison_op.cc b/tensorflow/core/kernels/risc/experimental/risc_binary_comparison_op.cc
new file mode 100644
index 0000000..a614a22
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_binary_comparison_op.cc
@@ -0,0 +1,48 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+template <typename T>
+class RiscBinaryComparisonOp : public OpKernel {
+ public:
+  explicit RiscBinaryComparisonOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscBinaryComparison op.
+  }
+};
+
+#define REGISTER_CPU(T)                                                       \
+  REGISTER_KERNEL_BUILDER(                                                    \
+      Name("RiscBinaryComparison").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      RiscBinaryComparisonOp<T>);
+
+REGISTER_CPU(bfloat16);
+REGISTER_CPU(Eigen::half);
+REGISTER_CPU(float);
+REGISTER_CPU(double);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_bitcast_op.cc b/tensorflow/core/kernels/risc/experimental/risc_bitcast_op.cc
new file mode 100644
index 0000000..d7144dd
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_bitcast_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscBitcastOp : public OpKernel {
+ public:
+  explicit RiscBitcastOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscBitcast op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscBitcast").Device(DEVICE_CPU), RiscBitcastOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_cast_op.cc b/tensorflow/core/kernels/risc/experimental/risc_cast_op.cc
new file mode 100644
index 0000000..bfbaa66
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_cast_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscCastOp : public OpKernel {
+ public:
+  explicit RiscCastOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscCast op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscCast").Device(DEVICE_CPU), RiscCastOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_cholesky_op.cc b/tensorflow/core/kernels/risc/experimental/risc_cholesky_op.cc
new file mode 100644
index 0000000..05b2497
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_cholesky_op.cc
@@ -0,0 +1,48 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+template <typename T>
+class RiscCholeskyOp : public OpKernel {
+ public:
+  explicit RiscCholeskyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscCholesky op.
+  }
+};
+
+#define REGISTER_CPU(T)                                               \
+  REGISTER_KERNEL_BUILDER(                                            \
+      Name("RiscCholesky").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      RiscCholeskyOp<T>);
+
+REGISTER_CPU(bfloat16);
+REGISTER_CPU(Eigen::half);
+REGISTER_CPU(float);
+REGISTER_CPU(double);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_condition_op.cc b/tensorflow/core/kernels/risc/experimental/risc_condition_op.cc
new file mode 100644
index 0000000..e76b217
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_condition_op.cc
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscConditionOp : public OpKernel {
+ public:
+  explicit RiscConditionOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscCondition op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscCondition").Device(DEVICE_CPU),
+                        RiscConditionOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_fft_op.cc b/tensorflow/core/kernels/risc/experimental/risc_fft_op.cc
new file mode 100644
index 0000000..d21aa20
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_fft_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscFftOp : public OpKernel {
+ public:
+  explicit RiscFftOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscFft op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscFft").Device(DEVICE_CPU), RiscFftOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_gather_op.cc b/tensorflow/core/kernels/risc/experimental/risc_gather_op.cc
new file mode 100644
index 0000000..424733f
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_gather_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscGatherOp : public OpKernel {
+ public:
+  explicit RiscGatherOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscGather op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscGather").Device(DEVICE_CPU), RiscGatherOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_is_finite_op.cc b/tensorflow/core/kernels/risc/experimental/risc_is_finite_op.cc
new file mode 100644
index 0000000..f223cb2
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_is_finite_op.cc
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscIsFiniteOp : public OpKernel {
+ public:
+  explicit RiscIsFiniteOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscIsFinite op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscIsFinite").Device(DEVICE_CPU),
+                        RiscIsFiniteOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_logical_and_op.cc b/tensorflow/core/kernels/risc/experimental/risc_logical_and_op.cc
new file mode 100644
index 0000000..7c2f2cb
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_logical_and_op.cc
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscLogicalAndOp : public OpKernel {
+ public:
+  explicit RiscLogicalAndOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscLogicalAnd op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscLogicalAnd").Device(DEVICE_CPU),
+                        RiscLogicalAndOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_logical_not_op.cc b/tensorflow/core/kernels/risc/experimental/risc_logical_not_op.cc
new file mode 100644
index 0000000..2f96c2b
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_logical_not_op.cc
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscLogicalNotOp : public OpKernel {
+ public:
+  explicit RiscLogicalNotOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscLogicalNot op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscLogicalNot").Device(DEVICE_CPU),
+                        RiscLogicalNotOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_logical_or_op.cc b/tensorflow/core/kernels/risc/experimental/risc_logical_or_op.cc
new file mode 100644
index 0000000..1e9ac0a
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_logical_or_op.cc
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscLogicalOrOp : public OpKernel {
+ public:
+  explicit RiscLogicalOrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscLogicalOr op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscLogicalOr").Device(DEVICE_CPU),
+                        RiscLogicalOrOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_random_uniform_op.cc b/tensorflow/core/kernels/risc/experimental/risc_random_uniform_op.cc
new file mode 100644
index 0000000..8c326a4
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_random_uniform_op.cc
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscRandomUniformOp : public OpKernel {
+ public:
+  explicit RiscRandomUniformOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscRandomUniform op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscRandomUniform").Device(DEVICE_CPU),
+                        RiscRandomUniformOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_reduce_op.cc b/tensorflow/core/kernels/risc/experimental/risc_reduce_op.cc
new file mode 100644
index 0000000..2a5cbc8
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_reduce_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscReduceOp : public OpKernel {
+ public:
+  explicit RiscReduceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscReduce op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscReduce").Device(DEVICE_CPU), RiscReduceOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_reverse_op.cc b/tensorflow/core/kernels/risc/experimental/risc_reverse_op.cc
new file mode 100644
index 0000000..815bc43
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_reverse_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscReverseOp : public OpKernel {
+ public:
+  explicit RiscReverseOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscReverse op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscReverse").Device(DEVICE_CPU), RiscReverseOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_scatter_op.cc b/tensorflow/core/kernels/risc/experimental/risc_scatter_op.cc
new file mode 100644
index 0000000..55e6d18
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_scatter_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscScatterOp : public OpKernel {
+ public:
+  explicit RiscScatterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscScatter op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscScatter").Device(DEVICE_CPU), RiscScatterOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_sort_op.cc b/tensorflow/core/kernels/risc/experimental/risc_sort_op.cc
new file mode 100644
index 0000000..698ca2d
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_sort_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscSortOp : public OpKernel {
+ public:
+  explicit RiscSortOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscSort op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscSort").Device(DEVICE_CPU), RiscSortOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_squeeze_op.cc b/tensorflow/core/kernels/risc/experimental/risc_squeeze_op.cc
new file mode 100644
index 0000000..aca072d
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_squeeze_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscSqueezeOp : public OpKernel {
+ public:
+  explicit RiscSqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscSqueeze op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscSqueeze").Device(DEVICE_CPU), RiscSqueezeOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_transpose_op.cc b/tensorflow/core/kernels/risc/experimental/risc_transpose_op.cc
new file mode 100644
index 0000000..58e0e4a
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_transpose_op.cc
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscTransposeOp : public OpKernel {
+ public:
+  explicit RiscTransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscTranspose op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscTranspose").Device(DEVICE_CPU),
+                        RiscTransposeOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_triangular_solve_op.cc b/tensorflow/core/kernels/risc/experimental/risc_triangular_solve_op.cc
new file mode 100644
index 0000000..d6e0be8d
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_triangular_solve_op.cc
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscTriangularSolveOp : public OpKernel {
+ public:
+  explicit RiscTriangularSolveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscTriangularSolve op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscTriangularSolve").Device(DEVICE_CPU),
+                        RiscTriangularSolveOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_unary_op.cc b/tensorflow/core/kernels/risc/experimental/risc_unary_op.cc
new file mode 100644
index 0000000..499686f
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_unary_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscUnaryOp : public OpKernel {
+ public:
+  explicit RiscUnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscUnary op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscUnary").Device(DEVICE_CPU), RiscUnaryOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/risc/experimental/risc_while_op.cc b/tensorflow/core/kernels/risc/experimental/risc_while_op.cc
new file mode 100644
index 0000000..165a41f
--- /dev/null
+++ b/tensorflow/core/kernels/risc/experimental/risc_while_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace risc {
+namespace experimental {
+
+class RiscWhileOp : public OpKernel {
+ public:
+  explicit RiscWhileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // TODO(b/171294012): Implement RiscWhile op.
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RiscWhile").Device(DEVICE_CPU), RiscWhileOp);
+
+}  // namespace experimental
+}  // namespace risc
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
index 6d9201d..90c38d1 100644
--- a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
+++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
@@ -15,6 +15,8 @@
 
 #define EIGEN_USE_THREADS
 
+#include "tensorflow/core/kernels/sparse_fill_empty_rows_op.h"
+
 #include <algorithm>
 #include <numeric>
 #include <unordered_map>
@@ -33,7 +35,157 @@
 
 using CPUDevice = Eigen::ThreadPoolDevice;
 
-template <typename T>
+namespace functor {
+
+template <typename T, typename Tindex>
+struct SparseFillEmptyRows<CPUDevice, T, Tindex> {
+  Status operator()(OpKernelContext* context, const Tensor& default_value_t,
+                    const Tensor& indices_t, const Tensor& values_t,
+                    const Tensor& dense_shape_t) {
+    const int kOutputIndicesOutput = 0;
+    const int kOutputValuesOutput = 1;
+    const int kEmptyRowIndicatorOutput = 2;
+    const int kReverseIndexMapOutput = 3;
+
+    const T& default_value = default_value_t.scalar<T>()();
+    const auto indices = indices_t.matrix<Tindex>();
+    const auto values = values_t.vec<T>();
+    const auto dense_shape = dense_shape_t.vec<Tindex>();
+
+    const Tindex N = indices_t.shape().dim_size(0);
+    const Tindex dense_rows = dense_shape(0);
+
+    bool* empty_row_indicator = nullptr;
+    if (context->output_required(kEmptyRowIndicatorOutput)) {
+      Tensor* empty_row_indicator_t = nullptr;
+      TF_RETURN_IF_ERROR(context->allocate_output(kEmptyRowIndicatorOutput,
+                                                  TensorShape({dense_rows}),
+                                                  &empty_row_indicator_t));
+      empty_row_indicator = empty_row_indicator_t->vec<bool>().data();
+    }
+    Tindex* reverse_index_map = nullptr;
+    if (context->output_required(kReverseIndexMapOutput)) {
+      Tensor* reverse_index_map_t = nullptr;
+      TF_RETURN_IF_ERROR(context->allocate_output(
+          kReverseIndexMapOutput, TensorShape({N}), &reverse_index_map_t));
+      reverse_index_map = reverse_index_map_t->vec<Tindex>().data();
+    }
+
+    int rank = indices_t.shape().dim_size(1);
+
+    if (dense_rows == 0) {
+      if (N != 0) {
+        return errors::InvalidArgument(
+            "Received SparseTensor with dense_shape[0] = 0 but "
+            "indices.shape[0] = ",
+            N);
+      }
+      Tensor* output_indices_t;
+      TensorShape output_indices_shape({0, rank});
+      TF_RETURN_IF_ERROR(context->allocate_output(
+          kOutputIndicesOutput, output_indices_shape, &output_indices_t));
+      Tensor* output_values_t;
+      TF_RETURN_IF_ERROR(context->allocate_output(
+          kOutputValuesOutput, TensorShape({0}), &output_values_t));
+
+      // Exit early, nothing more to do.
+      return Status::OK();
+    }
+
+    bool rows_are_ordered = true;
+    Tindex last_indices_row = 0;
+    std::vector<Tindex> csr_offset(dense_rows, 0);
+    for (int i = 0; i < N; ++i) {
+      const Tindex row = indices(i, 0);
+      if (row < 0 || row >= dense_rows) {
+        return errors::InvalidArgument("indices(", i, ", 0) is invalid: ", row,
+                                       " >= ", dense_rows);
+      }
+      ++csr_offset[row];
+      rows_are_ordered = rows_are_ordered & (row >= last_indices_row);
+      last_indices_row = row;
+    }
+    bool all_rows_full = true;
+    for (int row = 0; row < dense_rows; ++row) {
+      // csr_offset here describes the number of elements in this dense row
+      bool row_empty = (csr_offset[row] == 0);
+      if (empty_row_indicator) {
+        empty_row_indicator[row] = row_empty;
+      }
+      all_rows_full = all_rows_full & !row_empty;
+      // In filled version, each row has at least one element.
+      csr_offset[row] = std::max(csr_offset[row], Tindex{1});
+      // Update csr_offset to represent the number of elements up to and
+      // including dense_row + 1:
+      //  csr_offset(0) == #{elements of row 0}
+      //  csr_offset(1) == #{elements of row 1} + #{elements of row 0}
+      //  ..
+      //  csr_offset(i) == starting index for elements in row i + 1.
+      if (row > 0) {
+        csr_offset[row] += csr_offset[row - 1];
+      }
+    }
+
+    if (all_rows_full && rows_are_ordered) {
+      context->set_output(kOutputIndicesOutput, indices_t);
+      context->set_output(kOutputValuesOutput, values_t);
+      if (reverse_index_map) {
+        for (Tindex i = 0; i < N; ++i) {
+          reverse_index_map[i] = i;
+        }
+      }
+    } else {
+      Tensor* output_indices_t;
+      const Tindex N_full = csr_offset[dense_rows - 1];
+      TensorShape output_indices_shape({N_full, rank});
+      TF_RETURN_IF_ERROR(context->allocate_output(
+          kOutputIndicesOutput, output_indices_shape, &output_indices_t));
+      auto output_indices = output_indices_t->matrix<Tindex>();
+
+      Tensor* output_values_t;
+      TF_RETURN_IF_ERROR(context->allocate_output(
+          kOutputValuesOutput, TensorShape({N_full}), &output_values_t));
+      auto output_values = output_values_t->vec<T>();
+
+      std::vector<Tindex> filled_count(dense_rows, 0);
+
+      // Fill in values for rows that are not missing
+      for (Tindex i = 0; i < N; ++i) {
+        const Tindex row = indices(i, 0);
+        Tindex& offset = filled_count[row];
+        const Tindex output_i = ((row == 0) ? 0 : csr_offset[row - 1]) + offset;
+        offset++;  // Increment the filled count for this row.
+        std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0));
+        output_values(output_i) = values(i);
+        // We'll need this reverse index map to backprop correctly.
+        if (reverse_index_map) {
+          reverse_index_map[i] = output_i;
+        }
+      }
+
+      // Fill in values for rows that are missing
+      for (Tindex row = 0; row < dense_rows; ++row) {
+        const Tindex row_count = filled_count[row];
+        if (row_count == 0) {  // We haven't filled this row
+          const Tindex starting_index = (row == 0) ? 0 : csr_offset[row - 1];
+          // Remaining index values were set to zero already.
+          // Just need to set the row index in the right location.
+          output_indices(starting_index, 0) = row;
+          for (Tindex col = 1; col < rank; ++col) {
+            output_indices(starting_index, col) = 0;
+          }
+          output_values(starting_index) = default_value;
+        }
+      }
+    }
+
+    return Status::OK();
+  }
+};
+
+}  // namespace functor
+
+template <typename Device, typename T, typename Tindex>
 class SparseFillEmptyRowsOp : public OpKernel {
  public:
   explicit SparseFillEmptyRowsOp(OpKernelConstruction* context)
@@ -45,11 +197,6 @@
     const int kDenseShapeInput = 2;
     const int kDefaultValueInput = 3;
 
-    const int kOutputIndicesOutput = 0;
-    const int kOutputValuesOutput = 1;
-    const int kEmptyRowIndicatorOutput = 2;
-    const int kReverseIndexMapOutput = 3;
-
     const Tensor& indices_t = context->input(kIndicesInput);
     const Tensor& values_t = context->input(kValuesInput);
     const Tensor& dense_shape_t = context->input(kDenseShapeInput);
@@ -70,154 +217,75 @@
     // TODO(ebrevdo): add shape checks between values, indices,
     // dense_shape.  Also add check that dense rank > 0.
 
-    const T& default_value = default_value_t.scalar<T>()();
-    const auto indices = indices_t.matrix<int64>();
-    const auto values = values_t.vec<T>();
-    const auto dense_shape = dense_shape_t.vec<int64>();
-
-    const int64 N = indices_t.shape().dim_size(0);
-    const int64 dense_rows = dense_shape(0);
-
-    bool* empty_row_indicator = nullptr;
-    if (context->output_required(kEmptyRowIndicatorOutput)) {
-      Tensor* empty_row_indicator_t = nullptr;
-      OP_REQUIRES_OK(context,
-                     context->allocate_output(kEmptyRowIndicatorOutput,
-                                              TensorShape({dense_rows}),
-                                              &empty_row_indicator_t));
-      empty_row_indicator = empty_row_indicator_t->vec<bool>().data();
-    }
-    int64* reverse_index_map = nullptr;
-    if (context->output_required(kReverseIndexMapOutput)) {
-      Tensor* reverse_index_map_t = nullptr;
-      OP_REQUIRES_OK(context, context->allocate_output(kReverseIndexMapOutput,
-                                                       TensorShape({N}),
-                                                       &reverse_index_map_t));
-      reverse_index_map = reverse_index_map_t->vec<int64>().data();
-    }
-
-    int rank = indices_t.shape().dim_size(1);
-
-    if (dense_rows == 0) {
-      OP_REQUIRES(
-          context, N == 0,
-          errors::InvalidArgument("Received SparseTensor with dense_shape[0] = "
-                                  "0 but indices.shape[0] = ",
-                                  N));
-      Tensor* output_indices_t;
-      TensorShape output_indices_shape({0, rank});
-      OP_REQUIRES_OK(context, context->allocate_output(kOutputIndicesOutput,
-                                                       output_indices_shape,
-                                                       &output_indices_t));
-      Tensor* output_values_t;
-      OP_REQUIRES_OK(context, context->allocate_output(kOutputValuesOutput,
-                                                       TensorShape({0}),
-                                                       &output_values_t));
-
-      // Exit early, nothing more to do.
-      return;
-    }
-
-    bool rows_are_ordered = true;
-    int64 last_indices_row = 0;
-    std::vector<int64> csr_offset(dense_rows, 0);
-    for (int i = 0; i < N; ++i) {
-      const int64 row = indices(i, 0);
-      OP_REQUIRES(context, row >= 0 && row < dense_rows,
-                  errors::InvalidArgument("indices(", i, ", 0) is invalid: ",
-                                          row, " >= ", dense_rows));
-      ++csr_offset[row];
-      rows_are_ordered = rows_are_ordered & (row >= last_indices_row);
-      last_indices_row = row;
-    }
-    bool all_rows_full = true;
-    for (int row = 0; row < dense_rows; ++row) {
-      // csr_offset here describes the number of elements in this dense row
-      bool row_empty = (csr_offset[row] == 0);
-      if (empty_row_indicator) {
-        empty_row_indicator[row] = row_empty;
-      }
-      all_rows_full = all_rows_full & !row_empty;
-      // In filled version, each row has at least one element.
-      csr_offset[row] = std::max(csr_offset[row], int64{1});
-      // Update csr_offset to represent the number of elements up to and
-      // including dense_row + 1:
-      //  csr_offset(0) == #{elements of row 0}
-      //  csr_offset(1) == #{elements of row 1} + #{elements of row 0}
-      //  ..
-      //  csr_offset(i) == starting index for elements in row i + 1.
-      if (row > 0) {
-        csr_offset[row] += csr_offset[row - 1];
-      }
-    }
-
-    if (all_rows_full && rows_are_ordered) {
-      context->set_output(kOutputIndicesOutput, indices_t);
-      context->set_output(kOutputValuesOutput, values_t);
-      if (reverse_index_map) {
-        for (int64 i = 0; i < N; ++i) {
-          reverse_index_map[i] = i;
-        }
-      }
-    } else {
-      Tensor* output_indices_t;
-      const int64 N_full = csr_offset[dense_rows - 1];
-      TensorShape output_indices_shape({N_full, rank});
-      OP_REQUIRES_OK(context, context->allocate_output(kOutputIndicesOutput,
-                                                       output_indices_shape,
-                                                       &output_indices_t));
-      auto output_indices = output_indices_t->matrix<int64>();
-
-      Tensor* output_values_t;
-      OP_REQUIRES_OK(context, context->allocate_output(kOutputValuesOutput,
-                                                       TensorShape({N_full}),
-                                                       &output_values_t));
-      auto output_values = output_values_t->vec<T>();
-
-      std::vector<int64> filled_count(dense_rows, 0);
-
-      // Fill in values for rows that are not missing
-      for (int64 i = 0; i < N; ++i) {
-        const int64 row = indices(i, 0);
-        int64& offset = filled_count[row];
-        const int64 output_i = ((row == 0) ? 0 : csr_offset[row - 1]) + offset;
-        offset++;  // Increment the filled count for this row.
-        std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0));
-        output_values(output_i) = values(i);
-        // We'll need this reverse index map to backprop correctly.
-        if (reverse_index_map) {
-          reverse_index_map[i] = output_i;
-        }
-      }
-
-      // Fill in values for rows that are missing
-      for (int64 row = 0; row < dense_rows; ++row) {
-        const int64 row_count = filled_count[row];
-        if (row_count == 0) {  // We haven't filled this row
-          const int64 starting_index = (row == 0) ? 0 : csr_offset[row - 1];
-          // Remaining index values were set to zero already.
-          // Just need to set the row index in the right location.
-          output_indices(starting_index, 0) = row;
-          for (int64 col = 1; col < rank; ++col) {
-            output_indices(starting_index, col) = 0;
-          }
-          output_values(starting_index) = default_value;
-        }
-      }
-    }
+    OP_REQUIRES_OK(context, functor::SparseFillEmptyRows<Device, T, Tindex>()(
+                                context, default_value_t, indices_t, values_t,
+                                dense_shape_t));
   }
 };
 
-#define REGISTER_KERNELS(type)                            \
-  REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows")     \
-                              .Device(DEVICE_CPU)         \
-                              .TypeConstraint<type>("T"), \
-                          SparseFillEmptyRowsOp<type>)
+#define REGISTER_KERNELS(D, T, Tindex)                   \
+  REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows")    \
+                              .Device(DEVICE_##D)        \
+                              .HostMemory("dense_shape") \
+                              .TypeConstraint<T>("T"),   \
+                          SparseFillEmptyRowsOp<D##Device, T, Tindex>)
 
-TF_CALL_ALL_TYPES(REGISTER_KERNELS);
+#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T, int64)
+TF_CALL_ALL_TYPES(REGISTER_CPU_KERNELS);
+#undef REGISTER_CPU_KERNELS
+
 #undef REGISTER_KERNELS
 
-template <typename T>
+namespace functor {
+
+template <typename T, typename Tindex>
+struct SparseFillEmptyRowsGrad<CPUDevice, T, Tindex> {
+  Status operator()(OpKernelContext* context,
+                    typename TTypes<Tindex>::ConstVec reverse_index_map,
+                    typename TTypes<T>::ConstVec grad_values,
+                    typename TTypes<T>::Vec d_values,
+                    typename TTypes<T>::Scalar d_default_value) {
+    const CPUDevice& device = context->eigen_device<CPUDevice>();
+    const Tindex N = reverse_index_map.dimension(0);
+    const Tindex N_full = grad_values.dimension(0);
+
+    T& d_default_value_scalar = d_default_value();
+    d_default_value_scalar = T();
+
+    Tensor visited_t;
+    TF_RETURN_IF_ERROR(
+        context->allocate_temp(DT_BOOL, TensorShape({N_full}), &visited_t));
+    auto visited = visited_t.vec<bool>();
+    visited.device(device) = visited.constant(false);
+
+    for (int i = 0; i < N; ++i) {
+      // Locate the index of the output of the forward prop associated
+      // with this location in the input of the forward prop.  Copy
+      // the gradient into it.  Mark it as visited.
+      int64 reverse_index = reverse_index_map(i);
+      if (reverse_index < 0 || reverse_index >= N_full) {
+        return errors::InvalidArgument(
+            "Elements in reverse index must be in [0, ", N_full, ") but got ",
+            reverse_index);
+      }
+      d_values(i) = grad_values(reverse_index);
+      visited(reverse_index) = true;
+    }
+    for (int j = 0; j < N_full; ++j) {
+      // The default value gradient gets the accumulated remainder of
+      // the backprop values (since the default value was used to fill
+      // in these slots in the forward calculation).
+      if (!visited(j)) {
+        d_default_value_scalar += grad_values(j);
+      }
+    }
+    return Status::OK();
+  }
+};
+
+}  // namespace functor
+
+template <typename Device, typename T, typename Tindex>
 class SparseFillEmptyRowsGradOp : public OpKernel {
  public:
   explicit SparseFillEmptyRowsGradOp(OpKernelConstruction* context)
@@ -230,8 +298,6 @@
                    context->input("reverse_index_map", &reverse_index_map_t));
     OP_REQUIRES_OK(context, context->input("grad_values", &grad_values_t));
 
-    const CPUDevice& d = context->eigen_device<CPUDevice>();
-
     OP_REQUIRES(
         context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()),
         errors::InvalidArgument("reverse_index_map must be a vector, saw: ",
@@ -240,11 +306,10 @@
                 errors::InvalidArgument("grad_values must be a vector, saw: ",
                                         grad_values_t->shape().DebugString()));
 
-    const auto reverse_index_map = reverse_index_map_t->vec<int64>();
+    const auto reverse_index_map = reverse_index_map_t->vec<Tindex>();
     const auto grad_values = grad_values_t->vec<T>();
 
-    const int64 N = reverse_index_map_t->shape().dim_size(0);
-    const int64 N_full = grad_values_t->shape().dim_size(0);
+    const Tindex N = reverse_index_map_t->shape().dim_size(0);
 
     Tensor* d_values_t;
     OP_REQUIRES_OK(context, context->allocate_output(
@@ -254,44 +319,24 @@
     OP_REQUIRES_OK(context,
                    context->allocate_output("d_default_value", TensorShape({}),
                                             &d_default_value_t));
-    T& d_default_value = d_default_value_t->scalar<T>()();
-    d_default_value = T();
+    auto d_default_value = d_default_value_t->scalar<T>();
 
-    Tensor visited_t;
-    OP_REQUIRES_OK(context, context->allocate_temp(
-                                DT_BOOL, TensorShape({N_full}), &visited_t));
-    auto visited = visited_t.vec<bool>();
-    visited.device(d) = visited.constant(false);
-
-    for (int i = 0; i < N; ++i) {
-      // Locate the index of the output of the forward prop associated
-      // with this location in the input of the forward prop.  Copy
-      // the gradient into it.  Mark it as visited.
-      int64 reverse_index = reverse_index_map(i);
-      OP_REQUIRES(
-          context, 0 <= reverse_index && reverse_index < N_full,
-          errors::InvalidArgument("Elements in reverse index must be in [0, ",
-                                  N_full, ") but got ", reverse_index));
-      d_values(i) = grad_values(reverse_index);
-      visited(reverse_index) = true;
-    }
-    for (int j = 0; j < N_full; ++j) {
-      // The default value gradient gets the accumulated remainder of
-      // the backprop values (since the default value was used to fill
-      // in these slots in the forward calculation).
-      if (!visited(j)) {
-        d_default_value += grad_values(j);
-      }
-    }
+    OP_REQUIRES_OK(context,
+                   functor::SparseFillEmptyRowsGrad<Device, T, Tindex>()(
+                       context, reverse_index_map, grad_values, d_values,
+                       d_default_value));
   }
 };
 
-#define REGISTER_KERNELS(type)                            \
+#define REGISTER_KERNELS(D, T, Tindex)                    \
   REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRowsGrad") \
-                              .Device(DEVICE_CPU)         \
-                              .TypeConstraint<type>("T"), \
-                          SparseFillEmptyRowsGradOp<type>)
+                              .Device(DEVICE_##D)         \
+                              .TypeConstraint<T>("T"),    \
+                          SparseFillEmptyRowsGradOp<D##Device, T, Tindex>)
 
-TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
+#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T, int64)
+TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
+#undef REGISTER_CPU_KERNELS
+
 #undef REGISTER_KERNELS
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op.h b/tensorflow/core/kernels/sparse_fill_empty_rows_op.h
new file mode 100644
index 0000000..9d9bc29
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op.h
@@ -0,0 +1,47 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_FILL_EMPTY_ROWS_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_FILL_EMPTY_ROWS_OP_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename Device, typename T, typename Tindex>
+struct SparseFillEmptyRows {
+  Status operator()(OpKernelContext* context, const Tensor& default_value_t,
+                    const Tensor& indices_t, const Tensor& values_t,
+                    const Tensor& dense_shape_t);
+};
+
+template <typename Device, typename T, typename Tindex>
+struct SparseFillEmptyRowsGrad {
+  Status operator()(OpKernelContext* context,
+                    typename TTypes<Tindex>::ConstVec reverse_index_map,
+                    typename TTypes<T>::ConstVec grad_values,
+                    typename TTypes<T>::Vec d_values,
+                    typename TTypes<T>::Scalar d_default_value);
+};
+
+}  // namespace functor
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_SPARSE_FILL_EMPTY_ROWS_OP_H_
diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc
index 44bcab4..6fab5f1 100644
--- a/tensorflow/core/kernels/sparse_xent_op.cc
+++ b/tensorflow/core/kernels/sparse_xent_op.cc
@@ -18,7 +18,6 @@
 #define EIGEN_USE_THREADS
 
 #include "tensorflow/core/kernels/sparse_xent_op.h"
-
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -123,8 +122,6 @@
 REGISTER(CPU, float, int64)
 REGISTER(CPU, double, int32)
 REGISTER(CPU, double, int64)
-REGISTER(CPU, bfloat16, int32)
-REGISTER(CPU, bfloat16, int64)
 REGISTER(CPU, Eigen::half, int32)
 REGISTER(CPU, Eigen::half, int64)
 
diff --git a/tensorflow/core/kernels/sparse_xent_op_test.cc b/tensorflow/core/kernels/sparse_xent_op_test.cc
index f095f2e..85a5cd3 100644
--- a/tensorflow/core/kernels/sparse_xent_op_test.cc
+++ b/tensorflow/core/kernels/sparse_xent_op_test.cc
@@ -23,9 +23,9 @@
 
 namespace tensorflow {
 
-static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) {
+static Graph* SparseXent(int batch_size, int num_classes) {
   Graph* g = new Graph(OpRegistry::Global());
-  Tensor logits(value_type, TensorShape({batch_size, num_classes}));
+  Tensor logits(DT_FLOAT, TensorShape({batch_size, num_classes}));
   logits.flat<float>().setRandom();
   Tensor labels(DT_INT64, TensorShape({batch_size}));
   std::random_device rd;
@@ -41,45 +41,44 @@
   return g;
 }
 
-#define BM_SparseXentDev(BATCH, CLASS, DEVICE, DTYPE)                        \
-  static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE(        \
+#define BM_SparseXentDev(BATCH, CLASS, DEVICE)                               \
+  static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE(                  \
       ::testing::benchmark::State& state) {                                  \
-    test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS, DTYPE),                \
+    test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS),                       \
                     /*old_benchmark_api*/ false)                             \
         .Run(state);                                                         \
     state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \
                             CLASS);                                          \
   }                                                                          \
-  BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE);
-
-#define BM_SPARSE_XENT_DEV_CPU(DTYPE)       \
-  BM_SparseXentDev(8, 1000000, cpu, DTYPE); \
-  BM_SparseXentDev(16, 10000, cpu, DTYPE);  \
-  BM_SparseXentDev(16, 100000, cpu, DTYPE); \
-  BM_SparseXentDev(32, 10000, cpu, DTYPE);  \
-  BM_SparseXentDev(32, 100000, cpu, DTYPE); \
-  BM_SparseXentDev(64, 10000, cpu, DTYPE);  \
-  BM_SparseXentDev(64, 100000, cpu, DTYPE);
-
-// CPU
-BM_SPARSE_XENT_DEV_CPU(DT_FLOAT);
-BM_SPARSE_XENT_DEV_CPU(DT_BFLOAT16);
+  BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE);
 
 /// The representative tests for ptb_word on GPU
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-BM_SparseXentDev(8, 1000000, gpu, DT_FLOAT);
+BM_SparseXentDev(8, 1000000, gpu);
 
-BM_SparseXentDev(16, 10000, gpu, DT_FLOAT);
-BM_SparseXentDev(16, 30000, gpu, DT_FLOAT);
-BM_SparseXentDev(16, 100000, gpu, DT_FLOAT);
+BM_SparseXentDev(16, 10000, gpu);
+BM_SparseXentDev(16, 30000, gpu);
+BM_SparseXentDev(16, 100000, gpu);
 
-BM_SparseXentDev(32, 10000, gpu, DT_FLOAT);
-BM_SparseXentDev(32, 30000, gpu, DT_FLOAT);
-BM_SparseXentDev(32, 100000, gpu, DT_FLOAT);
+BM_SparseXentDev(32, 10000, gpu);
+BM_SparseXentDev(32, 30000, gpu);
+BM_SparseXentDev(32, 100000, gpu);
 
-BM_SparseXentDev(64, 10000, gpu, DT_FLOAT);
-BM_SparseXentDev(64, 30000, gpu, DT_FLOAT);
-BM_SparseXentDev(64, 100000, gpu, DT_FLOAT);
+BM_SparseXentDev(64, 10000, gpu);
+BM_SparseXentDev(64, 30000, gpu);
+BM_SparseXentDev(64, 100000, gpu);
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
+// CPU
+BM_SparseXentDev(8, 1000000, cpu);
+
+BM_SparseXentDev(16, 10000, cpu);
+BM_SparseXentDev(16, 100000, cpu);
+
+BM_SparseXentDev(32, 10000, cpu);
+BM_SparseXentDev(32, 100000, cpu);
+
+BM_SparseXentDev(64, 10000, cpu);
+BM_SparseXentDev(64, 100000, cpu);
+
 }  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 3113de5..12d626b 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -162,6 +162,86 @@
   }
 };
 
+template <typename T, typename Tindex, bool has_epsilon>
+struct SparseApplyAdagrad<CPUDevice, T, Tindex, has_epsilon> {
+  Status operator()(const CPUDevice& d, typename TTypes<T>::Matrix var,
+                    typename TTypes<T>::Matrix accum,
+                    typename TTypes<T>::ConstScalar lr,
+                    typename TTypes<T>::ConstScalar epsilon,
+                    typename TTypes<T>::ConstMatrix grad,
+                    typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
+                    bool update_slots) {
+    const Tindex N = static_cast<Tindex>(indices.dimension(0));
+    if (N == 0) return Status::OK();
+    const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0));
+    const T lr_scalar = lr();
+    const int in_bytes = inner_dim * sizeof(T) * 3;
+    const int out_bytes = inner_dim * sizeof(T) * 2;
+    const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
+                                    Eigen::TensorOpCost::MulCost<T>() * 2);
+    const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
+
+    if (inner_dim > 1) {
+      for (Tindex i = 0; i < N; ++i) {
+        const Tindex index = internal::SubtleMustCopy(indices(i));
+        if (!FastBoundsCheck(index, first_dim_size)) {
+          return errors::InvalidArgument(
+              strings::StrCat("Index ", index, " at offset ", i,
+                              " in indices is out of range"));
+        }
+      }
+
+      const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
+        for (Tindex i = start_idx; i < end_idx; ++i) {
+          const Tindex index = internal::SubtleMustCopy(indices(i));
+          auto a = accum.template chip<0>(index);
+          auto g = grad.template chip<0>(i);
+          auto v = var.template chip<0>(index);
+          if (update_slots) {
+            a += g.square();
+          }
+          if (has_epsilon) {
+            v -= g.constant(lr_scalar) * g / (a.sqrt() + a.constant(epsilon()));
+          } else {
+            v -= g.constant(lr_scalar) * g * a.rsqrt();
+          }
+        }
+      };
+
+      d.parallelFor(N, cost, shard);
+    } else {
+      for (Tindex i = 0; i < N; ++i) {
+        const Tindex index = internal::SubtleMustCopy(indices(i));
+        if (!FastBoundsCheck(index, first_dim_size)) {
+          return errors::InvalidArgument(
+              strings::StrCat("Index ", index, " at offset ", i,
+                              " in indices is out of range"));
+        }
+      }
+
+      const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
+        for (Tindex i = start_idx; i < end_idx; ++i) {
+          const Tindex index = internal::SubtleMustCopy(indices(i));
+          T& a = accum(index);
+          const T& g = grad(i);
+          if (update_slots) {
+            a += g * g;
+          }
+          if (has_epsilon) {
+            var(index) -= lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon());
+          } else {
+            var(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
+          }
+        }
+      };
+
+      d.parallelFor(N, cost, shard);
+    }
+
+    return Status::OK();
+  }
+};
+
 template <typename T>
 struct ApplyProximalAdagrad<CPUDevice, T> {
   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
@@ -190,6 +270,78 @@
   }
 };
 
+template <typename T, typename Tindex>
+struct SparseApplyProximalAdagrad<CPUDevice, T, Tindex> {
+  Status operator()(const CPUDevice& d, typename TTypes<T>::Matrix var,
+                    typename TTypes<T>::Matrix accum,
+                    typename TTypes<T>::ConstScalar lr,
+                    typename TTypes<T>::ConstScalar l1,
+                    typename TTypes<T>::ConstScalar l2,
+                    typename TTypes<T>::ConstMatrix grad,
+                    typename TTypes<Tindex>::ConstVec indices,
+                    int64 inner_dim) {
+    const Tindex N = static_cast<Tindex>(indices.dimension(0));
+    if (N == 0) return Status::OK();
+    const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0));
+    const T lr_scalar = lr();
+    const T l1_scalar = l1();
+    const T l2_scalar = l2();
+    if (inner_dim > 1) {
+      for (Tindex i = 0; i < N; i++) {
+        const Tindex index = internal::SubtleMustCopy(indices(i));
+        if (!FastBoundsCheck(index, first_dim_size)) {
+          return errors::InvalidArgument(
+              strings::StrCat("Index ", index, " at offset ", i,
+                              " in indices is out of range"));
+        }
+        auto a = accum.template chip<0>(index);
+        auto g = grad.template chip<0>(i);
+        auto v = var.template chip<0>(index);
+        a += g.square();
+        // compute learning_rate for current step.
+        auto learning_rate = a.constant(lr_scalar) * a.rsqrt();
+        auto prox_v = v;
+        // v = w - g * learning_rate.
+        prox_v -= g * learning_rate;
+        if (l1_scalar > 0) {
+          // compute sign(v) * max(|v|, 0)
+          v = prox_v.sign() *
+              (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
+                  .cwiseMax(static_cast<T>(0.0)) /
+              (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
+        } else {
+          v = prox_v /
+              (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
+        }
+      }
+    } else {
+      for (Tindex i = 0; i < N; i++) {
+        const Tindex index = internal::SubtleMustCopy(indices(i));
+        if (!FastBoundsCheck(index, first_dim_size)) {
+          return errors::InvalidArgument(
+              strings::StrCat("Index ", index, " at offset ", i,
+                              " in indices is out of range"));
+        }
+        T& a = accum(index);
+        const T& g = grad(i);
+        a += g * g;
+        auto learning_rate = lr_scalar / std::sqrt(a);
+        auto prox_v = var(index);
+        prox_v -= learning_rate * g;
+        if (l1_scalar > 0) {
+          var(index) = sgn(prox_v) *
+                       std::max(std::abs(prox_v) - learning_rate * l1_scalar,
+                                static_cast<T>(0.0)) /
+                       (1.0 + l2_scalar * learning_rate);
+        } else {
+          var(index) = prox_v / (1.0 + l2_scalar * learning_rate);
+        }
+      }
+    }
+    return Status::OK();
+  }
+};
+
 template <typename T>
 struct ApplyFtrlV2<CPUDevice, T> {
   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
@@ -1694,8 +1846,7 @@
 #undef REGISTER_CPU_KERNELS
 #undef REGISTER_KERNELS
 
-// Note, this op works on cpu only.
-template <typename T, typename Tindex>
+template <typename Device, typename T, typename Tindex>
 class SparseApplyAdagradOp : public OpKernel {
  public:
   explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -1705,13 +1856,13 @@
 
   void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
     const bool sparse = true;
-    auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
+    auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
         ctx, use_exclusive_lock_, sparse, {0, 1});
     Tensor var;
-    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
+    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
                             ctx, 0, use_exclusive_lock_, sparse, &var));
     Tensor accum;
-    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
+    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
                             ctx, 1, use_exclusive_lock_, sparse, &accum));
     OP_REQUIRES(
         ctx, var.IsInitialized(),
@@ -1755,78 +1906,14 @@
                 errors::InvalidArgument(
                     "Inner dimension should be greater than zero."));
 
-    // This op is implemented only for CPU device.
-    const auto& d = ctx->eigen_cpu_device();
-
-    if (N > 0) {
-      const int in_bytes = inner_dim * sizeof(T) * 3;
-      const int out_bytes = inner_dim * sizeof(T) * 2;
-      const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
-                                      Eigen::TensorOpCost::MulCost<T>() * 2);
-      const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
-
-      if (inner_dim > 1) {
-        const Tindex first_dim_size = var.dim_size(0);
-        auto indices_vec = indices.vec<Tindex>();
-        auto var_flat = var.flat_outer_dims<T>();
-        auto accum_flat = accum.flat_outer_dims<T>();
-        auto grad_flat = grad.flat_outer_dims<T>();
-        T lr_scalar = lr.scalar<T>()();
-
-        for (Tindex i = 0; i < N; ++i) {
-          const Tindex index = internal::SubtleMustCopy(indices_vec(i));
-          OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
-                      errors::InvalidArgument(
-                          strings::StrCat("Index ", index, " at offset ", i,
-                                          " in indices is out of range")));
-        }
-
-        const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
-          for (Tindex i = start_idx; i < end_idx; ++i) {
-            const Tindex index = internal::SubtleMustCopy(indices_vec(i));
-            auto a = accum_flat.template chip<0>(index);
-            auto g = grad_flat.template chip<0>(i);
-            auto v = var_flat.template chip<0>(index);
-            if (update_slots_) {
-              a += g.square();
-            }
-            v -= g.constant(lr_scalar) * g * a.rsqrt();
-          }
-        };
-
-        d.parallelFor(N, cost, shard);
-
-      } else {
-        auto indices_vec = indices.vec<Tindex>();
-        auto var_flat = var.flat<T>();
-        auto accum_flat = accum.flat<T>();
-        auto grad_flat = grad.flat<T>();
-        T lr_scalar = lr.scalar<T>()();
-        const Tindex first_dim_size = accum_flat.size();
-
-        for (Tindex i = 0; i < N; ++i) {
-          const Tindex index = internal::SubtleMustCopy(indices_vec(i));
-          OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
-                      errors::InvalidArgument(
-                          strings::StrCat("Index ", index, " at offset ", i,
-                                          " in indices is out of range")));
-        }
-
-        const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
-          for (Tindex i = start_idx; i < end_idx; ++i) {
-            const Tindex index = internal::SubtleMustCopy(indices_vec(i));
-            T& a = accum_flat(index);
-            const T& g = grad_flat(i);
-            if (update_slots_) {
-              a += g * g;
-            }
-            var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
-          }
-        };
-
-        d.parallelFor(N, cost, shard);
-      }
-    }
+    const Device& device = ctx->template eigen_device<Device>();
+    OP_REQUIRES_OK(
+        ctx, functor::SparseApplyAdagrad<Device, T, Tindex,
+                                         /*has_epsilon = */ false>()(
+                 device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(),
+                 // Note: Passing lr as a placeholder for unused epsilon.
+                 lr.scalar<T>(), lr.scalar<T>(), grad.flat_outer_dims<T>(),
+                 indices.vec<Tindex>(), inner_dim, update_slots_));
 
     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   }
@@ -1836,20 +1923,20 @@
   bool update_slots_;
 };
 
-#define REGISTER_KERNELS(T, Tindices)                                \
-  REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagrad")                 \
-                              .Device(DEVICE_CPU)                    \
-                              .TypeConstraint<T>("T")                \
-                              .TypeConstraint<Tindices>("Tindices"), \
-                          SparseApplyAdagradOp<T, Tindices>);        \
-  REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagrad")         \
-                              .Device(DEVICE_CPU)                    \
-                              .TypeConstraint<T>("T")                \
-                              .TypeConstraint<Tindices>("Tindices"), \
-                          SparseApplyAdagradOp<T, Tindices>);
-#define REGISTER_CPU_KERNELS(T) \
-  REGISTER_KERNELS(T, int32);   \
-  REGISTER_KERNELS(T, int64);
+#define REGISTER_KERNELS(D, T, Tindices)                                 \
+  REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagrad")                     \
+                              .Device(DEVICE_##D)                        \
+                              .TypeConstraint<T>("T")                    \
+                              .TypeConstraint<Tindices>("Tindices"),     \
+                          SparseApplyAdagradOp<D##Device, T, Tindices>); \
+  REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagrad")             \
+                              .Device(DEVICE_##D)                        \
+                              .TypeConstraint<T>("T")                    \
+                              .TypeConstraint<Tindices>("Tindices"),     \
+                          SparseApplyAdagradOp<D##Device, T, Tindices>);
+#define REGISTER_CPU_KERNELS(T)    \
+  REGISTER_KERNELS(CPU, T, int32); \
+  REGISTER_KERNELS(CPU, T, int64);
 
 TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
 TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
@@ -1857,8 +1944,7 @@
 #undef REGISTER_CPU_KERNELS
 #undef REGISTER_KERNELS
 
-// Note, this op works on cpu only.
-template <typename T, typename Tindex>
+template <typename Device, typename T, typename Tindex>
 class SparseApplyAdagradV2Op : public OpKernel {
  public:
   explicit SparseApplyAdagradV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -1868,13 +1954,13 @@
 
   void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
     const bool sparse = true;
-    auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
+    auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
         ctx, use_exclusive_lock_, sparse, {0, 1});
     Tensor var;
-    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
+    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
                             ctx, 0, use_exclusive_lock_, sparse, &var));
     Tensor accum;
-    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
+    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
                             ctx, 1, use_exclusive_lock_, sparse, &accum));
     OP_REQUIRES(
         ctx, var.IsInitialized(),
@@ -1922,82 +2008,13 @@
                 errors::InvalidArgument(
                     "Inner dimension should be greater than zero."));
 
-    // This op is implemented only for CPU device.
-    const auto& d = ctx->eigen_cpu_device();
-
-    if (N > 0) {
-      const int in_bytes = inner_dim * sizeof(T) * 3;
-      const int out_bytes = inner_dim * sizeof(T) * 2;
-      const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
-                                      Eigen::TensorOpCost::MulCost<T>() * 2);
-      const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
-
-      if (inner_dim > 1) {
-        const Tindex first_dim_size = var.dim_size(0);
-        auto indices_vec = indices.vec<Tindex>();
-        auto var_flat = var.flat_outer_dims<T>();
-        auto accum_flat = accum.flat_outer_dims<T>();
-        auto grad_flat = grad.flat_outer_dims<T>();
-        const T lr_scalar = lr.scalar<T>()();
-        const T epsilon_scalar = epsilon.scalar<T>()();
-
-        for (Tindex i = 0; i < N; ++i) {
-          const Tindex index = internal::SubtleMustCopy(indices_vec(i));
-          OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
-                      errors::InvalidArgument(
-                          strings::StrCat("Index ", index, " at offset ", i,
-                                          " in indices is out of range")));
-        }
-
-        const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
-          for (Tindex i = start_idx; i < end_idx; ++i) {
-            const Tindex index = internal::SubtleMustCopy(indices_vec(i));
-            auto a = accum_flat.template chip<0>(index);
-            auto g = grad_flat.template chip<0>(i);
-            auto v = var_flat.template chip<0>(index);
-            if (update_slots_) {
-              a += g.square();
-            }
-            v -= g.constant(lr_scalar) * g /
-                 (a.sqrt() + a.constant(epsilon_scalar));
-          }
-        };
-
-        d.parallelFor(N, cost, shard);
-
-      } else {
-        auto indices_vec = indices.vec<Tindex>();
-        auto var_flat = var.flat<T>();
-        auto accum_flat = accum.flat<T>();
-        auto grad_flat = grad.flat<T>();
-        T lr_scalar = lr.scalar<T>()();
-        const T epsilon_scalar = epsilon.scalar<T>()();
-        const Tindex first_dim_size = accum_flat.size();
-
-        for (Tindex i = 0; i < N; ++i) {
-          const Tindex index = internal::SubtleMustCopy(indices_vec(i));
-          OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
-                      errors::InvalidArgument(
-                          strings::StrCat("Index ", index, " at offset ", i,
-                                          " in indices is out of range")));
-        }
-
-        const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
-          for (Tindex i = start_idx; i < end_idx; ++i) {
-            const Tindex index = internal::SubtleMustCopy(indices_vec(i));
-            T& a = accum_flat(index);
-            const T& g = grad_flat(i);
-            if (update_slots_) {
-              a += g * g;
-            }
-            var_flat(index) -=
-                lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon_scalar);
-          }
-        };
-
-        d.parallelFor(N, cost, shard);
-      }
-    }
+    const Device& device = ctx->template eigen_device<Device>();
+    OP_REQUIRES_OK(
+        ctx, functor::SparseApplyAdagrad<Device, T, Tindex,
+                                         /*has_epsilon = */ true>()(
+                 device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(),
+                 lr.scalar<T>(), epsilon.scalar<T>(), grad.flat_outer_dims<T>(),
+                 indices.vec<Tindex>(), inner_dim, update_slots_));
 
     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   }
@@ -2007,20 +2024,20 @@
   bool update_slots_;
 };
 
-#define REGISTER_KERNELS(T, Tindices)                                \
-  REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradV2")               \
-                              .Device(DEVICE_CPU)                    \
-                              .TypeConstraint<T>("T")                \
-                              .TypeConstraint<Tindices>("Tindices"), \
-                          SparseApplyAdagradV2Op<T, Tindices>);      \
-  REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradV2")       \
-                              .Device(DEVICE_CPU)                    \
-                              .TypeConstraint<T>("T")                \
-                              .TypeConstraint<Tindices>("Tindices"), \
-                          SparseApplyAdagradV2Op<T, Tindices>);
-#define REGISTER_CPU_KERNELS(T) \
-  REGISTER_KERNELS(T, int32);   \
-  REGISTER_KERNELS(T, int64);
+#define REGISTER_KERNELS(D, T, Tindices)                                   \
+  REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradV2")                     \
+                              .Device(DEVICE_##D)                          \
+                              .TypeConstraint<T>("T")                      \
+                              .TypeConstraint<Tindices>("Tindices"),       \
+                          SparseApplyAdagradV2Op<D##Device, T, Tindices>); \
+  REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradV2")             \
+                              .Device(DEVICE_##D)                          \
+                              .TypeConstraint<T>("T")                      \
+                              .TypeConstraint<Tindices>("Tindices"),       \
+                          SparseApplyAdagradV2Op<D##Device, T, Tindices>);
+#define REGISTER_CPU_KERNELS(T)    \
+  REGISTER_KERNELS(CPU, T, int32); \
+  REGISTER_KERNELS(CPU, T, int64);
 
 TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
 TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
@@ -2028,8 +2045,7 @@
 #undef REGISTER_CPU_KERNELS
 #undef REGISTER_KERNELS
 
-// Note, this op works on cpu only.
-template <typename T, typename Tindex>
+template <typename Device, typename T, typename Tindex>
 class SparseApplyProximalAdagradOp : public OpKernel {
  public:
   explicit SparseApplyProximalAdagradOp(OpKernelConstruction* ctx)
@@ -2039,13 +2055,13 @@
 
   void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
     const bool sparse = true;
-    auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
+    auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
         ctx, use_exclusive_lock_, sparse, {0, 1});
     Tensor var;
-    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
+    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
                             ctx, 0, use_exclusive_lock_, sparse, &var));
     Tensor accum;
-    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
+    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
                             ctx, 1, use_exclusive_lock_, sparse, &accum));
     OP_REQUIRES(
         ctx, var.IsInitialized(),
@@ -2066,20 +2082,23 @@
     const Tensor& lr = ctx->input(2);
     OP_REQUIRES(ctx,
                 TensorShapeUtils::IsScalar(lr.shape()) &&
-                    lr.scalar<T>()() > static_cast<T>(0),
+                    (!std::is_same<Device, CPUDevice>::value ||
+                     lr.scalar<T>()() > static_cast<T>(0)),
                 errors::InvalidArgument("lr is not a positive scalar: ",
                                         lr.shape().DebugString()));
     const Tensor& l1 = ctx->input(3);
     OP_REQUIRES(ctx,
                 TensorShapeUtils::IsScalar(l1.shape()) &&
-                    l1.scalar<T>()() >= static_cast<T>(0),
+                    (!std::is_same<Device, CPUDevice>::value ||
+                     l1.scalar<T>()() >= static_cast<T>(0)),
                 errors::InvalidArgument("l1 regularization strength is not a "
                                         "non-negative scalar: ",
                                         l1.shape().DebugString()));
     const Tensor& l2 = ctx->input(4);
     OP_REQUIRES(ctx,
                 TensorShapeUtils::IsScalar(l2.shape()) &&
-                    l2.scalar<T>()() >= static_cast<T>(0),
+                    (!std::is_same<Device, CPUDevice>::value ||
+                     l2.scalar<T>()() >= static_cast<T>(0)),
                 errors::InvalidArgument("l2 regularization strength is not a "
                                         "non-negative scalar: ",
                                         l2.shape().DebugString()));
@@ -2106,77 +2125,12 @@
                 errors::InvalidArgument(
                     "Inner dimension should be greater than zero."));
 
-    if (N > 0) {
-      if (inner_dim > 1) {
-        const Tindex first_dim_size = var.dim_size(0);
-        auto indices_vec = indices.vec<Tindex>();
-        auto var_flat = var.flat_outer_dims<T>();
-        auto accum_flat = accum.flat_outer_dims<T>();
-        auto grad_flat = grad.flat_outer_dims<T>();
-        T lr_scalar = lr.scalar<T>()();
-        T l1_scalar = l1.scalar<T>()();
-        T l2_scalar = l2.scalar<T>()();
-
-        for (Tindex i = 0; i < N; i++) {
-          const Tindex index = internal::SubtleMustCopy(indices_vec(i));
-          OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
-                      errors::InvalidArgument(
-                          strings::StrCat("Index ", index, " at offset ", i,
-                                          " in indices is out of range")));
-          auto a = accum_flat.template chip<0>(index);
-          auto g = grad_flat.template chip<0>(i);
-          auto v = var_flat.template chip<0>(index);
-          a += g.square();
-          // compute learning_rate for current step.
-          auto learning_rate = a.constant(lr_scalar) * a.rsqrt();
-          auto prox_v = v;
-          // v = w - g * learning_rate.
-          prox_v -= g * learning_rate;
-          if (l1_scalar > 0) {
-            // compute sign(v) * max(|v|, 0)
-            v = prox_v.sign() *
-                (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
-                    .cwiseMax(static_cast<T>(0.0)) /
-                (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
-          } else {
-            v = prox_v /
-                (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
-          }
-        }
-      } else {
-        auto indices_vec = indices.vec<Tindex>();
-        auto var_flat = var.flat<T>();
-        auto accum_flat = accum.flat<T>();
-        auto grad_flat = grad.flat<T>();
-        T lr_scalar = lr.scalar<T>()();
-        T l1_scalar = l1.scalar<T>()();
-        T l2_scalar = l2.scalar<T>()();
-        const Tindex first_dim_size = accum_flat.size();
-
-        for (Tindex i = 0; i < N; i++) {
-          const Tindex index = internal::SubtleMustCopy(indices_vec(i));
-          OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
-                      errors::InvalidArgument(
-                          strings::StrCat("Index ", index, " at offset ", i,
-                                          " in indices is out of range")));
-          T& a = accum_flat(index);
-          const T& g = grad_flat(i);
-          a += g * g;
-          auto learning_rate = lr_scalar / std::sqrt(a);
-          auto prox_v = var_flat(index);
-          prox_v -= learning_rate * g;
-          if (l1_scalar > 0) {
-            var_flat(index) =
-                sgn(prox_v) *
-                std::max(std::abs(prox_v) - learning_rate * l1_scalar,
-                         static_cast<T>(0.0)) /
-                (1.0 + l2_scalar * learning_rate);
-          } else {
-            var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate);
-          }
-        }
-      }
-    }
+    const Device& device = ctx->template eigen_device<Device>();
+    OP_REQUIRES_OK(
+        ctx, functor::SparseApplyProximalAdagrad<Device, T, Tindex>()(
+                 device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(),
+                 lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(),
+                 grad.flat_outer_dims<T>(), indices.vec<Tindex>(), inner_dim));
 
     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   }
@@ -2185,22 +2139,25 @@
   bool use_exclusive_lock_;
 };
 
-#define REGISTER_KERNELS(T, Tindices)                                 \
-  REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalAdagrad")          \
-                              .Device(DEVICE_CPU)                     \
-                              .TypeConstraint<T>("T")                 \
-                              .TypeConstraint<Tindices>("Tindices"),  \
-                          SparseApplyProximalAdagradOp<T, Tindices>); \
-  REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalAdagrad")  \
-                              .Device(DEVICE_CPU)                     \
-                              .TypeConstraint<T>("T")                 \
-                              .TypeConstraint<Tindices>("Tindices"),  \
-                          SparseApplyProximalAdagradOp<T, Tindices>);
+#define REGISTER_KERNELS(D, T, Tindices)                     \
+  REGISTER_KERNEL_BUILDER(                                   \
+      Name("SparseApplyProximalAdagrad")                     \
+          .Device(DEVICE_##D)                                \
+          .TypeConstraint<T>("T")                            \
+          .TypeConstraint<Tindices>("Tindices"),             \
+      SparseApplyProximalAdagradOp<D##Device, T, Tindices>); \
+  REGISTER_KERNEL_BUILDER(                                   \
+      Name("ResourceSparseApplyProximalAdagrad")             \
+          .Device(DEVICE_##D)                                \
+          .TypeConstraint<T>("T")                            \
+          .TypeConstraint<Tindices>("Tindices"),             \
+      SparseApplyProximalAdagradOp<D##Device, T, Tindices>);
 
-REGISTER_KERNELS(float, int32);
-REGISTER_KERNELS(float, int64);
-REGISTER_KERNELS(double, int32);
-REGISTER_KERNELS(double, int64);
+REGISTER_KERNELS(CPU, float, int32);
+REGISTER_KERNELS(CPU, float, int64);
+REGISTER_KERNELS(CPU, double, int32);
+REGISTER_KERNELS(CPU, double, int64);
+
 #undef REGISTER_KERNELS
 
 template <typename Device, typename T>
@@ -2714,7 +2671,6 @@
 #undef REGISTER_CPU_KERNELS
 #undef REGISTER_KERNELS
 
-// Note, this op works on cpu only.
 template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage>
 class SparseApplyFtrlOp : public OpKernel {
  public:
@@ -2767,11 +2723,16 @@
     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
                 errors::InvalidArgument("indices must be one-dimensional"));
 
+    // Note: The range checks on lr, l1, l2, and lr_power below are disabled
+    // for non-CPU devices because their values cannot be accessed directly from
+    // the host. The GPU kernel will not crash if these conditions are not met,
+    // it will simply produce a bogus answer (possibly inf/nan).
     const Tensor& lr = ctx->input(5);
     OP_REQUIRES(
         ctx,
         TensorShapeUtils::IsScalar(lr.shape()) &&
-            (lr.scalar<T>()() > static_cast<T>(0) ||
+            (!std::is_same<Device, CPUDevice>::value ||
+             lr.scalar<T>()() > static_cast<T>(0) ||
              (multiply_linear_by_lr_ && lr.scalar<T>()() >= static_cast<T>(0))),
         errors::InvalidArgument("lr is not a positive scalar (or zero if "
                                 "multiply_linear_by_lr is set): ",
@@ -2780,14 +2741,16 @@
     const Tensor& l1 = ctx->input(6);
     OP_REQUIRES(ctx,
                 TensorShapeUtils::IsScalar(l1.shape()) &&
-                    l1.scalar<T>()() >= static_cast<T>(0),
+                    (!std::is_same<Device, CPUDevice>::value ||
+                     l1.scalar<T>()() >= static_cast<T>(0)),
                 errors::InvalidArgument("l1 regularization strength is not a "
                                         "non-negative scalar: ",
                                         l1.shape().DebugString()));
     const Tensor& l2 = ctx->input(7);
     OP_REQUIRES(ctx,
                 TensorShapeUtils::IsScalar(l2.shape()) &&
-                    l2.scalar<T>()() >= static_cast<T>(0),
+                    (!std::is_same<Device, CPUDevice>::value ||
+                     l2.scalar<T>()() >= static_cast<T>(0)),
                 errors::InvalidArgument("l2 regularization strength is not a "
                                         "non-negative scalar: ",
                                         l2.shape().DebugString()));
@@ -2795,7 +2758,8 @@
     const Tensor& lr_power = ctx->input(lr_power_index);
     OP_REQUIRES(ctx,
                 TensorShapeUtils::IsScalar(lr_power.shape()) &&
-                    lr_power.scalar<T>()() <= static_cast<T>(0),
+                    (!std::is_same<Device, CPUDevice>::value ||
+                     lr_power.scalar<T>()() <= static_cast<T>(0)),
                 errors::InvalidArgument("lr_power is not a "
                                         "non-positive scalar: ",
                                         lr_power.shape().DebugString()));
@@ -2822,7 +2786,8 @@
       OP_REQUIRES(
           ctx,
           TensorShapeUtils::IsScalar(l2_shrinkage->shape()) &&
-              l2_shrinkage->scalar<T>()() >= static_cast<T>(0),
+              (!std::is_same<Device, CPUDevice>::value ||
+               l2_shrinkage->scalar<T>()() >= static_cast<T>(0)),
           errors::InvalidArgument("l2 shrinkage regularization strength "
                                   "is not a non-negative scalar: ",
                                   l2_shrinkage->shape().DebugString()));
@@ -2849,22 +2814,22 @@
   bool multiply_linear_by_lr_;
 };
 
-#define REGISTER_KERNELS(T, Tindices)                                         \
+#define REGISTER_KERNELS(D, T, Tindices)                                      \
   REGISTER_KERNEL_BUILDER(                                                    \
       Name("SparseApplyFtrl")                                                 \
-          .Device(DEVICE_CPU)                                                 \
+          .Device(DEVICE_##D)                                                 \
           .TypeConstraint<T>("T")                                             \
           .TypeConstraint<Tindices>("Tindices"),                              \
-      SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/false>); \
+      SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/false>); \
   REGISTER_KERNEL_BUILDER(                                                    \
       Name("ResourceSparseApplyFtrl")                                         \
-          .Device(DEVICE_CPU)                                                 \
+          .Device(DEVICE_##D)                                                 \
           .TypeConstraint<T>("T")                                             \
           .TypeConstraint<Tindices>("Tindices"),                              \
-      SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/false>);
-#define REGISTER_CPU_KERNELS(T) \
-  REGISTER_KERNELS(T, int32);   \
-  REGISTER_KERNELS(T, int64);
+      SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/false>);
+#define REGISTER_CPU_KERNELS(T)    \
+  REGISTER_KERNELS(CPU, T, int32); \
+  REGISTER_KERNELS(CPU, T, int64);
 
 TF_CALL_half(REGISTER_CPU_KERNELS);
 TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
@@ -2872,24 +2837,59 @@
 TF_CALL_double(REGISTER_CPU_KERNELS);
 
 #undef REGISTER_CPU_KERNELS
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T, Tindex)                                           \
+  template <>                                                                 \
+  Status SparseApplyFtrl<GPUDevice, T, Tindex, /*has_l2_shrinkage=*/false>::  \
+  operator()(                                                                 \
+      const GPUDevice& d, typename TTypes<T>::Matrix var,                     \
+      typename TTypes<T>::Matrix accum, typename TTypes<T>::Matrix linear,    \
+      typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar l1, \
+      typename TTypes<T>::ConstScalar l2,                                     \
+      typename TTypes<T>::ConstScalar l2_shrinkage,                           \
+      typename TTypes<T>::ConstScalar lr_power,                               \
+      typename TTypes<T>::ConstMatrix grad,                                   \
+      typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,             \
+      bool multiply_linear_by_lr);                                            \
+  extern template struct SparseApplyFtrl<GPUDevice, T, Tindex,                \
+                                         /*has_l2_shrinkage=*/false>;
+DECLARE_GPU_SPEC(Eigen::half, int32);
+DECLARE_GPU_SPEC(Eigen::half, int64);
+DECLARE_GPU_SPEC(float, int32);
+DECLARE_GPU_SPEC(float, int64);
+DECLARE_GPU_SPEC(double, int32);
+DECLARE_GPU_SPEC(double, int64);
+#undef DECLARE_GPU_SPEC
+}  // namespace functor
+
+REGISTER_KERNELS(GPU, Eigen::half, int32);
+REGISTER_KERNELS(GPU, Eigen::half, int64);
+REGISTER_KERNELS(GPU, float, int32);
+REGISTER_KERNELS(GPU, float, int64);
+REGISTER_KERNELS(GPU, double, int32);
+REGISTER_KERNELS(GPU, double, int64);
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #undef REGISTER_KERNELS
 
-#define REGISTER_KERNELS(T, Tindices)                                        \
+#define REGISTER_KERNELS(D, T, Tindices)                                     \
   REGISTER_KERNEL_BUILDER(                                                   \
       Name("SparseApplyFtrlV2")                                              \
-          .Device(DEVICE_CPU)                                                \
+          .Device(DEVICE_##D)                                                \
           .TypeConstraint<T>("T")                                            \
           .TypeConstraint<Tindices>("Tindices"),                             \
-      SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/true>); \
+      SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/true>); \
   REGISTER_KERNEL_BUILDER(                                                   \
       Name("ResourceSparseApplyFtrlV2")                                      \
-          .Device(DEVICE_CPU)                                                \
+          .Device(DEVICE_##D)                                                \
           .TypeConstraint<T>("T")                                            \
           .TypeConstraint<Tindices>("Tindices"),                             \
-      SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/true>);
-#define REGISTER_CPU_KERNELS(T) \
-  REGISTER_KERNELS(T, int32);   \
-  REGISTER_KERNELS(T, int64);
+      SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/true>);
+#define REGISTER_CPU_KERNELS(T)    \
+  REGISTER_KERNELS(CPU, T, int32); \
+  REGISTER_KERNELS(CPU, T, int64);
 
 TF_CALL_half(REGISTER_CPU_KERNELS);
 TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
@@ -2897,6 +2897,41 @@
 TF_CALL_double(REGISTER_CPU_KERNELS);
 
 #undef REGISTER_CPU_KERNELS
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T, Tindex)                                           \
+  template <>                                                                 \
+  Status SparseApplyFtrl<GPUDevice, T, Tindex, /*has_l2_shrinkage=*/true>::   \
+  operator()(                                                                 \
+      const GPUDevice& d, typename TTypes<T>::Matrix var,                     \
+      typename TTypes<T>::Matrix accum, typename TTypes<T>::Matrix linear,    \
+      typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar l1, \
+      typename TTypes<T>::ConstScalar l2,                                     \
+      typename TTypes<T>::ConstScalar l2_shrinkage,                           \
+      typename TTypes<T>::ConstScalar lr_power,                               \
+      typename TTypes<T>::ConstMatrix grad,                                   \
+      typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,             \
+      bool multiply_linear_by_lr);                                            \
+  extern template struct SparseApplyFtrl<GPUDevice, T, Tindex,                \
+                                         /*has_l2_shrinkage=*/true>;
+DECLARE_GPU_SPEC(Eigen::half, int32);
+DECLARE_GPU_SPEC(Eigen::half, int64);
+DECLARE_GPU_SPEC(float, int32);
+DECLARE_GPU_SPEC(float, int64);
+DECLARE_GPU_SPEC(double, int32);
+DECLARE_GPU_SPEC(double, int64);
+#undef DECLARE_GPU_SPEC
+}  // namespace functor
+
+REGISTER_KERNELS(GPU, Eigen::half, int32);
+REGISTER_KERNELS(GPU, Eigen::half, int64);
+REGISTER_KERNELS(GPU, float, int32);
+REGISTER_KERNELS(GPU, float, int64);
+REGISTER_KERNELS(GPU, double, int32);
+REGISTER_KERNELS(GPU, double, int64);
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #undef REGISTER_KERNELS
 
 template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h
index 5af6030..886d676 100644
--- a/tensorflow/core/kernels/training_ops.h
+++ b/tensorflow/core/kernels/training_ops.h
@@ -92,6 +92,18 @@
                   typename TTypes<T>::ConstFlat grad);
 };
 
+template <typename Device, typename T, typename Tindex, bool has_epsilon>
+struct SparseApplyAdagrad {
+  // Note that epsilon is ignored if has_epsilon is false.
+  Status operator()(const Device& d, typename TTypes<T>::Matrix var,
+                    typename TTypes<T>::Matrix accum,
+                    typename TTypes<T>::ConstScalar lr,
+                    typename TTypes<T>::ConstScalar epsilon,
+                    typename TTypes<T>::ConstMatrix grad,
+                    typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
+                    bool update_slots);
+};
+
 template <typename Device, typename T>
 struct ApplyProximalAdagrad {
   void operator()(const Device& d, typename TTypes<T>::Flat var,
@@ -102,6 +114,17 @@
                   typename TTypes<T>::ConstFlat grad);
 };
 
+template <typename Device, typename T, typename Tindex>
+struct SparseApplyProximalAdagrad {
+  Status operator()(const Device& d, typename TTypes<T>::Matrix var,
+                    typename TTypes<T>::Matrix accum,
+                    typename TTypes<T>::ConstScalar lr,
+                    typename TTypes<T>::ConstScalar l1,
+                    typename TTypes<T>::ConstScalar l2,
+                    typename TTypes<T>::ConstMatrix grad,
+                    typename TTypes<Tindex>::ConstVec indices, int64 inner_dim);
+};
+
 template <typename Device, typename T>
 struct ApplyFtrl {
   void operator()(const Device& d, typename TTypes<T>::Flat var,
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
index 64df241..d0c3ba4 100644
--- a/tensorflow/core/kernels/training_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -27,6 +27,72 @@
 
 namespace functor {
 
+template <typename T, typename Tindex, bool has_l2_shrinkage>
+__global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
+                                      const T* l1, const T* l2,
+                                      const T* l2_shrinkage, const T* lr_power,
+                                      const T* grad, const Tindex* indices,
+                                      Tindex param_rows, Tindex updates_size,
+                                      Tindex indices_size,
+                                      bool multiply_linear_by_lr) {
+  const Tindex col_size = updates_size / indices_size;
+  GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
+    const Tindex indices_row = grad_index / col_size;
+    const Tindex param_row = indices[indices_row];
+    if (param_row < 0 || param_row >= param_rows) {
+      // Ignore indices that are out of range.
+      continue;
+    }
+
+    // Compute the index of var and accum.
+    const Tindex param_index = param_row * col_size + (grad_index % col_size);
+
+    // Read variables.
+    T var_i = var[param_index];
+    T accum_i = accum[param_index];
+    T linear_i = linear[param_index];
+    const T grad_i = grad[grad_index];
+    const T lr_t = *lr;
+    const T l1_t = *l1;
+    const T l2_t = *l2;
+    const T lr_power_t = *lr_power;
+
+    const T grad_shr_i =
+        has_l2_shrinkage ? grad_i + static_cast<T>(2) * (*l2_shrinkage) * var_i
+                         : grad_i;
+    const T new_accum_i = accum_i + grad_i * grad_i;
+    const bool lr_power_is_neg_half = lr_power_t == static_cast<T>(-0.5);
+    const T pow_new_accum = lr_power_is_neg_half
+                                ? sqrt(new_accum_i)
+                                : pow(new_accum_i, -lr_power_t);
+    const T pow_accum =
+        lr_power_is_neg_half ? sqrt(accum_i) : pow(accum_i, -lr_power_t);
+    T linear_change = grad_shr_i * lr_t - (pow_new_accum - pow_accum) * var_i;
+    if (!multiply_linear_by_lr) {
+      linear_change /= lr_t;
+    }
+    linear_i += linear_change;
+
+    T l1_mult = l1_t;
+    if (multiply_linear_by_lr) {
+      l1_mult *= lr_t;
+    }
+    const T l1_reg_adjust = max(min(linear_i, l1_mult), -l1_mult);
+    const T x = l1_reg_adjust - linear_i;
+    T y = pow_new_accum + static_cast<T>(2) * l2_t * lr_t;
+    if (!multiply_linear_by_lr) {
+      y /= lr_t;
+    }
+    var_i = x / y;
+    accum_i = new_accum_i;
+
+    // Write update back to variables.
+    var[param_index] = var_i;
+    accum[param_index] = accum_i;
+    linear[param_index] = linear_i;
+  }
+}
+
 template <typename T>
 __global__ __launch_bounds__(1024) void ApplyAdamKernel(
     int32 data_dim, T* var, T* m, T* v, const T* const beta1_power_,
@@ -573,6 +639,37 @@
   }
 };
 
+template <typename T, typename Tindex, bool has_l2_shrinkage>
+struct SparseApplyFtrl<GPUDevice, T, Tindex, has_l2_shrinkage> {
+  Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
+                    typename TTypes<T>::Matrix accum,
+                    typename TTypes<T>::Matrix linear,
+                    typename TTypes<T>::ConstScalar lr,
+                    typename TTypes<T>::ConstScalar l1,
+                    typename TTypes<T>::ConstScalar l2,
+                    typename TTypes<T>::ConstScalar l2_shrinkage,
+                    typename TTypes<T>::ConstScalar lr_power,
+                    typename TTypes<T>::ConstMatrix grad,
+                    typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
+                    bool multiply_linear_by_lr) {
+    const Tindex first_dim_size = var.dimension(0);
+    const Tindex grad_size = grad.size();
+    const Tindex indices_size = indices.size();
+    GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
+    return GpuLaunchKernel(
+        SparseApplyFtrlKernel<T, Tindex, has_l2_shrinkage>, config.block_count,
+        config.thread_per_block, 0, d.stream(), /*var=*/var.data(),
+        /*accum=*/accum.data(),
+        /*linear=*/linear.data(), /*lr=*/lr.data(), /*l1=*/l1.data(),
+        /*l2=*/l2.data(), /*l2_shrinkage=*/l2_shrinkage.data(),
+        /*lr_power=*/lr_power.data(), /*grad=*/grad.data(),
+        /*indices=*/indices.data(), /*param_rows=*/first_dim_size,
+        /*updates_size=*/grad_size,
+        /*indices_size=*/indices_size,
+        /*multiply_linear_by_lr=*/multiply_linear_by_lr);
+  }
+};
+
 template <typename T>
 struct ApplyMomentum<GPUDevice, T> {
   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
@@ -905,6 +1002,20 @@
 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, float>;
 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, double>;
 
+#define EXPLICITLY_INSTANTIATE_FUNCTOR(T)                               \
+  template struct functor::SparseApplyFtrl<GPUDevice, T, int32,         \
+                                           /*has_l2_shrinkage=*/false>; \
+  template struct functor::SparseApplyFtrl<GPUDevice, T, int64,         \
+                                           /*has_l2_shrinkage=*/false>; \
+  template struct functor::SparseApplyFtrl<GPUDevice, T, int32,         \
+                                           /*has_l2_shrinkage=*/true>;  \
+  template struct functor::SparseApplyFtrl<GPUDevice, T, int64,         \
+                                           /*has_l2_shrinkage=*/true>
+EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
+EXPLICITLY_INSTANTIATE_FUNCTOR(float);
+EXPLICITLY_INSTANTIATE_FUNCTOR(double);
+#undef EXPLICITLY_INSTANTIATE_FUNCTOR
+
 template struct functor::ApplyMomentum<GPUDevice, Eigen::half>;
 template struct functor::ApplyMomentum<GPUDevice, float>;
 template struct functor::ApplyMomentum<GPUDevice, double>;
diff --git a/tensorflow/core/lib/io/random_inputstream.cc b/tensorflow/core/lib/io/random_inputstream.cc
index 0f07b5f..6f931a8 100644
--- a/tensorflow/core/lib/io/random_inputstream.cc
+++ b/tensorflow/core/lib/io/random_inputstream.cc
@@ -55,9 +55,10 @@
   if (bytes_to_read < 0) {
     return errors::InvalidArgument("Cannot read negative number of bytes");
   }
+  int64 current_size = result->size();
   Status s = file_->Read(pos_, bytes_to_read, result);
   if (s.ok() || errors::IsOutOfRange(s)) {
-    pos_ += result->size();
+    pos_ += result->size() - current_size;
   }
   return s;
 }
diff --git a/tensorflow/core/lib/io/random_inputstream_test.cc b/tensorflow/core/lib/io/random_inputstream_test.cc
index 2fb325b..58d4b9b 100644
--- a/tensorflow/core/lib/io/random_inputstream_test.cc
+++ b/tensorflow/core/lib/io/random_inputstream_test.cc
@@ -52,6 +52,39 @@
   EXPECT_EQ(10, in.Tell());
 }
 
+#if defined(TF_CORD_SUPPORT)
+TEST(RandomInputStream, ReadNBytesWithCords) {
+  Env* env = Env::Default();
+  string fname = testing::TmpDir() + "/random_inputbuffer_test";
+  TF_ASSERT_OK(WriteStringToFile(env, fname, "0123456789"));
+
+  std::unique_ptr<RandomAccessFile> file;
+  TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
+  absl::Cord read;
+  RandomAccessInputStream in(file.get());
+
+  // Reading into `absl::Cord`s does not clear existing data from the cord.
+  TF_ASSERT_OK(in.ReadNBytes(3, &read));
+  EXPECT_EQ(read, "012");
+  EXPECT_EQ(3, in.Tell());
+  TF_ASSERT_OK(in.ReadNBytes(0, &read));
+  EXPECT_EQ(read, "012");
+  EXPECT_EQ(3, in.Tell());
+  TF_ASSERT_OK(in.ReadNBytes(5, &read));
+  EXPECT_EQ(read, "01234567");
+  EXPECT_EQ(8, in.Tell());
+  TF_ASSERT_OK(in.ReadNBytes(0, &read));
+  EXPECT_EQ(read, "01234567");
+  EXPECT_EQ(8, in.Tell());
+  EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(20, &read)));
+  EXPECT_EQ(read, "0123456789");
+  EXPECT_EQ(10, in.Tell());
+  TF_ASSERT_OK(in.ReadNBytes(0, &read));
+  EXPECT_EQ(read, "0123456789");
+  EXPECT_EQ(10, in.Tell());
+}
+#endif
+
 TEST(RandomInputStream, SkipNBytes) {
   Env* env = Env::Default();
   string fname = testing::TmpDir() + "/random_inputbuffer_test";
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
index b307a55..992c3a9 100644
--- a/tensorflow/core/lib/io/record_writer.cc
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -61,7 +61,7 @@
                            const RecordWriterOptions& options)
     : dest_(dest), options_(options) {
 #if defined(IS_SLIM_BUILD)
-  if (compression_type != compression::kNone) {
+  if (options.compression_type != RecordWriterOptions::NONE) {
     LOG(FATAL) << "Compression is unsupported on mobile platforms.";
   }
 #else
diff --git a/tensorflow/core/lib/io/zlib_inputstream.cc b/tensorflow/core/lib/io/zlib_inputstream.cc
index 7ea8508..2939346 100644
--- a/tensorflow/core/lib/io/zlib_inputstream.cc
+++ b/tensorflow/core/lib/io/zlib_inputstream.cc
@@ -228,6 +228,17 @@
   return Status::OK();
 }
 
+#if defined(TF_CORD_SUPPORT)
+Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, absl::Cord* result) {
+  // TODO(frankchn): Optimize this instead of bouncing through the buffer.
+  tstring buf;
+  TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &buf));
+  result->Clear();
+  result->Append(buf.data());
+  return Status::OK();
+}
+#endif
+
 int64 ZlibInputStream::Tell() const { return bytes_read_; }
 
 Status ZlibInputStream::Inflate() {
diff --git a/tensorflow/core/lib/io/zlib_inputstream.h b/tensorflow/core/lib/io/zlib_inputstream.h
index 427daa7..da9c3de 100644
--- a/tensorflow/core/lib/io/zlib_inputstream.h
+++ b/tensorflow/core/lib/io/zlib_inputstream.h
@@ -68,6 +68,10 @@
   // others:       If reading from stream failed.
   Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
 
+#if defined(TF_CORD_SUPPORT)
+  Status ReadNBytes(int64 bytes_to_read, absl::Cord* result) override;
+#endif
+
   int64 Tell() const override;
 
   Status Reset() override;
diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD
index 9b1447f..70fcce9 100644
--- a/tensorflow/core/nccl/BUILD
+++ b/tensorflow/core/nccl/BUILD
@@ -64,8 +64,6 @@
         "manual",
         "multi_gpu",
         "no_oss",
-        # TODO(b/147451637): Replace 'no_rocm' with 'rocm_multi_gpu'.
-        "no_rocm",
         "notap",
     ],
     deps = [
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 2018f79..ad673d3 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1441,6 +1441,50 @@
 #endif  // INTEL_MKL
 
 // --------------------------------------------------------------------------
+namespace {
+Status UniqueIdxShapeFn(InferenceContext* c) {
+  ShapeHandle input = c->input(0);
+  const Tensor* axis_t = c->input_tensor(1);
+  if (axis_t == nullptr || !c->RankKnown(input)) {
+    c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
+    return Status::OK();
+  }
+
+  if (c->Rank(c->input(1)) != 1) {
+    return errors::InvalidArgument("axis expects a 1D vector.");
+  }
+
+  int32 n = axis_t->NumElements();
+  if (n == 0) {
+    if (c->Rank(input) != 1) {
+      return errors::InvalidArgument("x expects a 1D vector.");
+    }
+    c->set_output(1, input);
+    return Status::OK();
+  } else if (n == 1) {
+    int64 axis;
+    if (axis_t->dtype() == DT_INT32) {
+      axis = static_cast<int64>(axis_t->flat<int32>()(0));
+    } else {
+      axis = axis_t->flat<int64>()(0);
+    }
+
+    int64 input_rank = c->Rank(input);
+    if (axis < -input_rank || axis >= input_rank) {
+      return errors::InvalidArgument("axis expects to be in the range [",
+                                     -input_rank, ", ", input_rank, ")");
+    }
+    if (axis < 0) {
+      axis += input_rank;
+    }
+    c->set_output(1, c->Vector(c->Dim(input, axis)));
+    return Status::OK();
+  }
+  return errors::InvalidArgument(
+      "axis does not support input tensors larger than 1 elements.");
+}
+}  // namespace
+
 REGISTER_OP("Unique")
     .Input("x: T")
     .Output("y: T")
@@ -1465,7 +1509,7 @@
     .Attr("out_idx: {int32, int64} = DT_INT32")
     .SetShapeFn([](InferenceContext* c) {
       c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
-      c->set_output(1, c->input(0));
+      TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
       return Status::OK();
     });
 
@@ -1496,7 +1540,7 @@
     .Attr("out_idx: {int32, int64} = DT_INT32")
     .SetShapeFn([](InferenceContext* c) {
       c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
-      c->set_output(1, c->input(0));
+      TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
       return Status::OK();
     });
diff --git a/tensorflow/core/ops/collective_ops.cc b/tensorflow/core/ops/collective_ops.cc
index 9033c71..e58db89 100644
--- a/tensorflow/core/ops/collective_ops.cc
+++ b/tensorflow/core/ops/collective_ops.cc
@@ -145,4 +145,35 @@
       return Status::OK();
     });
 
+REGISTER_OP("CollectiveBcastSendV2")
+    .Input("input: T")
+    .Output("data: T")
+    .Attr("T: {bool, float, float16, float64, int32, int64}")
+    .Input("group_size: int32")
+    .Input("group_key: int32")
+    .Input("instance_key: int32")
+    .Attr("communication_hint: string = 'auto'")
+    .Attr("timeout_seconds: float = 0")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("CollectiveBcastRecvV2")
+    .Output("data: T")
+    .Attr("T: {bool, float, float16, float64, int32, int64}")
+    .Input("group_size: int32")
+    .Input("group_key: int32")
+    .Input("instance_key: int32")
+    .Input("shape: Tshape")
+    .Attr("Tshape: {int32, int64} = DT_INT32")
+    .Attr("communication_hint: string = 'auto'")
+    .Attr("timeout_seconds: float = 0")
+    .SetIsStateful()
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      // The output shape is given by the `shape` input at index 3.
+      shape_inference::ShapeHandle out;
+      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(/*input_idx=*/3, &out));
+      c->set_output(/*idx=*/0, out);
+      return Status::OK();
+    });
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history_v2/CollectiveBcastRecvV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/CollectiveBcastRecvV2.pbtxt
new file mode 100644
index 0000000..3d1c3de
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/CollectiveBcastRecvV2.pbtxt
@@ -0,0 +1,65 @@
+op {
+  name: "CollectiveBcastRecvV2"
+  input_arg {
+    name: "group_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "group_key"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "instance_key"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "Tshape"
+  }
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BOOL
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tshape"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "communication_hint"
+    type: "string"
+    default_value {
+      s: "auto"
+    }
+  }
+  attr {
+    name: "timeout_seconds"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/CollectiveBcastSendV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/CollectiveBcastSendV2.pbtxt
new file mode 100644
index 0000000..c9af70d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/CollectiveBcastSendV2.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "CollectiveBcastSendV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "group_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "group_key"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "instance_key"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BOOL
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "communication_hint"
+    type: "string"
+    default_value {
+      s: "auto"
+    }
+  }
+  attr {
+    name: "timeout_seconds"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscBinaryArithmetic.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscBinaryArithmetic.pbtxt
new file mode 100644
index 0000000..9d5f080
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscBinaryArithmetic.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "RiscBinaryArithmetic"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "op_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "ADD"
+        s: "SUB"
+        s: "MUL"
+        s: "DIV"
+        s: "REM"
+        s: "MIN"
+        s: "POW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscBinaryComparison.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscBinaryComparison.pbtxt
new file mode 100644
index 0000000..d131476
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscBinaryComparison.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "RiscBinaryComparison"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "op_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "EQ"
+        s: "NE"
+        s: "GE"
+        s: "GT"
+        s: "LE"
+        s: "LT"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscBitcast.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscBitcast.pbtxt
new file mode 100644
index 0000000..1d37369
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscBitcast.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "RiscBitcast"
+  input_arg {
+    name: "x"
+    type_attr: "SrcT"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "DstT"
+  }
+  attr {
+    name: "SrcT"
+    type: "type"
+  }
+  attr {
+    name: "DstT"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscCast.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscCast.pbtxt
new file mode 100644
index 0000000..344d049
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscCast.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "RiscCast"
+  input_arg {
+    name: "x"
+    type_attr: "SrcT"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "DstT"
+  }
+  attr {
+    name: "SrcT"
+    type: "type"
+  }
+  attr {
+    name: "DstT"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscCholesky.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscCholesky.pbtxt
new file mode 100644
index 0000000..c6b24d1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscCholesky.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "RiscCholesky"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscCondition.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscCondition.pbtxt
new file mode 100644
index 0000000..859814c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscCondition.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "RiscCondition"
+  input_arg {
+    name: "pred"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "input_true"
+    type_attr: "SrcT"
+  }
+  input_arg {
+    name: "input_false"
+    type_attr: "SrcT"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "DstT"
+  }
+  attr {
+    name: "func_true"
+    type: "func"
+  }
+  attr {
+    name: "func_false"
+    type: "func"
+  }
+  attr {
+    name: "SrcT"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "DstT"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscFft.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscFft.pbtxt
new file mode 100644
index 0000000..605cd7e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscFft.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "RiscFft"
+  input_arg {
+    name: "input"
+    type_attr: "Tcomplex"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tcomplex"
+  }
+  attr {
+    name: "Tcomplex"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscGather.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscGather.pbtxt
new file mode 100644
index 0000000..18d4ba3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscGather.pbtxt
@@ -0,0 +1,50 @@
+op {
+  name: "RiscGather"
+  input_arg {
+    name: "params"
+    type_attr: "Tparams"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Taxis"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tparams"
+  }
+  attr {
+    name: "batch_dims"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "Tparams"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Taxis"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscIsFinite.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscIsFinite.pbtxt
new file mode 100644
index 0000000..19a4ae6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscIsFinite.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "RiscIsFinite"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscLogicalAnd.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscLogicalAnd.pbtxt
new file mode 100644
index 0000000..8bd4410
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscLogicalAnd.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "RiscLogicalAnd"
+  input_arg {
+    name: "x"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscLogicalNot.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscLogicalNot.pbtxt
new file mode 100644
index 0000000..3496ef0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscLogicalNot.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "RiscLogicalNot"
+  input_arg {
+    name: "x"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscLogicalOr.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscLogicalOr.pbtxt
new file mode 100644
index 0000000..3cf3192
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscLogicalOr.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "RiscLogicalOr"
+  input_arg {
+    name: "x"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscRandomUniform.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscRandomUniform.pbtxt
new file mode 100644
index 0000000..2d3cd00
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscRandomUniform.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "RiscRandomUniform"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscReduce.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscReduce.pbtxt
new file mode 100644
index 0000000..1dff780
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscReduce.pbtxt
@@ -0,0 +1,50 @@
+op {
+  name: "RiscReduce"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Index"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "reduce_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "MEAN"
+        s: "SUM"
+      }
+    }
+  }
+  attr {
+    name: "Index"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscReverse.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscReverse.pbtxt
new file mode 100644
index 0000000..60dec4d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscReverse.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "RiscReverse"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscScatter.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscScatter.pbtxt
new file mode 100644
index 0000000..5def9d1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscScatter.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "RiscScatter"
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscSort.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscSort.pbtxt
new file mode 100644
index 0000000..c49a695
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscSort.pbtxt
@@ -0,0 +1,50 @@
+op {
+  name: "RiscSort"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Index"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "Index"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    allowed_values {
+      list {
+        s: "ASCENDING"
+        s: "DESCENDING"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscSqueeze.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscSqueeze.pbtxt
new file mode 100644
index 0000000..bf4e712
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscSqueeze.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "RiscSqueeze"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "squeeze_dims"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscTranspose.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscTranspose.pbtxt
new file mode 100644
index 0000000..856b0d6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscTranspose.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "RiscTranspose"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "perm"
+    type_attr: "Tperm"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tperm"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscTriangularSolve.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscTriangularSolve.pbtxt
new file mode 100644
index 0000000..5b8518f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscTriangularSolve.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "RiscTriangularSolve"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "lower"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscUnary.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscUnary.pbtxt
new file mode 100644
index 0000000..0a7af35
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscUnary.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "RiscUnary"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "op_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "ABL"
+        s: "CEIL"
+        s: "COS"
+        s: "EXP"
+        s: "FLOOR"
+        s: "IMAG"
+        s: "LOG"
+        s: "NEG"
+        s: "REAL"
+        s: "SIGN"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/RiscWhile.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RiscWhile.pbtxt
new file mode 100644
index 0000000..8bb4745
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/RiscWhile.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "RiscWhile"
+  input_arg {
+    name: "input"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "cond"
+    type: "func"
+  }
+  attr {
+    name: "body"
+    type: "func"
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "parallel_iterations"
+    type: "int"
+    default_value {
+      i: 10
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index cde668a..ca4f14c 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -7442,6 +7442,71 @@
   is_stateful: true
 }
 op {
+  name: "CollectiveBcastRecvV2"
+  input_arg {
+    name: "group_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "group_key"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "instance_key"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "Tshape"
+  }
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BOOL
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tshape"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "communication_hint"
+    type: "string"
+    default_value {
+      s: "auto"
+    }
+  }
+  attr {
+    name: "timeout_seconds"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "CollectiveBcastSend"
   input_arg {
     name: "input"
@@ -7498,6 +7563,58 @@
   is_stateful: true
 }
 op {
+  name: "CollectiveBcastSendV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "group_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "group_key"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "instance_key"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BOOL
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "communication_hint"
+    type: "string"
+    default_value {
+      s: "auto"
+    }
+  }
+  attr {
+    name: "timeout_seconds"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "CollectiveGather"
   input_arg {
     name: "input"
@@ -41518,6 +41635,108 @@
   is_commutative: true
 }
 op {
+  name: "RiscBinaryArithmetic"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "op_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "ADD"
+        s: "SUB"
+        s: "MUL"
+        s: "DIV"
+        s: "REM"
+        s: "MIN"
+        s: "POW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "RiscBinaryComparison"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "op_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "EQ"
+        s: "NE"
+        s: "GE"
+        s: "GT"
+        s: "LE"
+        s: "LT"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "RiscBitcast"
+  input_arg {
+    name: "x"
+    type_attr: "SrcT"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "DstT"
+  }
+  attr {
+    name: "SrcT"
+    type: "type"
+  }
+  attr {
+    name: "DstT"
+    type: "type"
+  }
+}
+op {
   name: "RiscBroadcast"
   input_arg {
     name: "input"
@@ -41550,6 +41769,48 @@
   }
 }
 op {
+  name: "RiscCast"
+  input_arg {
+    name: "x"
+    type_attr: "SrcT"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "DstT"
+  }
+  attr {
+    name: "SrcT"
+    type: "type"
+  }
+  attr {
+    name: "DstT"
+    type: "type"
+  }
+}
+op {
+  name: "RiscCholesky"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
   name: "RiscConcat"
   input_arg {
     name: "values"
@@ -41589,6 +41850,57 @@
   }
 }
 op {
+  name: "RiscCondition"
+  input_arg {
+    name: "pred"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "input_true"
+    type_attr: "SrcT"
+  }
+  input_arg {
+    name: "input_false"
+    type_attr: "SrcT"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "DstT"
+  }
+  attr {
+    name: "func_true"
+    type: "func"
+  }
+  attr {
+    name: "func_false"
+    type: "func"
+  }
+  attr {
+    name: "SrcT"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "DstT"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
   name: "RiscConv"
   input_arg {
     name: "input"
@@ -41686,6 +41998,144 @@
   }
 }
 op {
+  name: "RiscFft"
+  input_arg {
+    name: "input"
+    type_attr: "Tcomplex"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tcomplex"
+  }
+  attr {
+    name: "Tcomplex"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "RiscGather"
+  input_arg {
+    name: "params"
+    type_attr: "Tparams"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Taxis"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tparams"
+  }
+  attr {
+    name: "batch_dims"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "Tparams"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Taxis"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "RiscIsFinite"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "RiscLogicalAnd"
+  input_arg {
+    name: "x"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+}
+op {
+  name: "RiscLogicalNot"
+  input_arg {
+    name: "x"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+}
+op {
+  name: "RiscLogicalOr"
+  input_arg {
+    name: "x"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+}
+op {
   name: "RiscMax"
   input_arg {
     name: "x"
@@ -41815,6 +42265,84 @@
   }
 }
 op {
+  name: "RiscRandomUniform"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "RiscReduce"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Index"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "reduce_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "MEAN"
+        s: "SUM"
+      }
+    }
+  }
+  attr {
+    name: "Index"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
   name: "RiscReshape"
   input_arg {
     name: "tensor"
@@ -41855,6 +42383,87 @@
   }
 }
 op {
+  name: "RiscReverse"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "RiscScatter"
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
   name: "RiscShape"
   input_arg {
     name: "input"
@@ -41932,6 +42541,234 @@
   }
 }
 op {
+  name: "RiscSort"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Index"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "Index"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    allowed_values {
+      list {
+        s: "ASCENDING"
+        s: "DESCENDING"
+      }
+    }
+  }
+}
+op {
+  name: "RiscSqueeze"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "squeeze_dims"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+}
+op {
+  name: "RiscTranspose"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "perm"
+    type_attr: "Tperm"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tperm"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "RiscTriangularSolve"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "lower"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "RiscUnary"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "op_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "ABL"
+        s: "CEIL"
+        s: "COS"
+        s: "EXP"
+        s: "FLOOR"
+        s: "IMAG"
+        s: "LOG"
+        s: "NEG"
+        s: "REAL"
+        s: "SIGN"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "RiscWhile"
+  input_arg {
+    name: "input"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "cond"
+    type: "func"
+  }
+  attr {
+    name: "body"
+    type: "func"
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "parallel_iterations"
+    type: "int"
+    default_value {
+      i: 10
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "RngReadAndSkip"
   input_arg {
     name: "resource"
diff --git a/tensorflow/core/ops/risc_ops.cc b/tensorflow/core/ops/risc_ops.cc
index 7b09702..5e122fc 100644
--- a/tensorflow/core/ops/risc_ops.cc
+++ b/tensorflow/core/ops/risc_ops.cc
@@ -30,6 +30,31 @@
     .SetIsAggregate()
     .SetIsCommutative();
 
+// TODO(b/171294012): include RiscMax here as well.
+REGISTER_OP("RiscBinaryArithmetic")
+    .Input("x: T")
+    .Input("y: T")
+    .Output("z: T")
+    .Attr("op_type: {'ADD', 'SUB', 'MUL', 'DIV', 'REM', 'MIN', 'POW'}")
+    .Attr("T: {bfloat16, half, float, double}")
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("RiscBinaryComparison")
+    .Input("x: T")
+    .Input("y: T")
+    .Output("z: bool")
+    .Attr("op_type: {'EQ', 'NE', 'GE', 'GT', 'LE', 'LT'}")
+    .Attr("T: {bfloat16, half, float, double}")
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+// TODO(b/171294012): change shape function.
+REGISTER_OP("RiscBitcast")
+    .Input("x: SrcT")
+    .Output("y: DstT")
+    .Attr("SrcT: type")
+    .Attr("DstT: type")
+    .SetShapeFn(shape_inference::UnknownShape);
+
 // TODO(b/171294012): change shape function.
 REGISTER_OP("RiscBroadcast")
     .Input("input: T")
@@ -39,6 +64,20 @@
     .Attr("Tidx: {int32, int64} = DT_INT32")
     .SetShapeFn(shape_inference::UnknownShape);
 
+REGISTER_OP("RiscCast")
+    .Input("x: SrcT")
+    .Output("y: DstT")
+    .Attr("SrcT: type")
+    .Attr("DstT: type")
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+// TODO(b/171294012): change shape function.
+REGISTER_OP("RiscCholesky")
+    .Input("input: T")
+    .Output("output: T")
+    .Attr("T: {bfloat16, half, float, double}")
+    .SetShapeFn(shape_inference::UnknownShape);
+
 REGISTER_OP("RiscConcat")
     .Input("values: N * T")
     .Input("axis: Tidx")
@@ -49,6 +88,18 @@
     .SetShapeFn(shape_inference::ConcatV2Shape);
 
 // TODO(b/171294012): change shape function.
+REGISTER_OP("RiscCondition")
+    .Input("pred: bool")
+    .Input("input_true: SrcT")
+    .Input("input_false: SrcT")
+    .Output("output: DstT")
+    .Attr("func_true: func")
+    .Attr("func_false: func")
+    .Attr("SrcT: {bfloat16, half, float, double}")
+    .Attr("DstT: {bfloat16, half, float, double}")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+// TODO(b/171294012): change shape function.
 REGISTER_OP("RiscConv")
     .Input("input: T")
     .Input("filter: T")
@@ -68,6 +119,48 @@
     .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::MatMulShape);
 
+// TODO(b/171294012): change shape function.
+REGISTER_OP("RiscFft")
+    .Input("input: Tcomplex")
+    .Output("output: Tcomplex")
+    .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+// TODO(b/171294012): change shape function.
+REGISTER_OP("RiscGather")
+    .Input("params: Tparams")
+    .Input("indices: Tindices")
+    .Input("axis: Taxis")
+    .Attr("batch_dims: int = 0")
+    .Output("output: Tparams")
+    .Attr("Tparams: type")
+    .Attr("Tindices: {int32,int64}")
+    .Attr("Taxis: {int32,int64}")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("RiscIsFinite")
+    .Input("x: T")
+    .Output("y: bool")
+    .Attr("T: {bfloat16, half, float, double}")
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("RiscLogicalAnd")
+    .Input("x: bool")
+    .Input("y: bool")
+    .Output("z: bool")
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("RiscLogicalNot")
+    .Input("x: bool")
+    .Output("z: bool")
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("RiscLogicalOr")
+    .Input("x: bool")
+    .Input("y: bool")
+    .Output("z: bool")
+    .SetShapeFn(shape_inference::UnchangedShape);
+
 REGISTER_OP("RiscMax")
     .Input("x: T")
     .Input("y: T")
@@ -96,6 +189,23 @@
     .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::UnknownShape);
 
+REGISTER_OP("RiscRandomUniform")
+    .Input("shape: T")
+    .Output("output: float")
+    .Attr("seed: int = 0")
+    .Attr("T: {int32, int64}")
+    .SetShapeFn(shape_inference::RandomShape);
+
+// TODO(b/171294012): change shape function.
+REGISTER_OP("RiscReduce")
+    .Input("tensor: T")
+    .Input("axis: Index")
+    .Output("output: T")
+    .Attr("reduce_type: {'MEAN', 'SUM'}")
+    .Attr("Index: {int32,int64} = DT_INT32")
+    .Attr("T: {bfloat16, half, float, double}")
+    .SetShapeFn(shape_inference::UnknownShape);
+
 // TODO(b/171294012): change shape function.
 REGISTER_OP("RiscReshape")
     .Input("tensor: T")
@@ -105,6 +215,24 @@
     .Attr("Tshape: {int32, int64} = DT_INT32")
     .SetShapeFn(shape_inference::UnknownShape);
 
+REGISTER_OP("RiscReverse")
+    .Input("tensor: T")
+    .Input("axis: Tidx")
+    .Output("output: T")
+    .Attr("Tidx: {int32, int64} = DT_INT32")
+    .Attr("T: {bfloat16, half, float, double}")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+// TODO(b/171294012): change shape function.
+REGISTER_OP("RiscScatter")
+    .Input("indices: Tindices")
+    .Input("updates: T")
+    .Input("shape: Tindices")
+    .Output("output: T")
+    .Attr("T: {bfloat16, half, float, double}")
+    .Attr("Tindices: {int32, int64}")
+    .SetShapeFn(shape_inference::UnknownShape);
+
 // TODO(b/171294012): change shape function.
 REGISTER_OP("RiscShape")
     .Input("input: T")
@@ -122,4 +250,61 @@
     .Attr("Index: {int32,int64}")
     .SetShapeFn(shape_inference::SliceShape);
 
+REGISTER_OP("RiscSort")
+    .Input("input: T")
+    .Input("axis: Index")
+    .Output("output: T")
+    .Attr("Index: {int32,int64} = DT_INT32")
+    .Attr("T: {bfloat16, half, float, double}")
+    .Attr("direction: {'ASCENDING', 'DESCENDING'}")
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+// TODO(b/171294012): change shape function.
+REGISTER_OP("RiscSqueeze")
+    .Input("input: T")
+    .Output("output: T")
+    .Attr("T: type")
+    .Attr("squeeze_dims: list(int) >= 0 = []")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+// TODO(b/171294012): change shape function.
+REGISTER_OP("RiscTranspose")
+    .Input("x: T")
+    .Input("perm: Tperm")
+    .Output("y: T")
+    .Attr("T: type")
+    .Attr("Tperm: {int32, int64} = DT_INT32")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+// TODO(b/171294012): change shape function.
+REGISTER_OP("RiscTriangularSolve")
+    .Input("matrix: T")
+    .Input("rhs: T")
+    .Output("output: T")
+    .Attr("lower: bool = True")
+    .Attr("adjoint: bool = False")
+    .Attr("T: {bfloat16, half, float, double}")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("RiscUnary")
+    .Input("x: T")
+    .Output("y: T")
+    .Attr(
+        "op_type: {'ABL', 'CEIL', 'COS', 'EXP', 'FLOOR', 'IMAG', 'LOG', 'NEG', "
+        "'REAL', 'SIGN'}")
+    .Attr("T: {bfloat16, half, float, double}")
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+// TODO(b/171294012): change shape function.
+REGISTER_OP("RiscWhile")
+    .Input("input: T")
+    .Output("output: T")
+    .Attr("T: list(type) >= 0")
+    .Attr("cond: func")
+    .Attr("body: func")
+    .Attr("output_shapes: list(shape) = []")
+    .Attr("parallel_iterations: int = 10")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::UnknownShape);
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 4c99382..a680ea6 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -94,10 +94,6 @@
         "mutex.h",
         "net.h",
         "numa.h",
-        "profile_utils/android_armv7a_cpu_utils_helper.h",
-        "profile_utils/cpu_utils.cc",
-        "profile_utils/cpu_utils.h",
-        "profile_utils/i_cpu_utils_helper.h",
         "ram_file_system.h",
         "resource_loader.h",
         "resource.h",
@@ -985,24 +981,9 @@
     alwayslink = 1,
 )
 
-cc_library(
+alias(
     name = "profile_utils_cpu_utils",
-    srcs = [
-        "profile_utils/android_armv7a_cpu_utils_helper.h",
-        "profile_utils/cpu_utils.cc",
-        "profile_utils/i_cpu_utils_helper.h",
-    ],
-    hdrs = [
-        "profile_utils/cpu_utils.h",
-    ],
-    copts = tf_copts(),
-    deps = [
-        ":logging",
-        ":macros",
-        ":types",
-        "@com_google_absl//absl/base",
-    ],
-    alwayslink = 1,
+    actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils",
 )
 
 filegroup(
@@ -1021,13 +1002,13 @@
         "mutex_test.cc",
         "net_test.cc",
         "port_test.cc",
-        "profile_utils/cpu_utils_test.cc",
         "scanner_test.cc",
         "str_util_test.cc",
         "strcat_test.cc",
         "stringpiece_test.cc",
         "stringprintf_test.cc",
         "vmodule_benchmark_test.cc",
+        "//tensorflow/core/platform/profile_utils:cpu_utils_test.cc",
     ],
     create_named_test_suite = True,
     deps = [
@@ -1370,10 +1351,10 @@
         "numa.h",
         "path.h",
         "prefetch.h",
-        "profile_utils/android_armv7a_cpu_utils_helper.h",
-        "profile_utils/clock_cycle_profiler.h",
-        "profile_utils/cpu_utils.h",
-        "profile_utils/i_cpu_utils_helper.h",
+        "//tensorflow/core/platform/profile_utils:android_armv7a_cpu_utils_helper.h",
+        "//tensorflow/core/platform/profile_utils:clock_cycle_profiler.h",
+        "//tensorflow/core/platform/profile_utils:cpu_utils.h",
+        "//tensorflow/core/platform/profile_utils:i_cpu_utils_helper.h",
         "protobuf.h",
         "ram_file_system.h",
         "random.h",
@@ -1661,11 +1642,11 @@
         "platform_strings.cc",
         "platform_strings.h",
         "platform_strings_computed.h",
-        "profile_utils/android_armv7a_cpu_utils_helper.cc",
-        "profile_utils/android_armv7a_cpu_utils_helper.h",
-        "profile_utils/cpu_utils.cc",
-        "profile_utils/cpu_utils.h",
-        "profile_utils/i_cpu_utils_helper.h",
+        "//tensorflow/core/platform/profile_utils:android_armv7a_cpu_utils_helper.cc",
+        "//tensorflow/core/platform/profile_utils:android_armv7a_cpu_utils_helper.h",
+        "//tensorflow/core/platform/profile_utils:cpu_utils.cc",
+        "//tensorflow/core/platform/profile_utils:cpu_utils.h",
+        "//tensorflow/core/platform/profile_utils:i_cpu_utils_helper.h",
         "protobuf_internal.h",
         "random.cc",
         "random.h",
@@ -1683,7 +1664,6 @@
     srcs = glob(
         [
             "*.h",
-            "profile_utils/**/*.h",
         ],
         exclude = [
             "dynamic_annotations.h",
@@ -1700,16 +1680,18 @@
             "**/rocm.h",
             "**/stream_executor.h",
         ],
-    ),
+    ) + [
+        "//tensorflow/core/platform/profile_utils:android_armv7a_cpu_utils_helper.h",
+        "//tensorflow/core/platform/profile_utils:cpu_utils.h",
+        "//tensorflow/core/platform/profile_utils:i_cpu_utils_helper.h",
+        "//tensorflow/core/platform/profile_utils:clock_cycle_profiler.h",
+    ],
     visibility = ["//tensorflow/core:__pkg__"],
 )
 
-filegroup(
+alias(
     name = "legacy_lib_internal_srcs",
-    srcs = [
-        "profile_utils/android_armv7a_cpu_utils_helper.cc",
-        "profile_utils/clock_cycle_profiler.cc",
-    ],
+    actual = "//tensorflow/core/platform/profile_utils:legacy_lib_internal_srcs",
     visibility = ["//tensorflow/core:__pkg__"],
 )
 
@@ -1728,7 +1710,6 @@
         "//learning/brain/google/xla/tests:__pkg__",
         "//learning/brain/tfrc/runtime/tpu_driver:__subpackages__",
         "//nlp/deleuze:__pkg__",
-        "//nlp/projects/minmt:__pkg__",
         "//tensorflow:__subpackages__",
     ],
 )
diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index babf249..17af900 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -177,7 +177,7 @@
   BIO_free_all(bio);
 
   // Now check the content of the header and the claim.
-  int dot = header_dot_claim.find_last_of(".");
+  int dot = header_dot_claim.find_last_of('.');
   string header_encoded = header_dot_claim.substr(0, dot);
   string claim_encoded = header_dot_claim.substr(dot + 1);
 
diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD
index 09ce2f2..7c1148b 100644
--- a/tensorflow/core/platform/default/BUILD
+++ b/tensorflow/core/platform/default/BUILD
@@ -283,8 +283,8 @@
         "//tensorflow/core/platform:init_main.h",
         "//tensorflow/core/platform:mem.h",
         "//tensorflow/core/platform:numa.h",
-        "//tensorflow/core/platform:profile_utils/cpu_utils.h",
         "//tensorflow/core/platform:snappy.h",
+        "//tensorflow/core/platform/profile_utils:cpu_utils.h",
     ],
     copts = tf_copts(),
     defines = ["TF_USE_SNAPPY"] + select({
@@ -292,6 +292,7 @@
         "//tensorflow:with_numa_support": ["TENSORFLOW_USE_NUMA"],
         "//conditions:default": [],
     }),
+    features = ["-layering_check"],
     tags = [
         "manual",
         "no_oss",
@@ -546,8 +547,8 @@
         "resource.cc",
         "stacktrace.h",
         "tracing_impl.h",
-        "//tensorflow/core/platform:profile_utils/cpu_utils.h",
-        "//tensorflow/core/platform:profile_utils/i_cpu_utils_helper.h",
+        "//tensorflow/core/platform/profile_utils:cpu_utils.h",
+        "//tensorflow/core/platform/profile_utils:i_cpu_utils_helper.h",
     ],
     visibility = ["//tensorflow/core/platform:__pkg__"],
 )
diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc
index b19c263..e0d2e5b 100644
--- a/tensorflow/core/platform/default/logging.cc
+++ b/tensorflow/core/platform/default/logging.cc
@@ -485,7 +485,8 @@
   __android_log_write(android_log_level, "native", ss.str().c_str());
 
   // Also log to stderr (for standalone Android apps).
-  std::cerr << "native : " << ss.str() << std::endl;
+  // Don't use 'std::cerr' since it crashes on Android.
+  fprintf(stderr, "native : %s\n", ss.str().c_str());
 
   // Android logging at level FATAL does not terminate execution, so abort()
   // is still required to stop the program.
diff --git a/tensorflow/core/platform/default/posix_file_system.cc b/tensorflow/core/platform/default/posix_file_system.cc
index 18fea3f..29f9bba 100644
--- a/tensorflow/core/platform/default/posix_file_system.cc
+++ b/tensorflow/core/platform/default/posix_file_system.cc
@@ -19,6 +19,7 @@
 #include <stdint.h>
 #include <stdio.h>
 #include <sys/mman.h>
+
 #if defined(__linux__)
 #include <sys/sendfile.h>
 #endif
@@ -31,6 +32,7 @@
 #include "tensorflow/core/platform/default/posix_file_system.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/error.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/file_system_helper.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/status.h"
@@ -92,6 +94,34 @@
     *result = StringPiece(scratch, dst - scratch);
     return s;
   }
+
+#if defined(TF_CORD_SUPPORT)
+  Status Read(uint64 offset, size_t n, absl::Cord* cord) const override {
+    if (n == 0) {
+      return Status::OK();
+    }
+    if (n < 0) {
+      return errors::InvalidArgument(
+          "Attempting to read ", n,
+          " bytes. You cannot read a negative number of bytes.");
+    }
+
+    char* scratch = new char[n];
+    if (scratch == nullptr) {
+      return errors::ResourceExhausted("Unable to allocate ", n,
+                                       " bytes for file reading.");
+    }
+
+    StringPiece tmp;
+    Status s = Read(offset, n, &tmp, scratch);
+
+    absl::Cord tmp_cord = absl::MakeCordFromExternal(
+        absl::string_view(static_cast<char*>(scratch), tmp.size()),
+        [scratch](absl::string_view) { delete[] scratch; });
+    cord->Append(tmp_cord);
+    return s;
+  }
+#endif
 };
 
 class PosixWritableFile : public WritableFile {
@@ -118,6 +148,19 @@
     return Status::OK();
   }
 
+#if defined(TF_CORD_SUPPORT)
+  // \brief Append 'cord' to the file.
+  Status Append(const absl::Cord& cord) override {
+    for (const auto& chunk : cord.Chunks()) {
+      size_t r = fwrite(chunk.data(), 1, chunk.size(), file_);
+      if (r != chunk.size()) {
+        return IOError(filename_, errno);
+      }
+    }
+    return Status::OK();
+  }
+#endif
+
   Status Close() override {
     if (file_ == nullptr) {
       return IOError(filename_, EBADF);
diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc
index 7bd34cd..f8ce2a9 100644
--- a/tensorflow/core/platform/file_system_helper.cc
+++ b/tensorflow/core/platform/file_system_helper.cc
@@ -52,115 +52,217 @@
 #endif
 }
 
+// A globbing pattern can only start with these characters:
+static const char kGlobbingChars[] = "*?[\\";
+
+static inline bool IsGlobbingPattern(const std::string& pattern) {
+  return (pattern.find_first_of(kGlobbingChars) != std::string::npos);
+}
+
+// Make sure that the first entry in `dirs` during glob expansion does not
+// contain a glob pattern. This is to prevent a corner-case bug where
+// `<pattern>` would be treated differently than `./<pattern>`.
+static std::string PatchPattern(const std::string& pattern) {
+  const std::string fixed_prefix =
+      pattern.substr(0, pattern.find_first_of(kGlobbingChars));
+
+  // Patching is needed when there is no directory part in `prefix`
+  if (io::Dirname(fixed_prefix).empty()) {
+    return io::JoinPath(".", pattern);
+  }
+
+  // No patching needed
+  return pattern;
+}
+
+static std::vector<std::string> AllDirectoryPrefixes(const std::string& d) {
+  std::vector<std::string> dirs;
+  const std::string patched = PatchPattern(d);
+  StringPiece dir(patched);
+
+  // If the pattern ends with a `/` (or `\\` on Windows), we need to strip it
+  // otherwise we would have one additional matching step and the result set
+  // would be empty.
+  bool is_directory = d[d.size() - 1] == '/';
+#ifdef PLATFORM_WINDOWS
+  is_directory = is_directory || (d[d.size() - 1] == '\\');
+#endif
+  if (is_directory) {
+    dir = io::Dirname(dir);
+  }
+
+  while (!dir.empty()) {
+    dirs.emplace_back(dir);
+    StringPiece new_dir(io::Dirname(dir));
+    // io::Dirname("/") returns "/" so we need to break the loop.
+    // On Windows, io::Dirname("C:\\") would return "C:\\", so we check for
+    // identity of the result instead of checking for dir[0] == `/`.
+    if (dir == new_dir) break;
+    dir = new_dir;
+  }
+
+  // Order the array from parent to ancestor (reverse order).
+  std::reverse(dirs.begin(), dirs.end());
+
+  return dirs;
+}
+
+static inline int GetFirstGlobbingEntry(const std::vector<std::string>& dirs) {
+  int i = 0;
+  for (const auto& d : dirs) {
+    if (IsGlobbingPattern(d)) {
+      break;
+    }
+    i++;
+  }
+  return i;
+}
+
 }  // namespace
 
 Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern,
                         std::vector<string>* results) {
+  // Check that `fs`, `env` and `results` are non-null.
+  if (fs == nullptr || env == nullptr || results == nullptr) {
+    return Status(tensorflow::error::INVALID_ARGUMENT,
+                  "Filesystem calls GetMatchingPaths with nullptr arguments");
+  }
+
+  // By design, we don't match anything on empty pattern
   results->clear();
   if (pattern.empty()) {
     return Status::OK();
   }
 
-  string fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\"));
-  string eval_pattern = pattern;
-  string dir(io::Dirname(fixed_prefix));
-  // If dir is empty then we need to fix up fixed_prefix and eval_pattern to
-  // include . as the top level directory.
-  if (dir.empty()) {
-    dir = ".";
-    fixed_prefix = io::JoinPath(dir, fixed_prefix);
-    eval_pattern = io::JoinPath(dir, eval_pattern);
+  // The pattern can contain globbing characters at multiple levels, e.g.:
+  //
+  //   foo/ba?/baz/f*r
+  //
+  // To match the full pattern, we must match every prefix subpattern and then
+  // operate on the children for each match. Thus, we separate all subpatterns
+  // in the `dirs` vector below.
+  std::vector<std::string> dirs = AllDirectoryPrefixes(pattern);
+
+  // We can have patterns that have several parents where no globbing is being
+  // done, for example, `foo/bar/baz/*`. We don't need to expand the directories
+  // which don't contain the globbing characters.
+  int matching_index = GetFirstGlobbingEntry(dirs);
+
+  // If we don't have globbing characters in the pattern then it specifies a
+  // path in the filesystem. We add it to the result set if it exists.
+  if (matching_index == dirs.size()) {
+    if (fs->FileExists(pattern).ok()) {
+      results->emplace_back(pattern);
+    }
+    return Status::OK();
   }
-  bool is_directory = pattern[pattern.size() - 1] == '/';
-#ifdef PLATFORM_WINDOWS
-  is_directory = is_directory || pattern[pattern.size() - 1] == '\\';
-#endif
-  std::vector<string> dirs;
-  if (!is_directory) {
-    dirs.emplace_back(eval_pattern);
-  }
-  StringPiece tmp_dir(io::Dirname(eval_pattern));
-  while (tmp_dir.size() > dir.size()) {
-    dirs.emplace_back(string(tmp_dir));
-    tmp_dir = io::Dirname(tmp_dir);
-  }
-  dirs.emplace_back(dir);
-  std::reverse(dirs.begin(), dirs.end());
-  // Setup a parallel BFS to explore everything under dir.
-  std::deque<std::pair<string, int>> dir_q;
-  std::deque<std::pair<string, int>> next_dir_q;
-  dir_q.emplace_back(std::make_pair(dirs[0], 0));
-  Status ret;  // Status to return.
-  mutex results_mutex;
-  condition_variable results_cond;
-  mutex next_que_mutex;
-  condition_variable next_que_cond;
-  while (!dir_q.empty()) {
-    next_dir_q.clear();
-    std::vector<Status> new_rets(dir_q.size());
-    auto handle_level = [fs, &results, &dir_q, &next_dir_q, &new_rets,
-                         &is_directory, &dirs, &results_mutex, &results_cond,
-                         &next_que_mutex, &next_que_cond](int i) {
-      string current_dir = dir_q.at(i).first;
-      int dir_index = dir_q.at(i).second;
-      dir_index++;
-      std::vector<string> children;
-      Status s = fs->GetChildren(current_dir, &children);
-      // In case PERMISSION_DENIED is encountered, we bail here.
+
+  // To expand the globbing, we do a BFS from `dirs[matching_index-1]`.
+  // At every step, we work on a pair `{dir, ix}` such that `dir` is a real
+  // directory, `ix < dirs.size() - 1` and `dirs[ix+1]` is a globbing pattern.
+  // To expand the pattern, we select from all the children of `dir` only those
+  // that match against `dirs[ix+1]`.
+  // If there are more entries in `dirs` after `dirs[ix+1]` this mean we have
+  // more patterns to match. So, we add to the queue only those children that
+  // are also directories, paired with `ix+1`.
+  // If there are no more entries in `dirs`, we return all children as part of
+  // the answer.
+  // Since we can get into a combinatorial explosion issue (e.g., pattern
+  // `/*/*/*`), we process the queue in parallel. Each parallel processing takes
+  // elements from `expand_queue` and adds them to `next_expand_queue`, after
+  // which we swap these two queues (similar to double buffering algorithms).
+  // PRECONDITION: `IsGlobbingPattern(dirs[0]) == false`
+  // PRECONDITION: `matching_index > 0`
+  // INVARIANT: If `{d, ix}` is in queue, then `d` and `dirs[ix]` are at the
+  //            same level in the filesystem tree.
+  // INVARIANT: If `{d, _}` is in queue, then `IsGlobbingPattern(d) == false`.
+  // INVARIANT: If `{d, _}` is in queue, then `d` is a real directory.
+  // INVARIANT: If `{_, ix}` is in queue, then `ix < dirs.size() - 1`.
+  // INVARIANT: If `{_, ix}` is in queue, `IsGlobbingPattern(dirs[ix + 1])`.
+  std::deque<std::pair<string, int>> expand_queue;
+  std::deque<std::pair<string, int>> next_expand_queue;
+  expand_queue.emplace_back(dirs[matching_index - 1], matching_index - 1);
+
+  // Adding to `result` or `new_expand_queue` need to be protected by mutexes
+  // since there are multiple threads writing to these.
+  mutex result_mutex;
+  mutex queue_mutex;
+
+  while (!expand_queue.empty()) {
+    next_expand_queue.clear();
+
+    // The work item for every item in `expand_queue`.
+    // pattern, we process them in parallel.
+    auto handle_level = [&fs, &results, &dirs, &expand_queue,
+                         &next_expand_queue, &result_mutex,
+                         &queue_mutex](int i) {
+      // See invariants above, all of these are valid accesses.
+      const auto& queue_item = expand_queue.at(i);
+      const std::string& parent = queue_item.first;
+      const int index = queue_item.second + 1;
+      const std::string& match_pattern = dirs[index];
+
+      // Get all children of `parent`. If this fails, return early.
+      std::vector<std::string> children;
+      Status s = fs->GetChildren(parent, &children);
       if (s.code() == tensorflow::error::PERMISSION_DENIED) {
         return;
       }
-      new_rets[i] = s;
-      if (children.empty()) return;
 
-      // children_dir_status holds is_dir status for children. It can have three
-      // possible values: OK for true; FAILED_PRECONDITION for false; CANCELLED
-      // if we don't calculate IsDirectory (we might do that because there isn't
-      // any point in exploring that child path).
-      std::vector<Status> children_dir_status;
+      // Also return early if we don't have any children
+      if (children.empty()) {
+        return;
+      }
 
-      // This IsDirectory call can be expensive for some FS. Parallelizing it.
-      children_dir_status.resize(children.size());
-      auto handle_children = [fs, &current_dir, &children, &dirs, dir_index,
-                              is_directory, &children_dir_status](int j) {
-        const string child_path = io::JoinPath(current_dir, children[j]);
-        if (!fs->Match(child_path, dirs[dir_index])) {
-          children_dir_status[j] =
+      // Since we can get extremely many children here and on some filesystems
+      // `IsDirectory` is expensive, we process the children in parallel.
+      // We also check that children match the pattern in parallel, for speedup.
+      // We store the status of the match and `IsDirectory` in
+      // `children_status` array, one element for each children.
+      std::vector<Status> children_status(children.size());
+      auto handle_children = [&fs, &match_pattern, &parent, &children,
+                              &children_status](int j) {
+        const std::string path = io::JoinPath(parent, children[j]);
+        if (!fs->Match(path, match_pattern)) {
+          children_status[j] =
               Status(tensorflow::error::CANCELLED, "Operation not needed");
-        } else if (dir_index != dirs.size() - 1) {
-          children_dir_status[j] = fs->IsDirectory(child_path);
         } else {
-          children_dir_status[j] =
-              is_directory ? fs->IsDirectory(child_path) : Status::OK();
+          children_status[j] = fs->IsDirectory(path);
         }
       };
       ForEach(0, children.size(), handle_children);
 
-      for (size_t j = 0; j < children.size(); ++j) {
-        const string child_path = io::JoinPath(current_dir, children[j]);
-        // If the IsDirectory call was cancelled we bail.
-        if (children_dir_status[j].code() == tensorflow::error::CANCELLED) {
+      // At this point, pairing `children` with `children_status` will tell us
+      // if a children:
+      //   * does not match the pattern
+      //   * matches the pattern and is a directory
+      //   * matches the pattern and is not a directory
+      // We fully ignore the first case.
+      // If we matched the last pattern (`index == dirs.size() - 1`) then all
+      // remaining children get added to the result.
+      // Otherwise, only the directories get added to the next queue.
+      for (size_t j = 0; j < children.size(); j++) {
+        if (children_status[j].code() == tensorflow::error::CANCELLED) {
           continue;
         }
-        if (children_dir_status[j].ok()) {
-          if (dir_index != dirs.size() - 1) {
-            mutex_lock lk(next_que_mutex);
-            next_dir_q.emplace_back(std::make_pair(child_path, dir_index));
-            next_que_cond.notify_one();
-          } else {
-            mutex_lock lk(results_mutex);
-            results->emplace_back(child_path);
-            results_cond.notify_one();
-          }
+
+        const std::string path = io::JoinPath(parent, children[j]);
+        if (index == dirs.size() - 1) {
+          mutex_lock l(result_mutex);
+          results->emplace_back(path);
+        } else if (children_status[j].ok()) {
+          mutex_lock l(queue_mutex);
+          next_expand_queue.emplace_back(path, index);
         }
       }
     };
-    ForEach(0, dir_q.size(), handle_level);
+    ForEach(0, expand_queue.size(), handle_level);
 
-    ret.Update(new_rets[dir_q.size() - 1]);
-    std::swap(dir_q, next_dir_q);
+    // After evaluating one level, swap the "buffers"
+    std::swap(expand_queue, next_expand_queue);
   }
-  return ret;
+
+  return Status::OK();
 }
 
 }  // namespace internal
diff --git a/tensorflow/core/platform/profile_utils/BUILD b/tensorflow/core/platform/profile_utils/BUILD
new file mode 100644
index 0000000..5d900e3
--- /dev/null
+++ b/tensorflow/core/platform/profile_utils/BUILD
@@ -0,0 +1,60 @@
+# Description:
+# profile_utils targets.
+
+load("//tensorflow:tensorflow.bzl", "filegroup")
+load(
+    "//tensorflow/core/platform:rules_cc.bzl",
+    "cc_library",
+)
+load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_copts",  # @unused
+)
+
+package(
+    default_visibility = [
+        "//tensorflow/core:__pkg__",
+        "//tensorflow/core/default:__pkg__",
+        "//tensorflow/core/platform:__pkg__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+exports_files(srcs = [
+    "android_armv7a_cpu_utils_helper.cc",
+    "android_armv7a_cpu_utils_helper.h",
+    "clock_cycle_profiler.h",
+    "cpu_utils.cc",
+    "cpu_utils.h",
+    "cpu_utils_test.cc",
+    "i_cpu_utils_helper.h",
+])
+
+filegroup(
+    name = "legacy_lib_internal_srcs",
+    srcs = [
+        "android_armv7a_cpu_utils_helper.cc",
+        "clock_cycle_profiler.cc",
+    ],
+    visibility = ["//tensorflow/core/platform:__pkg__"],
+)
+
+cc_library(
+    name = "profile_utils_cpu_utils",
+    srcs = [
+        "android_armv7a_cpu_utils_helper.h",
+        "cpu_utils.cc",
+        "i_cpu_utils_helper.h",
+    ],
+    hdrs = [
+        "cpu_utils.h",
+    ],
+    copts = tf_copts(),
+    deps = [
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:types",
+        "@com_google_absl//absl/base",
+    ],
+    alwayslink = 1,
+)
diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc
index 475f879..519ff72 100644
--- a/tensorflow/core/platform/windows/windows_file_system.cc
+++ b/tensorflow/core/platform/windows/windows_file_system.cc
@@ -147,6 +147,34 @@
     *result = StringPiece(scratch, dst - scratch);
     return s;
   }
+
+#if defined(TF_CORD_SUPPORT)
+  Status Read(uint64 offset, size_t n, absl::Cord* cord) const override {
+    if (n == 0) {
+      return Status::OK();
+    }
+    if (n < 0) {
+      return errors::InvalidArgument(
+          "Attempting to read ", n,
+          " bytes. You cannot read a negative number of bytes.");
+    }
+
+    char* scratch = new char[n];
+    if (scratch == nullptr) {
+      return errors::ResourceExhausted("Unable to allocate ", n,
+                                       " bytes for file reading.");
+    }
+
+    StringPiece tmp;
+    Status s = Read(offset, n, &tmp, scratch);
+
+    absl::Cord tmp_cord = absl::MakeCordFromExternal(
+        absl::string_view(static_cast<char*>(scratch), tmp.size()),
+        [scratch](absl::string_view) { delete[] scratch; });
+    cord->Append(tmp_cord);
+    return s;
+  }
+#endif
 };
 
 class WindowsWritableFile : public WritableFile {
@@ -177,6 +205,24 @@
     return Status::OK();
   }
 
+#if defined(TF_CORD_SUPPORT)
+  // \brief Append 'data' to the file.
+  Status Append(const absl::Cord& cord) override {
+    for (const auto& chunk : cord.Chunks()) {
+      DWORD bytes_written = 0;
+      DWORD data_size = static_cast<DWORD>(chunk.size());
+      BOOL write_result =
+          ::WriteFile(hfile_, chunk.data(), data_size, &bytes_written, NULL);
+      if (FALSE == write_result) {
+        return IOErrorFromWindowsError("Failed to WriteFile: " + filename_);
+      }
+
+      assert(size_t(bytes_written) == chunk.size());
+    }
+    return Status::OK();
+  }
+#endif
+
   Status Tell(int64* position) override {
     Status result = Flush();
     if (!result.ok()) {
diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc
index 7f9111d..0144e76 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc
@@ -92,7 +92,7 @@
 dynamic_shared:16384
 grid:2,1,1
 block:32,1,1
-occ_pct:1.0)MULTI";
+occ_pct:100)MULTI";
 
   XSpace space;
   XPlaneBuilder device_plane(
diff --git a/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc b/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc
index 4fe3ed5..accbe4e 100644
--- a/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc
+++ b/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc
@@ -21,37 +21,28 @@
 
 namespace tensorflow {
 namespace profiler {
+namespace {
 
-void MergeHostPlanes(XSpace* space) {
-  const XPlane* cupti_driver_api_plane =
-      FindPlaneWithName(*space, kCuptiDriverApiPlaneName);
-  const XPlane* python_tracer_plane =
-      FindPlaneWithName(*space, kPythonTracerPlaneName);
-  if (cupti_driver_api_plane || python_tracer_plane) {
-    XPlane* host_plane =
-        FindOrAddMutablePlaneWithName(space, kHostThreadsPlaneName);
-    if (cupti_driver_api_plane) {
-      MergePlanes(*cupti_driver_api_plane, host_plane);
-    }
-    if (python_tracer_plane) {
-      MergePlanes(*python_tracer_plane, host_plane);
-    }
-    SortXLinesBy(host_plane, XLinesComparatorByName());
-    if (cupti_driver_api_plane) {
-      RemovePlane(space, cupti_driver_api_plane);
-    }
-    if (python_tracer_plane) {
-      RemovePlane(space, python_tracer_plane);
-    }
+// Merges XPlanes generated by TraceMe, CUPTI API trace and Python tracer.
+void MergeHostPlanesAndSortLines(XSpace* space) {
+  XPlane* host_plane =
+      FindOrAddMutablePlaneWithName(space, kHostThreadsPlaneName);
+  std::vector<const XPlane*> additional_host_planes = FindPlanesWithNames(
+      *space, {kCuptiDriverApiPlaneName, kPythonTracerPlaneName});
+  if (!additional_host_planes.empty()) {
+    MergePlanes(additional_host_planes, host_plane);
+    RemovePlanes(space, additional_host_planes);
   }
+  SortXLinesBy(host_plane, XLinesComparatorByName());
 }
 
+}  // namespace
+
 void PostProcessSingleHostXSpace(XSpace* space, uint64 start_time_ns) {
   VLOG(3) << "Post processing local profiler XSpace.";
   // Post processing the collected XSpace without hold profiler lock.
-  // 1. Merge plane of host events with plane of CUPTI driver api.
-  MergeHostPlanes(space);
-
+  // 1. Merge all host planes and sorts lines by name.
+  MergeHostPlanesAndSortLines(space);
   // 2. Normalize all timestamps by shifting timeline to profiling start time.
   // NOTE: this have to be done before sorting XSpace due to timestamp overflow.
   NormalizeTimestamps(space, start_time_ns);
diff --git a/tensorflow/core/profiler/convert/post_process_single_host_xplane.h b/tensorflow/core/profiler/convert/post_process_single_host_xplane.h
index 70c6785..31ebe28 100644
--- a/tensorflow/core/profiler/convert/post_process_single_host_xplane.h
+++ b/tensorflow/core/profiler/convert/post_process_single_host_xplane.h
@@ -21,9 +21,6 @@
 namespace tensorflow {
 namespace profiler {
 
-// Merges XPlanes generated by TraceMe, CUPTI API trace and Python tracer.
-void MergeHostPlanes(XSpace* space);
-
 // Post process XSpaces collected locally from multiple profilers.
 void PostProcessSingleHostXSpace(XSpace* space, uint64 start_time_ns);
 
diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc
index 700f057..568c485 100644
--- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc
@@ -46,7 +46,7 @@
 dynamic_shared:0
 grid:1,1,1
 block:1,1,1
-occ_pct:0.5)MULTI"},
+occ_pct:50.0)MULTI"},
                 {StatType::kEquation, ""}});
 
   CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_middle",
@@ -57,7 +57,7 @@
 dynamic_shared:16384
 grid:2,1,1
 block:32,1,1
-occ_pct=0.13)MULTI"},
+occ_pct=13.0)MULTI"},
                 {StatType::kEquation, ""}});
 
   CreateXEvent(&device_trace_builder, &line_builder,
@@ -69,7 +69,7 @@
 dynamic_shared:16384
 grid:3,1,1
 block:64,1,1
-occ_pct:0.25)MULTI"},
+occ_pct:25.0)MULTI"},
                 {StatType::kEquation, ""}});
 
   KernelReportMap reports;
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
index e6b84b6..0e1ceea 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
@@ -29,7 +29,6 @@
 #include "tensorflow/core/profiler/convert/xplane_to_step_events.h"
 #include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h"
 #include "tensorflow/core/profiler/protobuf/diagnostics.pb.h"
-#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
 #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
 #include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
@@ -48,7 +47,6 @@
 
 namespace tensorflow {
 namespace profiler {
-namespace {
 
 DeviceCapabilities GetDeviceCapFromXPlane(const XPlane& device_plane) {
   DeviceCapabilities cap;
@@ -79,8 +77,6 @@
   return cap;
 }
 
-}  // namespace
-
 PerfEnv MakePerfEnv(double peak_tera_flops_per_second,
                     double peak_hbm_bw_giga_bytes_per_second) {
   PerfEnv result;
@@ -164,6 +160,8 @@
                     op_stats.mutable_run_environment());
 
   KernelReportMap reports;
+  absl::string_view gpu_model = "";
+
   // TODO(b/161942993) parallelize XPlane processing per thread.
   for (const XPlane* device_trace : device_planes) {
     if (options.generate_op_metrics_db) {
@@ -174,6 +172,9 @@
           ConvertDeviceTraceXPlaneToOpMetricsDb(*device_trace);
       op_metrics_db_combiner.Combine(device_op_metrics_db);
     }
+    if (gpu_model.empty()) {
+      gpu_model = GpuModelName(GetDeviceCapFromXPlane(*device_trace));
+    }
     if (options.generate_step_db) {
       CombineStepEvents(ConvertDeviceTraceXPlaneToStepEvents(*device_trace),
                         &step_events);
@@ -184,6 +185,11 @@
     }
   }
 
+  if (!gpu_model.empty()) {
+    // Overwrites the device type with the more specific GPU model name.
+    op_stats.mutable_run_environment()->set_device_type(std::string(gpu_model));
+  }
+
   // Combine into reports.
   if (options.generate_kernel_stats_db) {
     CopyTopKDurationKernelReportsToDb(reports,
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.h b/tensorflow/core/profiler/convert/xplane_to_op_stats.h
index 178f8c2..d327cfe 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_stats.h
+++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.h
@@ -18,6 +18,7 @@
 
 #include "absl/container/flat_hash_set.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
 #include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 
@@ -39,6 +40,9 @@
 void PropagateXSpaceDiagnosticsToOpStats(const XSpace& space,
                                          OpStats* op_stats);
 
+// Extracts DeviceCapabilities from XPlane stats.
+DeviceCapabilities GetDeviceCapFromXPlane(const XPlane& device_plane);
+
 // Populates PerfEnv.
 PerfEnv MakePerfEnv(double peak_tera_flops_per_second,
                     double peak_hbm_bw_giga_bytes_per_second);
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc
index a61c22f..e21a0ca 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc
@@ -87,7 +87,7 @@
   OpStats op_stats = ConvertXSpaceToOpStats(space, OpStatsOptions());
   const RunEnvironment& run_env = op_stats.run_environment();
 
-  EXPECT_EQ("GPU", run_env.device_type());
+  EXPECT_EQ("Nvidia GPU", run_env.device_type());
   EXPECT_EQ(1, run_env.host_count());
   EXPECT_EQ(1, run_env.task_count());
   EXPECT_EQ(2, run_env.device_core_count());
diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD
index 40773c6..681670a 100644
--- a/tensorflow/core/profiler/internal/gpu/BUILD
+++ b/tensorflow/core/profiler/internal/gpu/BUILD
@@ -127,8 +127,10 @@
         ":cupti_collector",
         ":cupti_interface",
         ":cupti_utils",
+        ":nvtx_utils",
         "//tensorflow/core:lib",
         "//tensorflow/core/profiler/internal/cpu:annotation_stack",
+        "//tensorflow/core/profiler/lib:scoped_annotation",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:node_hash_map",
         "@com_google_absl//absl/container:node_hash_set",
@@ -137,6 +139,16 @@
 )
 
 tf_cuda_library(
+    name = "nvtx_utils",
+    srcs = if_cuda_is_configured_compat(["nvtx_utils.cc"]),
+    hdrs = if_cuda_is_configured_compat(["nvtx_utils.h"]),
+    copts = tf_profiler_copts() + tf_copts(),
+    deps = [
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_cuda_library(
     name = "cupti_collector",
     srcs = if_cuda_is_configured_compat(["cupti_collector.cc"]),
     hdrs = if_cuda_is_configured_compat(["cupti_collector.h"]),
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_collector.cc b/tensorflow/core/profiler/internal/gpu/cupti_collector.cc
index 684df4f..42d97d1 100644
--- a/tensorflow/core/profiler/internal/gpu/cupti_collector.cc
+++ b/tensorflow/core/profiler/internal/gpu/cupti_collector.cc
@@ -113,7 +113,7 @@
     }
 
     stats.occupancy_pct =
-        occ_result.activeBlocksPerMultiprocessor * params.block_size;
+        occ_result.activeBlocksPerMultiprocessor * params.block_size * 100;
     stats.occupancy_pct /= device_properties.maxThreadsPerMultiprocessor;
 
     status = cudaOccMaxPotentialOccupancyBlockSize(
@@ -160,6 +160,11 @@
                               GetStatTypeStr(StatType::kKernelAnnotation)),
                           *plane->GetOrCreateStatMetadata(event.annotation));
     }
+    if (!event.nvtx_range.empty()) {
+      xevent.AddStatValue(
+          *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kNVTXRange)),
+          *plane->GetOrCreateStatMetadata(event.nvtx_range));
+    }
     if (event.context_id != CuptiTracerEvent::kInvalidContextId) {
       xevent.AddStatValue(
           *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kContextId)),
@@ -184,7 +189,7 @@
       params.dynamic_smem_size = event.kernel_info.dynamic_shared_memory_usage;
 
       OccupancyStats& occ_stats = occupancy_cache[params];
-      if (occ_stats.occupancy_pct == 0) {
+      if (occ_stats.occupancy_pct == 0.0) {
         occ_stats = GetOccupancy(params);
       }
       xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr(
@@ -224,27 +229,25 @@
 
     std::vector<Annotation> annotation_stack =
         ParseAnnotationStack(event.annotation);
-    // If multiple metadata have the same key name, show the values from the top
-    // of the stack (innermost annotation). Concatenate the values from
-    // "hlo_op".
-    absl::flat_hash_set<absl::string_view> key_set;
-    std::vector<absl::string_view> hlo_op_names;
-    for (auto annotation = annotation_stack.rbegin();
-         annotation != annotation_stack.rend(); ++annotation) {
-      for (const Annotation::Metadata& metadata : annotation->metadata) {
-        if (metadata.key == "tf_op") {
-          continue;  // ignored, obtained from HLO proto via DebugInfoMap
-        } else if (key_set.insert(metadata.key).second) {
-          xevent.ParseAndAddStatValue(
-              *plane->GetOrCreateStatMetadata(metadata.key), metadata.value);
-        }
-      }
-    }
     if (!annotation_stack.empty()) {
       xevent.AddStatValue(
           *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)),
           *plane->GetOrCreateStatMetadata(annotation_stack.begin()->name));
     }
+    // If multiple metadata have the same key name, show the values from the top
+    // of the stack (innermost annotation). Concatenate the values from
+    // "hlo_op".
+    absl::flat_hash_set<absl::string_view> key_set;
+
+    for (auto annotation = annotation_stack.rbegin();
+         annotation != annotation_stack.rend(); ++annotation) {
+      for (const Annotation::Metadata& metadata : annotation->metadata) {
+        if (key_set.insert(metadata.key).second) {
+          xevent.ParseAndAddStatValue(
+              *plane->GetOrCreateStatMetadata(metadata.key), metadata.value);
+        }
+      }
+    }
   }
 
   absl::optional<int> GetDeviceAttribute(CUdevice device,
@@ -549,8 +552,9 @@
 }  // namespace
 
 void AnnotationMap::Add(uint32 device_id, uint32 correlation_id,
-                        const std::string& annotation) {
-  if (annotation.empty()) return;
+                        const absl::string_view annotation,
+                        const absl::string_view nvtx_range) {
+  if (annotation.empty() && nvtx_range.empty()) return;
   VLOG(3) << "Add annotation: device_id: " << device_id
           << " correlation_id: " << correlation_id
           << " annotation: " << annotation;
@@ -558,20 +562,22 @@
   auto& per_device_map = per_device_map_[device_id];
   absl::MutexLock lock(&per_device_map.mutex);
   if (per_device_map.annotations.size() < max_size_) {
-    absl::string_view annotation_str =
-        *per_device_map.annotations.insert(annotation).first;
-    per_device_map.correlation_map.emplace(correlation_id, annotation_str);
+    AnnotationInfo info;
+    info.annotation = *per_device_map.annotations.emplace(annotation).first;
+    if (!nvtx_range.empty())
+      info.nvtx_range = *per_device_map.nvtx_ranges.emplace(nvtx_range).first;
+    per_device_map.correlation_map.emplace(correlation_id, info);
   }
 }
 
-absl::string_view AnnotationMap::LookUp(uint32 device_id,
-                                        uint32 correlation_id) {
-  if (device_id >= per_device_map_.size()) return absl::string_view();
+AnnotationMap::AnnotationInfo AnnotationMap::LookUp(uint32 device_id,
+                                                    uint32 correlation_id) {
+  if (device_id >= per_device_map_.size()) return AnnotationInfo();
   auto& per_device_map = per_device_map_[device_id];
   absl::MutexLock lock(&per_device_map.mutex);
   auto it = per_device_map.correlation_map.find(correlation_id);
   return it != per_device_map.correlation_map.end() ? it->second
-                                                    : absl::string_view();
+                                                    : AnnotationInfo();
 }
 
 // CuptiTraceCollectorImpl store the CuptiTracerEvents from CuptiTracer and
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_collector.h b/tensorflow/core/profiler/internal/gpu/cupti_collector.h
index ada6cec..d303c58 100644
--- a/tensorflow/core/profiler/internal/gpu/cupti_collector.h
+++ b/tensorflow/core/profiler/internal/gpu/cupti_collector.h
@@ -127,6 +127,7 @@
   // This points to strings in AnnotationMap, which should outlive the point
   // where serialization happens.
   absl::string_view annotation;
+  absl::string_view nvtx_range;
   uint64 start_time_ns = 0;
   uint64 end_time_ns = 0;
   uint32 device_id = 0;
@@ -156,11 +157,17 @@
 
 class AnnotationMap {
  public:
+  struct AnnotationInfo {
+    absl::string_view annotation;
+    absl::string_view nvtx_range;
+  };
+
   explicit AnnotationMap(uint64 max_size, uint32 num_gpus)
       : max_size_(max_size), per_device_map_(num_gpus) {}
   void Add(uint32 device_id, uint32 correlation_id,
-           const std::string& annotation);
-  absl::string_view LookUp(uint32 device_id, uint32 correlation_id);
+           const absl::string_view annotation,
+           const absl::string_view nvtx_range);
+  AnnotationInfo LookUp(uint32 device_id, uint32 correlation_id);
 
  private:
   struct PerDeviceAnnotationMap {
@@ -170,7 +177,8 @@
     // Annotation tends to be repetitive, use a hash_set to store the strings,
     // an use the reference to the string in the map.
     absl::node_hash_set<std::string> annotations;
-    absl::flat_hash_map<uint32, absl::string_view> correlation_map;
+    absl::node_hash_set<std::string> nvtx_ranges;
+    absl::flat_hash_map<uint32, AnnotationInfo> correlation_map;
   };
   const uint64 max_size_;
   absl::FixedArray<PerDeviceAnnotationMap> per_device_map_;
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
index 51a04af..6d04aeb 100644
--- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
+++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
@@ -19,6 +19,7 @@
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/node_hash_map.h"
 #include "absl/container/node_hash_set.h"
+#include "third_party/gpus/cuda/extras/CUPTI/include/generated_nvtx_meta.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/host_info.h"
@@ -26,6 +27,8 @@
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/profiler/internal/cpu/annotation_stack.h"
+#include "tensorflow/core/profiler/internal/gpu/cupti_collector.h"
+#include "tensorflow/core/profiler/internal/gpu/nvtx_utils.h"
 
 namespace tensorflow {
 namespace profiler {
@@ -418,8 +421,10 @@
   event.context_id = kernel->contextId;
   event.stream_id = kernel->streamId;
   event.correlation_id = kernel->correlationId;
-  event.annotation = collector->annotation_map()->LookUp(event.device_id,
-                                                         event.correlation_id);
+  AnnotationMap::AnnotationInfo info = collector->annotation_map()->LookUp(
+      event.device_id, event.correlation_id);
+  event.annotation = info.annotation;
+  event.nvtx_range = info.nvtx_range;
   event.kernel_info.registers_per_thread = kernel->registersPerThread;
   event.kernel_info.static_shared_memory_usage = kernel->staticSharedMemory;
   event.kernel_info.dynamic_shared_memory_usage = kernel->dynamicSharedMemory;
@@ -464,8 +469,9 @@
   event.context_id = memcpy->contextId;
   event.stream_id = memcpy->streamId;
   event.correlation_id = memcpy->correlationId;
-  event.annotation = collector->annotation_map()->LookUp(event.device_id,
-                                                         event.correlation_id);
+  AnnotationMap::AnnotationInfo info = collector->annotation_map()->LookUp(
+      event.device_id, event.correlation_id);
+  event.annotation = info.annotation;
   event.memcpy_info.kind = memcpy->copyKind;
   event.memcpy_info.num_bytes = memcpy->bytes;
   event.memcpy_info.destination = memcpy->deviceId;
@@ -488,8 +494,9 @@
   event.context_id = memcpy2->contextId;
   event.stream_id = memcpy2->streamId;
   event.correlation_id = memcpy2->correlationId;
-  event.annotation = collector->annotation_map()->LookUp(event.device_id,
-                                                         event.correlation_id);
+  AnnotationMap::AnnotationInfo info = collector->annotation_map()->LookUp(
+      event.device_id, event.correlation_id);
+  event.annotation = info.annotation;
   event.memcpy_info.kind = CUPTI_ACTIVITY_MEMCPY_KIND_PTOP;
   event.memcpy_info.num_bytes = memcpy2->bytes;
   event.memcpy_info.destination = memcpy2->dstDeviceId;
@@ -946,8 +953,9 @@
     event.context_id = stream_info.ctx_info->context_id;
     event.stream_id = stream_info.stream_id;
     event.correlation_id = record.correlation_id;
-    event.annotation =
-        annotation_map->LookUp(event.device_id, event.correlation_id);
+    AnnotationMap::AnnotationInfo info = collector_->annotation_map()->LookUp(
+        event.device_id, event.correlation_id);
+    event.annotation = info.annotation;
     event.kernel_info = record.details;
     collector_->AddEvent(std::move(event));
     return Status::OK();
@@ -974,8 +982,9 @@
     event.context_id = stream_info.ctx_info->context_id;
     event.stream_id = stream_info.stream_id;
     event.correlation_id = record.correlation_id;
-    event.annotation =
-        annotation_map->LookUp(event.device_id, event.correlation_id);
+    AnnotationMap::AnnotationInfo info = collector_->annotation_map()->LookUp(
+        event.device_id, event.correlation_id);
+    event.annotation = info.annotation;
     event.memcpy_info.num_bytes = record.size_bytes;
     // TODO: support MemcpyD2D where destination != source;
     event.memcpy_info.destination = ordinal_;
@@ -1063,7 +1072,7 @@
           // Because annotation are per device, therefore we need to populate
           // annotation for each device involved.
           collector_->annotation_map()->Add(*dev_id, cbdata->correlationId,
-                                            annotation);
+                                            annotation, "");
           record_indices.push_back(
               cuda_event_recorders_[*dev_id]->StartKernel<CUDA_LAUNCH_PARAMS>(
                   "CooperativeKernelMultiDevice", *context,
@@ -1425,6 +1434,11 @@
     RETURN_IF_CUPTI_ERROR(cupti_interface_->EnableDomain(
         1 /* ENABLE */, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API));
   }
+
+  if (option_->enable_nvtx_tracking) {
+    RETURN_IF_CUPTI_ERROR(cupti_interface_->EnableDomain(
+        1 /* ENABLE */, subscriber_, CUPTI_CB_DOMAIN_NVTX));
+  }
   return Status::OK();
 }
 
@@ -1443,6 +1457,11 @@
         0 /* DISABLE */, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API));
   }
 
+  if (option_->enable_nvtx_tracking) {
+    RETURN_IF_CUPTI_ERROR(cupti_interface_->EnableDomain(
+        0 /* DISABLE */, subscriber_, CUPTI_CB_DOMAIN_NVTX));
+  }
+
   VLOG(1) << "Disable subscriber";
   RETURN_IF_CUPTI_ERROR(cupti_interface_->Unsubscribe(subscriber_));
   return Status::OK();
@@ -1510,11 +1529,31 @@
   return 0;
 }
 
+Status CuptiTracer::HandleNVTXCallback(CUpti_CallbackId cbid,
+                                       const CUpti_CallbackData *cbdata) {
+  const CUpti_NvtxData *pdata =
+      reinterpret_cast<const CUpti_NvtxData *>(cbdata);
+  if (cbid == CUPTI_CBID_NVTX_nvtxDomainRangePushEx) {
+    const nvtxDomainRangePushEx_params *params =
+        reinterpret_cast<const nvtxDomainRangePushEx_params *>(
+            pdata->functionParams);
+    // TODO(profiler): The messageType is actually NVTX_MESSAGE_TYPE_REGISTERED
+    // (which is 3), However it seems to me that we can not get the registered
+    // string from nvtxDomainRegisterStringA_params. If we reinterpret the
+    // payload as ascii, it happen to work.
+    NVTXRangeTracker::EnterRange(params->core.eventAttrib->message.ascii);
+  } else if (cbid == CUPTI_CBID_NVTX_nvtxDomainRangePop) {
+    NVTXRangeTracker::ExitRange();
+  }
+  return Status::OK();
+}
+
 Status CuptiTracer::HandleCallback(CUpti_CallbackDomain domain,
                                    CUpti_CallbackId cbid,
                                    const CUpti_CallbackData *cbdata) {
   if (!api_tracing_enabled_) return Status::OK();  // already unsubscribed.
   if (!cupti_driver_api_hook_) return Status::OK();  // already unsubscribed.
+  if (domain == CUPTI_CB_DOMAIN_NVTX) return HandleNVTXCallback(cbid, cbdata);
   if (domain != CUPTI_CB_DOMAIN_DRIVER_API) return Status::OK();
   if (internalCuCall) return Status::OK();
 
@@ -1546,11 +1585,12 @@
         // we need to populate per device annotation map respectively.
         for (int i = 0; i < num_gpus_; ++i) {
           collector_->annotation_map()->Add(i, cbdata->correlationId,
-                                            annotation);
+                                            annotation, "");
         }
       } else {
+        absl::string_view nvtx_range = NVTXRangeTracker::CurrentRange();
         collector_->annotation_map()->Add(device_id, cbdata->correlationId,
-                                          annotation);
+                                          annotation, nvtx_range);
       }
     }
 
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.h b/tensorflow/core/profiler/internal/gpu/cupti_tracer.h
index 3f7a2d4..970c4f9 100644
--- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.h
+++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.h
@@ -18,6 +18,7 @@
 
 #include "absl/types/optional.h"
 #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
+#include "third_party/gpus/cuda/include/nvtx3/nvToolsExt.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/status.h"
@@ -50,6 +51,8 @@
   bool cupti_finalize = false;
   // Whether to call cuCtxSynchronize for each device before Stop().
   bool sync_devices_before_stop = false;
+  // Whether to enable NVTX tracking, we need this for TensorRT tracking.
+  bool enable_nvtx_tracking = false;
 };
 
 class CuptiDriverApiHook {
@@ -111,6 +114,8 @@
   Status DisableActivityTracing();
   Status Finalize();
   void ConfigureActivityUnifiedMemoryCounter(bool enable);
+  Status HandleNVTXCallback(CUpti_CallbackId cbid,
+                            const CUpti_CallbackData* cbdata);
 
   int num_gpus_;
   absl::optional<CuptiTracerOptions> option_;
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/core/profiler/internal/gpu/nvtx_utils.cc
similarity index 65%
copy from tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
copy to tensorflow/core/profiler/internal/gpu/nvtx_utils.cc
index 2dd4a8d..ace1533 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
+++ b/tensorflow/core/profiler/internal/gpu/nvtx_utils.cc
@@ -13,13 +13,18 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+#include "tensorflow/core/profiler/internal/gpu/nvtx_utils.h"
+
+#include "third_party/gpus/cuda/include/nvtx3/nvToolsExt.h"
+#include "tensorflow/core/platform/platform.h"
 
 namespace tensorflow {
+namespace profiler {
 
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
+/*static*/ std::stack<std::string> &NVTXRangeTracker::GetRangeStack() {
+  static thread_local std::stack<std::string> range_stack;
+  return range_stack;
+}
 
+}  // namespace profiler
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/internal/gpu/nvtx_utils.h b/tensorflow/core/profiler/internal/gpu/nvtx_utils.h
new file mode 100644
index 0000000..b9085fa
--- /dev/null
+++ b/tensorflow/core/profiler/internal/gpu/nvtx_utils.h
@@ -0,0 +1,58 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_NVTX_UTILS_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_NVTX_UTILS_H_
+
+#include <stack>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace profiler {
+
+/***
+ * We have no intention to use NVTX in tensorflow right now, we use this class
+ * to track NVTX instrumentation inside NVIDIA libraries (such as TensorRT).
+ * This bears a lot of resemblance to ScopedAnnotation for now.  In the future,
+ * we will use TraceMe to keep track trace context within a thread.
+ */
+class NVTXRangeTracker {
+ public:
+  static void EnterRange(const std::string& range) {
+    auto& range_stack = GetRangeStack();
+    range_stack.push(range);
+  }
+  static void ExitRange() {
+    auto& range_stack = GetRangeStack();
+    if (!range_stack.empty()) range_stack.pop();
+  }
+  static const absl::string_view CurrentRange() {
+    auto& range_stack = GetRangeStack();
+    if (!range_stack.empty()) return range_stack.top();
+    return "";
+  }
+
+ private:
+  static std::stack<std::string>& GetRangeStack();
+
+  TF_DISALLOW_COPY_AND_ASSIGN(NVTXRangeTracker);
+};
+
+}  // namespace profiler
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_NVTX_UTILS_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_code.cc b/tensorflow/core/profiler/internal/tfprof_code.cc
index eb3501b..eb5e4a2 100644
--- a/tensorflow/core/profiler/internal/tfprof_code.cc
+++ b/tensorflow/core/profiler/internal/tfprof_code.cc
@@ -114,7 +114,7 @@
     func_pb->set_id(function_table_.size());
 
     string file_base(io::Basename(file_path));
-    file_base = file_base.substr(0, file_base.find_last_of("."));
+    file_base = file_base.substr(0, file_base.find_last_of('.'));
     func_pb->set_name(
         string_table_->GetIndex(absl::StrCat(file_base, ":", func_name)));
     func_pb->set_filename(string_table_->GetIndex(file_path));
diff --git a/tensorflow/core/profiler/internal/tfprof_scope.cc b/tensorflow/core/profiler/internal/tfprof_scope.cc
index ba0bcd9..62da2f3 100644
--- a/tensorflow/core/profiler/internal/tfprof_scope.cc
+++ b/tensorflow/core/profiler/internal/tfprof_scope.cc
@@ -48,13 +48,13 @@
     nodes_map_[name] = std::unique_ptr<ScopeNode>(new ScopeNode(node));
   }
 
-  auto last_slash = name.find_last_of("/");
+  auto last_slash = name.find_last_of('/');
   while (last_slash != name.npos) {
     name = name.substr(0, last_slash);
     if (nodes_map_.find(name) == nodes_map_.end()) {
       CHECK(CreateParentNode(name));
     }
-    last_slash = name.find_last_of("/");
+    last_slash = name.find_last_of('/');
   }
 }
 
@@ -65,7 +65,7 @@
   // Found roots, which are nodes without "/".
   for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
     ScopeNode* node = it->second.get();
-    auto last_slash = node->name().find_last_of("/");
+    auto last_slash = node->name().find_last_of('/');
     if (last_slash == string::npos) {
       roots.push_back(node);
     } else {
diff --git a/tensorflow/core/profiler/internal/tfprof_stats.cc b/tensorflow/core/profiler/internal/tfprof_stats.cc
index bd10522..ad90c47 100644
--- a/tensorflow/core/profiler/internal/tfprof_stats.cc
+++ b/tensorflow/core/profiler/internal/tfprof_stats.cc
@@ -212,7 +212,7 @@
       int output_idx = 0;
       // input name format can be: "^node:src_output"
       // if not :src_output, then it's the first one (further verify?)
-      auto prefix_pos = node_input.find(":");
+      auto prefix_pos = node_input.find(':');
       if (prefix_pos != node_input.npos) {
         std::vector<string> input_parts = absl::StrSplit(node_input, ':');
         DCHECK(input_parts.size() == 2)
@@ -287,7 +287,7 @@
     for (const NodeExecStats& node_stat : dev_stat.node_stats()) {
       string name = node_stat.node_name();
       // Sometimes the node_name is suffixed with unnecessary information.
-      auto split_pos = node_stat.node_name().find(":");
+      auto split_pos = node_stat.node_name().find(':');
       if (split_pos != node_stat.node_name().npos) {
         name = node_stat.node_name().substr(0, split_pos);
       }
diff --git a/tensorflow/core/profiler/internal/tpu/BUILD b/tensorflow/core/profiler/internal/tpu/BUILD
new file mode 100644
index 0000000..e76e7e9
--- /dev/null
+++ b/tensorflow/core/profiler/internal/tpu/BUILD
@@ -0,0 +1,29 @@
+load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
+load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts")
+
+package(
+    default_visibility = ["//tensorflow:internal"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "tpu_tracer",
+    srcs = ["tpu_tracer.cc"],
+    copts = tf_profiler_copts(),
+    deps = [
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/profiler:profiler_options_proto_cc",
+        "//tensorflow/core/profiler/lib:profiler_factory",
+        "//tensorflow/core/profiler/lib:profiler_interface",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+        "//tensorflow/core/profiler/utils:time_utils",
+        "//tensorflow/core/profiler/utils:xplane_schema",
+        "//tensorflow/core/profiler/utils:xplane_utils",
+        "//tensorflow/core/tpu:tpu_api",
+        "//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
+        "//tensorflow/stream_executor/tpu:status_helper",
+        "@com_google_absl//absl/strings",
+    ],
+    alwayslink = True,
+)
diff --git a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
new file mode 100644
index 0000000..e4cf245
--- /dev/null
+++ b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/lib/profiler_factory.h"
+#include "tensorflow/core/profiler/lib/profiler_interface.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/tpu/tpu_api.h"
+#include "tensorflow/core/tpu/tpu_ops_c_api.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace {
+
+// Tpu implementation of ProfilerInterface.
+//
+// Thread-safety: This class is go/thread-compatible.
+class TpuTracer : public ProfilerInterface {
+ public:
+  explicit TpuTracer();
+  ~TpuTracer() override;
+
+  Status Start() override;
+
+  Status Stop() override;
+
+  // Unsupported.
+  Status CollectData(RunMetadata* run_metadata) override;
+
+  Status CollectData(XSpace* space) override;
+
+ private:
+  TpuProfiler* tpu_profiler_;
+};
+
+TpuTracer::TpuTracer() {
+  tpu_profiler_ = tpu::OpsApiFn()->TpuProfiler_CreateFn();
+}
+
+TpuTracer::~TpuTracer() { tpu::OpsApiFn()->TpuProfiler_FreeFn(tpu_profiler_); }
+
+Status TpuTracer::Start() {
+  StatusHelper status;
+  tpu::OpsApiFn()->TpuProfiler_StartFn(tpu_profiler_, status.c_status);
+  if (!status.ok()) {
+    VLOG(1) << "Run Start failed.";
+    return status.status();
+  }
+  return Status::OK();
+}
+
+Status TpuTracer::Stop() {
+  StatusHelper status;
+  tpu::OpsApiFn()->TpuProfiler_StopFn(tpu_profiler_, status.c_status);
+  if (!status.ok()) {
+    VLOG(1) << "Run Stop failed.";
+    return status.status();
+  }
+  return Status::OK();
+}
+
+Status TpuTracer::CollectData(RunMetadata* run_metadata) {
+  // Unsupported
+  return Status::OK();
+}
+
+Status TpuTracer::CollectData(XSpace* space) {
+  StatusHelper status;
+  tpu::OpsApiFn()->TpuProfiler_CollectDataFn(tpu_profiler_, status.c_status,
+                                             space);
+  if (!status.ok()) {
+    VLOG(1) << "Run CollectData failed.";
+    return status.status();
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
+// Not in anonymous namespace for testing purposes.
+std::unique_ptr<ProfilerInterface> CreateTpuTracer(
+    const ProfileOptions& options) {
+  if (options.device_type() != ProfileOptions::TPU &&
+      options.device_type() != ProfileOptions::UNSPECIFIED) {
+    return nullptr;
+  }
+  return absl::make_unique<TpuTracer>();
+}
+
+auto register_host_tracer_factory = [] {
+  RegisterProfilerFactory(&CreateTpuTracer);
+  return 0;
+}();
+
+}  // namespace profiler
+}  // namespace tensorflow
diff --git a/tensorflow/core/profiler/lib/traceme_encode.h b/tensorflow/core/profiler/lib/traceme_encode.h
index 4dcd6ea..de1046c 100644
--- a/tensorflow/core/profiler/lib/traceme_encode.h
+++ b/tensorflow/core/profiler/lib/traceme_encode.h
@@ -133,12 +133,29 @@
     absl::string_view op_name, absl::string_view op_type) {
   return absl::StrCat(op_name, ":", op_type);
 }
+
+TF_ATTRIBUTE_ALWAYS_INLINE inline std::string TraceMeOp(const char* op_name,
+                                                        const char* op_type) {
+  return absl::StrCat(op_name, ":", op_type);
+}
+
 TF_ATTRIBUTE_ALWAYS_INLINE inline std::string TraceMeOp(
     std::string&& op_name, absl::string_view op_type) {
   absl::StrAppend(&op_name, ":", op_type);
   return op_name;
 }
 
+// Concatenates op_name and op_type.
+TF_ATTRIBUTE_ALWAYS_INLINE inline std::string TraceMeOpOverride(
+    absl::string_view op_name, absl::string_view op_type) {
+  return absl::StrCat("#tf_op=", op_name, ":", op_type, "#");
+}
+
+TF_ATTRIBUTE_ALWAYS_INLINE inline std::string TraceMeOpOverride(
+    const char* op_name, const char* op_type) {
+  return absl::StrCat("#tf_op=", op_name, ":", op_type, "#");
+}
+
 }  // namespace profiler
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/profiler/profiler.cc b/tensorflow/core/profiler/profiler.cc
index 559d85a..6ddabbc 100644
--- a/tensorflow/core/profiler/profiler.cc
+++ b/tensorflow/core/profiler/profiler.cc
@@ -45,7 +45,7 @@
 namespace tfprof {
 void completion(const char* buf, linenoiseCompletions* lc) {
   string buf_str = buf;
-  if (buf_str.find(" ") == buf_str.npos) {
+  if (buf_str.find(' ') == buf_str.npos) {
     for (const char* opt : kCmds) {
       if (string(opt).find(buf_str) == 0) {
         linenoiseAddCompletion(lc, opt);
diff --git a/tensorflow/core/profiler/tfprof_options.cc b/tensorflow/core/profiler/tfprof_options.cc
index a482df4..b18433f 100644
--- a/tensorflow/core/profiler/tfprof_options.cc
+++ b/tensorflow/core/profiler/tfprof_options.cc
@@ -43,7 +43,7 @@
 
   std::set<string> output_types(kOutput,
                                 kOutput + sizeof(kOutput) / sizeof(*kOutput));
-  auto opt_split = output_opt.find(":");
+  auto opt_split = output_opt.find(':');
   std::vector<string> kv_split;
   if (opt_split == output_opt.npos) {
     if (output_types.find(output_opt) == output_types.end()) {
diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc
index e2cfc37..5b6af7d 100644
--- a/tensorflow/core/profiler/utils/derived_timeline.cc
+++ b/tensorflow/core/profiler/utils/derived_timeline.cc
@@ -278,10 +278,7 @@
         Timespan& group_span = group_launch_info.timespan;
         Timespan event_span = event.GetTimespan();
         if (group_launch_info.num_launches) {  // Existing group.
-          uint64 begin_ps =
-              std::min(group_span.begin_ps(), event_span.begin_ps());
-          uint64 end_ps = std::max(group_span.end_ps(), event_span.end_ps());
-          group_span = Timespan::FromEndPoints(begin_ps, end_ps);
+          group_span.ExpandToInclude(event_span);
         } else {
           group_span = event_span;
         }
diff --git a/tensorflow/core/profiler/utils/event_span.h b/tensorflow/core/profiler/utils/event_span.h
index c68be9e..9641568 100644
--- a/tensorflow/core/profiler/utils/event_span.h
+++ b/tensorflow/core/profiler/utils/event_span.h
@@ -38,33 +38,41 @@
   UNKNOWN_TIME = 0,
   // Host is computing.
   HOST_COMPUTE = 10,
+  // Host is preprocessing the data before the execution on device.
+  HOST_PREPROCESS = 20,
+  // Host is postprocessing the data after the execution on device.
+  HOST_POSTPROCESS = 30,
+  // Host is batching data (for inference).
+  HOST_BATCH_FORMATION = 40,
+  // Host runtime, like memory allocation and etc.
+  HOST_RUNTIME = 50,
   // Host is compiling.
-  HOST_COMPILE = 20,
+  HOST_COMPILE = 60,
   // Host-to-host communication.
-  HOST_TO_HOST = 30,
+  HOST_TO_HOST = 70,
   // Host-to-device communication.
-  HOST_TO_DEVICE = 40,
+  HOST_TO_DEVICE = 80,
   // Host is preparing to launch a computation on device.
-  HOST_PREPARE = 50,
+  HOST_PREPARE = 90,
   // Assigns a smaller priority to DEVICE_COLLECTIVES than HOST_WAIT_INPUT,
   // because if an all-reduce event is overlapped with an host-wait-input event,
   // we want to count it as waiting for input.
   // Collective Ops such as All-Reduce.
-  DEVICE_COLLECTIVES = 60,
+  DEVICE_COLLECTIVES = 100,
   // Host is waiting for input.
-  HOST_WAIT_INPUT = 70,
+  HOST_WAIT_INPUT = 110,
   // Device-to-device communication.
-  DEVICE_TO_DEVICE = 80,
+  DEVICE_TO_DEVICE = 120,
   // Device-to-host communication.
-  DEVICE_TO_HOST = 90,
+  DEVICE_TO_HOST = 130,
   // Device is computing with 32-bit precision.
-  DEVICE_COMPUTE_32 = 100,
+  DEVICE_COMPUTE_32 = 140,
   // Device is computing with 16-bit precision.
-  DEVICE_COMPUTE_16 = 110,
+  DEVICE_COMPUTE_16 = 150,
   // Device is waiting for another device.
-  DEVICE_WAIT_DEVICE = 120,
+  DEVICE_WAIT_DEVICE = 160,
   // Device is waiting for host.
-  DEVICE_WAIT_HOST = 130,
+  DEVICE_WAIT_HOST = 170,
   LAST_EVENT_TYPE = DEVICE_WAIT_HOST
 };
 
diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.cc b/tensorflow/core/profiler/utils/hardware_type_utils.cc
index 3fc382f..a32a719 100644
--- a/tensorflow/core/profiler/utils/hardware_type_utils.cc
+++ b/tensorflow/core/profiler/utils/hardware_type_utils.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/profiler/utils/hardware_type_utils.h"
 
+#include "absl/strings/match.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
@@ -85,9 +86,31 @@
          device_cap.clock_rate_in_ghz();
 }
 
+absl::string_view GpuModelName(const DeviceCapabilities& device_cap) {
+  switch (device_cap.compute_capability().major()) {
+    case 2:
+      return "Nvidia GPU (Fermi)";
+    case 3:
+      return "Nvidia GPU (Kepler)";
+    case 5:
+      return "Nvidia GPU (Maxwell)";
+    case 6:
+      return "Nvidia GPU (Pascal)";
+    case 7:
+      if (device_cap.compute_capability().minor() < 5) {
+        return "Nvidia GPU (Volta)";
+      } else {
+        return "Nvidia GPU (Turing)";
+      }
+    case 8:
+      return "Nvidia GPU (Ampere)";
+    default:
+      return "Nvidia GPU";
+  }
+}
+
 HardwareType ParseHardwareType(absl::string_view device_type) {
-  if (device_type == "GPU" || device_type == "Nvidia GPU")
-    return HardwareType::GPU;
+  if (absl::StrContains(device_type, "GPU")) return HardwareType::GPU;
   if (device_type == "CPU") return HardwareType::CPU_ONLY;
   if (device_type == "TPU") return HardwareType::TPU;
   return HardwareType::UNKNOWN_HARDWARE;
diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.h b/tensorflow/core/profiler/utils/hardware_type_utils.h
index 4a1470a..56734af 100644
--- a/tensorflow/core/profiler/utils/hardware_type_utils.h
+++ b/tensorflow/core/profiler/utils/hardware_type_utils.h
@@ -26,6 +26,9 @@
 // streaming multiprocessor.
 double GetFlopMaxThroughputPerSM(const DeviceCapabilities& device_cap);
 
+// Returns the GPU model name from the given DeviceCapabilities.
+absl::string_view GpuModelName(const DeviceCapabilities& device_cap);
+
 HardwareType ParseHardwareType(absl::string_view device_type);
 
 // Returns true if the given hardware type has a device.
diff --git a/tensorflow/core/profiler/utils/timespan.h b/tensorflow/core/profiler/utils/timespan.h
index 82775af..7228273 100644
--- a/tensorflow/core/profiler/utils/timespan.h
+++ b/tensorflow/core/profiler/utils/timespan.h
@@ -79,6 +79,12 @@
            std::max(begin_ps(), other.begin_ps());
   }
 
+  // Expands the timespan to include other.
+  void ExpandToInclude(const Timespan& other) {
+    *this = FromEndPoints(std::min(begin_ps(), other.begin_ps()),
+                          std::max(end_ps(), other.end_ps()));
+  }
+
   // Compares timespans by their begin time (ascending), duration (descending)
   // so nested spans are sorted from outer to innermost.
   bool operator<(const Timespan& other) const {
diff --git a/tensorflow/core/profiler/utils/xplane_builder.h b/tensorflow/core/profiler/utils/xplane_builder.h
index e4ff439..bdc9575 100644
--- a/tensorflow/core/profiler/utils/xplane_builder.h
+++ b/tensorflow/core/profiler/utils/xplane_builder.h
@@ -226,6 +226,7 @@
 
   int64 NumEvents() const { return line_->events_size(); }
 
+  absl::string_view Name() const { return line_->name(); }
   void SetName(absl::string_view name) { line_->set_name(std::string(name)); }
 
   void SetNameIfEmpty(absl::string_view name) {
@@ -271,6 +272,7 @@
   int64 Id() const { return plane_->id(); }
   void SetId(int64 id) { plane_->set_id(id); }
 
+  absl::string_view Name() const { return plane_->name(); }
   void SetName(absl::string_view name) { plane_->set_name(std::string(name)); }
 
   void ReserveLines(size_t num_lines) {
diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc
index 7dd00c4..f0ea78d 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.cc
+++ b/tensorflow/core/profiler/utils/xplane_schema.cc
@@ -176,6 +176,7 @@
       {"memalloc_details", kMemallocDetails},
       {"kernel_details", kKernelDetails},
       {"annotation", kKernelAnnotation},
+      {"nvtx_range", kNVTXRange},
       {"stream", kStream},
       // Stats added when processing traces.
       {"group_id", kGroupId},
diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h
index ad4c100..ca51874 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.h
+++ b/tensorflow/core/profiler/utils/xplane_schema.h
@@ -164,6 +164,7 @@
   kMemcpyDetails,
   kMemallocDetails,
   kKernelAnnotation,
+  kNVTXRange,
   kKernelDetails,
   kStream,
   // Stats added when processing traces.
diff --git a/tensorflow/core/profiler/utils/xplane_utils.cc b/tensorflow/core/profiler/utils/xplane_utils.cc
index 2df0a5b..5b7d22c 100644
--- a/tensorflow/core/profiler/utils/xplane_utils.cc
+++ b/tensorflow/core/profiler/utils/xplane_utils.cc
@@ -20,6 +20,7 @@
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
 #include "absl/strings/match.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/logging.h"
@@ -38,51 +39,84 @@
 // Returns the index of the first element in array for which pred is true.
 // Returns -1 if no such element is found.
 template <typename T, typename Pred>
-int FindIf(const protobuf::RepeatedPtrField<T>& array, Pred&& pred) {
+int Find(const protobuf::RepeatedPtrField<T>& array, const Pred& pred) {
   for (int i = 0; i < array.size(); ++i) {
     if (pred(&array.Get(i))) return i;
   }
   return -1;
 }
 
-// Removes the given element from array.
-template <typename T>
-void Remove(protobuf::RepeatedPtrField<T>* array, const T* elem) {
-  int i = FindIf(*array, [elem](const T* e) { return elem == e; });
-  if (i == -1) return;
-  for (; i < array->size() - 1; ++i) {
-    array->SwapElements(i + 1, i);
+// Returns the indices of all elements in array for which pred is true.
+template <typename T, typename Pred>
+std::vector<int> FindAll(const protobuf::RepeatedPtrField<T>& array,
+                         const Pred& pred) {
+  std::vector<int> indices;
+  for (int i = 0; i < array.size(); ++i) {
+    if (pred(&array.Get(i))) indices.push_back(i);
   }
-  array->RemoveLast();
+  return indices;
 }
 
-template <typename T, typename Pred>
-void RemoveIf(protobuf::RepeatedPtrField<T>* array, Pred&& pred) {
-  int i = FindIf(*array, pred);
-  if (i == -1) return;
+template <typename T>
+void RemoveAt(protobuf::RepeatedPtrField<T>* array,
+              const std::vector<int>& indices) {
+  if (indices.empty()) return;
+  if (array->size() == indices.size()) {
+    // Assumes that 'indices' consists of [0 ... N-1].
+    array->Clear();
+    return;
+  }
+  auto remove_iter = indices.begin();
+  int i = *(remove_iter++);
   for (int j = i + 1; j < array->size(); ++j) {
-    if (!pred(&array->Get(j))) array->SwapElements(j, i++);
+    if (remove_iter != indices.end() && *remove_iter == j) {
+      ++remove_iter;
+    } else {
+      array->SwapElements(j, i++);
+    }
   }
   array->DeleteSubrange(i, array->size() - i);
 }
 
-// Creates a Timespan from an XEvent.
-// WARNING: This should only be used when comparing events from the same XLine.
-Timespan XEventTimespan(const XEvent& event) {
-  return Timespan(event.offset_ps(), event.duration_ps());
+// Removes the given element from array.
+template <typename T>
+void Remove(protobuf::RepeatedPtrField<T>* array, const T* elem) {
+  int i = Find(*array, [elem](const T* e) { return elem == e; });
+  RemoveAt(array, {i});
+}
+
+template <typename T, typename Pred>
+void RemoveIf(protobuf::RepeatedPtrField<T>* array, Pred&& pred) {
+  std::vector<int> indices = FindAll(*array, pred);
+  RemoveAt(array, indices);
 }
 
 }  // namespace
 
 const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name) {
-  int i = FindIf(space.planes(),
-                 [name](const XPlane* plane) { return plane->name() == name; });
+  int i = Find(space.planes(),
+               [name](const XPlane* plane) { return plane->name() == name; });
   return (i != -1) ? &space.planes(i) : nullptr;
 }
 
+std::vector<const XPlane*> FindPlanesWithNames(
+    const XSpace& space, const std::vector<absl::string_view>& names) {
+  absl::flat_hash_set<absl::string_view> names_set(names.begin(), names.end());
+  std::vector<int> indices =
+      FindAll(space.planes(), [&names_set](const XPlane* plane) {
+        return names_set.contains(plane->name());
+      });
+  std::vector<const XPlane*> planes;
+  planes.reserve(indices.size());
+  for (int i : indices) {
+    planes.push_back(&space.planes(i));
+  }
+  return planes;
+}
+
 XPlane* FindMutablePlaneWithName(XSpace* space, absl::string_view name) {
-  int i = FindIf(space->planes(),
-                 [name](const XPlane* plane) { return plane->name() == name; });
+  int i = Find(space->planes(),
+               [name](const XPlane* plane) { return plane->name() == name; });
   return (i != -1) ? space->mutable_planes(i) : nullptr;
 }
 
@@ -114,8 +148,8 @@
 }
 
 const XLine* FindLineWithId(const XPlane& plane, int64 id) {
-  int i = FindIf(plane.lines(),
-                 [id](const XLine* line) { return line->id() == id; });
+  int i =
+      Find(plane.lines(), [id](const XLine* line) { return line->id() == id; });
   return (i != -1) ? &plane.lines(i) : nullptr;
 }
 
@@ -135,6 +169,13 @@
   Remove(space->mutable_planes(), plane);
 }
 
+void RemovePlanes(XSpace* space, const std::vector<const XPlane*>& planes) {
+  absl::flat_hash_set<const XPlane*> planes_set(planes.begin(), planes.end());
+  RemoveIf(space->mutable_planes(), [&planes_set](const XPlane* plane) {
+    return planes_set.contains(plane);
+  });
+}
+
 void RemoveLine(XPlane* plane, const XLine* line) {
   DCHECK(line != nullptr);
   Remove(plane->mutable_lines(), line);
@@ -251,6 +292,13 @@
   });
 }
 
+void MergePlanes(const std::vector<const XPlane*>& src_planes,
+                 XPlane* dst_plane) {
+  for (const XPlane* src_plane : src_planes) {
+    MergePlanes(*src_plane, dst_plane);
+  }
+}
+
 uint64 GetStartTimestampNs(const XPlane& plane) {
   int64 plane_timestamp = 0;
   for (const auto& line : plane.lines()) {
diff --git a/tensorflow/core/profiler/utils/xplane_utils.h b/tensorflow/core/profiler/utils/xplane_utils.h
index d3b6c63..8358c5c 100644
--- a/tensorflow/core/profiler/utils/xplane_utils.h
+++ b/tensorflow/core/profiler/utils/xplane_utils.h
@@ -22,15 +22,26 @@
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/timespan.h"
 #include "tensorflow/core/profiler/utils/trace_utils.h"
 
 namespace tensorflow {
 namespace profiler {
 
+// Returns a Timespan from an XEvent.
+// WARNING: This should only be used when comparing events from the same XLine.
+inline Timespan XEventTimespan(const XEvent& event) {
+  return Timespan(event.offset_ps(), event.duration_ps());
+}
+
 // Returns the plane with the given name or nullptr if not found.
 const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name);
 XPlane* FindMutablePlaneWithName(XSpace* space, absl::string_view name);
 
+// Returns the planes with the given names, if found.
+std::vector<const XPlane*> FindPlanesWithNames(
+    const XSpace& space, const std::vector<absl::string_view>& names);
+
 // Returns the plane with the given name in the container. If necessary, adds a
 // new plane to the container.
 XPlane* FindOrAddMutablePlaneWithName(XSpace* space, absl::string_view name);
@@ -47,6 +58,7 @@
 XStat* FindOrAddMutableStat(const XStatMetadata& stat_metadata, XEvent* event);
 
 void RemovePlane(XSpace* space, const XPlane* plane);
+void RemovePlanes(XSpace* space, const std::vector<const XPlane*>& planes);
 void RemoveLine(XPlane* plane, const XLine* line);
 void RemoveEvents(XLine* line,
                   const absl::flat_hash_set<const XEvent*>& events);
@@ -100,12 +112,16 @@
 void NormalizeTimestamps(XPlane* plane, uint64 start_time_ns);
 void NormalizeTimestamps(XSpace* space, uint64 start_time_ns);
 
-// Merge Xplane src_plane into Xplane dst_plane, both plane level stats, lines,
-// events and event level stats are merged; If src_plane and dst_plane both have
-// the same line, which have different start timestamps, we will normalize the
-// events offset timestamp correspondingly.
+// Merges src_plane into dst_plane. Both plane level stats, lines, events and
+// event level stats are merged. If src_plane and dst_plane both have the same
+// line, which have different start timestamps, we will normalize the events
+// offset timestamp correspondingly.
 void MergePlanes(const XPlane& src_plane, XPlane* dst_plane);
 
+// Merges each plane with a src_planes, into the dst_plane.
+void MergePlanes(const std::vector<const XPlane*>& src_planes,
+                 XPlane* dst_plane);
+
 // Plane's start timestamp is defined as the minimum of all lines' start
 // timestamps. If zero line exists, return 0;
 uint64 GetStartTimestampNs(const XPlane& plane);
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 29e3e8a..9b50d5e 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -593,6 +593,11 @@
       MLIR_BRIDGE_ROLLOUT_ENABLED = 1;
       // Disabling the MLIR bridge disables it for all graphs in this session.
       MLIR_BRIDGE_ROLLOUT_DISABLED = 2;
+      // Enable the MLIR bridge on a per graph basis based on an analysis of
+      // the features used in the graph. If the features used by the graph are
+      // supported by the MLIR bridge, the MLIR bridge will be used to run the
+      // graph.
+      MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED = 3;
     }
     // This field is underdevelopment, for now use enable_mlir_bridge
     // (b/166038521).
@@ -620,6 +625,11 @@
     // The XLA fusion autotuner can improve performance by executing a heuristic
     // search on the compiler parameters.
     int64 xla_fusion_autotuner_thresh = 15;
+
+    // Whether runtime execution uses TFRT.
+    bool use_tfrt = 18;
+
+    // Next: 19
   }
 
   Experimental experimental = 16;
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 66572fa..4feca91 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -108,7 +108,7 @@
 
 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 608  // Updated: 2020/12/7
+#define TF_GRAPH_DEF_VERSION 623  // Updated: 2020/12/22
 
 // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
 //
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 8264ec3..cd288bf 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -8,6 +8,7 @@
 
 package(
     default_visibility = [
+        "//tensorflow/compiler/mlir/tensorflow:__subpackages__",
         "//tensorflow/compiler/tf2xla/kernels:__subpackages__",
         "//tensorflow/core/tpu:__subpackages__",
         "//tensorflow/stream_executor/tpu:__subpackages__",
@@ -115,6 +116,7 @@
     name = "tpu_api",
     srcs = ["tpu_api.cc"],
     hdrs = ["tpu_api.h"],
+    visibility = ["//visibility:public"],
     deps = [
         ":libtftpu_header",
         ":tpu_executor_api",
@@ -343,6 +345,7 @@
     visibility = ["//visibility:public"],
     deps = [
         ":libtftpu_header",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/stream_executor/tpu:c_api_decl",
         "//tensorflow/stream_executor/tpu:proto_helper",
     ],
diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
index b7c71dc..333b893 100644
--- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
@@ -54,6 +54,7 @@
 const char* const kTPUReplicatedInput = "TPUReplicatedInput";
 const char* const kTPUReplicatedOutput = "TPUReplicatedOutput";
 const char* const kPivotForClusterAttr = "_pivot_for_cluster";
+const char* const kTPUPartitionedInput = "TPUPartitionedInput";
 
 // Finds the `index` of an _Arg or _Retval node.
 Status GetIndexAttr(const Node& n, int num_args, int* index) {
@@ -1586,7 +1587,18 @@
         }
       }
       if (!has_output) {
+        // Remove any TPUPartitionedInput node from the src nodes of the
+        // to-be-removed TPUReplicatedInput node
+        std::vector<Node*> to_be_removed_src_nodes;
+        for (const auto& e_in : n->in_edges()) {
+          if (!e_in->IsControlEdge() &&
+              e_in->src()->type_string() == kTPUPartitionedInput)
+            to_be_removed_src_nodes.push_back(e_in->src());
+        }
         graph->RemoveNode(n);
+        for (Node* node : to_be_removed_src_nodes) {
+          graph->RemoveNode(node);
+        }
       }
     }
   }
diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
index 6a9ef76..4d3f813 100644
--- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
+++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
@@ -52,7 +52,7 @@
     void (*initialize_fn)(bool init_library, int argc, char** argv);
     initialize_fn = reinterpret_cast<decltype(initialize_fn)>(
         dlsym(library_handle, "TfTpu_Initialize"));
-    (*initialize_fn)(/*init_library=*/false, /*argc=*/0, /*argv=*/nullptr);
+    (*initialize_fn)(/*init_library=*/true, /*argc=*/0, /*argv=*/nullptr);
 
     RegisterTpuPlatform();
     RegisterTpuSystemDevice();
diff --git a/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
index 1444dd7..84f01c2 100644
--- a/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
+++ b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
@@ -47,7 +47,7 @@
     void (*initialize_fn)(bool init_library, int argc, char** argv);
     initialize_fn = reinterpret_cast<decltype(initialize_fn)>(
         dlsym(library_handle, "TfTpu_Initialize"));
-    (*initialize_fn)(/*init_library=*/false, /*argc=*/0, /*argv=*/nullptr);
+    (*initialize_fn)(/*init_library=*/true, /*argc=*/0, /*argv=*/nullptr);
 
     RegisterTpuPlatform();
   }
diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc
index ef4d7ba..0b984fa 100644
--- a/tensorflow/core/tpu/tpu_library_init_fns.inc
+++ b/tensorflow/core/tpu/tpu_library_init_fns.inc
@@ -70,6 +70,12 @@
   TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateCompilationCacheKey);
   TFTPU_SET_FN(ops_api_fn, TpuCompile_DestroyCompilationCacheKey);
   TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateGuaranteedConstFingerprint);
+  
+  TFTPU_SET_FN(ops_api_fn, TpuProfiler_Create);
+  TFTPU_SET_FN(ops_api_fn, TpuProfiler_Free);
+  TFTPU_SET_FN(ops_api_fn, TpuProfiler_Start);
+  TFTPU_SET_FN(ops_api_fn, TpuProfiler_Stop);
+  TFTPU_SET_FN(ops_api_fn, TpuProfiler_CollectData);
 
   return tensorflow::Status::OK();
 }
diff --git a/tensorflow/core/tpu/tpu_on_demand_compiler.cc b/tensorflow/core/tpu/tpu_on_demand_compiler.cc
index c99808f..069cf0d 100644
--- a/tensorflow/core/tpu/tpu_on_demand_compiler.cc
+++ b/tensorflow/core/tpu/tpu_on_demand_compiler.cc
@@ -251,7 +251,7 @@
   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
       std::unique_ptr<HloModule> module,
       stream_executor::StreamExecutor* executor,
-      stream_executor::DeviceMemoryAllocator* device_allocator) override {
+      const CompileOptions& options) override {
     XLA_HloModule hlo_module;
     XLA_HloModule result;
     auto cleanup = xla::MakeCleanup([&hlo_module, &result]() {
@@ -261,7 +261,7 @@
     });
     hlo_module.module_config = HloModuleConfigToC(module->config());
     hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
-    auto allocator = ApiConverter::ToC(device_allocator);
+    auto allocator = ApiConverter::ToC(options.device_allocator);
     StatusHelper status;
     ExecutorApiFn()->TpuCompiler_RunHloPassesFn(
         compiler_, &hlo_module,
@@ -279,7 +279,7 @@
   StatusOr<std::unique_ptr<Executable>> RunBackend(
       std::unique_ptr<HloModule> module,
       stream_executor::StreamExecutor* executor,
-      stream_executor::DeviceMemoryAllocator* device_allocator) override {
+      const CompileOptions& options) override {
     XLA_HloModule hlo_module;
     auto cleanup = xla::MakeCleanup([&hlo_module]() {
       stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
@@ -288,7 +288,7 @@
     SE_Executable* result;
     hlo_module.module_config = HloModuleConfigToC(module->config());
     hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
-    auto allocator = ApiConverter::ToC(device_allocator);
+    auto allocator = ApiConverter::ToC(options.device_allocator);
 
     StatusHelper status;
     ExecutorApiFn()->TpuCompiler_RunBackendFn(
@@ -308,7 +308,7 @@
   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
       std::unique_ptr<HloModuleGroup> module_group,
       std::vector<std::vector<stream_executor::StreamExecutor*>> stream_exec,
-      stream_executor::DeviceMemoryAllocator* device_allocator) override {
+      const CompileOptions& options) override {
     XLA_HloModuleGroup se_module_group;
     se_module_group.proto =
         stream_executor::tpu::SerializeProto(module_group->ToProto());
@@ -339,7 +339,8 @@
       }
     }
 
-    SE_DeviceMemoryAllocator allocator = ApiConverter::ToC(device_allocator);
+    SE_DeviceMemoryAllocator allocator =
+        ApiConverter::ToC(options.device_allocator);
 
     SE_Executable** se_executables = new SE_Executable*[module_group->size()];
 
diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h
index 77e5ddb..f361110 100644
--- a/tensorflow/core/tpu/tpu_ops_c_api.h
+++ b/tensorflow/core/tpu/tpu_ops_c_api.h
@@ -19,6 +19,7 @@
 
 #include <cstdint>
 
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/tpu/libtftpu.h"
 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
 #include "tensorflow/stream_executor/tpu/proto_helper.h"
@@ -53,6 +54,8 @@
 
 typedef struct XLA_TpuMeshState XLA_TpuMeshState;
 
+typedef struct TpuProfiler TpuProfiler;
+
 typedef struct XLA_DeviceAssignment {
   const char* bytes;
   size_t size;
@@ -103,6 +106,21 @@
     TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state,
     XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
 
+// Creates a new TPU profiler object.
+TFTPU_CAPI_EXPORT TpuProfiler* TpuProfiler_Create();
+
+TFTPU_CAPI_EXPORT TpuProfiler* TpuProfiler_Free(TpuProfiler* tpu_profiler);
+
+TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler,
+                                         TF_Status* status);
+
+TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler,
+                                        TF_Status* status);
+
+TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(
+    TpuProfiler* tpu_profiler, TF_Status* status,
+    tensorflow::profiler::XSpace* space);
+
 // Creates a new TPU mesh state object.
 TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
 
@@ -397,6 +415,12 @@
   TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
   TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
 
+  TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create);
+  TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Free);
+  TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start);
+  TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop);
+  TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData);
+
   TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
   TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
   TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD
index e3e7d6e..fb40248 100644
--- a/tensorflow/core/util/BUILD
+++ b/tensorflow/core/util/BUILD
@@ -274,6 +274,7 @@
     visibility = [
         "//tensorflow/core:__pkg__",
         "//tensorflow/python:__pkg__",
+        "//tensorflow/python/util:__pkg__",
     ],
 )
 
@@ -460,6 +461,7 @@
     visibility = [
         "//tensorflow/core:__pkg__",
         "//tensorflow/python:__pkg__",
+        "//tensorflow/python/util:__pkg__",
     ],
     alwayslink = 1,
 )
@@ -474,6 +476,7 @@
         "//tensorflow/core/platform:__pkg__",
         "//tensorflow/python:__pkg__",
         "//tensorflow/python/eager:__pkg__",
+        "//tensorflow/python/util:__pkg__",
     ],
     deps = [
         "//tensorflow/core/platform:status",
diff --git a/tensorflow/core/util/stat_summarizer.cc b/tensorflow/core/util/stat_summarizer.cc
index f114491..61994d0 100644
--- a/tensorflow/core/util/stat_summarizer.cc
+++ b/tensorflow/core/util/stat_summarizer.cc
@@ -121,7 +121,7 @@
   std::string::size_type start = label.find(sep);
   if (start == std::string::npos) return "<>";
   start += sep.size();
-  std::string::size_type end = label.find("(", start);
+  std::string::size_type end = label.find('(', start);
   if (end == std::string::npos) return "<>";
   return label.substr(start, end - start);
 }
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 25eb889..484f811 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -2931,27 +2931,6 @@
 	return op.Output(0)
 }
 
-// Makes a copy of `x`.
-//
-// Arguments:
-//	x: The source tensor of type `T`.
-//
-// Returns     y: A `Tensor` of type `T`. A copy of `x`. Guaranteed that `y`
-//       is not an alias of `x`.
-func DeepCopy(scope *Scope, x tf.Output) (y tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "DeepCopy",
-		Input: []tf.Input{
-			x,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // PackAttr is an optional argument to Pack.
 type PackAttr func(optionalAttr)
 
@@ -4610,58 +4589,6 @@
 	return op.Output(0), op.Output(1)
 }
 
-// Sends the named tensor to another XLA computation. Wraps the XLA Send operator
-//
-// documented at
-//  https://www.tensorflow.org/performance/xla/operation_semantics#send .
-//
-// Arguments:
-//	tensor: The tensor to send.
-//	tensor_name: A string key that identifies the channel.
-//
-// Returns the created operation.
-func XlaSend(scope *Scope, tensor tf.Output, tensor_name string) (o *tf.Operation) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"tensor_name": tensor_name}
-	opspec := tf.OpSpec{
-		Type: "XlaSend",
-		Input: []tf.Input{
-			tensor,
-		},
-		Attrs: attrs,
-	}
-	return scope.AddOperation(opspec)
-}
-
-// Returns the index of a data point that should be added to the seed set.
-//
-// Entries in distances are assumed to be squared distances of candidate points to
-// the already sampled centers in the seed set. The op constructs one Markov chain
-// of the k-MC^2 algorithm and returns the index of one candidate point to be added
-// as an additional cluster center.
-//
-// Arguments:
-//	distances: Vector with squared distances to the closest previously sampled cluster center
-// for each candidate point.
-//	seed: Scalar. Seed for initializing the random number generator.
-//
-// Returns Scalar with the index of the sampled point.
-func KMC2ChainInitialization(scope *Scope, distances tf.Output, seed tf.Output) (index tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "KMC2ChainInitialization",
-		Input: []tf.Input{
-			distances, seed,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Selects num_to_sample rows of input using the KMeans++ criterion.
 //
 // Rows of points are assumed to be input points. One row is selected at random.
@@ -4693,6 +4620,45 @@
 	return op.Output(0)
 }
 
+// CollectiveBcastRecvV2Attr is an optional argument to CollectiveBcastRecvV2.
+type CollectiveBcastRecvV2Attr func(optionalAttr)
+
+// CollectiveBcastRecvV2CommunicationHint sets the optional communication_hint attribute to value.
+// If not specified, defaults to "auto"
+func CollectiveBcastRecvV2CommunicationHint(value string) CollectiveBcastRecvV2Attr {
+	return func(m optionalAttr) {
+		m["communication_hint"] = value
+	}
+}
+
+// CollectiveBcastRecvV2TimeoutSeconds sets the optional timeout_seconds attribute to value.
+// If not specified, defaults to 0
+func CollectiveBcastRecvV2TimeoutSeconds(value float32) CollectiveBcastRecvV2Attr {
+	return func(m optionalAttr) {
+		m["timeout_seconds"] = value
+	}
+}
+
+// Receives a tensor value broadcast from another device.
+func CollectiveBcastRecvV2(scope *Scope, group_size tf.Output, group_key tf.Output, instance_key tf.Output, shape tf.Output, T tf.DataType, optional ...CollectiveBcastRecvV2Attr) (data tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"T": T}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "CollectiveBcastRecvV2",
+		Input: []tf.Input{
+			group_size, group_key, instance_key, shape,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // AbortAttr is an optional argument to Abort.
 type AbortAttr func(optionalAttr)
 
@@ -15079,459 +15045,280 @@
 	return op.Output(0)
 }
 
-// DebugNumericSummaryAttr is an optional argument to DebugNumericSummary.
-type DebugNumericSummaryAttr func(optionalAttr)
-
-// DebugNumericSummaryDeviceName sets the optional device_name attribute to value.
-// If not specified, defaults to ""
-func DebugNumericSummaryDeviceName(value string) DebugNumericSummaryAttr {
-	return func(m optionalAttr) {
-		m["device_name"] = value
-	}
-}
-
-// DebugNumericSummaryTensorName sets the optional tensor_name attribute to value.
+// Calculate product with tridiagonal matrix.
 //
-// value: Name of the input tensor.
-// If not specified, defaults to ""
-func DebugNumericSummaryTensorName(value string) DebugNumericSummaryAttr {
-	return func(m optionalAttr) {
-		m["tensor_name"] = value
-	}
-}
-
-// DebugNumericSummaryDebugUrls sets the optional debug_urls attribute to value.
-//
-// value: List of URLs to debug targets, e.g.,
-//   file:///foo/tfdbg_dump, grpc:://localhost:11011.
-// If not specified, defaults to <>
-func DebugNumericSummaryDebugUrls(value []string) DebugNumericSummaryAttr {
-	return func(m optionalAttr) {
-		m["debug_urls"] = value
-	}
-}
-
-// DebugNumericSummaryLowerBound sets the optional lower_bound attribute to value.
-//
-// value: (float) The lower bound <= which values will be included in the
-//   generalized -inf count. Default: -inf.
-// If not specified, defaults to -inf
-func DebugNumericSummaryLowerBound(value float32) DebugNumericSummaryAttr {
-	return func(m optionalAttr) {
-		m["lower_bound"] = value
-	}
-}
-
-// DebugNumericSummaryUpperBound sets the optional upper_bound attribute to value.
-//
-// value: (float) The upper bound >= which values will be included in the
-//   generalized +inf count. Default: +inf.
-// If not specified, defaults to inf
-func DebugNumericSummaryUpperBound(value float32) DebugNumericSummaryAttr {
-	return func(m optionalAttr) {
-		m["upper_bound"] = value
-	}
-}
-
-// DebugNumericSummaryMuteIfHealthy sets the optional mute_if_healthy attribute to value.
-//
-// value: (bool) Do not send data to the debug URLs unless at least one
-//   of elements [2], [3] and [7] (i.e., the nan count and the generalized -inf and
-//   inf counts) is non-zero.
-// If not specified, defaults to false
-func DebugNumericSummaryMuteIfHealthy(value bool) DebugNumericSummaryAttr {
-	return func(m optionalAttr) {
-		m["mute_if_healthy"] = value
-	}
-}
-
-// DebugNumericSummaryGatedGrpc sets the optional gated_grpc attribute to value.
-//
-// value: Whether this op will be gated. If any of the debug_urls of this
-//   debug node is of the grpc:// scheme, when the value of this attribute is set
-//   to True, the data will not actually be sent via the grpc stream unless this
-//   debug op has been enabled at the debug_url. If all of the debug_urls of this
-//   debug node are of the grpc:// scheme and the debug op is enabled at none of
-//   them, the output will be an empty Tensor.
-// If not specified, defaults to false
-func DebugNumericSummaryGatedGrpc(value bool) DebugNumericSummaryAttr {
-	return func(m optionalAttr) {
-		m["gated_grpc"] = value
-	}
-}
-
-// Debug Numeric Summary Op.
-//
-// Provide a basic summary of numeric value types, range and distribution.
-//
-// output: A double tensor of shape [14 + nDimensions], where nDimensions is the
-//   number of dimensions of the tensor's shape. The elements of output are:
-//   [0]: is initialized (1.0) or not (0.0).
-//   [1]: total number of elements
-//   [2]: NaN element count
-//   [3]: generalized -inf count: elements <= lower_bound. lower_bound is -inf by
-//     default.
-//   [4]: negative element count (excluding -inf), if lower_bound is the default
-//     -inf. Otherwise, this is the count of elements > lower_bound and < 0.
-//   [5]: zero element count
-//   [6]: positive element count (excluding +inf), if upper_bound is the default
-//     +inf. Otherwise, this is the count of elements < upper_bound and > 0.
-//   [7]: generalized +inf count, elements >= upper_bound. upper_bound is +inf by
-//     default.
-// Output elements [1:8] are all zero, if the tensor is uninitialized.
-//   [8]: minimum of all non-inf and non-NaN elements.
-//        If uninitialized or no such element exists: +inf.
-//   [9]: maximum of all non-inf and non-NaN elements.
-//        If uninitialized or no such element exists: -inf.
-//   [10]: mean of all non-inf and non-NaN elements.
-//         If uninitialized or no such element exists: NaN.
-//   [11]: variance of all non-inf and non-NaN elements.
-//         If uninitialized or no such element exists: NaN.
-//   [12]: Data type of the tensor encoded as an enum integer. See the DataType
-//         proto for more details.
-//   [13]: Number of dimensions of the tensor (ndims).
-//   [14+]: Sizes of the dimensions.
-//
+// Calculates product of two matrices, where left matrix is a tridiagonal matrix.
 //
 // Arguments:
-//	input: Input tensor, non-Reference type.
-func DebugNumericSummary(scope *Scope, input tf.Output, optional ...DebugNumericSummaryAttr) (output tf.Output) {
+//	superdiag: Tensor of shape `[..., 1, M]`, representing superdiagonals of
+// tri-diagonal matrices to the left of multiplication. Last element is ignored.
+//	maindiag: Tensor of shape `[..., 1, M]`, representing main diagonals of tri-diagonal
+// matrices to the left of multiplication.
+//	subdiag: Tensor of shape `[..., 1, M]`, representing subdiagonals of tri-diagonal
+// matrices to the left of multiplication. First element is ignored.
+//	rhs: Tensor of shape `[..., M, N]`, representing MxN matrices to the right of
+// multiplication.
+//
+// Returns Tensor of shape `[..., M, N]` containing the product.
+func TridiagonalMatMul(scope *Scope, superdiag tf.Output, maindiag tf.Output, subdiag tf.Output, rhs tf.Output) (output tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
-	attrs := map[string]interface{}{}
+	opspec := tf.OpSpec{
+		Type: "TridiagonalMatMul",
+		Input: []tf.Input{
+			superdiag, maindiag, subdiag, rhs,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// CollectiveBcastRecvAttr is an optional argument to CollectiveBcastRecv.
+type CollectiveBcastRecvAttr func(optionalAttr)
+
+// CollectiveBcastRecvCommunicationHint sets the optional communication_hint attribute to value.
+// If not specified, defaults to "auto"
+func CollectiveBcastRecvCommunicationHint(value string) CollectiveBcastRecvAttr {
+	return func(m optionalAttr) {
+		m["communication_hint"] = value
+	}
+}
+
+// CollectiveBcastRecvTimeoutSeconds sets the optional timeout_seconds attribute to value.
+// If not specified, defaults to 0
+func CollectiveBcastRecvTimeoutSeconds(value float32) CollectiveBcastRecvAttr {
+	return func(m optionalAttr) {
+		m["timeout_seconds"] = value
+	}
+}
+
+// Receives a tensor value broadcast from another device.
+func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape, optional ...CollectiveBcastRecvAttr) (data tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"T": T, "group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape}
 	for _, a := range optional {
 		a(attrs)
 	}
 	opspec := tf.OpSpec{
-		Type: "DebugNumericSummary",
+		Type: "CollectiveBcastRecv",
+
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Scatter the data from the input value into specific TensorArray elements.
+//
+// `indices` must be a vector, its length must match the first dim of `value`.
+//
+// Arguments:
+//	handle: The handle to a TensorArray.
+//	indices: The locations at which to write the tensor elements.
+//	value: The concatenated tensor to write to the TensorArray.
+//	flow_in: A float scalar that enforces proper chaining of operations.
+//
+// Returns A float scalar that enforces proper chaining of operations.
+func TensorArrayScatterV3(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "TensorArrayScatterV3",
+		Input: []tf.Input{
+			handle, indices, value, flow_in,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Computes the matrix square root of one or more square matrices:
+//
+// matmul(sqrtm(A), sqrtm(A)) = A
+//
+// The input matrix should be invertible. If the input matrix is real, it should
+// have no eigenvalues which are real and negative (pairs of complex conjugate
+// eigenvalues are allowed).
+//
+// The matrix square root is computed by first reducing the matrix to
+// quasi-triangular form with the real Schur decomposition. The square root
+// of the quasi-triangular matrix is then computed directly. Details of
+// the algorithm can be found in: Nicholas J. Higham, "Computing real
+// square roots of a real matrix", Linear Algebra Appl., 1987.
+//
+// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+// form square matrices. The output is a tensor of the same shape as the input
+// containing the matrix square root for all input submatrices `[..., :, :]`.
+//
+// Arguments:
+//	input: Shape is `[..., M, M]`.
+//
+// Returns Shape is `[..., M, M]`.
+//
+// @compatibility(scipy)
+// Equivalent to scipy.linalg.sqrtm
+// @end_compatibility
+func MatrixSquareRoot(scope *Scope, input tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "MatrixSquareRoot",
 		Input: []tf.Input{
 			input,
 		},
-		Attrs: attrs,
 	}
 	op := scope.AddOperation(opspec)
 	return op.Output(0)
 }
 
-// Outputs random integers from a uniform distribution.
+// Pads a tensor with mirrored values.
 //
-// The generated values are uniform integers in the range `[minval, maxval)`.
-// The lower bound `minval` is included in the range, while the upper bound
-// `maxval` is excluded.
+// This operation pads a `input` with mirrored values according to the `paddings`
+// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is
+// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
+// how many values to add before the contents of `input` in that dimension, and
+// `paddings[D, 1]` indicates how many values to add after the contents of `input`
+// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater
+// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true
+// (if false, respectively).
 //
-// The random integers are slightly biased unless `maxval - minval` is an exact
-// power of two.  The bias is small for values of `maxval - minval` significantly
-// smaller than the range of the output (either `2^32` or `2^64`).
+// The padded size of each dimension D of the output is:
 //
-// Arguments:
-//	resource: The handle of the resource variable that stores the state of the RNG.
-//	algorithm: The RNG algorithm.
-//	shape: The shape of the output tensor.
-//	minval: Minimum value (inclusive, scalar).
-//	maxval: Maximum value (exclusive, scalar).
+// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
 //
-// Returns Random values with specified shape.
-func StatefulUniformInt(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "StatefulUniformInt",
-		Input: []tf.Input{
-			resource, algorithm, shape, minval, maxval,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// An Op to exchange data across TPU replicas.
+// For example:
 //
-// On each replica, the input is split into `split_count` blocks along
-// `split_dimension` and send to the other replicas given group_assignment. After
-// receiving `split_count` - 1 blocks from other replicas, we concatenate the
-// blocks along `concat_dimension` as the output.
-//
-// For example, suppose there are 2 TPU replicas:
-// replica 0 receives input: `[[A, B]]`
-// replica 1 receives input: `[[C, D]]`
-//
-// group_assignment=`[[0, 1]]`
-// concat_dimension=0
-// split_dimension=1
-// split_count=2
-//
-// replica 0's output: `[[A], [C]]`
-// replica 1's output: `[[B], [D]]`
-//
-// Arguments:
-//	input: The local input to the sum.
-//	group_assignment: An int32 tensor with shape
-// [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
-// replica ids in the ith subgroup.
-//	concat_dimension: The dimension number to concatenate.
-//	split_dimension: The dimension number to split.
-//	split_count: The number of splits, this number must equal to the sub-group
-// size(group_assignment.get_shape()[1])
-//
-// Returns The exchanged result.
-func AllToAll(scope *Scope, input tf.Output, group_assignment tf.Output, concat_dimension int64, split_dimension int64, split_count int64) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"concat_dimension": concat_dimension, "split_dimension": split_dimension, "split_count": split_count}
-	opspec := tf.OpSpec{
-		Type: "AllToAll",
-		Input: []tf.Input{
-			input, group_assignment,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// TridiagonalSolveAttr is an optional argument to TridiagonalSolve.
-type TridiagonalSolveAttr func(optionalAttr)
-
-// TridiagonalSolvePartialPivoting sets the optional partial_pivoting attribute to value.
-//
-// value: Whether to apply partial pivoting. Partial pivoting makes the procedure more
-// stable, but slower.
-// If not specified, defaults to true
-func TridiagonalSolvePartialPivoting(value bool) TridiagonalSolveAttr {
-	return func(m optionalAttr) {
-		m["partial_pivoting"] = value
-	}
-}
-
-// Solves tridiagonal systems of equations.
-//
-//   Solves tridiagonal systems of equations.
-//   Supports batch dimensions and multiple right-hand sides per each left-hand
-//   side.
-//   On CPU, solution is computed via Gaussian elimination with or without partial
-//   pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE
-//   library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
-//   Partial pivoting is not yet supported by XLA backends.
-//
-// Arguments:
-//	diagonals: Tensor of shape `[..., 3, M]` whose innermost 2 dimensions represent the
-// tridiagonal matrices with three rows being the superdiagonal, diagonals, and
-// subdiagonals, in order. The last element of the superdiagonal and the first
-// element of the subdiagonal is ignored.
-//	rhs: Tensor of shape `[..., M, K]`, representing K right-hand sides per each
-// left-hand side.
-//
-// Returns Tensor of shape `[..., M, K]` containing the solutions
-func TridiagonalSolve(scope *Scope, diagonals tf.Output, rhs tf.Output, optional ...TridiagonalSolveAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "TridiagonalSolve",
-		Input: []tf.Input{
-			diagonals, rhs,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Computes gradients for SparseSegmentMean.
-//
-// Returns tensor "output" with same shape as grad, except for dimension 0 whose
-// value is output_dim0.
-//
-// Arguments:
-//	grad: gradient propagated to the SparseSegmentMean op.
-//	indices: indices passed to the corresponding SparseSegmentMean op.
-//	segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
-//	output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
-func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "SparseSegmentMeanGrad",
-		Input: []tf.Input{
-			grad, indices, segment_ids, output_dim0,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// SvdAttr is an optional argument to Svd.
-type SvdAttr func(optionalAttr)
-
-// SvdComputeUv sets the optional compute_uv attribute to value.
-//
-// value: If true, left and right singular vectors will be
-// computed and returned in `u` and `v`, respectively.
-// If false, `u` and `v` are not set and should never referenced.
-// If not specified, defaults to true
-func SvdComputeUv(value bool) SvdAttr {
-	return func(m optionalAttr) {
-		m["compute_uv"] = value
-	}
-}
-
-// SvdFullMatrices sets the optional full_matrices attribute to value.
-//
-// value: If true, compute full-sized `u` and `v`. If false
-// (the default), compute only the leading `P` singular vectors.
-// Ignored if `compute_uv` is `False`.
-// If not specified, defaults to false
-func SvdFullMatrices(value bool) SvdAttr {
-	return func(m optionalAttr) {
-		m["full_matrices"] = value
-	}
-}
-
-// Computes the singular value decompositions of one or more matrices.
-//
-// Computes the SVD of each inner matrix in `input` such that
-// `input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])`
-//
-// ```python
-// # a is a tensor containing a batch of matrices.
-// # s is a tensor of singular values for each matrix.
-// # u is the tensor containing the left singular vectors for each matrix.
-// # v is the tensor containing the right singular vectors for each matrix.
-// s, u, v = svd(a)
-// s, _, _ = svd(a, compute_uv=False)
+// ```
+// # 't' is [[1, 2, 3], [4, 5, 6]].
+// # 'paddings' is [[1, 1]], [2, 2]].
+// # 'mode' is SYMMETRIC.
+// # rank of 't' is 2.
+// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2]
+//                       [2, 1, 1, 2, 3, 3, 2]
+//                       [5, 4, 4, 5, 6, 6, 5]
+//                       [5, 4, 4, 5, 6, 6, 5]]
 // ```
 //
 // Arguments:
-//	input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
-// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
+//	input: The input tensor to be padded.
+//	paddings: A two-column matrix specifying the padding sizes. The number of
+// rows must be the same as the rank of `input`.
+//	mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions
+// do not include the borders, while in symmetric mode the padded regions
+// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings`
+// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and
+// it is `[1, 2, 3, 3, 2]` in symmetric mode.
 //
-// Returns:
-//	s: Singular values. Shape is `[..., P]`.
-//	u: Left singular vectors. If `full_matrices` is `False` then shape is
-// `[..., M, P]`; if `full_matrices` is `True` then shape is
-// `[..., M, M]`. Undefined if `compute_uv` is `False`.
-//	v: Left singular vectors. If `full_matrices` is `False` then shape is
-// `[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`.
-// Undefined if `compute_uv` is false.
-func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.Output, v tf.Output) {
+// Returns The padded tensor.
+func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
+	attrs := map[string]interface{}{"mode": mode}
 	opspec := tf.OpSpec{
-		Type: "Svd",
+		Type: "MirrorPad",
 		Input: []tf.Input{
-			input,
+			input, paddings,
 		},
 		Attrs: attrs,
 	}
 	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1), op.Output(2)
+	return op.Output(0)
 }
 
-// Retrieve multiple values from the computation outfeed. Device ordinal is a
-// tensor allowing dynamic outfeed.
-//
-// This operation will block indefinitely until data is available. Output `i`
-// corresponds to XLA tuple element `i`.
-//
-// Arguments:
-//	device_ordinal: An int scalar tensor, representing the TPU device to use. This should be -1 when
-// the Op is running on a TPU device, and >= 0 when the Op is running on the CPU
-// device.
-//	dtypes: The element types of each element in `outputs`.
-//	shapes: The shapes of each tensor in `outputs`.
-//
-// Returns A list of tensors that will be read from the outfeed.
-func OutfeedDequeueTupleV2(scope *Scope, device_ordinal tf.Output, dtypes []tf.DataType, shapes []tf.Shape) (outputs []tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"dtypes": dtypes, "shapes": shapes}
-	opspec := tf.OpSpec{
-		Type: "OutfeedDequeueTupleV2",
-		Input: []tf.Input{
-			device_ordinal,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	if scope.Err() != nil {
-		return
-	}
-	var idx int
-	var err error
-	if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil {
-		scope.UpdateErr("OutfeedDequeueTupleV2", err)
-		return
-	}
-	return outputs
-}
+// TensorArrayV3Attr is an optional argument to TensorArrayV3.
+type TensorArrayV3Attr func(optionalAttr)
 
-// QrAttr is an optional argument to Qr.
-type QrAttr func(optionalAttr)
-
-// QrFullMatrices sets the optional full_matrices attribute to value.
+// TensorArrayV3ElementShape sets the optional element_shape attribute to value.
 //
-// value: If true, compute full-sized `q` and `r`. If false
-// (the default), compute only the leading `P` columns of `q`.
-// If not specified, defaults to false
-func QrFullMatrices(value bool) QrAttr {
+// value: The expected shape of an element, if known. Used to
+// validate the shapes of TensorArray elements. If this shape is not
+// fully specified, gathering zero-size TensorArrays is an error.
+// If not specified, defaults to <unknown_rank:true >
+func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr {
 	return func(m optionalAttr) {
-		m["full_matrices"] = value
+		m["element_shape"] = value
 	}
 }
 
-// Computes the QR decompositions of one or more matrices.
+// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value.
 //
-// Computes the QR decomposition of each inner matrix in `tensor` such that
-// `tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
+// value: A boolean that determines whether writes to the TensorArray
+// are allowed to grow the size.  By default, this is not allowed.
+// If not specified, defaults to false
+func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr {
+	return func(m optionalAttr) {
+		m["dynamic_size"] = value
+	}
+}
+
+// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value.
 //
-// Currently, the gradient for the QR decomposition is well-defined only when
-// the first `P` columns of the inner matrix are linearly independent, where
-// `P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`.
+// value: If true (default), Tensors in the TensorArray are cleared
+// after being read.  This disables multiple read semantics but allows early
+// release of memory.
+// If not specified, defaults to true
+func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr {
+	return func(m optionalAttr) {
+		m["clear_after_read"] = value
+	}
+}
+
+// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value.
 //
-// ```python
-// # a is a tensor.
-// # q is a tensor of orthonormal matrices.
-// # r is a tensor of upper triangular matrices.
-// q, r = qr(a)
-// q_full, r_full = qr(a, full_matrices=True)
-// ```
+// value: If true (default is false), then all
+// elements in the TensorArray will be expected to have identical shapes.
+// This allows certain behaviors, like dynamically checking for
+// consistent shapes on write, and being able to fill in properly
+// shaped zero tensors on stack -- even if the element_shape attribute
+// is not fully defined.
+// If not specified, defaults to false
+func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr {
+	return func(m optionalAttr) {
+		m["identical_element_shapes"] = value
+	}
+}
+
+// TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value.
+//
+// value: Overrides the name used for the temporary tensor_array
+// resource. Default value is the name of the 'TensorArray' op (which
+// is guaranteed unique).
+// If not specified, defaults to ""
+func TensorArrayV3TensorArrayName(value string) TensorArrayV3Attr {
+	return func(m optionalAttr) {
+		m["tensor_array_name"] = value
+	}
+}
+
+// An array of Tensors of given size.
+//
+// Write data via Write and read via Read or Pack.
 //
 // Arguments:
-//	input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
-// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
+//	size: The size of the array.
+//	dtype: The type of the elements on the tensor_array.
 //
 // Returns:
-//	q: Orthonormal basis for range of `a`. If `full_matrices` is `False` then
-// shape is `[..., M, P]`; if `full_matrices` is `True` then shape is
-// `[..., M, M]`.
-//	r: Triangular factor. If `full_matrices` is `False` then shape is
-// `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`.
-func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Output) {
+//	handle: The handle to the TensorArray.
+//	flow: A scalar used to control gradient flow.
+func TensorArrayV3(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV3Attr) (handle tf.Output, flow tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
-	attrs := map[string]interface{}{}
+	attrs := map[string]interface{}{"dtype": dtype}
 	for _, a := range optional {
 		a(attrs)
 	}
 	opspec := tf.OpSpec{
-		Type: "Qr",
+		Type: "TensorArrayV3",
 		Input: []tf.Input{
-			input,
+			size,
 		},
 		Attrs: attrs,
 	}
@@ -15539,6 +15326,83 @@
 	return op.Output(0), op.Output(1)
 }
 
+// MatrixSolveLsAttr is an optional argument to MatrixSolveLs.
+type MatrixSolveLsAttr func(optionalAttr)
+
+// MatrixSolveLsFast sets the optional fast attribute to value.
+// If not specified, defaults to true
+func MatrixSolveLsFast(value bool) MatrixSolveLsAttr {
+	return func(m optionalAttr) {
+		m["fast"] = value
+	}
+}
+
+// Solves one or more linear least-squares problems.
+//
+// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same
+// type as `matrix` and shape `[..., M, K]`.
+// The output is a tensor shape `[..., N, K]` where each output matrix solves
+// each of the equations
+// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]`
+// in the least squares sense.
+//
+// We use the following notation for (complex) matrix and right-hand sides
+// in the batch:
+//
+// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\),
+// `rhs`=\\(B  \in \mathbb{C}^{m \times k}\\),
+// `output`=\\(X  \in \mathbb{C}^{n \times k}\\),
+// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\).
+//
+// If `fast` is `True`, then the solution is computed by solving the normal
+// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
+// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares
+// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\).
+// If \\(m \lt n\\) then `output` is computed as
+// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
+// minimum-norm solution to the under-determined linear system, i.e.
+// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\),
+// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable
+// when \\(A\\) is numerically full rank and has a condition number
+// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or \\(\lambda\\) is
+// sufficiently large.
+//
+// If `fast` is `False` an algorithm based on the numerically robust complete
+// orthogonal decomposition is used. This computes the minimum-norm
+// least-squares solution, even when \\(A\\) is rank deficient. This path is
+// typically 6-7 times slower than the fast path. If `fast` is `False` then
+// `l2_regularizer` is ignored.
+//
+// Arguments:
+//	matrix: Shape is `[..., M, N]`.
+//	rhs: Shape is `[..., M, K]`.
+//	l2_regularizer: Scalar tensor.
+//
+// @compatibility(numpy)
+// Equivalent to np.linalg.lstsq
+// @end_compatibility
+//
+// Returns Shape is `[..., N, K]`.
+func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "MatrixSolveLs",
+		Input: []tf.Input{
+			matrix, rhs, l2_regularizer,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve.
 type MatrixTriangularSolveAttr func(optionalAttr)
 
@@ -15641,6 +15505,184 @@
 	return op.Output(0)
 }
 
+// Applies sparse addition to `input` using individual values or slices
+//
+// from `updates` according to indices `indices`.  The updates are non-aliasing:
+// `input` is only modified in-place if no other operations will use it.
+// Otherwise, a copy of `input` is made.  This operation has a gradient with
+// respect to both `input` and `updates`.
+//
+// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+//
+// `indices` must be integer tensor, containing indices into `input`.
+// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where `0 < K <= P`.
+//
+// The innermost dimension of `indices` (with length `K`) corresponds to
+// indices into elements (if `K = P`) or `(P-K)`-dimensional slices
+// (if `K < P`) along the `K`th dimension of `input`.
+//
+// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+//
+// $$[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].$$
+//
+// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8
+// elements. In Python, that addition would look like this:
+//
+//     input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8])
+//     indices = tf.constant([[4], [3], [1], [7]])
+//     updates = tf.constant([9, 10, 11, 12])
+//     output = tf.scatter_nd_non_aliasing_add(input, indices, updates)
+//     with tf.Session() as sess:
+//       print(sess.run(output))
+//
+// The resulting value `output` would look like this:
+//
+//     [1, 13, 3, 14, 14, 6, 7, 20]
+//
+// See `tf.scatter_nd` for more details about how to make updates to slices.
+//
+// Arguments:
+//	input: A Tensor.
+//	indices: A Tensor. Must be one of the following types: `int32`, `int64`.
+// A tensor of indices into `input`.
+//	updates: A Tensor. Must have the same type as ref. A tensor of updated values
+// to add to `input`.
+//
+// Returns A `Tensor` with the same shape as `input`, containing values of `input`
+// updated with `updates`.
+func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "ScatterNdNonAliasingAdd",
+		Input: []tf.Input{
+			input, indices, updates,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// LuAttr is an optional argument to Lu.
+type LuAttr func(optionalAttr)
+
+// LuOutputIdxType sets the optional output_idx_type attribute to value.
+// If not specified, defaults to DT_INT32
+func LuOutputIdxType(value tf.DataType) LuAttr {
+	return func(m optionalAttr) {
+		m["output_idx_type"] = value
+	}
+}
+
+// Computes the LU decomposition of one or more square matrices.
+//
+// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+// form square matrices.
+//
+// The input has to be invertible.
+//
+// The output consists of two tensors LU and P containing the LU decomposition
+// of all input submatrices `[..., :, :]`. LU encodes the lower triangular and
+// upper triangular factors.
+//
+// For each input submatrix of shape `[M, M]`, L is a lower triangular matrix of
+// shape `[M, M]` with unit diagonal whose entries correspond to the strictly lower
+// triangular part of LU. U is a upper triangular matrix of shape `[M, M]` whose
+// entries correspond to the upper triangular part, including the diagonal, of LU.
+//
+// P represents a permutation matrix encoded as a list of indices each between `0`
+// and `M-1`, inclusive. If P_mat denotes the permutation matrix corresponding to
+// P, then the L, U and P satisfies P_mat * input = L * U.
+//
+// Arguments:
+//	input: A tensor of shape `[..., M, M]` whose inner-most 2 dimensions form matrices of
+// size `[M, M]`.
+//
+// Returns:
+//	lu: A tensor of shape `[..., M, M]` whose strictly lower triangular part denotes the
+// lower triangular factor `L` with unit diagonal, and whose upper triangular part
+// denotes the upper triangular factor `U`.
+//	p: Permutation of the rows encoded as a list of indices in `0..M-1`. Shape is
+// `[..., M]`.
+// @compatibility(scipy)
+// Similar to `scipy.linalg.lu`, except the triangular factors `L` and `U` are
+// packed into a single tensor, the permutation is applied to `input` instead of
+// the right hand side and the permutation `P` is returned as a list of indices
+// instead of a permutation matrix.
+// @end_compatibility
+func Lu(scope *Scope, input tf.Output, optional ...LuAttr) (lu tf.Output, p tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "Lu",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1)
+}
+
+// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2.
+type SelfAdjointEigV2Attr func(optionalAttr)
+
+// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value.
+//
+// value: If `True` then eigenvectors will be computed and returned in `v`.
+// Otherwise, only the eigenvalues will be computed.
+// If not specified, defaults to true
+func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr {
+	return func(m optionalAttr) {
+		m["compute_v"] = value
+	}
+}
+
+// Computes the eigen decomposition of one or more square self-adjoint matrices.
+//
+// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in
+// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues
+// are sorted in non-decreasing order.
+//
+// ```python
+// # a is a tensor.
+// # e is a tensor of eigenvalues.
+// # v is a tensor of eigenvectors.
+// e, v = self_adjoint_eig(a)
+// e = self_adjoint_eig(a, compute_v=False)
+// ```
+//
+// Arguments:
+//	input: `Tensor` input of shape `[N, N]`.
+//
+// Returns:
+//	e: Eigenvalues. Shape is `[N]`.
+//	v: Eigenvectors. Shape is `[N, N]`.
+func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "SelfAdjointEigV2",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1)
+}
+
 // Computes the Eigen Decomposition of a batch of square self-adjoint matrices.
 //
 // DEPRECATED at GraphDef version 11: Use SelfAdjointEigV2 instead.
@@ -15671,6 +15713,51 @@
 	return op.Output(0)
 }
 
+// Computes the reverse mode backpropagated gradient of the Cholesky algorithm.
+//
+// For an explanation see "Differentiation of the Cholesky algorithm" by
+// Iain Murray http://arxiv.org/abs/1602.07527.
+//
+// Arguments:
+//	l: Output of batch Cholesky algorithm l = cholesky(A). Shape is `[..., M, M]`.
+// Algorithm depends only on lower triangular part of the innermost matrices of
+// this tensor.
+//	grad: df/dl where f is some scalar function. Shape is `[..., M, M]`.
+// Algorithm depends only on lower triangular part of the innermost matrices of
+// this tensor.
+//
+// Returns Symmetrized version of df/dA . Shape is `[..., M, M]`
+func CholeskyGrad(scope *Scope, l tf.Output, grad tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "CholeskyGrad",
+		Input: []tf.Input{
+			l, grad,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Deprecated, use python implementation tf.linalg.matrix_exponential.
+//
+// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead.
+func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "MatrixExponential",
+		Input: []tf.Input{
+			input,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Creates a dataset that emits the key-value pairs in one or more LMDB files.
 //
 // The Lightning Memory-Mapped Database Manager, or LMDB, is an embedded binary
@@ -16603,65 +16690,6 @@
 	return op.Output(0)
 }
 
-// Applies sparse addition to `input` using individual values or slices
-//
-// from `updates` according to indices `indices`.  The updates are non-aliasing:
-// `input` is only modified in-place if no other operations will use it.
-// Otherwise, a copy of `input` is made.  This operation has a gradient with
-// respect to both `input` and `updates`.
-//
-// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
-//
-// `indices` must be integer tensor, containing indices into `input`.
-// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where `0 < K <= P`.
-//
-// The innermost dimension of `indices` (with length `K`) corresponds to
-// indices into elements (if `K = P`) or `(P-K)`-dimensional slices
-// (if `K < P`) along the `K`th dimension of `input`.
-//
-// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
-//
-// $$[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].$$
-//
-// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8
-// elements. In Python, that addition would look like this:
-//
-//     input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8])
-//     indices = tf.constant([[4], [3], [1], [7]])
-//     updates = tf.constant([9, 10, 11, 12])
-//     output = tf.scatter_nd_non_aliasing_add(input, indices, updates)
-//     with tf.Session() as sess:
-//       print(sess.run(output))
-//
-// The resulting value `output` would look like this:
-//
-//     [1, 13, 3, 14, 14, 6, 7, 20]
-//
-// See `tf.scatter_nd` for more details about how to make updates to slices.
-//
-// Arguments:
-//	input: A Tensor.
-//	indices: A Tensor. Must be one of the following types: `int32`, `int64`.
-// A tensor of indices into `input`.
-//	updates: A Tensor. Must have the same type as ref. A tensor of updated values
-// to add to `input`.
-//
-// Returns A `Tensor` with the same shape as `input`, containing values of `input`
-// updated with `updates`.
-func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "ScatterNdNonAliasingAdd",
-		Input: []tf.Input{
-			input, indices, updates,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2.
 type MutableHashTableOfTensorsV2Attr func(optionalAttr)
 
@@ -17234,73 +17262,6 @@
 	return op.Output(0), op.Output(1)
 }
 
-// Computes the sum along sparse segments of a tensor divided by the sqrt of N.
-//
-// N is the size of the segment being reduced.
-//
-// Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
-// missing, the `output` tensor at that position will be zeroed.
-//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
-// for an explanation of segments.
-//
-// Arguments:
-//
-//	indices: A 1-D tensor. Has same rank as `segment_ids`.
-//	segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-//	num_segments: Should equal the number of distinct segment IDs.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SparseSegmentSqrtNWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "SparseSegmentSqrtNWithNumSegments",
-		Input: []tf.Input{
-			data, indices, segment_ids, num_segments,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Computes the Cholesky decomposition of one or more square matrices.
-//
-// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices.
-//
-// The input has to be symmetric and positive definite. Only the lower-triangular
-// part of the input will be used for this operation. The upper-triangular part
-// will not be read.
-//
-// The output is a tensor of the same shape as the input
-// containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
-//
-// **Note**: The gradient computation on GPU is faster for large matrices but
-// not for large batch dimensions when the submatrices are small. In this
-// case it might be faster to use the CPU.
-//
-// Arguments:
-//	input: Shape is `[..., M, M]`.
-//
-// Returns Shape is `[..., M, M]`.
-func Cholesky(scope *Scope, input tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "Cholesky",
-		Input: []tf.Input{
-			input,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Splits a tensor into a list.
 //
 // list[i] corresponds to lengths[i] tensors from the input tensor.
@@ -18548,86 +18509,6 @@
 	return op.Output(0)
 }
 
-// ResourceApplyGradientDescentAttr is an optional argument to ResourceApplyGradientDescent.
-type ResourceApplyGradientDescentAttr func(optionalAttr)
-
-// ResourceApplyGradientDescentUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, the subtraction will be protected by a lock;
-// otherwise the behavior is undefined, but may exhibit less contention.
-// If not specified, defaults to false
-func ResourceApplyGradientDescentUseLocking(value bool) ResourceApplyGradientDescentAttr {
-	return func(m optionalAttr) {
-		m["use_locking"] = value
-	}
-}
-
-// Update '*var' by subtracting 'alpha' * 'delta' from it.
-//
-// Arguments:
-//	var_: Should be from a Variable().
-//	alpha: Scaling factor. Must be a scalar.
-//	delta: The change.
-//
-// Returns the created operation.
-func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, delta tf.Output, optional ...ResourceApplyGradientDescentAttr) (o *tf.Operation) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "ResourceApplyGradientDescent",
-		Input: []tf.Input{
-			var_, alpha, delta,
-		},
-		Attrs: attrs,
-	}
-	return scope.AddOperation(opspec)
-}
-
-// Computes the matrix logarithm of one or more square matrices:
-//
-//
-// \\(log(exp(A)) = A\\)
-//
-// This op is only defined for complex matrices. If A is positive-definite and
-// real, then casting to a complex matrix, taking the logarithm and casting back
-// to a real matrix will give the correct result.
-//
-// This function computes the matrix logarithm using the Schur-Parlett algorithm.
-// Details of the algorithm can be found in Section 11.6.2 of:
-// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008.
-// ISBN 978-0-898716-46-7.
-//
-// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices. The output is a tensor of the same shape as the input
-// containing the exponential for all input submatrices `[..., :, :]`.
-//
-// Arguments:
-//	input: Shape is `[..., M, M]`.
-//
-// Returns Shape is `[..., M, M]`.
-//
-// @compatibility(scipy)
-// Equivalent to scipy.linalg.logm
-// @end_compatibility
-func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "MatrixLogarithm",
-		Input: []tf.Input{
-			input,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // SparseBincountAttr is an optional argument to SparseBincount.
 type SparseBincountAttr func(optionalAttr)
 
@@ -18681,35 +18562,6 @@
 	return op.Output(0)
 }
 
-// Calculate product with tridiagonal matrix.
-//
-// Calculates product of two matrices, where left matrix is a tridiagonal matrix.
-//
-// Arguments:
-//	superdiag: Tensor of shape `[..., 1, M]`, representing superdiagonals of
-// tri-diagonal matrices to the left of multiplication. Last element is ignored.
-//	maindiag: Tensor of shape `[..., 1, M]`, representing main diagonals of tri-diagonal
-// matrices to the left of multiplication.
-//	subdiag: Tensor of shape `[..., 1, M]`, representing subdiagonals of tri-diagonal
-// matrices to the left of multiplication. First element is ignored.
-//	rhs: Tensor of shape `[..., M, N]`, representing MxN matrices to the right of
-// multiplication.
-//
-// Returns Tensor of shape `[..., M, N]` containing the product.
-func TridiagonalMatMul(scope *Scope, superdiag tf.Output, maindiag tf.Output, subdiag tf.Output, rhs tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "TridiagonalMatMul",
-		Input: []tf.Input{
-			superdiag, maindiag, subdiag, rhs,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`
 //
 // if < 0, `scale * features` otherwise.
@@ -19738,34 +19590,6 @@
 	return op.Output(0)
 }
 
-// Computes the reverse mode backpropagated gradient of the Cholesky algorithm.
-//
-// For an explanation see "Differentiation of the Cholesky algorithm" by
-// Iain Murray http://arxiv.org/abs/1602.07527.
-//
-// Arguments:
-//	l: Output of batch Cholesky algorithm l = cholesky(A). Shape is `[..., M, M]`.
-// Algorithm depends only on lower triangular part of the innermost matrices of
-// this tensor.
-//	grad: df/dl where f is some scalar function. Shape is `[..., M, M]`.
-// Algorithm depends only on lower triangular part of the innermost matrices of
-// this tensor.
-//
-// Returns Symmetrized version of df/dA . Shape is `[..., M, M]`
-func CholeskyGrad(scope *Scope, l tf.Output, grad tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "CholeskyGrad",
-		Input: []tf.Input{
-			l, grad,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Reshapes a tensor.
 //
 // Given `tensor`, this operation returns a tensor that has the same values
@@ -20469,12 +20293,13 @@
 	return output
 }
 
-// Returns the TopK values in the array in sorted order. This is a combination
+// Returns the TopK values in the array in sorted order.
 //
-// of MakeUnique and TopKUnique. The returned top-K will have its lower bits
-// replaced by iota, thus it will be close to the original value but not exactly
-// the same. The running time is proportional to the product of K and the input
-// size. NaNs are never returned. Subnormal numbers are flushed to zero.
+// This is a combination of MakeUnique and TopKUnique. The returned top-K will
+// have its lower bits replaced by iota, thus it will be close to the original
+// value but not exactly the same. The running time is proportional to the product
+// of K and the input size. NaNs are never returned. Subnormal numbers are flushed
+// to zero.
 func TopKWithUnique(scope *Scope, input tf.Output, k int64) (topk tf.Output, topk_indices tf.Output) {
 	if scope.Err() != nil {
 		return
@@ -20774,94 +20599,6 @@
 	return op.Output(0)
 }
 
-// LowerBoundAttr is an optional argument to LowerBound.
-type LowerBoundAttr func(optionalAttr)
-
-// LowerBoundOutType sets the optional out_type attribute to value.
-// If not specified, defaults to DT_INT32
-func LowerBoundOutType(value tf.DataType) LowerBoundAttr {
-	return func(m optionalAttr) {
-		m["out_type"] = value
-	}
-}
-
-// Applies lower_bound(sorted_search_values, values) along each row.
-//
-// Each set of rows with the same index in (sorted_inputs, values) is treated
-// independently.  The resulting row is the equivalent of calling
-// `np.searchsorted(sorted_inputs, values, side='left')`.
-//
-// The result is not a global index to the entire
-// `Tensor`, but rather just the index in the last dimension.
-//
-// A 2-D example:
-//   sorted_sequence = [[0, 3, 9, 9, 10],
-//                      [1, 2, 3, 4, 5]]
-//   values = [[2, 4, 9],
-//             [0, 2, 6]]
-//
-//   result = LowerBound(sorted_sequence, values)
-//
-//   result == [[1, 2, 2],
-//              [0, 1, 5]]
-//
-// Arguments:
-//	sorted_inputs: 2-D Tensor where each row is ordered.
-//	values: 2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
-// the values that will be searched for in `sorted_search_values`.
-//
-// Returns A `Tensor` with the same shape as `values`.  It contains the first scalar index
-// into the last dimension where values can be inserted without changing the
-// ordered property.
-func LowerBound(scope *Scope, sorted_inputs tf.Output, values tf.Output, optional ...LowerBoundAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "LowerBound",
-		Input: []tf.Input{
-			sorted_inputs, values,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Returns the truth value of (x > y) element-wise.
-//
-// *NOTE*: `Greater` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-//
-// Example:
-//
-// ```python
-// x = tf.constant([5, 4, 6])
-// y = tf.constant([5, 2, 5])
-// tf.math.greater(x, y) ==> [False, True, True]
-//
-// x = tf.constant([5, 4, 6])
-// y = tf.constant([5])
-// tf.math.greater(x, y) ==> [False, False, True]
-// ```
-func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "Greater",
-		Input: []tf.Input{
-			x, y,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Compute the polygamma function \\(\psi^{(n)}(x)\\).
 //
 // The polygamma function is defined as:
@@ -21010,27 +20747,6 @@
 
 // Returns element-wise remainder of division. This emulates C semantics in that
 //
-// the result here is consistent with a truncating divide. E.g. `truncate(x / y) *
-// y + truncate_mod(x, y) = x`.
-//
-// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "TruncateMod",
-		Input: []tf.Input{
-			x, y,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Returns element-wise remainder of division. This emulates C semantics in that
-//
 // the result here is consistent with a truncating divide. E.g.
 // `tf.truncatediv(x, y) * y + truncate_mod(x, y) = x`.
 //
@@ -21515,89 +21231,6 @@
 	return op.Output(0), op.Output(1)
 }
 
-// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2.
-type QueueDequeueManyV2Attr func(optionalAttr)
-
-// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value.
-//
-// value: If the queue has fewer than n elements, this operation
-// will block for up to timeout_ms milliseconds.
-// Note: This option is not supported yet.
-// If not specified, defaults to -1
-func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr {
-	return func(m optionalAttr) {
-		m["timeout_ms"] = value
-	}
-}
-
-// Dequeues `n` tuples of one or more tensors from the given queue.
-//
-// If the queue is closed and there are fewer than `n` elements, then an
-// OutOfRange error is returned.
-//
-// This operation concatenates queue-element component tensors along the
-// 0th dimension to make a single component tensor.  All of the components
-// in the dequeued tuple will have size `n` in the 0th dimension.
-//
-// This operation has `k` outputs, where `k` is the number of components in
-// the tuples stored in the given queue, and output `i` is the ith
-// component of the dequeued tuple.
-//
-// N.B. If the queue is empty, this operation will block until `n` elements
-// have been dequeued (or 'timeout_ms' elapses, if specified).
-//
-// Arguments:
-//	handle: The handle to a queue.
-//	n: The number of tuples to dequeue.
-//	component_types: The type of each component in a tuple.
-//
-// Returns One or more tensors that were dequeued as a tuple.
-func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"component_types": component_types}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "QueueDequeueManyV2",
-		Input: []tf.Input{
-			handle, n,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	if scope.Err() != nil {
-		return
-	}
-	var idx int
-	var err error
-	if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
-		scope.UpdateErr("QueueDequeueManyV2", err)
-		return
-	}
-	return components
-}
-
-// Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or NaN.
-//
-// *NOTE*: `MulNoNan` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func MulNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "MulNoNan",
-		Input: []tf.Input{
-			x, y,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // AsStringAttr is an optional argument to AsString.
 type AsStringAttr func(optionalAttr)
 
@@ -23556,17 +23189,96 @@
 	return op.Output(0)
 }
 
-// Deprecated, use python implementation tf.linalg.matrix_exponential.
+// SvdAttr is an optional argument to Svd.
+type SvdAttr func(optionalAttr)
+
+// SvdComputeUv sets the optional compute_uv attribute to value.
 //
-// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead.
-func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) {
+// value: If true, left and right singular vectors will be
+// computed and returned in `u` and `v`, respectively.
+// If false, `u` and `v` are not set and should never referenced.
+// If not specified, defaults to true
+func SvdComputeUv(value bool) SvdAttr {
+	return func(m optionalAttr) {
+		m["compute_uv"] = value
+	}
+}
+
+// SvdFullMatrices sets the optional full_matrices attribute to value.
+//
+// value: If true, compute full-sized `u` and `v`. If false
+// (the default), compute only the leading `P` singular vectors.
+// Ignored if `compute_uv` is `False`.
+// If not specified, defaults to false
+func SvdFullMatrices(value bool) SvdAttr {
+	return func(m optionalAttr) {
+		m["full_matrices"] = value
+	}
+}
+
+// Computes the singular value decompositions of one or more matrices.
+//
+// Computes the SVD of each inner matrix in `input` such that
+// `input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])`
+//
+// ```python
+// # a is a tensor containing a batch of matrices.
+// # s is a tensor of singular values for each matrix.
+// # u is the tensor containing the left singular vectors for each matrix.
+// # v is the tensor containing the right singular vectors for each matrix.
+// s, u, v = svd(a)
+// s, _, _ = svd(a, compute_uv=False)
+// ```
+//
+// Arguments:
+//	input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
+//
+// Returns:
+//	s: Singular values. Shape is `[..., P]`.
+//	u: Left singular vectors. If `full_matrices` is `False` then shape is
+// `[..., M, P]`; if `full_matrices` is `True` then shape is
+// `[..., M, M]`. Undefined if `compute_uv` is `False`.
+//	v: Left singular vectors. If `full_matrices` is `False` then shape is
+// `[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`.
+// Undefined if `compute_uv` is false.
+func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.Output, v tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "Svd",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Computes gradients for SparseSegmentMean.
+//
+// Returns tensor "output" with same shape as grad, except for dimension 0 whose
+// value is output_dim0.
+//
+// Arguments:
+//	grad: gradient propagated to the SparseSegmentMean op.
+//	indices: indices passed to the corresponding SparseSegmentMean op.
+//	segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
+//	output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
+func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
 	opspec := tf.OpSpec{
-		Type: "MatrixExponential",
+		Type: "SparseSegmentMeanGrad",
 		Input: []tf.Input{
-			input,
+			grad, indices, segment_ids, output_dim0,
 		},
 	}
 	op := scope.AddOperation(opspec)
@@ -29883,7 +29595,7 @@
 	return op.Output(0), op.Output(1), op.Output(2)
 }
 
-//     Adds v into specified rows of x.
+// Adds v into specified rows of x.
 //
 //     Computes y = x; y[i, :] += v; return y.
 //
@@ -29907,41 +29619,116 @@
 	return op.Output(0)
 }
 
-// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2.
-type SelfAdjointEigV2Attr func(optionalAttr)
+// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2.
+type QueueDequeueManyV2Attr func(optionalAttr)
 
-// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value.
+// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value.
 //
-// value: If `True` then eigenvectors will be computed and returned in `v`.
-// Otherwise, only the eigenvalues will be computed.
-// If not specified, defaults to true
-func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr {
+// value: If the queue has fewer than n elements, this operation
+// will block for up to timeout_ms milliseconds.
+// Note: This option is not supported yet.
+// If not specified, defaults to -1
+func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr {
 	return func(m optionalAttr) {
-		m["compute_v"] = value
+		m["timeout_ms"] = value
 	}
 }
 
-// Computes the eigen decomposition of one or more square self-adjoint matrices.
+// Dequeues `n` tuples of one or more tensors from the given queue.
 //
-// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in
-// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues
-// are sorted in non-decreasing order.
+// If the queue is closed and there are fewer than `n` elements, then an
+// OutOfRange error is returned.
 //
-// ```python
-// # a is a tensor.
-// # e is a tensor of eigenvalues.
-// # v is a tensor of eigenvectors.
-// e, v = self_adjoint_eig(a)
-// e = self_adjoint_eig(a, compute_v=False)
-// ```
+// This operation concatenates queue-element component tensors along the
+// 0th dimension to make a single component tensor.  All of the components
+// in the dequeued tuple will have size `n` in the 0th dimension.
+//
+// This operation has `k` outputs, where `k` is the number of components in
+// the tuples stored in the given queue, and output `i` is the ith
+// component of the dequeued tuple.
+//
+// N.B. If the queue is empty, this operation will block until `n` elements
+// have been dequeued (or 'timeout_ms' elapses, if specified).
 //
 // Arguments:
-//	input: `Tensor` input of shape `[N, N]`.
+//	handle: The handle to a queue.
+//	n: The number of tuples to dequeue.
+//	component_types: The type of each component in a tuple.
 //
-// Returns:
-//	e: Eigenvalues. Shape is `[N]`.
-//	v: Eigenvectors. Shape is `[N, N]`.
-func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) {
+// Returns One or more tensors that were dequeued as a tuple.
+func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"component_types": component_types}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "QueueDequeueManyV2",
+		Input: []tf.Input{
+			handle, n,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	if scope.Err() != nil {
+		return
+	}
+	var idx int
+	var err error
+	if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
+		scope.UpdateErr("QueueDequeueManyV2", err)
+		return
+	}
+	return components
+}
+
+// Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or NaN.
+//
+// *NOTE*: `MulNoNan` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func MulNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "MulNoNan",
+		Input: []tf.Input{
+			x, y,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// StatelessRandomUniformFullIntV2Attr is an optional argument to StatelessRandomUniformFullIntV2.
+type StatelessRandomUniformFullIntV2Attr func(optionalAttr)
+
+// StatelessRandomUniformFullIntV2Dtype sets the optional dtype attribute to value.
+//
+// value: The type of the output.
+// If not specified, defaults to DT_UINT64
+func StatelessRandomUniformFullIntV2Dtype(value tf.DataType) StatelessRandomUniformFullIntV2Attr {
+	return func(m optionalAttr) {
+		m["dtype"] = value
+	}
+}
+
+// Outputs deterministic pseudorandom random integers from a uniform distribution.
+//
+// The generated values are uniform integers covering the whole range of `dtype`.
+//
+// The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
+//
+// Arguments:
+//	shape: The shape of the output tensor.
+//	key: Key for the counter-based RNG algorithm (shape uint64[1]).
+//	counter: Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
+//	alg: The RNG algorithm (shape int32[]).
+//
+// Returns Random values with specified shape.
+func StatelessRandomUniformFullIntV2(scope *Scope, shape tf.Output, key tf.Output, counter tf.Output, alg tf.Output, optional ...StatelessRandomUniformFullIntV2Attr) (output tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
@@ -29950,14 +29737,106 @@
 		a(attrs)
 	}
 	opspec := tf.OpSpec{
-		Type: "SelfAdjointEigV2",
+		Type: "StatelessRandomUniformFullIntV2",
 		Input: []tf.Input{
-			input,
+			shape, key, counter, alg,
 		},
 		Attrs: attrs,
 	}
 	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1)
+	return op.Output(0)
+}
+
+// Returns a batched diagonal tensor with a given batched diagonal values.
+//
+// Given a `diagonal`, this operation returns a tensor with the `diagonal` and
+// everything else padded with zeros. The diagonal is computed as follows:
+//
+// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a
+// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where:
+//
+// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`.
+//
+// For example:
+//
+// ```
+// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]]
+//
+// and diagonal.shape = (2, 4)
+//
+// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0]
+//                                      [0, 2, 0, 0]
+//                                      [0, 0, 3, 0]
+//                                      [0, 0, 0, 4]],
+//                                     [[5, 0, 0, 0]
+//                                      [0, 6, 0, 0]
+//                                      [0, 0, 7, 0]
+//                                      [0, 0, 0, 8]]]
+//
+// which has shape (2, 4, 4)
+// ```
+//
+// Arguments:
+//	diagonal: Rank `k`, where `k >= 1`.
+//
+// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`.
+func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "MatrixDiag",
+		Input: []tf.Input{
+			diagonal,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal.
+type StatelessTruncatedNormalAttr func(optionalAttr)
+
+// StatelessTruncatedNormalDtype sets the optional dtype attribute to value.
+//
+// value: The type of the output.
+// If not specified, defaults to DT_FLOAT
+func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr {
+	return func(m optionalAttr) {
+		m["dtype"] = value
+	}
+}
+
+// Outputs deterministic pseudorandom values from a truncated normal distribution.
+//
+// The generated values follow a normal distribution with mean 0 and standard
+// deviation 1, except that values whose magnitude is more than 2 standard
+// deviations from the mean are dropped and re-picked.
+//
+// The outputs are a deterministic function of `shape` and `seed`.
+//
+// Arguments:
+//	shape: The shape of the output tensor.
+//	seed: 2 seeds (shape [2]).
+//
+// Returns Random values with specified shape.
+func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "StatelessTruncatedNormal",
+		Input: []tf.Input{
+			shape, seed,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
 }
 
 // BiasAddGradAttr is an optional argument to BiasAddGrad.
@@ -30065,6 +29944,386 @@
 	return op.Output(0)
 }
 
+// FakeQuantWithMinMaxVarsPerChannelGradientAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannelGradient.
+type FakeQuantWithMinMaxVarsPerChannelGradientAttr func(optionalAttr)
+
+// FakeQuantWithMinMaxVarsPerChannelGradientNumBits sets the optional num_bits attribute to value.
+//
+// value: The bitwidth of the quantization; between 2 and 16, inclusive.
+// If not specified, defaults to 8
+func FakeQuantWithMinMaxVarsPerChannelGradientNumBits(value int64) FakeQuantWithMinMaxVarsPerChannelGradientAttr {
+	return func(m optionalAttr) {
+		m["num_bits"] = value
+	}
+}
+
+// FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange sets the optional narrow_range attribute to value.
+//
+// value: Whether to quantize into 2^num_bits - 1 distinct values.
+// If not specified, defaults to false
+func FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelGradientAttr {
+	return func(m optionalAttr) {
+		m["narrow_range"] = value
+	}
+}
+
+// Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation.
+//
+// Arguments:
+//	gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation,
+// shape one of: `[d]`, `[b, d]`,  `[b, h, w, d]`.
+//	inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape
+//   same as `gradients`.
+// min, max: Quantization interval, floats of shape `[d]`.
+//
+//
+//
+// Returns:
+//	backprops_wrt_input: Backpropagated gradients w.r.t. inputs, shape same as
+// `inputs`:
+//   `gradients * (inputs >= min && inputs <= max)`.
+//	backprop_wrt_min: Backpropagated gradients w.r.t. min parameter, shape `[d]`:
+// `sum_per_d(gradients * (inputs < min))`.
+//	backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`:
+// `sum_per_d(gradients * (inputs > max))`.
+func FakeQuantWithMinMaxVarsPerChannelGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "FakeQuantWithMinMaxVarsPerChannelGradient",
+		Input: []tf.Input{
+			gradients, inputs, min, max,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// PrintV2Attr is an optional argument to PrintV2.
+type PrintV2Attr func(optionalAttr)
+
+// PrintV2OutputStream sets the optional output_stream attribute to value.
+//
+// value: A string specifying the output stream or logging level to print to.
+// If not specified, defaults to "stderr"
+func PrintV2OutputStream(value string) PrintV2Attr {
+	return func(m optionalAttr) {
+		m["output_stream"] = value
+	}
+}
+
+// PrintV2End sets the optional end attribute to value.
+// If not specified, defaults to "\n"
+func PrintV2End(value string) PrintV2Attr {
+	return func(m optionalAttr) {
+		m["end"] = value
+	}
+}
+
+// Prints a string scalar.
+//
+// Prints a string scalar to the desired output_stream.
+//
+// Arguments:
+//	input: The string scalar to print.
+//
+// Returns the created operation.
+func PrintV2(scope *Scope, input tf.Output, optional ...PrintV2Attr) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "PrintV2",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	return scope.AddOperation(opspec)
+}
+
+// LowerBoundAttr is an optional argument to LowerBound.
+type LowerBoundAttr func(optionalAttr)
+
+// LowerBoundOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func LowerBoundOutType(value tf.DataType) LowerBoundAttr {
+	return func(m optionalAttr) {
+		m["out_type"] = value
+	}
+}
+
+// Applies lower_bound(sorted_search_values, values) along each row.
+//
+// Each set of rows with the same index in (sorted_inputs, values) is treated
+// independently.  The resulting row is the equivalent of calling
+// `np.searchsorted(sorted_inputs, values, side='left')`.
+//
+// The result is not a global index to the entire
+// `Tensor`, but rather just the index in the last dimension.
+//
+// A 2-D example:
+//   sorted_sequence = [[0, 3, 9, 9, 10],
+//                      [1, 2, 3, 4, 5]]
+//   values = [[2, 4, 9],
+//             [0, 2, 6]]
+//
+//   result = LowerBound(sorted_sequence, values)
+//
+//   result == [[1, 2, 2],
+//              [0, 1, 5]]
+//
+// Arguments:
+//	sorted_inputs: 2-D Tensor where each row is ordered.
+//	values: 2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+// the values that will be searched for in `sorted_search_values`.
+//
+// Returns A `Tensor` with the same shape as `values`.  It contains the first scalar index
+// into the last dimension where values can be inserted without changing the
+// ordered property.
+func LowerBound(scope *Scope, sorted_inputs tf.Output, values tf.Output, optional ...LowerBoundAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "LowerBound",
+		Input: []tf.Input{
+			sorted_inputs, values,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Returns the truth value of (x > y) element-wise.
+//
+// *NOTE*: `Greater` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+//
+// Example:
+//
+// ```python
+// x = tf.constant([5, 4, 6])
+// y = tf.constant([5, 2, 5])
+// tf.math.greater(x, y) ==> [False, True, True]
+//
+// x = tf.constant([5, 4, 6])
+// y = tf.constant([5])
+// tf.math.greater(x, y) ==> [False, False, True]
+// ```
+func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "Greater",
+		Input: []tf.Input{
+			x, y,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// StatelessRandomUniformFullIntAttr is an optional argument to StatelessRandomUniformFullInt.
+type StatelessRandomUniformFullIntAttr func(optionalAttr)
+
+// StatelessRandomUniformFullIntDtype sets the optional dtype attribute to value.
+//
+// value: The type of the output.
+// If not specified, defaults to DT_UINT64
+func StatelessRandomUniformFullIntDtype(value tf.DataType) StatelessRandomUniformFullIntAttr {
+	return func(m optionalAttr) {
+		m["dtype"] = value
+	}
+}
+
+// Outputs deterministic pseudorandom random integers from a uniform distribution.
+//
+// The generated values are uniform integers covering the whole range of `dtype`.
+//
+// The outputs are a deterministic function of `shape` and `seed`.
+//
+// Arguments:
+//	shape: The shape of the output tensor.
+//	seed: 2 seeds (shape [2]).
+//
+// Returns Random values with specified shape.
+func StatelessRandomUniformFullInt(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformFullIntAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "StatelessRandomUniformFullInt",
+		Input: []tf.Input{
+			shape, seed,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Returns element-wise remainder of division. This emulates C semantics in that
+//
+// the result here is consistent with a truncating divide. E.g. `truncate(x / y) *
+// y + truncate_mod(x, y) = x`.
+//
+// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "TruncateMod",
+		Input: []tf.Input{
+			x, y,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Calculates the gradient of the SparseMatrixSoftmax op.
+//
+// Arguments:
+//	softmax: A CSRSparseMatrix.
+//	grad_softmax: The gradient of `softmax`.
+//
+//
+// Returns The output gradient.
+func SparseMatrixSoftmaxGrad(scope *Scope, softmax tf.Output, grad_softmax tf.Output, type_ tf.DataType) (gradient tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"type": type_}
+	opspec := tf.OpSpec{
+		Type: "SparseMatrixSoftmaxGrad",
+		Input: []tf.Input{
+			softmax, grad_softmax,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Computes the GRU cell back-propagation for 1 time step.
+//
+// Args
+//     x: Input to the GRU cell.
+//     h_prev: State input from the previous GRU cell.
+//     w_ru: Weight matrix for the reset and update gate.
+//     w_c: Weight matrix for the cell connection gate.
+//     b_ru: Bias vector for the reset and update gate.
+//     b_c: Bias vector for the cell connection gate.
+//     r: Output of the reset gate.
+//     u: Output of the update gate.
+//     c: Output of the cell connection gate.
+//     d_h: Gradients of the h_new wrt to objective function.
+//
+// Returns
+//     d_x: Gradients of the x wrt to objective function.
+//     d_h_prev: Gradients of the h wrt to objective function.
+//     d_c_bar Gradients of the c_bar wrt to objective function.
+//     d_r_bar_u_bar Gradients of the r_bar & u_bar wrt to objective function.
+//
+// This kernel op implements the following mathematical equations:
+//
+// Note on notation of the variables:
+//
+// Concatenation of a and b is represented by a_b
+// Element-wise dot product of a and b is represented by ab
+// Element-wise dot product is represented by \circ
+// Matrix multiplication is represented by *
+//
+// Additional notes for clarity:
+//
+// `w_ru` can be segmented into 4 different matrices.
+// ```
+// w_ru = [w_r_x w_u_x
+//         w_r_h_prev w_u_h_prev]
+// ```
+// Similarly, `w_c` can be segmented into 2 different matrices.
+// ```
+// w_c = [w_c_x w_c_h_prevr]
+// ```
+// Same goes for biases.
+// ```
+// b_ru = [b_ru_x b_ru_h]
+// b_c = [b_c_x b_c_h]
+// ```
+// Another note on notation:
+// ```
+// d_x = d_x_component_1 + d_x_component_2
+//
+// where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
+// and d_x_component_2 = d_c_bar * w_c_x^T
+//
+// d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
+// where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
+// ```
+//
+// Mathematics behind the Gradients below:
+// ```
+// d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
+// d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
+//
+// d_r_bar_u_bar = [d_r_bar d_u_bar]
+//
+// [d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
+//
+// [d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
+//
+// d_x = d_x_component_1 + d_x_component_2
+//
+// d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
+// ```
+// Below calculation is performed in the python wrapper for the Gradients
+// (not in the gradient kernel.)
+// ```
+// d_w_ru = x_h_prevr^T * d_c_bar
+//
+// d_w_c = x_h_prev^T * d_r_bar_u_bar
+//
+// d_b_ru = sum of d_r_bar_u_bar along axis = 0
+//
+// d_b_c = sum of d_c_bar along axis = 0
+// ```
+func GRUBlockCellGrad(scope *Scope, x tf.Output, h_prev tf.Output, w_ru tf.Output, w_c tf.Output, b_ru tf.Output, b_c tf.Output, r tf.Output, u tf.Output, c tf.Output, d_h tf.Output) (d_x tf.Output, d_h_prev tf.Output, d_c_bar tf.Output, d_r_bar_u_bar tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "GRUBlockCellGrad",
+		Input: []tf.Input{
+			x, h_prev, w_ru, w_c, b_ru, b_c, r, u, c, d_h,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1), op.Output(2), op.Output(3)
+}
+
 // FractionalMaxPoolAttr is an optional argument to FractionalMaxPool.
 type FractionalMaxPoolAttr func(optionalAttr)
 
@@ -31081,126 +31340,6 @@
 	return op.Output(0)
 }
 
-// Calculates the gradient of the SparseMatrixSoftmax op.
-//
-// Arguments:
-//	softmax: A CSRSparseMatrix.
-//	grad_softmax: The gradient of `softmax`.
-//
-//
-// Returns The output gradient.
-func SparseMatrixSoftmaxGrad(scope *Scope, softmax tf.Output, grad_softmax tf.Output, type_ tf.DataType) (gradient tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"type": type_}
-	opspec := tf.OpSpec{
-		Type: "SparseMatrixSoftmaxGrad",
-		Input: []tf.Input{
-			softmax, grad_softmax,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Computes the GRU cell back-propagation for 1 time step.
-//
-// Args
-//     x: Input to the GRU cell.
-//     h_prev: State input from the previous GRU cell.
-//     w_ru: Weight matrix for the reset and update gate.
-//     w_c: Weight matrix for the cell connection gate.
-//     b_ru: Bias vector for the reset and update gate.
-//     b_c: Bias vector for the cell connection gate.
-//     r: Output of the reset gate.
-//     u: Output of the update gate.
-//     c: Output of the cell connection gate.
-//     d_h: Gradients of the h_new wrt to objective function.
-//
-// Returns
-//     d_x: Gradients of the x wrt to objective function.
-//     d_h_prev: Gradients of the h wrt to objective function.
-//     d_c_bar Gradients of the c_bar wrt to objective function.
-//     d_r_bar_u_bar Gradients of the r_bar & u_bar wrt to objective function.
-//
-// This kernel op implements the following mathematical equations:
-//
-// Note on notation of the variables:
-//
-// Concatenation of a and b is represented by a_b
-// Element-wise dot product of a and b is represented by ab
-// Element-wise dot product is represented by \circ
-// Matrix multiplication is represented by *
-//
-// Additional notes for clarity:
-//
-// `w_ru` can be segmented into 4 different matrices.
-// ```
-// w_ru = [w_r_x w_u_x
-//         w_r_h_prev w_u_h_prev]
-// ```
-// Similarly, `w_c` can be segmented into 2 different matrices.
-// ```
-// w_c = [w_c_x w_c_h_prevr]
-// ```
-// Same goes for biases.
-// ```
-// b_ru = [b_ru_x b_ru_h]
-// b_c = [b_c_x b_c_h]
-// ```
-// Another note on notation:
-// ```
-// d_x = d_x_component_1 + d_x_component_2
-//
-// where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
-// and d_x_component_2 = d_c_bar * w_c_x^T
-//
-// d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
-// where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
-// ```
-//
-// Mathematics behind the Gradients below:
-// ```
-// d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
-// d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
-//
-// d_r_bar_u_bar = [d_r_bar d_u_bar]
-//
-// [d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
-//
-// [d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
-//
-// d_x = d_x_component_1 + d_x_component_2
-//
-// d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
-// ```
-// Below calculation is performed in the python wrapper for the Gradients
-// (not in the gradient kernel.)
-// ```
-// d_w_ru = x_h_prevr^T * d_c_bar
-//
-// d_w_c = x_h_prev^T * d_r_bar_u_bar
-//
-// d_b_ru = sum of d_r_bar_u_bar along axis = 0
-//
-// d_b_c = sum of d_c_bar along axis = 0
-// ```
-func GRUBlockCellGrad(scope *Scope, x tf.Output, h_prev tf.Output, w_ru tf.Output, w_c tf.Output, b_ru tf.Output, b_c tf.Output, r tf.Output, u tf.Output, c tf.Output, d_h tf.Output) (d_x tf.Output, d_h_prev tf.Output, d_c_bar tf.Output, d_r_bar_u_bar tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "GRUBlockCellGrad",
-		Input: []tf.Input{
-			x, h_prev, w_ru, w_c, b_ru, b_c, r, u, c, d_h,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1), op.Output(2), op.Output(3)
-}
-
 // Encode audio data using the WAV file format.
 //
 // This operation will generate a string suitable to be saved out to create a .wav
@@ -32291,6 +32430,70 @@
 	return op.Output(0)
 }
 
+// Deserialize and concatenate `SparseTensors` from a serialized minibatch.
+//
+// The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where
+// `N` is the minibatch size and the rows correspond to packed outputs of
+// `SerializeSparse`.  The ranks of the original `SparseTensor` objects
+// must all match.  When the final `SparseTensor` is created, it has rank one
+// higher than the ranks of the incoming `SparseTensor` objects
+// (they have been concatenated along a new row dimension).
+//
+// The output `SparseTensor` object's shape values for all dimensions but the
+// first are the max across the input `SparseTensor` objects' shape values
+// for the corresponding dimensions.  Its first shape value is `N`, the minibatch
+// size.
+//
+// The input `SparseTensor` objects' indices are assumed ordered in
+// standard lexicographic order.  If this is not the case, after this
+// step run `SparseReorder` to restore index ordering.
+//
+// For example, if the serialized input is a `[2 x 3]` matrix representing two
+// original `SparseTensor` objects:
+//
+//     index = [ 0]
+//             [10]
+//             [20]
+//     values = [1, 2, 3]
+//     shape = [50]
+//
+// and
+//
+//     index = [ 2]
+//             [10]
+//     values = [4, 5]
+//     shape = [30]
+//
+// then the final deserialized `SparseTensor` will be:
+//
+//     index = [0  0]
+//             [0 10]
+//             [0 20]
+//             [1  2]
+//             [1 10]
+//     values = [1, 2, 3, 4, 5]
+//     shape = [2 50]
+//
+// Arguments:
+//	serialized_sparse: 2-D, The `N` serialized `SparseTensor` objects.
+// Must have 3 columns.
+//	dtype: The `dtype` of the serialized `SparseTensor` objects.
+func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"dtype": dtype}
+	opspec := tf.OpSpec{
+		Type: "DeserializeManySparse",
+		Input: []tf.Input{
+			serialized_sparse,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1), op.Output(2)
+}
+
 // Get the number of nodes in a tree
 //
 // Arguments:
@@ -33357,108 +33560,6 @@
 	return op.Output(0), op.Output(1), op.Output(2)
 }
 
-// CollectiveBcastRecvAttr is an optional argument to CollectiveBcastRecv.
-type CollectiveBcastRecvAttr func(optionalAttr)
-
-// CollectiveBcastRecvCommunicationHint sets the optional communication_hint attribute to value.
-// If not specified, defaults to "auto"
-func CollectiveBcastRecvCommunicationHint(value string) CollectiveBcastRecvAttr {
-	return func(m optionalAttr) {
-		m["communication_hint"] = value
-	}
-}
-
-// CollectiveBcastRecvTimeoutSeconds sets the optional timeout_seconds attribute to value.
-// If not specified, defaults to 0
-func CollectiveBcastRecvTimeoutSeconds(value float32) CollectiveBcastRecvAttr {
-	return func(m optionalAttr) {
-		m["timeout_seconds"] = value
-	}
-}
-
-// Receives a tensor value broadcast from another device.
-func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape, optional ...CollectiveBcastRecvAttr) (data tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"T": T, "group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "CollectiveBcastRecv",
-
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Scatter the data from the input value into specific TensorArray elements.
-//
-// `indices` must be a vector, its length must match the first dim of `value`.
-//
-// Arguments:
-//	handle: The handle to a TensorArray.
-//	indices: The locations at which to write the tensor elements.
-//	value: The concatenated tensor to write to the TensorArray.
-//	flow_in: A float scalar that enforces proper chaining of operations.
-//
-// Returns A float scalar that enforces proper chaining of operations.
-func TensorArrayScatterV3(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "TensorArrayScatterV3",
-		Input: []tf.Input{
-			handle, indices, value, flow_in,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Computes the matrix square root of one or more square matrices:
-//
-// matmul(sqrtm(A), sqrtm(A)) = A
-//
-// The input matrix should be invertible. If the input matrix is real, it should
-// have no eigenvalues which are real and negative (pairs of complex conjugate
-// eigenvalues are allowed).
-//
-// The matrix square root is computed by first reducing the matrix to
-// quasi-triangular form with the real Schur decomposition. The square root
-// of the quasi-triangular matrix is then computed directly. Details of
-// the algorithm can be found in: Nicholas J. Higham, "Computing real
-// square roots of a real matrix", Linear Algebra Appl., 1987.
-//
-// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices. The output is a tensor of the same shape as the input
-// containing the matrix square root for all input submatrices `[..., :, :]`.
-//
-// Arguments:
-//	input: Shape is `[..., M, M]`.
-//
-// Returns Shape is `[..., M, M]`.
-//
-// @compatibility(scipy)
-// Equivalent to scipy.linalg.sqrtm
-// @end_compatibility
-func MatrixSquareRoot(scope *Scope, input tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "MatrixSquareRoot",
-		Input: []tf.Input{
-			input,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // MutexV2Attr is an optional argument to MutexV2.
 type MutexV2Attr func(optionalAttr)
 
@@ -34871,9 +34972,9 @@
 	return scope.AddOperation(opspec)
 }
 
-// Returns the TopK unique values in the array in sorted order. The
+// Returns the TopK unique values in the array in sorted order.
 //
-// running time is proportional to the product of K and the input
+// The running time is proportional to the product of K and the input
 // size. Sorting the whole array is more efficient for sufficiently large
 // values of K. The median-of-medians algorithm is probably faster, but
 // difficult to implement efficiently in XLA. If there are fewer than K
@@ -36029,233 +36130,6 @@
 	return op.Output(0), op.Output(1), op.Output(2)
 }
 
-// Pads a tensor with mirrored values.
-//
-// This operation pads a `input` with mirrored values according to the `paddings`
-// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is
-// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
-// how many values to add before the contents of `input` in that dimension, and
-// `paddings[D, 1]` indicates how many values to add after the contents of `input`
-// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater
-// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true
-// (if false, respectively).
-//
-// The padded size of each dimension D of the output is:
-//
-// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
-//
-// For example:
-//
-// ```
-// # 't' is [[1, 2, 3], [4, 5, 6]].
-// # 'paddings' is [[1, 1]], [2, 2]].
-// # 'mode' is SYMMETRIC.
-// # rank of 't' is 2.
-// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2]
-//                       [2, 1, 1, 2, 3, 3, 2]
-//                       [5, 4, 4, 5, 6, 6, 5]
-//                       [5, 4, 4, 5, 6, 6, 5]]
-// ```
-//
-// Arguments:
-//	input: The input tensor to be padded.
-//	paddings: A two-column matrix specifying the padding sizes. The number of
-// rows must be the same as the rank of `input`.
-//	mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions
-// do not include the borders, while in symmetric mode the padded regions
-// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings`
-// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and
-// it is `[1, 2, 3, 3, 2]` in symmetric mode.
-//
-// Returns The padded tensor.
-func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"mode": mode}
-	opspec := tf.OpSpec{
-		Type: "MirrorPad",
-		Input: []tf.Input{
-			input, paddings,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// TensorArrayV3Attr is an optional argument to TensorArrayV3.
-type TensorArrayV3Attr func(optionalAttr)
-
-// TensorArrayV3ElementShape sets the optional element_shape attribute to value.
-//
-// value: The expected shape of an element, if known. Used to
-// validate the shapes of TensorArray elements. If this shape is not
-// fully specified, gathering zero-size TensorArrays is an error.
-// If not specified, defaults to <unknown_rank:true >
-func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr {
-	return func(m optionalAttr) {
-		m["element_shape"] = value
-	}
-}
-
-// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value.
-//
-// value: A boolean that determines whether writes to the TensorArray
-// are allowed to grow the size.  By default, this is not allowed.
-// If not specified, defaults to false
-func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr {
-	return func(m optionalAttr) {
-		m["dynamic_size"] = value
-	}
-}
-
-// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value.
-//
-// value: If true (default), Tensors in the TensorArray are cleared
-// after being read.  This disables multiple read semantics but allows early
-// release of memory.
-// If not specified, defaults to true
-func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr {
-	return func(m optionalAttr) {
-		m["clear_after_read"] = value
-	}
-}
-
-// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value.
-//
-// value: If true (default is false), then all
-// elements in the TensorArray will be expected to have identical shapes.
-// This allows certain behaviors, like dynamically checking for
-// consistent shapes on write, and being able to fill in properly
-// shaped zero tensors on stack -- even if the element_shape attribute
-// is not fully defined.
-// If not specified, defaults to false
-func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr {
-	return func(m optionalAttr) {
-		m["identical_element_shapes"] = value
-	}
-}
-
-// TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value.
-//
-// value: Overrides the name used for the temporary tensor_array
-// resource. Default value is the name of the 'TensorArray' op (which
-// is guaranteed unique).
-// If not specified, defaults to ""
-func TensorArrayV3TensorArrayName(value string) TensorArrayV3Attr {
-	return func(m optionalAttr) {
-		m["tensor_array_name"] = value
-	}
-}
-
-// An array of Tensors of given size.
-//
-// Write data via Write and read via Read or Pack.
-//
-// Arguments:
-//	size: The size of the array.
-//	dtype: The type of the elements on the tensor_array.
-//
-// Returns:
-//	handle: The handle to the TensorArray.
-//	flow: A scalar used to control gradient flow.
-func TensorArrayV3(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV3Attr) (handle tf.Output, flow tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"dtype": dtype}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "TensorArrayV3",
-		Input: []tf.Input{
-			size,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1)
-}
-
-// MatrixSolveLsAttr is an optional argument to MatrixSolveLs.
-type MatrixSolveLsAttr func(optionalAttr)
-
-// MatrixSolveLsFast sets the optional fast attribute to value.
-// If not specified, defaults to true
-func MatrixSolveLsFast(value bool) MatrixSolveLsAttr {
-	return func(m optionalAttr) {
-		m["fast"] = value
-	}
-}
-
-// Solves one or more linear least-squares problems.
-//
-// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
-// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same
-// type as `matrix` and shape `[..., M, K]`.
-// The output is a tensor shape `[..., N, K]` where each output matrix solves
-// each of the equations
-// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]`
-// in the least squares sense.
-//
-// We use the following notation for (complex) matrix and right-hand sides
-// in the batch:
-//
-// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\),
-// `rhs`=\\(B  \in \mathbb{C}^{m \times k}\\),
-// `output`=\\(X  \in \mathbb{C}^{n \times k}\\),
-// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\).
-//
-// If `fast` is `True`, then the solution is computed by solving the normal
-// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
-// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares
-// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\).
-// If \\(m \lt n\\) then `output` is computed as
-// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
-// minimum-norm solution to the under-determined linear system, i.e.
-// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\),
-// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable
-// when \\(A\\) is numerically full rank and has a condition number
-// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or \\(\lambda\\) is
-// sufficiently large.
-//
-// If `fast` is `False` an algorithm based on the numerically robust complete
-// orthogonal decomposition is used. This computes the minimum-norm
-// least-squares solution, even when \\(A\\) is rank deficient. This path is
-// typically 6-7 times slower than the fast path. If `fast` is `False` then
-// `l2_regularizer` is ignored.
-//
-// Arguments:
-//	matrix: Shape is `[..., M, N]`.
-//	rhs: Shape is `[..., M, K]`.
-//	l2_regularizer: Scalar tensor.
-//
-// @compatibility(numpy)
-// Equivalent to np.linalg.lstsq
-// @end_compatibility
-//
-// Returns Shape is `[..., N, K]`.
-func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "MatrixSolveLs",
-		Input: []tf.Input{
-			matrix, rhs, l2_regularizer,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Generates sparse cross from a list of sparse and dense tensors.
 //
 // The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each
@@ -37139,70 +37013,6 @@
 	return op.Output(0), op.Output(1), op.Output(2), op.Output(3)
 }
 
-// Deserialize and concatenate `SparseTensors` from a serialized minibatch.
-//
-// The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where
-// `N` is the minibatch size and the rows correspond to packed outputs of
-// `SerializeSparse`.  The ranks of the original `SparseTensor` objects
-// must all match.  When the final `SparseTensor` is created, it has rank one
-// higher than the ranks of the incoming `SparseTensor` objects
-// (they have been concatenated along a new row dimension).
-//
-// The output `SparseTensor` object's shape values for all dimensions but the
-// first are the max across the input `SparseTensor` objects' shape values
-// for the corresponding dimensions.  Its first shape value is `N`, the minibatch
-// size.
-//
-// The input `SparseTensor` objects' indices are assumed ordered in
-// standard lexicographic order.  If this is not the case, after this
-// step run `SparseReorder` to restore index ordering.
-//
-// For example, if the serialized input is a `[2 x 3]` matrix representing two
-// original `SparseTensor` objects:
-//
-//     index = [ 0]
-//             [10]
-//             [20]
-//     values = [1, 2, 3]
-//     shape = [50]
-//
-// and
-//
-//     index = [ 2]
-//             [10]
-//     values = [4, 5]
-//     shape = [30]
-//
-// then the final deserialized `SparseTensor` will be:
-//
-//     index = [0  0]
-//             [0 10]
-//             [0 20]
-//             [1  2]
-//             [1 10]
-//     values = [1, 2, 3, 4, 5]
-//     shape = [2 50]
-//
-// Arguments:
-//	serialized_sparse: 2-D, The `N` serialized `SparseTensor` objects.
-// Must have 3 columns.
-//	dtype: The `dtype` of the serialized `SparseTensor` objects.
-func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"dtype": dtype}
-	opspec := tf.OpSpec{
-		Type: "DeserializeManySparse",
-		Input: []tf.Input{
-			serialized_sparse,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1), op.Output(2)
-}
-
 // Sets the index-th position of the list to contain the given tensor.
 //
 // input_handle: the list
@@ -39888,49 +39698,6 @@
 	return op.Output(0), op.Output(1), op.Output(2)
 }
 
-// StatelessRandomUniformFullIntAttr is an optional argument to StatelessRandomUniformFullInt.
-type StatelessRandomUniformFullIntAttr func(optionalAttr)
-
-// StatelessRandomUniformFullIntDtype sets the optional dtype attribute to value.
-//
-// value: The type of the output.
-// If not specified, defaults to DT_UINT64
-func StatelessRandomUniformFullIntDtype(value tf.DataType) StatelessRandomUniformFullIntAttr {
-	return func(m optionalAttr) {
-		m["dtype"] = value
-	}
-}
-
-// Outputs deterministic pseudorandom random integers from a uniform distribution.
-//
-// The generated values are uniform integers covering the whole range of `dtype`.
-//
-// The outputs are a deterministic function of `shape` and `seed`.
-//
-// Arguments:
-//	shape: The shape of the output tensor.
-//	seed: 2 seeds (shape [2]).
-//
-// Returns Random values with specified shape.
-func StatelessRandomUniformFullInt(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformFullIntAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "StatelessRandomUniformFullInt",
-		Input: []tf.Input{
-			shape, seed,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp.
 type ResourceApplyRMSPropAttr func(optionalAttr)
 
@@ -40086,98 +39853,6 @@
 	return op.Output(0)
 }
 
-// Returns a batched diagonal tensor with a given batched diagonal values.
-//
-// Given a `diagonal`, this operation returns a tensor with the `diagonal` and
-// everything else padded with zeros. The diagonal is computed as follows:
-//
-// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a
-// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where:
-//
-// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`.
-//
-// For example:
-//
-// ```
-// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]]
-//
-// and diagonal.shape = (2, 4)
-//
-// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0]
-//                                      [0, 2, 0, 0]
-//                                      [0, 0, 3, 0]
-//                                      [0, 0, 0, 4]],
-//                                     [[5, 0, 0, 0]
-//                                      [0, 6, 0, 0]
-//                                      [0, 0, 7, 0]
-//                                      [0, 0, 0, 8]]]
-//
-// which has shape (2, 4, 4)
-// ```
-//
-// Arguments:
-//	diagonal: Rank `k`, where `k >= 1`.
-//
-// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`.
-func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "MatrixDiag",
-		Input: []tf.Input{
-			diagonal,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal.
-type StatelessTruncatedNormalAttr func(optionalAttr)
-
-// StatelessTruncatedNormalDtype sets the optional dtype attribute to value.
-//
-// value: The type of the output.
-// If not specified, defaults to DT_FLOAT
-func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr {
-	return func(m optionalAttr) {
-		m["dtype"] = value
-	}
-}
-
-// Outputs deterministic pseudorandom values from a truncated normal distribution.
-//
-// The generated values follow a normal distribution with mean 0 and standard
-// deviation 1, except that values whose magnitude is more than 2 standard
-// deviations from the mean are dropped and re-picked.
-//
-// The outputs are a deterministic function of `shape` and `seed`.
-//
-// Arguments:
-//	shape: The shape of the output tensor.
-//	seed: 2 seeds (shape [2]).
-//
-// Returns Random values with specified shape.
-func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "StatelessTruncatedNormal",
-		Input: []tf.Input{
-			shape, seed,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Returns a copy of the input tensor.
 func Snapshot(scope *Scope, input tf.Output) (output tf.Output) {
 	if scope.Err() != nil {
@@ -40653,51 +40328,6 @@
 	return sparse_indices, sparse_values, sparse_shapes, dense_values
 }
 
-// StatelessRandomUniformFullIntV2Attr is an optional argument to StatelessRandomUniformFullIntV2.
-type StatelessRandomUniformFullIntV2Attr func(optionalAttr)
-
-// StatelessRandomUniformFullIntV2Dtype sets the optional dtype attribute to value.
-//
-// value: The type of the output.
-// If not specified, defaults to DT_UINT64
-func StatelessRandomUniformFullIntV2Dtype(value tf.DataType) StatelessRandomUniformFullIntV2Attr {
-	return func(m optionalAttr) {
-		m["dtype"] = value
-	}
-}
-
-// Outputs deterministic pseudorandom random integers from a uniform distribution.
-//
-// The generated values are uniform integers covering the whole range of `dtype`.
-//
-// The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
-//
-// Arguments:
-//	shape: The shape of the output tensor.
-//	key: Key for the counter-based RNG algorithm (shape uint64[1]).
-//	counter: Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
-//	alg: The RNG algorithm (shape int32[]).
-//
-// Returns Random values with specified shape.
-func StatelessRandomUniformFullIntV2(scope *Scope, shape tf.Output, key tf.Output, counter tf.Output, alg tf.Output, optional ...StatelessRandomUniformFullIntV2Attr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "StatelessRandomUniformFullIntV2",
-		Input: []tf.Input{
-			shape, key, counter, alg,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate.
 type ResourceScatterNdUpdateAttr func(optionalAttr)
 
@@ -41094,6 +40724,46 @@
 	return op.Output(0), op.Output(1), op.Output(2)
 }
 
+// PrelinearizeTupleAttr is an optional argument to PrelinearizeTuple.
+type PrelinearizeTupleAttr func(optionalAttr)
+
+// PrelinearizeTupleLayouts sets the optional layouts attribute to value.
+//
+// value: A vector holding the requested layout in minor-to-major sequence for all the
+// tuple shapes in the order the shapes appear in the "shapes" input. The layout
+// elements for a sub-shape can be set to -1 in which case the corresponding layout
+// will be computed by the infeed operation.
+// If not specified, defaults to <>
+func PrelinearizeTupleLayouts(value []int64) PrelinearizeTupleAttr {
+	return func(m optionalAttr) {
+		m["layouts"] = value
+	}
+}
+
+// An op which linearizes multiple Tensor values to an opaque variant tensor.
+//
+// Arguments:
+//	inputs: A list of tensors that will be provided using the infeed mechanism.
+//	shapes: The shapes of each tensor in `inputs`.
+func PrelinearizeTuple(scope *Scope, inputs []tf.Output, shapes []tf.Shape, optional ...PrelinearizeTupleAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"shapes": shapes}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "PrelinearizeTuple",
+		Input: []tf.Input{
+			tf.OutputList(inputs),
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Forwards `data` to the output port determined by `pred`.
 //
 // If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise,
@@ -43125,46 +42795,6 @@
 	return scope.AddOperation(opspec)
 }
 
-// PrelinearizeTupleAttr is an optional argument to PrelinearizeTuple.
-type PrelinearizeTupleAttr func(optionalAttr)
-
-// PrelinearizeTupleLayouts sets the optional layouts attribute to value.
-//
-// value: A vector holding the requested layout in minor-to-major sequence for all the
-// tuple shapes in the order the shapes appear in the "shapes" input. The layout
-// elements for a sub-shape can be set to -1 in which case the corresponding layout
-// will be computed by the infeed operation.
-// If not specified, defaults to <>
-func PrelinearizeTupleLayouts(value []int64) PrelinearizeTupleAttr {
-	return func(m optionalAttr) {
-		m["layouts"] = value
-	}
-}
-
-// An op which linearizes multiple Tensor values to an opaque variant tensor.
-//
-// Arguments:
-//	inputs: A list of tensors that will be provided using the infeed mechanism.
-//	shapes: The shapes of each tensor in `inputs`.
-func PrelinearizeTuple(scope *Scope, inputs []tf.Output, shapes []tf.Shape, optional ...PrelinearizeTupleAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"shapes": shapes}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "PrelinearizeTuple",
-		Input: []tf.Input{
-			tf.OutputList(inputs),
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Computes the LSTM cell backward propagation for the entire time sequence.
 //
 // This implementation is to be used in conjunction of LSTMBlock.
@@ -43272,114 +42902,6 @@
 	return scope.AddOperation(opspec)
 }
 
-// FakeQuantWithMinMaxVarsPerChannelGradientAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannelGradient.
-type FakeQuantWithMinMaxVarsPerChannelGradientAttr func(optionalAttr)
-
-// FakeQuantWithMinMaxVarsPerChannelGradientNumBits sets the optional num_bits attribute to value.
-//
-// value: The bitwidth of the quantization; between 2 and 16, inclusive.
-// If not specified, defaults to 8
-func FakeQuantWithMinMaxVarsPerChannelGradientNumBits(value int64) FakeQuantWithMinMaxVarsPerChannelGradientAttr {
-	return func(m optionalAttr) {
-		m["num_bits"] = value
-	}
-}
-
-// FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange sets the optional narrow_range attribute to value.
-//
-// value: Whether to quantize into 2^num_bits - 1 distinct values.
-// If not specified, defaults to false
-func FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelGradientAttr {
-	return func(m optionalAttr) {
-		m["narrow_range"] = value
-	}
-}
-
-// Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation.
-//
-// Arguments:
-//	gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation,
-// shape one of: `[d]`, `[b, d]`,  `[b, h, w, d]`.
-//	inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape
-//   same as `gradients`.
-// min, max: Quantization interval, floats of shape `[d]`.
-//
-//
-//
-// Returns:
-//	backprops_wrt_input: Backpropagated gradients w.r.t. inputs, shape same as
-// `inputs`:
-//   `gradients * (inputs >= min && inputs <= max)`.
-//	backprop_wrt_min: Backpropagated gradients w.r.t. min parameter, shape `[d]`:
-// `sum_per_d(gradients * (inputs < min))`.
-//	backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`:
-// `sum_per_d(gradients * (inputs > max))`.
-func FakeQuantWithMinMaxVarsPerChannelGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "FakeQuantWithMinMaxVarsPerChannelGradient",
-		Input: []tf.Input{
-			gradients, inputs, min, max,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// PrintV2Attr is an optional argument to PrintV2.
-type PrintV2Attr func(optionalAttr)
-
-// PrintV2OutputStream sets the optional output_stream attribute to value.
-//
-// value: A string specifying the output stream or logging level to print to.
-// If not specified, defaults to "stderr"
-func PrintV2OutputStream(value string) PrintV2Attr {
-	return func(m optionalAttr) {
-		m["output_stream"] = value
-	}
-}
-
-// PrintV2End sets the optional end attribute to value.
-// If not specified, defaults to "\n"
-func PrintV2End(value string) PrintV2Attr {
-	return func(m optionalAttr) {
-		m["end"] = value
-	}
-}
-
-// Prints a string scalar.
-//
-// Prints a string scalar to the desired output_stream.
-//
-// Arguments:
-//	input: The string scalar to print.
-//
-// Returns the created operation.
-func PrintV2(scope *Scope, input tf.Output, optional ...PrintV2Attr) (o *tf.Operation) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "PrintV2",
-		Input: []tf.Input{
-			input,
-		},
-		Attrs: attrs,
-	}
-	return scope.AddOperation(opspec)
-}
-
 // Inserts a dimension of 1 into a tensor's shape.
 //
 // Given a tensor `input`, this operation inserts a dimension of 1 at the
@@ -43729,143 +43251,6 @@
 	return op.Output(0)
 }
 
-// TensorArrayConcatV2Attr is an optional argument to TensorArrayConcatV2.
-type TensorArrayConcatV2Attr func(optionalAttr)
-
-// TensorArrayConcatV2ElementShapeExcept0 sets the optional element_shape_except0 attribute to value.
-// If not specified, defaults to <unknown_rank:true >
-func TensorArrayConcatV2ElementShapeExcept0(value tf.Shape) TensorArrayConcatV2Attr {
-	return func(m optionalAttr) {
-		m["element_shape_except0"] = value
-	}
-}
-
-// Deprecated. Use TensorArrayConcatV3
-func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV2Attr) (value tf.Output, lengths tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"dtype": dtype}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "TensorArrayConcatV2",
-		Input: []tf.Input{
-			handle, flow_in,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1)
-}
-
-// Writes contents to the file at input filename. Creates file and recursively
-//
-// creates directory if not existing.
-//
-// Arguments:
-//	filename: scalar. The name of the file to which we write the contents.
-//	contents: scalar. The content to be written to the output file.
-//
-// Returns the created operation.
-func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "WriteFile",
-		Input: []tf.Input{
-			filename, contents,
-		},
-	}
-	return scope.AddOperation(opspec)
-}
-
-// WriteImageSummaryAttr is an optional argument to WriteImageSummary.
-type WriteImageSummaryAttr func(optionalAttr)
-
-// WriteImageSummaryMaxImages sets the optional max_images attribute to value.
-// If not specified, defaults to 3
-//
-// REQUIRES: value >= 1
-func WriteImageSummaryMaxImages(value int64) WriteImageSummaryAttr {
-	return func(m optionalAttr) {
-		m["max_images"] = value
-	}
-}
-
-// Writes an image summary.
-//
-// Writes image `tensor` at `step` with `tag` using summary `writer`.
-// `tensor` is image with shape [height, width, channels].
-//
-// Returns the created operation.
-func WriteImageSummary(scope *Scope, writer tf.Output, step tf.Output, tag tf.Output, tensor tf.Output, bad_color tf.Output, optional ...WriteImageSummaryAttr) (o *tf.Operation) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "WriteImageSummary",
-		Input: []tf.Input{
-			writer, step, tag, tensor, bad_color,
-		},
-		Attrs: attrs,
-	}
-	return scope.AddOperation(opspec)
-}
-
-// MatrixSolveAttr is an optional argument to MatrixSolve.
-type MatrixSolveAttr func(optionalAttr)
-
-// MatrixSolveAdjoint sets the optional adjoint attribute to value.
-//
-// value: Boolean indicating whether to solve with `matrix` or its (block-wise)
-// adjoint.
-// If not specified, defaults to false
-func MatrixSolveAdjoint(value bool) MatrixSolveAttr {
-	return func(m optionalAttr) {
-		m["adjoint"] = value
-	}
-}
-
-// Solves systems of linear equations.
-//
-// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is
-// a tensor shape `[..., M, K]`.  If `adjoint` is `False` then each output matrix
-// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
-// If `adjoint` is `True` then each output matrix satisfies
-// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`.
-//
-// Arguments:
-//	matrix: Shape is `[..., M, M]`.
-//	rhs: Shape is `[..., M, K]`.
-//
-// Returns Shape is `[..., M, K]`.
-func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "MatrixSolve",
-		Input: []tf.Input{
-			matrix, rhs,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax.
 type ResourceApplyAdaMaxAttr func(optionalAttr)
 
@@ -45003,72 +44388,6 @@
 	return op.Output(0)
 }
 
-// LuAttr is an optional argument to Lu.
-type LuAttr func(optionalAttr)
-
-// LuOutputIdxType sets the optional output_idx_type attribute to value.
-// If not specified, defaults to DT_INT32
-func LuOutputIdxType(value tf.DataType) LuAttr {
-	return func(m optionalAttr) {
-		m["output_idx_type"] = value
-	}
-}
-
-// Computes the LU decomposition of one or more square matrices.
-//
-// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices.
-//
-// The input has to be invertible.
-//
-// The output consists of two tensors LU and P containing the LU decomposition
-// of all input submatrices `[..., :, :]`. LU encodes the lower triangular and
-// upper triangular factors.
-//
-// For each input submatrix of shape `[M, M]`, L is a lower triangular matrix of
-// shape `[M, M]` with unit diagonal whose entries correspond to the strictly lower
-// triangular part of LU. U is a upper triangular matrix of shape `[M, M]` whose
-// entries correspond to the upper triangular part, including the diagonal, of LU.
-//
-// P represents a permutation matrix encoded as a list of indices each between `0`
-// and `M-1`, inclusive. If P_mat denotes the permutation matrix corresponding to
-// P, then the L, U and P satisfies P_mat * input = L * U.
-//
-// Arguments:
-//	input: A tensor of shape `[..., M, M]` whose inner-most 2 dimensions form matrices of
-// size `[M, M]`.
-//
-// Returns:
-//	lu: A tensor of shape `[..., M, M]` whose strictly lower triangular part denotes the
-// lower triangular factor `L` with unit diagonal, and whose upper triangular part
-// denotes the upper triangular factor `U`.
-//	p: Permutation of the rows encoded as a list of indices in `0..M-1`. Shape is
-// `[..., M]`.
-// @compatibility(scipy)
-// Similar to `scipy.linalg.lu`, except the triangular factors `L` and `U` are
-// packed into a single tensor, the permutation is applied to `input` instead of
-// the right hand side and the permutation `P` is returned as a list of indices
-// instead of a permutation matrix.
-// @end_compatibility
-func Lu(scope *Scope, input tf.Output, optional ...LuAttr) (lu tf.Output, p tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "Lu",
-		Input: []tf.Input{
-			input,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1)
-}
-
 // Outputs deterministic pseudorandom random numbers from a Poisson distribution.
 //
 // Outputs random values from a Poisson distribution.
@@ -45484,6 +44803,210 @@
 	return scope.AddOperation(opspec)
 }
 
+// TensorArrayConcatV2Attr is an optional argument to TensorArrayConcatV2.
+type TensorArrayConcatV2Attr func(optionalAttr)
+
+// TensorArrayConcatV2ElementShapeExcept0 sets the optional element_shape_except0 attribute to value.
+// If not specified, defaults to <unknown_rank:true >
+func TensorArrayConcatV2ElementShapeExcept0(value tf.Shape) TensorArrayConcatV2Attr {
+	return func(m optionalAttr) {
+		m["element_shape_except0"] = value
+	}
+}
+
+// Deprecated. Use TensorArrayConcatV3
+func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV2Attr) (value tf.Output, lengths tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"dtype": dtype}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "TensorArrayConcatV2",
+		Input: []tf.Input{
+			handle, flow_in,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1)
+}
+
+// Writes contents to the file at input filename. Creates file and recursively
+//
+// creates directory if not existing.
+//
+// Arguments:
+//	filename: scalar. The name of the file to which we write the contents.
+//	contents: scalar. The content to be written to the output file.
+//
+// Returns the created operation.
+func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "WriteFile",
+		Input: []tf.Input{
+			filename, contents,
+		},
+	}
+	return scope.AddOperation(opspec)
+}
+
+// MatrixSolveAttr is an optional argument to MatrixSolve.
+type MatrixSolveAttr func(optionalAttr)
+
+// MatrixSolveAdjoint sets the optional adjoint attribute to value.
+//
+// value: Boolean indicating whether to solve with `matrix` or its (block-wise)
+// adjoint.
+// If not specified, defaults to false
+func MatrixSolveAdjoint(value bool) MatrixSolveAttr {
+	return func(m optionalAttr) {
+		m["adjoint"] = value
+	}
+}
+
+// Solves systems of linear equations.
+//
+// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is
+// a tensor shape `[..., M, K]`.  If `adjoint` is `False` then each output matrix
+// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
+// If `adjoint` is `True` then each output matrix satisfies
+// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`.
+//
+// Arguments:
+//	matrix: Shape is `[..., M, M]`.
+//	rhs: Shape is `[..., M, K]`.
+//
+// Returns Shape is `[..., M, K]`.
+func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "MatrixSolve",
+		Input: []tf.Input{
+			matrix, rhs,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// WriteImageSummaryAttr is an optional argument to WriteImageSummary.
+type WriteImageSummaryAttr func(optionalAttr)
+
+// WriteImageSummaryMaxImages sets the optional max_images attribute to value.
+// If not specified, defaults to 3
+//
+// REQUIRES: value >= 1
+func WriteImageSummaryMaxImages(value int64) WriteImageSummaryAttr {
+	return func(m optionalAttr) {
+		m["max_images"] = value
+	}
+}
+
+// Writes an image summary.
+//
+// Writes image `tensor` at `step` with `tag` using summary `writer`.
+// `tensor` is image with shape [height, width, channels].
+//
+// Returns the created operation.
+func WriteImageSummary(scope *Scope, writer tf.Output, step tf.Output, tag tf.Output, tensor tf.Output, bad_color tf.Output, optional ...WriteImageSummaryAttr) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "WriteImageSummary",
+		Input: []tf.Input{
+			writer, step, tag, tensor, bad_color,
+		},
+		Attrs: attrs,
+	}
+	return scope.AddOperation(opspec)
+}
+
+// Computes the Cholesky decomposition of one or more square matrices.
+//
+// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+// form square matrices.
+//
+// The input has to be symmetric and positive definite. Only the lower-triangular
+// part of the input will be used for this operation. The upper-triangular part
+// will not be read.
+//
+// The output is a tensor of the same shape as the input
+// containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
+//
+// **Note**: The gradient computation on GPU is faster for large matrices but
+// not for large batch dimensions when the submatrices are small. In this
+// case it might be faster to use the CPU.
+//
+// Arguments:
+//	input: Shape is `[..., M, M]`.
+//
+// Returns Shape is `[..., M, M]`.
+func Cholesky(scope *Scope, input tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "Cholesky",
+		Input: []tf.Input{
+			input,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Computes the sum along sparse segments of a tensor divided by the sqrt of N.
+//
+// N is the size of the segment being reduced.
+//
+// Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
+// missing, the `output` tensor at that position will be zeroed.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
+// for an explanation of segments.
+//
+// Arguments:
+//
+//	indices: A 1-D tensor. Has same rank as `segment_ids`.
+//	segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//	num_segments: Should equal the number of distinct segment IDs.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SparseSegmentSqrtNWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "SparseSegmentSqrtNWithNumSegments",
+		Input: []tf.Input{
+			data, indices, segment_ids, num_segments,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Computes softmax cross entropy cost and gradients to backpropagate.
 //
 // Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept
@@ -45629,6 +45152,86 @@
 	return op.Output(0)
 }
 
+// Computes the matrix logarithm of one or more square matrices:
+//
+//
+// \\(log(exp(A)) = A\\)
+//
+// This op is only defined for complex matrices. If A is positive-definite and
+// real, then casting to a complex matrix, taking the logarithm and casting back
+// to a real matrix will give the correct result.
+//
+// This function computes the matrix logarithm using the Schur-Parlett algorithm.
+// Details of the algorithm can be found in Section 11.6.2 of:
+// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008.
+// ISBN 978-0-898716-46-7.
+//
+// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+// form square matrices. The output is a tensor of the same shape as the input
+// containing the exponential for all input submatrices `[..., :, :]`.
+//
+// Arguments:
+//	input: Shape is `[..., M, M]`.
+//
+// Returns Shape is `[..., M, M]`.
+//
+// @compatibility(scipy)
+// Equivalent to scipy.linalg.logm
+// @end_compatibility
+func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "MatrixLogarithm",
+		Input: []tf.Input{
+			input,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// ResourceApplyGradientDescentAttr is an optional argument to ResourceApplyGradientDescent.
+type ResourceApplyGradientDescentAttr func(optionalAttr)
+
+// ResourceApplyGradientDescentUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, the subtraction will be protected by a lock;
+// otherwise the behavior is undefined, but may exhibit less contention.
+// If not specified, defaults to false
+func ResourceApplyGradientDescentUseLocking(value bool) ResourceApplyGradientDescentAttr {
+	return func(m optionalAttr) {
+		m["use_locking"] = value
+	}
+}
+
+// Update '*var' by subtracting 'alpha' * 'delta' from it.
+//
+// Arguments:
+//	var_: Should be from a Variable().
+//	alpha: Scaling factor. Must be a scalar.
+//	delta: The change.
+//
+// Returns the created operation.
+func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, delta tf.Output, optional ...ResourceApplyGradientDescentAttr) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "ResourceApplyGradientDescent",
+		Input: []tf.Input{
+			var_, alpha, delta,
+		},
+		Attrs: attrs,
+	}
+	return scope.AddOperation(opspec)
+}
+
 // Creates and returns an empty tensor list.
 //
 // All list elements must be tensors of dtype element_dtype and shape compatible
@@ -49796,6 +49399,45 @@
 	return scope.AddOperation(opspec)
 }
 
+// CollectiveBcastSendV2Attr is an optional argument to CollectiveBcastSendV2.
+type CollectiveBcastSendV2Attr func(optionalAttr)
+
+// CollectiveBcastSendV2CommunicationHint sets the optional communication_hint attribute to value.
+// If not specified, defaults to "auto"
+func CollectiveBcastSendV2CommunicationHint(value string) CollectiveBcastSendV2Attr {
+	return func(m optionalAttr) {
+		m["communication_hint"] = value
+	}
+}
+
+// CollectiveBcastSendV2TimeoutSeconds sets the optional timeout_seconds attribute to value.
+// If not specified, defaults to 0
+func CollectiveBcastSendV2TimeoutSeconds(value float32) CollectiveBcastSendV2Attr {
+	return func(m optionalAttr) {
+		m["timeout_seconds"] = value
+	}
+}
+
+// Broadcasts a tensor value to one or more other devices.
+func CollectiveBcastSendV2(scope *Scope, input tf.Output, group_size tf.Output, group_key tf.Output, instance_key tf.Output, optional ...CollectiveBcastSendV2Attr) (data tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "CollectiveBcastSendV2",
+		Input: []tf.Input{
+			input, group_size, group_key, instance_key,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // InfeedEnqueueTupleAttr is an optional argument to InfeedEnqueueTuple.
 type InfeedEnqueueTupleAttr func(optionalAttr)
 
@@ -50054,6 +49696,173 @@
 	return scope.AddOperation(opspec)
 }
 
+// DebugNumericSummaryAttr is an optional argument to DebugNumericSummary.
+type DebugNumericSummaryAttr func(optionalAttr)
+
+// DebugNumericSummaryDeviceName sets the optional device_name attribute to value.
+// If not specified, defaults to ""
+func DebugNumericSummaryDeviceName(value string) DebugNumericSummaryAttr {
+	return func(m optionalAttr) {
+		m["device_name"] = value
+	}
+}
+
+// DebugNumericSummaryTensorName sets the optional tensor_name attribute to value.
+//
+// value: Name of the input tensor.
+// If not specified, defaults to ""
+func DebugNumericSummaryTensorName(value string) DebugNumericSummaryAttr {
+	return func(m optionalAttr) {
+		m["tensor_name"] = value
+	}
+}
+
+// DebugNumericSummaryDebugUrls sets the optional debug_urls attribute to value.
+//
+// value: List of URLs to debug targets, e.g.,
+//   file:///foo/tfdbg_dump, grpc:://localhost:11011.
+// If not specified, defaults to <>
+func DebugNumericSummaryDebugUrls(value []string) DebugNumericSummaryAttr {
+	return func(m optionalAttr) {
+		m["debug_urls"] = value
+	}
+}
+
+// DebugNumericSummaryLowerBound sets the optional lower_bound attribute to value.
+//
+// value: (float) The lower bound <= which values will be included in the
+//   generalized -inf count. Default: -inf.
+// If not specified, defaults to -inf
+func DebugNumericSummaryLowerBound(value float32) DebugNumericSummaryAttr {
+	return func(m optionalAttr) {
+		m["lower_bound"] = value
+	}
+}
+
+// DebugNumericSummaryUpperBound sets the optional upper_bound attribute to value.
+//
+// value: (float) The upper bound >= which values will be included in the
+//   generalized +inf count. Default: +inf.
+// If not specified, defaults to inf
+func DebugNumericSummaryUpperBound(value float32) DebugNumericSummaryAttr {
+	return func(m optionalAttr) {
+		m["upper_bound"] = value
+	}
+}
+
+// DebugNumericSummaryMuteIfHealthy sets the optional mute_if_healthy attribute to value.
+//
+// value: (bool) Do not send data to the debug URLs unless at least one
+//   of elements [2], [3] and [7] (i.e., the nan count and the generalized -inf and
+//   inf counts) is non-zero.
+// If not specified, defaults to false
+func DebugNumericSummaryMuteIfHealthy(value bool) DebugNumericSummaryAttr {
+	return func(m optionalAttr) {
+		m["mute_if_healthy"] = value
+	}
+}
+
+// DebugNumericSummaryGatedGrpc sets the optional gated_grpc attribute to value.
+//
+// value: Whether this op will be gated. If any of the debug_urls of this
+//   debug node is of the grpc:// scheme, when the value of this attribute is set
+//   to True, the data will not actually be sent via the grpc stream unless this
+//   debug op has been enabled at the debug_url. If all of the debug_urls of this
+//   debug node are of the grpc:// scheme and the debug op is enabled at none of
+//   them, the output will be an empty Tensor.
+// If not specified, defaults to false
+func DebugNumericSummaryGatedGrpc(value bool) DebugNumericSummaryAttr {
+	return func(m optionalAttr) {
+		m["gated_grpc"] = value
+	}
+}
+
+// Debug Numeric Summary Op.
+//
+// Provide a basic summary of numeric value types, range and distribution.
+//
+// output: A double tensor of shape [14 + nDimensions], where nDimensions is the
+//   number of dimensions of the tensor's shape. The elements of output are:
+//   [0]: is initialized (1.0) or not (0.0).
+//   [1]: total number of elements
+//   [2]: NaN element count
+//   [3]: generalized -inf count: elements <= lower_bound. lower_bound is -inf by
+//     default.
+//   [4]: negative element count (excluding -inf), if lower_bound is the default
+//     -inf. Otherwise, this is the count of elements > lower_bound and < 0.
+//   [5]: zero element count
+//   [6]: positive element count (excluding +inf), if upper_bound is the default
+//     +inf. Otherwise, this is the count of elements < upper_bound and > 0.
+//   [7]: generalized +inf count, elements >= upper_bound. upper_bound is +inf by
+//     default.
+// Output elements [1:8] are all zero, if the tensor is uninitialized.
+//   [8]: minimum of all non-inf and non-NaN elements.
+//        If uninitialized or no such element exists: +inf.
+//   [9]: maximum of all non-inf and non-NaN elements.
+//        If uninitialized or no such element exists: -inf.
+//   [10]: mean of all non-inf and non-NaN elements.
+//         If uninitialized or no such element exists: NaN.
+//   [11]: variance of all non-inf and non-NaN elements.
+//         If uninitialized or no such element exists: NaN.
+//   [12]: Data type of the tensor encoded as an enum integer. See the DataType
+//         proto for more details.
+//   [13]: Number of dimensions of the tensor (ndims).
+//   [14+]: Sizes of the dimensions.
+//
+//
+// Arguments:
+//	input: Input tensor, non-Reference type.
+func DebugNumericSummary(scope *Scope, input tf.Output, optional ...DebugNumericSummaryAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "DebugNumericSummary",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Outputs random integers from a uniform distribution.
+//
+// The generated values are uniform integers in the range `[minval, maxval)`.
+// The lower bound `minval` is included in the range, while the upper bound
+// `maxval` is excluded.
+//
+// The random integers are slightly biased unless `maxval - minval` is an exact
+// power of two.  The bias is small for values of `maxval - minval` significantly
+// smaller than the range of the output (either `2^32` or `2^64`).
+//
+// Arguments:
+//	resource: The handle of the resource variable that stores the state of the RNG.
+//	algorithm: The RNG algorithm.
+//	shape: The shape of the output tensor.
+//	minval: Minimum value (inclusive, scalar).
+//	maxval: Maximum value (exclusive, scalar).
+//
+// Returns Random values with specified shape.
+func StatefulUniformInt(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "StatefulUniformInt",
+		Input: []tf.Input{
+			resource, algorithm, shape, minval, maxval,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // SerializeManySparseAttr is an optional argument to SerializeManySparse.
 type SerializeManySparseAttr func(optionalAttr)
 
@@ -50661,6 +50470,58 @@
 	return op.Output(0)
 }
 
+// Sends the named tensor to another XLA computation. Wraps the XLA Send operator
+//
+// documented at
+//  https://www.tensorflow.org/performance/xla/operation_semantics#send .
+//
+// Arguments:
+//	tensor: The tensor to send.
+//	tensor_name: A string key that identifies the channel.
+//
+// Returns the created operation.
+func XlaSend(scope *Scope, tensor tf.Output, tensor_name string) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"tensor_name": tensor_name}
+	opspec := tf.OpSpec{
+		Type: "XlaSend",
+		Input: []tf.Input{
+			tensor,
+		},
+		Attrs: attrs,
+	}
+	return scope.AddOperation(opspec)
+}
+
+// Returns the index of a data point that should be added to the seed set.
+//
+// Entries in distances are assumed to be squared distances of candidate points to
+// the already sampled centers in the seed set. The op constructs one Markov chain
+// of the k-MC^2 algorithm and returns the index of one candidate point to be added
+// as an additional cluster center.
+//
+// Arguments:
+//	distances: Vector with squared distances to the closest previously sampled cluster center
+// for each candidate point.
+//	seed: Scalar. Seed for initializing the random number generator.
+//
+// Returns Scalar with the index of the sampled point.
+func KMC2ChainInitialization(scope *Scope, distances tf.Output, seed tf.Output) (index tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "KMC2ChainInitialization",
+		Input: []tf.Input{
+			distances, seed,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Creates a tree ensemble model and returns a handle to it.
 //
 // Arguments:
@@ -50711,6 +50572,126 @@
 	return op.Output(0)
 }
 
+// QrAttr is an optional argument to Qr.
+type QrAttr func(optionalAttr)
+
+// QrFullMatrices sets the optional full_matrices attribute to value.
+//
+// value: If true, compute full-sized `q` and `r`. If false
+// (the default), compute only the leading `P` columns of `q`.
+// If not specified, defaults to false
+func QrFullMatrices(value bool) QrAttr {
+	return func(m optionalAttr) {
+		m["full_matrices"] = value
+	}
+}
+
+// Computes the QR decompositions of one or more matrices.
+//
+// Computes the QR decomposition of each inner matrix in `tensor` such that
+// `tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
+//
+// Currently, the gradient for the QR decomposition is well-defined only when
+// the first `P` columns of the inner matrix are linearly independent, where
+// `P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`.
+//
+// ```python
+// # a is a tensor.
+// # q is a tensor of orthonormal matrices.
+// # r is a tensor of upper triangular matrices.
+// q, r = qr(a)
+// q_full, r_full = qr(a, full_matrices=True)
+// ```
+//
+// Arguments:
+//	input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
+//
+// Returns:
+//	q: Orthonormal basis for range of `a`. If `full_matrices` is `False` then
+// shape is `[..., M, P]`; if `full_matrices` is `True` then shape is
+// `[..., M, M]`.
+//	r: Triangular factor. If `full_matrices` is `False` then shape is
+// `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`.
+func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "Qr",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1)
+}
+
+// Retrieve multiple values from the computation outfeed. Device ordinal is a
+// tensor allowing dynamic outfeed.
+//
+// This operation will block indefinitely until data is available. Output `i`
+// corresponds to XLA tuple element `i`.
+//
+// Arguments:
+//	device_ordinal: An int scalar tensor, representing the TPU device to use. This should be -1 when
+// the Op is running on a TPU device, and >= 0 when the Op is running on the CPU
+// device.
+//	dtypes: The element types of each element in `outputs`.
+//	shapes: The shapes of each tensor in `outputs`.
+//
+// Returns A list of tensors that will be read from the outfeed.
+func OutfeedDequeueTupleV2(scope *Scope, device_ordinal tf.Output, dtypes []tf.DataType, shapes []tf.Shape) (outputs []tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"dtypes": dtypes, "shapes": shapes}
+	opspec := tf.OpSpec{
+		Type: "OutfeedDequeueTupleV2",
+		Input: []tf.Input{
+			device_ordinal,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	if scope.Err() != nil {
+		return
+	}
+	var idx int
+	var err error
+	if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil {
+		scope.UpdateErr("OutfeedDequeueTupleV2", err)
+		return
+	}
+	return outputs
+}
+
+// Makes a copy of `x`.
+//
+// Arguments:
+//	x: The source tensor of type `T`.
+//
+// Returns     y: A `Tensor` of type `T`. A copy of `x`. Guaranteed that `y`
+//       is not an alias of `x`.
+func DeepCopy(scope *Scope, x tf.Output) (y tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "DeepCopy",
+		Input: []tf.Input{
+			x,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Retrieves a single tensor from the computation outfeed. Device ordinal is a
 // tensor allowing dynamic outfeed.
 //
@@ -51495,6 +51476,104 @@
 	return op.Output(0)
 }
 
+// TridiagonalSolveAttr is an optional argument to TridiagonalSolve.
+type TridiagonalSolveAttr func(optionalAttr)
+
+// TridiagonalSolvePartialPivoting sets the optional partial_pivoting attribute to value.
+//
+// value: Whether to apply partial pivoting. Partial pivoting makes the procedure more
+// stable, but slower.
+// If not specified, defaults to true
+func TridiagonalSolvePartialPivoting(value bool) TridiagonalSolveAttr {
+	return func(m optionalAttr) {
+		m["partial_pivoting"] = value
+	}
+}
+
+// Solves tridiagonal systems of equations.
+//
+//   Solves tridiagonal systems of equations.
+//   Supports batch dimensions and multiple right-hand sides per each left-hand
+//   side.
+//   On CPU, solution is computed via Gaussian elimination with or without partial
+//   pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE
+//   library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
+//   Partial pivoting is not yet supported by XLA backends.
+//
+// Arguments:
+//	diagonals: Tensor of shape `[..., 3, M]` whose innermost 2 dimensions represent the
+// tridiagonal matrices with three rows being the superdiagonal, diagonals, and
+// subdiagonals, in order. The last element of the superdiagonal and the first
+// element of the subdiagonal is ignored.
+//	rhs: Tensor of shape `[..., M, K]`, representing K right-hand sides per each
+// left-hand side.
+//
+// Returns Tensor of shape `[..., M, K]` containing the solutions
+func TridiagonalSolve(scope *Scope, diagonals tf.Output, rhs tf.Output, optional ...TridiagonalSolveAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "TridiagonalSolve",
+		Input: []tf.Input{
+			diagonals, rhs,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// An Op to exchange data across TPU replicas.
+//
+// On each replica, the input is split into `split_count` blocks along
+// `split_dimension` and send to the other replicas given group_assignment. After
+// receiving `split_count` - 1 blocks from other replicas, we concatenate the
+// blocks along `concat_dimension` as the output.
+//
+// For example, suppose there are 2 TPU replicas:
+// replica 0 receives input: `[[A, B]]`
+// replica 1 receives input: `[[C, D]]`
+//
+// group_assignment=`[[0, 1]]`
+// concat_dimension=0
+// split_dimension=1
+// split_count=2
+//
+// replica 0's output: `[[A], [C]]`
+// replica 1's output: `[[B], [D]]`
+//
+// Arguments:
+//	input: The local input to the sum.
+//	group_assignment: An int32 tensor with shape
+// [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
+// replica ids in the ith subgroup.
+//	concat_dimension: The dimension number to concatenate.
+//	split_dimension: The dimension number to split.
+//	split_count: The number of splits, this number must equal to the sub-group
+// size(group_assignment.get_shape()[1])
+//
+// Returns The exchanged result.
+func AllToAll(scope *Scope, input tf.Output, group_assignment tf.Output, concat_dimension int64, split_dimension int64, split_count int64) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"concat_dimension": concat_dimension, "split_dimension": split_dimension, "split_count": split_count}
+	opspec := tf.OpSpec{
+		Type: "AllToAll",
+		Input: []tf.Input{
+			input, group_assignment,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // LoadTPUEmbeddingMDLAdagradLightParametersAttr is an optional argument to LoadTPUEmbeddingMDLAdagradLightParameters.
 type LoadTPUEmbeddingMDLAdagradLightParametersAttr func(optionalAttr)
 
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index 6d884f3..df5b34c 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -215,25 +215,29 @@
 func (t *Tensor) Shape() []int64 { return t.shape }
 
 // Reshape  updates tensor's shape in place if this is possible or returns an error otherwise.
-func (t *Tensor) Reshape(new_shape []int64) error {
-	old_shape_size := numElements(t.shape)
-	new_shape_size := numElements(new_shape)
+func (t *Tensor) Reshape(newShape []int64) error {
+	oldShapeSize := numElements(t.shape)
+	newShapeSize := numElements(newShape)
 
-	if old_shape_size != new_shape_size {
-		return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, old_shape_size, new_shape, new_shape_size)
+	if oldShapeSize != newShapeSize {
+		return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, oldShapeSize, newShape, newShapeSize)
 	}
 
-	if len(new_shape) == 0 {
+	if len(newShape) == 0 {
 		return nil
 	}
 
 	var shapePtr *C.int64_t
-	shapePtr = (*C.int64_t)(unsafe.Pointer(&new_shape[0]))
+	shapePtr = (*C.int64_t)(unsafe.Pointer(&newShape[0]))
 
 	status := newStatus()
-	C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(new_shape)), status.c)
+	C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(newShape)), status.c)
 
-	return status.Err()
+	if err := status.Err(); err != nil {
+		return err
+	}
+	t.shape = newShape
+	return nil
 }
 
 // Value converts the Tensor to a Go value. For now, not all Tensor types are
diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go
index 15b2ea5..8aa7106 100644
--- a/tensorflow/go/tensor_test.go
+++ b/tensorflow/go/tensor_test.go
@@ -358,3 +358,31 @@
 	})
 
 }
+
+func TestReshape(t *testing.T) {
+	tensor, err := NewTensor([]int64{1, 2})
+	if err != nil {
+		t.Fatalf("Unable to create new tensor: %v", err)
+	}
+
+	if got, want := len(tensor.Shape()), 1; got != want {
+		t.Fatalf("len(tensor.Shape()): got %d, want %d", got, want)
+	}
+	if got, want := tensor.Shape()[0], int64(2); got != want {
+		t.Errorf("tensor.Shape()[0]: got %d, want %d", got, want)
+	}
+
+	if err := tensor.Reshape([]int64{1, 2}); err != nil {
+		t.Fatalf("tensor.Reshape([1, 2]) failed: %v", err)
+	}
+
+	if got, want := len(tensor.Shape()), 2; got != want {
+		t.Fatalf("After reshape, len(tensor.Shape()): got %d, want %d", got, want)
+	}
+	if got, want := tensor.Shape()[0], int64(1); got != want {
+		t.Errorf("After reshape, tensor.Shape()[0]: got %d, want %d", got, want)
+	}
+	if got, want := tensor.Shape()[1], int64(2); got != want {
+		t.Errorf("After reshape, tensor.Shape()[1]: got %d, want %d", got, want)
+	}
+}
diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index bb73617..35b72ee 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -14,7 +14,10 @@
     "testdata/*.tflite",
     "testdata/*.csv",
     "models/testdata/*",
-]))
+]) + [
+    "create_op_resolver.h",
+    "create_op_resolver_with_selected_ops.cc",
+])
 
 config_setting(
     name = "gemmlowp_profiling",
@@ -260,7 +263,7 @@
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/core/api",
         "//tensorflow/lite/core/api:verifier",
-        "//tensorflow/lite/delegates:status",
+        "//tensorflow/lite/delegates:telemetry",
         "//tensorflow/lite/experimental/resource",
         "//tensorflow/lite/kernels/internal:compatibility",
         "//tensorflow/lite/profiling:platform_profiler",
@@ -347,7 +350,7 @@
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/core/api",
         "//tensorflow/lite/core/api:verifier",
-        "//tensorflow/lite/delegates:status",
+        "//tensorflow/lite/delegates:telemetry",
         "//tensorflow/lite/experimental/resource",
         "//tensorflow/lite/kernels/internal:compatibility",
         "//tensorflow/lite/profiling:platform_profiler",
@@ -707,6 +710,19 @@
     ],
 )
 
+# Defines CreateOpResolver with all builtin ops.
+cc_library(
+    name = "create_op_resolver_with_builtin_ops",
+    srcs = ["create_op_resolver_with_builtin_ops.cc"],
+    hdrs = ["create_op_resolver.h"],
+    copts = tflite_copts(),
+    deps = [
+        "//tensorflow/lite:op_resolver",
+        "//tensorflow/lite/core/api",
+        "//tensorflow/lite/core/shims:builtin_ops",
+    ],
+)
+
 cc_test(
     name = "util_test",
     size = "small",
diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt
index aeb271d..24d405a 100644
--- a/tensorflow/lite/CMakeLists.txt
+++ b/tensorflow/lite/CMakeLists.txt
@@ -176,6 +176,9 @@
 # XNNPACK delegate is preferred to the weak-symbol one.
 list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*tflite_with_xnnpack\\.cc$")
 
+# Exclude Flex related files.
+list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*with_selected_ops\\.cc$")
+
 if(_TFLITE_ENABLE_MMAP)
   list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation_disabled\\.cc$")
 else()
@@ -217,23 +220,6 @@
     FILTER "(_test)\\.(cc|h)$"
   )
   populate_tflite_source_vars(
-    "delegates/gpu/cl/kernels/special"
-    TFLITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_SRCS
-    FILTER "(_test)\\.(cc|h)$"
-  )
-  populate_tflite_source_vars(
-    "delegates/gpu/cl/selectors" TFLITE_DELEGATES_GPU_CL_SELECTORS_SRCS
-    FILTER "(_test)\\.(cc|h)$"
-  )
-  populate_tflite_source_vars(
-    "delegates/gpu/cl/selectors/default" TFLITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SRCS
-    FILTER "(_test)\\.(cc|h)$"
-  )
-  populate_tflite_source_vars(
-    "delegates/gpu/common" TFLITE_DELEGATES_GPU_COMMON_SRCS
-    FILTER "(_test)\\.(cc|h)$"
-  )
-  populate_tflite_source_vars(
     "delegates/gpu/common/default" TFLITE_DELEGATES_GPU_COMMON_DEFAULT_SRCS
     FILTER "(_test)\\.(cc|h)$"
   )
@@ -243,6 +229,18 @@
     FILTER "(_test)\\.(cc|h)$"
   )
   populate_tflite_source_vars(
+    "delegates/gpu/common/selectors" TFLITE_DELEGATES_GPU_COMMON_SELECTORS_SRCS
+    FILTER "(_test)\\.(cc|h)$"
+  )
+  populate_tflite_source_vars(
+    "delegates/gpu/common/selectors/default" TFLITE_DELEGATES_GPU_COMMON_SELECTORS_DEFAULT_SRCS
+    FILTER "(_test)\\.(cc|h)$"
+  )
+  populate_tflite_source_vars(
+    "delegates/gpu/common" TFLITE_DELEGATES_GPU_COMMON_SRCS
+    FILTER "(_test)\\.(cc|h)$"
+  )
+  populate_tflite_source_vars(
     "delegates/gpu/common/task"
     TFLITE_DELEGATES_GPU_COMMON_TASK_SRCS
     FILTER "(_test)\\.(cc|h)$"
@@ -267,12 +265,11 @@
     ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.cc
     ${TFLITE_DELEGATES_GPU_CL_SRCS}
     ${TFLITE_DELEGATES_GPU_CL_KERNELS_SRCS}
-    ${TFLITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_SRCS}
-    ${TFLITE_DELEGATES_GPU_CL_SELECTORS_SRCS}
-    ${TFLITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SRCS}
-    ${TFLITE_DELEGATES_GPU_COMMON_SRCS}
     ${TFLITE_DELEGATES_GPU_COMMON_DEFAULT_SRCS}
     ${TFLITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_SRCS}
+    ${TFLITE_DELEGATES_GPU_COMMON_SELECTORS_SRCS}
+    ${TFLITE_DELEGATES_GPU_COMMON_SELECTORS_DEFAULT_SRCS}
+    ${TFLITE_DELEGATES_GPU_COMMON_SRCS}
     ${TFLITE_DELEGATES_GPU_COMMON_TASK_SRCS}
     ${TFLITE_DELEGATES_GPU_COMMON_TASKS_SRCS}
     ${TFLITE_DELEGATES_GPU_COMMON_TASKS_SPECIAL_SRCS}
@@ -331,7 +328,7 @@
 endif()
 populate_tflite_source_vars("kernels"
   TFLITE_KERNEL_SRCS
-  FILTER ".*(_test_util_internal|test_main)\\.(cc|h)"
+  FILTER "(.*_test_util_internal|test_.*)\\.(cc|h)"
 )
 populate_tflite_source_vars("kernels/internal" TFLITE_KERNEL_INTERNAL_SRCS)
 populate_tflite_source_vars("kernels/internal/optimized"
diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc
index 8b8913d..1edfcaf 100644
--- a/tensorflow/lite/arena_planner.cc
+++ b/tensorflow/lite/arena_planner.cc
@@ -205,7 +205,9 @@
     for (int j = 0; j < node_temporaries->size; ++j) {
       int tensor_index = node_temporaries->data[j];
       alloc_node_[tensor_index] = i;
-      dealloc_node_[tensor_index] = i;
+      if (!preserve_intermediates_) {
+        dealloc_node_[tensor_index] = i;
+      }
     }
   }
 
diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc
index ca2127d..f9d735e 100644
--- a/tensorflow/lite/arena_planner_test.cc
+++ b/tensorflow/lite/arena_planner_test.cc
@@ -164,12 +164,13 @@
 
 class ArenaPlannerTest : public ::testing::Test {
  protected:
-  void SetGraph(TestGraph* graph, bool preserve_inputs = false) {
+  void SetGraph(TestGraph* graph, bool preserve_inputs = false,
+                bool preserve_intermediates = false) {
     graph_ = graph;
     context_.ReportError = ReportError;
     planner_.reset(new ArenaPlanner(
         &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)),
-        preserve_inputs, /*preserve intermediates*/ false, kTensorAlignment));
+        preserve_inputs, preserve_intermediates, kTensorAlignment));
     CHECK(planner_->ResetAllocations() == kTfLiteOk);
     CHECK(planner_->PlanAllocations() == kTfLiteOk);
   }
@@ -745,6 +746,35 @@
   EXPECT_EQ(GetOffset(2), GetOffsetAfter(5));
 }
 
+TEST_F(ArenaPlannerTest, DebugTensors) {
+  TestGraph graph({0, 1},
+                  {
+                      /* in, out, tmp */
+                      {{0, 1}, {2}, {5}},  // First op, with temporary
+                      {{2, 0}, {4}, {6}},  // Second op, with temporary
+                      {{4}, {3}, {7}}      // Third op, with temporary
+                  },
+                  {3});
+  SetGraph(&graph, false, /*preserve_intermediates=*/false);
+  Execute(0, 10);
+
+  // Memory of temporary tensors are shared by default.
+  EXPECT_EQ(GetOffset(5), 0);
+  EXPECT_EQ(GetOffset(6), 0);
+  EXPECT_EQ(GetOffset(7), 0);
+
+  SetGraph(&graph, false, /*preserve_intermediates=*/true);
+  Execute(0, 10);
+
+  std::set<std::ptrdiff_t> tensorOffsets;
+  for (int i = 0; i < 8; i++) {
+    tensorOffsets.insert(GetOffset(i));
+  }
+  // Every tensor should have unique memory allocation with
+  // preserve_intermediates.
+  EXPECT_EQ(tensorOffsets.size(), 8);
+}
+
 }  // namespace
 }  // namespace tflite
 
diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl
index 9f729d9..afc7615 100644
--- a/tensorflow/lite/build_def.bzl
+++ b/tensorflow/lite/build_def.bzl
@@ -635,6 +635,8 @@
         flags += " --ignore_converter_errors --run_with_flex"
     elif conversion_mode == "forward-compat":
         flags += " --make_forward_compat_test"
+    elif conversion_mode == "mlir-quant":
+        flags += " --mlir_quantizer"
     if test_name.startswith(merged_test_model_name() + "_"):
         flags += flags_for_merged_test_models(test_name, conversion_mode)
 
@@ -800,17 +802,17 @@
             model = models,
         )
         real_srcs.append(":%s_registration" % name)
-        real_deps.append("//tensorflow/lite/java/src/main/native:selected_ops_jni")
+        real_srcs.append("//tensorflow/lite:create_op_resolver_with_selected_ops.cc")
     else:
         # Support all operators if `models` not specified.
-        real_deps.append("//tensorflow/lite/java/src/main/native")
+        real_deps.append("//tensorflow/lite:create_op_resolver_with_builtin_ops")
 
     native.cc_library(
         name = name,
         srcs = real_srcs,
         hdrs = [
             # TODO(b/161323860) replace this by generated header.
-            "//tensorflow/lite/java/src/main/native:op_resolver.h",
+            "//tensorflow/lite:create_op_resolver.h",
         ],
         copts = tflite_copts(),
         linkopts = select({
@@ -854,9 +856,10 @@
     tflite_jni_binary(
         name = "libtensorflowlite_jni.so",
         linkscript = "//tensorflow/lite/java:tflite_version_script.lds",
+        # Do not sort: "native_framework_only" must come before custom tflite library.
         deps = [
-            ":%s_cc" % name,
             "//tensorflow/lite/java/src/main/native:native_framework_only",
+            ":%s_cc" % name,
         ],
     )
 
@@ -882,3 +885,61 @@
         name = "%s_aar" % name,
         android_library = name,
     )
+
+def tflite_custom_c_library(
+        name,
+        models = [],
+        **kwargs):
+    """Generates a tflite cc library, stripping off unused operators.
+
+    This library includes the C API and the op kernels used in the given models.
+
+    Args:
+        name: Str, name of the target.
+        models: List of models. This TFLite build will only include
+            operators used in these models. If the list is empty, all builtin
+            operators are included.
+       **kwargs: custom c_api cc_library kwargs.
+    """
+    op_resolver_deps = "//tensorflow/lite:create_op_resolver_with_builtin_ops"
+    if models:
+        gen_selected_ops(
+            name = "%s_registration" % name,
+            model = models,
+        )
+
+        native.cc_library(
+            name = "%s_create_op_resolver" % name,
+            srcs = [
+                ":%s_registration" % name,
+                "//tensorflow/lite:create_op_resolver_with_selected_ops.cc",
+            ],
+            hdrs = ["//tensorflow/lite:create_op_resolver.h"],
+            copts = tflite_copts(),
+            deps = [
+                "//tensorflow/lite:op_resolver",
+                "//tensorflow/lite:framework",
+                "//tensorflow/lite/kernels:builtin_ops",
+            ],
+        )
+        op_resolver_deps = "%s_create_op_resolver" % name
+
+    native.cc_library(
+        name = name,
+        srcs = ["//tensorflow/lite/c:c_api_srcs"],
+        hdrs = ["//tensorflow/lite/c:c_api.h"],
+        copts = tflite_copts(),
+        deps = [
+            op_resolver_deps,
+            "//tensorflow/lite/c:common",
+            "//tensorflow/lite/c:c_api_types",
+            "//tensorflow/lite:builtin_ops",
+            "//tensorflow/lite:framework",
+            "//tensorflow/lite:version",
+            "//tensorflow/lite/core/api",
+            "//tensorflow/lite/delegates:interpreter_utils",
+            "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
+            "//tensorflow/lite/kernels/internal:compatibility",
+        ],
+        **kwargs
+    )
diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD
index c366ec0..3119cf3 100644
--- a/tensorflow/lite/c/BUILD
+++ b/tensorflow/lite/c/BUILD
@@ -2,6 +2,7 @@
     "//tensorflow/lite:build_def.bzl",
     "tflite_cc_shared_object",
     "tflite_copts",
+    "tflite_custom_c_library",
 )
 load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
@@ -60,13 +61,13 @@
     copts = tflite_copts(),
     deps = [
         ":c_api_internal",
-        ":common",
+        ":c_api_types",
         "//tensorflow/lite:builtin_ops",
+        "//tensorflow/lite:create_op_resolver_with_builtin_ops",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:version",
         "//tensorflow/lite/delegates:interpreter_utils",
         "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
-        "//tensorflow/lite/kernels:builtin_ops",
         "//tensorflow/lite/kernels/internal:compatibility",
     ],
     alwayslink = 1,  # Why?? TODO(b/161243354): eliminate this.
@@ -87,6 +88,13 @@
     alwayslink = 1,  # Why?? TODO(b/161243354): eliminate this.
 )
 
+cc_library(
+    name = "c_api_types",
+    hdrs = ["c_api_types.h"],
+    compatible_with = get_compatible_with_portable(),
+    copts = tflite_copts(),
+)
+
 cc_test(
     name = "c_api_test",
     size = "small",
@@ -104,6 +112,33 @@
     ],
 )
 
+tflite_custom_c_library(
+    name = "selectively_built_c_api_test_lib",
+    testonly = 1,
+    models = [
+        "//tensorflow/lite:testdata/add.bin",
+        "//tensorflow/lite:testdata/add_quantized.bin",
+    ],
+    visibility = ["//visibility:private"],
+)
+
+cc_test(
+    name = "selectively_built_c_api_test",
+    size = "small",
+    srcs = ["c_api_test.cc"],
+    copts = tflite_copts(),
+    data = [
+        "//tensorflow/lite:testdata/add.bin",
+        "//tensorflow/lite:testdata/add_quantized.bin",
+    ],
+    deps = [
+        ":common",
+        ":selectively_built_c_api_test_lib",
+        "//tensorflow/lite/testing:util",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 cc_test(
     name = "c_api_experimental_test",
     size = "small",
@@ -130,6 +165,9 @@
     ],
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
+    deps = [
+        ":c_api_types",
+    ],
     alwayslink = 1,  # Why?? TODO(b/161243354): eliminate this.
 )
 
@@ -137,9 +175,19 @@
 exports_files([
     "c_api.h",
     "c_api_experimental.h",
+    "c_api_types.h",
     "common.h",
 ])
 
+# For use in selective build rule for C API.
+filegroup(
+    name = "c_api_srcs",
+    srcs = [
+        "c_api.cc",
+        "c_api_internal.h",
+    ],
+)
+
 # Test the C extension API code.
 cc_test(
     name = "common_test",
diff --git a/tensorflow/lite/c/c_api.cc b/tensorflow/lite/c/c_api.cc
index ab247a1..e59cb0f 100644
--- a/tensorflow/lite/c/c_api.cc
+++ b/tensorflow/lite/c/c_api.cc
@@ -18,19 +18,15 @@
 
 #include "tensorflow/lite/builtin_ops.h"
 #include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/create_op_resolver.h"
 #include "tensorflow/lite/delegates/interpreter_utils.h"
 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
 #include "tensorflow/lite/error_reporter.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/internal/compatibility.h"
-#include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
 #include "tensorflow/lite/version.h"
 
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
 namespace {
 class CallbackErrorReporter : public tflite::ErrorReporter {
  public:
@@ -85,6 +81,8 @@
 
 }  // namespace
 
+extern "C" {
+
 // LINT.IfChange
 
 const char* TfLiteVersion() { return TFLITE_VERSION_STRING; }
@@ -133,9 +131,10 @@
 TfLiteInterpreter* TfLiteInterpreterCreate(
     const TfLiteModel* model,
     const TfLiteInterpreterOptions* optional_options) {
-  tflite::ops::builtin::BuiltinOpResolver resolver;
+  std::unique_ptr<tflite::MutableOpResolver> resolver =
+      tflite::CreateOpResolver();
   return tflite::internal::InterpreterCreateWithOpResolver(
-      model, optional_options, &resolver);
+      model, optional_options, resolver.get());
 }
 
 void TfLiteInterpreterDelete(TfLiteInterpreter* interpreter) {
@@ -198,9 +197,7 @@
   return tensor->bytes;
 }
 
-void* TfLiteTensorData(const TfLiteTensor* tensor) {
-  return static_cast<void*>(tensor->data.raw);
-}
+void* TfLiteTensorData(const TfLiteTensor* tensor) { return tensor->data.raw; }
 
 const char* TfLiteTensorName(const TfLiteTensor* tensor) {
   return tensor->name;
@@ -233,9 +230,7 @@
 
 // LINT.ThenChange(//tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs)
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif  // __cplusplus
 
 namespace tflite {
 namespace internal {
diff --git a/tensorflow/lite/c/c_api.h b/tensorflow/lite/c/c_api.h
index dd4178d..b5a137b 100644
--- a/tensorflow/lite/c/c_api.h
+++ b/tensorflow/lite/c/c_api.h
@@ -17,8 +17,9 @@
 
 #include <stdarg.h>
 #include <stdint.h>
+#include <stdlib.h>
 
-#include "common.h"
+#include "tensorflow/lite/c/c_api_types.h"  // IWYU pragma: export
 
 // --------------------------------------------------------------------------
 /// C API for TensorFlow Lite.
@@ -71,14 +72,29 @@
 #endif  // __cplusplus
 
 // --------------------------------------------------------------------------
+// Opaque types used by the C API.
+
+// TfLiteModel wraps a loaded TensorFlow Lite model.
+typedef struct TfLiteModel TfLiteModel;
+
+// TfLiteInterpreterOptions allows customized interpreter configuration.
+typedef struct TfLiteInterpreterOptions TfLiteInterpreterOptions;
+
+// Allows delegation of nodes to alternative backends.
+typedef struct TfLiteDelegate TfLiteDelegate;
+
+// TfLiteInterpreter provides inference from a provided model.
+typedef struct TfLiteInterpreter TfLiteInterpreter;
+
+// A tensor in the interpreter system which is a wrapper around a buffer of
+// data including a dimensionality (or NULL if not currently defined).
+typedef struct TfLiteTensor TfLiteTensor;
+
+// --------------------------------------------------------------------------
 // TfLiteVersion returns a string describing version information of the
 // TensorFlow Lite library. TensorFlow Lite uses semantic versioning.
 TFL_CAPI_EXPORT extern const char* TfLiteVersion(void);
 
-// --------------------------------------------------------------------------
-// TfLiteModel wraps a loaded TensorFlow Lite model.
-typedef struct TfLiteModel TfLiteModel;
-
 // Returns a model from the provided buffer, or null on failure.
 TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreate(const void* model_data,
                                                       size_t model_size);
@@ -90,10 +106,6 @@
 // Destroys the model instance.
 TFL_CAPI_EXPORT extern void TfLiteModelDelete(TfLiteModel* model);
 
-// --------------------------------------------------------------------------
-// TfLiteInterpreterOptions allows customized interpreter configuration.
-typedef struct TfLiteInterpreterOptions TfLiteInterpreterOptions;
-
 // Returns a new interpreter options instances.
 TFL_CAPI_EXPORT extern TfLiteInterpreterOptions*
 TfLiteInterpreterOptionsCreate();
@@ -127,10 +139,6 @@
     void (*reporter)(void* user_data, const char* format, va_list args),
     void* user_data);
 
-// --------------------------------------------------------------------------
-// TfLiteInterpreter provides inference from a provided model.
-typedef struct TfLiteInterpreter TfLiteInterpreter;
-
 // Returns a new interpreter using the provided model and options, or null on
 // failure.
 //
diff --git a/tensorflow/lite/c/c_api_experimental.cc b/tensorflow/lite/c/c_api_experimental.cc
index c8f165e..938ff8b 100644
--- a/tensorflow/lite/c/c_api_experimental.cc
+++ b/tensorflow/lite/c/c_api_experimental.cc
@@ -24,9 +24,7 @@
 #include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/interpreter.h"
 
-#ifdef __cplusplus
 extern "C" {
-#endif  // __cplusplus
 
 TfLiteStatus TfLiteInterpreterResetVariableTensors(
     TfLiteInterpreter* interpreter) {
@@ -82,6 +80,4 @@
   options->enable_delegate_fallback = enable;
 }
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif  // __cplusplus
diff --git a/tensorflow/lite/c/c_api_types.h b/tensorflow/lite/c/c_api_types.h
new file mode 100644
index 0000000..e066e54
--- /dev/null
+++ b/tensorflow/lite/c/c_api_types.h
@@ -0,0 +1,92 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file declares types used by the pure C inference API defined in c_api.h,
+// some of which are also used in the C++ and C kernel and interpreter APIs.
+
+#ifndef TENSORFLOW_LITE_C_C_API_TYPES_H_
+#define TENSORFLOW_LITE_C_C_API_TYPES_H_
+
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Define TFL_CAPI_EXPORT macro to export a function properly with a shared
+// library.
+#ifdef SWIG
+#define TFL_CAPI_EXPORT
+#else
+#if defined(_WIN32)
+#ifdef TFL_COMPILE_LIBRARY
+#define TFL_CAPI_EXPORT __declspec(dllexport)
+#else
+#define TFL_CAPI_EXPORT __declspec(dllimport)
+#endif  // TFL_COMPILE_LIBRARY
+#else
+#define TFL_CAPI_EXPORT __attribute__((visibility("default")))
+#endif  // _WIN32
+#endif  // SWIG
+
+typedef enum TfLiteStatus {
+  kTfLiteOk = 0,
+
+  // Generally referring to an error in the runtime (i.e. interpreter)
+  kTfLiteError = 1,
+
+  // Generally referring to an error from a TfLiteDelegate itself.
+  kTfLiteDelegateError = 2,
+
+  // Generally referring to an error in applying a delegate due to
+  // incompatibility between runtime and delegate, e.g., this error is returned
+  // when trying to apply a TfLite delegate onto a model graph that's already
+  // immutable.
+  kTfLiteApplicationError = 3
+} TfLiteStatus;
+
+// Types supported by tensor
+typedef enum {
+  kTfLiteNoType = 0,
+  kTfLiteFloat32 = 1,
+  kTfLiteInt32 = 2,
+  kTfLiteUInt8 = 3,
+  kTfLiteInt64 = 4,
+  kTfLiteString = 5,
+  kTfLiteBool = 6,
+  kTfLiteInt16 = 7,
+  kTfLiteComplex64 = 8,
+  kTfLiteInt8 = 9,
+  kTfLiteFloat16 = 10,
+  kTfLiteFloat64 = 11,
+  kTfLiteComplex128 = 12,
+  kTfLiteUInt64 = 13,
+} TfLiteType;
+
+// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
+// If per-layer quantization is specified this field will still be populated in
+// addition to TfLiteAffineQuantization.
+// Parameters for asymmetric quantization. Quantized values can be converted
+// back to float using:
+//     real_value = scale * (quantized_value - zero_point)
+typedef struct TfLiteQuantizationParams {
+  float scale;
+  int32_t zero_point;
+} TfLiteQuantizationParams;
+
+#ifdef __cplusplus
+}  // extern C
+#endif
+#endif  // TENSORFLOW_LITE_C_C_API_TYPES_H_
diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c
index e43984e..0b54fda 100644
--- a/tensorflow/lite/c/common.c
+++ b/tensorflow/lite/c/common.c
@@ -14,6 +14,8 @@
 ==============================================================================*/
 
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/c/c_api_types.h"
+
 #ifndef TF_LITE_STATIC_MEMORY
 #include <stdlib.h>
 #include <string.h>
diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h
index 923a0fa..59ad977 100644
--- a/tensorflow/lite/c/common.h
+++ b/tensorflow/lite/c/common.h
@@ -40,26 +40,12 @@
 #include <stddef.h>
 #include <stdint.h>
 
+#include "tensorflow/lite/c/c_api_types.h"  // IWYU pragma: export
+
 #ifdef __cplusplus
 extern "C" {
 #endif  // __cplusplus
 
-typedef enum TfLiteStatus {
-  kTfLiteOk = 0,
-
-  // Generally referring to an error in the runtime (i.e. interpreter)
-  kTfLiteError = 1,
-
-  // Generally referring to an error from a TfLiteDelegate itself.
-  kTfLiteDelegateError = 2,
-
-  // Generally referring to an error in applying a delegate due to
-  // incompatibility between runtime and delegate, e.g., this error is returned
-  // when trying to apply a TfLite delegate onto a model graph that's already
-  // immutable.
-  kTfLiteApplicationError = 3
-} TfLiteStatus;
-
 // The list of external context types known to TF Lite. This list exists solely
 // to avoid conflicts and to ensure ops can share the external contexts they
 // need. Access to the external contexts is controlled by one of the
@@ -254,22 +240,6 @@
     }                                      \
   } while (0)
 
-// Define TFL_CAPI_EXPORT macro to export a function properly with a shared
-// library.
-#ifdef SWIG
-#define TFL_CAPI_EXPORT
-#else
-#if defined(_WIN32)
-#ifdef TFL_COMPILE_LIBRARY
-#define TFL_CAPI_EXPORT __declspec(dllexport)
-#else
-#define TFL_CAPI_EXPORT __declspec(dllimport)
-#endif  // TFL_COMPILE_LIBRARY
-#else
-#define TFL_CAPI_EXPORT __attribute__((visibility("default")))
-#endif  // _WIN32
-#endif  // SWIG
-
 // Single-precision complex data type compatible with the C99 definition.
 typedef struct TfLiteComplex64 {
   float re, im;  // real and imaginary parts, respectively.
@@ -285,24 +255,6 @@
   uint16_t data;
 } TfLiteFloat16;
 
-// Types supported by tensor
-typedef enum {
-  kTfLiteNoType = 0,
-  kTfLiteFloat32 = 1,
-  kTfLiteInt32 = 2,
-  kTfLiteUInt8 = 3,
-  kTfLiteInt64 = 4,
-  kTfLiteString = 5,
-  kTfLiteBool = 6,
-  kTfLiteInt16 = 7,
-  kTfLiteComplex64 = 8,
-  kTfLiteInt8 = 9,
-  kTfLiteFloat16 = 10,
-  kTfLiteFloat64 = 11,
-  kTfLiteComplex128 = 12,
-  kTfLiteUInt64 = 13,
-} TfLiteType;
-
 // Return the name of a given type, for error reporting purposes.
 const char* TfLiteTypeGetName(TfLiteType type);
 
@@ -319,22 +271,12 @@
 typedef struct TfLiteQuantization {
   // The type of quantization held by params.
   TfLiteQuantizationType type;
-  // Holds a reference to one of the quantization param structures specified
-  // below.
+  // Holds an optional reference to a quantization param structure. The actual
+  // type depends on the value of the `type` field (see the comment there for
+  // the values and corresponding types).
   void* params;
 } TfLiteQuantization;
 
-// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
-// If per-layer quantization is specified this field will still be populated in
-// addition to TfLiteAffineQuantization.
-// Parameters for asymmetric quantization. Quantized values can be converted
-// back to float using:
-//     real_value = scale * (quantized_value - zero_point)
-typedef struct TfLiteQuantizationParams {
-  float scale;
-  int32_t zero_point;
-} TfLiteQuantizationParams;
-
 // Parameters for asymmetric quantization across a dimension (i.e per output
 // channel quantization).
 // quantized_dimension specifies which dimension the scales and zero_points
@@ -536,7 +478,7 @@
   // WARNING: This is an experimental interface that is subject to change.
   struct TfLiteDelegate* delegate;
 } TfLiteNode;
-#else  // defined(TF_LITE_STATIC_MEMORY)?
+#else   // defined(TF_LITE_STATIC_MEMORY)?
 // NOTE: This flag is opt-in only at compile time.
 //
 // Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index a4bfc77..e9a0560 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -201,6 +201,14 @@
       return ParseDequantize(op, error_reporter, allocator, builtin_data);
     }
 
+    case BuiltinOperator_DIV: {
+      return ParseDiv(op, error_reporter, allocator, builtin_data);
+    }
+
+    case BuiltinOperator_EXP: {
+      return ParseExp(op, error_reporter, allocator, builtin_data);
+    }
+
     case BuiltinOperator_FILL: {
       return ParseFill(op, error_reporter, allocator, builtin_data);
     }
@@ -487,16 +495,7 @@
     case BuiltinOperator_HASHTABLE_LOOKUP:
       // no-op.
       return kTfLiteOk;
-    case BuiltinOperator_DIV: {
-      auto params = safe_allocator.Allocate<TfLiteDivParams>();
-      TF_LITE_ENSURE(error_reporter, params != nullptr);
-      if (const auto* schema_params = op->builtin_options_as_DivOptions()) {
-        params->activation =
-            ConvertActivation(schema_params->fused_activation_function());
-      }
-      *builtin_data = params.release();
-      return kTfLiteOk;
-    }
+
     case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
       auto params = safe_allocator.Allocate<TfLiteLocalResponseNormParams>();
       TF_LITE_ENSURE(error_reporter, params != nullptr);
@@ -761,6 +760,8 @@
               op->builtin_options_as_BatchMatMulOptions()) {
         params->adj_x = bmm_params->adj_x();
         params->adj_y = bmm_params->adj_y();
+        params->asymmetric_quantize_inputs =
+            bmm_params->asymmetric_quantize_inputs();
       }
       *builtin_data = params.release();
       return kTfLiteOk;
@@ -796,7 +797,6 @@
     case BuiltinOperator_ELU:
     case BuiltinOperator_EMBEDDING_LOOKUP:
     case BuiltinOperator_EQUAL:
-    case BuiltinOperator_EXP:
     case BuiltinOperator_EXPAND_DIMS:
     case BuiltinOperator_LOG_SOFTMAX:
     case BuiltinOperator_MATRIX_DIAG:
@@ -1089,6 +1089,21 @@
   return kTfLiteOk;
 }
 
+TfLiteStatus ParseDiv(const Operator* op, ErrorReporter* error_reporter,
+                      BuiltinDataAllocator* allocator, void** builtin_data) {
+  CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
+
+  SafeBuiltinDataAllocator safe_allocator(allocator);
+  auto params = safe_allocator.Allocate<TfLiteDivParams>();
+  TF_LITE_ENSURE(error_reporter, params != nullptr);
+  if (const auto* schema_params = op->builtin_options_as_DivOptions()) {
+    params->activation =
+        ConvertActivation(schema_params->fused_activation_function());
+  }
+  *builtin_data = params.release();
+  return kTfLiteOk;
+}
+
 // We have this parse function instead of directly returning kTfLiteOk from the
 // switch-case in ParseOpData because this function is used as part of the
 // selective registration for the OpResolver implementation in micro.
@@ -1100,6 +1115,14 @@
 // We have this parse function instead of directly returning kTfLiteOk from the
 // switch-case in ParseOpData because this function is used as part of the
 // selective registration for the OpResolver implementation in micro.
+TfLiteStatus ParseExp(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
+                      void**) {
+  return kTfLiteOk;
+}
+
+// We have this parse function instead of directly returning kTfLiteOk from the
+// switch-case in ParseOpData because this function is used as part of the
+// selective registration for the OpResolver implementation in micro.
 TfLiteStatus ParseFill(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
                        void**) {
   return kTfLiteOk;
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h
index 2540d35..82d5bbe 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.h
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.h
@@ -104,9 +104,15 @@
                              BuiltinDataAllocator* allocator,
                              void** builtin_data);
 
+TfLiteStatus ParseDiv(const Operator* op, ErrorReporter* error_reporter,
+                      BuiltinDataAllocator* allocator, void** builtin_data);
+
 TfLiteStatus ParseEqual(const Operator* op, ErrorReporter* error_reporter,
                         BuiltinDataAllocator* allocator, void** builtin_data);
 
+TfLiteStatus ParseExp(const Operator* op, ErrorReporter* error_reporter,
+                      BuiltinDataAllocator* allocator, void** builtin_data);
+
 TfLiteStatus ParseFill(const Operator* op, ErrorReporter* error_reporter,
                        BuiltinDataAllocator* allocator, void** builtin_data);
 
diff --git a/tensorflow/lite/core/api/profiler.h b/tensorflow/lite/core/api/profiler.h
index 897efbe..f2dd12c 100644
--- a/tensorflow/lite/core/api/profiler.h
+++ b/tensorflow/lite/core/api/profiler.h
@@ -181,12 +181,12 @@
       _profile_, __COUNTER__)((profiler), (tag), (node_index))
 
 #define TFLITE_ADD_RUNTIME_INSTRUMENTATION_EVENT(                          \
-    profiler, tag, delegate_status, interpreter_status)                    \
+    profiler, tag, event_metadata1, event_metadata2)                       \
   do {                                                                     \
-    if (!profiler) {                                                       \
+    if (profiler) {                                                        \
       const auto handle = profiler->BeginEvent(                            \
           tag, Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT, \
-          delegate_status, interpreter_status);                            \
+          event_metadata1, event_metadata2);                               \
       profiler->EndEvent(handle);                                          \
     }                                                                      \
   } while (false);
diff --git a/tensorflow/lite/core/shims/BUILD b/tensorflow/lite/core/shims/BUILD
index aa25059..e09573d 100644
--- a/tensorflow/lite/core/shims/BUILD
+++ b/tensorflow/lite/core/shims/BUILD
@@ -112,6 +112,10 @@
         "//tensorflow/lite/kernels:fully_connected.h",
     ],
     compatible_with = get_compatible_with_portable(),
+    visibility = [
+        "//tensorflow/lite:__subpackages__",
+        "//tensorflow_lite_support:__subpackages__",
+    ],
     deps = [
         "//tensorflow/lite:cc_api",
         "//tensorflow/lite/c:common",
diff --git a/tensorflow/lite/core/shims/c/builtin_op_data.h b/tensorflow/lite/core/shims/c/builtin_op_data.h
index f75b014..747c805 100644
--- a/tensorflow/lite/core/shims/c/builtin_op_data.h
+++ b/tensorflow/lite/core/shims/c/builtin_op_data.h
@@ -12,9 +12,9 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_
-#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_
+#ifndef TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_
+#define TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_
 
 #include "tensorflow/lite/c/builtin_op_data.h"
 
-#endif  // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_
+#endif  // TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/lite/core/shims/c/c_api.h b/tensorflow/lite/core/shims/c/c_api.h
index 90e0147..a42d163 100644
--- a/tensorflow/lite/core/shims/c/c_api.h
+++ b/tensorflow/lite/core/shims/c/c_api.h
@@ -12,9 +12,9 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_
-#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_
+#ifndef TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_
+#define TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_
 
 #include "tensorflow/lite/c/c_api.h"
 
-#endif  // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_
+#endif  // TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_
diff --git a/tensorflow/lite/core/shims/c/c_api_experimental.h b/tensorflow/lite/core/shims/c/c_api_experimental.h
index ceb1cb7..ec1222c 100644
--- a/tensorflow/lite/core/shims/c/c_api_experimental.h
+++ b/tensorflow/lite/core/shims/c/c_api_experimental.h
@@ -12,9 +12,9 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_
-#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_
+#ifndef TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_
+#define TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_
 
 #include "tensorflow/lite/c/c_api_experimental.h"
 
-#endif  // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_
+#endif  // TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_
diff --git a/tensorflow/lite/core/shims/c/common.h b/tensorflow/lite/core/shims/c/common.h
index a531546..bcbd168 100644
--- a/tensorflow/lite/core/shims/c/common.h
+++ b/tensorflow/lite/core/shims/c/common.h
@@ -12,9 +12,15 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_
-#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_
+#ifndef TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_
+#define TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_
 
 #include "tensorflow/lite/c/common.h"
 
-#endif  // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_
+// TfLiteOpaqueDelegate: allows delegation of nodes to alternative backends.
+// TfLiteOpaqueDelegate is an abstract type that is intended to have the same
+// role as TfLiteDelegate, but without necessarily exposing the implementation
+// details of how delegates are implemented.
+typedef TfLiteDelegate TfLiteOpaqueDelegate;
+
+#endif  // TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_
diff --git a/tensorflow/lite/java/src/main/native/op_resolver.h b/tensorflow/lite/create_op_resolver.h
similarity index 76%
rename from tensorflow/lite/java/src/main/native/op_resolver.h
rename to tensorflow/lite/create_op_resolver.h
index 08ff0ce..ab00d27 100644
--- a/tensorflow/lite/java/src/main/native/op_resolver.h
+++ b/tensorflow/lite/create_op_resolver.h
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
-#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
+#ifndef TENSORFLOW_LITE_CREATE_OP_RESOLVER_H_
+#define TENSORFLOW_LITE_CREATE_OP_RESOLVER_H_
 
 #include <memory>
 
@@ -21,8 +21,7 @@
 
 namespace tflite {
 
-std::unique_ptr<OpResolver> CreateOpResolver();
-
+std::unique_ptr<MutableOpResolver> CreateOpResolver();
 }
 
-#endif  // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
+#endif  // TENSORFLOW_LITE_CREATE_OP_RESOLVER_H_
diff --git a/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/lite/create_op_resolver_with_builtin_ops.cc
similarity index 60%
rename from tensorflow/lite/java/src/main/native/builtin_ops_jni.cc
rename to tensorflow/lite/create_op_resolver_with_builtin_ops.cc
index eb17fdc..0eda0b7 100644
--- a/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc
+++ b/tensorflow/lite/create_op_resolver_with_builtin_ops.cc
@@ -15,18 +15,17 @@
 
 #include <memory>
 
-#include "tensorflow/lite/core/api/op_resolver.h"
-#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/core/shims/cc/kernels/register.h"
+#include "tensorflow/lite/create_op_resolver.h"
 
 namespace tflite {
 
-// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in
-// the tflite namespace. This one instantiates a
-// BuiltinOpResolverWithoutDefaultDelegates, with all the builtin ops but
-// without applying any TfLite delegates by default (like the XNNPACK delegate).
-// For smaller binary sizes users should avoid linking this in, and should
-// provide a custom make CreateOpResolver() instead.
-std::unique_ptr<OpResolver> CreateOpResolver() {  // NOLINT
+// This function instantiates a  BuiltinOpResolverWithoutDefaultDelegates, with
+// all the builtin ops but without applying any TfLite delegates by default
+// (like the XNNPACK delegate). For smaller binary sizes users should avoid
+// linking this in, and should provide a CreateOpResolver() with selected ops
+// instead.
+std::unique_ptr<MutableOpResolver> CreateOpResolver() {  // NOLINT
   return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
       new tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
 }
diff --git a/tensorflow/lite/java/src/main/native/selected_ops_jni.cc b/tensorflow/lite/create_op_resolver_with_selected_ops.cc
similarity index 86%
rename from tensorflow/lite/java/src/main/native/selected_ops_jni.cc
rename to tensorflow/lite/create_op_resolver_with_selected_ops.cc
index d8eb233..c7c0978 100644
--- a/tensorflow/lite/java/src/main/native/selected_ops_jni.cc
+++ b/tensorflow/lite/create_op_resolver_with_selected_ops.cc
@@ -13,11 +13,11 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/java/src/main/native/op_resolver.h"
+#include "tensorflow/lite/create_op_resolver.h"
 #include "tensorflow/lite/mutable_op_resolver.h"
 
 // This method is generated by `gen_selected_ops`.
-// TODO(b/153652701): Instead of relying on a global method, make
+// TODO(b/174972014): Instead of relying on a global method, make
 // `gen_selected_ops` generating a header file with custom namespace.
 void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
 
@@ -26,11 +26,11 @@
 // regardless if selective registration is being used. C++ client will call
 // this method directly and Java client will call this method indirectly via
 // JNI code in interpreter_jni.cc.
-std::unique_ptr<OpResolver> CreateOpResolver() {
+std::unique_ptr<MutableOpResolver> CreateOpResolver() {
   std::unique_ptr<MutableOpResolver> resolver =
       std::unique_ptr<MutableOpResolver>(new MutableOpResolver());
   RegisterSelectedOps(resolver.get());
-  return std::move(resolver);
+  return resolver;
 }
 
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/BUILD b/tensorflow/lite/delegates/BUILD
index 4ac2493..4cbd58d 100644
--- a/tensorflow/lite/delegates/BUILD
+++ b/tensorflow/lite/delegates/BUILD
@@ -22,12 +22,31 @@
 )
 
 cc_library(
-    name = "status",
-    hdrs = ["status.h"],
+    name = "telemetry",
+    srcs = ["telemetry.cc"],
+    hdrs = ["telemetry.h"],
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
     deps = [
         "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core/api",
+        "//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs",
+    ],
+)
+
+cc_test(
+    name = "telemetry_test",
+    srcs = ["telemetry_test.cc"],
+    linkopts = tflite_linkopts(),
+    linkstatic = 1,
+    deps = [
+        ":telemetry",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core/api",
+        "//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs",
+        "//tensorflow/lite/profiling:profile_buffer",
+        "@com_google_googletest//:gtest_main",
+        "@flatbuffers",
     ],
 )
 
@@ -41,6 +60,7 @@
         "//tensorflow/lite:kernel_api",
         "//tensorflow/lite:util",
         "//tensorflow/lite/c:common",
+        "//tensorflow/lite/kernels:kernel_util",
     ],
 )
 
diff --git a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc
index 688212e..266ea09 100644
--- a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc
+++ b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc
@@ -548,6 +548,7 @@
           "StridedSliceAssign",
           "StridedSliceGrad",
           "StringJoin",
+          "StringLength",
           "StringLower",
           "StringSplit",
           "StringSplitV2",
diff --git a/tensorflow/lite/delegates/flex/java/src/main/native/flex_delegate_jni.cc b/tensorflow/lite/delegates/flex/java/src/main/native/flex_delegate_jni.cc
index fef7191..682eb5c 100644
--- a/tensorflow/lite/delegates/flex/java/src/main/native/flex_delegate_jni.cc
+++ b/tensorflow/lite/delegates/flex/java/src/main/native/flex_delegate_jni.cc
@@ -19,9 +19,7 @@
 #include "tensorflow/lite/delegates/utils/simple_delegate.h"
 #include "tensorflow/lite/testing/init_tensorflow.h"
 
-#ifdef __cplusplus
 extern "C" {
-#endif  // __cplusplus
 
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_flex_FlexDelegate_nativeInitTensorFlow(JNIEnv* env,
@@ -42,6 +40,4 @@
       reinterpret_cast<struct TfLiteDelegate*>(delegate));
 }
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif  // __cplusplus
diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index 069230e..2c0080c 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -92,6 +92,7 @@
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_builder",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
+        "//tensorflow/lite/delegates/gpu/common:precision",
         "//tensorflow/lite/delegates/gpu/common:quantization_util",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD
index 989a163..34c1296 100644
--- a/tensorflow/lite/delegates/gpu/cl/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/BUILD
@@ -395,8 +395,6 @@
         ":opencl_wrapper",
         ":serialization_cc_fbs",
         ":tensor",
-        "//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector",
-        "//tensorflow/lite/delegates/gpu/cl/selectors:special_selector",
         "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:memory_management",
         "//tensorflow/lite/delegates/gpu/common:model",
@@ -409,6 +407,8 @@
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common/selectors:operation_selector",
+        "//tensorflow/lite/delegates/gpu/common/selectors:special_selector",
         "//tensorflow/lite/delegates/gpu/common/task:arguments",
         "//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
         "//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
@@ -515,7 +515,6 @@
         "//tensorflow/lite/delegates/gpu/common/task:gpu_tensor",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
         "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
     ],
 )
 
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc
index 83b9f15..e5b4ea9 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc
@@ -346,7 +346,7 @@
     result += "  " + dispatch.label + " - " +
               std::to_string(absl::ToDoubleMilliseconds(dispatch.duration)) +
               " ms\n";
-    auto name = dispatch.label.substr(0, dispatch.label.find(" "));
+    auto name = dispatch.label.substr(0, dispatch.label.find(' '));
     if (statistics.find(name) != statistics.end()) {
       statistics[name].count++;
       statistics[name].total_time +=
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.cc b/tensorflow/lite/delegates/gpu/cl/cl_device.cc
index 1bd5db7..81b504c 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_device.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_device.cc
@@ -20,6 +20,7 @@
 #include <vector>
 
 #include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_split.h"
 #include "tensorflow/lite/delegates/gpu/cl/util.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
@@ -120,35 +121,6 @@
   }
 }
 
-GpuVendor ParseVendor(const std::string& device_name,
-                      const std::string& vendor_name) {
-  std::string d_name = device_name;
-  std::string v_name = vendor_name;
-  std::transform(d_name.begin(), d_name.end(), d_name.begin(), ::tolower);
-  std::transform(v_name.begin(), v_name.end(), v_name.begin(), ::tolower);
-  if (d_name.find("qualcomm") != std::string::npos ||
-      v_name.find("qualcomm") != std::string::npos) {
-    return GpuVendor::kQualcomm;
-  } else if (d_name.find("mali") != std::string::npos ||
-             v_name.find("mali") != std::string::npos) {
-    return GpuVendor::kMali;
-  } else if (d_name.find("power") != std::string::npos ||
-             v_name.find("power") != std::string::npos) {
-    return GpuVendor::kPowerVR;
-  } else if (d_name.find("nvidia") != std::string::npos ||
-             v_name.find("nvidia") != std::string::npos) {
-    return GpuVendor::kNvidia;
-  } else if (d_name.find("advanced micro devices") != std::string::npos ||
-             v_name.find("advanced micro devices") != std::string::npos) {
-    return GpuVendor::kAMD;
-  } else if (d_name.find("intel") != std::string::npos ||
-             v_name.find("intel") != std::string::npos) {
-    return GpuVendor::kIntel;
-  } else {
-    return GpuVendor::kUnknown;
-  }
-}
-
 // check that gpu_version belong to range min_version-max_version
 // min_version is included and max_version is excluded.
 bool IsGPUVersionInRange(int gpu_version, int min_version, int max_version) {
@@ -162,13 +134,9 @@
   const auto vendor_name = GetDeviceInfo<std::string>(id, CL_DEVICE_VENDOR);
   const auto opencl_c_version =
       GetDeviceInfo<std::string>(id, CL_DEVICE_OPENCL_C_VERSION);
-  info.gpu_api = GpuApi::kOpenCl;
-  info.vendor = ParseVendor(device_name, vendor_name);
-  if (info.IsAdreno()) {
-    info.adreno_info = AdrenoInfo(opencl_c_version);
-  } else if (info.IsMali()) {
-    info.mali_info = MaliInfo(device_name);
-  }
+  const std::string gpu_description =
+      absl::StrCat(device_name, " ", vendor_name, " ", opencl_c_version);
+  GetGpuInfoFromDeviceDescription(gpu_description, GpuApi::kOpenCl, &info);
   info.opencl_info.cl_version = ParseCLVersion(opencl_c_version);
   info.opencl_info.extensions =
       absl::StrSplit(GetDeviceInfo<std::string>(id, CL_DEVICE_EXTENSIONS), ' ');
diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc
index 2d4e6c5..941159d 100644
--- a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc
+++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc
@@ -209,14 +209,16 @@
   //   c) EglSync->CLEvent or GlSync->CLEvent mapping
   //      Fast, as it allows to map sync to CL event and use it as a dependency
   //      later without stalling GPU pipeline.
+  CLEvent inbound_event;
+  std::vector<cl_event> inbound_events;
   if (is_egl_sync_supported_) {
     EglSync sync;
     RETURN_IF_ERROR(EglSync::NewFence(egl_display_, &sync));
     if (is_egl_to_cl_mapping_supported_) {
       // (c) EglSync->CLEvent or GlSync->CLEvent mapping
       glFlush();
-      RETURN_IF_ERROR(
-          CreateClEventFromEglSync(context_, sync, &inbound_event_));
+      RETURN_IF_ERROR(CreateClEventFromEglSync(context_, sync, &inbound_event));
+      inbound_events.push_back(inbound_event.event());
     } else {
       // (b) EglSync + ClientWait
       RETURN_IF_ERROR(sync.ClientWait());
@@ -227,25 +229,20 @@
   }
 
   // Acquire all GL objects needed while processing.
-  auto make_acquire_wait = [&]() -> std::vector<cl_event> {
-    if (inbound_event_.is_valid()) {
-      return {inbound_event_.event()};
-    }
-    return {};
-  };
-  return AcquiredGlObjects::Acquire(memory_, queue_, make_acquire_wait(),
-                                    nullptr, &gl_objects_);
+  return AcquiredGlObjects::Acquire(memory_, queue_, inbound_events, nullptr,
+                                    &gl_objects_);
 }
 
 absl::Status GlInteropFabric::Finish() {
   if (!is_enabled()) {
     return absl::OkStatus();
   }
-  RETURN_IF_ERROR(gl_objects_.Release({}, &outbound_event_));
+  CLEvent outbound_event;
+  RETURN_IF_ERROR(gl_objects_.Release({}, &outbound_event));
 
   // if (is_egl_sync_supported_ && is_cl_to_egl_mapping_supported_) {
   //   EglSync egl_outbound_sync;
-  //   RETURN_IF_ERROR(CreateEglSyncFromClEvent(outbound_event_.event(),
+  //   RETURN_IF_ERROR(CreateEglSyncFromClEvent(outbound_event.event(),
   //                                            egl_display_,
   //                                            &egl_outbound_sync));
   //   // Instruct GL pipeline to wait until corresponding CL event is signaled.
@@ -254,12 +251,12 @@
   // } else {
   //   // Slower option if proper sync is not supported. It is equivalent to
   //   // clFinish, but, hopefully, faster.
-  //   outbound_event_.Wait();
+  //   outbound_event.Wait();
   // }
 
   // This slow sync is the only working solution right now. We have to debug why
   // above version is not working fast and reliable.
-  outbound_event_.Wait();
+  outbound_event.Wait();
   return absl::OkStatus();
 }
 
diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.h b/tensorflow/lite/delegates/gpu/cl/gl_interop.h
index aac769b..28c37c9 100644
--- a/tensorflow/lite/delegates/gpu/cl/gl_interop.h
+++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.h
@@ -136,8 +136,6 @@
   const EGLDisplay egl_display_;
   cl_context context_;
   cl_command_queue queue_;
-  CLEvent inbound_event_;
-  CLEvent outbound_event_;
   std::vector<cl_mem> memory_;
   AcquiredGlObjects gl_objects_;  // transient during Start/Finish calls.
 };
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
index f24eeb2..2b966be 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
@@ -27,14 +27,14 @@
 #include "absl/container/flat_hash_set.h"
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
-#include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h"
-#include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/memory_management.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/precision.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc
index 470b0d8..9e24cdb 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc
@@ -63,7 +63,7 @@
           "Layout doesn't have Batch dimension, but shape.b != 1");
     }
     RETURN_IF_ERROR(CreateTensor(*creation_context.context, src_shape,
-                                 op_def.src_tensors[0], &src[i]));
+                                 op_def.src_tensors[i], &src[i]));
     RETURN_IF_ERROR(src[i].WriteData(creation_context.queue, src_cpu[i]));
     operation->SetSrc(&src[i], i);
   }
@@ -76,7 +76,7 @@
           "Layout doesn't have Batch dimension, but shape.b != 1");
     }
     RETURN_IF_ERROR(CreateTensor(*creation_context.context, dst_shape,
-                                 op_def.dst_tensors[0], &dst[i]));
+                                 op_def.dst_tensors[i], &dst[i]));
 
     operation->SetDst(&dst[i], i);
   }
@@ -111,7 +111,7 @@
           "Layout doesn't have Batch dimension, but shape.b != 1");
     }
     RETURN_IF_ERROR(CreateTensor(*creation_context.context, src_shape,
-                                 op_def.src_tensors[0], &src[i]));
+                                 op_def.src_tensors[i], &src[i]));
     RETURN_IF_ERROR(src[i].WriteData(creation_context.queue, src_cpu[i]));
     operation->SetSrc(&src[i], i);
   }
@@ -124,7 +124,7 @@
           "Layout doesn't have Batch dimension, but shape.b != 1");
     }
     RETURN_IF_ERROR(CreateTensor(*creation_context.context, dst_shape,
-                                 op_def.dst_tensors[0], &dst[i]));
+                                 op_def.dst_tensors[i], &dst[i]));
 
     operation->SetDst(&dst[i], i);
   }
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1_test.cc
index 9db5315..717a6f6 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1_test.cc
@@ -59,6 +59,44 @@
   }
 }
 
+TEST_F(OpenCLOperationTest, Softmax1x1BigNumber) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 1, 1, 4);
+  double doubles[4] = {1.0, 2.0, 3.0, 100.0};
+  // exp(100) is inf in float (32 bit) but representable in double (64 bit)
+  src_tensor.data.resize(4);
+  src_tensor.data[0] = doubles[0];
+  src_tensor.data[1] = doubles[1];
+  src_tensor.data[2] = doubles[2];
+  src_tensor.data[3] = doubles[3];
+  EXPECT_TRUE(std::isinf(std::exp(src_tensor.data[3])));
+  EXPECT_FALSE(std::isinf(std::exp(doubles[3])));
+  double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) +
+              std::exp(doubles[2]) + std::exp(doubles[3]);
+
+  for (auto storage : env_.GetSupportedStorages()) {
+    for (auto precision : env_.GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      Softmax1x1 operation = CreateSoftmax1x1(op_def);
+      ASSERT_OK(ExecuteGPUOperation(
+          src_tensor, creation_context_,
+          absl::make_unique<Softmax1x1>(std::move(operation)), BHWC(1, 1, 1, 4),
+          &dst_tensor));
+      EXPECT_THAT(
+          dst_tensor.data,
+          Pointwise(FloatNear(eps),
+                    {std::exp(doubles[0]) / s0, std::exp(doubles[1]) / s0,
+                     std::exp(doubles[2]) / s0, std::exp(doubles[3]) / s0}));
+    }
+  }
+}
+
 }  // namespace
 }  // namespace cl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax_test.cc
index 8b1675b..09f247d 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax_test.cc
@@ -60,6 +60,44 @@
   }
 }
 
+TEST_F(OpenCLOperationTest, SoftmaxBigNumber) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 2, 1, 2);
+  double doubles[4] = {1.0, 2.0, 3.0, 100.0};
+  // exp(100) is inf in float (32 bit) but representable in double (64 bit)
+  src_tensor.data.resize(4);
+  src_tensor.data[0] = doubles[0];
+  src_tensor.data[1] = doubles[1];
+  src_tensor.data[2] = doubles[2];
+  src_tensor.data[3] = doubles[3];
+  EXPECT_TRUE(std::isinf(std::exp(src_tensor.data[3])));
+  EXPECT_FALSE(std::isinf(std::exp(doubles[3])));
+  double s0 = std::exp(doubles[0]) + std::exp(doubles[1]);
+  double s1 = std::exp(doubles[2]) + std::exp(doubles[3]);
+
+  for (auto storage : env_.GetSupportedStorages()) {
+    for (auto precision : env_.GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      GPUOperation operation = CreateSoftmax(op_def);
+      ASSERT_OK(ExecuteGPUOperation(
+          src_tensor, creation_context_,
+          absl::make_unique<GPUOperation>(std::move(operation)),
+          BHWC(1, 2, 1, 2), &dst_tensor));
+      EXPECT_THAT(
+          dst_tensor.data,
+          Pointwise(FloatNear(eps),
+                    {std::exp(doubles[0]) / s0, std::exp(doubles[1]) / s0,
+                     std::exp(doubles[2]) / s1, std::exp(doubles[3]) / s1}));
+    }
+  }
+}
+
 }  // namespace
 }  // namespace cl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.cc b/tensorflow/lite/delegates/gpu/cl/serialization.cc
index 812782c..4f20c41 100644
--- a/tensorflow/lite/delegates/gpu/cl/serialization.cc
+++ b/tensorflow/lite/delegates/gpu/cl/serialization.cc
@@ -53,7 +53,25 @@
       return data::DataType::FLOAT16;
     case DataType::FLOAT32:
       return data::DataType::FLOAT32;
-    default:
+    case DataType::FLOAT64:
+      return data::DataType::FLOAT64;
+    case DataType::UINT8:
+      return data::DataType::UINT8;
+    case DataType::INT8:
+      return data::DataType::INT8;
+    case DataType::UINT16:
+      return data::DataType::UINT16;
+    case DataType::INT16:
+      return data::DataType::INT16;
+    case DataType::UINT32:
+      return data::DataType::UINT32;
+    case DataType::INT32:
+      return data::DataType::INT32;
+    case DataType::UINT64:
+      return data::DataType::UINT64;
+    case DataType::INT64:
+      return data::DataType::INT64;
+    case DataType::UNKNOWN:
       return data::DataType::UNKNOWN;
   }
 }
@@ -118,7 +136,25 @@
       return DataType::FLOAT16;
     case data::DataType::FLOAT32:
       return DataType::FLOAT32;
-    default:
+    case data::DataType::FLOAT64:
+      return DataType::FLOAT64;
+    case data::DataType::UINT8:
+      return DataType::UINT8;
+    case data::DataType::INT8:
+      return DataType::INT8;
+    case data::DataType::UINT16:
+      return DataType::UINT16;
+    case data::DataType::INT16:
+      return DataType::INT16;
+    case data::DataType::UINT32:
+      return DataType::UINT32;
+    case data::DataType::INT32:
+      return DataType::INT32;
+    case data::DataType::UINT64:
+      return DataType::UINT64;
+    case data::DataType::INT64:
+      return DataType::INT64;
+    case data::DataType::UNKNOWN:
       return DataType::UNKNOWN;
   }
 }
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc
index 5243124..69d4af3 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.cc
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 
 #include <cstring>
+#include <memory>
 
 #include "absl/strings/str_cat.h"
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
@@ -469,26 +470,23 @@
 
 cl_mem Tensor::GetMemoryPtrForWriting() const { return memory_; }
 
-absl::Status Tensor::WriteDataBHWDC(absl::Span<const float> in,
-                                    CLCommandQueue* queue) {
+absl::Status Tensor::WriteDataBHWDC(const float* in, CLCommandQueue* queue) {
   void* data_ptr = nullptr;
   const int aligned_channels = GetAlignedChannels();
   const int elements_count =
       shape_.b * shape_.w * shape_.h * shape_.d * aligned_channels;
 
   const size_t data_size = elements_count * SizeOf(descriptor_.data_type);
-  std::vector<float> data_f;
-  std::vector<half> data_h;
+  std::unique_ptr<float[]> data_f;
+  std::unique_ptr<half[]> data_h;
   if (descriptor_.data_type == DataType::FLOAT32) {
-    data_f.resize(elements_count);
-    data_ptr = data_f.data();
-    DataFromBHWDC(in, shape_, descriptor_,
-                  absl::MakeSpan(data_f.data(), data_f.size()));
+    data_f.reset(new float[elements_count]);
+    data_ptr = data_f.get();
+    DataFromBHWDC(in, shape_, descriptor_, data_f.get());
   } else {
-    data_h.resize(elements_count);
-    data_ptr = data_h.data();
-    DataFromBHWDC(in, shape_, descriptor_,
-                  absl::MakeSpan(data_h.data(), data_h.size()));
+    data_h.reset(new half[elements_count]);
+    data_ptr = data_h.get();
+    DataFromBHWDC(in, shape_, descriptor_, data_h.get());
   }
 
   switch (descriptor_.storage_type) {
@@ -513,42 +511,41 @@
 absl::Status Tensor::WriteData(CLCommandQueue* queue,
                                const TensorFloat32& src) {
   RETURN_IF_ERROR(IsValid(src.shape));
-  return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue);
+  return WriteDataBHWDC(src.data.data(), queue);
 }
 
 absl::Status Tensor::WriteData(
     CLCommandQueue* queue,
     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src) {
-  return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue);
+  return WriteDataBHWDC(src.data.data(), queue);
 }
 
 absl::Status Tensor::WriteData(
     CLCommandQueue* queue,
     const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src) {
-  return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue);
+  return WriteDataBHWDC(src.data.data(), queue);
 }
 
 absl::Status Tensor::WriteData(CLCommandQueue* queue,
                                const Tensor5DFloat32& src) {
   RETURN_IF_ERROR(IsValid(src.shape));
-  return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue);
+  return WriteDataBHWDC(src.data.data(), queue);
 }
 
-absl::Status Tensor::ReadDataBHWDC(absl::Span<float> out,
-                                   CLCommandQueue* queue) const {
+absl::Status Tensor::ReadDataBHWDC(float* out, CLCommandQueue* queue) const {
   void* data_ptr = nullptr;
   const int aligned_channels = GetAlignedChannels();
   const int elements_count =
       shape_.b * shape_.w * shape_.h * shape_.d * aligned_channels;
   const size_t data_size = elements_count * SizeOf(descriptor_.data_type);
-  std::vector<float> data_f;
-  std::vector<half> data_h;
+  std::unique_ptr<float[]> data_f;
+  std::unique_ptr<half[]> data_h;
   if (descriptor_.data_type == DataType::FLOAT32) {
-    data_f.resize(elements_count);
-    data_ptr = data_f.data();
+    data_f.reset(new float[elements_count]);
+    data_ptr = data_f.get();
   } else {
-    data_h.resize(elements_count);
-    data_ptr = data_h.data();
+    data_h.reset(new half[elements_count]);
+    data_ptr = data_h.get();
   }
 
   switch (descriptor_.storage_type) {
@@ -568,11 +565,9 @@
   }
 
   if (descriptor_.data_type == DataType::FLOAT32) {
-    DataToBHWDC(absl::MakeConstSpan(data_f.data(), data_f.size()), shape_,
-                descriptor_, out);
+    DataToBHWDC(data_f.get(), shape_, descriptor_, out);
   } else {
-    DataToBHWDC(absl::MakeConstSpan(data_h.data(), data_h.size()), shape_,
-                descriptor_, out);
+    DataToBHWDC(data_h.get(), shape_, descriptor_, out);
   }
 
   return absl::OkStatus();
@@ -580,13 +575,13 @@
 
 absl::Status Tensor::ReadData(CLCommandQueue* queue, TensorFloat32* dst) const {
   RETURN_IF_ERROR(IsValid(dst->shape));
-  return ReadDataBHWDC(absl::MakeSpan(dst->data), queue);
+  return ReadDataBHWDC(dst->data.data(), queue);
 }
 
 absl::Status Tensor::ReadData(CLCommandQueue* queue,
                               Tensor5DFloat32* dst) const {
   RETURN_IF_ERROR(IsValid(dst->shape));
-  return ReadDataBHWDC(absl::MakeSpan(dst->data), queue);
+  return ReadDataBHWDC(dst->data.data(), queue);
 }
 
 absl::Status Tensor::CreateFromDescriptor(const TensorDescriptor& desc,
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h
index 331adf0..a1d1343 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.h
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.h
@@ -19,7 +19,6 @@
 #include <cstdint>
 #include <memory>
 
-#include "absl/types/span.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
@@ -103,10 +102,8 @@
   int GetChannelsAlignment() const;
   int GetAlignedChannels() const;
 
-  absl::Status WriteDataBHWDC(absl::Span<const float> in,
-                              CLCommandQueue* queue);
-  absl::Status ReadDataBHWDC(absl::Span<float> out,
-                             CLCommandQueue* queue) const;
+  absl::Status WriteDataBHWDC(const float* in, CLCommandQueue* queue);
+  absl::Status ReadDataBHWDC(float* out, CLCommandQueue* queue) const;
 
   int3 GetFullTensorRegion() const;
   void Release();
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h
index 856d89e..c47438f 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.h
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h
@@ -358,12 +358,6 @@
   bool IsCL30OrHigher() const;
 };
 
-inline bool IsOpenGl31OrAbove(const GpuInfo& gpu_info) {
-  return (gpu_info.opengl_info.major_version == 3 &&
-          gpu_info.opengl_info.minor_version >= 1) ||
-         gpu_info.opengl_info.major_version > 3;
-}
-
 // Currently it initializes:
 // vendor
 // AdrenoInfo if vendor is kQualcomm
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index 97eb075..5fca629 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -1534,10 +1534,10 @@
     ReduceAttributes attr;
     Tensor<Linear, DataType::INT32> axes;
     RETURN_IF_ERROR(reader->ReadTensor(1, &axes));
-    const TfLiteTensor* output = reader->GetOutputTensor(0);
+    const TfLiteTensor* input = reader->GetInputTensor(0);
     for (int i = 0; i < axes.data.size(); i++) {
       Axis axis;
-      RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis));
+      RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[i], &axis));
       attr.dims.insert(axis);
     }
     node->operation.attributes = attr;
@@ -2615,10 +2615,10 @@
     MeanAttributes attr;
     Tensor<Linear, DataType::INT32> axes;
     RETURN_IF_ERROR(reader->ReadTensor(1, &axes));
-    const TfLiteTensor* output = reader->GetOutputTensor(0);
+    const TfLiteTensor* input = reader->GetInputTensor(0);
     for (int i = 0; i < axes.data.size(); i++) {
       Axis axis;
-      RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis));
+      RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[i], &axis));
       attr.dims.insert(axis);
     }
     node->operation.attributes = attr;
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
index 5bf0d60..24b7324 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
@@ -187,7 +187,9 @@
 
 class InterpreterFp16 : public DelegatedInterpreter {
  public:
-  explicit InterpreterFp16(TfLiteBuiltinOperator op) : DelegatedInterpreter(3) {
+  explicit InterpreterFp16(TfLiteBuiltinOperator op,
+                           bool const_dequantize_inputs = true)
+      : DelegatedInterpreter(3) {
     void* builtin_data = malloc(sizeof(int));
     EXPECT_EQ(interpreter_.AddTensors(5), kTfLiteOk);
     EXPECT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk);
@@ -243,6 +245,15 @@
         interpreter_.SetTensorParametersReadWrite(
             2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
         kTfLiteOk);
+    if (const_dequantize_inputs) {
+      // This simulates the dequantize inputs being constants in the graph.
+      // If this is not true, FP16GraphPartitionHelper should not consider the
+      // corresponding DEQUANTIZE ops.
+      auto* tensor0 = interpreter_.tensor(0);
+      auto* tensor2 = interpreter_.tensor(2);
+      tensor0->allocation_type = kTfLiteMmapRo;
+      tensor2->allocation_type = kTfLiteMmapRo;
+    }
     EXPECT_EQ(
         interpreter_.SetTensorParametersReadWrite(
             1, TfLiteType::kTfLiteFloat32, "t1", dims, quantization, false),
@@ -337,6 +348,64 @@
   TfLiteIntArrayFree(ops_to_replace);
 }
 
+InterpreterFp16* interpreter_fp16_non_constant =
+    new InterpreterFp16(kTfLiteBuiltinAdd, /*const_dequantize_inputs=*/false);
+
+// Same as GetOpsToReplaceAcceptsFp16DequantizeNodes, but the DEQUANTIZE inputs
+// are not constant. As a result, we don't allow the delegate to accept them.
+TEST(ModelBuilderTest, GetOpsToReplaceRejectsNonConstantFp16DequantizeNodes) {
+  TfLiteContext* context = interpreter_fp16_non_constant->context();
+
+  // These functions are meant to be called inside delegates. Swap out
+  // for similar functions to permit direct calling of GetOpsToReplace.
+  context->GetExecutionPlan = [](struct TfLiteContext* context,
+                                 TfLiteIntArray** execution_plan) {
+    *execution_plan = interpreter_fp16_non_constant->exec_plan();
+    return kTfLiteOk;
+  };
+  context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
+                                       TfLiteNode** node,
+                                       TfLiteRegistration** registration) {
+    *node = interpreter_fp16_non_constant->node(node_index);
+    *registration = interpreter_fp16_non_constant->registration(node_index);
+    return kTfLiteOk;
+  };
+  context->PreviewDelegatePartitioning =
+      [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
+         TfLiteDelegateParams** partition_params_array, int* num_partitions) {
+        // The partitioner should accept only the Add op initially.
+        EXPECT_EQ(nodes_to_replace->size, 1);
+        // Single partition output.
+        auto params = interpreter_fp16_non_constant->add_delegate_params();
+        params->nodes_to_replace = TfLiteIntArrayCreate(1);
+        params->nodes_to_replace->data[0] = 2;
+        params->input_tensors = TfLiteIntArrayCreate(2);
+        params->input_tensors->data[0] = 1;
+        params->input_tensors->data[1] = 3;
+        params->output_tensors = TfLiteIntArrayCreate(1);
+        params->output_tensors->data[0] = 4;
+
+        *partition_params_array =
+            interpreter_fp16_non_constant->delegate_params();
+        *num_partitions = interpreter_fp16_non_constant->num_delegate_params();
+        return kTfLiteOk;
+      };
+
+  TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
+
+  // Only ADD is delegated, with FP32 (dequantized) inputs.
+  EXPECT_EQ(ops_to_replace->size, 1);
+  TfLiteNode* node = nullptr;
+  TfLiteRegistration* registration = nullptr;
+  context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
+                                  &registration);
+  EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
+            TfLiteType::kTfLiteFloat32);
+  EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
+            TfLiteType::kTfLiteFloat32);
+  TfLiteIntArrayFree(ops_to_replace);
+}
+
 InterpreterFp16* interpreter_fp16_gt_op =
     new InterpreterFp16(kTfLiteBuiltinGreater);
 
@@ -800,6 +869,13 @@
         interpreter_.SetTensorParametersReadWrite(
             2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
         kTfLiteOk);
+    // Simulate DEQUANTIZE inputs being constants.
+    auto* tensor0 = interpreter_.tensor(0);
+    auto* tensor1 = interpreter_.tensor(1);
+    auto* tensor2 = interpreter_.tensor(2);
+    tensor0->allocation_type = kTfLiteMmapRo;
+    tensor1->allocation_type = kTfLiteMmapRo;
+    tensor2->allocation_type = kTfLiteMmapRo;
     EXPECT_EQ(
         interpreter_.SetTensorParametersReadWrite(
             3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/common/selectors/BUILD
similarity index 89%
rename from tensorflow/lite/delegates/gpu/cl/selectors/BUILD
rename to tensorflow/lite/delegates/gpu/common/selectors/BUILD
index 8be7228..0198d62 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/selectors/BUILD
@@ -7,11 +7,11 @@
     name = "convolution_selector",
     hdrs = ["convolution_selector.h"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/cl/selectors/default:convolution_selector",  # buildcleaner: keep
         "//tensorflow/lite/delegates/gpu/common:model_hints",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/selectors/default:convolution_selector",  # buildcleaner: keep
         "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
     ],
@@ -21,9 +21,9 @@
     name = "convolution_transposed_selector",
     hdrs = ["convolution_transposed_selector.h"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/cl/selectors/default:convolution_transposed_selector",  # buildcleaner: keep
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/selectors/default:convolution_transposed_selector",  # buildcleaner: keep
         "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
         "@com_google_absl//absl/memory",
@@ -35,10 +35,10 @@
     hdrs = ["default_selector.h"],
     deps = [
         ":subgraph",
-        "//tensorflow/lite/delegates/gpu/cl/selectors/default:default_selector",  # buildcleaner: keep
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_hints",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/selectors/default:default_selector",  # buildcleaner: keep
         "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
     ],
@@ -48,9 +48,9 @@
     name = "dw_convolution_selector",
     hdrs = ["dw_convolution_selector.h"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/cl/selectors/default:dw_convolution_selector",  # buildcleaner: keep
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/selectors/default:dw_convolution_selector",  # buildcleaner: keep
         "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/memory",
     ],
@@ -60,9 +60,9 @@
     name = "fully_connected_selector",
     hdrs = ["fully_connected_selector.h"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/cl/selectors/default:fully_connected_selector",  # buildcleaner: keep
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/selectors/default:fully_connected_selector",  # buildcleaner: keep
         "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/memory",
     ],
@@ -80,8 +80,8 @@
         ":fully_connected_selector",
         ":simple_selectors",
         ":subgraph",
-        "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_hints",
         "//tensorflow/lite/delegates/gpu/common:operations",
@@ -105,7 +105,7 @@
     srcs = ["simple_selectors.cc"],
     hdrs = ["simple_selectors.h"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/cl:cl_device",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
@@ -141,8 +141,8 @@
     hdrs = ["special_selector.h"],
     deps = [
         ":subgraph",
-        "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/convolution_selector.h
similarity index 88%
rename from tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h
rename to tensorflow/lite/delegates/gpu/common/selectors/convolution_selector.h
index cef1e01..0fa11f5 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/convolution_selector.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_SELECTOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_SELECTOR_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_SELECTOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_SELECTOR_H_
 
 #include <memory>
 
@@ -27,7 +27,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 std::unique_ptr<GPUOperation> SelectConvolution(
     const Convolution2DAttributes& attr, const BHWC& dst_shape,
@@ -46,8 +45,7 @@
     const WeightsDescription& weights_desc, const OperationDef& op_def,
     ModelHints hints);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_SELECTOR_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/convolution_transposed_selector.h
similarity index 82%
rename from tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h
rename to tensorflow/lite/delegates/gpu/common/selectors/convolution_transposed_selector.h
index 4a2a6d9..5c94b89 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/convolution_transposed_selector.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_
 
 #include <memory>
 
@@ -25,7 +25,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 std::unique_ptr<GPUOperation> SelectConvolutionTransposed(
     const ConvolutionTransposedAttributes& attr, const GpuInfo& gpu_info,
@@ -35,8 +34,7 @@
     const ConvolutionTransposedAttributes& attr, const GpuInfo& gpu_info,
     const OperationDef& op_def, WeightsDescription* weights_desc);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD b/tensorflow/lite/delegates/gpu/common/selectors/default/BUILD
similarity index 96%
rename from tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD
rename to tensorflow/lite/delegates/gpu/common/selectors/default/BUILD
index 33edadf..0bcc41c 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/selectors/default/BUILD
@@ -46,11 +46,11 @@
     name = "default_selector",
     srcs = ["default_selector.cc"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/cl/selectors:subgraph",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_hints",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/selectors:subgraph",
         "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/strings",
     ],
@@ -60,7 +60,7 @@
     name = "dw_convolution_selector",
     srcs = ["dw_convolution_selector.cc"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/cl:cl_device",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_selector.cc
similarity index 99%
rename from tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_selector.cc
rename to tensorflow/lite/delegates/gpu/common/selectors/default/convolution_selector.cc
index f76fc56..9f0fdb5 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_selector.cc
@@ -30,7 +30,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 namespace {
 
 std::unique_ptr<GPUOperation> SelectConvolutionAdreno(
@@ -201,6 +200,5 @@
   return absl::make_unique<ConverterToConvWeights>(std::move(converter));
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_transposed_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_transposed_selector.cc
similarity index 99%
rename from tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_transposed_selector.cc
rename to tensorflow/lite/delegates/gpu/common/selectors/default/convolution_transposed_selector.cc
index e33d848..d4205ed 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_transposed_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_transposed_selector.cc
@@ -23,7 +23,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 namespace {
 
 std::unique_ptr<GPUOperation> SelectConvolutionTransposedAdreno(
@@ -142,6 +141,5 @@
   }
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/default_selector.cc
similarity index 93%
rename from tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc
rename to tensorflow/lite/delegates/gpu/common/selectors/default/default_selector.cc
index a7d94fa..2223938 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/default/default_selector.cc
@@ -16,16 +16,15 @@
 #include <memory>
 
 #include "absl/strings/str_cat.h"
-#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_hints.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 absl::Status SelectDefault(const GpuInfo& gpu_info, const OperationDef& op_def,
                            ModelHints hints, const std::vector<Value*>& inputs,
@@ -35,6 +34,5 @@
       absl::StrCat("No selector for ", node.operation.type));
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/dw_convolution_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/dw_convolution_selector.cc
similarity index 96%
rename from tensorflow/lite/delegates/gpu/cl/selectors/default/dw_convolution_selector.cc
rename to tensorflow/lite/delegates/gpu/common/selectors/default/dw_convolution_selector.cc
index 968d061..07b1113 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default/dw_convolution_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/default/dw_convolution_selector.cc
@@ -14,13 +14,12 @@
 ==============================================================================*/
 
 #include "absl/memory/memory.h"
-#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv.h"
 #include "tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv_3x3.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 namespace {
 
 std::unique_ptr<GPUOperation> SelectDWConvolutionAdreno(
@@ -79,6 +78,5 @@
   }
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/fully_connected_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/fully_connected_selector.cc
similarity index 98%
rename from tensorflow/lite/delegates/gpu/cl/selectors/default/fully_connected_selector.cc
rename to tensorflow/lite/delegates/gpu/common/selectors/default/fully_connected_selector.cc
index 43f6a26..8634093 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default/fully_connected_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/default/fully_connected_selector.cc
@@ -22,7 +22,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 std::unique_ptr<GPUOperation> SelectFullyConnectedGeneric(
     const FullyConnectedAttributes& attr, const GpuInfo& gpu_info,
@@ -96,6 +95,5 @@
   }
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/default_selector.h
similarity index 81%
rename from tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h
rename to tensorflow/lite/delegates/gpu/common/selectors/default_selector.h
index 1efa215..c6f7758 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/default_selector.h
@@ -13,29 +13,27 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SELECTOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SELECTOR_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DEFAULT_SELECTOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DEFAULT_SELECTOR_H_
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_hints.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 absl::Status SelectDefault(const GpuInfo& gpu_info, const OperationDef& op_def,
                            ModelHints hints, const std::vector<Value*>& inputs,
                            const std::vector<Value*>& outputs, const Node& node,
                            GPUOperationsSubgraph* gpu_subgraph);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SELECTOR_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DEFAULT_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/dw_convolution_selector.h
similarity index 80%
rename from tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h
rename to tensorflow/lite/delegates/gpu/common/selectors/dw_convolution_selector.h
index 0c92098..f3e50a9 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/dw_convolution_selector.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DW_CONVOLUTION_SELECTOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DW_CONVOLUTION_SELECTOR_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DW_CONVOLUTION_SELECTOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DW_CONVOLUTION_SELECTOR_H_
 
 #include <memory>
 
@@ -24,14 +24,12 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 std::unique_ptr<GPUOperation> SelectDWConvolution(
     const DepthwiseConvolution2DAttributes& attr, const GpuInfo& gpu_info,
     const OperationDef& op_def);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DW_CONVOLUTION_SELECTOR_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DW_CONVOLUTION_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h
similarity index 80%
rename from tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h
rename to tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h
index 5b1563a..e2e910e 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_FULLY_CONNECTED_SELECTOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_FULLY_CONNECTED_SELECTOR_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_FULLY_CONNECTED_SELECTOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_FULLY_CONNECTED_SELECTOR_H_
 
 #include <memory>
 
@@ -24,14 +24,12 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 std::unique_ptr<GPUOperation> SelectFullyConnected(
     const FullyConnectedAttributes& attr, const GpuInfo& gpu_info,
     const OperationDef& op_def, int batch_size);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_FULLY_CONNECTED_SELECTOR_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_FULLY_CONNECTED_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
similarity index 97%
rename from tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
rename to tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
index 3fca1c6..41c6937 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
@@ -13,19 +13,19 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h"
 
 #include "absl/strings/str_cat.h"
 #include "absl/types/any.h"
-#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
-#include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h"
-#include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h"
-#include "tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h"
-#include "tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h"
-#include "tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h"
-#include "tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/convolution_selector.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/convolution_transposed_selector.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/default_selector.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/dw_convolution_selector.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
@@ -38,7 +38,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 namespace {
 bool IsRecommendedForWinograd4x4To6x6(const Convolution2DAttributes& attr,
                                       const GpuInfo& gpu_info,
@@ -530,6 +529,5 @@
   }
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h
similarity index 82%
rename from tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h
rename to tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h
index b81bdaa..dfffa9b 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h
@@ -13,21 +13,20 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_OPERATION_SELECTOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_OPERATION_SELECTOR_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_OPERATION_SELECTOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_OPERATION_SELECTOR_H_
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_hints.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
                                   const OperationDef& op_def, ModelHints hints,
@@ -36,8 +35,7 @@
                                   const Node& node,
                                   GPUOperationsSubgraph* gpu_subgraph);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_OPERATION_SELECTOR_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_OPERATION_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
similarity index 98%
rename from tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
rename to tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
index 48bb8fc..6f7baef 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h"
 
 #include <memory>
 #include <set>
@@ -44,7 +44,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
                                          const GpuInfo& gpu_info) {
@@ -194,6 +193,5 @@
       CreateQuantizeAndDequantize(op_def, attr));
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
similarity index 93%
rename from tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
rename to tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
index 52c2310..4f757a8 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
@@ -13,12 +13,12 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SIMPLE_SELECTORS_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SIMPLE_SELECTORS_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SIMPLE_SELECTORS_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SIMPLE_SELECTORS_H_
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
@@ -26,7 +26,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
                                          const GpuInfo& gpu_info);
@@ -98,8 +97,7 @@
 std::unique_ptr<GPUOperation> SelectQuantizeAndDequantize(
     const QuantizeAndDequantizeAttributes& attr, const OperationDef& op_def);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SIMPLE_SELECTORS_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SIMPLE_SELECTORS_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/special_selector.cc
similarity index 97%
rename from tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc
rename to tensorflow/lite/delegates/gpu/common/selectors/special_selector.cc
index 6d5300d..1416001 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/special_selector.cc
@@ -13,10 +13,9 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h"
 
 #include "absl/types/any.h"
-#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
@@ -28,7 +27,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 namespace {
 absl::Status TryDepthwiseConvPlus1x1Conv(
     CalculationsPrecision precision, const GraphFloat32& graph,
@@ -208,6 +206,5 @@
   return absl::NotFoundError("No special combination.");
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/special_selector.h
similarity index 79%
rename from tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h
rename to tensorflow/lite/delegates/gpu/common/selectors/special_selector.h
index aecd0a0..fc33d51 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/special_selector.h
@@ -13,22 +13,22 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SPECIAL_SELECTOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SPECIAL_SELECTOR_H_
 
 #include <map>
 #include <set>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 absl::Status GPUSubgraphFromGraph(
     const GpuInfo& gpu_info, CalculationsPrecision precision,
@@ -37,8 +37,7 @@
     std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph,
     std::string* name);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SPECIAL_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc b/tensorflow/lite/delegates/gpu/common/selectors/subgraph.cc
similarity index 93%
rename from tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc
rename to tensorflow/lite/delegates/gpu/common/selectors/subgraph.cc
index cd3c987..5bb11b6 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/subgraph.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
 
 #include <memory>
 
@@ -23,7 +23,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
     const std::vector<Value*>& inputs, const std::vector<Value*>& outputs,
@@ -41,6 +40,5 @@
   return &gpu_subgraph->operations[0].operation;
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h b/tensorflow/lite/delegates/gpu/common/selectors/subgraph.h
similarity index 87%
rename from tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h
rename to tensorflow/lite/delegates/gpu/common/selectors/subgraph.h
index f94e0c4..243e2d4 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/subgraph.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SUBGRAPH_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SUBGRAPH_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SUBGRAPH_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SUBGRAPH_H_
 
 #include <memory>
 #include <vector>
@@ -25,7 +25,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 struct GPUOperationWithRefs {
   std::unique_ptr<GPUOperation> operation;
@@ -46,8 +45,7 @@
     const std::vector<Value*>& inputs, const std::vector<Value*>& outputs,
     GPUOperationsSubgraph* gpu_subgraph);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SUBGRAPH_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SUBGRAPH_H_
diff --git a/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs b/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs
index 8d94347..5b1918d 100644
--- a/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs
+++ b/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs
@@ -68,8 +68,17 @@
 
 enum DataType : byte {
   UNKNOWN = 0,
-  FLOAT32 = 1,
-  FLOAT16 = 2,
+  FLOAT16 = 1,
+  FLOAT32 = 2,
+  FLOAT64 = 3,
+  UINT8 = 4,
+  INT8 = 5,
+  UINT16 = 6,
+  INT16 = 7,
+  UINT32 = 8,
+  INT32 = 9,
+  UINT64 = 10,
+  INT64 = 11,
 }
 
 enum MemoryType : byte {
diff --git a/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h b/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
index e3c3a5c..82fadcc 100644
--- a/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
+++ b/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
@@ -122,33 +122,39 @@
 
 enum class DataType : int8_t {
   UNKNOWN = 0,
-  FLOAT32 = 1,
-  FLOAT16 = 2,
+  FLOAT16 = 1,
+  FLOAT32 = 2,
+  FLOAT64 = 3,
+  UINT8 = 4,
+  INT8 = 5,
+  UINT16 = 6,
+  INT16 = 7,
+  UINT32 = 8,
+  INT32 = 9,
+  UINT64 = 10,
+  INT64 = 11,
   MIN = UNKNOWN,
-  MAX = FLOAT16
+  MAX = INT64
 };
 
-inline const DataType (&EnumValuesDataType())[3] {
+inline const DataType (&EnumValuesDataType())[12] {
   static const DataType values[] = {
-    DataType::UNKNOWN,
-    DataType::FLOAT32,
-    DataType::FLOAT16
-  };
+      DataType::UNKNOWN, DataType::FLOAT16, DataType::FLOAT32,
+      DataType::FLOAT64, DataType::UINT8,   DataType::INT8,
+      DataType::UINT16,  DataType::INT16,   DataType::UINT32,
+      DataType::INT32,   DataType::UINT64,  DataType::INT64};
   return values;
 }
 
 inline const char * const *EnumNamesDataType() {
-  static const char * const names[4] = {
-    "UNKNOWN",
-    "FLOAT32",
-    "FLOAT16",
-    nullptr
-  };
+  static const char *const names[13] = {
+      "UNKNOWN", "FLOAT16", "FLOAT32", "FLOAT64", "UINT8", "INT8", "UINT16",
+      "INT16",   "UINT32",  "INT32",   "UINT64",  "INT64", nullptr};
   return names;
 }
 
 inline const char *EnumNameDataType(DataType e) {
-  if (flatbuffers::IsOutRange(e, DataType::UNKNOWN, DataType::FLOAT16)) return "";
+  if (flatbuffers::IsOutRange(e, DataType::UNKNOWN, DataType::INT64)) return "";
   const size_t index = static_cast<size_t>(e);
   return EnumNamesDataType()[index];
 }
diff --git a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc
index 3c02221..0acb7fb 100644
--- a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc
@@ -747,22 +747,22 @@
 void TensorDescriptor::UploadData(
     const tflite::gpu::Tensor<BHWC, DataType::FLOAT32>& src) {
   shape = BHWDC(src.shape.b, src.shape.h, src.shape.w, 1, src.shape.c);
-  UploadData(absl::MakeConstSpan(src.data));
+  UploadData(src.data.data());
 }
 
 void TensorDescriptor::UploadData(
     const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src) {
   shape = BHWDC(1, src.shape.h, src.shape.w, 1, src.shape.c);
-  UploadData(absl::MakeConstSpan(src.data));
+  UploadData(src.data.data());
 }
 
 void TensorDescriptor::UploadData(
     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src) {
   shape = BHWDC(1, 1, 1, 1, src.shape.v);
-  UploadData(absl::MakeConstSpan(src.data));
+  UploadData(src.data.data());
 }
 
-void TensorDescriptor::UploadData(absl::Span<const float> src) {
+void TensorDescriptor::UploadData(const float* src) {
   int aligned_channels = storage_type == TensorStorageType::SINGLE_TEXTURE_2D
                              ? shape.c
                              : AlignByN(shape.c, 4);
@@ -770,10 +770,10 @@
   data.resize(elements_count * SizeOf(data_type));
   if (data_type == DataType::FLOAT32) {
     float* gpu_data = reinterpret_cast<float*>(data.data());
-    DataFromBHWDC(src, shape, *this, absl::MakeSpan(gpu_data, elements_count));
+    DataFromBHWDC(src, shape, *this, gpu_data);
   } else {
     half* gpu_data = reinterpret_cast<half*>(data.data());
-    DataFromBHWDC(src, shape, *this, absl::MakeSpan(gpu_data, elements_count));
+    DataFromBHWDC(src, shape, *this, gpu_data);
   }
 }
 
@@ -848,8 +848,8 @@
 }  // namespace
 
 template <typename T>
-void DataFromBHWDC(absl::Span<const float> src, const BHWDC& shape,
-                   const TensorDescriptor& desc, absl::Span<T> dst) {
+void DataFromBHWDC(const float* src, const BHWDC& shape,
+                   const TensorDescriptor& desc, T* dst) {
   const int channels_alignment = GetChannelsAlignment(desc, shape);
   const int slices = DivideRoundUp(shape.c, 4);
   for (int b = 0; b < shape.b; ++b) {
@@ -876,18 +876,14 @@
   }
 }
 
-template void DataFromBHWDC<float>(absl::Span<const float> src,
-                                   const BHWDC& shape,
-                                   const TensorDescriptor& desc,
-                                   absl::Span<float> dst);
-template void DataFromBHWDC<half>(absl::Span<const float> src,
-                                  const BHWDC& shape,
-                                  const TensorDescriptor& desc,
-                                  absl::Span<half> dst);
+template void DataFromBHWDC<float>(const float* src, const BHWDC& shape,
+                                   const TensorDescriptor& desc, float* dst);
+template void DataFromBHWDC<half>(const float* src, const BHWDC& shape,
+                                  const TensorDescriptor& desc, half* dst);
 
 template <typename T>
-void DataToBHWDC(absl::Span<const T> src, const BHWDC& shape,
-                 const TensorDescriptor& desc, absl::Span<float> dst) {
+void DataToBHWDC(const T* src, const BHWDC& shape, const TensorDescriptor& desc,
+                 float* dst) {
   const int channels_alignment = GetChannelsAlignment(desc, shape);
   const int slices = DivideRoundUp(shape.c, 4);
   for (int b = 0; b < shape.b; ++b) {
@@ -910,13 +906,10 @@
   }
 }
 
-template void DataToBHWDC<float>(absl::Span<const float> src,
-                                 const BHWDC& shape,
-                                 const TensorDescriptor& desc,
-                                 absl::Span<float> dst);
-template void DataToBHWDC<half>(absl::Span<const half> src, const BHWDC& shape,
-                                const TensorDescriptor& desc,
-                                absl::Span<float> dst);
+template void DataToBHWDC<float>(const float* src, const BHWDC& shape,
+                                 const TensorDescriptor& desc, float* dst);
+template void DataToBHWDC<half>(const half* src, const BHWDC& shape,
+                                const TensorDescriptor& desc, float* dst);
 
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h
index 8f339ba..6486280 100644
--- a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h
+++ b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h
@@ -169,16 +169,16 @@
                            std::string* xc, std::string* yc, std::string* zc,
                            std::string* sc, std::string* bc) const;
 
-  void UploadData(absl::Span<const float> src);
+  void UploadData(const float* src);
 };
 
 template <typename T>
-void DataFromBHWDC(absl::Span<const float> src, const BHWDC& shape,
-                   const TensorDescriptor& desc, absl::Span<T> dst);
+void DataFromBHWDC(const float* src, const BHWDC& shape,
+                   const TensorDescriptor& desc, T* dst);
 
 template <typename T>
-void DataToBHWDC(absl::Span<const T> src, const BHWDC& shape,
-                 const TensorDescriptor& desc, absl::Span<float> dst);
+void DataToBHWDC(const T* src, const BHWDC& shape, const TensorDescriptor& desc,
+                 float* dst);
 
 std::string ToString(TensorStorageType type);
 
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/softmax.cc b/tensorflow/lite/delegates/gpu/common/tasks/softmax.cc
index 2bbad5c..09316cd 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/softmax.cc
@@ -33,15 +33,28 @@
   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
        "return; \n";
   c += "  float sum = 0.0f;\n";
+  c += "  float maximum = args.src_tensor.Read<float>(X, Y, 0).x;\n";
   c += "  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n";
   c += "    float4 t = args.src_tensor.Read<float>(X, Y, d);\n";
+  c += "    maximum = max(maximum, t.x);\n";
+  c += "    if (d * 4 + 1 < args.dst_tensor.Channels()) maximum = max(maximum, "
+       "t.y);\n";
+  c += "    if (d * 4 + 2 < args.dst_tensor.Channels()) maximum = max(maximum, "
+       "t.z);\n";
+  c += "    if (d * 4 + 3 < args.dst_tensor.Channels()) maximum = max(maximum, "
+       "t.w);\n";
+  c += "  }\n";
+  c += "  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n";
+  c += "    float4 t = args.src_tensor.Read<float>(X, Y, d) - "
+       "(float4)(maximum);\n";
   c += "    sum += exp(t.x);\n";
   c += "    if (d * 4 + 1 < args.dst_tensor.Channels()) sum += exp(t.y);\n";
   c += "    if (d * 4 + 2 < args.dst_tensor.Channels()) sum += exp(t.z);\n";
   c += "    if (d * 4 + 3 < args.dst_tensor.Channels()) sum += exp(t.w);\n";
   c += "  }\n";
   c += "  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n";
-  c += "    float4 t = args.src_tensor.Read<float>(X, Y, d);\n";
+  c += "    float4 t = args.src_tensor.Read<float>(X, Y, d) - "
+       "(float4)(maximum);\n";
   c += "    t = exp(t) / sum;\n";
   c += "    FLT4 result = TO_FLT4(t);\n";
   c += "    args.dst_tensor.Write(result, X, Y, d);\n";
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.cc b/tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.cc
index 952f081..b5fe668 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.cc
@@ -45,7 +45,6 @@
   args_.AddFloat("mask_y");
   args_.AddFloat("mask_z");
   args_.AddFloat("mask_w");
-  args_.AddInt("slices_x32");
 
   std::string c;
   c += "__kernel void main_function(\n";
@@ -58,24 +57,47 @@
   }
   c += "  float4 mask = (float4)(args.mask_x, args.mask_y, args.mask_z, "
        "args.mask_w);\n";
-  c += "  int offset = 0;\n";
-  c += "  float sum = 0.0f;\n";
-  c += "  int s = 0;\n";
+  c += "  float4 maxx4 = (float4)(args.src_tensor.Read<float>(0, 0, 0).x);\n";
   c += "  int tid = get_local_id(0);\n";
-  c += "  do {\n";
-  c += "    int z = offset + tid;\n";
-  c += "    if (z < args.dst_tensor.Slices()) {\n";
-  c += "      float4 mask_temp = z == args.dst_tensor.Slices() - 1 ? mask : "
+  c += "  for (int s = tid; s < args.src_tensor.Slices(); s += 32) {\n";
+  c += "    float4 mask_a = s == args.src_tensor.Slices() - 1 ? mask : "
        "(float4)(1.0f);\n";
-  c += "      float4 src = args.src_tensor.Read<float>(0, 0, z);\n";
-  c += "      sum += dot(mask_temp, exp(src));\n";
-  c += "      offset += 32;\n";
-  c += "    }\n";
-  c += "    s++;\n";
-  c += "  } while (s < args.slices_x32);\n";
-  c += "\n";
+  c += "    float4 mask_b = (float4)(1.0f) - mask_a;\n";
+  c += "    float4 src = args.src_tensor.Read<float>(0, 0, s);\n";
+  c += "    src = src * mask_a + mask_b * src.x;\n";
+  c += "    maxx4 = max(maxx4, src);\n";
+  c += "  }\n";
+  c += "  float maximum = max(maxx4.x, maxx4.y);\n";
+  c += "  maximum = max(maximum, maxx4.z);\n";
+  c += "  maximum = max(maximum, maxx4.w);\n";
   c += "  __local float4 tmp[8];\n";
   c += "  __local float* tmpx1 = (__local float*)tmp;\n";
+  c += "  tmpx1[tid] = maximum;\n";
+  c += "  barrier(CLK_LOCAL_MEM_FENCE);\n";
+  c += "  if (tid == 0) {\n";
+  c += "    maxx4 = max(tmp[0], tmp[1]);\n";
+  c += "    maxx4 = max(maxx4, tmp[2]);\n";
+  c += "    maxx4 = max(maxx4, tmp[3]);\n";
+  c += "    maxx4 = max(maxx4, tmp[4]);\n";
+  c += "    maxx4 = max(maxx4, tmp[5]);\n";
+  c += "    maxx4 = max(maxx4, tmp[6]);\n";
+  c += "    maxx4 = max(maxx4, tmp[7]);\n";
+  c += "    maximum = max(maxx4.x, maxx4.y);\n";
+  c += "    maximum = max(maximum, maxx4.z);\n";
+  c += "    maximum = max(maximum, maxx4.w);\n";
+  c += "    tmpx1[0] = maximum;\n";
+  c += "  }\n";
+  c += "  barrier(CLK_LOCAL_MEM_FENCE);\n";
+  c += "  maximum = tmpx1[0];\n";
+  c += "  float sum = 0.0f;\n";
+  c += "  for (int s = tid; s < args.src_tensor.Slices(); s += 32) {\n";
+  c += "    float4 mask_temp = s == args.src_tensor.Slices() - 1 ? mask : "
+       "(float4)(1.0f);\n";
+  c += "    float4 src = args.src_tensor.Read<float>(0, 0, s) - "
+       "(float4)(maximum);\n";
+  c += "    sum += dot(mask_temp, exp(src));\n";
+  c += "  }\n";
+  c += "  barrier(CLK_LOCAL_MEM_FENCE);\n";
   c += "  tmpx1[tid] = sum;\n";
   c += "  barrier(CLK_LOCAL_MEM_FENCE);\n";
   c += "  if (tid == 0) {\n";
@@ -92,18 +114,13 @@
   c += "  barrier(CLK_LOCAL_MEM_FENCE);\n";
   c += "  sum = tmpx1[0];\n";
   c += "\n";
-  c += "  offset = 0;\n";
-  c += "  s = 0;\n";
-  c += "  do {\n";
-  c += "    int z = offset + tid;\n";
-  c += "    if (z < args.dst_tensor.Slices()) {\n";
-  c += "      FLT4 res = TO_FLT4(exp(args.src_tensor.Read<float>(0, 0, "
-       "z))*sum);\n";
-  c += "      args.dst_tensor.Write(res, 0, 0, z);\n";
-  c += "      offset += 32;\n";
-  c += "    }\n";
-  c += "    s++;\n";
-  c += "  } while (s < args.slices_x32);\n";
+  c += "  int dst_s = get_global_id(0);\n";
+  c += "  if (dst_s < args.dst_tensor.Slices()) {\n";
+  c += "    float4 src = args.src_tensor.Read<float>(0, 0, dst_s) - "
+       "(float4)(maximum);\n";
+  c += "    FLT4 res = TO_FLT4(exp(src) * sum);\n";
+  c += "    args.dst_tensor.Write(res, 0, 0, dst_s);\n";
+  c += "  }\n";
   c += "}\n";
   return c;
 }
@@ -114,12 +131,12 @@
   RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
   RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
   RETURN_IF_ERROR(args->SetFloat("mask_w", mask.w));
-  RETURN_IF_ERROR(
-      args->SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32)));
   return absl::OkStatus();
 }
 
-int3 Softmax1x1::GetGridSize() const { return int3(32, dst_[0]->Batch(), 1); }
+int3 Softmax1x1::GetGridSize() const {
+  return int3(dst_[0]->Slices(), dst_[0]->Batch(), 1);
+}
 
 Softmax1x1 CreateSoftmax1x1(const OperationDef& definition) {
   return Softmax1x1(definition);
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc b/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc
index 7352e7f..32424a9 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc
@@ -23,7 +23,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 namespace {
 void UploadWeights(const DepthwiseConvolution2DAttributes& dw_attr,
                    const Convolution2DAttributes& conv_attr,
@@ -264,6 +263,5 @@
   return result;
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.h b/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.h
index 93ef127..c891261 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.h
+++ b/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.h
@@ -30,7 +30,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 bool IsDepthwiseConvPlus1x1ConvSupported(
     const OperationDef& definition,
@@ -42,7 +41,6 @@
     const DepthwiseConvolution2DAttributes& dw_attr,
     const Convolution2DAttributes& conv_attr);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc
index 7a10f49..a632dff 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc
@@ -27,7 +27,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 namespace {
 bool UseBufferForWeights(const GpuInfo& gpu_info) {
   return gpu_info.IsAdreno() || gpu_info.IsAMD() || gpu_info.IsMali();
@@ -195,6 +194,5 @@
   return result;
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h
index c447c80..5c14729 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h
+++ b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h
@@ -35,7 +35,6 @@
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 template <DataType T, typename S>
 void RearrangeFCWeightsToIOO4I4(const tflite::gpu::Tensor<OHWI, T>& weights,
@@ -176,7 +175,6 @@
                       const FullyConnectedAttributes& attr0,
                       const FullyConnectedAttributes& attr1);
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
index f546fdd..49e10a7 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
@@ -66,25 +66,58 @@
     };
     std::vector<Variable> uniform_parameters = {
         {"depth", depth},
-        {"depth_div_32", DivideRoundUp(depth, 32)},
         {"mask", GetMask(ctx.output_shapes[0][3])},
     };
     std::string source_code = R"(
   highp vec4 kOnes = vec4(1.0);
-  highp float sum = 0.0;
-  int offset = 0;
-  int s = 0;
   int tid = int(gl_LocalInvocationID.x);
-  do {
-    int z = offset + tid;
-    if (z < $depth$) {
-      highp vec4 mask_temp = z == $depth$ - 1 ? $mask$ : kOnes;
-      highp vec4 src = $input_data_0[0, 0, z]$;
-      sum += dot(mask_temp, exp(src));
-      offset += 32;
-    }
-    s++;
-  } while (s < $depth_div_32$);
+  highp vec4 maxx4 = $input_data_0[0, 0, 0]$;
+  maxx4.y = maxx4.x;
+  maxx4.z = maxx4.x;
+  maxx4.w = maxx4.x;
+  for (int s = tid; s < $depth$; s += 32) {
+    highp vec4 mask_a = s == $depth$ - 1 ? $mask$ : kOnes;
+    highp vec4 mask_b = kOnes - mask_a;
+    highp vec4 src = $input_data_0[0, 0, s]$;
+    src = src * mask_a + mask_b * src.x;
+    maxx4 = max(maxx4, src);
+  }
+  highp float maximum = max(maxx4.x, maxx4.y);
+  maximum = max(maximum, maxx4.z);
+  maximum = max(maximum, maxx4.w);
+  partial_sum[tid / 4][tid % 4] = maximum;
+
+  memoryBarrierShared();
+  barrier();
+
+  if (tid == 0) {
+    maxx4 = max(partial_sum[0], partial_sum[1]);
+    maxx4 = max(maxx4, partial_sum[2]);
+    maxx4 = max(maxx4, partial_sum[3]);
+    maxx4 = max(maxx4, partial_sum[4]);
+    maxx4 = max(maxx4, partial_sum[5]);
+    maxx4 = max(maxx4, partial_sum[6]);
+    maxx4 = max(maxx4, partial_sum[7]);
+    maximum = max(maxx4.x, maxx4.y);
+    maximum = max(maximum, maxx4.z);
+    maximum = max(maximum, maxx4.w);
+    partial_sum[0][0] = maximum;
+  }
+
+  memoryBarrierShared();
+  barrier();
+
+  maximum = partial_sum[0][0];
+
+  highp float sum = 0.0;
+  for (int s = tid; s < $depth$; s += 32) {
+    highp vec4 mask_temp = s == $depth$ - 1 ? $mask$ : kOnes;
+    highp vec4 src = $input_data_0[0, 0, s]$ - vec4(maximum);
+    sum += dot(mask_temp, exp(src));
+  }
+
+  memoryBarrierShared();
+  barrier();
 
   partial_sum[tid / 4][tid % 4] = sum;
 
@@ -108,24 +141,19 @@
 
   sum = partial_sum[0][0];
 
-  offset = 0;
-  s = 0;
-  do {
-    int z = offset + tid;
-    if (z < $depth$) {
-      highp vec4 src = $input_data_0[0, 0, z]$;
-      highp vec4 temp = exp(src) * sum;
-      $output_data_0[0, 0, z] = temp$;
-      offset += 32;
-    }
-    s++;
-  } while (s < $depth_div_32$);
+  int dst_s = int(gl_GlobalInvocationID.x);
+  if (dst_s < $depth$) {
+    highp vec4 src = $input_data_0[0, 0, dst_s]$ - vec4(maximum);
+    highp vec4 temp = exp(src) * sum;
+    $output_data_0[0, 0, dst_s] = temp$;
+  }
 )";
+
     *generated_code = {
         /*parameters=*/std::move(uniform_parameters),
         /*objects=*/{},
         /*shared_variables=*/std::move(shared_variables),
-        /*workload=*/uint3(32, 1, 1),
+        /*workload=*/uint3(depth, 1, 1),
         /*workgroup=*/uint3(32, 1, 1),
         /*source_code=*/std::move(source_code),
         /*input=*/IOStructure::ONLY_DEFINITIONS,
@@ -145,17 +173,24 @@
     std::string source_code = R"(
   highp vec4 kOnes = vec4(1.0);
   highp float sum = 0.0;
-  for (int d = 0; d < $src_depth$ - 1; ++d) {
+  highp float maximum = $input_data_0[gid.x, gid.y, 0]$.x;
+  for (int d = 0; d < $src_depth$; ++d) {
+    highp vec4 mask_a = d == $src_depth$ - 1 ? $mask$ : kOnes;
+    highp vec4 mask_b = kOnes - mask_a;
     highp vec4 src = $input_data_0[gid.x, gid.y, d]$;
-    sum += dot(kOnes, exp(src));
-  }
-  {
-    int d = $src_depth$ - 1;
-    highp vec4 src = $input_data_0[gid.x, gid.y, d]$;
-    sum += dot($mask$, exp(src));
+    src = src * mask_a + mask_b * src.x;
+    maximum = max(maximum, src.x);
+    maximum = max(maximum, src.y);
+    maximum = max(maximum, src.z);
+    maximum = max(maximum, src.w);
   }
   for (int d = 0; d < $src_depth$; ++d) {
-    highp vec4 src = $input_data_0[gid.x, gid.y, d]$;
+    highp vec4 mask_temp = d == $src_depth$ - 1 ? $mask$ : kOnes;
+    highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
+    sum += dot(mask_temp, exp(src));
+  }
+  for (int d = 0; d < $src_depth$; ++d) {
+    highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
     highp vec4 temp_sum = exp(src) / sum;
     $output_data_0[gid.x, gid.y, d] = temp_sum$;
   }
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/softmax_test.cc
index 1707e1e..67577bb 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/softmax_test.cc
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax_test.cc
@@ -121,6 +121,76 @@
                          std::exp(0.3f) / sum, std::exp(0.4f) / sum}));
 }
 
+TEST(SoftmaxTest, SoftmaxBigNumber) {
+  TensorRef<BHWC> input;
+  input.type = DataType::FLOAT32;
+  input.ref = 0;
+  input.shape = BHWC(1, 2, 1, 2);
+
+  TensorRef<BHWC> output;
+  output.type = DataType::FLOAT32;
+  output.ref = 1;
+  output.shape = BHWC(1, 2, 1, 2);
+
+  SoftmaxAttributes attr;
+  attr.axis = Axis::CHANNELS;
+
+  double doubles[4] = {1.0, 2.0, 3.0, 100.0};
+  // exp(100) is inf in float (32 bit) but representable in double (64 bit)
+  ASSERT_TRUE(std::isinf(std::exp(static_cast<float>(doubles[3]))));
+  ASSERT_FALSE(std::isinf(std::exp(doubles[3])));
+  double s0 = std::exp(doubles[0]) + std::exp(doubles[1]);
+  double s1 = std::exp(doubles[2]) + std::exp(doubles[3]);
+
+  SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
+                      {output});
+  ASSERT_TRUE(model.PopulateTensor(
+      0, {static_cast<float>(doubles[0]), static_cast<float>(doubles[1]),
+          static_cast<float>(doubles[2]), static_cast<float>(doubles[3])}));
+  ASSERT_OK(model.Invoke(*NewSoftmaxNodeShader()));
+  EXPECT_THAT(model.GetOutput(0),
+              Pointwise(FloatNear(1e-6f),
+                        {static_cast<float>(std::exp(doubles[0]) / s0),
+                         static_cast<float>(std::exp(doubles[1]) / s0),
+                         static_cast<float>(std::exp(doubles[2]) / s1),
+                         static_cast<float>(std::exp(doubles[3]) / s1)}));
+}
+
+TEST(SoftmaxTest, Softmax1x1BigNumber) {
+  TensorRef<BHWC> input;
+  input.type = DataType::FLOAT32;
+  input.ref = 0;
+  input.shape = BHWC(1, 1, 1, 4);
+
+  TensorRef<BHWC> output;
+  output.type = DataType::FLOAT32;
+  output.ref = 1;
+  output.shape = BHWC(1, 1, 1, 4);
+
+  SoftmaxAttributes attr;
+  attr.axis = Axis::CHANNELS;
+
+  double doubles[4] = {1.0, 2.0, 3.0, 100.0};
+  // exp(100) is inf in float (32 bit) but representable in double (64 bit)
+  ASSERT_TRUE(std::isinf(std::exp(static_cast<float>(doubles[3]))));
+  ASSERT_FALSE(std::isinf(std::exp(doubles[3])));
+  double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) +
+              std::exp(doubles[2]) + std::exp(doubles[3]);
+
+  SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
+                      {output});
+  ASSERT_TRUE(model.PopulateTensor(
+      0, {static_cast<float>(doubles[0]), static_cast<float>(doubles[1]),
+          static_cast<float>(doubles[2]), static_cast<float>(doubles[3])}));
+  ASSERT_OK(model.Invoke(*NewSoftmaxNodeShader()));
+  EXPECT_THAT(model.GetOutput(0),
+              Pointwise(FloatNear(1e-6f),
+                        {static_cast<float>(std::exp(doubles[0]) / s0),
+                         static_cast<float>(std::exp(doubles[1]) / s0),
+                         static_cast<float>(std::exp(doubles[2]) / s0),
+                         static_cast<float>(std::exp(doubles[3]) / s0)}));
+}
+
 }  // namespace
 }  // namespace gl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
index d31d058..dc7fd6f 100644
--- a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
+++ b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
@@ -23,9 +23,7 @@
 #include "tensorflow/lite/experimental/acceleration/compatibility/android_info.h"
 #include "tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h"
 
-#ifdef __cplusplus
 extern "C" {
-#endif  // __cplusplus
 
 JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
     JNIEnv* env, jclass clazz, jboolean precision_loss_allowed,
@@ -113,6 +111,4 @@
   delete compatibility_list;
 }
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif  // __cplusplus
diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD
index 6dcde34..8a01066 100644
--- a/tensorflow/lite/delegates/gpu/metal/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/BUILD
@@ -26,23 +26,27 @@
     deps = [
         ":compiled_model",
         ":compute_task_descriptor",
-        ":runtime_options",
         "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:precision",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common:winograd_util",
         "//tensorflow/lite/delegates/gpu/metal/kernels",
-        "//tensorflow/lite/delegates/gpu/metal/kernels:custom_registry",
+        "//tensorflow/lite/delegates/gpu/metal/selectors:operation_selector",
+        "//tensorflow/lite/delegates/gpu/metal/selectors:subgraph",
     ],
 )
 
 objc_library(
     name = "buffer",
-    srcs = ["buffer.mm"],
+    srcs = ["buffer.cc"],
     hdrs = ["buffer.h"],
-    copts = DEFAULT_COPTS,
+    copts = DEFAULT_COPTS + [
+        "-ObjC++",
+    ],
     sdk_frameworks = ["Metal"],
     deps = [
         ":gpu_object",
@@ -155,16 +159,19 @@
 
 objc_library(
     name = "compute_task",
-    srcs = ["compute_task.mm"],
+    srcs = ["compute_task.cc"],
     hdrs = ["compute_task.h"],
-    copts = DEFAULT_COPTS,
+    copts = DEFAULT_COPTS + [
+        "-ObjC++",
+    ],
     sdk_frameworks = ["Metal"],
     deps = [
         ":common",
         ":compute_task_descriptor",
         ":metal_arguments",
-        ":runtime_options",
+        ":metal_spatial_tensor",
         "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common:precision",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
@@ -203,17 +210,20 @@
 
 objc_library(
     name = "inference_context",
-    srcs = ["inference_context.mm"],
+    srcs = ["inference_context.cc"],
     hdrs = ["inference_context.h"],
-    copts = DEFAULT_COPTS,
+    copts = DEFAULT_COPTS + [
+        "-ObjC++",
+    ],
     sdk_frameworks = ["Metal"],
     deps = [
         ":compiled_model",
         ":compute_task",
         ":compute_task_descriptor",
-        ":runtime_options",
+        ":metal_spatial_tensor",
         "//tensorflow/lite/delegates/gpu/common:memory_management",
         "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common:precision",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:util",
@@ -245,9 +255,11 @@
 
 objc_library(
     name = "metal_arguments",
-    srcs = ["metal_arguments.mm"],
+    srcs = ["metal_arguments.cc"],
     hdrs = ["metal_arguments.h"],
-    copts = DEFAULT_COPTS,
+    copts = DEFAULT_COPTS + [
+        "-ObjC++",
+    ],
     sdk_frameworks = ["Metal"],
     deps = [
         ":buffer",
@@ -263,9 +275,11 @@
 
 objc_library(
     name = "metal_spatial_tensor",
-    srcs = ["metal_spatial_tensor.mm"],
+    srcs = ["metal_spatial_tensor.cc"],
     hdrs = ["metal_spatial_tensor.h"],
-    copts = DEFAULT_COPTS,
+    copts = DEFAULT_COPTS + [
+        "-ObjC++",
+    ],
     sdk_frameworks = ["Metal"],
     deps = [
         ":gpu_object",
@@ -274,7 +288,6 @@
         "//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
         "//tensorflow/lite/delegates/gpu/common/task:gpu_tensor",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
-        "@com_google_absl//absl/types:span",
     ],
 )
 
@@ -292,11 +305,6 @@
     ],
 )
 
-cc_library(
-    name = "runtime_options",
-    hdrs = ["runtime_options.h"],
-)
-
 objc_library(
     name = "TestBinary",
     testonly = 1,
@@ -342,7 +350,6 @@
         "//tensorflow/lite/delegates/gpu/metal:common",
         "//tensorflow/lite/delegates/gpu/metal:inference_context",
         "//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor",
-        "//tensorflow/lite/delegates/gpu/metal:runtime_options",
         "//tensorflow/lite/delegates/gpu/metal/kernels:test_util",
         "@com_google_absl//absl/memory",
     ],
diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc
index b04632b..02ee339 100644
--- a/tensorflow/lite/delegates/gpu/metal/api.cc
+++ b/tensorflow/lite/delegates/gpu/metal/api.cc
@@ -25,519 +25,23 @@
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
+#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/concat.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/prelu.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/relu.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/reshape.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/resize.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/slice.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
+#include "tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.h"
+#include "tensorflow/lite/delegates/gpu/metal/selectors/subgraph.h"
 
 namespace tflite {
 namespace gpu {
 namespace metal {
-namespace {
 
-ComputeTaskDescriptorPtr SelectDepthWiseConv(
-    const OperationDef& op_def, const DepthwiseConvolution2DAttributes& attr) {
-  if (CheckDepthWiseConv3x3Stride1x1Support(attr)) {
-    auto gpu_op = DepthWiseConv3x3Stride1x1(op_def, attr);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  } else if (CheckDepthWiseConv3x3Stride2Support(attr)) {
-    auto gpu_op = DepthWiseConv3x3Stride2(op_def, attr);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  } else {
-    auto gpu_op = DepthWiseConvolution(op_def, attr);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  }
-}
-
-ComputeTaskDescriptorPtr SelectConvolutionTransposed(
-    const OperationDef& op_def, const ConvolutionTransposedAttributes& attr,
-    const GpuInfo& gpu_info) {
-  if (CheckConvolutionTransposed4x4Support(attr)) {
-    auto gpu_op = ConvolutionTransposed4x4(op_def, attr, gpu_info);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  } else {
-    auto gpu_op = ConvolutionTransposed(op_def, attr, gpu_info);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  }
-}
-
-ComputeTaskDescriptorPtr SelectQuantizeAndDequantize(
-    const OperationDef& op_def, const QuantizeAndDequantizeAttributes& attr) {
-  auto gpu_op = QuantizeAndDequantize(op_def, attr);
-  return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-}
-
-ComputeTaskDescriptorPtr SelectPReLU(const OperationDef& op_def,
-                                     const BHWC& src_shape,
-                                     const PReLUAttributes& attr) {
-  auto alpha = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
-  if (alpha) {
-    auto gpu_op = PReLU(op_def, attr);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  }
-  auto alpha3d = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
-  if (!alpha3d) {
-    return {};
-  }
-  if (alpha3d->shape.h != src_shape.h || alpha3d->shape.w != src_shape.w ||
-      alpha3d->shape.c != src_shape.c) {
-    return {};
-  }
-  auto gpu_op = PReLUFull(op_def, attr);
-  return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-}
-
-ComputeTaskDescriptorPtr SelectReshape(const OperationDef& op_def,
-                                       const BHWC& src_shape,
-                                       const ReshapeAttributes& attr) {
-  if (src_shape.c % 4 == 0 && attr.new_shape.c % 4 == 0) {
-    auto gpu_op = Reshapex4(op_def, attr);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  } else {
-    auto gpu_op = Reshape(op_def, attr);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  }
-}
-
-ComputeTaskDescriptorPtr SelectSoftmax(const OperationDef& op_def,
-                                       const BHWC& src_shape,
-                                       const GpuInfo& gpu_info) {
-  if (src_shape.w == 1 && src_shape.h == 1) {
-    auto gpu_op = Softmax1x1(op_def, gpu_info, src_shape.c);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  } else {
-    auto gpu_op = Softmax(op_def, src_shape.c);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  }
-}
-
-ComputeTaskDescriptorPtr SelectSpaceToDepth(
-    const OperationDef& op_def, const SpaceToDepthAttributes& attr) {
-  auto gpu_op = SpaceToDepth(op_def, attr);
-  return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-}
-
-ComputeTaskDescriptorPtr SelectWinograd4x4To36(
-    const OperationDef& op_def, const Winograd4x4To36Attributes& attr,
-    const GpuInfo& gpu_info) {
-  if (gpu_info.IsApple()) {
-    auto gpu_op = Winograd4x4To36(op_def, attr);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  } else {
-    auto gpu_op = Winograd4x4To36TileX6(op_def, attr);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  }
-}
-
-ComputeTaskDescriptorPtr SelectWinograd36To4x4(
-    const OperationDef& op_def, const Winograd36To4x4Attributes& attr,
-    const GpuInfo& gpu_info) {
-  if (gpu_info.IsApple()) {
-    auto gpu_op = Winograd36To4x4(op_def, attr);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  } else {
-    auto gpu_op = Winograd36To4x4Tile4x1(op_def, attr);
-    return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-  }
-}
-
-bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
-                                   const BHWC& dst_shape) {
-  const int tiles_x = DivideRoundUp(dst_shape.w, 4);
-  const int tiles_y = DivideRoundUp(dst_shape.h, 4);
-  const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
-  const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
-  const bool suitable_attributes =
-      attr.weights.shape.w == 3 && attr.weights.shape.h == 3 &&
-      attr.dilations == HW(1, 1) && attr.strides == HW(1, 1);
-
-  const int min_depth = 16;
-  const int min_hw = 32;
-  const bool recommended_channels =
-      src_depth >= min_depth && dst_depth >= min_depth;
-  const bool recommended_hw = tiles_x * tiles_y >= min_hw;
-  return suitable_attributes && recommended_channels && recommended_hw;
-}
-
-absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
-                                const std::vector<ValueId>& inputs,
-                                const std::vector<ValueId>& outputs,
-                                const GpuInfo& gpu_info,
-                                const RuntimeOptions& options,
-                                int* last_value_id,
-                                std::map<ValueId, BHWC>* tensor_shapes,
-                                std::vector<NodeDescriptor>* nodes) {
+absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info,
+                     CalculationsPrecision precision,
+                     CompiledModel* compiled_model) {
   if (!IsBatchMatchesForAllValues(graph)) {
     return absl::InvalidArgumentError(
         "Only identical batch dimension is supported");
   }
-  int node_id = static_cast<int>(node->id);
-  auto op_type = OperationTypeFromString(node->operation.type);
-  nodes->push_back({});
-  auto& node_desc = nodes->back();
-  node_desc.description = node->operation.type + "_" + std::to_string(node->id);
-  node_desc.src_tensors_ids = inputs;
-  node_desc.dst_tensors_ids = outputs;
-  OperationDef op_def;
-  if (options.storage_precision == RuntimeOptions::Precision::FP32) {
-    op_def.precision = CalculationsPrecision::F32;
-  } else {
-    if (options.accumulator_precision == RuntimeOptions::Precision::FP32) {
-      op_def.precision = CalculationsPrecision::F32_F16;
-    } else {
-      op_def.precision = CalculationsPrecision::F16;
-    }
-  }
-  DataType data_type = DeduceDataTypeFromPrecision(op_def.precision);
-  TensorDescriptor tensor_descriptor =
-      TensorDescriptor{data_type, TensorStorageType::BUFFER, Layout::HWC};
-  op_def.src_tensors.resize(inputs.size(), tensor_descriptor);
-  op_def.dst_tensors.resize(outputs.size(), tensor_descriptor);
-  switch (op_type) {
-    case OperationType::ADD: {
-      if (inputs.size() == 1) {
-        if (node->operation.attributes.has_value()) {
-          auto attr =
-              absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
-          auto gpu_op = ElementwiseWithOneInputAndConstantArguent(
-              op_def, op_type, attr.param);
-          node_desc.task =
-              std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-        } else {
-          return absl::UnimplementedError(
-              "Missing attributes for single input op: " +
-              node->operation.type);
-        }
-      } else if (inputs.size() == 2) {
-        const auto srcs = graph.FindInputs(node_id);
-        auto gpu_op =
-            ElementwiseWithTwoInputs(op_def, srcs[1]->tensor.shape, op_type);
-        node_desc.task =
-            std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      } else {  // more than 2 inputs
-        auto gpu_op = Add(op_def);
-        node_desc.task =
-            std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      }
-      break;
-    }
-    case OperationType::CONCAT: {
-      std::vector<BHWC> input_shapes;
-      for (auto& input : graph.FindInputs(node->id)) {
-        input_shapes.push_back(input->tensor.shape);
-      }
-      auto gpu_op = Concat(
-          op_def, absl::any_cast<ConcatAttributes>(node->operation.attributes),
-          input_shapes);
-      node_desc.task =
-          std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      break;
-    }
-    case OperationType::CONVOLUTION_2D: {
-      if (graph.FindInputs(node->id).size() != 1) {
-        return absl::UnimplementedError(
-            "Convolution does not support more than 1 runtime tensor");
-      }
-      const auto src_shape = graph.FindInputs(node_id)[0]->tensor.shape;
-      const auto dst_shape = graph.FindOutputs(node_id)[0]->tensor.shape;
-      auto attr =
-          absl::any_cast<Convolution2DAttributes>(node->operation.attributes);
-      if (IsSuitableForWinograd4x4To6x6(attr, dst_shape)) {
-        int tiles_x = DivideRoundUp(dst_shape.w, 4);
-        int tiles_y = DivideRoundUp(dst_shape.h, 4);
-        const BHWC shape_0{src_shape.b, 36, tiles_x * tiles_y, src_shape.c};
-        const BHWC shape_1{src_shape.b, 36, tiles_x * tiles_y, dst_shape.c};
-
-        Winograd4x4To36Attributes wino_up_attr;
-        wino_up_attr.padding = attr.padding;
-        int value_id = *last_value_id + 1;
-        (*tensor_shapes)[value_id] = shape_0;
-        (*tensor_shapes)[value_id + 1] = shape_1;
-        nodes->resize(3);
-        (*nodes)[0].description = "winograd_up_" + std::to_string(node->id);
-        (*nodes)[1].description =
-            node->operation.type + std::to_string(node->id);
-        (*nodes)[2].description = "winograd_down_" + std::to_string(node->id);
-        (*nodes)[0].task =
-            SelectWinograd4x4To36(op_def, wino_up_attr, gpu_info);
-        (*nodes)[0].src_tensors_ids = {inputs[0]};
-        (*nodes)[0].dst_tensors_ids = {static_cast<unsigned int>(value_id)};
-
-        auto gpu_op = ConvolutionWino4x4To6x6(op_def, shape_1, attr, gpu_info);
-        (*nodes)[1].task =
-            std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-        (*nodes)[1].src_tensors_ids = {static_cast<unsigned int>(value_id)};
-        (*nodes)[1].dst_tensors_ids = {static_cast<unsigned int>(value_id + 1)};
-
-        Winograd36To4x4Attributes wino_down_attr;
-        wino_down_attr.output_shape = dst_shape;
-        wino_down_attr.biases = attr.bias;
-        (*nodes)[2].task =
-            SelectWinograd36To4x4(op_def, wino_down_attr, gpu_info);
-        (*nodes)[2].src_tensors_ids = {static_cast<unsigned int>(value_id + 1)};
-        (*nodes)[2].dst_tensors_ids = {outputs[0]};
-        (*last_value_id) += 2;
-      } else {
-        auto gpu_op = ConvolutionGeneric(op_def, dst_shape, attr, gpu_info);
-        node_desc.task =
-            std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      }
-      break;
-    }
-    case OperationType::CONVOLUTION_TRANSPOSED:
-      if (graph.FindInputs(node->id).size() != 1) {
-        return absl::UnimplementedError(
-            "Convolution Transposed does not support more than 1 runtime "
-            "tensor");
-      }
-      node_desc.task = SelectConvolutionTransposed(
-          op_def,
-          absl::any_cast<ConvolutionTransposedAttributes>(
-              node->operation.attributes),
-          gpu_info);
-      break;
-    case OperationType::DEPTHWISE_CONVOLUTION:
-      if (graph.FindInputs(node->id).size() != 1) {
-        return absl::UnimplementedError(
-            "DepthWise Convolution does not support more than 1 runtime "
-            "tensor");
-      }
-      node_desc.task = SelectDepthWiseConv(
-          op_def, absl::any_cast<DepthwiseConvolution2DAttributes>(
-                      node->operation.attributes));
-      break;
-    case OperationType::FULLY_CONNECTED: {
-      auto gpu_op = FullyConnected(
-          op_def,
-          absl::any_cast<FullyConnectedAttributes>(node->operation.attributes),
-          gpu_info);
-      node_desc.task =
-          std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      break;
-    }
-    case OperationType::MAX_UNPOOLING_2D: {
-      auto gpu_op = MaxUnpooling(
-          op_def,
-          absl::any_cast<MaxUnpooling2DAttributes>(node->operation.attributes));
-      node_desc.task =
-          std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      break;
-    }
-    case OperationType::MEAN: {
-      auto attr = absl::any_cast<MeanAttributes>(node->operation.attributes);
-      if (attr.dims != std::set<Axis>({Axis::HEIGHT, Axis::WIDTH})) {
-        return absl::UnimplementedError("Mean supports HW axis only in Metal");
-      }
-      auto gpu_op = Mean(op_def, attr);
-      node_desc.task =
-          std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      break;
-    }
-    case OperationType::MUL:
-      if (inputs.size() == 1) {
-        if (node->operation.attributes.has_value()) {
-          auto attr =
-              absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
-          auto gpu_op = ElementwiseWithOneInputAndConstantArguent(
-              op_def, op_type, attr.param);
-          node_desc.task =
-              std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-        } else {
-          return absl::UnimplementedError(
-              "Missing attributes for single input op: " +
-              node->operation.type);
-        }
-      } else if (inputs.size() == 2) {
-        const auto srcs = graph.FindInputs(node_id);
-        auto gpu_op =
-            ElementwiseWithTwoInputs(op_def, srcs[1]->tensor.shape, op_type);
-        node_desc.task =
-            std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      }
-      break;
-    case OperationType::PAD: {
-      auto attr = absl::any_cast<PadAttributes>(node->operation.attributes);
-      if (attr.appended.b != 0 || attr.prepended.b != 0) {
-        return absl::UnimplementedError("Padding for BATCH is not supported.");
-      }
-      auto gpu_op = Padding(op_def, attr);
-      node_desc.task =
-          std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      break;
-    }
-    case OperationType::POOLING_2D: {
-      auto attr =
-          absl::any_cast<Pooling2DAttributes>(node->operation.attributes);
-      op_def.dst_tensors = {tensor_descriptor};
-      auto gpu_op = Pooling(op_def, attr, false);
-      node_desc.task =
-          std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      node_desc.dst_tensors_ids = {outputs[0]};
-      if (attr.type == PoolingType::MAX && attr.output_indices) {
-        auto gpu_ind_op = Pooling(op_def, attr, true);
-        nodes->push_back({});
-        nodes->back().description =
-            node->operation.type + "_indices_" + std::to_string(node->id);
-        nodes->back().task =
-            std::make_shared<ComputeTaskDescriptor>(std::move(gpu_ind_op));
-        nodes->back().src_tensors_ids = {inputs[0]};
-        nodes->back().dst_tensors_ids = {outputs[1]};
-      }
-      break;
-    }
-    case OperationType::PRELU: {
-      const auto src_shape = graph.FindInputs(node_id)[0]->tensor.shape;
-      node_desc.task = SelectPReLU(
-          op_def, src_shape,
-          absl::any_cast<PReLUAttributes>(node->operation.attributes));
-      break;
-    }
-    case OperationType::RELU: {
-      auto gpu_op = ReLU(
-          op_def, absl::any_cast<ReLUAttributes>(node->operation.attributes));
-      node_desc.task =
-          std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      break;
-    }
-    case OperationType::QUANTIZE_AND_DEQUANTIZE:
-      node_desc.task = SelectQuantizeAndDequantize(
-          op_def, absl::any_cast<QuantizeAndDequantizeAttributes>(
-                      node->operation.attributes));
-      break;
-    case OperationType::RESHAPE: {
-      const auto src_shape = graph.FindInputs(node_id)[0]->tensor.shape;
-      node_desc.task = SelectReshape(
-          op_def, src_shape,
-          absl::any_cast<ReshapeAttributes>(node->operation.attributes));
-      break;
-    }
-    case OperationType::RESIZE: {
-      auto gpu_op = Resize(op_def, absl::any_cast<Resize2DAttributes>(
-                                       node->operation.attributes));
-      node_desc.task =
-          std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      break;
-    }
-    case OperationType::SLICE: {
-      auto gpu_op = Slice(
-          op_def, absl::any_cast<SliceAttributes>(node->operation.attributes));
-      node_desc.task =
-          std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      break;
-    }
-    case OperationType::SOFTMAX: {
-      auto attr = absl::any_cast<SoftmaxAttributes>(node->operation.attributes);
-      if (attr.axis != Axis::CHANNELS) {
-        return absl::UnimplementedError(
-            "Softmax supports only CHANNELS dimension");
-      }
-      const auto src_shape = graph.FindInputs(node_id)[0]->tensor.shape;
-      node_desc.task = SelectSoftmax(op_def, src_shape, gpu_info);
-      break;
-    }
-    case OperationType::SPACE_TO_DEPTH:
-      node_desc.task = SelectSpaceToDepth(
-          op_def,
-          absl::any_cast<SpaceToDepthAttributes>(node->operation.attributes));
-      break;
-    case OperationType::ABS:
-    case OperationType::COPY:
-    case OperationType::COS:
-    case OperationType::ELU:
-    case OperationType::EXP:
-    case OperationType::HARD_SWISH:
-    case OperationType::LOG:
-    case OperationType::NEG:
-    case OperationType::RSQRT:
-    case OperationType::SIGMOID:
-    case OperationType::SIN:
-    case OperationType::SQRT:
-    case OperationType::SQUARE:
-    case OperationType::TANH: {
-      auto gpu_op = ElementwiseWithOneInput(op_def, op_type);
-      node_desc.task =
-          std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      break;
-    }
-    case OperationType::DIV:
-    case OperationType::MAXIMUM:
-    case OperationType::MINIMUM:
-    case OperationType::POW:
-    case OperationType::SQUARED_DIFF:
-    case OperationType::SUB: {
-      if (inputs.size() == 1) {
-        if (node->operation.attributes.has_value()) {
-          auto attr =
-              absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
-          auto gpu_op = ElementwiseWithOneInputAndConstantArguent(
-              op_def, op_type, attr.param);
-          node_desc.task =
-              std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-        } else {
-          return absl::UnimplementedError(
-              "Missing attributes for single input op: " +
-              node->operation.type);
-        }
-      } else if (inputs.size() == 2) {
-        const auto srcs = graph.FindInputs(node_id);
-        auto gpu_op =
-            ElementwiseWithTwoInputs(op_def, srcs[1]->tensor.shape, op_type);
-        node_desc.task =
-            std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
-      }
-    } break;
-    case OperationType::BATCH_NORMALIZATION:
-    case OperationType::BATCH_TO_SPACE:
-    case OperationType::BATCHED_MATMUL:
-    case OperationType::CONST:
-    case OperationType::LSTM:
-    // TODO(b/162763635): implement MeanStddevNormalization for Metal.
-    case OperationType::MEAN_STDDEV_NORMALIZATION:
-    case OperationType::REDUCE_MAXIMUM:
-    case OperationType::REDUCE_MINIMUM:
-    case OperationType::REDUCE_PRODUCT:
-    case OperationType::REDUCE_SUM:
-    // comparison operations
-    case OperationType::LESS:
-    case OperationType::LESS_EQUAL:
-    case OperationType::EQUAL:
-    case OperationType::NOT_EQUAL:
-    case OperationType::GREATER:
-    case OperationType::GREATER_EQUAL:
-    case OperationType::SPACE_TO_BATCH:
-    case OperationType::TRANSPOSE:
-    case OperationType::UNKNOWN:
-      return absl::UnimplementedError("Unsupported op: " +
-                                      node->operation.type);
-  }
-  return absl::OkStatus();
-}
-
-}  // namespace
-
-absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info,
-                     const RuntimeOptions& options,
-                     CompiledModel* compiled_model) {
   int last_value_id = 0;
   for (const auto& value : graph.values()) {
     compiled_model->tensor_shapes[value->id] = value->tensor.shape;
@@ -545,43 +49,54 @@
   }
   int node_linear_id = 0;
   for (const auto& node : graph.nodes()) {
-    std::vector<ValueId> inputs;
-    for (auto& input : graph.FindInputs(node->id)) {
-      inputs.push_back(static_cast<ValueId>(input->id));
+    auto inputs = graph.FindInputs(node->id);
+    auto outputs = graph.FindOutputs(node->id);
+    DataType data_type = DeduceDataTypeFromPrecision(precision);
+    TensorDescriptor tensor_descriptor =
+        TensorDescriptor{data_type, TensorStorageType::BUFFER, Layout::HWC};
+    OperationDef op_def;
+    op_def.precision = precision;
+    for (int j = 0; j < inputs.size(); ++j) {
+      op_def.src_tensors.push_back(tensor_descriptor);
     }
-    std::vector<ValueId> outputs;
-    for (auto& output : graph.FindOutputs(node->id)) {
-      outputs.push_back(static_cast<ValueId>(output->id));
+    for (int j = 0; j < outputs.size(); ++j) {
+      op_def.dst_tensors.push_back(tensor_descriptor);
     }
-    std::vector<NodeDescriptor> node_descs;
-    std::vector<ComputeTaskDescriptorPtr> custom_tasks;
-    auto custom_status =
-        RegisterCustomOps(graph, node, inputs, outputs, options, &custom_tasks);
-    if (!custom_status.ok()) {
-      auto primary_status = RegisterPrimaryOps(
-          graph, node, inputs, outputs, gpu_info, options, &last_value_id,
-          &compiled_model->tensor_shapes, &node_descs);
-      if (!primary_status.ok()) {
-        return absl::UnimplementedError(
-            absl::Substitute("Unsupported op type: $0; custom registry error: "
-                             "$1; primary registry error: $2;",
-                             node->operation.type, custom_status.message(),
-                             primary_status.message()));
+    GPUOperationsSubgraph gpu_subgraph;
+    RETURN_IF_ERROR(GPUOperationFromNode(gpu_info, op_def, inputs, outputs,
+                                         *node, &gpu_subgraph));
+    std::map<int, ValueId> mapping_to_global_ids;
+    for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
+      const auto& t = gpu_subgraph.new_tensors[j];
+      last_value_id++;
+      compiled_model->tensor_shapes[last_value_id] = t.first;
+      mapping_to_global_ids[j] = last_value_id;
+    }
+    for (auto& gpu_op : gpu_subgraph.operations) {
+      NodeDescriptor metal_node;
+      metal_node.task = std::move(gpu_op.operation);
+      metal_node.src_tensors_ids.resize(gpu_op.input_ids.size());
+      for (int j = 0; j < gpu_op.input_ids.size(); ++j) {
+        int id = gpu_op.input_ids[j];
+        if (id >= 0) {
+          metal_node.src_tensors_ids[j] = id;
+        } else {
+          metal_node.src_tensors_ids[j] = mapping_to_global_ids[-(id + 1)];
+        }
       }
-    } else {
-      for (auto& custom_task : custom_tasks) {
-        NodeDescriptor node_desc;
-        node_desc.task = custom_task;
-        node_desc.description =
-            node->operation.type + "_" + std::to_string(node->id);
-        node_desc.src_tensors_ids = inputs;
-        node_desc.dst_tensors_ids = outputs;
-        node_descs.push_back(node_desc);
+      metal_node.dst_tensors_ids.resize(gpu_op.output_ids.size());
+      for (int j = 0; j < gpu_op.output_ids.size(); ++j) {
+        int id = gpu_op.output_ids[j];
+        if (id >= 0) {
+          metal_node.dst_tensors_ids[j] = id;
+        } else {
+          metal_node.dst_tensors_ids[j] = mapping_to_global_ids[-(id + 1)];
+        }
       }
-    }
-    for (auto& node_desc : node_descs) {
-      node_desc.id = node_linear_id++;
-      compiled_model->nodes.push_back(node_desc);
+      metal_node.description =
+          node->operation.type + " " + std::to_string(node->id);
+      metal_node.id = node_linear_id++;
+      compiled_model->nodes.push_back(std::move(metal_node));
     }
   }
   return absl::OkStatus();
diff --git a/tensorflow/lite/delegates/gpu/metal/api.h b/tensorflow/lite/delegates/gpu/metal/api.h
index f7cdfa4..a2ef5c2 100644
--- a/tensorflow/lite/delegates/gpu/metal/api.h
+++ b/tensorflow/lite/delegates/gpu/metal/api.h
@@ -18,9 +18,9 @@
 
 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
 namespace gpu {
@@ -28,7 +28,7 @@
 
 // Builds CompiledModel out of GraphFloat32 graph using provided RuntimeOptions.
 absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info,
-                     const RuntimeOptions& options,
+                     CalculationsPrecision precision,
                      CompiledModel* compiled_model);
 
 }  // namespace metal
diff --git a/tensorflow/lite/delegates/gpu/metal/buffer.mm b/tensorflow/lite/delegates/gpu/metal/buffer.cc
similarity index 86%
rename from tensorflow/lite/delegates/gpu/metal/buffer.mm
rename to tensorflow/lite/delegates/gpu/metal/buffer.cc
index 72a11e4..99dc052 100644
--- a/tensorflow/lite/delegates/gpu/metal/buffer.mm
+++ b/tensorflow/lite/delegates/gpu/metal/buffer.cc
@@ -62,18 +62,18 @@
                                                 id<MTLDevice> device) {
   size_ = desc.size;
   if (desc.data.empty()) {
-    buffer_ = [device newBufferWithLength:size_
-                                         options:MTLResourceStorageModeShared];
+    buffer_ =
+        [device newBufferWithLength:size_ options:MTLResourceStorageModeShared];
   } else {
     buffer_ = [device newBufferWithBytes:desc.data.data()
-                                 length:size_
-                                options:MTLResourceStorageModeShared];
+                                  length:size_
+                                 options:MTLResourceStorageModeShared];
   }
   return absl::OkStatus();
 }
 
 absl::Status CreateBuffer(size_t size_in_bytes, const void* data,
-                                  id<MTLDevice> device, Buffer* result) {
+                          id<MTLDevice> device, Buffer* result) {
   id<MTLBuffer> buffer;
   if (data) {
     buffer = [device newBufferWithBytes:data
@@ -81,7 +81,7 @@
                                 options:MTLResourceStorageModeShared];
   } else {
     buffer = [device newBufferWithLength:size_in_bytes
-                                         options:MTLResourceStorageModeShared];
+                                 options:MTLResourceStorageModeShared];
   }
 
   *result = Buffer(buffer, size_in_bytes);
diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.cc b/tensorflow/lite/delegates/gpu/metal/compute_task.cc
new file mode 100644
index 0000000..d3919db
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/metal/compute_task.cc
@@ -0,0 +1,243 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
+
+#include <Availability.h>
+#include <string>
+#include <tuple>
+
+#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
+#include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/types.h"
+#include "tensorflow/lite/delegates/gpu/common/util.h"
+#include "tensorflow/lite/delegates/gpu/metal/common.h"
+
+namespace tflite {
+namespace gpu {
+namespace metal {
+
+absl::Status ComputeTask::CompileWithDevice(id<MTLDevice> device,
+                                            const NodeDescriptor& desc,
+                                            CalculationsPrecision precision) {
+  size_t offset = desc.task->src_tensors_names.size() +
+                  desc.task->uniform_buffers.size() +
+                  desc.task->immutable_buffers.size() + 1;
+  RETURN_IF_ERROR(metal_args_.Init(device, offset, &desc.task->args,
+                                   &desc.task->shader_source));
+  NSString* barrier;
+  // simdgroup_barrier is supported on macOS 10.13+ and Metal shading language
+  // version 2.0
+  if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
+    barrier = @"simdgroup_barrier";
+  } else {
+    barrier = @"threadgroup_barrier";
+  }
+  NSString* storageType;
+  NSString* accumulatorType;
+  NSString* toAccumulatorType = @"";
+  NSString* toAccumulatorType2 = @"";
+  NSString* toAccumulatorType3 = @"";
+  NSString* toAccumulatorType4 = @"";
+  if (precision == CalculationsPrecision::F32) {
+    storageType = @"float";
+    accumulatorType = @"float";
+  } else {
+    // FP16
+    storageType = @"half";
+    if (precision == CalculationsPrecision::F32_F16) {
+      accumulatorType = @"float";
+      toAccumulatorType = @"float";
+      toAccumulatorType2 = @"float2";
+      toAccumulatorType3 = @"float3";
+      toAccumulatorType4 = @"float4";
+    } else {
+      accumulatorType = @"half";
+    }
+  }
+  NSDictionary<NSString*, NSString*>* macros = @{
+    @"FLT" : storageType,
+    @"FLT2" : [NSString stringWithFormat:@"%@2", storageType],
+    @"FLT3" : [NSString stringWithFormat:@"%@3", storageType],
+    @"FLT4" : [NSString stringWithFormat:@"%@4", storageType],
+    @"ACCUM_FLT" : accumulatorType,
+    @"ACCUM_FLT2" : [NSString stringWithFormat:@"%@2", accumulatorType],
+    @"ACCUM_FLT3" : [NSString stringWithFormat:@"%@3", accumulatorType],
+    @"ACCUM_FLT4" : [NSString stringWithFormat:@"%@4", accumulatorType],
+    @"TO_ACCUM_TYPE" : toAccumulatorType,
+    @"TO_ACCUM2_TYPE" : toAccumulatorType2,
+    @"TO_ACCUM3_TYPE" : toAccumulatorType3,
+    @"TO_ACCUM4_TYPE" : toAccumulatorType4,
+    @"SIMDGROUP_BARRIER" : barrier,
+  };
+
+  NSString* code =
+      [NSString stringWithCString:desc.task->shader_source.c_str()
+                         encoding:[NSString defaultCStringEncoding]];
+  id<MTLComputePipelineState> program;
+  RETURN_IF_ERROR(
+      CreateComputeProgram(device, code, @"ComputeFunction", macros, &program));
+  if (!program) {
+    return absl::InternalError("Unknown shader compilation error");
+  }
+  for (auto& id : desc.src_tensors_ids) {
+    input_buffers_.emplace_back(InputBuffer{id, nil});
+  }
+  for (auto& uniform : desc.task->uniform_buffers) {
+    uniform_buffers_.emplace_back(UniformBuffer{{}, uniform.data_function});
+  }
+  output_buffers_.emplace_back(OutputBuffer{desc.dst_tensors_ids[0], nil});
+  const bool f32_storage = precision == CalculationsPrecision::F32;
+  for (auto& immutable : desc.task->immutable_buffers) {
+    int padding = 4 * (f32_storage ? sizeof(float) : sizeof(HalfBits));
+    int paddedSize = AlignByN(immutable.data.size(), padding);
+    immutable.data.resize(paddedSize);
+    id<MTLBuffer> metalBuffer =
+        [device newBufferWithBytes:immutable.data.data()
+                            length:immutable.data.size()
+                           options:MTLResourceStorageModeShared];
+    immutable_buffers_.emplace_back(metalBuffer);
+  }
+  resize_function_ = desc.task->resize_function;
+  program_ = program;
+  return absl::OkStatus();
+}
+
+absl::Status ComputeTask::UpdateParamsWithDevice(
+      id<MTLDevice> device, const std::map<ValueId, BHWC>& tensor_shapes) {
+  std::vector<BHWC> src_shapes;
+  std::vector<BHWC> dst_shapes;
+  for (const auto& in_buf : input_buffers_) {
+    auto it = tensor_shapes.find(in_buf.uid);
+    if (it == tensor_shapes.end()) {
+      return absl::InvalidArgumentError("Missing tensor shape");
+    }
+    src_shapes.push_back(it->second);
+  }
+  for (const auto& out_buf : output_buffers_) {
+    auto it = tensor_shapes.find(out_buf.uid);
+    if (it == tensor_shapes.end()) {
+      return absl::InvalidArgumentError("Missing tensor shape");
+    }
+    dst_shapes.push_back(it->second);
+  }
+  for (auto& uniform : uniform_buffers_) {
+    uniform.data = uniform.data_function(src_shapes, dst_shapes);
+  }
+
+  // Dispatch parameters re-calculation
+  auto workGroups = resize_function_(src_shapes, dst_shapes);
+  groups_size_ = workGroups.first;
+  MTLSize threadsPerGroup = [device maxThreadsPerThreadgroup];
+  if (groups_size_.x > threadsPerGroup.width ||
+      groups_size_.y > threadsPerGroup.height ||
+      groups_size_.z > threadsPerGroup.depth) {
+    std::string error("Threads per working group: ");
+    error += std::to_string(groups_size_.x) + ", " +
+             std::to_string(groups_size_.y) + ", " +
+             std::to_string(groups_size_.z);
+    error += "is larger than the MTLDevice can support: ";
+    error += std::to_string(threadsPerGroup.width) + ", " +
+             std::to_string(threadsPerGroup.height) + ", " +
+             std::to_string(threadsPerGroup.depth);
+    return absl::InvalidArgumentError(error);
+  }
+  groups_count_ = workGroups.second;
+  return absl::OkStatus();
+}
+
+bool ComputeTask::HasInOutIds(const std::set<ValueId>& ids) const {
+  for (auto& buffer : input_buffers_) {
+    if (ids.count(buffer.uid)) {
+      return true;
+    }
+  }
+  for (auto& buffer : output_buffers_) {
+    if (ids.count(buffer.uid)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+void ComputeTask::EncodeWithEncoder(id<MTLComputeCommandEncoder> encoder) {
+  // The dispatch call is intended to be skipped.
+  if (groups_count_.x * groups_count_.y * groups_count_.z == 0) {
+    return;
+  }
+
+  [encoder setComputePipelineState:program_];
+
+  int bindIndex = 0;
+  for (const auto& buffer : output_buffers_) {
+    [encoder setBuffer:buffer.metal_handle offset:0 atIndex:bindIndex];
+    bindIndex++;
+  }
+  for (const auto& buffer : input_buffers_) {
+    [encoder setBuffer:buffer.metal_handle offset:0 atIndex:bindIndex];
+    bindIndex++;
+  }
+  for (auto& immutable : immutable_buffers_) {
+    [encoder setBuffer:immutable offset:0 atIndex:bindIndex];
+    bindIndex++;
+  }
+  for (auto& uniform : uniform_buffers_) {
+    [encoder setBytes:uniform.data.data()
+               length:uniform.data.size()
+              atIndex:bindIndex];
+    bindIndex++;
+  }
+  metal_args_.Encode(encoder, bindIndex);
+
+  MTLSize groupsCount =
+      MTLSizeMake(groups_count_.x, groups_count_.y, groups_count_.z);
+  MTLSize groupsSize =
+      MTLSizeMake(groups_size_.x, groups_size_.y, groups_size_.z);
+  [encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
+}
+
+std::vector<ValueId> ComputeTask::GetOutputIds() const {
+  std::vector<tflite::gpu::ValueId> result;
+  for (auto& buffer : output_buffers_) {
+    result.push_back(buffer.uid);
+  }
+  return result;
+}
+
+std::vector<ValueId> ComputeTask::GetInputIds() const {
+  std::vector<tflite::gpu::ValueId> result;
+  for (auto& buffer : input_buffers_) {
+    result.push_back(buffer.uid);
+  }
+  return result;
+}
+
+void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
+  input_buffers_[index].metal_handle = tensor.GetBufferHandle();
+}
+
+void ComputeTask::SetDstTensor(const MetalSpatialTensor& tensor, int index) {
+  output_buffers_[index].metal_handle = tensor.GetBufferHandle();
+}
+
+void ComputeTask::SetDescription(const std::string& description) {
+  description_ = description;
+}
+
+}  // namespace metal
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.h b/tensorflow/lite/delegates/gpu/metal/compute_task.h
index b3c32f4..aeb0732 100644
--- a/tensorflow/lite/delegates/gpu/metal/compute_task.h
+++ b/tensorflow/lite/delegates/gpu/metal/compute_task.h
@@ -24,50 +24,80 @@
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/metal/common.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
+#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
+#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
 
-@interface TFLComputeTask : NSObject
+namespace tflite {
+namespace gpu {
+namespace metal {
 
-/// Returns empty string or error if shader can't be compiled.
-- (absl::Status)compileWithDevice:(id<MTLDevice>)device
-                   taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc
-                   runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options;
+class ComputeTask {
+ public:
+  ComputeTask() = default;
 
-/// Updates parameters for inputs/outputs/intermediate tensors
-- (absl::Status)updateParamsWithDevice:(id<MTLDevice>)device
-                          tensorShapes:(const std::map<tflite::gpu::ValueId, tflite::gpu::BHWC>&)
-                                           tensorShapes;
+  // Move only
+  ComputeTask(ComputeTask&& args) = default;
+  ComputeTask& operator=(ComputeTask&& args) = default;
+  ComputeTask(const ComputeTask&) = delete;
+  ComputeTask& operator=(const ComputeTask&) = delete;
 
-/// Updates buffers for intermediate tensors only. Returns error if out of memory or a buffer is
-/// larger than MTLDevice can support.
-/// @param buffers is a map from intermediate tensors' ValueId to metal handles with corresponding
-///        buffers.
-/// @param outputIDs must match the output of added operations.
-/// @param usageRecordIds is a map from intermediate tensors' ValueId to corresponding tensor usage
-/// records ids.
-/// @param sharedBufferIds contain shared buffer id for each tensor usage record id.
-/// @param sharedBuffers contain metal handles to the allocated buffers for each shared buffer id.
-/// TODO(ypisarchyk): probably we can decrease the number of parameters here
-- (absl::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id<MTLBuffer>>*)buffers
-                    outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds
-               usageRecordIds:(const std::map<::tflite::gpu::ValueId, size_t>&)usageRecordIds
-              sharedBufferIds:(const std::vector<size_t>&)sharedBufferIds
-                sharedBuffers:(const std::vector<id<MTLBuffer>>&)sharedBuffers;
+  /// Returns empty string or error if shader can't be compiled.
+  absl::Status CompileWithDevice(id<MTLDevice> device,
+                                 const NodeDescriptor& desc,
+                                 CalculationsPrecision precision);
 
-- (bool)hasInOutIds:(const std::set<::tflite::gpu::ValueId>&)ids;
+  /// Updates parameters for inputs/outputs/intermediate tensors
+  absl::Status UpdateParamsWithDevice(
+      id<MTLDevice> device, const std::map<ValueId, BHWC>& tensor_shapes);
 
-- (void)updateBuffers:(const std::map<::tflite::gpu::ValueId, id<MTLBuffer>>&)inputOutputBuffers;
+  bool HasInOutIds(const std::set<ValueId>& ids) const;
 
-- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder;
+  void EncodeWithEncoder(id<MTLComputeCommandEncoder> encoder);
 
-- (std::vector<tflite::gpu::ValueId>)getOutputIds;
-- (std::vector<tflite::gpu::ValueId>)getInputIds;
+  std::vector<ValueId> GetOutputIds() const;
+  std::vector<ValueId> GetInputIds() const;
 
-- (void)setDescription:(const std::string&)description;
+  void SetSrcTensor(const MetalSpatialTensor& tensor, int index);
 
-@end
+  void SetDstTensor(const MetalSpatialTensor& tensor, int index);
+
+  void SetDescription(const std::string& description);
+
+ private:
+  struct InputBuffer {
+    ValueId uid;
+    id<MTLBuffer> metal_handle;
+  };
+
+  struct OutputBuffer {
+    ValueId uid;
+    id<MTLBuffer> metal_handle;
+  };
+
+  struct UniformBuffer {
+    std::vector<uint8_t> data;
+    UniformsFunction data_function;
+  };
+
+  id<MTLComputePipelineState> program_;
+  std::vector<InputBuffer> input_buffers_;
+  std::vector<OutputBuffer> output_buffers_;
+  std::vector<id<MTLBuffer>> immutable_buffers_;
+  std::vector<UniformBuffer> uniform_buffers_;
+  uint3 groups_size_;
+  uint3 groups_count_;
+  DispatchParamsFunction resize_function_;
+  std::string description_;
+  MetalArguments metal_args_;
+};
+
+}  // namespace metal
+}  // namespace gpu
+}  // namespace tflite
 
 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_
diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.mm b/tensorflow/lite/delegates/gpu/metal/compute_task.mm
deleted file mode 100644
index 62a6a61..0000000
--- a/tensorflow/lite/delegates/gpu/metal/compute_task.mm
+++ /dev/null
@@ -1,301 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
-
-#include <Availability.h>
-#include <string>
-#include <tuple>
-
-#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
-#include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/shape.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/common/types.h"
-#include "tensorflow/lite/delegates/gpu/common/util.h"
-#include "tensorflow/lite/delegates/gpu/metal/common.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
-
-using ::tflite::gpu::AlignByN;
-using ::tflite::gpu::BHWC;
-using ::tflite::gpu::HalfBits;
-using ::tflite::gpu::metal::ComputeTaskDescriptorPtr;
-using ::tflite::gpu::metal::CreateComputeProgram;
-using ::tflite::gpu::metal::DispatchParamsFunction;
-using ::tflite::gpu::metal::RuntimeOptions;
-using ::tflite::gpu::metal::UniformsFunction;
-using ::tflite::gpu::uint3;
-using ::tflite::gpu::ValueId;
-
-namespace {
-
-struct InputBuffer {
-  ValueId uid;
-  id<MTLBuffer> metalHandle;
-};
-
-struct OutputBuffer {
-  ValueId uid;
-  id<MTLBuffer> metalHandle;
-};
-
-struct UniformBuffer {
-  std::vector<uint8_t> data;
-  UniformsFunction dataFunction;
-};
-
-}  // namespace
-
-@implementation TFLComputeTask {
-  id<MTLComputePipelineState> _program;
-  std::vector<InputBuffer> _inputBuffers;
-  std::vector<OutputBuffer> _outputBuffers;
-  std::vector<id<MTLBuffer>> _immutableBuffers;
-  std::vector<UniformBuffer> _uniformBuffers;
-  uint3 _groupsSize;
-  uint3 _groupsCount;
-  DispatchParamsFunction _resizeFunction;
-  std::string _description;
-  tflite::gpu::metal::MetalArguments _metal_args;
-}
-
-- (absl::Status)compileWithDevice:(id<MTLDevice>)device
-                   taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc
-                   runtimeOptions:(const RuntimeOptions&)options {
-  size_t offset = desc.task->src_tensors_names.size() + desc.task->uniform_buffers.size()
-                  + desc.task->immutable_buffers.size() + 1;
-  RETURN_IF_ERROR(_metal_args.Init(device, offset, &desc.task->args, &desc.task->shader_source));
-  NSString* barrier;
-  // simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
-  if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
-    barrier = @"simdgroup_barrier";
-  } else {
-    barrier = @"threadgroup_barrier";
-  }
-  NSString* storageType;
-  NSString* accumulatorType;
-  NSString* toAccumulatorType = @"";
-  NSString* toAccumulatorType2 = @"";
-  NSString* toAccumulatorType3 = @"";
-  NSString* toAccumulatorType4 = @"";
-  if (options.storage_precision == RuntimeOptions::Precision::FP32) {
-    storageType = @"float";
-    accumulatorType = @"float";
-  } else {
-    // FP16
-    storageType = @"half";
-    if (options.accumulator_precision == RuntimeOptions::Precision::FP32) {
-      accumulatorType = @"float";
-      toAccumulatorType = @"float";
-      toAccumulatorType2 = @"float2";
-      toAccumulatorType3 = @"float3";
-      toAccumulatorType4 = @"float4";
-    } else {
-      accumulatorType = @"half";
-    }
-  }
-  NSDictionary<NSString*, NSString*>* macros = @{
-    @"FLT" : storageType,
-    @"FLT2" : [NSString stringWithFormat:@"%@2", storageType],
-    @"FLT3" : [NSString stringWithFormat:@"%@3", storageType],
-    @"FLT4" : [NSString stringWithFormat:@"%@4", storageType],
-    @"ACCUM_FLT" : accumulatorType,
-    @"ACCUM_FLT2" : [NSString stringWithFormat:@"%@2", accumulatorType],
-    @"ACCUM_FLT3" : [NSString stringWithFormat:@"%@3", accumulatorType],
-    @"ACCUM_FLT4" : [NSString stringWithFormat:@"%@4", accumulatorType],
-    @"TO_ACCUM_TYPE" : toAccumulatorType,
-    @"TO_ACCUM2_TYPE" : toAccumulatorType2,
-    @"TO_ACCUM3_TYPE" : toAccumulatorType3,
-    @"TO_ACCUM4_TYPE" : toAccumulatorType4,
-    @"SIMDGROUP_BARRIER" : barrier,
-  };
-
-  NSString* code = [NSString stringWithCString:desc.task->shader_source.c_str()
-                                      encoding:[NSString defaultCStringEncoding]];
-  id<MTLComputePipelineState> program;
-  RETURN_IF_ERROR(CreateComputeProgram(device, code, @"ComputeFunction", macros, &program));
-  if (!program) {
-    return absl::InternalError("Unknown shader compilation error");
-  }
-  for (auto& id : desc.src_tensors_ids) {
-    _inputBuffers.emplace_back(InputBuffer{id, nil});
-  }
-  for (auto& uniform : desc.task->uniform_buffers) {
-    _uniformBuffers.emplace_back(UniformBuffer{{}, uniform.data_function});
-  }
-  _outputBuffers.emplace_back(OutputBuffer{desc.dst_tensors_ids[0], nil});
-  for (auto& immutable : desc.task->immutable_buffers) {
-    int padding =
-        4 * (options.storage_precision == RuntimeOptions::Precision::FP32 ? sizeof(float)
-                                                                          : sizeof(HalfBits));
-    int paddedSize = AlignByN(immutable.data.size(), padding);
-    immutable.data.resize(paddedSize);
-    id<MTLBuffer> metalBuffer = [device newBufferWithBytes:immutable.data.data()
-                                                    length:immutable.data.size()
-                                                   options:MTLResourceStorageModeShared];
-    _immutableBuffers.emplace_back(metalBuffer);
-  }
-  _resizeFunction = desc.task->resize_function;
-  _program = program;
-  return absl::OkStatus();
-}
-
-- (absl::Status)
-    updateParamsWithDevice:(id<MTLDevice>)device
-              tensorShapes:(const std::map<tflite::gpu::ValueId, tflite::gpu::BHWC>&)tensorShapes {
-  std::vector<BHWC> src_shapes;
-  std::vector<BHWC> dst_shapes;
-  for (const auto& in_buf : _inputBuffers) {
-    auto it = tensorShapes.find(in_buf.uid);
-    if (it == tensorShapes.end()) {
-      return absl::InvalidArgumentError("Missing tensor shape");
-    }
-    src_shapes.push_back(it->second);
-  }
-  for (const auto& out_buf : _outputBuffers) {
-    auto it = tensorShapes.find(out_buf.uid);
-    if (it == tensorShapes.end()) {
-      return absl::InvalidArgumentError("Missing tensor shape");
-    }
-    dst_shapes.push_back(it->second);
-  }
-  for (auto& uniform : _uniformBuffers) {
-    uniform.data = uniform.dataFunction(src_shapes, dst_shapes);
-  }
-
-  // Dispatch parameters re-calculation
-  auto workGroups = _resizeFunction(src_shapes, dst_shapes);
-  _groupsSize = workGroups.first;
-  MTLSize threadsPerGroup = [device maxThreadsPerThreadgroup];
-  if (_groupsSize.x > threadsPerGroup.width || _groupsSize.y > threadsPerGroup.height ||
-      _groupsSize.z > threadsPerGroup.depth) {
-    std::string error("Threads per working group: ");
-    error += std::to_string(_groupsSize.x) + ", " + std::to_string(_groupsSize.y) + ", " +
-             std::to_string(_groupsSize.z);
-    error += "is larger than the MTLDevice can support: ";
-    error += std::to_string(threadsPerGroup.width) + ", " + std::to_string(threadsPerGroup.height) +
-             ", " + std::to_string(threadsPerGroup.depth);
-    return absl::InvalidArgumentError(error);
-  }
-  _groupsCount = workGroups.second;
-  return absl::OkStatus();
-}
-
-- (absl::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id<MTLBuffer>>*)buffers
-                    outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds
-               usageRecordIds:(const std::map<ValueId, size_t>&)usageRecordIds
-              sharedBufferIds:(const std::vector<size_t>&)sharedBufferIds
-                sharedBuffers:(const std::vector<id<MTLBuffer>>&)sharedBuffers {
-  for (auto& buffer : _outputBuffers) {
-    // If the buffer is intermediate: set its metalHandle from sharedBuffers
-    if (std::find(outputIds.begin(), outputIds.end(), buffer.uid) == outputIds.end()) {
-      auto usageRecordIt = usageRecordIds.find(buffer.uid);
-      if (usageRecordIt == usageRecordIds.end()) {
-        return absl::InternalError("TensorUsageRecord for intermediate tensor is not found.");
-      }
-      buffer.metalHandle = sharedBuffers.at(sharedBufferIds.at(usageRecordIt->second));
-      (*buffers)[buffer.uid] = buffer.metalHandle;
-    }
-  }
-
-  // Re-assign input buffers
-  for (auto& buffer : _inputBuffers) {
-    buffer.metalHandle = (*buffers)[buffer.uid];
-  }
-  return absl::OkStatus();
-}
-
-- (bool)hasInOutIds:(const std::set<::tflite::gpu::ValueId>&)ids {
-  for (auto& buffer : _inputBuffers) {
-    if (ids.count(buffer.uid)) {
-      return true;
-    }
-  }
-  for (auto& buffer : _outputBuffers) {
-    if (ids.count(buffer.uid)) {
-      return true;
-    }
-  }
-  return false;
-}
-
-- (void)updateBuffers:(const std::map<::tflite::gpu::ValueId, id<MTLBuffer>>&)inputOutputBuffers {
-  for (auto& buffer : _inputBuffers) {
-    const auto externalBuffer = inputOutputBuffers.find(buffer.uid);
-    if (externalBuffer != inputOutputBuffers.end()) {
-      buffer.metalHandle = externalBuffer->second;
-    }
-  }
-  for (auto& buffer : _outputBuffers) {
-    const auto externalBuffer = inputOutputBuffers.find(buffer.uid);
-    if (externalBuffer != inputOutputBuffers.end()) {
-      buffer.metalHandle = externalBuffer->second;
-    }
-  }
-}
-
-- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder {
-  // The dispatch call is intended to be skipped.
-  if (_groupsCount.x * _groupsCount.y * _groupsCount.z == 0) {
-    return;
-  }
-
-  [encoder setComputePipelineState:_program];
-
-  int bindIndex = 0;
-  for (const auto& buffer : _outputBuffers) {
-    [encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
-    bindIndex++;
-  }
-  for (const auto& buffer : _inputBuffers) {
-    [encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
-    bindIndex++;
-  }
-  for (auto& immutable : _immutableBuffers) {
-    [encoder setBuffer:immutable offset:0 atIndex:bindIndex];
-    bindIndex++;
-  }
-  for (auto& uniform : _uniformBuffers) {
-    [encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex];
-    bindIndex++;
-  }
-  _metal_args.Encode(encoder, bindIndex);
-
-  MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z);
-  MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z);
-  [encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
-}
-
-- (std::vector<tflite::gpu::ValueId>)getOutputIds {
-  std::vector<tflite::gpu::ValueId> result;
-  for (auto& buffer : _outputBuffers) {
-    result.push_back(buffer.uid);
-  }
-  return result;
-}
-
-- (std::vector<tflite::gpu::ValueId>)getInputIds {
-  std::vector<tflite::gpu::ValueId> result;
-  for (auto& buffer : _inputBuffers) {
-    result.push_back(buffer.uid);
-  }
-  return result;
-}
-
-- (void)setDescription:(const std::string&)description {
-  _description = description;
-}
-
-@end
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.cc b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
new file mode 100644
index 0000000..76e6782
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
@@ -0,0 +1,308 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
+
+#include <map>
+#include <vector>
+
+#include "absl/strings/substitute.h"
+#include "tensorflow/lite/delegates/gpu/common/memory_management.h"
+#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
+#include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/precision.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/util.h"
+#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
+#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
+#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
+
+namespace tflite {
+namespace gpu {
+namespace metal {
+namespace {
+void AddUsage(ValueId id, int task_index,
+              std::map<ValueId, int2>* usage_records) {
+  auto it = usage_records->find(id);
+  if (it == usage_records->end()) {
+    // initializing start index(.x) and end index(.y)
+    (*usage_records)[id].x = task_index;
+    (*usage_records)[id].y = task_index;
+  } else {
+    // updating end index(.y)
+    (*usage_records)[id].y = task_index;
+  }
+}
+}  // namespace
+
+absl::Status InferenceContext::CompileModelWithDevice(
+    id<MTLDevice> device, const CompiledModel& compiled_model,
+    const std::vector<ValueId>& input_ids,
+    const std::vector<ValueId>& output_ids, CalculationsPrecision precision) {
+  input_ids_ = input_ids;
+  output_ids_ = output_ids;
+  precision_ = precision;
+  // Metal resources are created here.
+  for (const auto& node : compiled_model.nodes) {
+    ComputeTask task;
+    RETURN_IF_ERROR(task.CompileWithDevice(device, node, precision_));
+    task.SetDescription(node.description);
+    compute_tasks_.emplace_back(std::move(task));
+  }
+  tensor_shapes_ = compiled_model.tensor_shapes;
+  for (auto& task : compute_tasks_) {
+    // The same device must be used here as well as on shader compilation stage.
+    RETURN_IF_ERROR(task.UpdateParamsWithDevice(device, tensor_shapes_));
+  }
+  RETURN_IF_ERROR(AllocateTensors(device));
+  return absl::OkStatus();
+}
+
+absl::Status InferenceContext::AllocateTensors(id<MTLDevice> device) {
+  std::set<ValueId> preallocated_ids;
+  for (auto tensor_id : input_ids_) {
+    preallocated_ids.insert(tensor_id);
+  }
+  for (const auto& outputId : output_ids_) {
+    preallocated_ids.insert(outputId);
+  }
+  for (int i = 0; i < compute_tasks_.size(); ++i) {
+    auto& task = compute_tasks_[i];
+    if (task.HasInOutIds(preallocated_ids)) {
+      task_ids_with_preallocated_tensors_.push_back(i);
+    }
+  }
+
+  const bool f32_storage = precision_ == CalculationsPrecision::F32;
+  for (auto& tensor_id : preallocated_ids) {
+    BHWC shape = tensor_shapes_[tensor_id];
+    TensorDescriptor descriptor;
+    descriptor.storage_type = TensorStorageType::BUFFER;
+    descriptor.data_type = f32_storage ? DataType::FLOAT32 : DataType::FLOAT16;
+    descriptor.layout = Layout::HWC;
+    preallocated_tensors_[tensor_id] =
+        CreateSharedBufferTensor(nil, shape, descriptor);
+  }
+
+  RETURN_IF_ERROR(AllocateMemoryForBuffers(device));
+  BindTensorsToOperations();
+  return absl::OkStatus();
+}
+
+MetalSpatialTensor* InferenceContext::GetTensor(ValueId tensor_id) {
+  if (preallocated_tensors_.find(tensor_id) != preallocated_tensors_.end()) {
+    return &preallocated_tensors_[tensor_id];
+  } else if (graph_ids_to_shared_buffer_tensors_.find(tensor_id) !=
+             graph_ids_to_shared_buffer_tensors_.end()) {
+    return &shared_buffer_tensors_
+        [graph_ids_to_shared_buffer_tensors_[tensor_id]];
+  }
+  return nullptr;
+}
+
+void InferenceContext::BindTensorsToOperations() {
+  for (auto& task : compute_tasks_) {
+    const auto& src_ids = task.GetInputIds();
+    for (int i = 0; i < src_ids.size(); ++i) {
+      MetalSpatialTensor* tensor = GetTensor(src_ids[i]);
+      task.SetSrcTensor(*tensor, i);
+    }
+    const auto& dst_ids = task.GetOutputIds();
+    for (int i = 0; i < dst_ids.size(); ++i) {
+      MetalSpatialTensor* tensor = GetTensor(dst_ids[i]);
+      task.SetDstTensor(*tensor, i);
+    }
+  }
+}
+
+void InferenceContext::GetUsages(std::map<ValueId, int2>* usages) {
+  for (ValueId in_id : input_ids_) {
+    if (preallocated_tensors_.find(in_id) == preallocated_tensors_.end()) {
+      AddUsage(in_id, 0, usages);
+    }
+  }
+  for (int op_index = 0; op_index < compute_tasks_.size(); ++op_index) {
+    for (auto& tensor_id : compute_tasks_[op_index].GetInputIds()) {
+      if (preallocated_tensors_.find(tensor_id) ==
+          preallocated_tensors_.end()) {
+        AddUsage(tensor_id, op_index, usages);
+      }
+    }
+    for (auto& tensor_id : compute_tasks_[op_index].GetOutputIds()) {
+      if (preallocated_tensors_.find(tensor_id) ==
+          preallocated_tensors_.end()) {
+        AddUsage(tensor_id, op_index, usages);
+      }
+    }
+  }
+  for (ValueId out_id : output_ids_) {
+    if (preallocated_tensors_.find(out_id) == preallocated_tensors_.end()) {
+      AddUsage(out_id, compute_tasks_.size(), usages);
+    }
+  }
+}
+
+absl::Status InferenceContext::AllocateMemoryForBuffers(id<MTLDevice> device) {
+  std::map<ValueId, int2> buffer_usages;
+  GetUsages(&buffer_usages);
+
+  std::vector<TensorUsageRecord<size_t>> buffer_usage_records;
+  for (auto& usage : buffer_usages) {
+    const auto& shape = tensor_shapes_[usage.first];
+    const size_t buffer_size =
+        shape.b * shape.w * shape.h * AlignByN(shape.c, 4);
+    graph_ids_to_shared_buffer_tensors_[usage.first] =
+        buffer_usage_records.size();
+    buffer_usage_records.push_back({buffer_size,
+                                    static_cast<TaskId>(usage.second.x),
+                                    static_cast<TaskId>(usage.second.y)});
+  }
+
+  ObjectsAssignment<size_t> buffer_assignment;
+  RETURN_IF_ERROR(AssignObjectsToTensors(
+      buffer_usage_records, MemoryStrategy::GREEDY_BEST, &buffer_assignment));
+
+  const bool f32_storage = precision_ == CalculationsPrecision::F32;
+  size_t dataTypeSize = f32_storage ? sizeof(float) : sizeof(HalfBits);
+  shared_buffers_.resize(buffer_assignment.object_sizes.size());
+  for (int i = 0; i < buffer_assignment.object_sizes.size(); ++i) {
+    // Initialize metal buffer
+    NSUInteger bufferSize = dataTypeSize * buffer_assignment.object_sizes[i];
+
+#if (defined(__MAC_10_14) &&                               \
+     __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_14) ||    \
+    (defined(__IPHONE_12_0) &&                             \
+     __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_12_0) || \
+    (defined(__TVOS_12_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_12_0)
+    if (bufferSize > [device maxBufferLength]) {
+      std::string error("Tensor id: ");
+      error += std::to_string(buffer_assignment.object_ids[i]) +
+               " with size: " + std::to_string(bufferSize) +
+               " exceeds MTLDevice maxBufferLength: " +
+               std::to_string([device maxBufferLength]);
+      return absl::ResourceExhaustedError(error);
+    }
+#endif
+#if defined(__MAC_10_12) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_12
+    if ([device currentAllocatedSize] + bufferSize >
+        [device recommendedMaxWorkingSetSize]) {
+      std::string error(
+          "Out of memory in MTLBuffer allocation. Currently allocated: ");
+      error += std::to_string([device currentAllocatedSize]);
+      return absl::ResourceExhaustedError(error);
+    }
+#endif
+
+    shared_buffers_[i] =
+        [device newBufferWithLength:bufferSize
+                            options:MTLResourceStorageModeShared];
+  }
+
+  std::vector<bool> created_tensors(buffer_usage_records.size(), false);
+  shared_buffer_tensors_.resize(buffer_usage_records.size());
+  TensorDescriptor descriptor;
+  descriptor.storage_type = TensorStorageType::BUFFER;
+  descriptor.data_type = f32_storage ? DataType::FLOAT32 : DataType::FLOAT16;
+  descriptor.layout = Layout::HWC;
+  for (auto& task : compute_tasks_) {
+    const std::vector<ValueId> input_ids = task.GetInputIds();
+    const std::vector<ValueId> output_ids = task.GetOutputIds();
+    std::vector<ValueId> all_ids = input_ids;
+    all_ids.insert(all_ids.end(), output_ids.begin(), output_ids.end());
+    for (auto& tensor_id : all_ids) {
+      if (preallocated_tensors_.find(tensor_id) != preallocated_tensors_.end())
+        continue;
+      const int tensor_index = graph_ids_to_shared_buffer_tensors_[tensor_id];
+      if (created_tensors[tensor_index]) continue;
+      const auto& shape = tensor_shapes_[tensor_id];
+      const int buffer_index = buffer_assignment.object_ids[tensor_index];
+      shared_buffer_tensors_[tensor_index] = CreateSharedBufferTensor(
+          shared_buffers_[buffer_index], shape, descriptor);
+      created_tensors[tensor_index] = true;
+    }
+  }
+  return absl::OkStatus();
+}
+
+void InferenceContext::EncodeWithEncoder(
+    id<MTLComputeCommandEncoder> command_encoder,
+    const std::map<ValueId, id<MTLBuffer>>& in_out_buffers) {
+  UpdatePreallocatedTensors(in_out_buffers);
+  for (int i = 0; i < compute_tasks_.size(); ++i) {
+    auto& task = compute_tasks_[i];
+    task.EncodeWithEncoder(command_encoder);
+  }
+}
+
+void InferenceContext::EncodeWithCommandBuffer(
+    id<MTLCommandBuffer> command_buffer,
+    const std::map<ValueId, id<MTLBuffer>>& in_out_buffers) {
+  UpdatePreallocatedTensors(in_out_buffers);
+  for (int i = 0; i < compute_tasks_.size(); ++i) {
+    id<MTLComputeCommandEncoder> encoder =
+        [command_buffer computeCommandEncoder];
+    auto& task = compute_tasks_[i];
+    task.EncodeWithEncoder(encoder);
+    [encoder endEncoding];
+  }
+}
+
+void InferenceContext::EncodeWithCommandQueue(
+    id<MTLCommandQueue> command_queue,
+    const std::map<ValueId, id<MTLBuffer>>& in_out_buffers, int flush_period) {
+  UpdatePreallocatedTensors(in_out_buffers);
+  id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
+  for (int i = 0; i < compute_tasks_.size(); ++i) {
+    id<MTLComputeCommandEncoder> encoder =
+        [command_buffer computeCommandEncoder];
+    auto& task = compute_tasks_[i];
+    task.EncodeWithEncoder(encoder);
+    [encoder endEncoding];
+    if (i % flush_period == (flush_period - 1)) {
+      [command_buffer commit];
+      command_buffer = [command_queue commandBuffer];
+    }
+  }
+  [command_buffer commit];
+}
+
+void InferenceContext::UpdatePreallocatedTensors(
+    const std::map<ValueId, id<MTLBuffer>>& preallocated) {
+  for (const auto& it : preallocated) {
+    preallocated_tensors_[it.first].SetBufferHandle(it.second);
+  }
+  for (auto& task_index : task_ids_with_preallocated_tensors_) {
+    auto& task = compute_tasks_[task_index];
+    const auto& src_ids = task.GetInputIds();
+    for (int i = 0; i < src_ids.size(); ++i) {
+      const auto& it = preallocated_tensors_.find(src_ids[i]);
+      if (it != preallocated_tensors_.end()) {
+        task.SetSrcTensor(it->second, i);
+      }
+    }
+    const auto& dst_ids = task.GetOutputIds();
+    for (int i = 0; i < dst_ids.size(); ++i) {
+      const auto& it = preallocated_tensors_.find(dst_ids[i]);
+      if (it != preallocated_tensors_.end()) {
+        task.SetDstTensor(it->second, i);
+      }
+    }
+  }
+}
+
+}  // namespace metal
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.h b/tensorflow/lite/delegates/gpu/metal/inference_context.h
index c215a91..36bc5bc 100644
--- a/tensorflow/lite/delegates/gpu/metal/inference_context.h
+++ b/tensorflow/lite/delegates/gpu/metal/inference_context.h
@@ -23,11 +23,17 @@
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
+#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
+#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
+
+namespace tflite {
+namespace gpu {
+namespace metal {
 
 /// Stages of model preprocessing:
 /// 1. Operations' initialization. All operations are initialized and added into
@@ -39,38 +45,97 @@
 /// 3. GPU compute tasks generation. Shader code generation happens here.
 /// 4. Intermediate resource allocation.
 /// Inference.
-@interface TFLInferenceContext : NSObject
 
-/// Compiles model: groups operations to be fused; validates model structure.
-/// @param device Used to create resources: shaders, buffers. Also the device is used in
-///             consecutive call setInputDimensions().
-/// @param model Contains ordered vector of shader programs ready to be compiled for GPU and
-///             with all supplementary buffers data.
-/// @param inputBufferIDs IDs must match the input of added operations.
-/// @param outputBufferIDs IDs must match the output of added operations.
-/// @param runtimeOptions Options are used to specify data/calculations precision.
-/// @return Status signals whether model is compiled successfully or not.
-/// @discussion Previously added operations are distilled into sorted list of sets of
-///             ComputeTaskDescriptors, which can be fused into a single GPU task.
-- (absl::Status)compileModelWithDevice:(id<MTLDevice>)device
-                                 model:(const tflite::gpu::metal::CompiledModel&)compiledModel
-                        inputBufferIDs:(const std::vector<tflite::gpu::ValueId>&)inputBufferIDs
-                       outputBufferIDs:(const std::vector<tflite::gpu::ValueId>&)outputBufferIDs
-                        runtimeOptions:(const tflite::gpu::metal::RuntimeOptions&)options;
+class InferenceContext {
+ public:
+  InferenceContext() = default;
+  /// Compiles model: groups operations to be fused; validates model structure.
+  /// @param device Used to create resources: shaders, buffers. Also the device
+  /// is used in
+  ///             consecutive call setInputDimensions().
+  /// @param model Contains ordered vector of shader programs ready to be
+  /// compiled for GPU and
+  ///             with all supplementary buffers data.
+  /// @param inputBufferIDs IDs must match the input of added operations.
+  /// @param outputBufferIDs IDs must match the output of added operations.
+  /// @param runtimeOptions Options are used to specify data/calculations
+  /// precision.
+  /// @return Status signals whether model is compiled successfully or not.
+  /// @discussion Previously added operations are distilled into sorted list of
+  /// sets of
+  ///             ComputeTaskDescriptors, which can be fused into a single GPU
+  ///             task.
+  absl::Status CompileModelWithDevice(id<MTLDevice> device,
+                                      const CompiledModel& compiled_model,
+                                      const std::vector<ValueId>& input_ids,
+                                      const std::vector<ValueId>& output_ids,
+                                      CalculationsPrecision precision);
 
-/// Inserts all GPU compute tasks into the command encoder.
-/// @param inputOutputBuffers Must be created and passed into the method with pairs ID:buffer
-/// @param encoderBlock User-defined block to take control over command encoder. Can be nil.
-///             The block can be used, for example, for fine-grained benchmarking where end encoding
-///             is performed and command buffer is committed with completion block. A new command
-///             buffer must be created and new command encoder must be returned by the block.
-///             The block is called after every dispatch encoding.
-/// @discussion No GPU synchronization functions are used inside. All GPU resources must be created
-///             with the same device which has been used in compileModelWithDevice() method.
-- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)commandEncoder
-       inputOutputBuffers:(const std::map<::tflite::gpu::ValueId, id<MTLBuffer>>&)inputOutputBuffers
-             encoderBlock:(id<MTLComputeCommandEncoder> (^)(bool isLast))encoderBlock;
+  /// Inserts all GPU compute tasks into the command encoder.
+  /// @param inputOutputBuffers Must be created and passed into the method
+  /// with pairs ID:buffer
+  /// @discussion No GPU synchronization functions are used inside. All GPU
+  /// resources must be created
+  ///             with the same device which has been used in
+  ///             compileModelWithDevice() method.
+  void EncodeWithEncoder(
+      id<MTLComputeCommandEncoder> command_encoder,
+      const std::map<ValueId, id<MTLBuffer>>& in_out_buffers);
 
-@end
+  /// Inserts all GPU compute tasks into the command buffer. For every task will
+  /// be used separate
+  ///   encoder.
+  /// @param inputOutputBuffers Must be created and passed into the method with
+  /// pairs ID:buffer
+  /// @discussion No GPU synchronization functions are used inside. All GPU
+  /// resources must be created
+  ///             with the same device which has been used in
+  ///             compileModelWithDevice() method.
+  void EncodeWithCommandBuffer(
+      id<MTLCommandBuffer> command_buffer,
+      const std::map<ValueId, id<MTLBuffer>>& in_out_buffers);
+
+  /// Adds all GPU compute tasks to the command queue. For every task will be
+  /// used separate
+  ///   encoder. Few encoders(flushPeriod) batched into compute buffer that sent
+  ///   for execution.
+  /// @param inputOutputBuffers Must be created and passed into the method with
+  /// pairs ID:buffer
+  /// @discussion No GPU synchronization functions are used inside. All GPU
+  /// resources must be created
+  ///             with the same device which has been used in
+  ///             compileModelWithDevice() method.
+  void EncodeWithCommandQueue(
+      id<MTLCommandQueue> command_queue,
+      const std::map<ValueId, id<MTLBuffer>>& in_out_buffers, int flush_period);
+
+ private:
+  absl::Status AllocateTensors(id<MTLDevice> device);
+  absl::Status AllocateMemoryForBuffers(id<MTLDevice> device);
+  void BindTensorsToOperations();
+  MetalSpatialTensor* GetTensor(ValueId tensor_id);
+  void GetUsages(std::map<ValueId, int2>* usages);
+  void UpdatePreallocatedTensors(
+      const std::map<ValueId, id<MTLBuffer>>& preallocated);
+
+  std::vector<ComputeTask> compute_tasks_;
+  // contains indexes of compute_tasks_
+  std::vector<int> task_ids_with_preallocated_tensors_;
+  std::vector<ValueId> input_ids_;
+  std::vector<ValueId> output_ids_;
+  CalculationsPrecision precision_;
+  std::map<ValueId, BHWC> tensor_shapes_;
+  std::map<ValueId, MetalSpatialTensor> preallocated_tensors_;
+
+  std::map<ValueId, int> graph_ids_to_shared_buffer_tensors_;
+  std::vector<id<MTLBuffer>> shared_buffers_;
+  std::vector<MetalSpatialTensor>
+      shared_buffer_tensors_;  // use references to memory
+                               // from _sharedBuffers
+};
+
+}  // namespace metal
+}  // namespace gpu
+}  // namespace tflite
 
 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.mm b/tensorflow/lite/delegates/gpu/metal/inference_context.mm
deleted file mode 100644
index 84322a4..0000000
--- a/tensorflow/lite/delegates/gpu/metal/inference_context.mm
+++ /dev/null
@@ -1,183 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
-
-#include <map>
-#include <vector>
-
-#include "absl/strings/substitute.h"
-#include "tensorflow/lite/delegates/gpu/common/memory_management.h"
-#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
-#include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/shape.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/common/util.h"
-#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
-#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
-
-using ::tflite::gpu::BHWC;
-using ::tflite::gpu::metal::ComputeTaskDescriptorPtr;
-using ::tflite::gpu::metal::RuntimeOptions;
-using ::tflite::gpu::ValueId;
-using ::tflite::gpu::AlignByN;
-using ::tflite::gpu::HalfBits;
-using ::tflite::gpu::MemoryStrategy;
-using ::tflite::gpu::TensorUsageRecord;
-
-@implementation TFLInferenceContext {
-  std::vector<TFLComputeTask*> _computeTasks;
-  // contains indexes of _computeTasks
-  std::vector<int> _taskIdsWithInOutBuffers;
-  std::vector<ValueId> _inputIds;
-  std::vector<ValueId> _outputIds;
-  id<MTLDevice> _device;
-  RuntimeOptions _options;
-  std::map<ValueId, BHWC> _tensorShapes;
-}
-
-- (absl::Status)compileModelWithDevice:(id<MTLDevice>)device
-                                 model:(const tflite::gpu::metal::CompiledModel&) compiledModel
-                        inputBufferIDs:(const std::vector<tflite::gpu::ValueId>&)inputBufferIDs
-                       outputBufferIDs:(const std::vector<tflite::gpu::ValueId>&)outputBufferIDs
-                        runtimeOptions:(const RuntimeOptions&)options {
-  _device = device;
-  _inputIds = inputBufferIDs;
-  _outputIds = outputBufferIDs;
-  _options = options;
-  // Metal resources are created here.
-  for (const auto& node : compiledModel.nodes) {
-    TFLComputeTask* task = [[TFLComputeTask alloc] init];
-    RETURN_IF_ERROR([task compileWithDevice:_device
-                             taskDescriptor:node
-                             runtimeOptions:_options]);
-    [task setDescription:node.description];
-    _computeTasks.emplace_back(task);
-  }
-  _tensorShapes = compiledModel.tensor_shapes;
-  [self allocateTensors];
-  return absl::OkStatus();
-}
-
-- (absl::Status)allocateTensors {
-  // These maps contain all input/output/intermediate buffers shared across model.
-  std::map<ValueId, id<MTLBuffer>> buffers;
-  std::set<ValueId> preallocatedIds;
-  // Insert uninitialized input buffers. This buffers will be set externally.
-  for (auto tensor_id : _inputIds) {
-    buffers[tensor_id] = nil;
-    preallocatedIds.insert(tensor_id);
-  }
-  for (const auto& outputId : _outputIds) {
-    preallocatedIds.insert(outputId);
-  }
-  for (auto& task : _computeTasks) {
-    // The same device must be used here as well as on shader compilation stage.
-    RETURN_IF_ERROR([task updateParamsWithDevice:_device tensorShapes:_tensorShapes]);
-  }
-
-  // TODO(ypisarchyk): it make sense to move it to separate function
-  // Generate usage records for each intermediate tensor in order of their first_task
-  std::vector<TensorUsageRecord<size_t>> usageRecords;
-  std::map<ValueId, size_t> usageRecordIds;
-  for (uint32_t i = 0; i < _computeTasks.size(); ++i) {
-    for (const auto tensor_id : [_computeTasks[i] getOutputIds]) {
-      if (!preallocatedIds.count(tensor_id)) {
-        if (!usageRecordIds.count(tensor_id)) {
-          const auto it = _tensorShapes.find(tensor_id);
-          if (it == _tensorShapes.end()) {
-            return absl::InternalError("Dimensions for intermediate tensor not found.");
-          }
-          usageRecordIds[tensor_id] = usageRecords.size();
-          usageRecords.emplace_back(it->second.w * it->second.h * AlignByN(it->second.c, 4), i, i);
-        } else {
-          usageRecords[usageRecordIds[tensor_id]].last_task = i;
-        }
-      }
-    }
-    for (const auto tensor_id : [_computeTasks[i] getInputIds]) {
-      if (!preallocatedIds.count(tensor_id)) {
-        usageRecords[usageRecordIds[tensor_id]].last_task = i;
-      }
-    }
-  }
-
-  tflite::gpu::ObjectsAssignment<size_t> assignment;
-  RETURN_IF_ERROR(AssignObjectsToTensors(usageRecords, MemoryStrategy::GREEDY_BEST, &assignment));
-  auto objectsCount = assignment.object_sizes.size();
-  std::vector<id<MTLBuffer>> sharedBuffers(objectsCount);
-  size_t dataTypeSize = _options.storage_precision == RuntimeOptions::Precision::FP32
-                            ? sizeof(float)
-                            : sizeof(HalfBits);
-
-  // allocate buffers for each shared object
-  for (size_t i = 0; i < objectsCount; ++i) {
-    // Initialize metal buffer
-    NSUInteger bufferSize = dataTypeSize * assignment.object_sizes[i];
-
-#if (defined(__MAC_10_14) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_14) ||      \
-    (defined(__IPHONE_12_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_12_0) || \
-    (defined(__TVOS_12_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_12_0)
-    if (bufferSize > [_device maxBufferLength]) {
-      std::string error("Tensor id: ");
-      error += std::to_string(assignment.object_ids[i]) +
-               " with size: " + std::to_string(bufferSize) +
-               " exceeds MTLDevice maxBufferLength: " + std::to_string([_device maxBufferLength]);
-      return absl::ResourceExhaustedError(error);
-    }
-#endif
-#if defined(__MAC_10_12) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_12
-    if ([_device currentAllocatedSize] + bufferSize > [_device recommendedMaxWorkingSetSize]) {
-      std::string error("Out of memory in MTLBuffer allocation. Currently allocated: ");
-      error += std::to_string([_device currentAllocatedSize]);
-      return absl::ResourceExhaustedError(error);
-    }
-#endif
-
-    sharedBuffers[i] = [_device newBufferWithLength:bufferSize
-                                            options:MTLResourceStorageModeShared];
-  }
-  for (int i = 0; i < _computeTasks.size(); ++i) {
-    auto& task = _computeTasks[i];
-    if ([task hasInOutIds:preallocatedIds]) {
-      _taskIdsWithInOutBuffers.push_back(i);
-    }
-    RETURN_IF_ERROR([task assignBuffers:&buffers
-                              outputIds:_outputIds
-                         usageRecordIds:usageRecordIds
-                        sharedBufferIds:assignment.object_ids
-                          sharedBuffers:sharedBuffers]);
-  }
-  return absl::OkStatus();
-}
-
-- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)commandEncoder
-       inputOutputBuffers:(const std::map<ValueId, id<MTLBuffer>>&)inputOutputBuffers
-             encoderBlock:(id<MTLComputeCommandEncoder> (^)(bool isLast))encoderBlock {
-  for (auto& task_index : _taskIdsWithInOutBuffers) {
-    auto& task = _computeTasks[task_index];
-    [task updateBuffers:inputOutputBuffers];
-  }
-  for (int i = 0; i < _computeTasks.size(); ++i) {
-    auto& task = _computeTasks[i];
-    [task encodeWithEncoder:commandEncoder];
-    if (encoderBlock != nil) {
-      commandEncoder = encoderBlock(i == _computeTasks.size() - 1);
-    }
-  }
-}
-
-@end
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index 4dd0ed1..ee7f0e8 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -169,18 +169,6 @@
 )
 
 cc_library(
-    name = "custom_registry",
-    srcs = ["custom_registry.cc"],
-    hdrs = ["custom_registry.h"],
-    deps = [
-        "//tensorflow/lite/delegates/gpu/common:model",
-        "//tensorflow/lite/delegates/gpu/common:status",
-        "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
-        "//tensorflow/lite/delegates/gpu/metal:runtime_options",
-    ],
-)
-
-cc_library(
     name = "depthwise_conv",
     srcs = ["depthwise_conv.cc"],
     hdrs = ["depthwise_conv.h"],
@@ -803,9 +791,12 @@
     name = "test_util",
     testonly = 1,
     srcs = [
-        "test_util.mm",
+        "test_util.cc",
     ],
     hdrs = ["test_util.h"],
+    copts = [
+        "-ObjC++",
+    ],
     sdk_frameworks = [
         "Metal",
     ],
@@ -814,6 +805,7 @@
         "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:precision",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
@@ -823,9 +815,10 @@
         "//tensorflow/lite/delegates/gpu/metal:common",
         "//tensorflow/lite/delegates/gpu/metal:compiled_model",
         "//tensorflow/lite/delegates/gpu/metal:inference_context",
-        "//tensorflow/lite/delegates/gpu/metal:runtime_options",
+        "//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor",
         "@FP16",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -897,12 +890,12 @@
     deps = [
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:gpu_info",
+        "//tensorflow/lite/delegates/gpu/common:precision",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
         "//tensorflow/lite/delegates/gpu/metal:common",
         "//tensorflow/lite/delegates/gpu/metal:inference_context",
-        "//tensorflow/lite/delegates/gpu/metal:runtime_options",
     ],
 )
 
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm
index 22a798c..3facbc4 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::ElementwiseAttributes;
 using ::tflite::gpu::BHWC;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm
index 195a298..6ac084c 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::Axis;
 using ::tflite::gpu::BHWC;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm
index 71ea6f9..f9aa9e8 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm
@@ -28,7 +28,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::Axis;
 using ::tflite::gpu::BHWC;
@@ -285,16 +284,7 @@
     src_tensor.data[i] = sin(i);
   }
 
-  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
-  tflite::gpu::metal::RuntimeOptions options;
-  options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32;
-  options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32;
-
-  std::map<ValueId, TensorFloat32> inputs_v0;
-  inputs_v0[0] = src_tensor;
-  std::map<ValueId, TensorFloat32> outputs_v0;
-  outputs_v0[1].shape = dst_shape;
-  outputs_v0[1].data.resize(dst_shape.DimensionsProduct());
+  TensorFloat32 output0;
 
   tflite::gpu::OperationDef op_def;
   op_def.precision = tflite::gpu::CalculationsPrecision::F32;
@@ -303,61 +293,43 @@
   op_def.src_tensors.push_back(tensor_descriptor);
   op_def.dst_tensors.push_back(tensor_descriptor);
 
-  std::string device_name = std::string([[device name] UTF8String]);
-  tflite::gpu::GpuInfo gpu_info;
-  tflite::gpu::GetGpuInfoFromDeviceDescription(device_name, tflite::gpu::GpuApi::kMetal, &gpu_info);
-  auto gpu_op0 = ConvolutionGeneric(op_def, dst_shape, attr, gpu_info);
-  std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
-  nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op0));
-  nodes[0].src_tensors_ids = {0};
-  nodes[0].dst_tensors_ids = {1};
-  auto status = RunGraph(nodes, device, inputs_v0, &outputs_v0);
+  tflite::gpu::metal::MetalExecutionEnvironment env;
+  auto gpu_op0 = ConvolutionGeneric(op_def, dst_shape, attr, env.GetGpuInfo());
+  auto op0_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op0));
+  auto status = env.ExecuteGPUOperation(src_tensor, std::move(op0_ptr), dst_shape, &output0);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 
   tflite::gpu::metal::Winograd4x4To36Attributes wino_up_attr;
   wino_up_attr.padding = attr.padding;
   auto gpu_op1 = tflite::gpu::metal::Winograd4x4To36(op_def, wino_up_attr);
+  auto op1_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op1));
 
-  auto gpu_op2 = ConvolutionWino4x4To6x6(op_def, conv_shape, attr, gpu_info);
+  auto gpu_op2 = ConvolutionWino4x4To6x6(op_def, conv_shape, attr, env.GetGpuInfo());
+  auto op2_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op2));
 
   tflite::gpu::metal::Winograd36To4x4Attributes wino_down_attr;
   wino_down_attr.output_shape = dst_shape;
   wino_down_attr.biases = attr.bias;
   auto gpu_op3 = tflite::gpu::metal::Winograd36To4x4(op_def, wino_down_attr);
+  auto op3_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op3));
 
-  std::map<ValueId, TensorFloat32> inputs_v1;
-  inputs_v1[0] = src_tensor;
-  std::map<ValueId, TensorFloat32> outputs_v1;
-  outputs_v1[2].shape = conv_shape;
-  outputs_v1[2].shape.c = src_shape.c;
-  outputs_v1[2].data.resize(outputs_v1[2].shape.DimensionsProduct());
-  nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op1));
-  nodes[0].src_tensors_ids = {0};
-  nodes[0].dst_tensors_ids = {2};
-  status = RunGraph(nodes, device, inputs_v1, &outputs_v1);
-
-  std::map<ValueId, TensorFloat32> inputs_v2;
-  inputs_v2[2] = outputs_v1[2];
-  std::map<ValueId, TensorFloat32> outputs_v2;
-  outputs_v2[3].shape = conv_shape;
-  outputs_v2[3].data.resize(outputs_v2[3].shape.DimensionsProduct());
-  nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op2));
-  nodes[0].src_tensors_ids = {2};
-  nodes[0].dst_tensors_ids = {3};
-  status = RunGraph(nodes, device, inputs_v2, &outputs_v2);
-
-  std::map<ValueId, TensorFloat32> inputs_v3;
-  inputs_v3[3] = outputs_v2[3];
-  std::map<ValueId, TensorFloat32> outputs_v3;
-  outputs_v3[1].shape = dst_shape;
-  outputs_v3[1].data.resize(outputs_v3[1].shape.DimensionsProduct());
-  nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op3));
-  nodes[0].src_tensors_ids = {3};
-  nodes[0].dst_tensors_ids = {1};
-  status = RunGraph(nodes, device, inputs_v3, &outputs_v3);
+  TensorFloat32 output1;
+  BHWC output1_shape = conv_shape;
+  output1_shape.c = src_shape.c;
+  status = env.ExecuteGPUOperation(src_tensor, std::move(op1_ptr), output1_shape, &output1);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 
-  status = CompareVectors(outputs_v0[1].data, outputs_v3[1].data, 1e-4f);
+  TensorFloat32 output2;
+  BHWC output2_shape = conv_shape;
+  status = env.ExecuteGPUOperation(output1, std::move(op2_ptr), output2_shape, &output2);
+  XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
+
+  TensorFloat32 output3;
+  BHWC output3_shape = dst_shape;
+  status = env.ExecuteGPUOperation(output2, std::move(op3_ptr), output3_shape, &output3);
+  XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
+
+  status = CompareVectors(output0.data, output3.data, 1e-4f);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 }
 
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc
deleted file mode 100644
index 620a458..0000000
--- a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc
+++ /dev/null
@@ -1,39 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h"
-
-#include <vector>
-
-#include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
-
-namespace tflite {
-namespace gpu {
-namespace metal {
-
-absl::Status RegisterCustomOps(const GraphFloat32& graph, const Node* node,
-                               const std::vector<ValueId>& inputs,
-                               const std::vector<ValueId>& outputs,
-                               const RuntimeOptions& options,
-                               std::vector<ComputeTaskDescriptorPtr>* tasks) {
-  return absl::UnimplementedError("Unsupported op: " + node->operation.type);
-}
-
-}  // namespace metal
-}  // namespace gpu
-}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h
deleted file mode 100644
index eee1632..0000000
--- a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h
+++ /dev/null
@@ -1,41 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CUSTOM_REGISTRY_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CUSTOM_REGISTRY_H_
-
-#include <vector>
-
-#include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
-
-namespace tflite {
-namespace gpu {
-namespace metal {
-
-// Registers custom operations.
-absl::Status RegisterCustomOps(const GraphFloat32& graph, const Node* node,
-                               const std::vector<ValueId>& inputs,
-                               const std::vector<ValueId>& outputs,
-                               const RuntimeOptions& options,
-                               std::vector<ComputeTaskDescriptorPtr>* tasks);
-
-}  // namespace metal
-}  // namespace gpu
-}  // namespace tflite
-
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CUSTOM_REGISTRY_H_
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm
index dcf550f..817a371 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::Axis;
 using ::tflite::gpu::BHWC;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm
index 867ed59..5826e2b 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::DataType;
 using ::tflite::gpu::HWC;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm
index e57f9aa..b6e4cb9 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::BHWC;
 using ::tflite::gpu::DataType;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm
index cf4aacf..5ee3603 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::BHWC;
 using ::tflite::gpu::DataType;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm
index 67325c1..e4fa301 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::Axis;
 using ::tflite::gpu::BHWC;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm
index 9c55cfc..e8c0ef6 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::BHWC;
 using ::tflite::gpu::DataType;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm
index d2d95b3..a28dd64 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::BHWC;
 using ::tflite::gpu::DataType;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm
index 1df08be..76642f5 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::BHWC;
 using ::tflite::gpu::DataType;
@@ -40,10 +39,10 @@
 using ::tflite::gpu::metal::CompareVectors;
 using ::tflite::gpu::metal::SingleOpModel;
 
-@interface SoftmaxTest : XCTestCase
+@interface PReLUTest : XCTestCase
 @end
 
-@implementation SoftmaxTest
+@implementation PReLUTest
 - (void)setUp {
   [super setUp];
 }
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize_test.mm
index 7a16f1d..7eb71bf 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize_test.mm
@@ -25,7 +25,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
 
 using ::tflite::NudgeQuantizationRange;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm
index 52de77e..c4eb8f5 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::BHWC;
 using ::tflite::gpu::DataType;
@@ -37,10 +36,10 @@
 using ::tflite::gpu::metal::CompareVectors;
 using ::tflite::gpu::metal::SingleOpModel;
 
-@interface SliceTest : XCTestCase
+@interface ReLUTest : XCTestCase
 @end
 
-@implementation SliceTest
+@implementation ReLUTest
 - (void)setUp {
   [super setUp];
 }
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm
index 684e83b..9a64ef5 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::BHWC;
 using ::tflite::gpu::DataType;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm
index 082f2c8..f087777 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::BHWC;
 using ::tflite::gpu::DataType;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm
index e0c2956..25b45d4 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::BHWC;
 using ::tflite::gpu::DataType;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
index 57609bc..9a3a8ea 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
@@ -50,21 +50,55 @@
                             uint tid[[thread_index_in_threadgroup]],
                             uint3 ugid[[thread_position_in_grid]])
 {
-  int offset = 0;
-  float sum = 0.0f;
-  int s = 0;
-  do {
-    if (offset + tid < params.size.x) {
-      float4 mask_temp = offset + tid == params.size.x - 1 ? params.mask : float4(1.0h);
-      float4 src = float4(src_tensor[offset + tid]);
-      sum += dot(mask_temp, exp(src));
-      offset += 32;
-    }
-    s++;
-  } while (s < params.size.y);
+
+  float4 maxx4 = float4(src_tensor[0].x);
+  for (int s = int(tid); s < params.size.x; s += 32) {
+    float4 mask_a = s == params.size.x - 1 ? params.mask : float4(1.0f);
+    float4 mask_b = float4(1.0f) - mask_a;
+    float4 src = float4(src_tensor[s]);
+    src = src * mask_a + mask_b * src.x;
+    maxx4 = max(maxx4, src);
+  }
+  float maximum = max(maxx4.x, maxx4.y);
+  maximum = max(maximum, maxx4.z);
+  maximum = max(maximum, maxx4.w);
 
   threadgroup float4 tmp[8];
   threadgroup float* tmpx1 = (threadgroup float*)tmp;
+
+  tmpx1[tid] = maximum;
+)";
+  code += "  " + barrier + "(mem_flags::mem_threadgroup);\n";
+  code += R"(
+  if (tid == 0) {
+    maxx4 = max(tmp[0], tmp[1]);
+    maxx4 = max(maxx4, tmp[2]);
+    maxx4 = max(maxx4, tmp[3]);
+    maxx4 = max(maxx4, tmp[4]);
+    maxx4 = max(maxx4, tmp[5]);
+    maxx4 = max(maxx4, tmp[6]);
+    maxx4 = max(maxx4, tmp[7]);
+    maximum = max(maxx4.x, maxx4.y);
+    maximum = max(maximum, maxx4.z);
+    maximum = max(maximum, maxx4.w);
+    tmpx1[0] = maximum;
+  }
+)";
+  code += "  " + barrier + "(mem_flags::mem_threadgroup);\n";
+  code += R"(
+  maximum = tmpx1[0];
+
+  float sum = 0.0f;
+  for (int s = int(tid); s < params.size.x; s += 32) {
+    float4 mask_temp = s == params.size.x - 1 ? params.mask : float4(1.0f);
+    float4 src = float4(src_tensor[s]) - float4(maximum);
+    sum += dot(mask_temp, exp(src));
+  }
+
+)";
+  code += "  " + barrier + "(mem_flags::mem_threadgroup);\n";
+  code += R"(
+
   tmpx1[tid] = sum;
 )";
   code += "  " + barrier + "(mem_flags::mem_threadgroup);\n";
@@ -85,74 +119,90 @@
   code += R"(
   sum = tmpx1[0];
 
-  offset = 0;
-  s = 0;
-  do {
-    if (offset + tid < params.size.x) {
-      int linear_index = offset + tid;
-      FLT4 value = FLT4(exp(float4(src_tensor[linear_index])) * sum);
-      uint3 gid = uint3(0, 0, linear_index);
-      $2
-      dst_tensor[linear_index] = value;
-      offset += 32;
-    }
-    s++;
-  } while (s < params.size.y);
+  int dst_s = int(ugid.x);
+  if (dst_s < params.size.x) {
+    int linear_index = dst_s;
+    float4 src = float4(src_tensor[linear_index]) - float4(maximum);
+    FLT4 value = FLT4(exp(src) * sum);
+    uint3 gid = uint3(0, 0, linear_index);
+    $2
+    dst_tensor[linear_index] = value;
+  }
 })";
   return code;
 }
 }  // namespace
 
-ComputeTaskDescriptor Softmax(const OperationDef& definition,
-                              int channels_count) {
+ComputeTaskDescriptor Softmax(const OperationDef& definition) {
   ComputeTaskDescriptor desc(definition);
   desc.shader_source = R"(
-    #include <metal_stdlib>
-    using namespace metal;
-    constant int src_channels = )";
-  desc.shader_source += std::to_string(channels_count);
-  desc.shader_source += R"(;
-    $0
-    kernel void ComputeFunction(
-                                $1
-                                uint3 gid[[thread_position_in_grid]]) {
-      if (int(gid.x) >= size.x || int(gid.y) >= size.y) {
-        return;
-      }
-      float shift = 0.0f;
-      int remaining_channels = src_channels % 4;
+#include <metal_stdlib>
+using namespace metal;
 
-      float sum = 0.0f;
-      for (int d = 0; d < src_channels / 4; ++d) {
-        int buffer_index = (d * size.y + gid.y) * size.x + gid.x;
-        sum += dot(float4(1.0f), exp(float4(src_tensor[buffer_index]) - shift));
-      }
-      if (remaining_channels > 0) {
-        int buffer_index = ((src_channels / 4) * size.y + gid.y) * size.x + gid.x;
-        float4 last_element = float4(src_tensor[buffer_index]);
-        sum += exp(last_element.x - shift);
-        if (remaining_channels > 1) sum += exp(last_element.y - shift);
-        if (remaining_channels == 3) sum += exp(last_element.z - shift);
-      }
+struct uniforms {
+  int4 size;
+  float4 mask;
+};
+$0
+kernel void ComputeFunction(
+                            $1
+                            uint3 gid[[thread_position_in_grid]]) {
+  if (int(gid.x) >= params.size.x || int(gid.y) >= params.size.y) {
+    return;
+  }
 
-      for (int d = 0; d < (src_channels + 3) / 4; ++d) {
-        const int linear_index = (d * size.y + gid.y) * size.x + gid.x;
-        FLT4 value = FLT4(exp(float4(src_tensor[linear_index]) - shift) / sum);
-        $2
-        dst_tensor[linear_index] = value;
-      }
-    }
+  float maximum = src_tensor[gid.y * params.size.x + gid.x].x;
+  for (int d = 0; d < params.size.z; ++d) {
+    int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
+    float4 mask_a = d == params.size.z - 1 ? params.mask : float4(1.0f);
+    float4 mask_b = float4(1.0f) - mask_a;
+    float4 src = float4(src_tensor[buffer_index]);
+    src = src * mask_a + mask_b * src.x;
+    maximum = max(maximum, src.x);
+    maximum = max(maximum, src.y);
+    maximum = max(maximum, src.z);
+    maximum = max(maximum, src.w);
+  }
+
+  float sum = 0.0f;
+  for (int d = 0; d < params.size.z; ++d) {
+    int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
+    float4 mask_temp = d == params.size.z - 1 ? params.mask : float4(1.0f);
+    float4 src = float4(src_tensor[buffer_index]) - float4(maximum);
+    sum += dot(mask_temp, exp(src));
+  }
+
+  for (int d = 0; d < params.size.z; ++d) {
+    const int linear_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
+    float4 src = float4(src_tensor[linear_index]) - float4(maximum);
+    FLT4 value = FLT4(exp(src) / sum);
+    $2
+    dst_tensor[linear_index] = value;
+  }
+}
   )";
 
   desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
   desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
 
   desc.uniform_buffers = {
-      {"constant int2& size",
+      {"constant uniforms& params",
        [](const std::vector<BHWC>& src_shapes,
           const std::vector<BHWC>& dst_shapes) {
-         std::vector<int> sizes{dst_shapes[0].w, dst_shapes[0].h};
-         return GetByteBuffer(sizes);
+         const int dst_depth = DivideRoundUp(dst_shapes[0].c, 4);
+         struct uniforms {
+           int4 size;
+           float4 mask;
+         };
+         uniforms params;
+         params.size = {dst_shapes[0].w, dst_shapes[0].h, dst_depth, 1};
+         params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
+         int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
+         for (int i = 0; i < reminder; ++i) {
+           params.mask[i] = 1.0f;
+         }
+         const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&params);
+         return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
        }},
   };
 
@@ -168,7 +218,7 @@
 }
 
 ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
-                                 const GpuInfo& gpu_info, int channels_count) {
+                                 const GpuInfo& gpu_info) {
   ComputeTaskDescriptor desc(definition);
   desc.shader_source = GetSoftmax1x1Code(gpu_info);
 
@@ -177,9 +227,9 @@
 
   desc.uniform_buffers = {
       {"constant uniforms& params",
-       [channels_count](const std::vector<BHWC>& src_shapes,
-                        const std::vector<BHWC>& dst_shapes) {
-         const int src_depth = DivideRoundUp(channels_count, 4);
+       [](const std::vector<BHWC>& src_shapes,
+          const std::vector<BHWC>& dst_shapes) {
+         const int src_depth = DivideRoundUp(dst_shapes[0].c, 4);
          struct uniforms {
            int4 size;
            float4 mask;
@@ -187,7 +237,7 @@
          uniforms params;
          params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
          params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
-         const int reminder = channels_count % 4 == 0 ? 4 : channels_count % 4;
+         int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
          for (int i = 0; i < reminder; ++i) {
            params.mask[i] = 1.0f;
          }
@@ -198,7 +248,10 @@
 
   desc.resize_function = [](const std::vector<BHWC>& src_shapes,
                             const std::vector<BHWC>& dst_shapes) {
-    return std::make_pair(uint3{32u, 1u, 1u}, uint3{1u, 1u, 1u});
+    uint3 groups_size{32, 1, 1};
+    uint3 groups_count{
+        DivideRoundUp(DivideRoundUp(dst_shapes[0].c, 4), groups_size.x), 1, 1};
+    return std::make_pair(groups_size, groups_count);
   };
 
   return desc;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h
index c64fd8a..c2d372a 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h
@@ -27,13 +27,12 @@
 namespace gpu {
 namespace metal {
 
-ComputeTaskDescriptor Softmax(const OperationDef& definition,
-                              int channels_count);
+ComputeTaskDescriptor Softmax(const OperationDef& definition);
 
 // Softmax for case when width = height = 1 and AXIS = CHANNELS
 // We have this case in MobilenetV1/V2.
 ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
-                                 const GpuInfo& gpu_info, int channels_count);
+                                 const GpuInfo& gpu_info);
 
 }  // namespace metal
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm
index 9196e9f..c841a0f 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::Axis;
 using ::tflite::gpu::BHWC;
@@ -134,4 +133,76 @@
   XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
 }
 
+- (void)testSoftmaxBigNumber {
+  TensorRef<BHWC> input;
+  input.type = DataType::FLOAT32;
+  input.ref = 0;
+  input.shape = BHWC(1, 2, 1, 2);
+
+  TensorRef<BHWC> output;
+  output.type = DataType::FLOAT32;
+  output.ref = 1;
+  output.shape = BHWC(1, 2, 1, 2);
+
+  SoftmaxAttributes attr;
+  attr.axis = Axis::CHANNELS;
+
+  double doubles[4] = {1.0, 2.0, 3.0, 100.0};
+  // exp(100) is inf in float (32 bit) but representable in double (64 bit)
+  XCTAssertTrue(std::isinf(std::exp(static_cast<float>(doubles[3]))));
+  XCTAssertFalse(std::isinf(std::exp(doubles[3])));
+  double s0 = std::exp(doubles[0]) + std::exp(doubles[1]);
+  double s1 = std::exp(doubles[2]) + std::exp(doubles[3]);
+
+  SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output});
+  XCTAssertTrue(model.PopulateTensor(0, {static_cast<float>(doubles[0]),
+                                         static_cast<float>(doubles[1]),
+                                         static_cast<float>(doubles[2]),
+                                         static_cast<float>(doubles[3])}));
+  auto status = model.Invoke();
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+  status = CompareVectors({static_cast<float>(std::exp(doubles[0]) / s0),
+                           static_cast<float>(std::exp(doubles[1]) / s0),
+                           static_cast<float>(std::exp(doubles[2]) / s1),
+                           static_cast<float>(std::exp(doubles[3]) / s1)},
+                          model.GetOutput(0), 1e-6f);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testSoftmax1x1BigNumber {
+  TensorRef<BHWC> input;
+  input.type = DataType::FLOAT32;
+  input.ref = 0;
+  input.shape = BHWC(1, 1, 1, 4);
+
+  TensorRef<BHWC> output;
+  output.type = DataType::FLOAT32;
+  output.ref = 1;
+  output.shape = BHWC(1, 1, 1, 4);
+
+  SoftmaxAttributes attr;
+  attr.axis = Axis::CHANNELS;
+
+  double doubles[4] = {1.0, 2.0, 3.0, 100.0};
+  // exp(100) is inf in float (32 bit) but representable in double (64 bit)
+  XCTAssertTrue(std::isinf(std::exp(static_cast<float>(doubles[3]))));
+  XCTAssertFalse(std::isinf(std::exp(doubles[3])));
+  double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) +
+      std::exp(doubles[2]) + std::exp(doubles[3]);
+
+  SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output});
+  XCTAssertTrue(model.PopulateTensor(0, {static_cast<float>(doubles[0]),
+                                         static_cast<float>(doubles[1]),
+                                         static_cast<float>(doubles[2]),
+                                         static_cast<float>(doubles[3])}));
+  auto status = model.Invoke();
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+  status = CompareVectors({static_cast<float>(std::exp(doubles[0]) / s0),
+                           static_cast<float>(std::exp(doubles[1]) / s0),
+                           static_cast<float>(std::exp(doubles[2]) / s0),
+                           static_cast<float>(std::exp(doubles[3]) / s0)},
+                          model.GetOutput(0), 1e-6f);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
 @end
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm
index 17e3988..b7c474e 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::BHWC;
 using ::tflite::gpu::DataType;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.cc b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.cc
new file mode 100644
index 0000000..218e4df
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.cc
@@ -0,0 +1,397 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
+
+#import <Metal/Metal.h>
+
+#include <functional>
+#include <map>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/substitute.h"
+#include "tensorflow/lite/delegates/gpu/common/convert.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
+#include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/precision.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
+#include "tensorflow/lite/delegates/gpu/common/types.h"
+#include "tensorflow/lite/delegates/gpu/common/util.h"
+#include "tensorflow/lite/delegates/gpu/metal/api.h"
+#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
+#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
+#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
+#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
+
+namespace tflite {
+namespace gpu {
+namespace metal {
+
+SingleOpModel::SingleOpModel(Operation&& operation,
+                             const std::vector<TensorRef<BHWC>>& inputs,
+                             const std::vector<TensorRef<BHWC>>& outputs) {
+  auto node = graph_.NewNode();
+  node->operation = std::move(operation);
+
+  for (int i = 0; i < inputs.size(); ++i) {
+    auto input = graph_.NewValue();
+    input->tensor = inputs[i];
+    graph_.AddConsumer(node->id, input->id).IgnoreError();
+    TensorFloat32 tensor;
+    tensor.id = input->tensor.ref;
+    tensor.shape = input->tensor.shape;
+    inputs_.emplace_back(std::move(tensor));
+  }
+
+  for (int i = 0; i < outputs.size(); ++i) {
+    auto output = graph_.NewValue();
+    output->tensor = outputs[i];
+    graph_.SetProducer(node->id, output->id).IgnoreError();
+    TensorFloat32 tensor;
+    tensor.id = output->id;
+    tensor.shape = output->tensor.shape;
+    outputs_.emplace_back(std::move(tensor));
+  }
+}
+
+absl::Status SingleOpModel::Invoke() {
+  std::vector<ValueId> input_ids;
+  input_ids.reserve(inputs_.size());
+  for (const auto& input : inputs_) {
+    input_ids.push_back(input.id);
+  }
+  std::vector<ValueId> output_ids;
+  output_ids.reserve(outputs_.size());
+  std::map<ValueId, BHWC> output_dimensions;
+  for (const auto& output : outputs_) {
+    output_ids.push_back(output.id);
+    output_dimensions[output.id] = output.shape;
+  }
+
+  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
+  std::string device_name = std::string([[device name] UTF8String]);
+  GpuInfo gpu_info;
+  GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info);
+  CalculationsPrecision precision = CalculationsPrecision::F32;
+  CompiledModel compiled_model;
+  RETURN_IF_ERROR(Compile(graph_, gpu_info, precision, &compiled_model));
+  CompiledModel optimized_model;
+  RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model,
+                                        &optimized_model));
+
+  InferenceContext inference_context;
+  RETURN_IF_ERROR(inference_context.CompileModelWithDevice(
+      device, optimized_model, input_ids, output_ids, precision));
+  std::map<ValueId, BHWC> input_dimensions;
+  std::map<ValueId, id<MTLBuffer>> input_buffers;
+  for (auto& input : inputs_) {
+    input_dimensions[input.id] = input.shape;
+    NSUInteger elements_count = input.shape.w * input.shape.h *
+                                AlignByN(input.shape.c, 4) * input.shape.b;
+    std::vector<float> src_gpu(elements_count);
+    id<MTLBuffer> input_buffer;
+    RETURN_IF_ERROR(ConvertToPHWC4(absl::MakeConstSpan(input.data), input.shape,
+                                   absl::MakeSpan(src_gpu)));
+    input_buffer = [device newBufferWithBytes:src_gpu.data()
+                                       length:(elements_count * sizeof(float))
+                                      options:MTLResourceStorageModeShared];
+    input_buffers[input.id] = input_buffer;
+  }
+
+  std::map<ValueId, id<MTLBuffer>> output_buffers;
+  for (const auto& outputDimension : output_dimensions) {
+    // Uninitialized output buffer.
+    const ValueId key = outputDimension.first;
+    const BHWC& dims = outputDimension.second;
+    const NSUInteger size =
+        dims.b * dims.w * dims.h * AlignByN(dims.c, 4) * sizeof(float);
+    output_buffers[key] =
+        [device newBufferWithLength:size options:MTLResourceStorageModeShared];
+  }
+
+  // Inference itself.
+  std::map<ValueId, id<MTLBuffer>> inout_buffers(input_buffers.begin(),
+                                                 input_buffers.end());
+  inout_buffers.insert(output_buffers.begin(), output_buffers.end());
+  id<MTLCommandQueue> command_queue = [device newCommandQueue];
+  id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
+  id<MTLComputeCommandEncoder> command_encoder =
+      [command_buffer computeCommandEncoder];
+  inference_context.EncodeWithEncoder(command_encoder, inout_buffers);
+  [command_encoder endEncoding];
+  [command_buffer commit];
+  [command_buffer waitUntilCompleted];
+
+  for (auto& output : outputs_) {
+    const auto& dim = output_dimensions[output.id];
+    NSUInteger elements_count = dim.w * dim.h * AlignByN(dim.c, 4) * dim.b;
+    output.shape = dim;
+    output.data.resize(output.shape.DimensionsProduct());
+    float* output_pointer =
+        reinterpret_cast<float*>([output_buffers[output.id] contents]);
+    RETURN_IF_ERROR(
+        ConvertFromPHWC4(absl::MakeConstSpan(output_pointer, elements_count),
+                         output.shape, absl::MakeSpan(output.data)));
+  }
+  return absl::OkStatus();
+}
+
+absl::Status CompareVectors(const std::vector<float>& reference,
+                            const std::vector<float>& output, float max_error) {
+  if (reference.size() != output.size()) {
+    const std::string message =
+        "CompareVectors: vectors size does not match for reference: " +
+        std::to_string(reference.size()) +
+        " vs. output: " + std::to_string(output.size());
+    return absl::InternalError(message);
+  }
+  for (int i = 0; i < reference.size(); i++) {
+    float error = std::abs(reference[i] - output[i]);
+    if (error > max_error) {
+      const std::string message =
+          "Reference: " + std::to_string(reference[i]) +
+          ", output: " + std::to_string(output[i]) +
+          ", error: " + std::to_string(error) +
+          ", max allowed error: " + std::to_string(max_error);
+      return absl::InternalError(message);
+    }
+  }
+  return absl::OkStatus();
+}
+
+absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes,
+                      id<MTLDevice> device,
+                      const std::map<ValueId, TensorFloat32>& inputs,
+                      std::map<ValueId, TensorFloat32>* outputs) {
+  std::vector<ValueId> inputBufferIDs;
+  inputBufferIDs.reserve(inputs.size());
+  for (const auto& input : inputs) {
+    inputBufferIDs.push_back(input.first);
+  }
+  std::vector<ValueId> outputBufferIDs;
+  outputBufferIDs.reserve(outputs->size());
+  for (const auto& output : *outputs) {
+    outputBufferIDs.push_back(output.first);
+  }
+  std::map<ValueId, BHWC> outputDimensions;
+  CompiledModel raw_model;
+  raw_model.nodes = nodes;
+  for (const auto& input : inputs) {
+    raw_model.tensor_shapes[input.first] = input.second.shape;
+  }
+  for (const auto& output : *outputs) {
+    outputDimensions[output.first] = output.second.shape;
+    raw_model.tensor_shapes[output.first] = output.second.shape;
+  }
+  CompiledModel optimized_model;
+  RETURN_IF_ERROR(ValidateOptimizeModel(inputBufferIDs, outputBufferIDs,
+                                        raw_model, &optimized_model));
+
+  CalculationsPrecision precision = CalculationsPrecision::F32;
+
+  InferenceContext inference_context;
+  RETURN_IF_ERROR(inference_context.CompileModelWithDevice(
+      device, optimized_model, inputBufferIDs, outputBufferIDs, precision));
+  std::map<ValueId, BHWC> inputDimensions;
+  std::map<ValueId, std::vector<float>> inputBuffersCPU;
+  std::map<ValueId, id<MTLBuffer>> inputBuffersGPU;
+  for (auto& input : inputs) {
+    const auto& src = input.second;
+    inputDimensions[input.first] = src.shape;
+    const int paddedDepth = AlignByN(src.shape.c, 4);
+    NSUInteger elementsCount =
+        src.shape.w * src.shape.h * paddedDepth * src.shape.b;
+    std::vector<float> src_gpu(elementsCount);
+    id<MTLBuffer> inputBuffer;
+    RETURN_IF_ERROR(ConvertToPHWC4(absl::MakeConstSpan(src.data), src.shape,
+                                   absl::MakeSpan(src_gpu)));
+    inputBuffer = [device newBufferWithBytes:src_gpu.data()
+                                      length:(elementsCount * sizeof(float))
+                                     options:MTLResourceStorageModeShared];
+    inputBuffersGPU[input.first] = inputBuffer;
+  }
+
+  std::map<ValueId, id<MTLBuffer>> outputBuffers;
+  for (const auto& outputDimension : outputDimensions) {
+    // Uninitialized output buffer.
+    const ValueId key = outputDimension.first;
+    const BHWC& dims = outputDimension.second;
+    const NSUInteger outputDataSize =
+        dims.b * dims.w * dims.h * AlignByN(dims.c, 4) * sizeof(float);
+    outputBuffers[key] =
+        [device newBufferWithLength:outputDataSize
+                            options:MTLResourceStorageModeShared];
+  }
+
+  // Inference itself.
+  std::map<ValueId, id<MTLBuffer>> inputOutputBuffers(inputBuffersGPU.begin(),
+                                                      inputBuffersGPU.end());
+  inputOutputBuffers.insert(outputBuffers.begin(), outputBuffers.end());
+  id<MTLCommandQueue> commandQueue = [device newCommandQueue];
+  id<MTLCommandBuffer> commandBuffer = [commandQueue commandBuffer];
+  id<MTLComputeCommandEncoder> commandEncoder =
+      [commandBuffer computeCommandEncoder];
+  inference_context.EncodeWithEncoder(commandEncoder, inputOutputBuffers);
+  [commandEncoder endEncoding];
+  [commandBuffer commit];
+  [commandBuffer waitUntilCompleted];
+
+  for (auto& output : *outputs) {
+    const auto& dim = outputDimensions[output.first];
+    const int paddedDepth = AlignByN(dim.c, 4);
+    NSUInteger elementsCount = dim.w * dim.h * paddedDepth * dim.b;
+    auto& dst = output.second;
+    dst.shape = dim;
+    dst.data.resize(dst.shape.DimensionsProduct());
+    float* outputPointer =
+        reinterpret_cast<float*>([outputBuffers[output.first] contents]);
+    RETURN_IF_ERROR(
+        ConvertFromPHWC4(absl::MakeConstSpan(outputPointer, elementsCount),
+                         dst.shape, absl::MakeSpan(dst.data)));
+  }
+
+  return absl::OkStatus();
+}
+
+MetalExecutionEnvironment::MetalExecutionEnvironment() {
+  device_ = MTLCreateSystemDefaultDevice();
+  std::string device_name = std::string([[device_ name] UTF8String]);
+  GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info_);
+}
+
+std::vector<CalculationsPrecision>
+MetalExecutionEnvironment::GetSupportedPrecisions() const {
+  return {CalculationsPrecision::F32, CalculationsPrecision::F32_F16,
+          CalculationsPrecision::F16};
+}
+
+std::vector<TensorStorageType> MetalExecutionEnvironment::GetSupportedStorages()
+    const {
+  return {TensorStorageType::BUFFER};
+}
+
+// returns storage types that support zero clamping when reading OOB in HW
+// (Height/Width) dimensions.
+std::vector<TensorStorageType>
+MetalExecutionEnvironment::GetSupportedStoragesWithHWZeroClampSupport() const {
+  return {};
+}
+
+absl::Status MetalExecutionEnvironment::ExecuteGPUOperation(
+    const std::vector<TensorFloat32>& src_cpu,
+    std::unique_ptr<ComputeTaskDescriptor>&& operation,
+    const std::vector<BHWC>& dst_sizes,
+    const std::vector<TensorFloat32*>& dst_cpu) {
+  const OperationDef op_def = operation->definition;
+  std::vector<MetalSpatialTensor> src(src_cpu.size());
+  for (int i = 0; i < src_cpu.size(); ++i) {
+    auto src_shape = src_cpu[i].shape;
+    if (src_shape.b != 1 && !op_def.IsBatchSupported()) {
+      return absl::InvalidArgumentError(
+          "Layout doesn't have Batch dimension, but shape.b != 1");
+    }
+    RETURN_IF_ERROR(
+        CreateTensor(device_, src_shape, op_def.src_tensors[i], &src[i]));
+    RETURN_IF_ERROR(src[i].WriteData(src_cpu[i]));
+  }
+
+  std::vector<MetalSpatialTensor> dst(dst_cpu.size());
+  for (int i = 0; i < dst_cpu.size(); ++i) {
+    auto dst_shape = dst_sizes[i];
+    if (dst_shape.b != 1 && !op_def.IsBatchSupported()) {
+      return absl::InvalidArgumentError(
+          "Layout doesn't have Batch dimension, but shape.b != 1");
+    }
+    RETURN_IF_ERROR(
+        CreateTensor(device_, dst_shape, op_def.dst_tensors[i], &dst[i]));
+  }
+
+  std::map<ValueId, BHWC> tensor_shapes;
+  NodeDescriptor metal_node;
+  metal_node.task = std::move(operation);
+  metal_node.src_tensors_ids.resize(src_cpu.size());
+  for (int i = 0; i < src_cpu.size(); ++i) {
+    metal_node.src_tensors_ids[i] = i;
+    tensor_shapes[i] = src_cpu[i].shape;
+  }
+  metal_node.dst_tensors_ids.resize(dst_cpu.size());
+  for (int i = 0; i < dst_cpu.size(); ++i) {
+    metal_node.dst_tensors_ids[i] = src_cpu.size() + i;
+    tensor_shapes[src_cpu.size() + i] = dst_sizes[i];
+  }
+  metal_node.description = "test_op";
+  metal_node.id = 0;
+
+  std::string buffer_declarations;
+  int index = 0;
+  for (int i = 0; i < metal_node.task->dst_tensors_names.size(); ++i) {
+    buffer_declarations += metal_node.task->dst_tensors_names[i] + "[[buffer(" +
+                           std::to_string(index) + ")]],\n";
+    index++;
+  }
+  for (int i = 0; i < metal_node.task->src_tensors_names.size(); ++i) {
+    buffer_declarations += metal_node.task->src_tensors_names[i] + "[[buffer(" +
+                           std::to_string(index) + ")]],\n";
+    index++;
+  }
+  for (const auto& buffer : metal_node.task->immutable_buffers) {
+    buffer_declarations +=
+        buffer.declaration + "[[buffer(" + std::to_string(index) + ")]],\n";
+    index++;
+  }
+  for (const auto& buffer : metal_node.task->uniform_buffers) {
+    buffer_declarations +=
+        buffer.declaration + "[[buffer(" + std::to_string(index) + ")]],\n";
+    index++;
+  }
+
+  metal_node.task->shader_source = absl::Substitute(
+      metal_node.task->shader_source, "$0", buffer_declarations + "$1", "");
+
+  ComputeTask gpu_task;
+  RETURN_IF_ERROR(
+      gpu_task.CompileWithDevice(device_, metal_node, op_def.precision));
+  RETURN_IF_ERROR(gpu_task.UpdateParamsWithDevice(device_, tensor_shapes));
+  for (int i = 0; i < src_cpu.size(); ++i) {
+    gpu_task.SetSrcTensor(src[i], i);
+  }
+  for (int i = 0; i < dst_cpu.size(); ++i) {
+    gpu_task.SetDstTensor(dst[i], i);
+  }
+
+  id<MTLCommandQueue> command_queue = [device_ newCommandQueue];
+  id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
+  id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
+  gpu_task.EncodeWithEncoder(encoder);
+  [encoder endEncoding];
+  [command_buffer commit];
+  [command_buffer waitUntilCompleted];
+
+  for (int i = 0; i < dst_cpu.size(); ++i) {
+    dst_cpu[i]->shape = dst_sizes[i];
+    dst_cpu[i]->data = std::vector<float>(dst_sizes[i].DimensionsProduct(), 0);
+    RETURN_IF_ERROR(dst[i].ReadData(dst_cpu[i]));
+  }
+
+  return absl::OkStatus();
+}
+
+}  // namespace metal
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h
index 14b64d3..b5740f2 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_TEST_UTIL_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_TEST_UTIL_H_
 
+#import <Metal/Metal.h>
+
 #include <map>
 #include <vector>
 
@@ -26,7 +28,6 @@
 #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
 namespace gpu {
@@ -68,6 +69,48 @@
                       const std::map<ValueId, TensorFloat32>& inputs,
                       std::map<ValueId, TensorFloat32>* outputs);
 
+class MetalExecutionEnvironment {
+ public:
+  MetalExecutionEnvironment();
+  ~MetalExecutionEnvironment() = default;
+
+  std::vector<CalculationsPrecision> GetSupportedPrecisions() const;
+  std::vector<TensorStorageType> GetSupportedStorages() const;
+  // returns storage types that support zero clamping when reading OOB in HW
+  // (Height/Width) dimensions.
+  std::vector<TensorStorageType> GetSupportedStoragesWithHWZeroClampSupport()
+      const;
+
+  const GpuInfo& GetGpuInfo() const { return gpu_info_; }
+
+  absl::Status ExecuteGPUOperation(
+      const std::vector<TensorFloat32>& src_cpu,
+      std::unique_ptr<ComputeTaskDescriptor>&& operation,
+      const std::vector<BHWC>& dst_sizes,
+      const std::vector<TensorFloat32*>& dst_cpu);
+
+  absl::Status ExecuteGPUOperation(
+      const TensorFloat32& src_cpu,
+      std::unique_ptr<ComputeTaskDescriptor>&& operation, const BHWC& dst_size,
+      TensorFloat32* result) {
+    return ExecuteGPUOperation(std::vector<TensorFloat32>{src_cpu},
+                               std::move(operation), dst_size, result);
+  }
+
+  absl::Status ExecuteGPUOperation(
+      const std::vector<TensorFloat32>& src_cpu,
+      std::unique_ptr<ComputeTaskDescriptor>&& operation, const BHWC& dst_size,
+      TensorFloat32* result) {
+    return ExecuteGPUOperation(
+        std::vector<TensorFloat32>{src_cpu}, std::move(operation),
+        std::vector<BHWC>{dst_size}, std::vector<TensorFloat32*>{result});
+  }
+
+ private:
+  id<MTLDevice> device_;
+  GpuInfo gpu_info_;
+};
+
 }  // namespace metal
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm
deleted file mode 100644
index 912910c..0000000
--- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm
+++ /dev/null
@@ -1,264 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-
-#import <Metal/Metal.h>
-
-#include <functional>
-#include <map>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/lite/delegates/gpu/common/convert.h"
-#include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/shape.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/common/tensor.h"
-#include "tensorflow/lite/delegates/gpu/common/types.h"
-#include "tensorflow/lite/delegates/gpu/common/util.h"
-#include "tensorflow/lite/delegates/gpu/metal/api.h"
-#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
-#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
-#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
-#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
-
-namespace tflite {
-namespace gpu {
-namespace metal {
-
-SingleOpModel::SingleOpModel(Operation&& operation, const std::vector<TensorRef<BHWC>>& inputs,
-                             const std::vector<TensorRef<BHWC>>& outputs) {
-  auto node = graph_.NewNode();
-  node->operation = std::move(operation);
-
-  for (int i = 0; i < inputs.size(); ++i) {
-    auto input = graph_.NewValue();
-    input->tensor = inputs[i];
-    graph_.AddConsumer(node->id, input->id).IgnoreError();
-    TensorFloat32 tensor;
-    tensor.id = input->tensor.ref;
-    tensor.shape = input->tensor.shape;
-    inputs_.emplace_back(std::move(tensor));
-  }
-
-  for (int i = 0; i < outputs.size(); ++i) {
-    auto output = graph_.NewValue();
-    output->tensor = outputs[i];
-    graph_.SetProducer(node->id, output->id).IgnoreError();
-    TensorFloat32 tensor;
-    tensor.id = output->id;
-    tensor.shape = output->tensor.shape;
-    outputs_.emplace_back(std::move(tensor));
-  }
-}
-
-absl::Status SingleOpModel::Invoke() {
-  std::vector<ValueId> input_ids;
-  input_ids.reserve(inputs_.size());
-  for (const auto& input : inputs_) {
-    input_ids.push_back(input.id);
-  }
-  std::vector<ValueId> output_ids;
-  output_ids.reserve(outputs_.size());
-  std::map<ValueId, BHWC> output_dimensions;
-  for (const auto& output : outputs_) {
-    output_ids.push_back(output.id);
-    output_dimensions[output.id] = output.shape;
-  }
-
-  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
-  std::string device_name = std::string([[device name] UTF8String]);
-  GpuInfo gpu_info;
-  GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info);
-  RuntimeOptions options;
-  options.storage_precision = RuntimeOptions::Precision::FP32;
-  options.accumulator_precision = RuntimeOptions::Precision::FP32;
-  CompiledModel compiled_model;
-  RETURN_IF_ERROR(Compile(graph_, gpu_info, options, &compiled_model));
-  CompiledModel optimized_model;
-  RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model, &optimized_model));
-
-  TFLInferenceContext* graph = [[TFLInferenceContext alloc] init];
-  RETURN_IF_ERROR([graph compileModelWithDevice:device
-                                          model:optimized_model
-                                 inputBufferIDs:input_ids
-                                outputBufferIDs:output_ids
-                                 runtimeOptions:options]);
-  std::map<ValueId, BHWC> input_dimensions;
-  std::map<ValueId, id<MTLBuffer>> input_buffers;
-  for (auto& input : inputs_) {
-    input_dimensions[input.id] = input.shape;
-    NSUInteger elements_count =
-        input.shape.w * input.shape.h * AlignByN(input.shape.c, 4) * input.shape.b;
-    std::vector<float> src_gpu(elements_count);
-    id<MTLBuffer> input_buffer;
-    RETURN_IF_ERROR(
-        ConvertToPHWC4(absl::MakeConstSpan(input.data), input.shape, absl::MakeSpan(src_gpu)));
-    input_buffer = [device newBufferWithBytes:src_gpu.data()
-                                       length:(elements_count * sizeof(float))
-                                      options:MTLResourceStorageModeShared];
-    input_buffers[input.id] = input_buffer;
-  }
-
-  std::map<ValueId, id<MTLBuffer>> output_buffers;
-  for (const auto& outputDimension : output_dimensions) {
-    // Uninitialized output buffer.
-    const ValueId key = outputDimension.first;
-    const BHWC& dims = outputDimension.second;
-    const NSUInteger size = dims.b * dims.w * dims.h * AlignByN(dims.c, 4) * sizeof(float);
-    output_buffers[key] = [device newBufferWithLength:size options:MTLResourceStorageModeShared];
-  }
-
-  // Inference itself.
-  std::map<ValueId, id<MTLBuffer>> inout_buffers(input_buffers.begin(), input_buffers.end());
-  inout_buffers.insert(output_buffers.begin(), output_buffers.end());
-  id<MTLCommandQueue> command_queue = [device newCommandQueue];
-  id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
-  id<MTLComputeCommandEncoder> command_encoder = [command_buffer computeCommandEncoder];
-  [graph encodeWithEncoder:command_encoder inputOutputBuffers:inout_buffers encoderBlock:nil];
-  [command_encoder endEncoding];
-  [command_buffer commit];
-  [command_buffer waitUntilCompleted];
-
-  for (auto& output : outputs_) {
-    const auto& dim = output_dimensions[output.id];
-    NSUInteger elements_count = dim.w * dim.h * AlignByN(dim.c, 4) * dim.b;
-    output.shape = dim;
-    output.data.resize(output.shape.DimensionsProduct());
-    float* output_pointer = reinterpret_cast<float*>([output_buffers[output.id] contents]);
-    RETURN_IF_ERROR(ConvertFromPHWC4(absl::MakeConstSpan(output_pointer, elements_count),
-                                     output.shape, absl::MakeSpan(output.data)));
-  }
-  return absl::OkStatus();
-}
-
-absl::Status CompareVectors(const std::vector<float>& reference, const std::vector<float>& output,
-                            float max_error) {
-  if (reference.size() != output.size()) {
-    const std::string message = "CompareVectors: vectors size does not match for reference: " +
-                                std::to_string(reference.size()) +
-                                " vs. output: " + std::to_string(output.size());
-    return absl::InternalError(message);
-  }
-  for (int i = 0; i < reference.size(); i++) {
-    float error = std::abs(reference[i] - output[i]);
-    if (error > max_error) {
-      const std::string message =
-          "Reference: " + std::to_string(reference[i]) + ", output: " + std::to_string(output[i]) +
-          ", error: " + std::to_string(error) + ", max allowed error: " + std::to_string(max_error);
-      return absl::InternalError(message);
-    }
-  }
-  return absl::OkStatus();
-}
-
-absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes, id<MTLDevice> device,
-                      const std::map<ValueId, TensorFloat32>& inputs,
-                      std::map<ValueId, TensorFloat32>* outputs) {
-  std::vector<ValueId> inputBufferIDs;
-  inputBufferIDs.reserve(inputs.size());
-  for (const auto& input : inputs) {
-    inputBufferIDs.push_back(input.first);
-  }
-  std::vector<ValueId> outputBufferIDs;
-  outputBufferIDs.reserve(outputs->size());
-  for (const auto& output : *outputs) {
-    outputBufferIDs.push_back(output.first);
-  }
-  std::map<ValueId, BHWC> outputDimensions;
-  CompiledModel raw_model;
-  raw_model.nodes = nodes;
-  for(const auto& input : inputs) {
-    raw_model.tensor_shapes[input.first] = input.second.shape;
-  }
-  for(const auto& output : *outputs) {
-    outputDimensions[output.first] = output.second.shape;
-    raw_model.tensor_shapes[output.first] = output.second.shape;
-  }
-  CompiledModel optimized_model;
-  RETURN_IF_ERROR(
-      ValidateOptimizeModel(inputBufferIDs, outputBufferIDs, raw_model, &optimized_model));
-
-  RuntimeOptions options;
-  options.storage_precision = RuntimeOptions::Precision::FP32;
-  options.accumulator_precision = RuntimeOptions::Precision::FP32;
-
-  TFLInferenceContext* graph = [[TFLInferenceContext alloc] init];
-  RETURN_IF_ERROR([graph compileModelWithDevice:device
-                                          model:optimized_model
-                                 inputBufferIDs:inputBufferIDs
-                                outputBufferIDs:outputBufferIDs
-                                 runtimeOptions:options]);
-  std::map<ValueId, BHWC> inputDimensions;
-  std::map<ValueId, std::vector<float>> inputBuffersCPU;
-  std::map<ValueId, id<MTLBuffer>> inputBuffersGPU;
-  for (auto& input : inputs) {
-    const auto& src = input.second;
-    inputDimensions[input.first] = src.shape;
-    const int paddedDepth = AlignByN(src.shape.c, 4);
-    NSUInteger elementsCount = src.shape.w * src.shape.h * paddedDepth * src.shape.b;
-    std::vector<float> src_gpu(elementsCount);
-    id<MTLBuffer> inputBuffer;
-    RETURN_IF_ERROR(
-        ConvertToPHWC4(absl::MakeConstSpan(src.data), src.shape, absl::MakeSpan(src_gpu)));
-    inputBuffer = [device newBufferWithBytes:src_gpu.data()
-                                      length:(elementsCount * sizeof(float))
-                                     options:MTLResourceStorageModeShared];
-    inputBuffersGPU[input.first] = inputBuffer;
-  }
-
-  std::map<ValueId, id<MTLBuffer>> outputBuffers;
-  for (const auto& outputDimension : outputDimensions) {
-    // Uninitialized output buffer.
-    const ValueId key = outputDimension.first;
-    const BHWC& dims = outputDimension.second;
-    const NSUInteger outputDataSize =
-        dims.b * dims.w * dims.h * AlignByN(dims.c, 4) * sizeof(float);
-    outputBuffers[key] = [device newBufferWithLength:outputDataSize
-                                             options:MTLResourceStorageModeShared];
-  }
-
-  // Inference itself.
-  std::map<ValueId, id<MTLBuffer>> inputOutputBuffers(inputBuffersGPU.begin(),
-                                                      inputBuffersGPU.end());
-  inputOutputBuffers.insert(outputBuffers.begin(), outputBuffers.end());
-  id<MTLCommandQueue> commandQueue = [device newCommandQueue];
-  id<MTLCommandBuffer> commandBuffer = [commandQueue commandBuffer];
-  id<MTLComputeCommandEncoder> commandEncoder = [commandBuffer computeCommandEncoder];
-  [graph encodeWithEncoder:commandEncoder inputOutputBuffers:inputOutputBuffers encoderBlock:nil];
-  [commandEncoder endEncoding];
-  [commandBuffer commit];
-  [commandBuffer waitUntilCompleted];
-
-  for (auto& output : *outputs) {
-    const auto& dim = outputDimensions[output.first];
-    const int paddedDepth = AlignByN(dim.c, 4);
-    NSUInteger elementsCount = dim.w * dim.h * paddedDepth * dim.b;
-    auto& dst = output.second;
-    dst.shape = dim;
-    dst.data.resize(dst.shape.DimensionsProduct());
-    float* outputPointer = reinterpret_cast<float*>([outputBuffers[output.first] contents]);
-    RETURN_IF_ERROR(ConvertFromPHWC4(absl::MakeConstSpan(outputPointer, elementsCount), dst.shape,
-                                     absl::MakeSpan(dst.data)));
-  }
-
-  return absl::OkStatus();
-}
-
-}  // namespace metal
-}  // namespace gpu
-}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv_test.mm
index 3d716ec..dd5f412 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv_test.mm
@@ -27,7 +27,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 using ::tflite::gpu::ConvolutionTransposedAttributes;
 using ::tflite::gpu::BHWC;
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm
index 90d6c2e..ac053da 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm
@@ -26,7 +26,6 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
 
 using ::tflite::gpu::BHWC;
@@ -93,22 +92,15 @@
   op_def.src_tensors.push_back(tensor_descriptor);
   op_def.dst_tensors.push_back(tensor_descriptor);
   auto gpu_op = tflite::gpu::metal::Winograd4x4To36(op_def, attr);
-  std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
-  nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
-  nodes[0].src_tensors_ids = {0};
-  nodes[0].dst_tensors_ids = {1};
 
-  std::map<ValueId, TensorFloat32> inputs;
-  inputs[0] = src_tensor;
-  std::map<ValueId, TensorFloat32> outputs;
-  outputs[1].shape = BHWC(1, 36, 1, 1);
-  outputs[1].data.resize(36, 0.0f);
-
-  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
-  auto status = RunGraph(nodes, device, inputs, &outputs);
+  tflite::gpu::metal::MetalExecutionEnvironment env;
+  auto op_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
+  TensorFloat32 gpu_output;
+  auto status = env.ExecuteGPUOperation(src_tensor, std::move(op_ptr),
+                                        BHWC(1, 36, 1, 1), &gpu_output);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 
-  status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f);
+  status = CompareVectors(dst_tensor.data, gpu_output.data, 1e-6f);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 }
 
@@ -151,10 +143,6 @@
     }
   }
 
-  tflite::gpu::metal::RuntimeOptions options;
-  options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32;
-  options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32;
-
   tflite::gpu::metal::Winograd4x4To36Attributes attr;
   attr.padding.prepended = tflite::gpu::HW(1, 1);
   attr.padding.appended = tflite::gpu::HW(1, 1);
@@ -167,22 +155,15 @@
   op_def.src_tensors.push_back(tensor_descriptor);
   op_def.dst_tensors.push_back(tensor_descriptor);
   auto gpu_op = tflite::gpu::metal::Winograd4x4To36TileX6(op_def, attr);
-  std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
-  nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
-  nodes[0].src_tensors_ids = {0};
-  nodes[0].dst_tensors_ids = {1};
 
-  std::map<ValueId, TensorFloat32> inputs;
-  inputs[0] = src_tensor;
-  std::map<ValueId, TensorFloat32> outputs;
-  outputs[1].shape = BHWC(1, 36, 1, 1);
-  outputs[1].data.resize(36, 0.0f);
-
-  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
-  auto status = RunGraph(nodes, device, inputs, &outputs);
+  tflite::gpu::metal::MetalExecutionEnvironment env;
+  auto op_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
+  TensorFloat32 gpu_output;
+  auto status = env.ExecuteGPUOperation(src_tensor, std::move(op_ptr),
+                                        BHWC(1, 36, 1, 1), &gpu_output);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 
-  status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f);
+  status = CompareVectors(dst_tensor.data, gpu_output.data, 1e-6f);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 }
 
@@ -229,10 +210,6 @@
   attr.biases.shape = tflite::gpu::Linear(1);
   attr.biases.data.resize(1, 0.0f);
 
-  tflite::gpu::metal::RuntimeOptions options;
-  options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32;
-  options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32;
-
   tflite::gpu::OperationDef op_def;
   op_def.precision = tflite::gpu::CalculationsPrecision::F32;
   tflite::gpu::TensorDescriptor tensor_descriptor = tflite::gpu::TensorDescriptor{
@@ -242,22 +219,15 @@
   op_def.src_tensors.push_back(tensor_descriptor);
   op_def.dst_tensors.push_back(tensor_descriptor);
   auto gpu_op = tflite::gpu::metal::Winograd36To4x4(op_def, attr);
-  std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
-  nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
-  nodes[0].src_tensors_ids = {0};
-  nodes[0].dst_tensors_ids = {1};
 
-  std::map<ValueId, TensorFloat32> inputs;
-  inputs[0] = src_tensor;
-  std::map<ValueId, TensorFloat32> outputs;
-  outputs[1].shape = BHWC(1, 4, 4, 1);
-  outputs[1].data.resize(16, 0.0f);
-
-  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
-  auto status = RunGraph(nodes, device, inputs, &outputs);
+  tflite::gpu::metal::MetalExecutionEnvironment env;
+  auto op_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
+  TensorFloat32 gpu_output;
+  auto status = env.ExecuteGPUOperation(src_tensor, std::move(op_ptr),
+                                        BHWC(1, 4, 4, 1), &gpu_output);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 
-  status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-5f);
+  status = CompareVectors(dst_tensor.data, gpu_output.data, 1e-5f);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 }
 
@@ -304,10 +274,6 @@
   attr.biases.shape = tflite::gpu::Linear(1);
   attr.biases.data.resize(1, 0.0f);
 
-  tflite::gpu::metal::RuntimeOptions options;
-  options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32;
-  options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32;
-
   tflite::gpu::OperationDef op_def;
   op_def.precision = tflite::gpu::CalculationsPrecision::F32;
   tflite::gpu::TensorDescriptor tensor_descriptor = tflite::gpu::TensorDescriptor{
@@ -317,22 +283,15 @@
   op_def.src_tensors.push_back(tensor_descriptor);
   op_def.dst_tensors.push_back(tensor_descriptor);
   auto gpu_op = tflite::gpu::metal::Winograd36To4x4Tile4x1(op_def, attr);
-  std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
-  nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
-  nodes[0].src_tensors_ids = {0};
-  nodes[0].dst_tensors_ids = {1};
 
-  std::map<ValueId, TensorFloat32> inputs;
-  inputs[0] = src_tensor;
-  std::map<ValueId, TensorFloat32> outputs;
-  outputs[1].shape = BHWC(1, 4, 4, 1);
-  outputs[1].data.resize(16, 0.0f);
-
-  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
-  auto status = RunGraph(nodes, device, inputs, &outputs);
+  tflite::gpu::metal::MetalExecutionEnvironment env;
+  auto op_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
+  TensorFloat32 gpu_output;
+  auto status = env.ExecuteGPUOperation(src_tensor, std::move(op_ptr),
+                                        BHWC(1, 4, 4, 1), &gpu_output);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 
-  status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f);
+  status = CompareVectors(dst_tensor.data, gpu_output.data, 1e-6f);
   XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
 }
 
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.mm b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
similarity index 90%
rename from tensorflow/lite/delegates/gpu/metal/metal_arguments.mm
rename to tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
index d5c7671..400f001 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.mm
+++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
@@ -135,7 +135,8 @@
 // Static
 constexpr char MetalArguments::kArgsPrefix[];
 
-absl::Status MetalArguments::Init(id<MTLDevice> device, int buffer_offset, Arguments* args, std::string* code) {
+absl::Status MetalArguments::Init(id<MTLDevice> device, int buffer_offset,
+                                  Arguments* args, std::string* code) {
   RETURN_IF_ERROR(AllocateObjects(*args, device));
   RETURN_IF_ERROR(AddObjectArgs(args));
   RETURN_IF_ERROR(ResolveSelectorsPass(*args, {}, code));
@@ -174,13 +175,15 @@
     const_data_.resize(aligned_pos * 4);
     for (auto& it : float_values_) {
       if (it.second.active) {
-        float* ptr = reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
+        float* ptr =
+            reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
         *ptr = it.second.value;
       }
     }
     for (auto& it : int_values_) {
       if (it.second.active) {
-        int32_t* ptr = reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
+        int32_t* ptr =
+            reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
         *ptr = it.second.value;
       }
     }
@@ -200,7 +203,8 @@
   }
   it->second.value = value;
   if (it->second.active) {
-    int32_t* ptr = reinterpret_cast<int32_t*>(&const_data_[it->second.bytes_offset]);
+    int32_t* ptr =
+        reinterpret_cast<int32_t*>(&const_data_[it->second.bytes_offset]);
     *ptr = value;
   }
   return absl::OkStatus();
@@ -213,23 +217,28 @@
   }
   it->second.value = value;
   if (it->second.active) {
-    float* ptr = reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
+    float* ptr =
+        reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
     *ptr = value;
   }
   return absl::OkStatus();
 }
 
 absl::Status MetalArguments::SetHalf(const std::string& name, half value) {
-  return absl::UnimplementedError("No support of half uniforms in Metal backend");
+  return absl::UnimplementedError(
+      "No support of half uniforms in Metal backend");
 }
 
-void MetalArguments::Encode(id<MTLComputeCommandEncoder> encoder, int buffer_offset) const {
+void MetalArguments::Encode(id<MTLComputeCommandEncoder> encoder,
+                            int buffer_offset) const {
   for (auto& b : buffers_) {
     [encoder setBuffer:b.second.handle offset:0 atIndex:buffer_offset];
     buffer_offset++;
   }
   if (!const_data_.empty()) {
-    [encoder setBytes:const_data_.data() length:const_data_.size() atIndex:buffer_offset];
+    [encoder setBytes:const_data_.data()
+               length:const_data_.size()
+              atIndex:buffer_offset];
   }
 }
 
@@ -262,17 +271,18 @@
       attributes += absl::StrCat("  __attribute__((", attr, "))");
     }
     AppendArgument(
-        absl::StrCat(
-            MemoryTypeToMetalType(t.second.desc.memory_type), " ",
-            ToMetalDataType(t.second.desc.data_type, t.second.desc.element_size),
-            "* ", t.first, "[[buffer(", buffer_offset, ")]]", attributes),
+        absl::StrCat(MemoryTypeToMetalType(t.second.desc.memory_type), " ",
+                     ToMetalDataType(t.second.desc.data_type,
+                                     t.second.desc.element_size),
+                     "* ", t.first, "[[buffer(", buffer_offset, ")]]",
+                     attributes),
         &result);
     buffer_offset++;
   }
   if (!const_data_.empty()) {
-    AppendArgument(
-        absl::StrCat("constant uniforms_buffer& U[[buffer(", buffer_offset, ")]]"),
-        &result);
+    AppendArgument(absl::StrCat("constant uniforms_buffer& U[[buffer(",
+                                buffer_offset, ")]]"),
+                   &result);
     buffer_offset++;
   }
   if (!result.empty()) {
@@ -295,7 +305,8 @@
   return absl::OkStatus();
 }
 
-void MetalArguments::AddBuffer(const std::string& name, const GPUBufferDescriptor& desc) {
+void MetalArguments::AddBuffer(const std::string& name,
+                               const GPUBufferDescriptor& desc) {
   buffers_[name].desc = desc;
 }
 
@@ -313,7 +324,8 @@
   }
 }
 
-absl::Status MetalArguments::SetBuffer(const std::string& name, id<MTLBuffer> handle) {
+absl::Status MetalArguments::SetBuffer(const std::string& name,
+                                       id<MTLBuffer> handle) {
   auto it = buffers_.find(name);
   if (it == buffers_.end()) {
     return absl::NotFoundError(
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.mm b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
similarity index 80%
rename from tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.mm
rename to tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
index 7e923dc..f021dc4 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.mm
+++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc
@@ -15,6 +15,8 @@
 
 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
 
+#include <memory>
+
 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
 
 namespace tflite {
@@ -27,7 +29,7 @@
                                   const void* data_ptr, id<MTLBuffer>* buffer) {
   const int slices = DivideRoundUp(shape.c, 4);
   switch (descriptor.storage_type) {
-    case TensorStorageType::BUFFER:{
+    case TensorStorageType::BUFFER: {
       const size_t data_size = shape.b * shape.w * shape.h * shape.d * slices *
                                4 * SizeOf(descriptor.data_type);
       if (data_ptr) {
@@ -51,8 +53,8 @@
 }
 
 absl::Status CreateTensor(id<MTLDevice> device, const BHWDC& shape,
-                          const TensorDescriptor& descriptor, id<MTLBuffer> buffer,
-                          MetalSpatialTensor* result) {
+                          const TensorDescriptor& descriptor,
+                          id<MTLBuffer> buffer, MetalSpatialTensor* result) {
   const bool memory_owner = buffer == nullptr;
   if (memory_owner) {
     RETURN_IF_ERROR(
@@ -64,15 +66,17 @@
 }
 }  // namespace
 
-MetalSpatialTensor::MetalSpatialTensor(id<MTLBuffer> buffer, bool memory_owner, const BHWC& shape,
-               const TensorDescriptor& descriptor)
+MetalSpatialTensor::MetalSpatialTensor(id<MTLBuffer> buffer, bool memory_owner,
+                                       const BHWC& shape,
+                                       const TensorDescriptor& descriptor)
     : memory_(buffer),
       memory_owner_(memory_owner),
       shape_(shape.b, shape.h, shape.w, 1, shape.c),
       descriptor_(descriptor) {}
 
-MetalSpatialTensor::MetalSpatialTensor(id<MTLBuffer> buffer, bool memory_owner, const BHWDC& shape,
-               const TensorDescriptor& descriptor)
+MetalSpatialTensor::MetalSpatialTensor(id<MTLBuffer> buffer, bool memory_owner,
+                                       const BHWDC& shape,
+                                       const TensorDescriptor& descriptor)
     : memory_(buffer),
       memory_owner_(memory_owner),
       shape_(shape),
@@ -103,8 +107,9 @@
   }
 }
 
-absl::Status MetalSpatialTensor::GetGPUResources(const GPUObjectDescriptor* obj_ptr,
-                                     GPUResourcesWithValue* resources) const {
+absl::Status MetalSpatialTensor::GetGPUResources(
+    const GPUObjectDescriptor* obj_ptr,
+    GPUResourcesWithValue* resources) const {
   const auto* buffer_desc = dynamic_cast<const BufferDescriptor*>(obj_ptr);
   if (buffer_desc) {
     if (descriptor_.storage_type != TensorStorageType::BUFFER) {
@@ -234,25 +239,23 @@
              : AlignByN(shape_.c, 4);
 }
 
-absl::Status MetalSpatialTensor::WriteDataBHWDC(absl::Span<const float> in) {
+absl::Status MetalSpatialTensor::WriteDataBHWDC(const float* in) {
   void* data_ptr = nullptr;
   const int aligned_channels = GetAlignedChannels();
   const int elements_count =
       shape_.b * shape_.w * shape_.h * shape_.d * aligned_channels;
 
   const size_t data_size = elements_count * SizeOf(descriptor_.data_type);
-  std::vector<float> data_f;
-  std::vector<half> data_h;
+  std::unique_ptr<float[]> data_f;
+  std::unique_ptr<half[]> data_h;
   if (descriptor_.data_type == DataType::FLOAT32) {
-    data_f.resize(elements_count);
-    data_ptr = data_f.data();
-    DataFromBHWDC(in, shape_, descriptor_,
-                  absl::MakeSpan(data_f.data(), data_f.size()));
+    data_f.reset(new float[elements_count]);
+    data_ptr = data_f.get();
+    DataFromBHWDC(in, shape_, descriptor_, data_f.get());
   } else {
-    data_h.resize(elements_count);
-    data_ptr = data_h.data();
-    DataFromBHWDC(in, shape_, descriptor_,
-                  absl::MakeSpan(data_h.data(), data_h.size()));
+    data_h.reset(new half[elements_count]);
+    data_ptr = data_h.get();
+    DataFromBHWDC(in, shape_, descriptor_, data_h.get());
   }
 
   switch (descriptor_.storage_type) {
@@ -273,38 +276,38 @@
 
 absl::Status MetalSpatialTensor::WriteData(const TensorFloat32& src) {
   RETURN_IF_ERROR(IsValid(src.shape));
-  return WriteDataBHWDC(absl::MakeConstSpan(src.data));
+  return WriteDataBHWDC(src.data.data());
 }
 
 absl::Status MetalSpatialTensor::WriteData(
     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src) {
-  return WriteDataBHWDC(absl::MakeConstSpan(src.data));
+  return WriteDataBHWDC(src.data.data());
 }
 
 absl::Status MetalSpatialTensor::WriteData(
     const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src) {
-  return WriteDataBHWDC(absl::MakeConstSpan(src.data));
+  return WriteDataBHWDC(src.data.data());
 }
 
 absl::Status MetalSpatialTensor::WriteData(const Tensor5DFloat32& src) {
   RETURN_IF_ERROR(IsValid(src.shape));
-  return WriteDataBHWDC(absl::MakeConstSpan(src.data));
+  return WriteDataBHWDC(src.data.data());
 }
 
-absl::Status MetalSpatialTensor::ReadDataBHWDC(absl::Span<float> out) const {
+absl::Status MetalSpatialTensor::ReadDataBHWDC(float* out) const {
   void* data_ptr = nullptr;
   const int aligned_channels = GetAlignedChannels();
   const int elements_count =
       shape_.b * shape_.w * shape_.h * shape_.d * aligned_channels;
   const size_t data_size = elements_count * SizeOf(descriptor_.data_type);
-  std::vector<float> data_f;
-  std::vector<half> data_h;
+  std::unique_ptr<float[]> data_f;
+  std::unique_ptr<half[]> data_h;
   if (descriptor_.data_type == DataType::FLOAT32) {
-    data_f.resize(elements_count);
-    data_ptr = data_f.data();
+    data_f.reset(new float[elements_count]);
+    data_ptr = data_f.get();
   } else {
-    data_h.resize(elements_count);
-    data_ptr = data_h.data();
+    data_h.reset(new half[elements_count]);
+    data_ptr = data_h.get();
   }
 
   switch (descriptor_.storage_type) {
@@ -321,11 +324,9 @@
   }
 
   if (descriptor_.data_type == DataType::FLOAT32) {
-    DataToBHWDC(absl::MakeConstSpan(data_f.data(), data_f.size()), shape_,
-                descriptor_, out);
+    DataToBHWDC(data_f.get(), shape_, descriptor_, out);
   } else {
-    DataToBHWDC(absl::MakeConstSpan(data_h.data(), data_h.size()), shape_,
-                descriptor_, out);
+    DataToBHWDC(data_h.get(), shape_, descriptor_, out);
   }
 
   return absl::OkStatus();
@@ -333,15 +334,16 @@
 
 absl::Status MetalSpatialTensor::ReadData(TensorFloat32* dst) const {
   RETURN_IF_ERROR(IsValid(dst->shape));
-  return ReadDataBHWDC(absl::MakeSpan(dst->data));
+  return ReadDataBHWDC(dst->data.data());
 }
 
 absl::Status MetalSpatialTensor::ReadData(Tensor5DFloat32* dst) const {
   RETURN_IF_ERROR(IsValid(dst->shape));
-  return ReadDataBHWDC(absl::MakeSpan(dst->data));
+  return ReadDataBHWDC(dst->data.data());
 }
 
-absl::Status MetalSpatialTensor::CreateFromDescriptor(const TensorDescriptor& desc, id<MTLDevice> device) {
+absl::Status MetalSpatialTensor::CreateFromDescriptor(
+    const TensorDescriptor& desc, id<MTLDevice> device) {
   shape_ = desc.shape;
   descriptor_.data_type = desc.data_type;
   descriptor_.storage_type = desc.storage_type;
@@ -357,17 +359,38 @@
   return absl::OkStatus();
 }
 
+void MetalSpatialTensor::SetBufferHandle(id<MTLBuffer> buffer) {
+  memory_ = buffer;
+}
+
+id<MTLBuffer> MetalSpatialTensor::GetBufferHandle() const { return memory_; }
+
 absl::Status CreateTensor(id<MTLDevice> device, const BHWC& shape,
-                          const TensorDescriptor& descriptor, MetalSpatialTensor* result) {
+                          const TensorDescriptor& descriptor,
+                          MetalSpatialTensor* result) {
   const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c);
   return CreateTensor(device, shape5D, descriptor, nullptr, result);
 }
 
 absl::Status CreateTensor(id<MTLDevice> device, const BHWDC& shape,
-                          const TensorDescriptor& descriptor, MetalSpatialTensor* result) {
+                          const TensorDescriptor& descriptor,
+                          MetalSpatialTensor* result) {
   return CreateTensor(device, shape, descriptor, nullptr, result);
 }
 
+MetalSpatialTensor CreateSharedBufferTensor(
+    id<MTLBuffer> buffer, const BHWC& shape,
+    const TensorDescriptor& descriptor) {
+  const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c);
+  return MetalSpatialTensor(buffer, false, shape5D, descriptor);
+}
+
+MetalSpatialTensor CreateSharedBufferTensor(
+    id<MTLBuffer> buffer, const BHWDC& shape,
+    const TensorDescriptor& descriptor) {
+  return MetalSpatialTensor(buffer, false, shape, descriptor);
+}
+
 }  // namespace metal
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h
index 791aae1..b5d9a28 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h
+++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h
@@ -18,7 +18,6 @@
 
 #import <Metal/Metal.h>
 
-#include "absl/types/span.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/task/gpu_tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
@@ -74,12 +73,15 @@
   absl::Status CreateFromDescriptor(const TensorDescriptor& desc,
                                     id<MTLDevice> device);
 
+  void SetBufferHandle(id<MTLBuffer> buffer);
+  id<MTLBuffer> GetBufferHandle() const;
+
  private:
   absl::Status IsValid(const BHWC& shape) const;
   absl::Status IsValid(const BHWDC& shape) const;
 
-  absl::Status WriteDataBHWDC(absl::Span<const float> in);
-  absl::Status ReadDataBHWDC(absl::Span<float> out) const;
+  absl::Status WriteDataBHWDC(const float* in);
+  absl::Status ReadDataBHWDC(float* out) const;
 
   int GetAlignedChannels() const;
   int3 GetFullTensorRegion() const;
@@ -99,6 +101,14 @@
                           const TensorDescriptor& descriptor,
                           MetalSpatialTensor* result);
 
+MetalSpatialTensor CreateSharedBufferTensor(id<MTLBuffer> buffer,
+                                            const BHWC& shape,
+                                            const TensorDescriptor& descriptor);
+
+MetalSpatialTensor CreateSharedBufferTensor(id<MTLBuffer> buffer,
+                                            const BHWDC& shape,
+                                            const TensorDescriptor& descriptor);
+
 }  // namespace metal
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/runtime_options.h b/tensorflow/lite/delegates/gpu/metal/runtime_options.h
deleted file mode 100644
index d8e8fe3..0000000
--- a/tensorflow/lite/delegates/gpu/metal/runtime_options.h
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_RUNTIME_OPTIONS_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_RUNTIME_OPTIONS_H_
-
-namespace tflite {
-namespace gpu {
-namespace metal {
-
-struct RuntimeOptions {
-  enum class Precision {
-    FP16,
-    FP32,
-  };
-  // Buffer storage format. If FP32 then accumulator must be FP32.
-  Precision storage_precision = Precision::FP32;
-  // Accumulator precision. Defines the precision for convolutions.
-  Precision accumulator_precision = Precision::FP32;
-};
-
-}  // namespace metal
-}  // namespace gpu
-}  // namespace tflite
-
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_RUNTIME_OPTIONS_H_
diff --git a/tensorflow/lite/delegates/gpu/metal/selectors/BUILD b/tensorflow/lite/delegates/gpu/metal/selectors/BUILD
new file mode 100644
index 0000000..9f5b885
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/metal/selectors/BUILD
@@ -0,0 +1,46 @@
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "default_selector",
+    hdrs = ["default_selector.h"],
+    deps = [
+        ":subgraph",
+        "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/metal/selectors/default:default_selector",  # buildcleaner: keep
+    ],
+)
+
+cc_library(
+    name = "operation_selector",
+    srcs = ["operation_selector.cc"],
+    hdrs = ["operation_selector.h"],
+    deps = [
+        ":default_selector",
+        ":subgraph",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
+        "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:precision",
+        "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common:winograd_util",
+        "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
+        "//tensorflow/lite/delegates/gpu/metal/kernels",
+    ],
+)
+
+cc_library(
+    name = "subgraph",
+    srcs = ["subgraph.cc"],
+    hdrs = ["subgraph.h"],
+    deps = [
+        "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
+        "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
+    ],
+)
diff --git a/tensorflow/lite/delegates/gpu/metal/selectors/default/BUILD b/tensorflow/lite/delegates/gpu/metal/selectors/default/BUILD
new file mode 100644
index 0000000..a9f1dde
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/metal/selectors/default/BUILD
@@ -0,0 +1,16 @@
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "default_selector",
+    srcs = ["default_selector.cc"],
+    deps = [
+        "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/metal/selectors:subgraph",
+        "@com_google_absl//absl/strings",
+    ],
+)
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc b/tensorflow/lite/delegates/gpu/metal/selectors/default/default_selector.cc
similarity index 80%
copy from tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc
copy to tensorflow/lite/delegates/gpu/metal/selectors/default/default_selector.cc
index a7d94fa..eb8a8f7 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc
+++ b/tensorflow/lite/delegates/gpu/metal/selectors/default/default_selector.cc
@@ -16,25 +16,23 @@
 #include <memory>
 
 #include "absl/strings/str_cat.h"
-#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/model_hints.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/metal/selectors/subgraph.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
+namespace metal {
 
 absl::Status SelectDefault(const GpuInfo& gpu_info, const OperationDef& op_def,
-                           ModelHints hints, const std::vector<Value*>& inputs,
+                           const std::vector<Value*>& inputs,
                            const std::vector<Value*>& outputs, const Node& node,
                            GPUOperationsSubgraph* gpu_subgraph) {
   return absl::UnimplementedError(
       absl::StrCat("No selector for ", node.operation.type));
 }
 
-}  // namespace cl
+}  // namespace metal
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h b/tensorflow/lite/delegates/gpu/metal/selectors/default_selector.h
similarity index 65%
copy from tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h
copy to tensorflow/lite/delegates/gpu/metal/selectors/default_selector.h
index 1efa215..75033c3 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h
+++ b/tensorflow/lite/delegates/gpu/metal/selectors/default_selector.h
@@ -13,29 +13,26 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SELECTOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SELECTOR_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_DEFAULT_SELECTOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_DEFAULT_SELECTOR_H_
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/model_hints.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
-#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
+#include "tensorflow/lite/delegates/gpu/metal/selectors/subgraph.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
+namespace metal {
 
 absl::Status SelectDefault(const GpuInfo& gpu_info, const OperationDef& op_def,
-                           ModelHints hints, const std::vector<Value*>& inputs,
+                           const std::vector<Value*>& inputs,
                            const std::vector<Value*>& outputs, const Node& node,
                            GPUOperationsSubgraph* gpu_subgraph);
 
-}  // namespace cl
+}  // namespace metal
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SELECTOR_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_DEFAULT_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.cc
new file mode 100644
index 0000000..2556a26
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.cc
@@ -0,0 +1,529 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.h"
+
+#include <vector>
+
+#include "absl/strings/substitute.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
+#include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
+#include "tensorflow/lite/delegates/gpu/common/util.h"
+#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
+#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/concat.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/prelu.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/relu.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/reshape.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/resize.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/slice.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h"
+#include "tensorflow/lite/delegates/gpu/metal/selectors/default_selector.h"
+#include "tensorflow/lite/delegates/gpu/metal/selectors/subgraph.h"
+
+namespace tflite {
+namespace gpu {
+namespace metal {
+namespace {
+
+std::unique_ptr<ComputeTaskDescriptor> SelectDepthWiseConv(
+    const OperationDef& op_def, const DepthwiseConvolution2DAttributes& attr) {
+  if (CheckDepthWiseConv3x3Stride1x1Support(attr)) {
+    auto gpu_op = DepthWiseConv3x3Stride1x1(op_def, attr);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  } else if (CheckDepthWiseConv3x3Stride2Support(attr)) {
+    auto gpu_op = DepthWiseConv3x3Stride2(op_def, attr);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  } else {
+    auto gpu_op = DepthWiseConvolution(op_def, attr);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  }
+}
+
+std::unique_ptr<ComputeTaskDescriptor> SelectConvolutionTransposed(
+    const OperationDef& op_def, const ConvolutionTransposedAttributes& attr,
+    const GpuInfo& gpu_info) {
+  if (CheckConvolutionTransposed4x4Support(attr)) {
+    auto gpu_op = ConvolutionTransposed4x4(op_def, attr, gpu_info);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  } else {
+    auto gpu_op = ConvolutionTransposed(op_def, attr, gpu_info);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  }
+}
+
+std::unique_ptr<ComputeTaskDescriptor> SelectQuantizeAndDequantize(
+    const OperationDef& op_def, const QuantizeAndDequantizeAttributes& attr) {
+  auto gpu_op = QuantizeAndDequantize(op_def, attr);
+  return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+}
+
+std::unique_ptr<ComputeTaskDescriptor> SelectPReLU(
+    const OperationDef& op_def, const BHWC& src_shape,
+    const PReLUAttributes& attr) {
+  auto alpha = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
+  if (alpha) {
+    auto gpu_op = PReLU(op_def, attr);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  }
+  auto alpha3d = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
+  if (!alpha3d) {
+    return {};
+  }
+  if (alpha3d->shape.h != src_shape.h || alpha3d->shape.w != src_shape.w ||
+      alpha3d->shape.c != src_shape.c) {
+    return {};
+  }
+  auto gpu_op = PReLUFull(op_def, attr);
+  return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+}
+
+std::unique_ptr<ComputeTaskDescriptor> SelectReshape(
+    const OperationDef& op_def, const BHWC& src_shape,
+    const ReshapeAttributes& attr) {
+  if (src_shape.c % 4 == 0 && attr.new_shape.c % 4 == 0) {
+    auto gpu_op = Reshapex4(op_def, attr);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  } else {
+    auto gpu_op = Reshape(op_def, attr);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  }
+}
+
+std::unique_ptr<ComputeTaskDescriptor> SelectSoftmax(const OperationDef& op_def,
+                                                     const BHWC& src_shape,
+                                                     const GpuInfo& gpu_info) {
+  if (src_shape.w == 1 && src_shape.h == 1) {
+    auto gpu_op = Softmax1x1(op_def, gpu_info);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  } else {
+    auto gpu_op = Softmax(op_def);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  }
+}
+
+std::unique_ptr<ComputeTaskDescriptor> SelectSpaceToDepth(
+    const OperationDef& op_def, const SpaceToDepthAttributes& attr) {
+  auto gpu_op = SpaceToDepth(op_def, attr);
+  return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+}
+
+std::unique_ptr<ComputeTaskDescriptor> SelectWinograd4x4To36(
+    const OperationDef& op_def, const Winograd4x4To36Attributes& attr,
+    const GpuInfo& gpu_info) {
+  if (gpu_info.IsApple()) {
+    auto gpu_op = Winograd4x4To36(op_def, attr);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  } else {
+    auto gpu_op = Winograd4x4To36TileX6(op_def, attr);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  }
+}
+
+std::unique_ptr<ComputeTaskDescriptor> SelectWinograd36To4x4(
+    const OperationDef& op_def, const Winograd36To4x4Attributes& attr,
+    const GpuInfo& gpu_info) {
+  if (gpu_info.IsApple()) {
+    auto gpu_op = Winograd36To4x4(op_def, attr);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  } else {
+    auto gpu_op = Winograd36To4x4Tile4x1(op_def, attr);
+    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  }
+}
+
+bool IsRecommendedForWinograd4x4To6x6(const Convolution2DAttributes& attr,
+                                      const GpuInfo& gpu_info,
+                                      const BHWC& dst_shape) {
+  const int tiles_x = DivideRoundUp(dst_shape.w, 4);
+  const int tiles_y = DivideRoundUp(dst_shape.h, 4);
+  const int total_tiles = tiles_x * tiles_y;
+  const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
+  const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
+  int min_depth = 16;
+  const int min_tiles = 32;
+  if (total_tiles >= min_tiles * 8) {
+    min_depth /= 4;
+    min_depth = std::max(min_depth, 8);
+  } else if (total_tiles >= min_tiles * 4) {
+    min_depth /= 2;
+    min_depth = std::max(min_depth, 8);
+  }
+  const bool recommended_channels =
+      src_depth >= min_depth && dst_depth >= min_depth;
+  const bool recommended_hw = total_tiles >= min_tiles;
+  return recommended_channels && recommended_hw;
+}
+
+absl::Status WinogradFromNode(const GpuInfo& gpu_info,
+                              const std::vector<Value*>& inputs,
+                              const std::vector<Value*>& outputs,
+                              const OperationDef& op_def,
+                              const BHWC& input_shape, const BHWC& output_shape,
+                              const Convolution2DAttributes& attr,
+                              GPUOperationsSubgraph* gpu_subgraph) {
+  if (!IsSuitableForWinograd4x4To6x6(attr)) {
+    return absl::UnimplementedError("No implementation for this case.");
+  }
+  if (!IsRecommendedForWinograd4x4To6x6(attr, gpu_info, output_shape)) {
+    return absl::UnimplementedError("Not recommended for this case.");
+  }
+
+  const int tiles_x = DivideRoundUp(output_shape.w, 4);
+  const int tiles_y = DivideRoundUp(output_shape.h, 4);
+  const BHWC shape_0{input_shape.b, 36, tiles_x * tiles_y, input_shape.c};
+  const BHWC shape_1{input_shape.b, 36, tiles_x * tiles_y, output_shape.c};
+  TensorDescriptor tensor_desc = op_def.src_tensors[0];
+  gpu_subgraph->new_tensors = {{shape_0, tensor_desc}, {shape_1, tensor_desc}};
+  gpu_subgraph->operations.clear();
+  gpu_subgraph->operations.resize(3);
+
+  OperationDef winograd_up_def;
+  winograd_up_def.precision = op_def.precision;
+  winograd_up_def.src_tensors.push_back(op_def.src_tensors[0]);
+  winograd_up_def.dst_tensors.push_back(op_def.src_tensors[0]);
+  auto& winograd_up = gpu_subgraph->operations[0];
+  Winograd4x4To36Attributes wino_up_attr;
+  wino_up_attr.padding = attr.padding;
+  winograd_up.operation =
+      SelectWinograd4x4To36(winograd_up_def, wino_up_attr, gpu_info);
+  winograd_up.input_ids = {static_cast<int>(inputs[0]->id)};
+  winograd_up.output_ids = {-1};
+
+  OperationDef conv_def;
+  conv_def.precision = op_def.precision;
+  conv_def.src_tensors.push_back(op_def.src_tensors[0]);
+  conv_def.dst_tensors.push_back(op_def.src_tensors[0]);
+  auto& conv = gpu_subgraph->operations[1];
+  conv.input_ids = {-1};
+  conv.output_ids = {-2};
+  auto gpu_op = ConvolutionWino4x4To6x6(conv_def, shape_1, attr, gpu_info);
+  conv.operation = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+  OperationDef winograd_down_def;
+  winograd_down_def.precision = op_def.precision;
+  winograd_down_def.src_tensors.push_back(op_def.src_tensors[0]);
+  winograd_down_def.dst_tensors.push_back(op_def.dst_tensors[0]);
+  auto& winograd_down = gpu_subgraph->operations[2];
+  winograd_down.input_ids = {-2};
+  winograd_down.output_ids = {static_cast<int>(outputs[0]->id)};
+  Winograd36To4x4Attributes wino_down_attr;
+  wino_down_attr.output_shape = outputs[0]->tensor.shape;
+  wino_down_attr.biases = attr.bias;
+  winograd_down.operation =
+      SelectWinograd36To4x4(winograd_down_def, wino_down_attr, gpu_info);
+  return absl::OkStatus();
+}
+
+}  // namespace
+
+absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
+                                  const OperationDef& op_def,
+                                  const std::vector<Value*>& inputs,
+                                  const std::vector<Value*>& outputs,
+                                  const Node& node,
+                                  GPUOperationsSubgraph* gpu_subgraph) {
+  std::unique_ptr<ComputeTaskDescriptor>* task =
+      InitSingleOpSubgraph(inputs, outputs, gpu_subgraph);
+  auto op_type = OperationTypeFromString(node.operation.type);
+  switch (op_type) {
+    case OperationType::ADD: {
+      if (inputs.size() == 1) {
+        if (node.operation.attributes.has_value()) {
+          auto attr =
+              absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
+          auto gpu_op = ElementwiseWithOneInputAndConstantArguent(
+              op_def, op_type, attr.param);
+          *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+        } else {
+          return absl::UnimplementedError(
+              "Missing attributes for single input op: " + node.operation.type);
+        }
+      } else if (inputs.size() == 2) {
+        auto gpu_op =
+            ElementwiseWithTwoInputs(op_def, inputs[1]->tensor.shape, op_type);
+        *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      } else {  // more than 2 inputs
+        auto gpu_op = Add(op_def);
+        *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      }
+      break;
+    }
+    case OperationType::CONCAT: {
+      std::vector<BHWC> input_shapes;
+      for (auto& input : inputs) {
+        input_shapes.push_back(input->tensor.shape);
+      }
+      auto gpu_op = Concat(
+          op_def, absl::any_cast<ConcatAttributes>(node.operation.attributes),
+          input_shapes);
+      *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      break;
+    }
+    case OperationType::CONVOLUTION_2D: {
+      if (inputs.size() != 1) {
+        return absl::UnimplementedError(
+            "Convolution does not support more than 1 runtime tensor");
+      }
+      auto attr =
+          absl::any_cast<Convolution2DAttributes>(node.operation.attributes);
+      auto input_shape = inputs[0]->tensor.shape;
+      auto output_shape = outputs[0]->tensor.shape;
+      if (WinogradFromNode(gpu_info, inputs, outputs, op_def, input_shape,
+                           output_shape, attr, gpu_subgraph)
+              .ok()) {
+        return absl::OkStatus();
+      } else {
+        auto gpu_op = ConvolutionGeneric(op_def, output_shape, attr, gpu_info);
+        *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      }
+      break;
+    }
+    case OperationType::CONVOLUTION_TRANSPOSED:
+      if (inputs.size() != 1) {
+        return absl::UnimplementedError(
+            "Convolution Transposed does not support more than 1 runtime "
+            "tensor");
+      }
+      *task = SelectConvolutionTransposed(
+          op_def,
+          absl::any_cast<ConvolutionTransposedAttributes>(
+              node.operation.attributes),
+          gpu_info);
+      break;
+    case OperationType::DEPTHWISE_CONVOLUTION:
+      if (inputs.size() != 1) {
+        return absl::UnimplementedError(
+            "DepthWise Convolution does not support more than 1 runtime "
+            "tensor");
+      }
+      *task = SelectDepthWiseConv(
+          op_def, absl::any_cast<DepthwiseConvolution2DAttributes>(
+                      node.operation.attributes));
+      break;
+    case OperationType::FULLY_CONNECTED: {
+      auto gpu_op = FullyConnected(
+          op_def,
+          absl::any_cast<FullyConnectedAttributes>(node.operation.attributes),
+          gpu_info);
+      *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      break;
+    }
+    case OperationType::MAX_UNPOOLING_2D: {
+      auto gpu_op = MaxUnpooling(
+          op_def,
+          absl::any_cast<MaxUnpooling2DAttributes>(node.operation.attributes));
+      *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      break;
+    }
+    case OperationType::MEAN: {
+      auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
+      if (attr.dims != std::set<Axis>({Axis::HEIGHT, Axis::WIDTH})) {
+        return absl::UnimplementedError("Mean supports HW axis only in Metal");
+      }
+      auto gpu_op = Mean(op_def, attr);
+      *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      break;
+    }
+    case OperationType::MUL:
+      if (inputs.size() == 1) {
+        if (node.operation.attributes.has_value()) {
+          auto attr =
+              absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
+          auto gpu_op = ElementwiseWithOneInputAndConstantArguent(
+              op_def, op_type, attr.param);
+          *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+        } else {
+          return absl::UnimplementedError(
+              "Missing attributes for single input op: " + node.operation.type);
+        }
+      } else if (inputs.size() == 2) {
+        auto gpu_op =
+            ElementwiseWithTwoInputs(op_def, inputs[1]->tensor.shape, op_type);
+        *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      }
+      break;
+    case OperationType::PAD: {
+      auto attr = absl::any_cast<PadAttributes>(node.operation.attributes);
+      if (attr.appended.b != 0 || attr.prepended.b != 0) {
+        return absl::UnimplementedError("Padding for BATCH is not supported.");
+      }
+      auto gpu_op = Padding(op_def, attr);
+      *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      break;
+    }
+    case OperationType::POOLING_2D: {
+      auto attr =
+          absl::any_cast<Pooling2DAttributes>(node.operation.attributes);
+      auto pooling_op_def = op_def;
+      pooling_op_def.dst_tensors = {op_def.dst_tensors[0]};
+      auto gpu_op = Pooling(op_def, attr, false);
+      gpu_subgraph->operations[0].operation =
+          absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      gpu_subgraph->operations[0].input_ids = {static_cast<int>(inputs[0]->id)};
+      gpu_subgraph->operations[0].output_ids = {
+          static_cast<int>(outputs[0]->id)};
+      if (attr.type == PoolingType::MAX && attr.output_indices) {
+        gpu_subgraph->operations.push_back({});
+        auto gpu_ind_op = Pooling(op_def, attr, true);
+        gpu_subgraph->operations[1].operation =
+            absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_ind_op));
+        gpu_subgraph->operations[1].input_ids = {
+            static_cast<int>(inputs[0]->id)};
+        gpu_subgraph->operations[1].output_ids = {
+            static_cast<int>(outputs[1]->id)};
+      }
+      break;
+    }
+    case OperationType::PRELU: {
+      const auto src_shape = inputs[0]->tensor.shape;
+      *task = SelectPReLU(
+          op_def, src_shape,
+          absl::any_cast<PReLUAttributes>(node.operation.attributes));
+      break;
+    }
+    case OperationType::RELU: {
+      auto gpu_op = ReLU(
+          op_def, absl::any_cast<ReLUAttributes>(node.operation.attributes));
+      *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      break;
+    }
+    case OperationType::QUANTIZE_AND_DEQUANTIZE:
+      *task = SelectQuantizeAndDequantize(
+          op_def, absl::any_cast<QuantizeAndDequantizeAttributes>(
+                      node.operation.attributes));
+      break;
+    case OperationType::RESHAPE: {
+      const auto src_shape = inputs[0]->tensor.shape;
+      *task = SelectReshape(
+          op_def, src_shape,
+          absl::any_cast<ReshapeAttributes>(node.operation.attributes));
+      break;
+    }
+    case OperationType::RESIZE: {
+      auto gpu_op =
+          Resize(op_def,
+                 absl::any_cast<Resize2DAttributes>(node.operation.attributes));
+      *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      break;
+    }
+    case OperationType::SLICE: {
+      auto gpu_op = Slice(
+          op_def, absl::any_cast<SliceAttributes>(node.operation.attributes));
+      *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      break;
+    }
+    case OperationType::SOFTMAX: {
+      auto attr = absl::any_cast<SoftmaxAttributes>(node.operation.attributes);
+      if (attr.axis != Axis::CHANNELS) {
+        return absl::UnimplementedError(
+            "Softmax supports only CHANNELS dimension");
+      }
+      const auto src_shape = inputs[0]->tensor.shape;
+      *task = SelectSoftmax(op_def, src_shape, gpu_info);
+      break;
+    }
+    case OperationType::SPACE_TO_DEPTH:
+      *task = SelectSpaceToDepth(op_def, absl::any_cast<SpaceToDepthAttributes>(
+                                             node.operation.attributes));
+      break;
+    case OperationType::ABS:
+    case OperationType::COPY:
+    case OperationType::COS:
+    case OperationType::ELU:
+    case OperationType::EXP:
+    case OperationType::HARD_SWISH:
+    case OperationType::LOG:
+    case OperationType::NEG:
+    case OperationType::RSQRT:
+    case OperationType::SIGMOID:
+    case OperationType::SIN:
+    case OperationType::SQRT:
+    case OperationType::SQUARE:
+    case OperationType::TANH: {
+      auto gpu_op = ElementwiseWithOneInput(op_def, op_type);
+      *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      break;
+    }
+    case OperationType::DIV:
+    case OperationType::MAXIMUM:
+    case OperationType::MINIMUM:
+    case OperationType::POW:
+    case OperationType::SQUARED_DIFF:
+    case OperationType::SUB: {
+      if (inputs.size() == 1) {
+        if (node.operation.attributes.has_value()) {
+          auto attr =
+              absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
+          auto gpu_op = ElementwiseWithOneInputAndConstantArguent(
+              op_def, op_type, attr.param);
+          *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+        } else {
+          return absl::UnimplementedError(
+              "Missing attributes for single input op: " + node.operation.type);
+        }
+      } else if (inputs.size() == 2) {
+        auto gpu_op =
+            ElementwiseWithTwoInputs(op_def, inputs[1]->tensor.shape, op_type);
+        *task = absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+      }
+    } break;
+    case OperationType::BATCH_NORMALIZATION:
+    case OperationType::BATCH_TO_SPACE:
+    case OperationType::BATCHED_MATMUL:
+    case OperationType::CONST:
+    case OperationType::LSTM:
+    // TODO(b/162763635): implement MeanStddevNormalization for Metal.
+    case OperationType::MEAN_STDDEV_NORMALIZATION:
+    case OperationType::REDUCE_MAXIMUM:
+    case OperationType::REDUCE_MINIMUM:
+    case OperationType::REDUCE_PRODUCT:
+    case OperationType::REDUCE_SUM:
+    // comparison operations
+    case OperationType::LESS:
+    case OperationType::LESS_EQUAL:
+    case OperationType::EQUAL:
+    case OperationType::NOT_EQUAL:
+    case OperationType::GREATER:
+    case OperationType::GREATER_EQUAL:
+    case OperationType::SPACE_TO_BATCH:
+    case OperationType::TRANSPOSE:
+      return absl::UnimplementedError("Unsupported op: " + node.operation.type);
+    default:
+      return SelectDefault(gpu_info, op_def, inputs, outputs, node,
+                           gpu_subgraph);
+  }
+  return absl::OkStatus();
+}
+
+}  // namespace metal
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h b/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.h
similarity index 64%
copy from tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h
copy to tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.h
index b81bdaa..64f6b29 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h
+++ b/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.h
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,31 +13,27 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_OPERATION_SELECTOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_OPERATION_SELECTOR_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_OPERATION_SELECTOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_OPERATION_SELECTOR_H_
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/model_hints.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
-#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
+#include "tensorflow/lite/delegates/gpu/metal/selectors/subgraph.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
+namespace metal {
 
 absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
-                                  const OperationDef& op_def, ModelHints hints,
+                                  const OperationDef& op_def,
                                   const std::vector<Value*>& inputs,
                                   const std::vector<Value*>& outputs,
                                   const Node& node,
                                   GPUOperationsSubgraph* gpu_subgraph);
-
-}  // namespace cl
+}  // namespace metal
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_OPERATION_SELECTOR_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_OPERATION_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc b/tensorflow/lite/delegates/gpu/metal/selectors/subgraph.cc
similarity index 85%
copy from tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc
copy to tensorflow/lite/delegates/gpu/metal/selectors/subgraph.cc
index cd3c987..5d84f75 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc
+++ b/tensorflow/lite/delegates/gpu/metal/selectors/subgraph.cc
@@ -13,19 +13,19 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
+#include "tensorflow/lite/delegates/gpu/metal/selectors/subgraph.h"
 
 #include <memory>
 
 #include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
+#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
+namespace metal {
 
-std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
+std::unique_ptr<ComputeTaskDescriptor>* InitSingleOpSubgraph(
     const std::vector<Value*>& inputs, const std::vector<Value*>& outputs,
     GPUOperationsSubgraph* gpu_subgraph) {
   gpu_subgraph->operations.clear();
@@ -41,6 +41,6 @@
   return &gpu_subgraph->operations[0].operation;
 }
 
-}  // namespace cl
+}  // namespace metal
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h b/tensorflow/lite/delegates/gpu/metal/selectors/subgraph.h
similarity index 77%
copy from tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h
copy to tensorflow/lite/delegates/gpu/metal/selectors/subgraph.h
index f94e0c4..9fe12d6 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h
+++ b/tensorflow/lite/delegates/gpu/metal/selectors/subgraph.h
@@ -13,22 +13,22 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SUBGRAPH_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SUBGRAPH_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_SUBGRAPH_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_SUBGRAPH_H_
 
 #include <memory>
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
+#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
+namespace metal {
 
 struct GPUOperationWithRefs {
-  std::unique_ptr<GPUOperation> operation;
+  std::unique_ptr<ComputeTaskDescriptor> operation;
 
   // input and output ids can be positive or negative.
   // if we have positive id, we will use preallocated tensor from GraphFloat32
@@ -42,12 +42,12 @@
   std::vector<std::pair<BHWC, TensorDescriptor>> new_tensors;
 };
 
-std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
+std::unique_ptr<ComputeTaskDescriptor>* InitSingleOpSubgraph(
     const std::vector<Value*>& inputs, const std::vector<Value*>& outputs,
     GPUOperationsSubgraph* gpu_subgraph);
 
-}  // namespace cl
+}  // namespace metal
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SUBGRAPH_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_SUBGRAPH_H_
diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm
index 782fdd8..ceed25a 100644
--- a/tensorflow/lite/delegates/gpu/metal_delegate.mm
+++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm
@@ -45,10 +45,11 @@
 #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
-#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
+#include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/minimal_logging.h"
 
+
 namespace tflite {
 namespace gpu {
 namespace metal {
@@ -229,8 +230,8 @@
     return absl::NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index));
   }
 
-  void SetCommandEncoder(id<MTLComputeCommandEncoder> encoder) {
-    external_command_encoder_ = encoder;
+  void SetCommandBuffer(id<MTLCommandBuffer> command_buffer) {
+    external_command_buffer_ = command_buffer;
   }
 
   // This directs the runtime to allocate memory for input/output temporary
@@ -338,19 +339,17 @@
     GpuInfo gpu_info;
     GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info);
     size_t storage_type_size;
-    RuntimeOptions runtime_options;
+    CalculationsPrecision precision;
     if (options_.allow_precision_loss) {
       storage_type_size = sizeof(HalfBits);
-      runtime_options.storage_precision = RuntimeOptions::Precision::FP16;
       if (gpu_info.IsRoundToNearestSupported()) {
-        runtime_options.accumulator_precision = RuntimeOptions::Precision::FP16;
+        precision = CalculationsPrecision::F16;
       } else {
-        runtime_options.accumulator_precision = RuntimeOptions::Precision::FP32;
+        precision = CalculationsPrecision::F32_F16;
       }
     } else {
       storage_type_size = sizeof(float);
-      runtime_options.storage_precision = RuntimeOptions::Precision::FP32;
-      runtime_options.accumulator_precision = RuntimeOptions::Precision::FP32;
+      precision = CalculationsPrecision::F32;
     }
 
     // TODO(impjdi): Merge logic with above.
@@ -435,16 +434,12 @@
 
     // TODO(impjdi): Merge these.
     CompiledModel compiled_model;
-    RETURN_IF_ERROR(Compile(graph, gpu_info, runtime_options, &compiled_model));
+    RETURN_IF_ERROR(Compile(graph, gpu_info, precision, &compiled_model));
     CompiledModel optimized_model;
     RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model, &optimized_model));
 
-    inference_context_ = [[TFLInferenceContext alloc] init];
-    RETURN_IF_ERROR([inference_context_ compileModelWithDevice:metal_device_
-                                                         model:optimized_model
-                                                inputBufferIDs:input_ids
-                                               outputBufferIDs:output_ids
-                                                runtimeOptions:runtime_options]);
+    RETURN_IF_ERROR(inference_context_.CompileModelWithDevice(metal_device_, optimized_model,
+                                                              input_ids, output_ids, precision));
     return absl::OkStatus();
   }
 
@@ -454,12 +449,14 @@
     // We need only synchronization so volatile works better than atomic which reads from global
     // memory each time.
     __block volatile bool buffer_completed = false;
-    __block id<MTLCommandBuffer> command_buffer;
-    __block id<MTLComputeCommandEncoder> encoder = external_command_encoder_;
-    if (external_command_encoder_ == nil) {
+    id<MTLCommandBuffer> command_buffer = external_command_buffer_;
+    if (external_command_buffer_ == nil) {
       command_buffer = [command_queue_ commandBuffer];
-      encoder = [command_buffer computeCommandEncoder];
     }
+    const bool flush = external_command_buffer_ == nil &&
+        (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive ||
+         options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive);
+    const int flush_period = 8;
 
     const bool is_quantized_model = !quant_conversion_map_.empty();
     if (is_quantized_model) {
@@ -475,51 +472,41 @@
       void* gpu_ptr = [input_output_buffers_[input.id] contents];
       std::memcpy(gpu_ptr, tensor->data.f, input.shape.DimensionsProduct() * sizeof(float));
       if (input_output_buffers_[input.id] == bphwc4_buffers_[input.id]) continue;
-      [converter_to_BPHWC4_ convertWithEncoder:encoder
+      id<MTLComputeCommandEncoder> input_encoder = [command_buffer computeCommandEncoder];
+      [converter_to_BPHWC4_ convertWithEncoder:input_encoder
                                          shape:input.shape
                                   sourceBuffer:input_output_buffers_[input.id]
                                convertedBuffer:bphwc4_buffers_[input.id]];
-      if (external_command_encoder_ == nil) {
-        [encoder endEncoding];
+      [input_encoder endEncoding];
+    }
+
+    @autoreleasepool {
+      if (flush) {
         [command_buffer commit];
+        inference_context_.EncodeWithCommandQueue(command_queue_, bphwc4_buffers_, flush_period);
         command_buffer = [command_queue_ commandBuffer];
-        encoder = [command_buffer computeCommandEncoder];
+      } else {
+        inference_context_.EncodeWithCommandBuffer(command_buffer, bphwc4_buffers_);
       }
     }
 
-    [inference_context_
-         encodeWithEncoder:encoder
-        inputOutputBuffers:bphwc4_buffers_
-              encoderBlock:^(bool isLast) {
-                if (external_command_encoder_ != nil ||
-                    options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive) {
-                  return encoder;
-                }
-                if (isLast) {
-                  if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) {
-                    [command_buffer addCompletedHandler:^(id<MTLCommandBuffer>) {
-                      buffer_completed = true;
-                    }];
-                  }
-                } else {
-                  [encoder endEncoding];
-                  [command_buffer commit];
-                  command_buffer = [command_queue_ commandBuffer];
-                  encoder = [command_buffer computeCommandEncoder];
-                }
-                return encoder;
-              }];
     for (const auto& output : graph_outputs_) {
       if (output.set_externally) continue;
       if (bphwc4_buffers_[output.id] == input_output_buffers_[output.id]) continue;
-      [converter_from_BPHWC4_ convertWithEncoder:encoder
+      id<MTLComputeCommandEncoder> output_encoder = [command_buffer computeCommandEncoder];
+      [converter_from_BPHWC4_ convertWithEncoder:output_encoder
                                            shape:output.shape
                                     sourceBuffer:bphwc4_buffers_[output.id]
                                  convertedBuffer:input_output_buffers_[output.id]];
+      [output_encoder endEncoding];
     }
 
-    if (external_command_encoder_ == nil) {
-      [encoder endEncoding];
+    if (external_command_buffer_ == nil) {
+      if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) {
+        [command_buffer addCompletedHandler:^(id<MTLCommandBuffer>) {
+          buffer_completed = true;
+        }];
+      }
       [command_buffer commit];
       if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) {
         while (!buffer_completed) {
@@ -531,16 +518,16 @@
         // passive wait: this thread sleeps until GPU finishes.
         [command_buffer waitUntilCompleted];
       } else if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) {
-        command_buffer = [command_queue_ commandBuffer];
-        encoder = [command_buffer computeCommandEncoder];
-        [encoder setComputePipelineState:signal_program_];
-        [encoder setBuffer:signal_buffer_ offset:0 atIndex:0];
+        id<MTLCommandBuffer> signal_cb = [command_queue_ commandBuffer];
+        id<MTLComputeCommandEncoder> signal_encoder = [signal_cb computeCommandEncoder];
+        [signal_encoder setComputePipelineState:signal_program_];
+        [signal_encoder setBuffer:signal_buffer_ offset:0 atIndex:0];
         signal_value_++;
-        [encoder setBytes:&signal_value_ length:sizeof(int) atIndex:1];
-        [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
+        [signal_encoder setBytes:&signal_value_ length:sizeof(int) atIndex:1];
+        [signal_encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
                 threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-        [encoder endEncoding];
-        [command_buffer commit];
+        [signal_encoder endEncoding];
+        [signal_cb commit];
         gpu_alarm_clock_->Start();
         const int* signal_ptr = reinterpret_cast<const int*>([signal_buffer_ contents]);
         while (signal_ptr[0] != signal_value_) {
@@ -550,9 +537,9 @@
         }
       }
     } else {
-      // External command encoder must be set before every invoke call.
-      external_command_encoder_ = nil;
-      // External command encoder is assigned so all output buffers are controlled by a user.
+      // External command buffer must be set before every invoke call.
+      external_command_buffer_ = nil;
+      // External command buffer is assigned so all output buffers are controlled by a user.
       for (const auto& output : graph_outputs_) {
         if (!output.set_externally) {
           return absl::InternalError(
@@ -604,7 +591,7 @@
   // model_builder - and vice versa.
   absl::flat_hash_map<int, int> quant_conversion_map_;
 
-  TFLInferenceContext* inference_context_;
+  InferenceContext inference_context_;
   // input and output buffers are passed into Metal inference engine
   std::map<::tflite::gpu::ValueId, id<MTLBuffer>> input_output_buffers_;
   std::map<::tflite::gpu::ValueId, id<MTLBuffer>> bphwc4_buffers_;
@@ -620,7 +607,7 @@
   std::vector<BufferDescriptor> graph_inputs_;
   std::vector<BufferDescriptor> graph_outputs_;
 
-  id<MTLComputeCommandEncoder> external_command_encoder_ = nil;
+  id<MTLCommandBuffer> external_command_buffer_ = nil;
   id<MTLCommandQueue> command_queue_;
   std::unique_ptr<GpuAlarmClock> gpu_alarm_clock_;
   id<MTLComputePipelineState> signal_program_;
@@ -712,11 +699,11 @@
 
 // Note: This function is not exposed in `metal_delegate.h`, but it's exposed in
 // `metal_delegate_internal.h`.
-bool TFLGpuDelegateSetCommandEncoder(
-    TfLiteDelegate* delegate, id<MTLComputeCommandEncoder> encoder) {
+bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate,
+                                    id<MTLCommandBuffer> command_buffer) {
   auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate);
   if (!metal_delegate) return false;
-  metal_delegate->SetCommandEncoder(encoder);
+  metal_delegate->SetCommandBuffer(command_buffer);
   return true;
 }
 
diff --git a/tensorflow/lite/delegates/gpu/metal_delegate_internal.h b/tensorflow/lite/delegates/gpu/metal_delegate_internal.h
index 1f35bda..121caef 100644
--- a/tensorflow/lite/delegates/gpu/metal_delegate_internal.h
+++ b/tensorflow/lite/delegates/gpu/metal_delegate_internal.h
@@ -33,9 +33,9 @@
                                            int tensor_index,
                                            id<MTLBuffer> metal_buffer);
 
-// Binds user-defined MTLComputeCommandEncoder. The delegate puts all GPU tasks
-// into this encoder instead of the internal encoder.
-bool TFLGpuDelegateSetCommandEncoder(TfLiteDelegate* delegate,
-                                     id<MTLComputeCommandEncoder> encoder);
+// Binds user-defined MTLCommandBuffer. The delegate puts all GPU tasks
+// into this buffer instead of the internal command buffer.
+bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate,
+                                    id<MTLCommandBuffer> command_buffer);
 
 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_DELEGATE_INTERNAL_H_
diff --git a/tensorflow/lite/delegates/hexagon/java/src/main/native/hexagon_delegate_jni.cc b/tensorflow/lite/delegates/hexagon/java/src/main/native/hexagon_delegate_jni.cc
index 9254b82..5d79946 100644
--- a/tensorflow/lite/delegates/hexagon/java/src/main/native/hexagon_delegate_jni.cc
+++ b/tensorflow/lite/delegates/hexagon/java/src/main/native/hexagon_delegate_jni.cc
@@ -18,9 +18,7 @@
 
 #include "tensorflow/lite/delegates/hexagon/hexagon_delegate.h"
 
-#ifdef __cplusplus
 extern "C" {
-#endif
 
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_HexagonDelegate_createDelegate(
@@ -51,6 +49,4 @@
              : JNI_FALSE;
 }
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif  // __cplusplus
diff --git a/tensorflow/lite/delegates/hexagon/utils.cc b/tensorflow/lite/delegates/hexagon/utils.cc
index 397400c..1a179a3 100644
--- a/tensorflow/lite/delegates/hexagon/utils.cc
+++ b/tensorflow/lite/delegates/hexagon/utils.cc
@@ -94,6 +94,7 @@
     case kTfLiteBuiltinSlice:
     case kTfLiteBuiltinSoftmax:
     case kTfLiteBuiltinSpaceToDepth:
+    case kTfLiteBuiltinDepthToSpace:
     case kTfLiteBuiltinSplit:
     case kTfLiteBuiltinStridedSlice:
     case kTfLiteBuiltinSub:
diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD
index b3c6eff..86694d9 100644
--- a/tensorflow/lite/delegates/nnapi/BUILD
+++ b/tensorflow/lite/delegates/nnapi/BUILD
@@ -135,6 +135,7 @@
         ":nnapi_delegate_mock_test",
         "//tensorflow/lite:framework",
         "//tensorflow/lite/c:common",
+        "//tensorflow/lite/kernels:deprecated_backends",
         "//tensorflow/lite/kernels:test_util",
         "//tensorflow/lite/nnapi:nnapi_implementation",
         "//tensorflow/lite/nnapi:nnapi_lib",
diff --git a/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc b/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc
index c94a523..e3cb57b 100644
--- a/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc
+++ b/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc
@@ -17,12 +17,10 @@
 
 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
 
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
 using namespace tflite;
 
+extern "C" {
+
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_nnapi_NnApiDelegate_createDelegate(
     JNIEnv* env, jclass clazz, jint preference, jstring accelerator_name,
@@ -87,6 +85,4 @@
   delete reinterpret_cast<StatefulNnApiDelegate*>(delegate);
 }
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif  // __cplusplus
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index 279984c..8984650 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -221,6 +221,8 @@
     case kTfLiteBuiltinPow:
     case kTfLiteBuiltinMaximum:
     case kTfLiteBuiltinMinimum:
+    case kTfLiteBuiltinPrelu:
+    case kTfLiteBuiltinLeakyRelu:
       return true;
     default:
       return false;
@@ -2448,6 +2450,7 @@
             &val_ctx);
       }
     } break;
+    case kTfLiteBuiltinLeakyRelu:
     case kTfLiteBuiltinPrelu: {
       ExpectOpVersion(version, 1, &val_ctx);
       ExpectMinAndroidSdkVersion(android_sdk_version, kMinSdkVersionForNNAPI12,
@@ -3357,6 +3360,38 @@
     case kTfLiteBuiltinCast: {
       *nn_op_type = ANEURALNETWORKS_CAST;
     } break;
+    case kTfLiteBuiltinLeakyRelu: {
+      const auto input_type =
+          mapping_args.context->tensors[mapping_args.node->inputs->data[0]]
+              .type;
+      auto builtin = reinterpret_cast<TfLiteLeakyReluParams*>(
+          mapping_args.node->builtin_data);
+
+      TfLiteTensor alpha_tensor;
+      alpha_tensor.type = input_type;
+      alpha_tensor.allocation_type = kTfLiteDynamic;
+      alpha_tensor.dims = TfLiteIntArrayCreate(1);
+      alpha_tensor.dims->data[0] = 1;
+      alpha_tensor.params.zero_point = 0;
+
+      int new_tensor_index = -1;
+      if (input_type == kTfLiteFloat32) {
+        alpha_tensor.params.scale = 0;
+        std::vector<float> alpha_value = {builtin->alpha};
+        mapping_args.builder->AddNewInputConstantTensor(
+            ANEURALNETWORKS_TENSOR_FLOAT32, kTfLiteFloat32, alpha_tensor.dims,
+            alpha_value, alpha_tensor.params, &new_tensor_index);
+      } else {
+        alpha_tensor.params.scale = builtin->alpha;
+        std::vector<uint8_t> alpha_value = {1};
+        mapping_args.builder->AddNewInputConstantTensor(
+            ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
+            alpha_tensor.dims, alpha_value, alpha_tensor.params,
+            &new_tensor_index);
+      }
+
+      *nn_op_type = ANEURALNETWORKS_PRELU;
+    } break;
     case kTfLiteBuiltinPrelu: {
       *nn_op_type = ANEURALNETWORKS_PRELU;
     } break;
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
index 38e4fcd..16e7a26 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -2056,7 +2056,13 @@
 
   m.Invoke();
 
+#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
+    __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+  // Hash returns differently on machines with different endianness
+  EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 1, 1, 1, 0));
+#else
   EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0));
+#endif
 }
 
 TEST(NNAPIDelegate, LSHProjectionSparse1DInputs) {
@@ -2067,7 +2073,13 @@
 
   m.Invoke();
 
+#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
+    __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+  // Hash returns differently on machines with different endianness
+  EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 3, 8 + 2));
+#else
   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0));
+#endif
 }
 
 TEST(NNAPIDelegate, LSHProjectionSparse3DInputs) {
@@ -2080,7 +2092,13 @@
 
   m.Invoke();
 
+#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
+    __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+  // Hash returns differently on machines with different endianness
+  EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 3, 8 + 2));
+#else
   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1));
+#endif
 }
 
 class BaseActivationsOpModel : public SingleOpModelWithNNAPI {
@@ -5316,6 +5334,69 @@
   AdvancedDynamicValuedTest<int8_t, TensorType_INT8>();
 }
 
+// A base class of Leaky ReLU op model. It provides the constructor for
+// FloatLeakyReluOpModel and QuantizedLeakyReluOpModel.
+class LeakyReluOpModel : public SingleOpModelWithNNAPI {
+ public:
+  LeakyReluOpModel(const TensorData& input, const float& alpha)
+      : input_type_(input.type) {
+    input_ = AddInput(input);
+    output_ = AddOutput({input.type, input.shape, input.min, input.max});
+
+    SetBuiltinOp(BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions,
+                 CreateLeakyReluOptions(builder_, alpha).Union());
+    BuildInterpreterWithNNAPI({GetShape(input_)});
+  }
+
+  void SetInput(std::initializer_list<float> data) {
+    SetData(input_, input_type_, data);
+  }
+
+  std::vector<float> GetOutput() {
+    std::vector<float> output;
+    GetData(output_, input_type_, &output);
+    return output;
+  }
+
+ protected:
+  int input_;
+  int output_;
+
+  const TensorType input_type_;
+};
+
+TEST(NNAPIDelegate, LeakyReluFloat) {
+  LeakyReluOpModel m({TensorType_FLOAT32, {2, 3}}, 0.5f);
+
+  m.SetInput({
+      0.0f, 1.0f, 3.0f,    // Row 1
+      1.0f, -1.0f, -2.0f,  // Row 2
+  });
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+                                 0.0f, 1.0f, 3.0f,    // Row 1
+                                 1.0f, -0.5f, -1.0f,  // Row 2
+
+                             }));
+}
+
+TEST(NNAPIDelegate, LeakyReluQuantized) {
+  const float kMin = -1;
+  const float kMax = 127.f / 128.f;
+  LeakyReluOpModel m({TensorType_UINT8, {2, 3}, 8 * kMin, 8 * kMax}, 0.5f);
+  m.SetInput({
+      0.0f, 1.0f, 3.0f,    // Row 1
+      1.0f, -1.0f, -2.0f,  // Row 2
+  });
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+                                 {
+                                     0.0f, 1.0f, 3.0f,    // Row 1
+                                     1.0f, -0.5f, -1.0f,  // Row 2
+                                 },
+                                 kQuantizedTolerance)));
+}
+
 }  // namespace
 }  // namespace tflite
 
diff --git a/tensorflow/lite/delegates/status.h b/tensorflow/lite/delegates/status.h
deleted file mode 100644
index e56bf7c..0000000
--- a/tensorflow/lite/delegates/status.h
+++ /dev/null
@@ -1,83 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_LITE_DELEGATES_STATUS_H_
-#define TENSORFLOW_LITE_DELEGATES_STATUS_H_
-
-#include <cstdint>
-#include <limits>
-
-#include "tensorflow/lite/c/common.h"
-
-// This file defines data structures to represent detailed TFLite delegate
-// status, e.g. NNAPI delegate application failure because of a driver issue
-// etc. Such status is ONLY to be used for internal APIs.
-// Note, we simply use TfLiteStatus to represent high-level status while
-// delegate-specific status codes are defined with DelegateStatus.
-// WARNING: This is an experimental feature that is subject to change.
-namespace tflite {
-namespace delegates {
-
-// Defines the source of the code where it is generated from. We list all TFLite
-// delegates that're officially implemented and available as of April, 2020
-// (i.e. w/ 'TFLITE_' prefix to imply this).
-enum class DelegateStatusSource {
-  NONE = 0,
-  TFLITE_GPU = 1,
-  TFLITE_NNAPI = 2,
-  TFLITE_HEXAGON = 3,
-  TFLITE_XNNPACK = 4,
-  TFLITE_COREML = 5,
-  MAX_NUM_SOURCES = std::numeric_limits<int32_t>::max(),
-};
-
-// Defines the detailed status that combines a DelegateStatusSource and a
-// status int32_t code.
-class DelegateStatus {
- public:
-  DelegateStatus() : DelegateStatus(DelegateStatusSource::NONE, 0) {}
-  explicit DelegateStatus(int32_t code)
-      : DelegateStatus(DelegateStatusSource::NONE, code) {}
-  explicit DelegateStatus(int64_t full_status)
-      : DelegateStatus(
-            static_cast<DelegateStatusSource>(
-                full_status >> 32 &
-                static_cast<int32_t>(DelegateStatusSource::MAX_NUM_SOURCES)),
-            static_cast<int32_t>(full_status &
-                                 std::numeric_limits<int32_t>::max())) {}
-  DelegateStatus(DelegateStatusSource source, int32_t code)
-      : source_(static_cast<int32_t>(source)), code_(code) {}
-
-  // Return the detailed full status encoded as a int64_t value.
-  int64_t full_status() const {
-    return static_cast<int64_t>(source_) << 32 | code_;
-  }
-
-  DelegateStatusSource source() const {
-    return static_cast<DelegateStatusSource>(source_);
-  }
-
-  int32_t code() const { return code_; }
-
- private:
-  // value of a DelegateStatusSource, like DelegateStatusSource::TFLITE_GPU
-  int32_t source_;
-  // value of a status code, like kTfLiteOk.
-  int32_t code_;
-};
-
-}  // namespace delegates
-}  // namespace tflite
-
-#endif  // TENSORFLOW_LITE_DELEGATES_STATUS_H_
diff --git a/tensorflow/lite/delegates/telemetry.cc b/tensorflow/lite/delegates/telemetry.cc
new file mode 100644
index 0000000..cba8486
--- /dev/null
+++ b/tensorflow/lite/delegates/telemetry.cc
@@ -0,0 +1,49 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/telemetry.h"
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/profiler.h"
+
+namespace tflite {
+namespace delegates {
+
+// TODO(b/153131797): Add an IFTTT here once we have a profiler to interpret
+// these events, so that the two components don't go out of sync.
+
+TfLiteStatus ReportDelegateSettings(TfLiteContext* context,
+                                    TfLiteDelegate* delegate,
+                                    const TFLiteSettings& settings) {
+  auto* profiler = reinterpret_cast<Profiler*>(context->profiler);
+  const int64_t event_metadata1 = reinterpret_cast<int64_t>(delegate);
+  const int64_t event_metadata2 = reinterpret_cast<int64_t>(&settings);
+  TFLITE_ADD_RUNTIME_INSTRUMENTATION_EVENT(profiler, kDelegateSettingsTag,
+                                           event_metadata1, event_metadata2);
+  return kTfLiteOk;
+}
+
+TfLiteStatus ReportDelegateStatus(TfLiteContext* context,
+                                  TfLiteDelegate* delegate,
+                                  const DelegateStatus& status) {
+  auto* profiler = reinterpret_cast<Profiler*>(context->profiler);
+  TFLITE_ADD_RUNTIME_INSTRUMENTATION_EVENT(profiler, kDelegateStatusTag,
+                                           status.full_status(),
+                                           static_cast<int64_t>(kTfLiteOk));
+  return kTfLiteOk;
+}
+
+}  // namespace delegates
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/telemetry.h b/tensorflow/lite/delegates/telemetry.h
new file mode 100644
index 0000000..d7e92be
--- /dev/null
+++ b/tensorflow/lite/delegates/telemetry.h
@@ -0,0 +1,110 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_DELEGATES_STATUS_H_
+#define TENSORFLOW_LITE_DELEGATES_STATUS_H_
+
+#include <cstdint>
+#include <limits>
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
+
+// This file implements utilities for delegate telemetry. These enable
+// representation and reporting of hardware-specific configurations, status
+// codes, etc.
+// These APIs are for internal use *only*, and should be modified with care to
+// avoid incompatibilities between delegates & runtime.
+// WARNING: This is an experimental feature that is subject to change.
+namespace tflite {
+namespace delegates {
+
+// Used to identify specific events for tflite::Profiler.
+constexpr char kDelegateSettingsTag[] = "delegate_settings";
+constexpr char kDelegateStatusTag[] = "delegate_status";
+
+// Defines the delegate or hardware-specific 'namespace' that a status code
+// belongs to. For example, GPU delegate errors might be belong to TFLITE_GPU,
+// while OpenCL-specific ones might be TFLITE_GPU_CL.
+enum class DelegateStatusSource {
+  NONE = 0,
+  TFLITE_GPU = 1,
+  TFLITE_NNAPI = 2,
+  TFLITE_HEXAGON = 3,
+  TFLITE_XNNPACK = 4,
+  TFLITE_COREML = 5,
+  MAX_NUM_SOURCES = std::numeric_limits<int32_t>::max(),
+};
+
+// DelegateStatus defines a namespaced status with a combination of
+// DelegateStatusSource & the corresponding fine-grained 32-bit code. Used to
+// convert to/from a 64-bit representation as follows:
+//
+// delegates::DelegateStatus status(
+//      delegates::DelegateStatusSource::TFLITE_NNAPI,
+//      ANEURALNETWORKS_OP_FAILED);
+// int64_t code = status.full_status();
+//
+// auto parsed_status = delegates::DelegateStatus(code);
+class DelegateStatus {
+ public:
+  DelegateStatus() : DelegateStatus(DelegateStatusSource::NONE, 0) {}
+  explicit DelegateStatus(int32_t code)
+      : DelegateStatus(DelegateStatusSource::NONE, code) {}
+  explicit DelegateStatus(int64_t full_status)
+      : DelegateStatus(
+            static_cast<DelegateStatusSource>(
+                full_status >> 32 &
+                static_cast<int32_t>(DelegateStatusSource::MAX_NUM_SOURCES)),
+            static_cast<int32_t>(full_status &
+                                 std::numeric_limits<int32_t>::max())) {}
+  DelegateStatus(DelegateStatusSource source, int32_t code)
+      : source_(static_cast<int32_t>(source)), code_(code) {}
+
+  // Return the detailed full status encoded as a int64_t value.
+  int64_t full_status() const {
+    return static_cast<int64_t>(source_) << 32 | code_;
+  }
+
+  DelegateStatusSource source() const {
+    return static_cast<DelegateStatusSource>(source_);
+  }
+
+  int32_t code() const { return code_; }
+
+ private:
+  // value of a DelegateStatusSource, like DelegateStatusSource::TFLITE_GPU
+  int32_t source_;
+  // value of a status code, like kTfLiteOk.
+  int32_t code_;
+};
+
+// Used by delegates to report their configuration/settings to TFLite.
+// Calling this method adds a new GENERAL_RUNTIME_INSTRUMENTATION_EVENT to
+// the runtime Profiler.
+TfLiteStatus ReportDelegateSettings(TfLiteContext* context,
+                                    TfLiteDelegate* delegate,
+                                    const TFLiteSettings& settings);
+
+// Used by delegates to report their status to the TFLite runtime.
+// Calling this method adds a new GENERAL_RUNTIME_INSTRUMENTATION_EVENT to
+// the runtime Profiler.
+TfLiteStatus ReportDelegateStatus(TfLiteContext* context,
+                                  TfLiteDelegate* delegate,
+                                  const DelegateStatus& status);
+
+}  // namespace delegates
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_STATUS_H_
diff --git a/tensorflow/lite/delegates/telemetry_test.cc b/tensorflow/lite/delegates/telemetry_test.cc
new file mode 100644
index 0000000..759097c
--- /dev/null
+++ b/tensorflow/lite/delegates/telemetry_test.cc
@@ -0,0 +1,141 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/delegates/telemetry.h"
+
+#include <cstdint>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flatbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/profiler.h"
+#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
+#include "tensorflow/lite/profiling/profile_buffer.h"
+
+namespace tflite {
+namespace delegates {
+namespace {
+
+constexpr int32_t kDummyCode = 2;
+constexpr bool kDummyGpuPrecisionLossAllowed = true;
+constexpr tflite::Delegate kDummyDelegate = tflite::Delegate_GPU;
+constexpr DelegateStatusSource kDummySource =
+    DelegateStatusSource::TFLITE_NNAPI;
+
+TEST(TelemetryTest, StatusConversion) {
+  DelegateStatus status(kDummySource, kDummyCode);
+  int64_t serialized_int = status.full_status();
+  DelegateStatus deserialized_status(serialized_int);
+
+  EXPECT_EQ(kDummyCode, deserialized_status.code());
+  EXPECT_EQ(kDummySource, deserialized_status.source());
+  EXPECT_EQ(serialized_int, deserialized_status.full_status());
+}
+
+// Dummy profiler to test delegate reporting.
+class DelegateProfiler : public Profiler {
+ public:
+  DelegateProfiler() {}
+  ~DelegateProfiler() override = default;
+
+  uint32_t BeginEvent(const char* tag, EventType event_type,
+                      int64_t event_metadata1,
+                      int64_t event_metadata2) override {
+    int event_handle = -1;
+    if (event_type ==
+            Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT &&
+        std::string(tag) == kDelegateSettingsTag) {
+      event_buffer_.emplace_back();
+      event_handle = event_buffer_.size();
+
+      // event_metadata1 is a pointer to a TfLiteDelegate.
+      EXPECT_NE(event_metadata1, 0);
+      auto* delegate = reinterpret_cast<TfLiteDelegate*>(event_metadata1);
+      EXPECT_EQ(delegate->flags, kTfLiteDelegateFlagsNone);
+      // event_metadata2 is a pointer to TFLiteSettings.
+      EXPECT_NE(event_metadata2, 0);
+      auto* settings = reinterpret_cast<TFLiteSettings*>(event_metadata2);
+      EXPECT_EQ(settings->delegate(), kDummyDelegate);
+      EXPECT_EQ(settings->gpu_settings()->is_precision_loss_allowed(),
+                kDummyGpuPrecisionLossAllowed);
+    } else if (event_type ==
+                   Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT &&
+               std::string(tag) == kDelegateStatusTag) {
+      event_buffer_.emplace_back();
+      event_handle = event_buffer_.size();
+
+      EXPECT_EQ(event_metadata2, static_cast<int64_t>(kTfLiteOk));
+      DelegateStatus reported_status(event_metadata1);
+      EXPECT_EQ(reported_status.source(), kDummySource);
+      EXPECT_EQ(reported_status.code(), kDummyCode);
+    }
+
+    EXPECT_NE(-1, event_handle);
+    return event_handle;
+  }
+
+  void EndEvent(uint32_t event_handle) override {
+    EXPECT_EQ(event_handle, event_buffer_.size());
+  }
+
+  int NumRecordedEvents() { return event_buffer_.size(); }
+
+ private:
+  std::vector<profiling::ProfileEvent> event_buffer_;
+};
+
+TEST(TelemetryTest, DelegateStatusReport) {
+  DelegateProfiler profiler;
+  TfLiteDelegate delegate = TfLiteDelegateCreate();
+  TfLiteContext context;
+  context.profiler = &profiler;
+  DelegateStatus status(kDummySource, kDummyCode);
+
+  EXPECT_EQ(ReportDelegateStatus(&context, &delegate, status), kTfLiteOk);
+  EXPECT_EQ(ReportDelegateStatus(&context, &delegate, status), kTfLiteOk);
+  EXPECT_EQ(profiler.NumRecordedEvents(), 2);
+}
+
+TEST(TelemetryTest, DelegateSettingsReport) {
+  DelegateProfiler profiler;
+  TfLiteDelegate delegate = TfLiteDelegateCreate();
+  TfLiteContext context;
+  context.profiler = &profiler;
+
+  flatbuffers::FlatBufferBuilder flatbuffer_builder;
+  flatbuffers::Offset<tflite::GPUSettings> gpu_settings =
+      tflite::CreateGPUSettings(
+          flatbuffer_builder,
+          /**is_precision_loss_allowed**/ kDummyGpuPrecisionLossAllowed);
+  auto* tflite_settings_ptr = flatbuffers::GetTemporaryPointer(
+      flatbuffer_builder,
+      CreateTFLiteSettings(flatbuffer_builder, kDummyDelegate,
+                           /*nnapi_settings=*/0,
+                           /*gpu_settings=*/gpu_settings));
+
+  EXPECT_EQ(ReportDelegateSettings(&context, &delegate, *tflite_settings_ptr),
+            kTfLiteOk);
+  EXPECT_EQ(profiler.NumRecordedEvents(), 1);
+
+  // Also report status to simulate typical use-case.
+  DelegateStatus status(kDummySource, kDummyCode);
+  EXPECT_EQ(ReportDelegateStatus(&context, &delegate, status), kTfLiteOk);
+  EXPECT_EQ(ReportDelegateStatus(&context, &delegate, status), kTfLiteOk);
+  EXPECT_EQ(profiler.NumRecordedEvents(), 3);
+}
+
+}  // namespace
+}  // namespace delegates
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/utils.cc b/tensorflow/lite/delegates/utils.cc
index 289586c..9012066 100644
--- a/tensorflow/lite/delegates/utils.cc
+++ b/tensorflow/lite/delegates/utils.cc
@@ -20,6 +20,7 @@
 
 #include "tensorflow/lite/builtin_ops.h"
 #include "tensorflow/lite/context_util.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
 
 namespace tflite {
 namespace delegates {
@@ -183,7 +184,8 @@
     // its value in delegated_dequant_consumers.
     for (int j = 0; j < node->inputs->size; ++j) {
       const int input_tid = node->inputs->data[j];
-      if (dequant_consumers_.find(input_tid) != dequant_consumers_.end()) {
+      if (constant_dequant_consumers_.find(input_tid) !=
+          constant_dequant_consumers_.end()) {
         delegated_dequant_consumers[input_tid] += 1;
       }
     }
@@ -192,9 +194,10 @@
   // If the number of delegated consumers is same as total number of consumers,
   // add the corresponding DEQUANTIZE op to the delegated nodes.
   for (auto tensor_and_consumers : delegated_dequant_consumers) {
-    if (dequant_consumers_[tensor_and_consumers.first] ==
+    if (constant_dequant_consumers_[tensor_and_consumers.first] ==
         tensor_and_consumers.second) {
-      ops_to_replace.emplace_back(dequant_nodes_[tensor_and_consumers.first]);
+      ops_to_replace.emplace_back(
+          constant_dequant_nodes_[tensor_and_consumers.first]);
     }
   }
 
@@ -216,16 +219,21 @@
 bool FP16GraphPartitionHelper::IsNodeSupported(
     TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
     int node_id, std::string* unsupported_details) {
-  if (registration->builtin_code == kTfLiteBuiltinDequantize &&
-      context_->tensors[node->inputs->data[0]].type ==
-          TfLiteType::kTfLiteFloat16) {
-    // Update mappings if this node is a fp16 DEQUANTIZE node.
-    dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
-    dequant_nodes_[node->outputs->data[0]] = node_id;
-    // We do not accept these ops right now.
-    // This is done to support use-cases where a DEQUANTIZE output might be
-    // consumed by a CPU op.
-    return false;
+  if (registration->builtin_code == kTfLiteBuiltinDequantize) {
+    auto& dequantize_input = context_->tensors[node->inputs->data[0]];
+    if (dequantize_input.type == kTfLiteFloat16 &&
+        IsConstantTensor(&dequantize_input)) {
+      // Update mappings if this node is a fp16 DEQUANTIZE node that
+      // works on a **constant** input tensor.
+      // If the input is not a constant, the remapping that we do here will
+      // cause bugs due to preceding ops such as DENSIFY.
+      constant_dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
+      constant_dequant_nodes_[node->outputs->data[0]] = node_id;
+      // We do not accept these ops right now.
+      // This is done to support use-cases where a DEQUANTIZE output might be
+      // consumed by a CPU op.
+      return false;
+    }
   }
 
   // To check if a (possibly) FP16 node is supported, we temporarily point the
@@ -234,7 +242,7 @@
   // we remap the original node inputs, so that the TFLite graph remains the
   // same.
   std::vector<int> orig_inputs;
-  if (!dequant_nodes_.empty()) {
+  if (!constant_dequant_nodes_.empty()) {
     RemapFp16InputTensors(node, &orig_inputs);
   }
 
@@ -245,10 +253,11 @@
     // Remapping happened. Restore original inputs.
     for (int j = 0; j < node->inputs->size; ++j) {
       node->inputs->data[j] = orig_inputs[j];
-      if (dequant_nodes_.find(orig_inputs[j]) != dequant_nodes_.end()) {
+      if (constant_dequant_nodes_.find(orig_inputs[j]) !=
+          constant_dequant_nodes_.end()) {
         // If its a fp16 tensor, increment number of consumers of the
         // corresponding DEQUANTIZE.
-        dequant_consumers_[orig_inputs[j]] += 1;
+        constant_dequant_consumers_[orig_inputs[j]] += 1;
       }
     }
   }
@@ -289,8 +298,8 @@
   bool is_remapped = false;
   for (int j = 0; j < inputs->size; ++j) {
     const int input_tid = inputs->data[j];
-    const auto it = dequant_map_.find(input_tid);
-    if (it != dequant_map_.end()) {
+    const auto it = constant_dequant_map_.find(input_tid);
+    if (it != constant_dequant_map_.end()) {
       inputs->data[j] = it->second;
       is_remapped = true;
     }
diff --git a/tensorflow/lite/delegates/utils.h b/tensorflow/lite/delegates/utils.h
index a9fb673..31b2977 100644
--- a/tensorflow/lite/delegates/utils.h
+++ b/tensorflow/lite/delegates/utils.h
@@ -131,8 +131,8 @@
 // Specialized partitioner for graphs that possibly contain fp16 tensors.
 //
 // From nodes that accept fp16 inputs, this delegates the following:
-// 1. All nodes (except DEQUANTIZE) that are supported with fp16 inputs by the
-// delegate (in the TFLite graph, these nodes take in dequantized FP32
+// 1. All nodes (except DEQUANTIZE) that are supported with constant fp16 inputs
+// by the delegate (in the TFLite graph, these nodes take in dequantized FP32
 // outputs).
 // 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first*
 // delegated partition. This is because TFLite's partitioning algorithm
@@ -168,11 +168,12 @@
 
   // ('dequantize' here refers to fp16 DEQUANTIZE)
   // Mapping of dequantize nodes' output tensor-id to its node id.
-  std::unordered_map<int, int> dequant_nodes_;
+  // TODO(b/156707497): Use absl hash_maps here.
+  std::unordered_map<int, int> constant_dequant_nodes_;
   // Mapping of DEQUANTIZE node's output (fp32) to its input (fp16).
-  std::unordered_map<int, int> dequant_map_;
+  std::unordered_map<int, int> constant_dequant_map_;
   // mapping of DEQUANTIZE output tensor-id to its number of consumers.
-  std::unordered_map<int, int> dequant_consumers_;
+  std::unordered_map<int, int> constant_dequant_consumers_;
 };
 
 }  // namespace delegates
diff --git a/tensorflow/lite/delegates/utils/dummy_delegate/external_delegate_adaptor.cc b/tensorflow/lite/delegates/utils/dummy_delegate/external_delegate_adaptor.cc
index 7ae6539..fdefd2e 100644
--- a/tensorflow/lite/delegates/utils/dummy_delegate/external_delegate_adaptor.cc
+++ b/tensorflow/lite/delegates/utils/dummy_delegate/external_delegate_adaptor.cc
@@ -84,9 +84,7 @@
 }  // namespace tools
 }  // namespace tflite
 
-#ifdef __cplusplus
 extern "C" {
-#endif  // __cplusplus
 
 // Defines two symbols that need to be exported to use the TFLite external
 // delegate. See tensorflow/lite/delegates/external for details.
@@ -101,6 +99,4 @@
   TfLiteDummyDelegateDelete(delegate);
 }
 
-#ifdef __cplusplus
-}
-#endif  // __cplusplus
+}  // extern "C"
diff --git a/tensorflow/lite/delegates/xnnpack/README.md b/tensorflow/lite/delegates/xnnpack/README.md
index 9e45c56..44afe25 100644
--- a/tensorflow/lite/delegates/xnnpack/README.md
+++ b/tensorflow/lite/delegates/xnnpack/README.md
@@ -303,22 +303,22 @@
 * Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported,
   but fused `TANH` and `SIGN_BIT` activations are not.
 
-### Sparse Inference (experimental)
+### Sparse Inference
 
 XNNPACK backend supports sparse inference for CNN models described in the
-[Fast Sparse ConvNets](https://arxiv.org/abs/1911.09723) paper. This
-functionality must be enabled at build-time via
-`--define xnn_enable_sparse=true` Bazel flag. Sparse inference is restricted
-to subgraphs with the following operators:
+[Fast Sparse ConvNets](https://arxiv.org/abs/1911.09723) paper. Sparse
+inference is restricted to subgraphs with the following operators:
 
+* Sparse subgraph must store its weights in sparse representation (using
+  `DENSIFY` operators in the TensorFlow Lite schema).
 * Sparse subgraph must start with a 3x3 stride-2 `CONV_2D` operator with
   padding 1 on each side, no dilation, and 3 input channels.
-* Sparse subgraph must end with a `MEAN` operator that does reduction across
-  spatial axes.
+* Sparse subgraph must end with either a `MEAN` operator with reduction across
+  spatial axes, or a `DEPTH_TO_SPACE` operator.
 * Sparse subgraph may contain the following operators:
-  * `CONV_2D` with 1x1 kernel and no padding. It is important to have high
-    sparsity (at least 70%) in the filter of this operator to get speedup
-    over dense inference.
+  * `CONV_2D` with 1x1 kernel and no padding. At least 2/3rd of filter weights
+    in the 1x1 `CONV_2D` operators across the sparse subgraph must be zeroes
+    to enable sparse inference.
   * `DEPTHWISE_CONV_2D` with 3x3 kernel, stride 1, no dilation, and padding 1
     on each side.
   * `DEPTHWISE_CONV_2D` with 3x3 kernel, stride 2, no dilation, and padding 1
@@ -327,19 +327,18 @@
     on each side.
   * `DEPTHWISE_CONV_2D` with 5x5 kernel, stride 2, no dilation, and padding 2
     on each side.
+  * `RESIZE_BILINEAR` operator with output dimensions greater than 1.
+  * `MEAN` operator with reduction across spatial axes.
   * `ADD` and `MUL` operators where both inputs are 4D tensors. If one of the
     inputs to `ADD` or `MUL` is a constant tensor, it must be representable as
     either a scalar, or a 1D vector.
-  * Unary elementwise operators `ABS`, `CEIL`, `FLOOR`, `HARD_SWISH`,
+  * Unary elementwise operators `ABS`, `CEIL`, `ELU`, `FLOOR`, `HARD_SWISH`,
     `LEAKY_RELU`, `LOGISTIC`, `NEG`, `RELU`, `RELU6`, `RELU_N1_TO_1`, `ROUND`,
-    and `SQUARE`.
+    `SIGMOID`, and `SQUARE`.
 
 Pre-trained [Fast Sparse ConvNets models](https://github.com/google-research/google-research/tree/master/fastconvnets)
 provide examples that satisfy these constrains.
 
-In addition to acceleration, sparse models get the compression benefit by
-storing only non-zero values in the [TensorFlow Lite file format](https://github.com/tensorflow/tensorflow/blob/4aea552e064cf92330e07e83a3b5a1ca2a7034d0/tensorflow/lite/schema/schema.fbs#L84-L109).
-
 ### Other limitations
 
 * Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and
diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc
index ca75ec4..d3b073c 100644
--- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc
+++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc
@@ -89,6 +89,8 @@
   // ignored in the delegate implementation, because their outputs are
   // pre-unpacked in DelegatePrepare.
   std::unordered_set<int> static_unpack_nodes_;
+  // Set of indices of tensors with unpacked static sparse weights.
+  std::unordered_set<int> static_sparse_weights_;
 #if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
   // Thread pool with smart-pointer for lifetime management.
   std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)> threadpool_{
@@ -134,17 +136,13 @@
     std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph(
         subgraph_ptr, &xnn_delete_subgraph);
 
+    bool has_sparse_weights = false;
     // Detect which tensors are used as inputs or outputs of any subgraph nodes.
     // -1 denotes tensor not used in the subgraph. These indexes will be
     // filtered out and removed later.
     std::vector<int> tensors(context->tensors_size, -1);
     for (int i = 0; i < params->nodes_to_replace->size; i++) {
       const int node_index = params->nodes_to_replace->data[i];
-      if (delegate->static_unpack_nodes_.count(node_index)) {
-        // The node unpacks static input and can be skipped because its input
-        // was pre-unpacked in DelegatePrepare.
-        continue;
-      }
 
       TfLiteNode* node = nullptr;
       TfLiteRegistration* registration = nullptr;
@@ -153,6 +151,22 @@
         return nullptr;
       }
 
+      // Detect if any of the node's inputs are sparse weights.
+      if (!has_sparse_weights) {
+        for (int i = 0; i < node->inputs->size; i++) {
+          if (delegate->static_sparse_weights_.count(node->inputs->data[i]) !=
+              0) {
+            has_sparse_weights = true;
+          }
+        }
+      }
+
+      if (delegate->static_unpack_nodes_.count(node_index) != 0) {
+        // The node unpacks static input and can be skipped because its input
+        // was pre-unpacked in DelegatePrepare.
+        continue;
+      }
+
       switch (registration->builtin_code) {
         case kTfLiteBuiltinMean:
         case kTfLiteBuiltinPad:
@@ -260,8 +274,9 @@
     }
 
     xnn_runtime_t runtime_ptr = nullptr;
+    const uint32_t flags = has_sparse_weights ? XNN_FLAG_SPARSE_INFERENCE : 0;
     status = xnn_create_runtime_v2(subgraph.get(), delegate->threadpool(),
-                                   /*flags=*/0, &runtime_ptr);
+                                   flags, &runtime_ptr);
     if (status != xnn_status_success) {
       TF_LITE_KERNEL_LOG(context, "failed to create XNNPACK runtime");
       return nullptr;
@@ -2911,6 +2926,7 @@
   static_unpacked_data_map_.clear();
   static_unpacked_data_.clear();
   static_unpack_nodes_.clear();
+  static_sparse_weights_.clear();
 
   TfLiteIntArray* execution_plan = nullptr;
   if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
@@ -2962,6 +2978,11 @@
           quasi_static_tensors_to_unpack.insert(node->inputs->data[0]);
         }
 
+        // If dequantized input is sparse, so is its output
+        if (static_sparse_weights_.count(node->inputs->data[0]) != 0) {
+          static_sparse_weights_.insert(node->outputs->data[0]);
+        }
+
         // Skip this node for now. If output of the node is consumed only by
         // delegated nodes, it will be added to nodes_to_delegate in the end.
         continue;
@@ -2987,6 +3008,7 @@
         static_unpack_nodes_.insert(node_index);
         quasi_static_tensors_producers[node->outputs->data[0]] = node_index;
         quasi_static_tensors.insert(node->outputs->data[0]);
+        static_sparse_weights_.insert(node->outputs->data[0]);
 
         // Skip this node for now. If output of the node is consumed only by
         // delegated nodes, it will be added to nodes_to_delegate in the end.
@@ -3156,22 +3178,25 @@
 
         switch (input_tensor.type) {
           case kTfLiteFloat32: {
+            const size_t dense_size = context->tensors[t].bytes / sizeof(float);
+            float* unpacked_fp32_data = reinterpret_cast<float*>(unpacked_data);
             tflite::optimize::sparsity::FormatConverter<float> converter(
                 vector_shape, *input_tensor.sparsity);
             converter.SparseToDense(
-                static_cast<const float*>(input_tensor.data.data));
-            const std::vector<float> out = converter.GetData();
-            std::memcpy(unpacked_data, out.data(), out.size() * sizeof(float));
+                static_cast<const float*>(input_tensor.data.data), dense_size,
+                unpacked_fp32_data, context);
             break;
           }
           case kTfLiteFloat16: {
+            const size_t dense_size =
+                context->tensors[t].bytes / sizeof(Eigen::half);
+            Eigen::half* unpacked_fp16_data =
+                reinterpret_cast<Eigen::half*>(unpacked_data);
             tflite::optimize::sparsity::FormatConverter<Eigen::half> converter(
                 vector_shape, *input_tensor.sparsity);
             converter.SparseToDense(
-                static_cast<const Eigen::half*>(input_tensor.data.data));
-            const std::vector<Eigen::half> out = converter.GetData();
-            std::memcpy(unpacked_data, out.data(),
-                        out.size() * sizeof(Eigen::half));
+                static_cast<const Eigen::half*>(input_tensor.data.data),
+                dense_size, unpacked_fp16_data, context);
             break;
           }
           default: {
diff --git a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin
index f64f35d..6baadf3 100644
--- a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin
+++ b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin
Binary files differ
diff --git a/tensorflow/lite/experimental/acceleration/configuration/BUILD b/tensorflow/lite/experimental/acceleration/configuration/BUILD
index 023db7b..633da07 100644
--- a/tensorflow/lite/experimental/acceleration/configuration/BUILD
+++ b/tensorflow/lite/experimental/acceleration/configuration/BUILD
@@ -14,6 +14,7 @@
 # ==============================================================================
 
 load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library", "flatbuffer_java_library", "flatc_path")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = [
@@ -32,6 +33,7 @@
     $(location {}) --proto -o $(@D) $(location :configuration.proto)
     perl -p -i -e 's/tflite.proto/tflite/' $(@D)/configuration.fbs
     """.format(flatc_path),
+    compatible_with = get_compatible_with_portable(),
     tools = [
         flatc_path,
     ],
@@ -68,6 +70,7 @@
 flatbuffer_cc_library(
     name = "configuration_fbs",
     srcs = [":configuration.fbs"],
+    compatible_with = get_compatible_with_portable(),
 )
 
 flatbuffer_java_library(
diff --git a/tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h b/tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h
new file mode 100644
index 0000000..53d455c
--- /dev/null
+++ b/tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h
@@ -0,0 +1,1444 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_CONFIGURATION_TFLITE_H_
+#define FLATBUFFERS_GENERATED_CONFIGURATION_TFLITE_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+namespace tflite {
+
+struct ComputeSettings;
+struct ComputeSettingsT;
+
+struct NNAPISettings;
+struct NNAPISettingsT;
+
+struct GPUSettings;
+struct GPUSettingsT;
+
+struct HexagonSettings;
+struct HexagonSettingsT;
+
+struct XNNPackSettings;
+struct XNNPackSettingsT;
+
+struct EdgeTpuSettings;
+struct EdgeTpuSettingsT;
+
+struct CPUSettings;
+struct CPUSettingsT;
+
+struct TFLiteSettings;
+struct TFLiteSettingsT;
+
+struct FallbackSettings;
+struct FallbackSettingsT;
+
+enum ExecutionPreference {
+  ExecutionPreference_ANY = 0,
+  ExecutionPreference_LOW_LATENCY = 1,
+  ExecutionPreference_LOW_POWER = 2,
+  ExecutionPreference_FORCE_CPU = 3,
+  ExecutionPreference_MIN = ExecutionPreference_ANY,
+  ExecutionPreference_MAX = ExecutionPreference_FORCE_CPU
+};
+
+inline const ExecutionPreference (&EnumValuesExecutionPreference())[4] {
+  static const ExecutionPreference values[] = {
+    ExecutionPreference_ANY,
+    ExecutionPreference_LOW_LATENCY,
+    ExecutionPreference_LOW_POWER,
+    ExecutionPreference_FORCE_CPU
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesExecutionPreference() {
+  static const char * const names[5] = {
+    "ANY",
+    "LOW_LATENCY",
+    "LOW_POWER",
+    "FORCE_CPU",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNameExecutionPreference(ExecutionPreference e) {
+  if (flatbuffers::IsOutRange(e, ExecutionPreference_ANY, ExecutionPreference_FORCE_CPU)) return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesExecutionPreference()[index];
+}
+
+enum Delegate {
+  Delegate_NONE = 0,
+  Delegate_NNAPI = 1,
+  Delegate_GPU = 2,
+  Delegate_HEXAGON = 3,
+  Delegate_XNNPACK = 4,
+  Delegate_EDGETPU = 5,
+  Delegate_MIN = Delegate_NONE,
+  Delegate_MAX = Delegate_EDGETPU
+};
+
+inline const Delegate (&EnumValuesDelegate())[6] {
+  static const Delegate values[] = {
+    Delegate_NONE,
+    Delegate_NNAPI,
+    Delegate_GPU,
+    Delegate_HEXAGON,
+    Delegate_XNNPACK,
+    Delegate_EDGETPU
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesDelegate() {
+  static const char * const names[7] = {
+    "NONE",
+    "NNAPI",
+    "GPU",
+    "HEXAGON",
+    "XNNPACK",
+    "EDGETPU",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNameDelegate(Delegate e) {
+  if (flatbuffers::IsOutRange(e, Delegate_NONE, Delegate_EDGETPU)) return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesDelegate()[index];
+}
+
+enum NNAPIExecutionPreference {
+  NNAPIExecutionPreference_UNDEFINED = 0,
+  NNAPIExecutionPreference_NNAPI_LOW_POWER = 1,
+  NNAPIExecutionPreference_NNAPI_FAST_SINGLE_ANSWER = 2,
+  NNAPIExecutionPreference_NNAPI_SUSTAINED_SPEED = 3,
+  NNAPIExecutionPreference_MIN = NNAPIExecutionPreference_UNDEFINED,
+  NNAPIExecutionPreference_MAX = NNAPIExecutionPreference_NNAPI_SUSTAINED_SPEED
+};
+
+inline const NNAPIExecutionPreference (&EnumValuesNNAPIExecutionPreference())[4] {
+  static const NNAPIExecutionPreference values[] = {
+    NNAPIExecutionPreference_UNDEFINED,
+    NNAPIExecutionPreference_NNAPI_LOW_POWER,
+    NNAPIExecutionPreference_NNAPI_FAST_SINGLE_ANSWER,
+    NNAPIExecutionPreference_NNAPI_SUSTAINED_SPEED
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesNNAPIExecutionPreference() {
+  static const char * const names[5] = {
+    "UNDEFINED",
+    "NNAPI_LOW_POWER",
+    "NNAPI_FAST_SINGLE_ANSWER",
+    "NNAPI_SUSTAINED_SPEED",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNameNNAPIExecutionPreference(NNAPIExecutionPreference e) {
+  if (flatbuffers::IsOutRange(e, NNAPIExecutionPreference_UNDEFINED, NNAPIExecutionPreference_NNAPI_SUSTAINED_SPEED)) return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesNNAPIExecutionPreference()[index];
+}
+
+enum NNAPIExecutionPriority {
+  NNAPIExecutionPriority_NNAPI_PRIORITY_UNDEFINED = 0,
+  NNAPIExecutionPriority_NNAPI_PRIORITY_LOW = 1,
+  NNAPIExecutionPriority_NNAPI_PRIORITY_MEDIUM = 2,
+  NNAPIExecutionPriority_NNAPI_PRIORITY_HIGH = 3,
+  NNAPIExecutionPriority_MIN = NNAPIExecutionPriority_NNAPI_PRIORITY_UNDEFINED,
+  NNAPIExecutionPriority_MAX = NNAPIExecutionPriority_NNAPI_PRIORITY_HIGH
+};
+
+inline const NNAPIExecutionPriority (&EnumValuesNNAPIExecutionPriority())[4] {
+  static const NNAPIExecutionPriority values[] = {
+    NNAPIExecutionPriority_NNAPI_PRIORITY_UNDEFINED,
+    NNAPIExecutionPriority_NNAPI_PRIORITY_LOW,
+    NNAPIExecutionPriority_NNAPI_PRIORITY_MEDIUM,
+    NNAPIExecutionPriority_NNAPI_PRIORITY_HIGH
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesNNAPIExecutionPriority() {
+  static const char * const names[5] = {
+    "NNAPI_PRIORITY_UNDEFINED",
+    "NNAPI_PRIORITY_LOW",
+    "NNAPI_PRIORITY_MEDIUM",
+    "NNAPI_PRIORITY_HIGH",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNameNNAPIExecutionPriority(NNAPIExecutionPriority e) {
+  if (flatbuffers::IsOutRange(e, NNAPIExecutionPriority_NNAPI_PRIORITY_UNDEFINED, NNAPIExecutionPriority_NNAPI_PRIORITY_HIGH)) return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesNNAPIExecutionPriority()[index];
+}
+
+enum GPUBackend {
+  GPUBackend_UNSET = 0,
+  GPUBackend_OPENCL = 1,
+  GPUBackend_OPENGL = 2,
+  GPUBackend_MIN = GPUBackend_UNSET,
+  GPUBackend_MAX = GPUBackend_OPENGL
+};
+
+inline const GPUBackend (&EnumValuesGPUBackend())[3] {
+  static const GPUBackend values[] = {
+    GPUBackend_UNSET,
+    GPUBackend_OPENCL,
+    GPUBackend_OPENGL
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesGPUBackend() {
+  static const char * const names[4] = {
+    "UNSET",
+    "OPENCL",
+    "OPENGL",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNameGPUBackend(GPUBackend e) {
+  if (flatbuffers::IsOutRange(e, GPUBackend_UNSET, GPUBackend_OPENGL)) return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesGPUBackend()[index];
+}
+
+namespace EdgeTpuSettings_ {
+
+enum PowerState {
+  PowerState_UNDEFINED = 0,
+  PowerState_TPU_CORE_OFF = 1,
+  PowerState_READY = 2,
+  PowerState_READY_WITH_RETENTION = 3,
+  PowerState_ACTIVE_MIN_POWER = 4,
+  PowerState_ACTIVE_LOW_POWER = 5,
+  PowerState_ACTIVE = 6,
+  PowerState_OVER_DRIVE = 7,
+  PowerState_MIN = PowerState_UNDEFINED,
+  PowerState_MAX = PowerState_OVER_DRIVE
+};
+
+inline const PowerState (&EnumValuesPowerState())[8] {
+  static const PowerState values[] = {
+    PowerState_UNDEFINED,
+    PowerState_TPU_CORE_OFF,
+    PowerState_READY,
+    PowerState_READY_WITH_RETENTION,
+    PowerState_ACTIVE_MIN_POWER,
+    PowerState_ACTIVE_LOW_POWER,
+    PowerState_ACTIVE,
+    PowerState_OVER_DRIVE
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesPowerState() {
+  static const char * const names[9] = {
+    "UNDEFINED",
+    "TPU_CORE_OFF",
+    "READY",
+    "READY_WITH_RETENTION",
+    "ACTIVE_MIN_POWER",
+    "ACTIVE_LOW_POWER",
+    "ACTIVE",
+    "OVER_DRIVE",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNamePowerState(PowerState e) {
+  if (flatbuffers::IsOutRange(e, PowerState_UNDEFINED, PowerState_OVER_DRIVE)) return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesPowerState()[index];
+}
+
+}  // namespace EdgeTpuSettings_
+
+struct ComputeSettingsT : public flatbuffers::NativeTable {
+  typedef ComputeSettings TableType;
+  tflite::ExecutionPreference preference;
+  std::unique_ptr<tflite::TFLiteSettingsT> tflite_settings;
+  std::string model_namespace_for_statistics;
+  std::string model_identifier_for_statistics;
+  ComputeSettingsT()
+      : preference(tflite::ExecutionPreference_ANY) {
+  }
+};
+
+struct ComputeSettings FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef ComputeSettingsT NativeTableType;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_PREFERENCE = 4,
+    VT_TFLITE_SETTINGS = 6,
+    VT_MODEL_NAMESPACE_FOR_STATISTICS = 8,
+    VT_MODEL_IDENTIFIER_FOR_STATISTICS = 10
+  };
+  tflite::ExecutionPreference preference() const {
+    return static_cast<tflite::ExecutionPreference>(GetField<int32_t>(VT_PREFERENCE, 0));
+  }
+  const tflite::TFLiteSettings *tflite_settings() const {
+    return GetPointer<const tflite::TFLiteSettings *>(VT_TFLITE_SETTINGS);
+  }
+  const flatbuffers::String *model_namespace_for_statistics() const {
+    return GetPointer<const flatbuffers::String *>(VT_MODEL_NAMESPACE_FOR_STATISTICS);
+  }
+  const flatbuffers::String *model_identifier_for_statistics() const {
+    return GetPointer<const flatbuffers::String *>(VT_MODEL_IDENTIFIER_FOR_STATISTICS);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<int32_t>(verifier, VT_PREFERENCE) &&
+           VerifyOffset(verifier, VT_TFLITE_SETTINGS) &&
+           verifier.VerifyTable(tflite_settings()) &&
+           VerifyOffset(verifier, VT_MODEL_NAMESPACE_FOR_STATISTICS) &&
+           verifier.VerifyString(model_namespace_for_statistics()) &&
+           VerifyOffset(verifier, VT_MODEL_IDENTIFIER_FOR_STATISTICS) &&
+           verifier.VerifyString(model_identifier_for_statistics()) &&
+           verifier.EndTable();
+  }
+  ComputeSettingsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(ComputeSettingsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<ComputeSettings> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ComputeSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ComputeSettingsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_preference(tflite::ExecutionPreference preference) {
+    fbb_.AddElement<int32_t>(ComputeSettings::VT_PREFERENCE, static_cast<int32_t>(preference), 0);
+  }
+  void add_tflite_settings(flatbuffers::Offset<tflite::TFLiteSettings> tflite_settings) {
+    fbb_.AddOffset(ComputeSettings::VT_TFLITE_SETTINGS, tflite_settings);
+  }
+  void add_model_namespace_for_statistics(flatbuffers::Offset<flatbuffers::String> model_namespace_for_statistics) {
+    fbb_.AddOffset(ComputeSettings::VT_MODEL_NAMESPACE_FOR_STATISTICS, model_namespace_for_statistics);
+  }
+  void add_model_identifier_for_statistics(flatbuffers::Offset<flatbuffers::String> model_identifier_for_statistics) {
+    fbb_.AddOffset(ComputeSettings::VT_MODEL_IDENTIFIER_FOR_STATISTICS, model_identifier_for_statistics);
+  }
+  explicit ComputeSettingsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ComputeSettingsBuilder &operator=(const ComputeSettingsBuilder &);
+  flatbuffers::Offset<ComputeSettings> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<ComputeSettings>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<ComputeSettings> CreateComputeSettings(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    tflite::ExecutionPreference preference = tflite::ExecutionPreference_ANY,
+    flatbuffers::Offset<tflite::TFLiteSettings> tflite_settings = 0,
+    flatbuffers::Offset<flatbuffers::String> model_namespace_for_statistics = 0,
+    flatbuffers::Offset<flatbuffers::String> model_identifier_for_statistics = 0) {
+  ComputeSettingsBuilder builder_(_fbb);
+  builder_.add_model_identifier_for_statistics(model_identifier_for_statistics);
+  builder_.add_model_namespace_for_statistics(model_namespace_for_statistics);
+  builder_.add_tflite_settings(tflite_settings);
+  builder_.add_preference(preference);
+  return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ComputeSettings> CreateComputeSettingsDirect(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    tflite::ExecutionPreference preference = tflite::ExecutionPreference_ANY,
+    flatbuffers::Offset<tflite::TFLiteSettings> tflite_settings = 0,
+    const char *model_namespace_for_statistics = nullptr,
+    const char *model_identifier_for_statistics = nullptr) {
+  auto model_namespace_for_statistics__ = model_namespace_for_statistics ? _fbb.CreateString(model_namespace_for_statistics) : 0;
+  auto model_identifier_for_statistics__ = model_identifier_for_statistics ? _fbb.CreateString(model_identifier_for_statistics) : 0;
+  return tflite::CreateComputeSettings(
+      _fbb,
+      preference,
+      tflite_settings,
+      model_namespace_for_statistics__,
+      model_identifier_for_statistics__);
+}
+
+flatbuffers::Offset<ComputeSettings> CreateComputeSettings(flatbuffers::FlatBufferBuilder &_fbb, const ComputeSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct NNAPISettingsT : public flatbuffers::NativeTable {
+  typedef NNAPISettings TableType;
+  std::string accelerator_name;
+  std::string cache_directory;
+  std::string model_token;
+  tflite::NNAPIExecutionPreference execution_preference;
+  int32_t no_of_nnapi_instances_to_cache;
+  std::unique_ptr<tflite::FallbackSettingsT> fallback_settings;
+  bool allow_nnapi_cpu_on_android_10_plus;
+  tflite::NNAPIExecutionPriority execution_priority;
+  bool allow_dynamic_dimensions;
+  bool allow_fp16_precision_for_fp32;
+  NNAPISettingsT()
+      : execution_preference(tflite::NNAPIExecutionPreference_UNDEFINED),
+        no_of_nnapi_instances_to_cache(0),
+        allow_nnapi_cpu_on_android_10_plus(false),
+        execution_priority(tflite::NNAPIExecutionPriority_NNAPI_PRIORITY_UNDEFINED),
+        allow_dynamic_dimensions(false),
+        allow_fp16_precision_for_fp32(false) {
+  }
+};
+
+struct NNAPISettings FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef NNAPISettingsT NativeTableType;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_ACCELERATOR_NAME = 4,
+    VT_CACHE_DIRECTORY = 6,
+    VT_MODEL_TOKEN = 8,
+    VT_EXECUTION_PREFERENCE = 10,
+    VT_NO_OF_NNAPI_INSTANCES_TO_CACHE = 12,
+    VT_FALLBACK_SETTINGS = 14,
+    VT_ALLOW_NNAPI_CPU_ON_ANDROID_10_PLUS = 16,
+    VT_EXECUTION_PRIORITY = 18,
+    VT_ALLOW_DYNAMIC_DIMENSIONS = 20,
+    VT_ALLOW_FP16_PRECISION_FOR_FP32 = 22
+  };
+  const flatbuffers::String *accelerator_name() const {
+    return GetPointer<const flatbuffers::String *>(VT_ACCELERATOR_NAME);
+  }
+  const flatbuffers::String *cache_directory() const {
+    return GetPointer<const flatbuffers::String *>(VT_CACHE_DIRECTORY);
+  }
+  const flatbuffers::String *model_token() const {
+    return GetPointer<const flatbuffers::String *>(VT_MODEL_TOKEN);
+  }
+  tflite::NNAPIExecutionPreference execution_preference() const {
+    return static_cast<tflite::NNAPIExecutionPreference>(GetField<int32_t>(VT_EXECUTION_PREFERENCE, 0));
+  }
+  int32_t no_of_nnapi_instances_to_cache() const {
+    return GetField<int32_t>(VT_NO_OF_NNAPI_INSTANCES_TO_CACHE, 0);
+  }
+  const tflite::FallbackSettings *fallback_settings() const {
+    return GetPointer<const tflite::FallbackSettings *>(VT_FALLBACK_SETTINGS);
+  }
+  bool allow_nnapi_cpu_on_android_10_plus() const {
+    return GetField<uint8_t>(VT_ALLOW_NNAPI_CPU_ON_ANDROID_10_PLUS, 0) != 0;
+  }
+  tflite::NNAPIExecutionPriority execution_priority() const {
+    return static_cast<tflite::NNAPIExecutionPriority>(GetField<int32_t>(VT_EXECUTION_PRIORITY, 0));
+  }
+  bool allow_dynamic_dimensions() const {
+    return GetField<uint8_t>(VT_ALLOW_DYNAMIC_DIMENSIONS, 0) != 0;
+  }
+  bool allow_fp16_precision_for_fp32() const {
+    return GetField<uint8_t>(VT_ALLOW_FP16_PRECISION_FOR_FP32, 0) != 0;
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_ACCELERATOR_NAME) &&
+           verifier.VerifyString(accelerator_name()) &&
+           VerifyOffset(verifier, VT_CACHE_DIRECTORY) &&
+           verifier.VerifyString(cache_directory()) &&
+           VerifyOffset(verifier, VT_MODEL_TOKEN) &&
+           verifier.VerifyString(model_token()) &&
+           VerifyField<int32_t>(verifier, VT_EXECUTION_PREFERENCE) &&
+           VerifyField<int32_t>(verifier, VT_NO_OF_NNAPI_INSTANCES_TO_CACHE) &&
+           VerifyOffset(verifier, VT_FALLBACK_SETTINGS) &&
+           verifier.VerifyTable(fallback_settings()) &&
+           VerifyField<uint8_t>(verifier, VT_ALLOW_NNAPI_CPU_ON_ANDROID_10_PLUS) &&
+           VerifyField<int32_t>(verifier, VT_EXECUTION_PRIORITY) &&
+           VerifyField<uint8_t>(verifier, VT_ALLOW_DYNAMIC_DIMENSIONS) &&
+           VerifyField<uint8_t>(verifier, VT_ALLOW_FP16_PRECISION_FOR_FP32) &&
+           verifier.EndTable();
+  }
+  NNAPISettingsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(NNAPISettingsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<NNAPISettings> Pack(flatbuffers::FlatBufferBuilder &_fbb, const NNAPISettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct NNAPISettingsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_accelerator_name(flatbuffers::Offset<flatbuffers::String> accelerator_name) {
+    fbb_.AddOffset(NNAPISettings::VT_ACCELERATOR_NAME, accelerator_name);
+  }
+  void add_cache_directory(flatbuffers::Offset<flatbuffers::String> cache_directory) {
+    fbb_.AddOffset(NNAPISettings::VT_CACHE_DIRECTORY, cache_directory);
+  }
+  void add_model_token(flatbuffers::Offset<flatbuffers::String> model_token) {
+    fbb_.AddOffset(NNAPISettings::VT_MODEL_TOKEN, model_token);
+  }
+  void add_execution_preference(tflite::NNAPIExecutionPreference execution_preference) {
+    fbb_.AddElement<int32_t>(NNAPISettings::VT_EXECUTION_PREFERENCE, static_cast<int32_t>(execution_preference), 0);
+  }
+  void add_no_of_nnapi_instances_to_cache(int32_t no_of_nnapi_instances_to_cache) {
+    fbb_.AddElement<int32_t>(NNAPISettings::VT_NO_OF_NNAPI_INSTANCES_TO_CACHE, no_of_nnapi_instances_to_cache, 0);
+  }
+  void add_fallback_settings(flatbuffers::Offset<tflite::FallbackSettings> fallback_settings) {
+    fbb_.AddOffset(NNAPISettings::VT_FALLBACK_SETTINGS, fallback_settings);
+  }
+  void add_allow_nnapi_cpu_on_android_10_plus(bool allow_nnapi_cpu_on_android_10_plus) {
+    fbb_.AddElement<uint8_t>(NNAPISettings::VT_ALLOW_NNAPI_CPU_ON_ANDROID_10_PLUS, static_cast<uint8_t>(allow_nnapi_cpu_on_android_10_plus), 0);
+  }
+  void add_execution_priority(tflite::NNAPIExecutionPriority execution_priority) {
+    fbb_.AddElement<int32_t>(NNAPISettings::VT_EXECUTION_PRIORITY, static_cast<int32_t>(execution_priority), 0);
+  }
+  void add_allow_dynamic_dimensions(bool allow_dynamic_dimensions) {
+    fbb_.AddElement<uint8_t>(NNAPISettings::VT_ALLOW_DYNAMIC_DIMENSIONS, static_cast<uint8_t>(allow_dynamic_dimensions), 0);
+  }
+  void add_allow_fp16_precision_for_fp32(bool allow_fp16_precision_for_fp32) {
+    fbb_.AddElement<uint8_t>(NNAPISettings::VT_ALLOW_FP16_PRECISION_FOR_FP32, static_cast<uint8_t>(allow_fp16_precision_for_fp32), 0);
+  }
+  explicit NNAPISettingsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  NNAPISettingsBuilder &operator=(const NNAPISettingsBuilder &);
+  flatbuffers::Offset<NNAPISettings> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<NNAPISettings>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<NNAPISettings> CreateNNAPISettings(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    flatbuffers::Offset<flatbuffers::String> accelerator_name = 0,
+    flatbuffers::Offset<flatbuffers::String> cache_directory = 0,
+    flatbuffers::Offset<flatbuffers::String> model_token = 0,
+    tflite::NNAPIExecutionPreference execution_preference = tflite::NNAPIExecutionPreference_UNDEFINED,
+    int32_t no_of_nnapi_instances_to_cache = 0,
+    flatbuffers::Offset<tflite::FallbackSettings> fallback_settings = 0,
+    bool allow_nnapi_cpu_on_android_10_plus = false,
+    tflite::NNAPIExecutionPriority execution_priority = tflite::NNAPIExecutionPriority_NNAPI_PRIORITY_UNDEFINED,
+    bool allow_dynamic_dimensions = false,
+    bool allow_fp16_precision_for_fp32 = false) {
+  NNAPISettingsBuilder builder_(_fbb);
+  builder_.add_execution_priority(execution_priority);
+  builder_.add_fallback_settings(fallback_settings);
+  builder_.add_no_of_nnapi_instances_to_cache(no_of_nnapi_instances_to_cache);
+  builder_.add_execution_preference(execution_preference);
+  builder_.add_model_token(model_token);
+  builder_.add_cache_directory(cache_directory);
+  builder_.add_accelerator_name(accelerator_name);
+  builder_.add_allow_fp16_precision_for_fp32(allow_fp16_precision_for_fp32);
+  builder_.add_allow_dynamic_dimensions(allow_dynamic_dimensions);
+  builder_.add_allow_nnapi_cpu_on_android_10_plus(allow_nnapi_cpu_on_android_10_plus);
+  return builder_.Finish();
+}
+
+inline flatbuffers::Offset<NNAPISettings> CreateNNAPISettingsDirect(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    const char *accelerator_name = nullptr,
+    const char *cache_directory = nullptr,
+    const char *model_token = nullptr,
+    tflite::NNAPIExecutionPreference execution_preference = tflite::NNAPIExecutionPreference_UNDEFINED,
+    int32_t no_of_nnapi_instances_to_cache = 0,
+    flatbuffers::Offset<tflite::FallbackSettings> fallback_settings = 0,
+    bool allow_nnapi_cpu_on_android_10_plus = false,
+    tflite::NNAPIExecutionPriority execution_priority = tflite::NNAPIExecutionPriority_NNAPI_PRIORITY_UNDEFINED,
+    bool allow_dynamic_dimensions = false,
+    bool allow_fp16_precision_for_fp32 = false) {
+  auto accelerator_name__ = accelerator_name ? _fbb.CreateString(accelerator_name) : 0;
+  auto cache_directory__ = cache_directory ? _fbb.CreateString(cache_directory) : 0;
+  auto model_token__ = model_token ? _fbb.CreateString(model_token) : 0;
+  return tflite::CreateNNAPISettings(
+      _fbb,
+      accelerator_name__,
+      cache_directory__,
+      model_token__,
+      execution_preference,
+      no_of_nnapi_instances_to_cache,
+      fallback_settings,
+      allow_nnapi_cpu_on_android_10_plus,
+      execution_priority,
+      allow_dynamic_dimensions,
+      allow_fp16_precision_for_fp32);
+}
+
+flatbuffers::Offset<NNAPISettings> CreateNNAPISettings(flatbuffers::FlatBufferBuilder &_fbb, const NNAPISettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct GPUSettingsT : public flatbuffers::NativeTable {
+  typedef GPUSettings TableType;
+  bool is_precision_loss_allowed;
+  bool enable_quantized_inference;
+  tflite::GPUBackend force_backend;
+  GPUSettingsT()
+      : is_precision_loss_allowed(false),
+        enable_quantized_inference(true),
+        force_backend(tflite::GPUBackend_UNSET) {
+  }
+};
+
+struct GPUSettings FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef GPUSettingsT NativeTableType;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_IS_PRECISION_LOSS_ALLOWED = 4,
+    VT_ENABLE_QUANTIZED_INFERENCE = 6,
+    VT_FORCE_BACKEND = 8
+  };
+  bool is_precision_loss_allowed() const {
+    return GetField<uint8_t>(VT_IS_PRECISION_LOSS_ALLOWED, 0) != 0;
+  }
+  bool enable_quantized_inference() const {
+    return GetField<uint8_t>(VT_ENABLE_QUANTIZED_INFERENCE, 1) != 0;
+  }
+  tflite::GPUBackend force_backend() const {
+    return static_cast<tflite::GPUBackend>(GetField<int32_t>(VT_FORCE_BACKEND, 0));
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<uint8_t>(verifier, VT_IS_PRECISION_LOSS_ALLOWED) &&
+           VerifyField<uint8_t>(verifier, VT_ENABLE_QUANTIZED_INFERENCE) &&
+           VerifyField<int32_t>(verifier, VT_FORCE_BACKEND) &&
+           verifier.EndTable();
+  }
+  GPUSettingsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(GPUSettingsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<GPUSettings> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GPUSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct GPUSettingsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_is_precision_loss_allowed(bool is_precision_loss_allowed) {
+    fbb_.AddElement<uint8_t>(GPUSettings::VT_IS_PRECISION_LOSS_ALLOWED, static_cast<uint8_t>(is_precision_loss_allowed), 0);
+  }
+  void add_enable_quantized_inference(bool enable_quantized_inference) {
+    fbb_.AddElement<uint8_t>(GPUSettings::VT_ENABLE_QUANTIZED_INFERENCE, static_cast<uint8_t>(enable_quantized_inference), 1);
+  }
+  void add_force_backend(tflite::GPUBackend force_backend) {
+    fbb_.AddElement<int32_t>(GPUSettings::VT_FORCE_BACKEND, static_cast<int32_t>(force_backend), 0);
+  }
+  explicit GPUSettingsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  GPUSettingsBuilder &operator=(const GPUSettingsBuilder &);
+  flatbuffers::Offset<GPUSettings> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<GPUSettings>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<GPUSettings> CreateGPUSettings(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    bool is_precision_loss_allowed = false,
+    bool enable_quantized_inference = true,
+    tflite::GPUBackend force_backend = tflite::GPUBackend_UNSET) {
+  GPUSettingsBuilder builder_(_fbb);
+  builder_.add_force_backend(force_backend);
+  builder_.add_enable_quantized_inference(enable_quantized_inference);
+  builder_.add_is_precision_loss_allowed(is_precision_loss_allowed);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<GPUSettings> CreateGPUSettings(flatbuffers::FlatBufferBuilder &_fbb, const GPUSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct HexagonSettingsT : public flatbuffers::NativeTable {
+  typedef HexagonSettings TableType;
+  int32_t debug_level;
+  int32_t powersave_level;
+  bool print_graph_profile;
+  bool print_graph_debug;
+  HexagonSettingsT()
+      : debug_level(0),
+        powersave_level(0),
+        print_graph_profile(false),
+        print_graph_debug(false) {
+  }
+};
+
+struct HexagonSettings FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef HexagonSettingsT NativeTableType;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DEBUG_LEVEL = 4,
+    VT_POWERSAVE_LEVEL = 6,
+    VT_PRINT_GRAPH_PROFILE = 8,
+    VT_PRINT_GRAPH_DEBUG = 10
+  };
+  int32_t debug_level() const {
+    return GetField<int32_t>(VT_DEBUG_LEVEL, 0);
+  }
+  int32_t powersave_level() const {
+    return GetField<int32_t>(VT_POWERSAVE_LEVEL, 0);
+  }
+  bool print_graph_profile() const {
+    return GetField<uint8_t>(VT_PRINT_GRAPH_PROFILE, 0) != 0;
+  }
+  bool print_graph_debug() const {
+    return GetField<uint8_t>(VT_PRINT_GRAPH_DEBUG, 0) != 0;
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<int32_t>(verifier, VT_DEBUG_LEVEL) &&
+           VerifyField<int32_t>(verifier, VT_POWERSAVE_LEVEL) &&
+           VerifyField<uint8_t>(verifier, VT_PRINT_GRAPH_PROFILE) &&
+           VerifyField<uint8_t>(verifier, VT_PRINT_GRAPH_DEBUG) &&
+           verifier.EndTable();
+  }
+  HexagonSettingsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(HexagonSettingsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<HexagonSettings> Pack(flatbuffers::FlatBufferBuilder &_fbb, const HexagonSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct HexagonSettingsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_debug_level(int32_t debug_level) {
+    fbb_.AddElement<int32_t>(HexagonSettings::VT_DEBUG_LEVEL, debug_level, 0);
+  }
+  void add_powersave_level(int32_t powersave_level) {
+    fbb_.AddElement<int32_t>(HexagonSettings::VT_POWERSAVE_LEVEL, powersave_level, 0);
+  }
+  void add_print_graph_profile(bool print_graph_profile) {
+    fbb_.AddElement<uint8_t>(HexagonSettings::VT_PRINT_GRAPH_PROFILE, static_cast<uint8_t>(print_graph_profile), 0);
+  }
+  void add_print_graph_debug(bool print_graph_debug) {
+    fbb_.AddElement<uint8_t>(HexagonSettings::VT_PRINT_GRAPH_DEBUG, static_cast<uint8_t>(print_graph_debug), 0);
+  }
+  explicit HexagonSettingsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  HexagonSettingsBuilder &operator=(const HexagonSettingsBuilder &);
+  flatbuffers::Offset<HexagonSettings> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<HexagonSettings>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<HexagonSettings> CreateHexagonSettings(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    int32_t debug_level = 0,
+    int32_t powersave_level = 0,
+    bool print_graph_profile = false,
+    bool print_graph_debug = false) {
+  HexagonSettingsBuilder builder_(_fbb);
+  builder_.add_powersave_level(powersave_level);
+  builder_.add_debug_level(debug_level);
+  builder_.add_print_graph_debug(print_graph_debug);
+  builder_.add_print_graph_profile(print_graph_profile);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<HexagonSettings> CreateHexagonSettings(flatbuffers::FlatBufferBuilder &_fbb, const HexagonSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct XNNPackSettingsT : public flatbuffers::NativeTable {
+  typedef XNNPackSettings TableType;
+  int32_t num_threads;
+  XNNPackSettingsT()
+      : num_threads(0) {
+  }
+};
+
+struct XNNPackSettings FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef XNNPackSettingsT NativeTableType;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_NUM_THREADS = 4
+  };
+  int32_t num_threads() const {
+    return GetField<int32_t>(VT_NUM_THREADS, 0);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<int32_t>(verifier, VT_NUM_THREADS) &&
+           verifier.EndTable();
+  }
+  XNNPackSettingsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(XNNPackSettingsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<XNNPackSettings> Pack(flatbuffers::FlatBufferBuilder &_fbb, const XNNPackSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct XNNPackSettingsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_num_threads(int32_t num_threads) {
+    fbb_.AddElement<int32_t>(XNNPackSettings::VT_NUM_THREADS, num_threads, 0);
+  }
+  explicit XNNPackSettingsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  XNNPackSettingsBuilder &operator=(const XNNPackSettingsBuilder &);
+  flatbuffers::Offset<XNNPackSettings> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<XNNPackSettings>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<XNNPackSettings> CreateXNNPackSettings(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    int32_t num_threads = 0) {
+  XNNPackSettingsBuilder builder_(_fbb);
+  builder_.add_num_threads(num_threads);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<XNNPackSettings> CreateXNNPackSettings(flatbuffers::FlatBufferBuilder &_fbb, const XNNPackSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct EdgeTpuSettingsT : public flatbuffers::NativeTable {
+  typedef EdgeTpuSettings TableType;
+  tflite::EdgeTpuSettings_::PowerState inference_power_state;
+  EdgeTpuSettingsT()
+      : inference_power_state(tflite::EdgeTpuSettings_::PowerState_UNDEFINED) {
+  }
+};
+
+struct EdgeTpuSettings FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef EdgeTpuSettingsT NativeTableType;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_INFERENCE_POWER_STATE = 4
+  };
+  tflite::EdgeTpuSettings_::PowerState inference_power_state() const {
+    return static_cast<tflite::EdgeTpuSettings_::PowerState>(GetField<int32_t>(VT_INFERENCE_POWER_STATE, 0));
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<int32_t>(verifier, VT_INFERENCE_POWER_STATE) &&
+           verifier.EndTable();
+  }
+  EdgeTpuSettingsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(EdgeTpuSettingsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<EdgeTpuSettings> Pack(flatbuffers::FlatBufferBuilder &_fbb, const EdgeTpuSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct EdgeTpuSettingsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_inference_power_state(tflite::EdgeTpuSettings_::PowerState inference_power_state) {
+    fbb_.AddElement<int32_t>(EdgeTpuSettings::VT_INFERENCE_POWER_STATE, static_cast<int32_t>(inference_power_state), 0);
+  }
+  explicit EdgeTpuSettingsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  EdgeTpuSettingsBuilder &operator=(const EdgeTpuSettingsBuilder &);
+  flatbuffers::Offset<EdgeTpuSettings> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<EdgeTpuSettings>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<EdgeTpuSettings> CreateEdgeTpuSettings(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    tflite::EdgeTpuSettings_::PowerState inference_power_state = tflite::EdgeTpuSettings_::PowerState_UNDEFINED) {
+  EdgeTpuSettingsBuilder builder_(_fbb);
+  builder_.add_inference_power_state(inference_power_state);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<EdgeTpuSettings> CreateEdgeTpuSettings(flatbuffers::FlatBufferBuilder &_fbb, const EdgeTpuSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct CPUSettingsT : public flatbuffers::NativeTable {
+  typedef CPUSettings TableType;
+  int32_t num_threads;
+  CPUSettingsT()
+      : num_threads(0) {
+  }
+};
+
+struct CPUSettings FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef CPUSettingsT NativeTableType;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_NUM_THREADS = 4
+  };
+  int32_t num_threads() const {
+    return GetField<int32_t>(VT_NUM_THREADS, 0);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<int32_t>(verifier, VT_NUM_THREADS) &&
+           verifier.EndTable();
+  }
+  CPUSettingsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(CPUSettingsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<CPUSettings> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CPUSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct CPUSettingsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_num_threads(int32_t num_threads) {
+    fbb_.AddElement<int32_t>(CPUSettings::VT_NUM_THREADS, num_threads, 0);
+  }
+  explicit CPUSettingsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  CPUSettingsBuilder &operator=(const CPUSettingsBuilder &);
+  flatbuffers::Offset<CPUSettings> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<CPUSettings>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<CPUSettings> CreateCPUSettings(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    int32_t num_threads = 0) {
+  CPUSettingsBuilder builder_(_fbb);
+  builder_.add_num_threads(num_threads);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<CPUSettings> CreateCPUSettings(flatbuffers::FlatBufferBuilder &_fbb, const CPUSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct TFLiteSettingsT : public flatbuffers::NativeTable {
+  typedef TFLiteSettings TableType;
+  tflite::Delegate delegate;
+  std::unique_ptr<tflite::NNAPISettingsT> nnapi_settings;
+  std::unique_ptr<tflite::GPUSettingsT> gpu_settings;
+  std::unique_ptr<tflite::HexagonSettingsT> hexagon_settings;
+  std::unique_ptr<tflite::XNNPackSettingsT> xnnpack_settings;
+  std::unique_ptr<tflite::CPUSettingsT> cpu_settings;
+  int32_t max_delegated_partitions;
+  std::unique_ptr<tflite::EdgeTpuSettingsT> edgetpu_settings;
+  std::unique_ptr<tflite::FallbackSettingsT> fallback_settings;
+  TFLiteSettingsT()
+      : delegate(tflite::Delegate_NONE),
+        max_delegated_partitions(0) {
+  }
+};
+
+struct TFLiteSettings FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef TFLiteSettingsT NativeTableType;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DELEGATE = 4,
+    VT_NNAPI_SETTINGS = 6,
+    VT_GPU_SETTINGS = 8,
+    VT_HEXAGON_SETTINGS = 10,
+    VT_XNNPACK_SETTINGS = 12,
+    VT_CPU_SETTINGS = 14,
+    VT_MAX_DELEGATED_PARTITIONS = 16,
+    VT_EDGETPU_SETTINGS = 18,
+    VT_FALLBACK_SETTINGS = 20
+  };
+  tflite::Delegate delegate() const {
+    return static_cast<tflite::Delegate>(GetField<int32_t>(VT_DELEGATE, 0));
+  }
+  const tflite::NNAPISettings *nnapi_settings() const {
+    return GetPointer<const tflite::NNAPISettings *>(VT_NNAPI_SETTINGS);
+  }
+  const tflite::GPUSettings *gpu_settings() const {
+    return GetPointer<const tflite::GPUSettings *>(VT_GPU_SETTINGS);
+  }
+  const tflite::HexagonSettings *hexagon_settings() const {
+    return GetPointer<const tflite::HexagonSettings *>(VT_HEXAGON_SETTINGS);
+  }
+  const tflite::XNNPackSettings *xnnpack_settings() const {
+    return GetPointer<const tflite::XNNPackSettings *>(VT_XNNPACK_SETTINGS);
+  }
+  const tflite::CPUSettings *cpu_settings() const {
+    return GetPointer<const tflite::CPUSettings *>(VT_CPU_SETTINGS);
+  }
+  int32_t max_delegated_partitions() const {
+    return GetField<int32_t>(VT_MAX_DELEGATED_PARTITIONS, 0);
+  }
+  const tflite::EdgeTpuSettings *edgetpu_settings() const {
+    return GetPointer<const tflite::EdgeTpuSettings *>(VT_EDGETPU_SETTINGS);
+  }
+  const tflite::FallbackSettings *fallback_settings() const {
+    return GetPointer<const tflite::FallbackSettings *>(VT_FALLBACK_SETTINGS);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<int32_t>(verifier, VT_DELEGATE) &&
+           VerifyOffset(verifier, VT_NNAPI_SETTINGS) &&
+           verifier.VerifyTable(nnapi_settings()) &&
+           VerifyOffset(verifier, VT_GPU_SETTINGS) &&
+           verifier.VerifyTable(gpu_settings()) &&
+           VerifyOffset(verifier, VT_HEXAGON_SETTINGS) &&
+           verifier.VerifyTable(hexagon_settings()) &&
+           VerifyOffset(verifier, VT_XNNPACK_SETTINGS) &&
+           verifier.VerifyTable(xnnpack_settings()) &&
+           VerifyOffset(verifier, VT_CPU_SETTINGS) &&
+           verifier.VerifyTable(cpu_settings()) &&
+           VerifyField<int32_t>(verifier, VT_MAX_DELEGATED_PARTITIONS) &&
+           VerifyOffset(verifier, VT_EDGETPU_SETTINGS) &&
+           verifier.VerifyTable(edgetpu_settings()) &&
+           VerifyOffset(verifier, VT_FALLBACK_SETTINGS) &&
+           verifier.VerifyTable(fallback_settings()) &&
+           verifier.EndTable();
+  }
+  TFLiteSettingsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(TFLiteSettingsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<TFLiteSettings> Pack(flatbuffers::FlatBufferBuilder &_fbb, const TFLiteSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct TFLiteSettingsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_delegate(tflite::Delegate delegate) {
+    fbb_.AddElement<int32_t>(TFLiteSettings::VT_DELEGATE, static_cast<int32_t>(delegate), 0);
+  }
+  void add_nnapi_settings(flatbuffers::Offset<tflite::NNAPISettings> nnapi_settings) {
+    fbb_.AddOffset(TFLiteSettings::VT_NNAPI_SETTINGS, nnapi_settings);
+  }
+  void add_gpu_settings(flatbuffers::Offset<tflite::GPUSettings> gpu_settings) {
+    fbb_.AddOffset(TFLiteSettings::VT_GPU_SETTINGS, gpu_settings);
+  }
+  void add_hexagon_settings(flatbuffers::Offset<tflite::HexagonSettings> hexagon_settings) {
+    fbb_.AddOffset(TFLiteSettings::VT_HEXAGON_SETTINGS, hexagon_settings);
+  }
+  void add_xnnpack_settings(flatbuffers::Offset<tflite::XNNPackSettings> xnnpack_settings) {
+    fbb_.AddOffset(TFLiteSettings::VT_XNNPACK_SETTINGS, xnnpack_settings);
+  }
+  void add_cpu_settings(flatbuffers::Offset<tflite::CPUSettings> cpu_settings) {
+    fbb_.AddOffset(TFLiteSettings::VT_CPU_SETTINGS, cpu_settings);
+  }
+  void add_max_delegated_partitions(int32_t max_delegated_partitions) {
+    fbb_.AddElement<int32_t>(TFLiteSettings::VT_MAX_DELEGATED_PARTITIONS, max_delegated_partitions, 0);
+  }
+  void add_edgetpu_settings(flatbuffers::Offset<tflite::EdgeTpuSettings> edgetpu_settings) {
+    fbb_.AddOffset(TFLiteSettings::VT_EDGETPU_SETTINGS, edgetpu_settings);
+  }
+  void add_fallback_settings(flatbuffers::Offset<tflite::FallbackSettings> fallback_settings) {
+    fbb_.AddOffset(TFLiteSettings::VT_FALLBACK_SETTINGS, fallback_settings);
+  }
+  explicit TFLiteSettingsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  TFLiteSettingsBuilder &operator=(const TFLiteSettingsBuilder &);
+  flatbuffers::Offset<TFLiteSettings> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<TFLiteSettings>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<TFLiteSettings> CreateTFLiteSettings(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    tflite::Delegate delegate = tflite::Delegate_NONE,
+    flatbuffers::Offset<tflite::NNAPISettings> nnapi_settings = 0,
+    flatbuffers::Offset<tflite::GPUSettings> gpu_settings = 0,
+    flatbuffers::Offset<tflite::HexagonSettings> hexagon_settings = 0,
+    flatbuffers::Offset<tflite::XNNPackSettings> xnnpack_settings = 0,
+    flatbuffers::Offset<tflite::CPUSettings> cpu_settings = 0,
+    int32_t max_delegated_partitions = 0,
+    flatbuffers::Offset<tflite::EdgeTpuSettings> edgetpu_settings = 0,
+    flatbuffers::Offset<tflite::FallbackSettings> fallback_settings = 0) {
+  TFLiteSettingsBuilder builder_(_fbb);
+  builder_.add_fallback_settings(fallback_settings);
+  builder_.add_edgetpu_settings(edgetpu_settings);
+  builder_.add_max_delegated_partitions(max_delegated_partitions);
+  builder_.add_cpu_settings(cpu_settings);
+  builder_.add_xnnpack_settings(xnnpack_settings);
+  builder_.add_hexagon_settings(hexagon_settings);
+  builder_.add_gpu_settings(gpu_settings);
+  builder_.add_nnapi_settings(nnapi_settings);
+  builder_.add_delegate(delegate);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<TFLiteSettings> CreateTFLiteSettings(flatbuffers::FlatBufferBuilder &_fbb, const TFLiteSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct FallbackSettingsT : public flatbuffers::NativeTable {
+  typedef FallbackSettings TableType;
+  bool allow_automatic_fallback_on_compilation_error;
+  bool allow_automatic_fallback_on_execution_error;
+  FallbackSettingsT()
+      : allow_automatic_fallback_on_compilation_error(false),
+        allow_automatic_fallback_on_execution_error(false) {
+  }
+};
+
+struct FallbackSettings FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef FallbackSettingsT NativeTableType;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_ALLOW_AUTOMATIC_FALLBACK_ON_COMPILATION_ERROR = 4,
+    VT_ALLOW_AUTOMATIC_FALLBACK_ON_EXECUTION_ERROR = 6
+  };
+  bool allow_automatic_fallback_on_compilation_error() const {
+    return GetField<uint8_t>(VT_ALLOW_AUTOMATIC_FALLBACK_ON_COMPILATION_ERROR, 0) != 0;
+  }
+  bool allow_automatic_fallback_on_execution_error() const {
+    return GetField<uint8_t>(VT_ALLOW_AUTOMATIC_FALLBACK_ON_EXECUTION_ERROR, 0) != 0;
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<uint8_t>(verifier, VT_ALLOW_AUTOMATIC_FALLBACK_ON_COMPILATION_ERROR) &&
+           VerifyField<uint8_t>(verifier, VT_ALLOW_AUTOMATIC_FALLBACK_ON_EXECUTION_ERROR) &&
+           verifier.EndTable();
+  }
+  FallbackSettingsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(FallbackSettingsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<FallbackSettings> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FallbackSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FallbackSettingsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_allow_automatic_fallback_on_compilation_error(bool allow_automatic_fallback_on_compilation_error) {
+    fbb_.AddElement<uint8_t>(FallbackSettings::VT_ALLOW_AUTOMATIC_FALLBACK_ON_COMPILATION_ERROR, static_cast<uint8_t>(allow_automatic_fallback_on_compilation_error), 0);
+  }
+  void add_allow_automatic_fallback_on_execution_error(bool allow_automatic_fallback_on_execution_error) {
+    fbb_.AddElement<uint8_t>(FallbackSettings::VT_ALLOW_AUTOMATIC_FALLBACK_ON_EXECUTION_ERROR, static_cast<uint8_t>(allow_automatic_fallback_on_execution_error), 0);
+  }
+  explicit FallbackSettingsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  FallbackSettingsBuilder &operator=(const FallbackSettingsBuilder &);
+  flatbuffers::Offset<FallbackSettings> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<FallbackSettings>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<FallbackSettings> CreateFallbackSettings(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    bool allow_automatic_fallback_on_compilation_error = false,
+    bool allow_automatic_fallback_on_execution_error = false) {
+  FallbackSettingsBuilder builder_(_fbb);
+  builder_.add_allow_automatic_fallback_on_execution_error(allow_automatic_fallback_on_execution_error);
+  builder_.add_allow_automatic_fallback_on_compilation_error(allow_automatic_fallback_on_compilation_error);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<FallbackSettings> CreateFallbackSettings(flatbuffers::FlatBufferBuilder &_fbb, const FallbackSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+inline ComputeSettingsT *ComputeSettings::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new ComputeSettingsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void ComputeSettings::UnPackTo(ComputeSettingsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = preference(); _o->preference = _e; }
+  { auto _e = tflite_settings(); if (_e) _o->tflite_settings = std::unique_ptr<tflite::TFLiteSettingsT>(_e->UnPack(_resolver)); }
+  { auto _e = model_namespace_for_statistics(); if (_e) _o->model_namespace_for_statistics = _e->str(); }
+  { auto _e = model_identifier_for_statistics(); if (_e) _o->model_identifier_for_statistics = _e->str(); }
+}
+
+inline flatbuffers::Offset<ComputeSettings> ComputeSettings::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ComputeSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateComputeSettings(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ComputeSettings> CreateComputeSettings(flatbuffers::FlatBufferBuilder &_fbb, const ComputeSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ComputeSettingsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _preference = _o->preference;
+  auto _tflite_settings = _o->tflite_settings ? CreateTFLiteSettings(_fbb, _o->tflite_settings.get(), _rehasher) : 0;
+  auto _model_namespace_for_statistics = _o->model_namespace_for_statistics.empty() ? 0 : _fbb.CreateString(_o->model_namespace_for_statistics);
+  auto _model_identifier_for_statistics = _o->model_identifier_for_statistics.empty() ? 0 : _fbb.CreateString(_o->model_identifier_for_statistics);
+  return tflite::CreateComputeSettings(
+      _fbb,
+      _preference,
+      _tflite_settings,
+      _model_namespace_for_statistics,
+      _model_identifier_for_statistics);
+}
+
+inline NNAPISettingsT *NNAPISettings::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new NNAPISettingsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void NNAPISettings::UnPackTo(NNAPISettingsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = accelerator_name(); if (_e) _o->accelerator_name = _e->str(); }
+  { auto _e = cache_directory(); if (_e) _o->cache_directory = _e->str(); }
+  { auto _e = model_token(); if (_e) _o->model_token = _e->str(); }
+  { auto _e = execution_preference(); _o->execution_preference = _e; }
+  { auto _e = no_of_nnapi_instances_to_cache(); _o->no_of_nnapi_instances_to_cache = _e; }
+  { auto _e = fallback_settings(); if (_e) _o->fallback_settings = std::unique_ptr<tflite::FallbackSettingsT>(_e->UnPack(_resolver)); }
+  { auto _e = allow_nnapi_cpu_on_android_10_plus(); _o->allow_nnapi_cpu_on_android_10_plus = _e; }
+  { auto _e = execution_priority(); _o->execution_priority = _e; }
+  { auto _e = allow_dynamic_dimensions(); _o->allow_dynamic_dimensions = _e; }
+  { auto _e = allow_fp16_precision_for_fp32(); _o->allow_fp16_precision_for_fp32 = _e; }
+}
+
+inline flatbuffers::Offset<NNAPISettings> NNAPISettings::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NNAPISettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateNNAPISettings(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<NNAPISettings> CreateNNAPISettings(flatbuffers::FlatBufferBuilder &_fbb, const NNAPISettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NNAPISettingsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _accelerator_name = _o->accelerator_name.empty() ? 0 : _fbb.CreateString(_o->accelerator_name);
+  auto _cache_directory = _o->cache_directory.empty() ? 0 : _fbb.CreateString(_o->cache_directory);
+  auto _model_token = _o->model_token.empty() ? 0 : _fbb.CreateString(_o->model_token);
+  auto _execution_preference = _o->execution_preference;
+  auto _no_of_nnapi_instances_to_cache = _o->no_of_nnapi_instances_to_cache;
+  auto _fallback_settings = _o->fallback_settings ? CreateFallbackSettings(_fbb, _o->fallback_settings.get(), _rehasher) : 0;
+  auto _allow_nnapi_cpu_on_android_10_plus = _o->allow_nnapi_cpu_on_android_10_plus;
+  auto _execution_priority = _o->execution_priority;
+  auto _allow_dynamic_dimensions = _o->allow_dynamic_dimensions;
+  auto _allow_fp16_precision_for_fp32 = _o->allow_fp16_precision_for_fp32;
+  return tflite::CreateNNAPISettings(
+      _fbb,
+      _accelerator_name,
+      _cache_directory,
+      _model_token,
+      _execution_preference,
+      _no_of_nnapi_instances_to_cache,
+      _fallback_settings,
+      _allow_nnapi_cpu_on_android_10_plus,
+      _execution_priority,
+      _allow_dynamic_dimensions,
+      _allow_fp16_precision_for_fp32);
+}
+
+inline GPUSettingsT *GPUSettings::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new GPUSettingsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void GPUSettings::UnPackTo(GPUSettingsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = is_precision_loss_allowed(); _o->is_precision_loss_allowed = _e; }
+  { auto _e = enable_quantized_inference(); _o->enable_quantized_inference = _e; }
+  { auto _e = force_backend(); _o->force_backend = _e; }
+}
+
+inline flatbuffers::Offset<GPUSettings> GPUSettings::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GPUSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateGPUSettings(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<GPUSettings> CreateGPUSettings(flatbuffers::FlatBufferBuilder &_fbb, const GPUSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GPUSettingsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _is_precision_loss_allowed = _o->is_precision_loss_allowed;
+  auto _enable_quantized_inference = _o->enable_quantized_inference;
+  auto _force_backend = _o->force_backend;
+  return tflite::CreateGPUSettings(
+      _fbb,
+      _is_precision_loss_allowed,
+      _enable_quantized_inference,
+      _force_backend);
+}
+
+inline HexagonSettingsT *HexagonSettings::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new HexagonSettingsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void HexagonSettings::UnPackTo(HexagonSettingsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = debug_level(); _o->debug_level = _e; }
+  { auto _e = powersave_level(); _o->powersave_level = _e; }
+  { auto _e = print_graph_profile(); _o->print_graph_profile = _e; }
+  { auto _e = print_graph_debug(); _o->print_graph_debug = _e; }
+}
+
+inline flatbuffers::Offset<HexagonSettings> HexagonSettings::Pack(flatbuffers::FlatBufferBuilder &_fbb, const HexagonSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateHexagonSettings(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<HexagonSettings> CreateHexagonSettings(flatbuffers::FlatBufferBuilder &_fbb, const HexagonSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const HexagonSettingsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _debug_level = _o->debug_level;
+  auto _powersave_level = _o->powersave_level;
+  auto _print_graph_profile = _o->print_graph_profile;
+  auto _print_graph_debug = _o->print_graph_debug;
+  return tflite::CreateHexagonSettings(
+      _fbb,
+      _debug_level,
+      _powersave_level,
+      _print_graph_profile,
+      _print_graph_debug);
+}
+
+inline XNNPackSettingsT *XNNPackSettings::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new XNNPackSettingsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void XNNPackSettings::UnPackTo(XNNPackSettingsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = num_threads(); _o->num_threads = _e; }
+}
+
+inline flatbuffers::Offset<XNNPackSettings> XNNPackSettings::Pack(flatbuffers::FlatBufferBuilder &_fbb, const XNNPackSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateXNNPackSettings(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<XNNPackSettings> CreateXNNPackSettings(flatbuffers::FlatBufferBuilder &_fbb, const XNNPackSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const XNNPackSettingsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _num_threads = _o->num_threads;
+  return tflite::CreateXNNPackSettings(
+      _fbb,
+      _num_threads);
+}
+
+inline EdgeTpuSettingsT *EdgeTpuSettings::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new EdgeTpuSettingsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void EdgeTpuSettings::UnPackTo(EdgeTpuSettingsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = inference_power_state(); _o->inference_power_state = _e; }
+}
+
+inline flatbuffers::Offset<EdgeTpuSettings> EdgeTpuSettings::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EdgeTpuSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateEdgeTpuSettings(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<EdgeTpuSettings> CreateEdgeTpuSettings(flatbuffers::FlatBufferBuilder &_fbb, const EdgeTpuSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EdgeTpuSettingsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _inference_power_state = _o->inference_power_state;
+  return tflite::CreateEdgeTpuSettings(
+      _fbb,
+      _inference_power_state);
+}
+
+inline CPUSettingsT *CPUSettings::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new CPUSettingsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void CPUSettings::UnPackTo(CPUSettingsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = num_threads(); _o->num_threads = _e; }
+}
+
+inline flatbuffers::Offset<CPUSettings> CPUSettings::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CPUSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateCPUSettings(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<CPUSettings> CreateCPUSettings(flatbuffers::FlatBufferBuilder &_fbb, const CPUSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CPUSettingsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _num_threads = _o->num_threads;
+  return tflite::CreateCPUSettings(
+      _fbb,
+      _num_threads);
+}
+
+inline TFLiteSettingsT *TFLiteSettings::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new TFLiteSettingsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void TFLiteSettings::UnPackTo(TFLiteSettingsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = delegate(); _o->delegate = _e; }
+  { auto _e = nnapi_settings(); if (_e) _o->nnapi_settings = std::unique_ptr<tflite::NNAPISettingsT>(_e->UnPack(_resolver)); }
+  { auto _e = gpu_settings(); if (_e) _o->gpu_settings = std::unique_ptr<tflite::GPUSettingsT>(_e->UnPack(_resolver)); }
+  { auto _e = hexagon_settings(); if (_e) _o->hexagon_settings = std::unique_ptr<tflite::HexagonSettingsT>(_e->UnPack(_resolver)); }
+  { auto _e = xnnpack_settings(); if (_e) _o->xnnpack_settings = std::unique_ptr<tflite::XNNPackSettingsT>(_e->UnPack(_resolver)); }
+  { auto _e = cpu_settings(); if (_e) _o->cpu_settings = std::unique_ptr<tflite::CPUSettingsT>(_e->UnPack(_resolver)); }
+  { auto _e = max_delegated_partitions(); _o->max_delegated_partitions = _e; }
+  { auto _e = edgetpu_settings(); if (_e) _o->edgetpu_settings = std::unique_ptr<tflite::EdgeTpuSettingsT>(_e->UnPack(_resolver)); }
+  { auto _e = fallback_settings(); if (_e) _o->fallback_settings = std::unique_ptr<tflite::FallbackSettingsT>(_e->UnPack(_resolver)); }
+}
+
+inline flatbuffers::Offset<TFLiteSettings> TFLiteSettings::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TFLiteSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateTFLiteSettings(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<TFLiteSettings> CreateTFLiteSettings(flatbuffers::FlatBufferBuilder &_fbb, const TFLiteSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TFLiteSettingsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _delegate = _o->delegate;
+  auto _nnapi_settings = _o->nnapi_settings ? CreateNNAPISettings(_fbb, _o->nnapi_settings.get(), _rehasher) : 0;
+  auto _gpu_settings = _o->gpu_settings ? CreateGPUSettings(_fbb, _o->gpu_settings.get(), _rehasher) : 0;
+  auto _hexagon_settings = _o->hexagon_settings ? CreateHexagonSettings(_fbb, _o->hexagon_settings.get(), _rehasher) : 0;
+  auto _xnnpack_settings = _o->xnnpack_settings ? CreateXNNPackSettings(_fbb, _o->xnnpack_settings.get(), _rehasher) : 0;
+  auto _cpu_settings = _o->cpu_settings ? CreateCPUSettings(_fbb, _o->cpu_settings.get(), _rehasher) : 0;
+  auto _max_delegated_partitions = _o->max_delegated_partitions;
+  auto _edgetpu_settings = _o->edgetpu_settings ? CreateEdgeTpuSettings(_fbb, _o->edgetpu_settings.get(), _rehasher) : 0;
+  auto _fallback_settings = _o->fallback_settings ? CreateFallbackSettings(_fbb, _o->fallback_settings.get(), _rehasher) : 0;
+  return tflite::CreateTFLiteSettings(
+      _fbb,
+      _delegate,
+      _nnapi_settings,
+      _gpu_settings,
+      _hexagon_settings,
+      _xnnpack_settings,
+      _cpu_settings,
+      _max_delegated_partitions,
+      _edgetpu_settings,
+      _fallback_settings);
+}
+
+inline FallbackSettingsT *FallbackSettings::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new FallbackSettingsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void FallbackSettings::UnPackTo(FallbackSettingsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = allow_automatic_fallback_on_compilation_error(); _o->allow_automatic_fallback_on_compilation_error = _e; }
+  { auto _e = allow_automatic_fallback_on_execution_error(); _o->allow_automatic_fallback_on_execution_error = _e; }
+}
+
+inline flatbuffers::Offset<FallbackSettings> FallbackSettings::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FallbackSettingsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateFallbackSettings(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FallbackSettings> CreateFallbackSettings(flatbuffers::FlatBufferBuilder &_fbb, const FallbackSettingsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FallbackSettingsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _allow_automatic_fallback_on_compilation_error = _o->allow_automatic_fallback_on_compilation_error;
+  auto _allow_automatic_fallback_on_execution_error = _o->allow_automatic_fallback_on_execution_error;
+  return tflite::CreateFallbackSettings(
+      _fbb,
+      _allow_automatic_fallback_on_compilation_error,
+      _allow_automatic_fallback_on_execution_error);
+}
+
+}  // namespace tflite
+
+#endif  // FLATBUFFERS_GENERATED_CONFIGURATION_TFLITE_H_
diff --git a/tensorflow/lite/g3doc/guide/build_cmake.md b/tensorflow/lite/g3doc/guide/build_cmake.md
index 646b1ed..ec1ec38 100644
--- a/tensorflow/lite/g3doc/guide/build_cmake.md
+++ b/tensorflow/lite/g3doc/guide/build_cmake.md
@@ -110,6 +110,20 @@
 cmake --build . -j -t benchmark_model
 ```
 
+## Available Options to build TensorFlow Lite
+
+Here is the list of available options. You can override it with
+`-D<option_name>=[ON|OFF]`. For example, `-DTFLITE_ENABLE_XNNPACK=OFF` to
+disable XNNPACK which is enabled by default.
+
+Option Name           | Feature                                  | Default
+--------------------- | ---------------------------------------- | ------------
+TFLITE_ENABLE_RUY     | Enable RUY matrix multiplication library | OFF
+TFLITE_ENABLE_NNAPI   | Enable NNAPI delegate                    | ON (Android)
+TFLITE_ENABLE_GPU     | Enable GPU delegate                      | OFF
+TFLITE_ENABLE_XNNPACK | Enable XNNPACK delegate                  | ON
+TFLITE_ENABLE_MMAP    | Enable MMAP (unsupported on Windows)     | ON
+
 ## Create a CMake project which uses TensorFlow Lite
 
 Here is the CMakeLists.txt of
diff --git a/tensorflow/lite/g3doc/guide/ops_select.md b/tensorflow/lite/g3doc/guide/ops_select.md
index e56b783..3dd54ef 100644
--- a/tensorflow/lite/g3doc/guide/ops_select.md
+++ b/tensorflow/lite/g3doc/guide/ops_select.md
@@ -261,15 +261,11 @@
     input/output types that are typically available in TensorFlow.
 *   Unsupported ops: Control flow ops and ops that require explicit
     initialization from resources, like `HashTableV2`, are not yet supported.
-*   Unsupported optimizations: If you apply an optimization known as
-    [post training quantization](../performance/post_training_quantization.md),
-    only the TensorFlow Lite ops will be quantized (or optimized), but the
-    TensorFlow ops will remain as float (or unoptimized).
 
-## Future plans
+## Updates
 
-The following is a list of improvements to this pipeline that are in progress:
-
-*   *Improved performance* - Work is being done to ensure TensorFlow Lite with
-    TensorFlow ops nicely cooperates with hardware accelerated delegates, for
-    example, NNAPI and GPU delegates.
+*   Version 2.5 (not yet officially released)
+    -   You can apply an optimization known as
+        [post training quantization](../performance/post_training_quantization.md)
+*   Version 2.4
+    -   Compatibility with hardware accelerated delegates has improved
diff --git a/tensorflow/lite/g3doc/performance/implementing_delegate.md b/tensorflow/lite/g3doc/performance/implementing_delegate.md
index d97e56c..e290812 100644
--- a/tensorflow/lite/g3doc/performance/implementing_delegate.md
+++ b/tensorflow/lite/g3doc/performance/implementing_delegate.md
@@ -72,7 +72,7 @@
 
 ### 1 - `SimpleDelegateInterface`
 
-This class represents the capabilities of the delegate, which operations aer
+This class represents the capabilities of the delegate, which operations are
 supported, and a factory class for creating a kernel which encapsulates the
 delegated graph. For more details, see the interface defined in this
 [C++ header file](https://github.com/tensorflow/tensorflow/blob/8a643858ce174b8bd1b4bb8fa4bfaa62f7e8c45f/tensorflow/lite/delegates/utils/simple_delegate.h#L71).
diff --git a/tensorflow/lite/g3doc/performance/measurement.md b/tensorflow/lite/g3doc/performance/measurement.md
index 2d03fa8..8947054 100644
--- a/tensorflow/lite/g3doc/performance/measurement.md
+++ b/tensorflow/lite/g3doc/performance/measurement.md
@@ -32,9 +32,16 @@
 Download the nightly pre-built Android benchmark apps using the links below:
 
 *   [android_aarch64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model.apk)
-
 *   [android_arm](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_arm_benchmark_model.apk)
 
+As for Android benchmark apps that support [TF ops](https://www.tensorflow.org/lite/guide/ops_select)
+via [Flex delegate](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/flex),
+use the links below:
+
+*   [android_aarch64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model_plus_flex.apk)
+*   [android_arm](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_arm_benchmark_model_plus_flex.apk)
+
+
 You can also build the app from source by following these
 [instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/android).
 
@@ -115,6 +122,16 @@
 *   [android_aarch64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model)
 *   [android_arm](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_arm_benchmark_model)
 
+As for nightly pre-built binaries that support [TF ops](https://www.tensorflow.org/lite/guide/ops_select)
+via [Flex delegate](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/flex),
+use the links below:
+
+*   [linux_x86-64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_x86-64_benchmark_model_plus_flex)
+*   [linux_aarch64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_aarch64_benchmark_model_plus_flex)
+*   [linux_arm](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_arm_benchmark_model_plus_flex)
+*   [android_aarch64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model_plus_flex)
+*   [android_arm](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_arm_benchmark_model_plus_flex)
+
 You can also build the native benchmark binary from
 [source](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark)
 on your computer.
diff --git a/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb
index 2ebaaaf..53c57d2 100644
--- a/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb
+++ b/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb
@@ -11,7 +11,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 1,
+      "execution_count": null,
       "metadata": {
         "cellView": "form",
         "id": "I9sUhVL_VZNO"
@@ -46,20 +46,20 @@
         "id": "CGuqeuPSVNo-"
       },
       "source": [
-        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/performance/post_training_float16_quant\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "\u003c/table\u003e"
+        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/performance/post_training_float16_quant\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
+        "  </td>\n",
+        "</table>"
       ]
     },
     {
@@ -97,7 +97,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 2,
+      "execution_count": null,
       "metadata": {
         "id": "gyqAw1M9lyab"
       },
@@ -114,24 +114,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 3,
+      "execution_count": null,
       "metadata": {
         "id": "c6nb7OPlXs_3"
       },
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "tf.float16"
-            ]
-          },
-          "execution_count": 3,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "tf.float16"
       ]
@@ -147,34 +134,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 4,
+      "execution_count": null,
       "metadata": {
         "id": "hWSAjQWagIHl"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
-            "11493376/11490434 [==============================] - 0s 0us/step\n",
-            "11501568/11490434 [==============================] - 0s 0us/step\n",
-            "1875/1875 [==============================] - 12s 6ms/step - loss: 0.2864 - accuracy: 0.9207 - val_loss: 0.1467 - val_accuracy: 0.9560\n"
-          ]
-        },
-        {
-          "data": {
-            "text/plain": [
-              "\u003ctensorflow.python.keras.callbacks.History at 0x7fcd75df46a0\u003e"
-            ]
-          },
-          "execution_count": 4,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "# Load MNIST dataset\n",
         "mnist = keras.datasets.mnist\n",
@@ -230,7 +194,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 5,
+      "execution_count": null,
       "metadata": {
         "id": "_i8B2nDZmAgQ"
       },
@@ -251,7 +215,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 6,
+      "execution_count": null,
       "metadata": {
         "id": "vptWZq2xnclo"
       },
@@ -263,24 +227,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 7,
+      "execution_count": null,
       "metadata": {
         "id": "Ie9pQaQrn5ue"
       },
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "84452"
-            ]
-          },
-          "execution_count": 7,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
         "tflite_model_file.write_bytes(tflite_model)"
@@ -297,7 +248,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 8,
+      "execution_count": null,
       "metadata": {
         "id": "HEZ6ET1AHAS3"
       },
@@ -318,24 +269,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 9,
+      "execution_count": null,
       "metadata": {
         "id": "yuNfl3CoHNK3"
       },
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "44272"
-            ]
-          },
-          "execution_count": 9,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "tflite_fp16_model = converter.convert()\n",
         "tflite_model_fp16_file = tflite_models_dir/\"mnist_model_quant_f16.tflite\"\n",
@@ -353,21 +291,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 10,
+      "execution_count": null,
       "metadata": {
         "id": "JExfcfLDscu4"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "total 128K\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828 44K Jun 23 06:04 mnist_model_quant_f16.tflite\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828 83K Jun 23 06:04 mnist_model.tflite\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "!ls -lh {tflite_models_dir}"
       ]
@@ -401,7 +329,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 11,
+      "execution_count": null,
       "metadata": {
         "id": "Jn16Rc23zTss"
       },
@@ -413,7 +341,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 12,
+      "execution_count": null,
       "metadata": {
         "id": "J8Pztk1mvNVL"
       },
@@ -434,7 +362,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 13,
+      "execution_count": null,
       "metadata": {
         "id": "AKslvo2kwWac"
       },
@@ -452,24 +380,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 14,
+      "execution_count": null,
       "metadata": {
         "id": "XZClM2vo3_bm"
       },
-      "outputs": [
-        {
-          "data": {
-            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFxZJREFUeJzt3XtU1HXeB/D3cE0RVDSG4eKMPJBL\nIrI6ZqXhBTFrVwwpw5WEAGnLc9ZL2nbbI1arPPV4nix99jRR7aiFz7qmtIu6KhulVrJj4baYHiKI\nq6DCE4pyG7/PH51mI5nf4DAX9Pt+neM5zO/z/f2+H37ynt/M/GbmpxJCCBCRdDzc3QARuQfDTyQp\nhp9IUgw/kaQYfiJJMfxEkmL4yeF6enqgUqlQXV0NAMjOzsaGDRucPm9+fj5mzpzp9HluFgy/nYYN\nG2b55+HhgSFDhlhuv/vuu06fPzs7u1cPvr6+GDlypNPntUd+fj6effZZm+OmT5+OP/7xj07p4Ztv\nvum1v4YNGwaVSoXNmzc7Zb4bgZe7G7hRXbp0yfKzTqdDfn4+5syZY3V8T08PvLwct7vz8/ORn59v\nuZ2WloahQ4c6bPs/Zjab4enp6ZRtu0pERESv/7Ovv/4a48aNw8KFC93YlXvxyO8kzz//PB5++GEs\nXrwY/v7+2LFjB9LS0pCbm2sZc/jwYeh0Osvturo6JCcn49Zbb8XYsWOxdevWfs118eJF7NmzB+np\n6f0a/8O8L7zwAkaNGoWxY8di586dlnpaWhqWL1+OefPmwc/PD0eOHEFHRwdWr16N8PBwqNVqPPHE\nE+jo6LCsk5eXh+DgYISGhsJoNPaa76e/9/vvv4+4uDgEBAQgMjISBw8exG9/+1t8+umn+PWvf41h\nw4Zh5cqVAIBTp05hzpw5CAwMxM9+9jPs3r3bsp1z587hl7/8JQICAnDnnXeiqqqqX78/ABiNRsye\nPRvh4eH9XuemI2jAtFqtOHToUK9lzz33nPD29hYffPCBMJvN4vLly2LJkiVi3bp1ljGHDh0SWq1W\nCCFET0+PmDhxovj9738vOjs7RUVFhdBqteLw4cNCCCFKSkrEqFGj+pz/rbfeEpGRkf3u99ChQ8LT\n01OsWbNGdHR0iOLiYjFkyBBRUVEhhBBiyZIlYsSIEeKTTz4RZrNZdHR0iOXLl4sHHnhAtLS0iO++\n+07cd9994vnnnxdCCPGXv/xFBAcHi/LycnHp0iXx0EMPCQCiqqrKsr0ffu9jx46J4cOHi8OHDwuz\n2SxqamrE6dOnhRBCTJs2TbzzzjuWPtva2kRISIgwGo2iu7tbmEwmERgYaBmfkpIiUlNTRXt7uzh5\n8qQIDg4WM2bMsKw/b9488corr1zz+1+9elVotVqxffv2fu+zmxHD7wDWwj9r1qxey5TCf/ToUTF2\n7Nhe41944QWRnZ1tc/74+Hjx4osv9rvfQ4cOCW9vb9He3m5ZlpycLDZs2GDp89FHH7XUzGaz8PX1\nFdXV1ZZlH3/8seUO55FHHhHPPfecpVZeXm41/JmZmWLNmjV99vXT8O/YsUPMnDmz15jMzEzx0ksv\nia6uLuHp6Wm5wxJCiLVr1/YKvzV///vfhb+/f6/fX0Z8zu9E1/OQ8ttvv0VNTQ1GjBhhWWY2m22+\nel1VVYWjR49i27Zt19XbqFGjer1GoNVq0dDQYLn9497Pnj2Lzs5OTJw40bJM/OjzYA0NDZg2bVqv\nbVlTW1uLKVOm9KvHb7/9FseOHeu1T3p6epCRkYGmpiaYzeZefWq1WpSWltrcrtFoxEMPPeS010hu\nFAy/E6lUql63/fz8cPnyZcvts2fPWn4ODw9HVFQUvvrqq+uaY9u2bZgxY4Zi4Ppy4cIFXLlyBUOG\nDAEA1NTUQK/X99m7Wq2Gj48Pzpw5A7Vafc22NBoNamtrLbdramqszhseHo7Kyso+az/dX+Hh4UhI\nSMD+/fuvGdvd3Q0PDw/U1tYiMjLS5rw/aG9vx+7du1FUVGRz7M2OL/i5UFxcHIqKitDa2orGxka8\n9tprltpdd90FHx8fbNq0CR0dHTCbzfjyyy9x4sQJxW1u27YNGRkZ1yxPS0tDdna21fWuXr2K3Nxc\ndHV1oaSkBPv378eDDz7Y51hPT09kZ2dj5cqVOHfuHIQQqKurw8GDBwEAixYtwttvv43Tp0+jvb0d\n69evtzpvVlYW8vPz8eGHH+Lq1auoq6vDmTNnAHx/J/PNN99YxiYlJaG8vBzvvfceuru70d3djdLS\nUpw5cwbe3t544IEHsG7dOly5cgX/+te/sH37dsV9BQC7d+9GUFAQ7rnnHptjb3YMvwtlZGQgOjoa\nWq0W8+bNQ2pqqqXm5eWFffv2obS0FDqdDqNHj8Zjjz2GtrY2AEBJSUmvh78AcOTIETQ1NSElJeWa\nuWpra3s9FP+psLAw+Pn5QaPRID09Hfn5+YiKirI6ftOmTdBqtbjjjjswfPhwzJ07FxUVFQCA+fPn\nY/ny5ZgxYwZuu+02JCYmWt3O3XffjTfffBO/+c1vMHz4cMyaNcvyqGHlypUoKCjAiBEjsHr1agwf\nPhx/+9vfsGPHDmg0GgQHB+OZZ55BZ2cnAOAPf/gDWltboVarkZWVhUcffbTXXHPnzsXLL7/ca5nR\naMTSpUuveZQhI5UQ/DKPm01HRwd+/vOf48svv+zzvQWHDx9Gdna25R14JCc+578J3XLLLdf92gHJ\nhw/7iSTFh/1EkuKRn0hSLn3O76PyxS3wc+WURFLpQDu6RGe/xg4o/AcOHMCKFStgNpuRnZ2Np59+\nWnH8LfDDVFXCQKYkIgXHRXG/x9r9sN9sNmP58uXYv38/Tp06hYKCApw6dcrezRGRi9kd/tLSUkRG\nRiIiIgI+Pj5ITU1FYWGhI3sjIieyO/z19fW9PlQRFhaG+vr6a8YZDAbo9Xro9Xp0o3/PRYjI+ewO\nf19nCPt6y2ROTg5MJhNMJhO84WvvdETkYHaHPywsrNcnuerq6hASEuKQpojI+ewO/5QpU1BRUYGq\nqip0dXVh586dSEpKcmRvROREdp/q8/LywpYtW3DvvffCbDYjMzMT48ePd2RvRORELn17b4AqkOf5\niZzouChGm2jp11i+vZdIUgw/kaQYfiJJMfxEkmL4iSTF8BNJiuEnkhTDTyQphp9IUgw/kaQYfiJJ\nMfxEkmL4iSTF8BNJiuEnkhTDTyQphp9IUgw/kaQYfiJJMfxEkmL4iSTF8BNJiuEnkhTDTyQphp9I\nUgw/kaQYfiJJMfxEkmL4iSTF8BNJymsgK+t0Ovj7+8PT0xNeXl4wmUyO6ouInGxA4QeADz/8EKNH\nj3ZEL0TkQnzYTySpAYVfpVJh7ty5mDx5MgwGQ59jDAYD9Ho99Ho9utE5kOmIyIFUQghh78oNDQ0I\nCQlBc3MzEhMT8frrryM+Pt7q+ABVIKaqEuydjohsOC6K0SZa+jV2QEf+kJAQAEBQUBCSk5NRWlo6\nkM0RkQvZHf729nZcvHjR8vPBgwcRExPjsMaIyLnsfrW/qakJycnJAICenh786le/wrx58xzWGBE5\nl93hj4iIwMmTJx3ZCxG5EE/1EUmK4SeSFMNPJCmGn0hSDD+RpAb8wR5ZXFh2l9XamEe+Vlz3dLNa\nsd7V6a1YDy1Qrg+tu2S1drXslOK6JC8e+YkkxfATSYrhJ5IUw08kKYafSFIMP5GkGH4iSfE8fz89\ntfY9q7UUv1bllf9jgJPPVC5X91y2Wtt8btYAJ79xlTZrrdb8Ng1XXNer+ISj2xl0eOQnkhTDTyQp\nhp9IUgw/kaQYfiJJMfxEkmL4iSQ1oCv2XK8b+Yo97Q9OtVo7H6t8HzryK+Vd3BqtUqz7xP6fYv3l\nmPet1hKHXFFct+jyMMX6L4Za/66AgboiuhTrxzv9FOszb+m2e+7IoscU67fl/MPubbuTy67YQ0Q3\nLoafSFIMP5GkGH4iSTH8RJJi+IkkxfATSYqf5+8nvz8fV6gNbNsBA1sdrwfPtFp7aZpOee6PlK85\n8PLMSDs66h+vK1cV637/bFSsj/p4t2J9go/16x0MrVa+FoIMbB75MzMzERQUhJiYGMuylpYWJCYm\nIioqComJiWhttfFlFkQ06NgMf0ZGBg4cONBrWV5eHhISElBRUYGEhATk5eU5rUEicg6b4Y+Pj0dg\nYGCvZYWFhUhPTwcApKenY+/evc7pjoicxq7n/E1NTdBoNAAAjUaD5uZmq2MNBgMMBgMAoBud9kxH\nRE7g9Ff7c3JyYDKZYDKZ4A1fZ09HRP1kV/jVajUaG79/JbaxsRFBQUEObYqInM+u8CclJcFoNAIA\njEYjFixY4NCmiMj5bD7nX7x4MUpKSnD+/HmEhYVh/fr1ePrpp7Fo0SK89dZbGDNmDHbt2uWKXsmK\nnrNNVmt+u63XAMBsY9t+f75gR0eO0ZR9l2J9vI/yn+9/tYyzWtO9843iuj2K1ZuDzfAXFBT0uby4\nuNjhzRCR6/DtvUSSYviJJMXwE0mK4SeSFMNPJCl+pJfcxksbrljf8uwWxbq3ylOxvmvzHKu1UY2f\nKq4rAx75iSTF8BNJiuEnkhTDTyQphp9IUgw/kaQYfiJJ8Tw/uc3pVaGK9Sm+ypcuL+9Svvx44KnL\n192TTHjkJ5IUw08kKYafSFIMP5GkGH4iSTH8RJJi+IkkxfP85FSdv5hitfb5g/9tY23lKzw9vmKF\nYn3IJ6U2ti83HvmJJMXwE0mK4SeSFMNPJCmGn0hSDD+RpBh+IknxPD85Vc191o8vw1TK5/EXVyUq\n1oceOKlYF4pVsnnkz8zMRFBQEGJiYizLcnNzERoairi4OMTFxWHfvn1ObZKIHM9m+DMyMnDgwIFr\nlq9atQplZWUoKyvD/fff75TmiMh5bIY/Pj4egYGBruiFiFzI7hf8tmzZgtjYWGRmZqK1tdXqOIPB\nAL1eD71ej2502jsdETmYXeF//PHHUVlZibKyMmg0Gjz55JNWx+bk5MBkMsFkMsHbxgc1iMh17Aq/\nWq2Gp6cnPDw8sGzZMpSW8tNTRDcau8Lf2Nho+XnPnj29zgQQ0Y3B5nn+xYsXo6SkBOfPn0dYWBjW\nr1+PkpISlJWVQaVSQafT4Y033nBFrzQIefj7K9Yfueeo1Vrb1Q7FdZs3RCjWfTv/oVgnZTbDX1BQ\ncM2yrKwspzRDRK7Dt/cSSYrhJ5IUw08kKYafSFIMP5Gk+JFeGpCK3PGK9b+O/h+rtQUVKYrr+u7j\nqTxn4pGfSFIMP5GkGH4iSTH8RJJi+IkkxfATSYrhJ5IUz/OTou/S7lSs//Ph1xTrlT3dVmuX/jNM\ncV1fNCrWaWB45CeSFMNPJCmGn0hSDD+RpBh+Ikkx/ESSYviJJMXz/JLzCg1RrK/83f8q1n1Vyn9C\nqScfsVq7dT8/r+9OPPITSYrhJ5IUw08kKYafSFIMP5GkGH4iSTH8RJKyeZ6/trYWS5cuxdmzZ+Hh\n4YGcnBysWLECLS0tePjhh1FdXQ2dToc//elPGDlypCt6puug8lL+L5741zrF+kPDLijW370YpFhX\n/8768eWq4prkbDaP/F5eXti0aRO++uorfPbZZ9i6dStOnTqFvLw8JCQkoKKiAgkJCcjLy3NFv0Tk\nIDbDr9FoMGnSJACAv78/oqOjUV9fj8LCQqSnpwMA0tPTsXfvXud2SkQOdV3P+aurq/HFF19g6tSp\naGpqgkajAfD9HURzc7NTGiQi5+j3e/svXbqElJQUvPrqqwgICOj3BAaDAQaDAQDQjc7r75CInKJf\nR/7u7m6kpKRgyZIlWLhwIQBArVajsfH7L1hsbGxEUFDfL/zk5OTAZDLBZDLBG74OapuIBspm+IUQ\nyMrKQnR0NFavXm1ZnpSUBKPRCAAwGo1YsGCB87okIodTCSGE0oCjR4/innvuwYQJE+Dh8f19xYYN\nGzB16lQsWrQINTU1GDNmDHbt2oXAwEDFyQJUgZiqSnBc92STarLyJbSLPtg+oO3f/cxyxfqIbZ8O\naPt0fY6LYrSJln6Ntfmcf/r06bB2/1BcXHx9nRHRoMF3+BFJiuEnkhTDTyQphp9IUgw/kaQYfiJJ\n8au7bwKet99mtZazs3BA2779beXz+Lrtnw1o++Q+PPITSYrhJ5IUw08kKYafSFIMP5GkGH4iSTH8\nRJLief6bwOknrH9l+vyhbQPadlhJl/IA5a+DoEGMR34iSTH8RJJi+IkkxfATSYrhJ5IUw08kKYaf\nSFI8z38D6Jh/h2K9eP4mhepQxzZDNw0e+YkkxfATSYrhJ5IUw08kKYafSFIMP5GkGH4iSdk8z19b\nW4ulS5fi7Nmz8PDwQE5ODlasWIHc3Fy8+eabuPXWWwEAGzZswP333+/0hmXUMM1TsT7Gy/5z+e9e\nDFKse7cpf56fn+a/cdkMv5eXFzZt2oRJkybh4sWLmDx5MhITEwEAq1atwpo1a5zeJBE5ns3wazQa\naDQaAIC/vz+io6NRX1/v9MaIyLmu6zl/dXU1vvjiC0ydOhUAsGXLFsTGxiIzMxOtra19rmMwGKDX\n66HX69GNzoF3TEQO0e/wX7p0CSkpKXj11VcREBCAxx9/HJWVlSgrK4NGo8GTTz7Z53o5OTkwmUww\nmUzwhq/DGieigelX+Lu7u5GSkoIlS5Zg4cKFAAC1Wg1PT094eHhg2bJlKC0tdWqjRORYNsMvhEBW\nVhaio6OxevVqy/LGxkbLz3v27EFMTIxzOiQip7D5gt+xY8ewfft2TJgwAXFxcQC+P61XUFCAsrIy\nqFQq6HQ6vPHGG05vlq7fxgu3K9Y/vVenWBeNXzqwGxpMbIZ/+vTpEH18NzvP6RPd2PgOPyJJMfxE\nkmL4iSTF8BNJiuEnkhTDTyQplejrPJ6TBKgCMVWV4KrpiKRzXBSjTbT0ayyP/ESSYviJJMXwE0mK\n4SeSFMNPJCmGn0hSDD+RpFx6iW6fUR5o1VVZbp87d87y1d+DzWDtbbD2BbA3ezmyN5/q/h/PXfom\nn5/S6/UwmUzuml7RYO1tsPYFsDd7uas3PuwnkhTDTyQpz9zc3Fx3NjB58mR3Tq9osPY2WPsC2Ju9\n3NGbW5/zE5H78GE/kaQYfiJJuSX8Bw4cwLhx4xAZGYm8vDx3tGCVTqezXKNAr9e7tZfMzEwEBQX1\nuiBKS0sLEhMTERUVhcTERKvXSHRHb7m5uQgNDUVcXBzi4uKwb98+t/RWW1uLWbNmITo6GuPHj8fm\nzZsBuH/fWevLbftNuFhPT4+IiIgQlZWVorOzU8TGxory8nJXt2GVVqsV586dc3cbQgghPvroI3Hi\nxAkxfvx4y7K1a9eKjRs3CiGE2Lhxo3jqqacGTW/r1q0Tr7zyilv6+bGGhgZx4sQJIYQQbW1tIioq\nSpSXl7t931nry137zeVH/tLSUkRGRiIiIgI+Pj5ITU1FYWGhq9u4IcTHxyMwMLDXssLCQqSnpwMA\n0tPTsXfvXne01mdvg4VGo8GkSZMA9L6svLv3nbW+3MXl4a+vr0d4eLjldlhYmFt3wE+pVCrMnTsX\nkydPhsFgcHc712hqaoJGowHw/R9Tc3OzmzvqrT+XbXelH19WfjDtO3sud+9oLg+/6OPMokqlcnUb\nVh07dgyff/459u/fj61bt+Ljjz92d0s3jP5ett1VfnpZ+cHC3svdO5rLwx8WFoba2lrL7bq6OoSE\nhLi6Dat+6CUoKAjJycmD7tLjarXacoXkxsZGBAUFubmjfxtMl223dll5d++7wXS5e5eHf8qUKaio\nqEBVVRW6urqwc+dOJCUlubqNPrW3t+PixYuWnw8ePDjoLj2elJQEo9EIADAajViwYIGbO/q3wXLZ\ndmHlsvLu3nfW+nLbfnP5S4xCiKKiIhEVFSUiIiLESy+95I4W+lRZWSliY2NFbGysuP32293eW2pq\nqggODhZeXl4iNDRU5Ofni/Pnz4vZs2eLyMhIMXv2bHHhwoVB01taWpqIiYkREyZMEPPnzxcNDQ1u\n6e3IkSMCgJgwYYKYOHGimDhxoigqKnL7vrPWl7v2G9/eSyQpvsOPSFIMP5GkGH4iSTH8RJJi+Ikk\nxfATSYrhJ5LU/wOdAGX9nfSgHgAAAABJRU5ErkJggg==\n",
-            "text/plain": [
-              "\u003cFigure size 600x400 with 1 Axes\u003e"
-            ]
-          },
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "display_data"
-        }
-      ],
+      "outputs": [],
       "source": [
         "import matplotlib.pylab as plt\n",
         "\n",
@@ -482,7 +397,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 15,
+      "execution_count": null,
       "metadata": {
         "id": "3gwhv4lKbYZ4"
       },
@@ -500,24 +415,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 16,
+      "execution_count": null,
       "metadata": {
         "id": "CIH7G_MwbY2x"
       },
-      "outputs": [
-        {
-          "data": {
-            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFxZJREFUeJzt3XtU1HXeB/D3cE0RVDSG4eKMPJBL\nIrI6ZqXhBTFrVwwpw5WEAGnLc9ZL2nbbI1arPPV4nix99jRR7aiFz7qmtIu6KhulVrJj4baYHiKI\nq6DCE4pyG7/PH51mI5nf4DAX9Pt+neM5zO/z/f2+H37ynt/M/GbmpxJCCBCRdDzc3QARuQfDTyQp\nhp9IUgw/kaQYfiJJMfxEkmL4yeF6enqgUqlQXV0NAMjOzsaGDRucPm9+fj5mzpzp9HluFgy/nYYN\nG2b55+HhgSFDhlhuv/vuu06fPzs7u1cPvr6+GDlypNPntUd+fj6effZZm+OmT5+OP/7xj07p4Ztv\nvum1v4YNGwaVSoXNmzc7Zb4bgZe7G7hRXbp0yfKzTqdDfn4+5syZY3V8T08PvLwct7vz8/ORn59v\nuZ2WloahQ4c6bPs/Zjab4enp6ZRtu0pERESv/7Ovv/4a48aNw8KFC93YlXvxyO8kzz//PB5++GEs\nXrwY/v7+2LFjB9LS0pCbm2sZc/jwYeh0Osvturo6JCcn49Zbb8XYsWOxdevWfs118eJF7NmzB+np\n6f0a/8O8L7zwAkaNGoWxY8di586dlnpaWhqWL1+OefPmwc/PD0eOHEFHRwdWr16N8PBwqNVqPPHE\nE+jo6LCsk5eXh+DgYISGhsJoNPaa76e/9/vvv4+4uDgEBAQgMjISBw8exG9/+1t8+umn+PWvf41h\nw4Zh5cqVAIBTp05hzpw5CAwMxM9+9jPs3r3bsp1z587hl7/8JQICAnDnnXeiqqqqX78/ABiNRsye\nPRvh4eH9XuemI2jAtFqtOHToUK9lzz33nPD29hYffPCBMJvN4vLly2LJkiVi3bp1ljGHDh0SWq1W\nCCFET0+PmDhxovj9738vOjs7RUVFhdBqteLw4cNCCCFKSkrEqFGj+pz/rbfeEpGRkf3u99ChQ8LT\n01OsWbNGdHR0iOLiYjFkyBBRUVEhhBBiyZIlYsSIEeKTTz4RZrNZdHR0iOXLl4sHHnhAtLS0iO++\n+07cd9994vnnnxdCCPGXv/xFBAcHi/LycnHp0iXx0EMPCQCiqqrKsr0ffu9jx46J4cOHi8OHDwuz\n2SxqamrE6dOnhRBCTJs2TbzzzjuWPtva2kRISIgwGo2iu7tbmEwmERgYaBmfkpIiUlNTRXt7uzh5\n8qQIDg4WM2bMsKw/b9488corr1zz+1+9elVotVqxffv2fu+zmxHD7wDWwj9r1qxey5TCf/ToUTF2\n7Nhe41944QWRnZ1tc/74+Hjx4osv9rvfQ4cOCW9vb9He3m5ZlpycLDZs2GDp89FHH7XUzGaz8PX1\nFdXV1ZZlH3/8seUO55FHHhHPPfecpVZeXm41/JmZmWLNmjV99vXT8O/YsUPMnDmz15jMzEzx0ksv\nia6uLuHp6Wm5wxJCiLVr1/YKvzV///vfhb+/f6/fX0Z8zu9E1/OQ8ttvv0VNTQ1GjBhhWWY2m22+\nel1VVYWjR49i27Zt19XbqFGjer1GoNVq0dDQYLn9497Pnj2Lzs5OTJw40bJM/OjzYA0NDZg2bVqv\nbVlTW1uLKVOm9KvHb7/9FseOHeu1T3p6epCRkYGmpiaYzeZefWq1WpSWltrcrtFoxEMPPeS010hu\nFAy/E6lUql63/fz8cPnyZcvts2fPWn4ODw9HVFQUvvrqq+uaY9u2bZgxY4Zi4Ppy4cIFXLlyBUOG\nDAEA1NTUQK/X99m7Wq2Gj48Pzpw5A7Vafc22NBoNamtrLbdramqszhseHo7Kyso+az/dX+Hh4UhI\nSMD+/fuvGdvd3Q0PDw/U1tYiMjLS5rw/aG9vx+7du1FUVGRz7M2OL/i5UFxcHIqKitDa2orGxka8\n9tprltpdd90FHx8fbNq0CR0dHTCbzfjyyy9x4sQJxW1u27YNGRkZ1yxPS0tDdna21fWuXr2K3Nxc\ndHV1oaSkBPv378eDDz7Y51hPT09kZ2dj5cqVOHfuHIQQqKurw8GDBwEAixYtwttvv43Tp0+jvb0d\n69evtzpvVlYW8vPz8eGHH+Lq1auoq6vDmTNnAHx/J/PNN99YxiYlJaG8vBzvvfceuru70d3djdLS\nUpw5cwbe3t544IEHsG7dOly5cgX/+te/sH37dsV9BQC7d+9GUFAQ7rnnHptjb3YMvwtlZGQgOjoa\nWq0W8+bNQ2pqqqXm5eWFffv2obS0FDqdDqNHj8Zjjz2GtrY2AEBJSUmvh78AcOTIETQ1NSElJeWa\nuWpra3s9FP+psLAw+Pn5QaPRID09Hfn5+YiKirI6ftOmTdBqtbjjjjswfPhwzJ07FxUVFQCA+fPn\nY/ny5ZgxYwZuu+02JCYmWt3O3XffjTfffBO/+c1vMHz4cMyaNcvyqGHlypUoKCjAiBEjsHr1agwf\nPhx/+9vfsGPHDmg0GgQHB+OZZ55BZ2cnAOAPf/gDWltboVarkZWVhUcffbTXXHPnzsXLL7/ca5nR\naMTSpUuveZQhI5UQ/DKPm01HRwd+/vOf48svv+zzvQWHDx9Gdna25R14JCc+578J3XLLLdf92gHJ\nhw/7iSTFh/1EkuKRn0hSLn3O76PyxS3wc+WURFLpQDu6RGe/xg4o/AcOHMCKFStgNpuRnZ2Np59+\nWnH8LfDDVFXCQKYkIgXHRXG/x9r9sN9sNmP58uXYv38/Tp06hYKCApw6dcrezRGRi9kd/tLSUkRG\nRiIiIgI+Pj5ITU1FYWGhI3sjIieyO/z19fW9PlQRFhaG+vr6a8YZDAbo9Xro9Xp0o3/PRYjI+ewO\nf19nCPt6y2ROTg5MJhNMJhO84WvvdETkYHaHPywsrNcnuerq6hASEuKQpojI+ewO/5QpU1BRUYGq\nqip0dXVh586dSEpKcmRvROREdp/q8/LywpYtW3DvvffCbDYjMzMT48ePd2RvRORELn17b4AqkOf5\niZzouChGm2jp11i+vZdIUgw/kaQYfiJJMfxEkmL4iSTF8BNJiuEnkhTDTyQphp9IUgw/kaQYfiJJ\nMfxEkmL4iSTF8BNJiuEnkhTDTyQphp9IUgw/kaQYfiJJMfxEkmL4iSTF8BNJiuEnkhTDTyQphp9I\nUgw/kaQYfiJJMfxEkmL4iSTF8BNJymsgK+t0Ovj7+8PT0xNeXl4wmUyO6ouInGxA4QeADz/8EKNH\nj3ZEL0TkQnzYTySpAYVfpVJh7ty5mDx5MgwGQ59jDAYD9Ho99Ho9utE5kOmIyIFUQghh78oNDQ0I\nCQlBc3MzEhMT8frrryM+Pt7q+ABVIKaqEuydjohsOC6K0SZa+jV2QEf+kJAQAEBQUBCSk5NRWlo6\nkM0RkQvZHf729nZcvHjR8vPBgwcRExPjsMaIyLnsfrW/qakJycnJAICenh786le/wrx58xzWGBE5\nl93hj4iIwMmTJx3ZCxG5EE/1EUmK4SeSFMNPJCmGn0hSDD+RpAb8wR5ZXFh2l9XamEe+Vlz3dLNa\nsd7V6a1YDy1Qrg+tu2S1drXslOK6JC8e+YkkxfATSYrhJ5IUw08kKYafSFIMP5GkGH4iSfE8fz89\ntfY9q7UUv1bllf9jgJPPVC5X91y2Wtt8btYAJ79xlTZrrdb8Ng1XXNer+ISj2xl0eOQnkhTDTyQp\nhp9IUgw/kaQYfiJJMfxEkmL4iSQ1oCv2XK8b+Yo97Q9OtVo7H6t8HzryK+Vd3BqtUqz7xP6fYv3l\nmPet1hKHXFFct+jyMMX6L4Za/66AgboiuhTrxzv9FOszb+m2e+7IoscU67fl/MPubbuTy67YQ0Q3\nLoafSFIMP5GkGH4iSTH8RJJi+IkkxfATSYqf5+8nvz8fV6gNbNsBA1sdrwfPtFp7aZpOee6PlK85\n8PLMSDs66h+vK1cV637/bFSsj/p4t2J9go/16x0MrVa+FoIMbB75MzMzERQUhJiYGMuylpYWJCYm\nIioqComJiWhttfFlFkQ06NgMf0ZGBg4cONBrWV5eHhISElBRUYGEhATk5eU5rUEicg6b4Y+Pj0dg\nYGCvZYWFhUhPTwcApKenY+/evc7pjoicxq7n/E1NTdBoNAAAjUaD5uZmq2MNBgMMBgMAoBud9kxH\nRE7g9Ff7c3JyYDKZYDKZ4A1fZ09HRP1kV/jVajUaG79/JbaxsRFBQUEObYqInM+u8CclJcFoNAIA\njEYjFixY4NCmiMj5bD7nX7x4MUpKSnD+/HmEhYVh/fr1ePrpp7Fo0SK89dZbGDNmDHbt2uWKXsmK\nnrNNVmt+u63XAMBsY9t+f75gR0eO0ZR9l2J9vI/yn+9/tYyzWtO9843iuj2K1ZuDzfAXFBT0uby4\nuNjhzRCR6/DtvUSSYviJJMXwE0mK4SeSFMNPJCl+pJfcxksbrljf8uwWxbq3ylOxvmvzHKu1UY2f\nKq4rAx75iSTF8BNJiuEnkhTDTyQphp9IUgw/kaQYfiJJ8Tw/uc3pVaGK9Sm+ypcuL+9Svvx44KnL\n192TTHjkJ5IUw08kKYafSFIMP5GkGH4iSTH8RJJi+IkkxfP85FSdv5hitfb5g/9tY23lKzw9vmKF\nYn3IJ6U2ti83HvmJJMXwE0mK4SeSFMNPJCmGn0hSDD+RpBh+IknxPD85Vc191o8vw1TK5/EXVyUq\n1oceOKlYF4pVsnnkz8zMRFBQEGJiYizLcnNzERoairi4OMTFxWHfvn1ObZKIHM9m+DMyMnDgwIFr\nlq9atQplZWUoKyvD/fff75TmiMh5bIY/Pj4egYGBruiFiFzI7hf8tmzZgtjYWGRmZqK1tdXqOIPB\nAL1eD71ej2502jsdETmYXeF//PHHUVlZibKyMmg0Gjz55JNWx+bk5MBkMsFkMsHbxgc1iMh17Aq/\nWq2Gp6cnPDw8sGzZMpSW8tNTRDcau8Lf2Nho+XnPnj29zgQQ0Y3B5nn+xYsXo6SkBOfPn0dYWBjW\nr1+PkpISlJWVQaVSQafT4Y033nBFrzQIefj7K9Yfueeo1Vrb1Q7FdZs3RCjWfTv/oVgnZTbDX1BQ\ncM2yrKwspzRDRK7Dt/cSSYrhJ5IUw08kKYafSFIMP5Gk+JFeGpCK3PGK9b+O/h+rtQUVKYrr+u7j\nqTxn4pGfSFIMP5GkGH4iSTH8RJJi+IkkxfATSYrhJ5IUz/OTou/S7lSs//Ph1xTrlT3dVmuX/jNM\ncV1fNCrWaWB45CeSFMNPJCmGn0hSDD+RpBh+Ikkx/ESSYviJJMXz/JLzCg1RrK/83f8q1n1Vyn9C\nqScfsVq7dT8/r+9OPPITSYrhJ5IUw08kKYafSFIMP5GkGH4iSTH8RJKyeZ6/trYWS5cuxdmzZ+Hh\n4YGcnBysWLECLS0tePjhh1FdXQ2dToc//elPGDlypCt6puug8lL+L5741zrF+kPDLijW370YpFhX\n/8768eWq4prkbDaP/F5eXti0aRO++uorfPbZZ9i6dStOnTqFvLw8JCQkoKKiAgkJCcjLy3NFv0Tk\nIDbDr9FoMGnSJACAv78/oqOjUV9fj8LCQqSnpwMA0tPTsXfvXud2SkQOdV3P+aurq/HFF19g6tSp\naGpqgkajAfD9HURzc7NTGiQi5+j3e/svXbqElJQUvPrqqwgICOj3BAaDAQaDAQDQjc7r75CInKJf\nR/7u7m6kpKRgyZIlWLhwIQBArVajsfH7L1hsbGxEUFDfL/zk5OTAZDLBZDLBG74OapuIBspm+IUQ\nyMrKQnR0NFavXm1ZnpSUBKPRCAAwGo1YsGCB87okIodTCSGE0oCjR4/innvuwYQJE+Dh8f19xYYN\nGzB16lQsWrQINTU1GDNmDHbt2oXAwEDFyQJUgZiqSnBc92STarLyJbSLPtg+oO3f/cxyxfqIbZ8O\naPt0fY6LYrSJln6Ntfmcf/r06bB2/1BcXHx9nRHRoMF3+BFJiuEnkhTDTyQphp9IUgw/kaQYfiJJ\n8au7bwKet99mtZazs3BA2779beXz+Lrtnw1o++Q+PPITSYrhJ5IUw08kKYafSFIMP5GkGH4iSTH8\nRJLief6bwOknrH9l+vyhbQPadlhJl/IA5a+DoEGMR34iSTH8RJJi+IkkxfATSYrhJ5IUw08kKYaf\nSFI8z38D6Jh/h2K9eP4mhepQxzZDNw0e+YkkxfATSYrhJ5IUw08kKYafSFIMP5GkGH4iSdk8z19b\nW4ulS5fi7Nmz8PDwQE5ODlasWIHc3Fy8+eabuPXWWwEAGzZswP333+/0hmXUMM1TsT7Gy/5z+e9e\nDFKse7cpf56fn+a/cdkMv5eXFzZt2oRJkybh4sWLmDx5MhITEwEAq1atwpo1a5zeJBE5ns3wazQa\naDQaAIC/vz+io6NRX1/v9MaIyLmu6zl/dXU1vvjiC0ydOhUAsGXLFsTGxiIzMxOtra19rmMwGKDX\n66HX69GNzoF3TEQO0e/wX7p0CSkpKXj11VcREBCAxx9/HJWVlSgrK4NGo8GTTz7Z53o5OTkwmUww\nmUzwhq/DGieigelX+Lu7u5GSkoIlS5Zg4cKFAAC1Wg1PT094eHhg2bJlKC0tdWqjRORYNsMvhEBW\nVhaio6OxevVqy/LGxkbLz3v27EFMTIxzOiQip7D5gt+xY8ewfft2TJgwAXFxcQC+P61XUFCAsrIy\nqFQq6HQ6vPHGG05vlq7fxgu3K9Y/vVenWBeNXzqwGxpMbIZ/+vTpEH18NzvP6RPd2PgOPyJJMfxE\nkmL4iSTF8BNJiuEnkhTDTyQplejrPJ6TBKgCMVWV4KrpiKRzXBSjTbT0ayyP/ESSYviJJMXwE0mK\n4SeSFMNPJCmGn0hSDD+RpFx6iW6fUR5o1VVZbp87d87y1d+DzWDtbbD2BbA3ezmyN5/q/h/PXfom\nn5/S6/UwmUzuml7RYO1tsPYFsDd7uas3PuwnkhTDTyQpz9zc3Fx3NjB58mR3Tq9osPY2WPsC2Ju9\n3NGbW5/zE5H78GE/kaQYfiJJuSX8Bw4cwLhx4xAZGYm8vDx3tGCVTqezXKNAr9e7tZfMzEwEBQX1\nuiBKS0sLEhMTERUVhcTERKvXSHRHb7m5uQgNDUVcXBzi4uKwb98+t/RWW1uLWbNmITo6GuPHj8fm\nzZsBuH/fWevLbftNuFhPT4+IiIgQlZWVorOzU8TGxory8nJXt2GVVqsV586dc3cbQgghPvroI3Hi\nxAkxfvx4y7K1a9eKjRs3CiGE2Lhxo3jqqacGTW/r1q0Tr7zyilv6+bGGhgZx4sQJIYQQbW1tIioq\nSpSXl7t931nry137zeVH/tLSUkRGRiIiIgI+Pj5ITU1FYWGhq9u4IcTHxyMwMLDXssLCQqSnpwMA\n0tPTsXfvXne01mdvg4VGo8GkSZMA9L6svLv3nbW+3MXl4a+vr0d4eLjldlhYmFt3wE+pVCrMnTsX\nkydPhsFgcHc712hqaoJGowHw/R9Tc3OzmzvqrT+XbXelH19WfjDtO3sud+9oLg+/6OPMokqlcnUb\nVh07dgyff/459u/fj61bt+Ljjz92d0s3jP5ett1VfnpZ+cHC3svdO5rLwx8WFoba2lrL7bq6OoSE\nhLi6Dat+6CUoKAjJycmD7tLjarXacoXkxsZGBAUFubmjfxtMl223dll5d++7wXS5e5eHf8qUKaio\nqEBVVRW6urqwc+dOJCUlubqNPrW3t+PixYuWnw8ePDjoLj2elJQEo9EIADAajViwYIGbO/q3wXLZ\ndmHlsvLu3nfW+nLbfnP5S4xCiKKiIhEVFSUiIiLESy+95I4W+lRZWSliY2NFbGysuP32293eW2pq\nqggODhZeXl4iNDRU5Ofni/Pnz4vZs2eLyMhIMXv2bHHhwoVB01taWpqIiYkREyZMEPPnzxcNDQ1u\n6e3IkSMCgJgwYYKYOHGimDhxoigqKnL7vrPWl7v2G9/eSyQpvsOPSFIMP5GkGH4iSTH8RJJi+Ikk\nxfATSYrhJ5LU/wOdAGX9nfSgHgAAAABJRU5ErkJggg==\n",
-            "text/plain": [
-              "\u003cFigure size 600x400 with 1 Axes\u003e"
-            ]
-          },
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "display_data"
-        }
-      ],
+      "outputs": [],
       "source": [
         "plt.imshow(test_images[0])\n",
         "template = \"True:{true}, predicted:{predict}\"\n",
@@ -537,7 +439,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 17,
+      "execution_count": null,
       "metadata": {
         "id": "05aeAuWjvjPx"
       },
@@ -577,19 +479,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 18,
+      "execution_count": null,
       "metadata": {
         "id": "T5mWkSbMcU5z"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "0.956\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "print(evaluate_model(interpreter))"
       ]
@@ -605,19 +499,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 19,
+      "execution_count": null,
       "metadata": {
         "id": "-9cnwiPp6EGm"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "0.956\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "# NOTE: Colab runs on server CPUs. At the time of writing this, TensorFlow Lite\n",
         "# doesn't have super optimized server CPU kernels. For this reason this may be\n",
diff --git a/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
index 21c7bd9b..b761387 100644
--- a/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
+++ b/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
@@ -46,20 +46,20 @@
         "id": "CIGrZZPTZVeO"
       },
       "source": [
-        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/performance/post_training_integer_quant\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "\u003c/table\u003e"
+        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/performance/post_training_integer_quant\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
+        "  </td>\n",
+        "</table>"
       ]
     },
     {
@@ -110,7 +110,7 @@
         "\n",
         "import tensorflow as tf\n",
         "import numpy as np\n",
-        "assert float(tf.__version__[:3]) \u003e= 2.3"
+        "assert float(tf.__version__[:3]) >= 2.3"
       ]
     },
     {
@@ -139,38 +139,7 @@
       "metadata": {
         "id": "eMsw_6HujaqM"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
-            "11493376/11490434 [==============================] - 0s 0us/step\n",
-            "Epoch 1/5\n",
-            "1875/1875 [==============================] - 5s 2ms/step - loss: 0.2793 - accuracy: 0.9227 - val_loss: 0.1392 - val_accuracy: 0.9618\n",
-            "Epoch 2/5\n",
-            "1875/1875 [==============================] - 5s 2ms/step - loss: 0.1179 - accuracy: 0.9667 - val_loss: 0.0928 - val_accuracy: 0.9719\n",
-            "Epoch 3/5\n",
-            "1875/1875 [==============================] - 4s 2ms/step - loss: 0.0860 - accuracy: 0.9754 - val_loss: 0.0742 - val_accuracy: 0.9755\n",
-            "Epoch 4/5\n",
-            "1875/1875 [==============================] - 4s 2ms/step - loss: 0.0691 - accuracy: 0.9796 - val_loss: 0.0686 - val_accuracy: 0.9776\n",
-            "Epoch 5/5\n",
-            "1875/1875 [==============================] - 4s 2ms/step - loss: 0.0589 - accuracy: 0.9823 - val_loss: 0.0654 - val_accuracy: 0.9787\n"
-          ]
-        },
-        {
-          "data": {
-            "text/plain": [
-              "\u003ctensorflow.python.keras.callbacks.History at 0x7f69e0275a58\u003e"
-            ]
-          },
-          "execution_count": null,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "# Load MNIST dataset\n",
         "mnist = tf.keras.datasets.mnist\n",
@@ -271,22 +240,7 @@
       "metadata": {
         "id": "HEZ6ET1AHAS3"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "INFO:tensorflow:Assets written to: /tmp/tmpcojyiqri/assets\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "INFO:tensorflow:Assets written to: /tmp/tmpcojyiqri/assets\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
         "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
@@ -328,22 +282,7 @@
       "metadata": {
         "id": "FiwiWU3gHdkW"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "INFO:tensorflow:Assets written to: /tmp/tmp1bvfr71i/assets\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "INFO:tensorflow:Assets written to: /tmp/tmp1bvfr71i/assets\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "def representative_data_gen():\n",
         "  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):\n",
@@ -374,16 +313,7 @@
       "metadata": {
         "id": "id1OEKFELQwp"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "input:  \u003cclass 'numpy.float32'\u003e\n",
-            "output:  \u003cclass 'numpy.float32'\u003e\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)\n",
         "input_type = interpreter.get_input_details()[0]['dtype']\n",
@@ -429,22 +359,7 @@
       "metadata": {
         "id": "kzjEjcDs3BHa"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "INFO:tensorflow:Assets written to: /tmp/tmpvnuxq9pa/assets\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "INFO:tensorflow:Assets written to: /tmp/tmpvnuxq9pa/assets\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "def representative_data_gen():\n",
         "  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):\n",
@@ -477,16 +392,7 @@
       "metadata": {
         "id": "PaNkOS-twz4k"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "input:  \u003cclass 'numpy.uint8'\u003e\n",
-            "output:  \u003cclass 'numpy.uint8'\u003e\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)\n",
         "input_type = interpreter.get_input_details()[0]['dtype']\n",
@@ -528,20 +434,7 @@
       "metadata": {
         "id": "BEY59dC14uRv"
       },
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "24720"
-            ]
-          },
-          "execution_count": null,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "import pathlib\n",
         "\n",
@@ -677,21 +570,7 @@
       "metadata": {
         "id": "iTK0x980coto"
       },
-      "outputs": [
-        {
-          "data": {
-            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEXCAYAAABrgzLrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAVZUlEQVR4nO3de9RVdZ3H8fcH5aKICsIQIEFeWN5mxGK8pGM2aBplWtNYTBk2GjVljrOYlWatpEmdVpNZM5VGaqJ5ibximomUYxqhaCgqlTcU6EE0YEArLo/f+WPvpw6Pz9nn4dwffp/XWmdxzv7ty5cDn7Ovv70VEZjZ9q9fqwsws+Zw2M0S4bCbJcJhN0uEw26WCIfdLBEOex8habykkLRjq2uplqSrJF3Qy3GXSTq20TWlxGFvM/l/8j9KeqXkNbrOywhJ+xS0n5aPc0m34Sflw6+qZz3WHA57ezoxInYpef2uBTU8A5zSbUtiGvDbFtRideCw91GSRkuaK2mNpKclfayk7VBJCyStk9Qh6ZuSBuRt9+WjPZpvNXygzCJWAUuA4/PphgFvBeZ2q+M9kp7Il3WvpP1L2g6R9IikDZJ+AAzqNu27JS3Op/2FpL+p8WuxAg5733UDsAIYDbwfuEjS3+dtncC/AcOBI4DJwCcBIuLofJyD862GHxQs42rgI/n7DwK3ARu7GiVNAK4HzgZGAHcCt0sakP+43ApcAwwDfgj8Q8m0hwBXAh8H9gC+A8yVNHCbvwnrFYe9Pd2ar+3WSbq1e6OkscCRwDkR8aeIWAxcTh7MiHg4In4ZEVsiYhlZkN5WRR23AMdI2i2f99Xd2j8A3BER8yJiM/BVYCeyLYDDgf7A1yNic0TcCDxUMu104DsRsTAiOiNiNtkPyeFV1Gm94LC3p5MjYvf8dXIP7aOBNRGxoWTY88AYyNa4kn4kaZWk9cBFZGv5bRIRfwTuAD4P7BERD/RQx/Ml478GLM/rGA2sjK17Wj1f8n4cMKPkR20dMDafzhrAYe+bfgcMkzSkZNgbgZX5+0uBXwP7RsSuwHmAqlzW1cAM4Ptl6hjX9UGSyAK7EugAxuTDSmvsshy4sORHbfeI2Dkirq+yTqvAYe+DImI58AvgPyUNyg9snc5fAjkEWA+8Imk/4F+6zeJFYK9eLu5/geOA/+mhbQ7wLkmTJfUn+1HYmNe2ANgCnCWpv6T3AYeWTPtd4BOSDlNmsKR3dfsBszpy2PuuqcB4srXrLcD5EXFP3vbvwD8BG8hC1f0g3Exgdr75fErRQiIzPyLW9ND2G+DDZD8ELwMnkp023BQRm4D3AacBa8j2728umXYR8DHgm8Ba4Ol8XGsQ+eYVZmnwmt0sEQ67WSIcdrNEOOxmiXDYrSFKu6hKOk/S5U1Y5jGSVjR6OX2Vw14jSW/s1h01JL1a8vnvGrjsd0m6Pz+FtkrS5b09T13SP76rzmWSzm1EnRFxUUSc0Yuaet3ffVtJGijpCknP5x1zFkt6ZyOW1a4c9hpFxAul3VHzwQeXDPt517gNuPHEbsAFZJeY7k92mep/beM8ds/rngp8QdIJ3UfoyzfMKLEj2VV7byP73j4PzJE0voU1NZXD3kD5TSAekHSJpN8DMyXNlPT9knG2ugONpN3yNVCHpJWSLpC0Q0/zj4jrIuKuiPhDRKwlu4DmyGpqjYgFwBPAQV2bw5LOkbQK+J6kfpLOlfSMpN9LmpN3e+36e5yarzV/L+lz3b6H7n/no/IureskLc+/p+nAh4DP5Fsat+fjjpZ0k6SXJD0n6ayS+eyUbw2slfQk8LcFf79XI2JmRCyLiNci4kfAc8Bbqvm++iKHvfEOA54FRgIX9mL8q8guM90HOAR4B3AG/HmXYZ2kN5aZ9miywG6T/HLVI4EDgV/lg99A1jV1HFkPtU8DJ5OtGUeTXfX2rXz6A8iuxz81b9sD2LPMssYBPya76m4EMBFYHBGzgGuBr+RbRCdK6gfcDjxKttUyGThb0vH57M4H9s5fx5PdXKN0Wd+W9O0ydYwEJlDF99VnRYRfdXwBAeyTvz8NeKFb+0zg+yWfx+fT7Ej2g7AR2KmkfSrws14s9ziyAE7oZZ1dy12XT7cUOCtvOwbYBAwqGX8pMLnk8yhgc173F4AbStoG59Mf2/3vDHwWuKVMTVcBF5R8PqyH7++zwPfy988CJ5S0TQdW9OLv3h+4h6yLbcv/zzTrtT3si7W75dsw7jiy/4gdJZ3F+lWah6TDgeuA90fEtt42anhEbOlh+EsR8adutd0i6bWSYZ1kP1CjS2uMiFfz3ZaejCW75VVvjANGK+v+2mUHoOs4yFbLZesutD3KtxauIfsxOrOXdWwXHPbG69754FVg55LPbyh5v5xszV4ugK+j7I4vc4F/joj5tRTaTfe6l+fL6N6nHUkdZAcIuz7vTLYp35PlbN37rdIyn4uIfcuM30H249G1KV5u96arLgFXkP1ATYnshhvJ8D578y0Gjs73v3cj2ywFICI6gLuBiyXtmh8U21tSj3eZkXQQcBfw6Yi4vYf2mZLurVPdlwEX5vvcSBoh6aS87Ubg3fmBtwHAf1D+/9a1wLGSTpG0o6Q9JE3M27p3vX0Q2JAfKNxJ0g6SDpLUdSBuDvBZSUMl7Ul2XKHIpWQ/SidGdmOOpDjsTRYR88i6nD4GPAz8qNsoHwEGAE+S7UvfSLZ/XHpOv2sNNoPsINcVJefLSw84jQVetyau0jfItiDulrQB+CXZPjUR8QTwKbJdiY687h4vbomIF4Apee1ryH78Ds6brwAOyA9C3hoRncC7yQ7iPUfWjfZyslNnAF8k23R/juxH8prSZUm6TNJl+ftxZPe7mwisKvm+PlTLl9KXuIvrdkzSYrKDauX2ny0hDrtZIrwZb5YIh90sEQ67WSKaep59gAbGIAY3c5FmSfkTr7IpNvZ42/Cawp73kPoG2VVNl0fEl4vGH8RgDtPkWhZpZgUWFlxXVfVmfN4T61vAO4EDgKl5hwgza0O17LMfCjwdEc9Gdo/wG4CTKkxjZi1SS9jHsHUnhBX5sK1Imi5pkaRFm//yAFAza7KGH42PiFkRMSkiJvXHT+M1a5Vawr6S7NrrLnvylwcLmlmbqSXsDwH7SnpT3tPpg2QdJcysDVV96i0itkg6E/gJ2am3K/PeT2bWhmo6zx4RdwJ31qkWM2sgXy5rlgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJ8COb+4BlFxxR2N45qPwjvEYc+FLhtAsOvqmqmrrs/dOPFrYPeXCnsm0j//sXNS3bto3X7GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZInyevQ2svWPfwvbHJ36zYcveXP4Ufa/8+u2XF7ZfO2lU2bY5895WOG3n0qeqqsl65jW7WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIn2dvgkrn0R+YeEPDln3Zur0K27+24LjC9vHjivvD333AzYXtHxrSUbbtwtOGF0671zk+z15PNYVd0jJgA9AJbImISfUoyszqrx5r9rdHxMt1mI+ZNZD32c0SUWvYA7hb0sOSpvc0gqTpkhZJWrSZjTUuzsyqVetm/FERsVLSXwHzJP06Iu4rHSEiZgGzAHbVsBq7XZhZtWpas0fEyvzP1cAtwKH1KMrM6q/qsEsaLGlI13vgHcDj9SrMzOqrls34kcAtkrrmc11E3FWXqvqYLZPfUtj+04O/VWEO/Qtbv752QmH7zz5QcMbzd6sLp52wdlFhe79BgwrbL1r414Xt5w1fUrZty9AthdNafVUd9oh4Fji4jrWYWQP51JtZIhx2s0Q47GaJcNjNEuGwmyXCXVzr4JUxAwrb+1X4Ta10au3e9xSf3up89jeF7bV4+ouHFLZfN+ziCnMYWLZlz7u8rmkmf9tmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSJ8nr0Odr96QWH7+xd9uLBda9cXtm/pWLaNFdXPGVPuKWzfpV/58+jWXrxmN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S4fPsTdD55G9bXUJZyy48orD99N2/WmEOxbeantFxeNm2IfcsLZy2s8KSbdt4zW6WCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJcLn2bdz604tPo/+wEeKz6Pv1q/4PPqCjTsUti++oPx953da/2DhtFZfFdfskq6UtFrS4yXDhkmaJ+mp/M+hjS3TzGrVm834q4ATug07F5gfEfsC8/PPZtbGKoY9Iu4D1nQbfBIwO38/Gzi5znWZWZ1Vu88+MiI68vergJHlRpQ0HZgOMIidq1ycmdWq5qPxERFAFLTPiohJETGpf8FD/syssaoN+4uSRgHkf66uX0lm1gjVhn0uMC1/Pw24rT7lmFmjVNxnl3Q9cAwwXNIK4Hzgy8AcSacDzwOnNLJIq97Lby67hwVUPo9eybR7zyhsn3Crz6W3i4phj4ipZZom17kWM2sgXy5rlgiH3SwRDrtZIhx2s0Q47GaJcBfX7cCmeePKti3Y7+IKUxefejt4wbTC9v1nPFPY7ttBtw+v2c0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRPg8ex+w417jC9u/tM8Py7YNrdCF9eGNxcse96XiM+Wda9cWz8DahtfsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kifJ69D9h7zsrC9kMGVP+bPXX+JwrbJzz6UNXztvbiNbtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgifZ28Da6cdUdj+xZGV7v0+sGzLtGXHFk65/2eeLmz3fd+3HxXX7JKulLRa0uMlw2ZKWilpcf6a0tgyzaxWvdmMvwo4oYfhl0TExPx1Z33LMrN6qxj2iLgPWNOEWsysgWo5QHempMfyzfyh5UaSNF3SIkmLNlPhhmdm1jDVhv1SYG9gItABlD2CFBGzImJSREzqX3Agycwaq6qwR8SLEdEZEa8B3wUOrW9ZZlZvVYVd0qiSj+8FHi83rpm1h4rn2SVdDxwDDJe0AjgfOEbSRCCAZcDHG1hjn7fjmNGF7X931sLC9l36Vb/7s+DJfQrbJ6x1f/VUVAx7REztYfAVDajFzBrIl8uaJcJhN0uEw26WCIfdLBEOu1ki3MW1CZaeN7aw/dY33F7T/N++5B/LtrkLq3Xxmt0sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TPszfBw++5pMIYtd3BZ7dPvla2bcvatTXN27YfXrObJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonwefbtwOaRu5Vt679pTBMreb3Ol14u2xYbix8HpoHF1x/sMGJ4VTUBdI7YvbD9qRkDqp53b0Snyrbt9+kK9yBYv76qZXrNbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslojePbB4LXA2MJHtE86yI+IakYcAPgPFkj20+JSLceboF7rjxylaXUNZbf9XTQ4AzL7+4a+G0Q0dsKGxf+Jbrqqqp3R3w+TML2/f6zIKq5tubNfsWYEZEHAAcDnxK0gHAucD8iNgXmJ9/NrM2VTHsEdEREY/k7zcAS4ExwEnA7Hy02cDJjSrSzGq3TfvsksYDhwALgZER0ZE3rSLbzDezNtXrsEvaBbgJODsitro4NyKCbH++p+mmS1okadFmiq+FNrPG6VXYJfUnC/q1EXFzPvhFSaPy9lHA6p6mjYhZETEpIib1r/HGimZWvYphlyTgCmBpRHytpGkuMC1/Pw24rf7lmVm9KNsCLxhBOgr4ObAE6Lpn8Xlk++1zgDcCz5OdeltTNK9dNSwO0+Raa+5z/viTNxW2zz/oxiZVkpY/xKaybZuj/O23e2PKY6cVtv/f4uq73466f0th+8AfP1S2bWHMZ32s6bH/bMXz7BFxP1Cu8216yTXro3wFnVkiHHazRDjsZolw2M0S4bCbJcJhN0uEbyXdBDsd/1xh+4EXFXdpjAb+Kw3Zr/DSiIZ2Iz3w5x8tbI8XBtc0/71ufKV844NLapr3UJ6qqb0VvGY3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRJRsT97PaXan92sWYr6s3vNbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslomLYJY2V9DNJT0p6QtK/5sNnSlopaXH+mtL4cs2sWr15/MAWYEZEPCJpCPCwpHl52yUR8dXGlWdm9VIx7BHRAXTk7zdIWgqMaXRhZlZf27TPLmk8cAiwMB90pqTHJF0paWiZaaZLWiRp0WY21lSsmVWv12GXtAtwE3B2RKwHLgX2BiaSrfkv7mm6iJgVEZMiYlJ/BtahZDOrRq/CLqk/WdCvjYibASLixYjojIjXgO8ChzauTDOrVW+Oxgu4AlgaEV8rGT6qZLT3Ao/Xvzwzq5feHI0/EjgVWCJpcT7sPGCqpIlAAMuAjzekQjOri94cjb8f6Ok+1HfWvxwzaxRfQWeWCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0SoYho3sKkl4DnSwYNB15uWgHbpl1ra9e6wLVVq561jYuIET01NDXsr1u4tCgiJrWsgALtWlu71gWurVrNqs2b8WaJcNjNEtHqsM9q8fKLtGtt7VoXuLZqNaW2lu6zm1nztHrNbmZN4rCbJaIlYZd0gqTfSHpa0rmtqKEcScskLckfQ72oxbVcKWm1pMdLhg2TNE/SU/mfPT5jr0W1tcVjvAseM97S767Vjz9v+j67pB2A3wLHASuAh4CpEfFkUwspQ9IyYFJEtPwCDElHA68AV0fEQfmwrwBrIuLL+Q/l0Ig4p01qmwm80urHeOdPKxpV+phx4GTgNFr43RXUdQpN+N5asWY/FHg6Ip6NiE3ADcBJLaij7UXEfcCaboNPAmbn72eT/WdpujK1tYWI6IiIR/L3G4Cux4y39LsrqKspWhH2McDyks8raK/nvQdwt6SHJU1vdTE9GBkRHfn7VcDIVhbTg4qP8W6mbo8Zb5vvrprHn9fKB+he76iIeDPwTuBT+eZqW4psH6ydzp326jHezdLDY8b/rJXfXbWPP69VK8K+Ehhb8nnPfFhbiIiV+Z+rgVtov0dRv9j1BN38z9UtrufP2ukx3j09Zpw2+O5a+fjzVoT9IWBfSW+SNAD4IDC3BXW8jqTB+YETJA0G3kH7PYp6LjAtfz8NuK2FtWylXR7jXe4x47T4u2v5488joukvYArZEflngM+1ooYyde0FPJq/nmh1bcD1ZJt1m8mObZwO7AHMB54C7gGGtVFt1wBLgMfIgjWqRbUdRbaJ/hiwOH9NafV3V1BXU743Xy5rlggfoDNLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEvH/9ALsS7Cy9ngAAAAASUVORK5CYII=\n",
-            "text/plain": [
-              "\u003cFigure size 432x288 with 1 Axes\u003e"
-            ]
-          },
-          "metadata": {
-            "needs_background": "light",
-            "tags": []
-          },
-          "output_type": "display_data"
-        }
-      ],
+      "outputs": [],
       "source": [
         "test_model(tflite_model_file, test_image_index, model_type=\"Float\")"
       ]
@@ -711,21 +590,7 @@
       "metadata": {
         "id": "rc1i9umMcp0t"
       },
-      "outputs": [
-        {
-          "data": {
-            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEXCAYAAABrgzLrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAWRklEQVR4nO3de9RVdZ3H8fcHRVRQBHEQ0SBvldoSi9FKKxu1lKm0Vjk5jWLlYGuyci3XlGlTNKPWNJrZTQcvqeUl0kwtM5VyeYkx0UhQKm94oUfRwEQtBPzOH/v32PHhnH0O5w6/z2utZ3HO/u3L9zk8n7Ovv70VEZjZhm9Yrwsws+5w2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHPYMSHpO0o5tnufNko5p5zzbuUxJIWnnTte0PnHYO0DS0ZIWSHpB0hOSviNpdJeWvVYgImJURDzUjeWnGmamsH16yPBPp+Ezu1WL/Y3D3maSTgD+G/h3YDTwJmAycIOk4T0srdv+ABw1ZNj0NNx6wGFvI0lbAl8CPhkR10fEqohYDBwO7Aj8cxrvQkmnVEy3v6THK96fKOlBSSsk3SfpfRVtR0u6TdLpkpZLeljSIantVOCtwLfSpvu30vCQtLOk7dLwwZ8XJEXFvD8qaVGa788lTapoO0jS7yT9Oc1XdT6OO4HNJe2ept8d2DQNr/zM/lXSA5KWSbpG0naNLrOsXlubw95eb6H4g/5R5cCIeA64Dnhng/N5kCK0oym+PL4vaUJF+z7A74FxwFeB8yUpIk4GbgWOS5vuxw2p449p+KiIGAVcBVwOIOlQ4CTg/cA2aT6XpbZx6Xf6fFrmg8C+Dfwe3+Nva/fp6f3LJP0D8GWKL8MJwCMV9ZQus6xeq85hb69xwNMRsbpK2wDFH2VdEfHDFMyXIuIHwP3A3hWjPBIR50bEGuAiiqCMX5dCJX0WeC3w0TTo48CXI2JRqv80YEpaW04D7o2IKyJiFfB14IkGFvN94Ii0+/Kh9L7Sh4ELIuLuiFgJfA54s6TJDSyzrF6rwmFvr6eBcZI2rtI2IbXXJekoSfMlPSPpGWAPii+SQS//0UfEC+nlqEaLTJv9nwYOi4i/pMGTgLMqlrmMYrN5IrAd8FjFMqPyfS0R8SjwAEUQ74+IodNsR7E2Hxz/OeBPDS6zrF6rwmFvr7nASopNy5dJGgUcAtycBj0PbF4xyrYV404CzgWOA7aOiK2AhdTfRx5U2mdZ0msotgYOHxK+x4BjI2Krip/NIuJXFFslO1TMQ5Xv67gYOCH9O9QfKUI7ON+RwNbAkgaWWVavVeGwt1FE/JliH/ubkg6WNDxtks6mWKtfkkadD0yTNFbStsDxFbMZSRHYpwAkfYRizd6oJykOBq4lHUC8Gjg5Im4b0nwO8LmKA2qjJX0wtf0U2F3S+9NWy6eo+IKq4wcUxypmV2m7DPiIpCmSRlBsAdyRDmrWW2ZZvVaFw95mEfFVigNHpwMrgIcp1uIHRsTzabTvAb8FFgM3UARicPr7gDMothKeBF4P3L4OJZwFfCAdof7GkLY3AK8Bzqw8Kp+WexXFKcPLJT1LsTVxSGp7Gvgg8BWKzexdGq0pIv4SETdV7C5Utt0E/AdwJcWafCeKffu6yyyr16qT71TTWWnN/J/Avmkf1qwnHPYukHQksCoiLu91LZYvh90sE95nN8uEw24dIWmxpAPT65MkndeFZb7ismN7JYe9RZJeNeR685D0fMX7t3Zw2f+YrpN/RkXvuvMkbdHgtJNTrYN1LpZ0YifqjIjTIqJu11QN6TPQTpJGSDpf0iMq+hzMTxcXZcNhb1FEPDrkenOAPSuG3To4bo0r61oxGjiF4mqz11FcPfY/6ziPrVLdRwBfkHTw0BE6UHcvbExxIc7bKT63zwOz03UQWXDYO0hFD7XbJZ0p6U/ATBV9vb9fMc7gGnbj9H50WgMNSFoi6RRJG1Wbf0RcmnrXvRARyymuvGukg0q1ec0F7gX2GNwclvRZSU8A35U0TH/rjfcnSbMlja34PY5Ma80/STp5yOcw9HfeT9Kv0hbJY+lzmkFxrfxn0pbGtWnc7SRdKekpFT38PlUxn83S1sBySfcBf1/y+z0fETMjYnHqc/ATimsg3tjM57U+ctg7bx/gIYqOKqc2MP6FwGpgZ2AviqvPjoGXdxmekfSqGtO+jSKw60SFfYHdgd+kwdsCYykuZ50BfBI4jGLNuB2wHPh2mn434GzgyNS2NbB9jWVNAn4GfJOiY9AUYH5EzKK4wvCraYvoPZKGAddSXIA0ETgAOF7Su9LsvkhxIc5OwLsoetZVLus7kr5To47xwK408XmttyLCP238objUdef0+mjg0SHtM4HvV7yfnKbZmOILYSWwWUX7EcAvG1juQRQB3LXBOgeX+0yabhHwqdS2P/AisGnF+IuAAyreTwBWpbq/AFxe0TYyTX/g0N+ZomfbVTVquhA4peL9PlU+v88B302vHwIOrmibATzewO8+HLgJ+N9e/71082dD2Bfrd3V7h1WYRPGHOCC93O9lWL15SHoTcCnwgYhY1zvBjIvqXXKfioi/DqntKkkvVQxbQ/EFNbSH2vNpt6WaHSj6pjdiErCdil5tgzai6LvO0OVS0YOulrS18D2KL6Pj6oy+QXHYO2/oVUs1e7xR/OGupHYA1yJpL+Aa4KMRMaeVQocYWvdjaRlrXRMvaYDiAOHg+80pNuWreYxX9s2vt8yHI2KXGuMP9owb3BSvtXszWJeA8ym+oKZF0U8+G95n7775wNvS/vdois1SACJigKJjzBmStkwHxXaS9PZqM5K0B3A9xW2wrq3SPlPSzW2q+xzg1LTPjaRtVNwtBuAK4N3pwNsmFH0Bav1tXQIcKOlwSRtL2lrSlNQ2tMfer4EV6UDhZpI2krSHpMEDcbMper6NkbQ9xXGFMmdTfCm9J6p0zNnQOexdFhE3UvRyuwe4C/jJkFGOAjYB7qPYl76CYv+48pz+4BrsBIqDXOdXnC+vPOC0A+vWY67MWRRbEDdIWgH8H8U+NRFxL/AJil2JgVR31YtbougMNC3Vvoziy2/P1Hw+sFs6CPnjKO7E826Kg3gPU3QTPo/i1BkU3YkfSW03sPZtr86RdE56PQk4Ns3riYrP68OtfCjrE18bvwGTNJ/ioFqt/WfLiMNulglvxptlwmE3y4TDbpaJrp5n30QjYlNGdnORZln5K8/zYqyseifilsKeekidRXFV03kR8ZWy8TdlJPvogFYWaWYl7ii5rqrpzfjUE+vbFHf03I3iyR+7NTs/M+usVvbZ9wYeiIiHIuJFimd0HVpnGjPrkVbCPpFXdkJ4nCqP3pE0Q9I8SfNWsbKFxZlZKzp+ND4iZkXE1IiYOpwRnV6cmdXQStiX8Mpnb22fhplZH2ol7HcCu0h6derp9CGKjhJm1oeaPvUWEaslHQf8nOLU2wWp95OZ9aGWzrNHxHXAdW2qxcw6yJfLmmXCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJvzI5vXA4lPeXNq+ZtPaj/DaZvenSqedu+eVTdU0aKdffKS0fYtfb1azbfw3ftXSsm3deM1ulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XC59n7wPKf7lLavnDKtzq27FW1T9E35HfvOK+0/ZKpE2q2zb7x7aXTrll0f1M1WXVes5tlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmfB59i6odx799imXd2zZ5zyzY2n71+YeVNo+eVJ5f/gbdvtRafuHtxio2Xbq0eNKp93xsz7P3k4thV3SYmAFsAZYHRFT21GUmbVfO9bs74iIp9swHzPrIO+zm2Wi1bAHcIOkuyTNqDaCpBmS5kmat4qVLS7OzJrV6mb8fhGxRNLfATdK+l1E3FI5QkTMAmYBbKmxLXa7MLNmtbRmj4gl6d+lwFXA3u0oyszar+mwSxopaYvB18A7gYXtKszM2quVzfjxwFWSBudzaURc35aq1jOrD3hjafsv9vx2nTkML239+vJdS9t/+U8lZzz/uLR02l2XzyttH7bppqXtp93x+tL2k8YtqNm2eszq0mmtvZoOe0Q8BOzZxlrMrIN86s0sEw67WSYcdrNMOOxmmXDYzTLhLq5t8NzETUrbh9X5Tq13au3m95af3lrz0O9L21vxwJf2Km2/dOwZdeYwombL9td7XdNN/rTNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0z4PHsbbHXx3NL2D8z7l9J2LX+2tH31wOJ1rKh9jpl2U2n7qGG1z6Nbf/Ga3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhM+zd8Ga+/7Q6xJqWnzqm0vbP7bV6XXmUH6r6RMG3lSzbYubFpVOu6bOkm3deM1ulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XC59k3cM8cWX4e/fajys+jjx5Wfh597sqNStvnn1L7vvObPfvr0mmtvequ2SVdIGmppIUVw8ZKulHS/enfMZ0t08xa1chm/IXAwUOGnQjMiYhdgDnpvZn1sbphj4hbgGVDBh8KXJReXwQc1ua6zKzNmt1nHx8RA+n1E8D4WiNKmgHMANiUzZtcnJm1quWj8RERQJS0z4qIqRExdXjJQ/7MrLOaDfuTkiYApH+Xtq8kM+uEZsN+DTA9vZ4OXN2ecsysU+rus0u6DNgfGCfpceCLwFeA2ZI+BjwCHN7JIq15T7+h5h4WUP88ej3Tbz6mtH3XH/tcer+oG/aIOKJG0wFtrsXMOsiXy5plwmE3y4TDbpYJh90sEw67WSbcxXUD8OKNk2q2zX3tGXWmLj/1tufc6aXtrzvhwdJ23w66f3jNbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwufZ1wMb7zi5tP2/dv5hzbYxdbqw3rWyfNmT/qv8TPma5cvLZ2B9w2t2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTPs++Hthp9pLS9r02af47+4g5Hy9t3/W3dzY9b+svXrObZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZpnwefY+sHz6m0vbvzS+3r3fR9Rsmb74wNIpX/eZB0rbfd/3DUfdNbukCyQtlbSwYthMSUskzU8/0zpbppm1qpHN+AuBg6sMPzMipqSf69pblpm1W92wR8QtwLIu1GJmHdTKAbrjJN2TNvPH1BpJ0gxJ8yTNW0WdG56ZWcc0G/azgZ2AKcAAUPMIUkTMioipETF1eMmBJDPrrKbCHhFPRsSaiHgJOBfYu71lmVm7NRV2SRMq3r4PWFhrXDPrD3XPs0u6DNgfGCfpceCLwP6SpgABLAaO7WCN672NJ25X2v7WT91R2j5qWPO7P3Pv27m0fdfl7q+ei7phj4gjqgw+vwO1mFkH+XJZs0w47GaZcNjNMuGwm2XCYTfLhLu4dsGik3Yobf/xtte2NP93LPhgzTZ3YbVBXrObZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZpnwefYuuOu9Z9YZo7U7+Iz+t5dqtq1evryleduGw2t2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTPs++AVg1fnTNtuEvTuxiJWtb89TTNdtiZfnjwDSi/PqDjbYZ11RNAGu22aq0/f4TNml63o2INarZ9tpP1rkHwbPPNrVMr9nNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w08sjmHYCLgfEUj2ieFRFnSRoL/ACYTPHY5sMjwp2ne+CnV1zQ6xJqestvqj0EuPD0k1uWTjtmmxWl7Xe88dKmaup3u33+uNL2HT8zt6n5NrJmXw2cEBG7AW8CPiFpN+BEYE5E7ALMSe/NrE/VDXtEDETE3en1CmARMBE4FLgojXYRcFinijSz1q3TPrukycBewB3A+IgYSE1PUGzmm1mfajjskkYBVwLHR8QrLs6NiKDYn6823QxJ8yTNW0X5tdBm1jkNhV3ScIqgXxIRP0qDn5Q0IbVPAJZWmzYiZkXE1IiYOrzFGyuaWfPqhl2SgPOBRRHxtYqma4Dp6fV04Or2l2dm7aJiC7xkBGk/4FZgATB4z+KTKPbbZwOvAh6hOPW2rGxeW2ps7KMDWq15vfOXn7+6tH3OHld0qZK8vBAv1mxbFbVvv92IafccXdr+5/nNd7+dcNvq0vYRP7uzZtsdMYdnY1nV/rN1z7NHxG1Arc63+SXXbD3lK+jMMuGwm2XCYTfLhMNulgmH3SwTDrtZJnwr6S7Y7F0Pl7bvflp5l8bo4P/SFq8tvTSio91Id7/1I6Xt8ejIlua/4xXP1W789YKW5j2G+1tq7wWv2c0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTNTtz95OufZnN+uWsv7sXrObZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZpmoG3ZJO0j6paT7JN0r6dNp+ExJSyTNTz/TOl+umTWrkccPrAZOiIi7JW0B3CXpxtR2ZkSc3rnyzKxd6oY9IgaAgfR6haRFwMROF2Zm7bVO++ySJgN7AXekQcdJukfSBZLG1JhmhqR5kuatYmVLxZpZ8xoOu6RRwJXA8RHxLHA2sBMwhWLNf0a16SJiVkRMjYipwxnRhpLNrBkNhV3ScIqgXxIRPwKIiCcjYk1EvAScC+zduTLNrFWNHI0XcD6wKCK+VjF8QsVo7wMWtr88M2uXRo7G7wscCSyQND8NOwk4QtIUIIDFwLEdqdDM2qKRo/G3AdXuQ31d+8sxs07xFXRmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sE4qI7i1Megp4pGLQOODprhWwbvq1tn6tC1xbs9pZ26SI2KZaQ1fDvtbCpXkRMbVnBZTo19r6tS5wbc3qVm3ejDfLhMNuloleh31Wj5dfpl9r69e6wLU1qyu19XSf3cy6p9drdjPrEofdLBM9CbukgyX9XtIDkk7sRQ21SFosaUF6DPW8HtdygaSlkhZWDBsr6UZJ96d/qz5jr0e19cVjvEseM97Tz67Xjz/v+j67pI2APwAHAY8DdwJHRMR9XS2kBkmLgakR0fMLMCS9DXgOuDgi9kjDvgosi4ivpC/KMRHx2T6pbSbwXK8f452eVjSh8jHjwGHA0fTwsyup63C68Ln1Ys2+N/BARDwUES8ClwOH9qCOvhcRtwDLhgw+FLgovb6I4o+l62rU1hciYiAi7k6vVwCDjxnv6WdXUldX9CLsE4HHKt4/Tn897z2AGyTdJWlGr4upYnxEDKTXTwDje1lMFXUf491NQx4z3jefXTOPP2+VD9Ctbb+IeANwCPCJtLnal6LYB+unc6cNPca7W6o8Zvxlvfzsmn38eat6EfYlwA4V77dPw/pCRCxJ/y4FrqL/HkX95OATdNO/S3tcz8v66THe1R4zTh98dr18/Hkvwn4nsIukV0vaBPgQcE0P6liLpJHpwAmSRgLvpP8eRX0NMD29ng5c3cNaXqFfHuNd6zHj9Piz6/njzyOi6z/ANIoj8g8CJ/eihhp17Qj8Nv3c2+vagMsoNutWURzb+BiwNTAHuB+4CRjbR7V9D1gA3EMRrAk9qm0/ik30e4D56Wdarz+7krq68rn5clmzTPgAnVkmHHazTDjsZplw2M0y4bCbZcJhN8uEw26Wif8HteKJB66NhMUAAAAASUVORK5CYII=\n",
-            "text/plain": [
-              "\u003cFigure size 432x288 with 1 Axes\u003e"
-            ]
-          },
-          "metadata": {
-            "needs_background": "light",
-            "tags": []
-          },
-          "output_type": "display_data"
-        }
-      ],
+      "outputs": [],
       "source": [
         "test_model(tflite_model_quant_file, test_image_index, model_type=\"Quantized\")"
       ]
@@ -785,15 +650,7 @@
       "metadata": {
         "id": "T5mWkSbMcU5z"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Float model accuracy is 97.8700% (Number of test samples=10000)\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "evaluate_model(tflite_model_file, model_type=\"Float\")"
       ]
@@ -813,15 +670,7 @@
       "metadata": {
         "id": "-9cnwiPp6EGm"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Quantized model accuracy is 97.8100% (Number of test samples=10000)\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "evaluate_model(tflite_model_quant_file, model_type=\"Quantized\")"
       ]
diff --git a/tensorflow/lite/g3doc/performance/post_training_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_quant.ipynb
index 311c818..0e58e05 100644
--- a/tensorflow/lite/g3doc/performance/post_training_quant.ipynb
+++ b/tensorflow/lite/g3doc/performance/post_training_quant.ipynb
@@ -11,7 +11,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 1,
+      "execution_count": null,
       "metadata": {
         "cellView": "form",
         "id": "R3yYtBPkM2qZ"
@@ -46,20 +46,20 @@
         "id": "CIGrZZPTZVeO"
       },
       "source": [
-        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/performance/post_training_quant\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/lite/g3doc/performance/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "\u003c/table\u003e"
+        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/performance/post_training_quant\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/lite/g3doc/performance/post_training_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
+        "  </td>\n",
+        "</table>"
       ]
     },
     {
@@ -118,7 +118,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 2,
+      "execution_count": null,
       "metadata": {
         "id": "gyqAw1M9lyab"
       },
@@ -144,31 +144,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 3,
+      "execution_count": null,
       "metadata": {
         "id": "hWSAjQWagIHl"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "1875/1875 [==============================] - 10s 5ms/step - loss: 0.2787 - accuracy: 0.9203 - val_loss: 0.1323 - val_accuracy: 0.9624\n"
-          ]
-        },
-        {
-          "data": {
-            "text/plain": [
-              "\u003ctensorflow.python.keras.callbacks.History at 0x7f6443480e80\u003e"
-            ]
-          },
-          "execution_count": 3,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "# Load MNIST dataset\n",
         "mnist = keras.datasets.mnist\n",
@@ -224,7 +204,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 4,
+      "execution_count": null,
       "metadata": {
         "id": "_i8B2nDZmAgQ"
       },
@@ -245,7 +225,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 5,
+      "execution_count": null,
       "metadata": {
         "id": "vptWZq2xnclo"
       },
@@ -257,24 +237,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 6,
+      "execution_count": null,
       "metadata": {
         "id": "Ie9pQaQrn5ue"
       },
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "84452"
-            ]
-          },
-          "execution_count": 6,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
         "tflite_model_file.write_bytes(tflite_model)"
@@ -291,24 +258,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 7,
+      "execution_count": null,
       "metadata": {
         "id": "g8PUvLWDlmmz"
       },
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "23840"
-            ]
-          },
-          "execution_count": 7,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
         "tflite_quant_model = converter.convert()\n",
@@ -327,24 +281,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 8,
+      "execution_count": null,
       "metadata": {
         "id": "JExfcfLDscu4"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "total 214M\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828  44K Jun 23 06:04 mnist_model_quant_f16.tflite\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828  24K Jun 23 06:12 mnist_model_quant.tflite\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828  83K Jun 23 06:12 mnist_model.tflite\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828  44M Jun 23 06:10 resnet_v2_101_quantized.tflite\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828 171M Jun 23 06:09 resnet_v2_101.tflite\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "!ls -lh {tflite_models_dir}"
       ]
@@ -372,7 +313,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 9,
+      "execution_count": null,
       "metadata": {
         "id": "Jn16Rc23zTss"
       },
@@ -384,7 +325,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 10,
+      "execution_count": null,
       "metadata": {
         "id": "J8Pztk1mvNVL"
       },
@@ -405,7 +346,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 11,
+      "execution_count": null,
       "metadata": {
         "id": "AKslvo2kwWac"
       },
@@ -423,24 +364,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 12,
+      "execution_count": null,
       "metadata": {
         "id": "XZClM2vo3_bm"
       },
-      "outputs": [
-        {
-          "data": {
-            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFxZJREFUeJzt3XtU1HXeB/D3cE0RVDSG4eKMPJBL\nIrI6ZqXhBTFrVwwpw5WEAGnLc9ZL2nbbI1arPPV4nix99jRR7aiFz7qmtIu6KhulVrJj4baYHiKI\nq6DCE4pyG7/PH51mI5nf4DAX9Pt+neM5zO/z/f2+H37ynt/M/GbmpxJCCBCRdDzc3QARuQfDTyQp\nhp9IUgw/kaQYfiJJMfxEkmL4yeF6enqgUqlQXV0NAMjOzsaGDRucPm9+fj5mzpzp9HluFgy/nYYN\nG2b55+HhgSFDhlhuv/vuu06fPzs7u1cPvr6+GDlypNPntUd+fj6effZZm+OmT5+OP/7xj07p4Ztv\nvum1v4YNGwaVSoXNmzc7Zb4bgZe7G7hRXbp0yfKzTqdDfn4+5syZY3V8T08PvLwct7vz8/ORn59v\nuZ2WloahQ4c6bPs/Zjab4enp6ZRtu0pERESv/7Ovv/4a48aNw8KFC93YlXvxyO8kzz//PB5++GEs\nXrwY/v7+2LFjB9LS0pCbm2sZc/jwYeh0Osvturo6JCcn49Zbb8XYsWOxdevWfs118eJF7NmzB+np\n6f0a/8O8L7zwAkaNGoWxY8di586dlnpaWhqWL1+OefPmwc/PD0eOHEFHRwdWr16N8PBwqNVqPPHE\nE+jo6LCsk5eXh+DgYISGhsJoNPaa76e/9/vvv4+4uDgEBAQgMjISBw8exG9/+1t8+umn+PWvf41h\nw4Zh5cqVAIBTp05hzpw5CAwMxM9+9jPs3r3bsp1z587hl7/8JQICAnDnnXeiqqqqX78/ABiNRsye\nPRvh4eH9XuemI2jAtFqtOHToUK9lzz33nPD29hYffPCBMJvN4vLly2LJkiVi3bp1ljGHDh0SWq1W\nCCFET0+PmDhxovj9738vOjs7RUVFhdBqteLw4cNCCCFKSkrEqFGj+pz/rbfeEpGRkf3u99ChQ8LT\n01OsWbNGdHR0iOLiYjFkyBBRUVEhhBBiyZIlYsSIEeKTTz4RZrNZdHR0iOXLl4sHHnhAtLS0iO++\n+07cd9994vnnnxdCCPGXv/xFBAcHi/LycnHp0iXx0EMPCQCiqqrKsr0ffu9jx46J4cOHi8OHDwuz\n2SxqamrE6dOnhRBCTJs2TbzzzjuWPtva2kRISIgwGo2iu7tbmEwmERgYaBmfkpIiUlNTRXt7uzh5\n8qQIDg4WM2bMsKw/b9488corr1zz+1+9elVotVqxffv2fu+zmxHD7wDWwj9r1qxey5TCf/ToUTF2\n7Nhe41944QWRnZ1tc/74+Hjx4osv9rvfQ4cOCW9vb9He3m5ZlpycLDZs2GDp89FHH7XUzGaz8PX1\nFdXV1ZZlH3/8seUO55FHHhHPPfecpVZeXm41/JmZmWLNmjV99vXT8O/YsUPMnDmz15jMzEzx0ksv\nia6uLuHp6Wm5wxJCiLVr1/YKvzV///vfhb+/f6/fX0Z8zu9E1/OQ8ttvv0VNTQ1GjBhhWWY2m22+\nel1VVYWjR49i27Zt19XbqFGjer1GoNVq0dDQYLn9497Pnj2Lzs5OTJw40bJM/OjzYA0NDZg2bVqv\nbVlTW1uLKVOm9KvHb7/9FseOHeu1T3p6epCRkYGmpiaYzeZefWq1WpSWltrcrtFoxEMPPeS010hu\nFAy/E6lUql63/fz8cPnyZcvts2fPWn4ODw9HVFQUvvrqq+uaY9u2bZgxY4Zi4Ppy4cIFXLlyBUOG\nDAEA1NTUQK/X99m7Wq2Gj48Pzpw5A7Vafc22NBoNamtrLbdramqszhseHo7Kyso+az/dX+Hh4UhI\nSMD+/fuvGdvd3Q0PDw/U1tYiMjLS5rw/aG9vx+7du1FUVGRz7M2OL/i5UFxcHIqKitDa2orGxka8\n9tprltpdd90FHx8fbNq0CR0dHTCbzfjyyy9x4sQJxW1u27YNGRkZ1yxPS0tDdna21fWuXr2K3Nxc\ndHV1oaSkBPv378eDDz7Y51hPT09kZ2dj5cqVOHfuHIQQqKurw8GDBwEAixYtwttvv43Tp0+jvb0d\n69evtzpvVlYW8vPz8eGHH+Lq1auoq6vDmTNnAHx/J/PNN99YxiYlJaG8vBzvvfceuru70d3djdLS\nUpw5cwbe3t544IEHsG7dOly5cgX/+te/sH37dsV9BQC7d+9GUFAQ7rnnHptjb3YMvwtlZGQgOjoa\nWq0W8+bNQ2pqqqXm5eWFffv2obS0FDqdDqNHj8Zjjz2GtrY2AEBJSUmvh78AcOTIETQ1NSElJeWa\nuWpra3s9FP+psLAw+Pn5QaPRID09Hfn5+YiKirI6ftOmTdBqtbjjjjswfPhwzJ07FxUVFQCA+fPn\nY/ny5ZgxYwZuu+02JCYmWt3O3XffjTfffBO/+c1vMHz4cMyaNcvyqGHlypUoKCjAiBEjsHr1agwf\nPhx/+9vfsGPHDmg0GgQHB+OZZ55BZ2cnAOAPf/gDWltboVarkZWVhUcffbTXXHPnzsXLL7/ca5nR\naMTSpUuveZQhI5UQ/DKPm01HRwd+/vOf48svv+zzvQWHDx9Gdna25R14JCc+578J3XLLLdf92gHJ\nhw/7iSTFh/1EkuKRn0hSLn3O76PyxS3wc+WURFLpQDu6RGe/xg4o/AcOHMCKFStgNpuRnZ2Np59+\nWnH8LfDDVFXCQKYkIgXHRXG/x9r9sN9sNmP58uXYv38/Tp06hYKCApw6dcrezRGRi9kd/tLSUkRG\nRiIiIgI+Pj5ITU1FYWGhI3sjIieyO/z19fW9PlQRFhaG+vr6a8YZDAbo9Xro9Xp0o3/PRYjI+ewO\nf19nCPt6y2ROTg5MJhNMJhO84WvvdETkYHaHPywsrNcnuerq6hASEuKQpojI+ewO/5QpU1BRUYGq\nqip0dXVh586dSEpKcmRvROREdp/q8/LywpYtW3DvvffCbDYjMzMT48ePd2RvRORELn17b4AqkOf5\niZzouChGm2jp11i+vZdIUgw/kaQYfiJJMfxEkmL4iSTF8BNJiuEnkhTDTyQphp9IUgw/kaQYfiJJ\nMfxEkmL4iSTF8BNJiuEnkhTDTyQphp9IUgw/kaQYfiJJMfxEkmL4iSTF8BNJiuEnkhTDTyQphp9I\nUgw/kaQYfiJJMfxEkmL4iSTF8BNJymsgK+t0Ovj7+8PT0xNeXl4wmUyO6ouInGxA4QeADz/8EKNH\nj3ZEL0TkQnzYTySpAYVfpVJh7ty5mDx5MgwGQ59jDAYD9Ho99Ho9utE5kOmIyIFUQghh78oNDQ0I\nCQlBc3MzEhMT8frrryM+Pt7q+ABVIKaqEuydjohsOC6K0SZa+jV2QEf+kJAQAEBQUBCSk5NRWlo6\nkM0RkQvZHf729nZcvHjR8vPBgwcRExPjsMaIyLnsfrW/qakJycnJAICenh786le/wrx58xzWGBE5\nl93hj4iIwMmTJx3ZCxG5EE/1EUmK4SeSFMNPJCmGn0hSDD+RpAb8wR5ZXFh2l9XamEe+Vlz3dLNa\nsd7V6a1YDy1Qrg+tu2S1drXslOK6JC8e+YkkxfATSYrhJ5IUw08kKYafSFIMP5GkGH4iSfE8fz89\ntfY9q7UUv1bllf9jgJPPVC5X91y2Wtt8btYAJ79xlTZrrdb8Ng1XXNer+ISj2xl0eOQnkhTDTyQp\nhp9IUgw/kaQYfiJJMfxEkmL4iSQ1oCv2XK8b+Yo97Q9OtVo7H6t8HzryK+Vd3BqtUqz7xP6fYv3l\nmPet1hKHXFFct+jyMMX6L4Za/66AgboiuhTrxzv9FOszb+m2e+7IoscU67fl/MPubbuTy67YQ0Q3\nLoafSFIMP5GkGH4iSTH8RJJi+IkkxfATSYqf5+8nvz8fV6gNbNsBA1sdrwfPtFp7aZpOee6PlK85\n8PLMSDs66h+vK1cV637/bFSsj/p4t2J9go/16x0MrVa+FoIMbB75MzMzERQUhJiYGMuylpYWJCYm\nIioqComJiWhttfFlFkQ06NgMf0ZGBg4cONBrWV5eHhISElBRUYGEhATk5eU5rUEicg6b4Y+Pj0dg\nYGCvZYWFhUhPTwcApKenY+/evc7pjoicxq7n/E1NTdBoNAAAjUaD5uZmq2MNBgMMBgMAoBud9kxH\nRE7g9Ff7c3JyYDKZYDKZ4A1fZ09HRP1kV/jVajUaG79/JbaxsRFBQUEObYqInM+u8CclJcFoNAIA\njEYjFixY4NCmiMj5bD7nX7x4MUpKSnD+/HmEhYVh/fr1ePrpp7Fo0SK89dZbGDNmDHbt2uWKXsmK\nnrNNVmt+u63XAMBsY9t+f75gR0eO0ZR9l2J9vI/yn+9/tYyzWtO9843iuj2K1ZuDzfAXFBT0uby4\nuNjhzRCR6/DtvUSSYviJJMXwE0mK4SeSFMNPJCl+pJfcxksbrljf8uwWxbq3ylOxvmvzHKu1UY2f\nKq4rAx75iSTF8BNJiuEnkhTDTyQphp9IUgw/kaQYfiJJ8Tw/uc3pVaGK9Sm+ypcuL+9Svvx44KnL\n192TTHjkJ5IUw08kKYafSFIMP5GkGH4iSTH8RJJi+IkkxfP85FSdv5hitfb5g/9tY23lKzw9vmKF\nYn3IJ6U2ti83HvmJJMXwE0mK4SeSFMNPJCmGn0hSDD+RpBh+IknxPD85Vc191o8vw1TK5/EXVyUq\n1oceOKlYF4pVsnnkz8zMRFBQEGJiYizLcnNzERoairi4OMTFxWHfvn1ObZKIHM9m+DMyMnDgwIFr\nlq9atQplZWUoKyvD/fff75TmiMh5bIY/Pj4egYGBruiFiFzI7hf8tmzZgtjYWGRmZqK1tdXqOIPB\nAL1eD71ej2502jsdETmYXeF//PHHUVlZibKyMmg0Gjz55JNWx+bk5MBkMsFkMsHbxgc1iMh17Aq/\nWq2Gp6cnPDw8sGzZMpSW8tNTRDcau8Lf2Nho+XnPnj29zgQQ0Y3B5nn+xYsXo6SkBOfPn0dYWBjW\nr1+PkpISlJWVQaVSQafT4Y033nBFrzQIefj7K9Yfueeo1Vrb1Q7FdZs3RCjWfTv/oVgnZTbDX1BQ\ncM2yrKwspzRDRK7Dt/cSSYrhJ5IUw08kKYafSFIMP5Gk+JFeGpCK3PGK9b+O/h+rtQUVKYrr+u7j\nqTxn4pGfSFIMP5GkGH4iSTH8RJJi+IkkxfATSYrhJ5IUz/OTou/S7lSs//Ph1xTrlT3dVmuX/jNM\ncV1fNCrWaWB45CeSFMNPJCmGn0hSDD+RpBh+Ikkx/ESSYviJJMXz/JLzCg1RrK/83f8q1n1Vyn9C\nqScfsVq7dT8/r+9OPPITSYrhJ5IUw08kKYafSFIMP5GkGH4iSTH8RJKyeZ6/trYWS5cuxdmzZ+Hh\n4YGcnBysWLECLS0tePjhh1FdXQ2dToc//elPGDlypCt6puug8lL+L5741zrF+kPDLijW370YpFhX\n/8768eWq4prkbDaP/F5eXti0aRO++uorfPbZZ9i6dStOnTqFvLw8JCQkoKKiAgkJCcjLy3NFv0Tk\nIDbDr9FoMGnSJACAv78/oqOjUV9fj8LCQqSnpwMA0tPTsXfvXud2SkQOdV3P+aurq/HFF19g6tSp\naGpqgkajAfD9HURzc7NTGiQi5+j3e/svXbqElJQUvPrqqwgICOj3BAaDAQaDAQDQjc7r75CInKJf\nR/7u7m6kpKRgyZIlWLhwIQBArVajsfH7L1hsbGxEUFDfL/zk5OTAZDLBZDLBG74OapuIBspm+IUQ\nyMrKQnR0NFavXm1ZnpSUBKPRCAAwGo1YsGCB87okIodTCSGE0oCjR4/innvuwYQJE+Dh8f19xYYN\nGzB16lQsWrQINTU1GDNmDHbt2oXAwEDFyQJUgZiqSnBc92STarLyJbSLPtg+oO3f/cxyxfqIbZ8O\naPt0fY6LYrSJln6Ntfmcf/r06bB2/1BcXHx9nRHRoMF3+BFJiuEnkhTDTyQphp9IUgw/kaQYfiJJ\n8au7bwKet99mtZazs3BA2779beXz+Lrtnw1o++Q+PPITSYrhJ5IUw08kKYafSFIMP5GkGH4iSTH8\nRJLief6bwOknrH9l+vyhbQPadlhJl/IA5a+DoEGMR34iSTH8RJJi+IkkxfATSYrhJ5IUw08kKYaf\nSFI8z38D6Jh/h2K9eP4mhepQxzZDNw0e+YkkxfATSYrhJ5IUw08kKYafSFIMP5GkGH4iSdk8z19b\nW4ulS5fi7Nmz8PDwQE5ODlasWIHc3Fy8+eabuPXWWwEAGzZswP333+/0hmXUMM1TsT7Gy/5z+e9e\nDFKse7cpf56fn+a/cdkMv5eXFzZt2oRJkybh4sWLmDx5MhITEwEAq1atwpo1a5zeJBE5ns3wazQa\naDQaAIC/vz+io6NRX1/v9MaIyLmu6zl/dXU1vvjiC0ydOhUAsGXLFsTGxiIzMxOtra19rmMwGKDX\n66HX69GNzoF3TEQO0e/wX7p0CSkpKXj11VcREBCAxx9/HJWVlSgrK4NGo8GTTz7Z53o5OTkwmUww\nmUzwhq/DGieigelX+Lu7u5GSkoIlS5Zg4cKFAAC1Wg1PT094eHhg2bJlKC0tdWqjRORYNsMvhEBW\nVhaio6OxevVqy/LGxkbLz3v27EFMTIxzOiQip7D5gt+xY8ewfft2TJgwAXFxcQC+P61XUFCAsrIy\nqFQq6HQ6vPHGG05vlq7fxgu3K9Y/vVenWBeNXzqwGxpMbIZ/+vTpEH18NzvP6RPd2PgOPyJJMfxE\nkmL4iSTF8BNJiuEnkhTDTyQplejrPJ6TBKgCMVWV4KrpiKRzXBSjTbT0ayyP/ESSYviJJMXwE0mK\n4SeSFMNPJCmGn0hSDD+RpFx6iW6fUR5o1VVZbp87d87y1d+DzWDtbbD2BbA3ezmyN5/q/h/PXfom\nn5/S6/UwmUzuml7RYO1tsPYFsDd7uas3PuwnkhTDTyQpz9zc3Fx3NjB58mR3Tq9osPY2WPsC2Ju9\n3NGbW5/zE5H78GE/kaQYfiJJuSX8Bw4cwLhx4xAZGYm8vDx3tGCVTqezXKNAr9e7tZfMzEwEBQX1\nuiBKS0sLEhMTERUVhcTERKvXSHRHb7m5uQgNDUVcXBzi4uKwb98+t/RWW1uLWbNmITo6GuPHj8fm\nzZsBuH/fWevLbftNuFhPT4+IiIgQlZWVorOzU8TGxory8nJXt2GVVqsV586dc3cbQgghPvroI3Hi\nxAkxfvx4y7K1a9eKjRs3CiGE2Lhxo3jqqacGTW/r1q0Tr7zyilv6+bGGhgZx4sQJIYQQbW1tIioq\nSpSXl7t931nry137zeVH/tLSUkRGRiIiIgI+Pj5ITU1FYWGhq9u4IcTHxyMwMLDXssLCQqSnpwMA\n0tPTsXfvXne01mdvg4VGo8GkSZMA9L6svLv3nbW+3MXl4a+vr0d4eLjldlhYmFt3wE+pVCrMnTsX\nkydPhsFgcHc712hqaoJGowHw/R9Tc3OzmzvqrT+XbXelH19WfjDtO3sud+9oLg+/6OPMokqlcnUb\nVh07dgyff/459u/fj61bt+Ljjz92d0s3jP5ett1VfnpZ+cHC3svdO5rLwx8WFoba2lrL7bq6OoSE\nhLi6Dat+6CUoKAjJycmD7tLjarXacoXkxsZGBAUFubmjfxtMl223dll5d++7wXS5e5eHf8qUKaio\nqEBVVRW6urqwc+dOJCUlubqNPrW3t+PixYuWnw8ePDjoLj2elJQEo9EIADAajViwYIGbO/q3wXLZ\ndmHlsvLu3nfW+nLbfnP5S4xCiKKiIhEVFSUiIiLESy+95I4W+lRZWSliY2NFbGysuP32293eW2pq\nqggODhZeXl4iNDRU5Ofni/Pnz4vZs2eLyMhIMXv2bHHhwoVB01taWpqIiYkREyZMEPPnzxcNDQ1u\n6e3IkSMCgJgwYYKYOHGimDhxoigqKnL7vrPWl7v2G9/eSyQpvsOPSFIMP5GkGH4iSTH8RJJi+Ikk\nxfATSYrhJ5LU/wOdAGX9nfSgHgAAAABJRU5ErkJggg==\n",
-            "text/plain": [
-              "\u003cFigure size 600x400 with 1 Axes\u003e"
-            ]
-          },
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "display_data"
-        }
-      ],
+      "outputs": [],
       "source": [
         "import matplotlib.pylab as plt\n",
         "\n",
@@ -462,7 +390,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 13,
+      "execution_count": null,
       "metadata": {
         "id": "05aeAuWjvjPx"
       },
@@ -502,19 +430,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 14,
+      "execution_count": null,
       "metadata": {
         "id": "DqXBnDfJ7qxL"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "0.9624\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "print(evaluate_model(interpreter))"
       ]
@@ -530,19 +450,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 15,
+      "execution_count": null,
       "metadata": {
         "id": "-9cnwiPp6EGm"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "0.9626\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "print(evaluate_model(interpreter_quant))"
       ]
@@ -573,7 +485,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 16,
+      "execution_count": null,
       "metadata": {
         "id": "jrXZxSJiJfYN"
       },
@@ -591,24 +503,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 17,
+      "execution_count": null,
       "metadata": {
         "id": "LwnV4KxwVEoG"
       },
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "178509092"
-            ]
-          },
-          "execution_count": 17,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "# Convert to TF Lite without quantization\n",
         "resnet_tflite_file = tflite_models_dir/\"resnet_v2_101.tflite\"\n",
@@ -617,24 +516,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 18,
+      "execution_count": null,
       "metadata": {
         "id": "2qkZD0VoVExe"
       },
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "45182656"
-            ]
-          },
-          "execution_count": 18,
-          "metadata": {
-            "tags": []
-          },
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "# Convert to TF Lite with quantization\n",
         "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
@@ -644,23 +530,11 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 19,
+      "execution_count": null,
       "metadata": {
         "id": "vhOjeg1x9Knp"
       },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "-rw-rw-r-- 1 colaboratory-playground 50844828  44K Jun 23 06:04 /tmp/mnist_tflite_models/mnist_model_quant_f16.tflite\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828  24K Jun 23 06:12 /tmp/mnist_tflite_models/mnist_model_quant.tflite\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828  83K Jun 23 06:12 /tmp/mnist_tflite_models/mnist_model.tflite\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828  44M Jun 23 06:13 /tmp/mnist_tflite_models/resnet_v2_101_quantized.tflite\n",
-            "-rw-rw-r-- 1 colaboratory-playground 50844828 171M Jun 23 06:12 /tmp/mnist_tflite_models/resnet_v2_101.tflite\n"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "!ls -lh {tflite_models_dir}/*.tflite"
       ]
diff --git a/tensorflow/lite/g3doc/performance/post_training_quantization.md b/tensorflow/lite/g3doc/performance/post_training_quantization.md
index 2fd4f07..da48e8f 100644
--- a/tensorflow/lite/g3doc/performance/post_training_quantization.md
+++ b/tensorflow/lite/g3doc/performance/post_training_quantization.md
@@ -56,9 +56,29 @@
 compatibility with integer only hardware devices or accelerators by making sure
 all model math is integer quantized.
 
-For full integer quantization, you need to measure the dynamic range of
-activations and inputs by supplying sample input data to the converter. Refer to
-the `representative_dataset_gen()` function used in the following code.
+For full integer quantization, you need to calibrate or estimate the range, i.e,
+(min, max) of all floating-point tensors in the model. Unlike constant tensors
+such as weights and biases, variable tensors such as model input, activations
+(outputs of intermediate layers) and model output cannot be calibrated unless we
+run a few inference cycles. As a result, the converter requires a representative
+dataset to calibrate them. This dataset can be a small subset (around ~100-500
+samples) of the training or validation data. Refer to the
+`representative_dataset()` function below.
+
+<pre>
+def representative_dataset():
+  for data in tf.data.Dataset.from_tensor_slices((images)).batch(1).take(100):
+    yield [data.astype(tf.float32)]
+</pre>
+
+For testing purposes, you can use a dummy dataset as follows:
+
+<pre>
+def representative_dataset():
+    for _ in range(100):
+      data = np.random.rand(1, 244, 244, 3)
+      yield [data.astype(np.float32)]
+ </pre>
 
 #### Integer with float fallback (using default float input/output)
 
@@ -70,11 +90,7 @@
 import tensorflow as tf
 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
 <b>converter.optimizations = [tf.lite.Optimize.DEFAULT]
-def representative_dataset_gen():
-  for _ in range(num_calibration_steps):
-    # Get sample input data as a numpy array in a method of your choosing.
-    yield [input]
-converter.representative_dataset = representative_dataset_gen</b>
+converter.representative_dataset = representative_dataset</b>
 tflite_quant_model = converter.convert()
 </pre>
 
@@ -101,11 +117,7 @@
 import tensorflow as tf
 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
 converter.optimizations = [tf.lite.Optimize.DEFAULT]
-def representative_dataset_gen():
-  for _ in range(num_calibration_steps):
-    # Get sample input data as a numpy array in a method of your choosing.
-    yield [input]
-converter.representative_dataset = representative_dataset_gen
+converter.representative_dataset = representative_dataset
 <b>converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]</b>
 <b>converter.inference_input_type = tf.int8</b>  # or tf.uint8
 <b>converter.inference_output_type = tf.int8</b>  # or tf.uint8
@@ -158,11 +170,7 @@
 <pre>
 import tensorflow as tf
 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-def representative_dataset_gen():
-  for _ in range(num_calibration_steps):
-    # Get sample input data as a numpy array in a method of your choosing.
-    yield [input]
-converter.representative_dataset = representative_dataset_gen
+converter.representative_dataset = representative_dataset
 <b>converter.optimizations = [tf.lite.Optimize.DEFAULT]
 converter.target_spec.supported_ops = [tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8]</b>
 tflite_quant_model = converter.convert()
@@ -174,11 +182,7 @@
 <pre>
 import tensorflow as tf
 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-def representative_dataset_gen():
-  for _ in range(num_calibration_steps):
-    # Get sample input data as a numpy array in a method of your choosing.
-    yield [input]
-converter.representative_dataset = representative_dataset_gen
+converter.representative_dataset = representative_dataset
 converter.optimizations = [tf.lite.Optimize.DEFAULT]
 converter.target_spec.supported_ops = [tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
 <b>tf.lite.OpsSet.TFLITE_BUILTINS</b>]
diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc
index b5ef31e..6b2d24e 100644
--- a/tensorflow/lite/interpreter.cc
+++ b/tensorflow/lite/interpreter.cc
@@ -24,7 +24,7 @@
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/context_util.h"
 #include "tensorflow/lite/core/api/error_reporter.h"
-#include "tensorflow/lite/delegates/status.h"
+#include "tensorflow/lite/delegates/telemetry.h"
 #include "tensorflow/lite/graph_info.h"
 #include "tensorflow/lite/memory_planner.h"
 #include "tensorflow/lite/minimal_logging.h"
diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc
index 4249c85..dcf6705 100644
--- a/tensorflow/lite/interpreter_builder.cc
+++ b/tensorflow/lite/interpreter_builder.cc
@@ -163,6 +163,12 @@
 #endif
   void* lib_tf_internal =
       SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal);
+#if defined(_WIN32)
+  if (lib_tf_internal == nullptr) {
+    lib_tf_internal = SharedLibrary::LoadLibrary(
+        "_pywrap_tensorflow_interpreter_wrapper.pyd");
+  }
+#endif
   if (lib_tf_internal) {
     acquire_flex_delegate_func =
         reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
diff --git a/tensorflow/lite/ios/BUILD.apple b/tensorflow/lite/ios/BUILD.apple
index 44cd298..540c598 100644
--- a/tensorflow/lite/ios/BUILD.apple
+++ b/tensorflow/lite/ios/BUILD.apple
@@ -31,6 +31,7 @@
     name = "strip_common_include_path_core",
     hdr_labels = [
         "//tensorflow/lite/c:c_api.h",
+        "//tensorflow/lite/c:common.h",
         "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate.h",
     ],
 )
@@ -49,8 +50,9 @@
     name = "TensorFlowLiteC_framework",
     hdrs = [
         ":c_api.h",
+        ":common.h",
         ":xnnpack_delegate.h",
-        "//tensorflow/lite/c:common.h",
+        "//tensorflow/lite/c:c_api_types.h",
     ],
     allowlist_symbols_file = ":allowlist_TensorFlowLiteC.txt",
     bundle_name = "TensorFlowLiteC",
@@ -121,6 +123,7 @@
     name = "tensorflow_lite_c",
     hdrs = [
         "//tensorflow/lite/c:c_api.h",
+        "//tensorflow/lite/c:c_api_types.h",
         "//tensorflow/lite/c:common.h",
         "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate.h",
     ],
@@ -130,6 +133,7 @@
     ],
     deps = [
         "//tensorflow/lite/c:c_api",
+        "//tensorflow/lite/c:common",
         "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
     ],
 )
diff --git a/tensorflow/lite/ios/TensorFlowLiteC.podspec b/tensorflow/lite/ios/TensorFlowLiteC.podspec
index 1b98693..58f8f2b 100644
--- a/tensorflow/lite/ios/TensorFlowLiteC.podspec
+++ b/tensorflow/lite/ios/TensorFlowLiteC.podspec
@@ -1,10 +1,10 @@
 Pod::Spec.new do |s|
   s.name             = 'TensorFlowLiteC'
-  s.version          = '2.3.0'
+  s.version          = '2.4.0'
   s.authors          = 'Google Inc.'
   s.license          = { :type => 'Apache' }
   s.homepage         = 'https://github.com/tensorflow/tensorflow'
-  s.source           = { :http => "https://dl.google.com/dl/cpdc/b03814d8b5a44ad2/TensorFlowLiteC-#{s.version}.tar.gz" }
+  s.source           = { :http => "https://dl.google.com/dl/cpdc/e8a95c1d411b795e/TensorFlowLiteC-#{s.version}.tar.gz" }
   s.summary          = 'TensorFlow Lite'
   s.description      = <<-DESC
 
diff --git a/tensorflow/lite/ios/TensorFlowLiteSelectTfOps.podspec b/tensorflow/lite/ios/TensorFlowLiteSelectTfOps.podspec
index 393040b..7fc4dc2 100644
--- a/tensorflow/lite/ios/TensorFlowLiteSelectTfOps.podspec
+++ b/tensorflow/lite/ios/TensorFlowLiteSelectTfOps.podspec
@@ -1,10 +1,10 @@
 Pod::Spec.new do |s|
   s.name             = 'TensorFlowLiteSelectTfOps'
-  s.version          = '2.3.0'
+  s.version          = '2.4.0'
   s.authors          = 'Google Inc.'
   s.license          = { :type => 'Apache' }
   s.homepage         = 'https://github.com/tensorflow/tensorflow'
-  s.source           = { :http => "https://dl.google.com/dl/cpdc/4f626bc24212fd61/TensorFlowLiteSelectTfOps-#{s.version}.tar.gz" }
+  s.source           = { :http => "https://dl.google.com/dl/cpdc/dde267f91a6cd441/TensorFlowLiteSelectTfOps-#{s.version}.tar.gz" }
   s.summary          = 'TensorFlow Lite Select TF Ops'
   s.description      = <<-DESC
 
diff --git a/tensorflow/lite/ios/ios.bzl b/tensorflow/lite/ios/ios.bzl
index 4373e71..a9e98aa 100644
--- a/tensorflow/lite/ios/ios.bzl
+++ b/tensorflow/lite/ios/ios.bzl
@@ -74,16 +74,16 @@
 
 # When the static framework is built with bazel, the all header files are moved
 # to the "Headers" directory with no header path prefixes. This auxiliary rule
-# is used for stripping the path prefix to the "common.h" file included by the
-# "c_api.h" header.
+# is used for stripping the path prefix of header inclusions paths from the
+# provided headers.
 def strip_common_include_path_prefix(name, hdr_labels, prefix = ""):
-    """Create modified header files with the common.h include path stripped out.
+    """Create modified header files with the inclusion path prefixes removed.
 
     Args:
       name: The name to be used as a prefix to the generated genrules.
       hdr_labels: List of header labels to strip out the include path. Each
           label must end with a colon followed by the header file name.
-      prefix: Optional prefix path to prepend to the common.h inclusion path.
+      prefix: Optional prefix path to prepend to the final inclusion path.
     """
 
     for hdr_label in hdr_labels:
@@ -95,7 +95,7 @@
             srcs = [hdr_label],
             outs = [hdr_filename],
             cmd = """
-            sed 's|#include ".*common.h"|#include "{}common.h"|'\
+            sed -E 's|#include ".*/([^/]+\\.h)"|#include "{}\\1"|g'\
             "$(location {})"\
             > "$@"
             """.format(prefix, hdr_label),
diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD
index 822a573..d791105 100644
--- a/tensorflow/lite/java/BUILD
+++ b/tensorflow/lite/java/BUILD
@@ -16,6 +16,7 @@
     "src/testdata/add.bin",
     "src/testdata/add_unknown_dimensions.bin",
     "src/testdata/grace_hopper_224.jpg",
+    "src/testdata/mul_add_signature_def.bin",
     "src/testdata/tile_with_bool_input.bin",
     "AndroidManifest.xml",
     "proguard.flags",
@@ -37,7 +38,9 @@
     headers = [
         "//tensorflow/lite:builtin_ops.h",
         "//tensorflow/lite/c:c_api.h",
+        "//tensorflow/lite/c:c_api_types.h",
         "//tensorflow/lite/c:c_api_experimental.h",
+        # TODO(b/175298345): Clean up and if possible remove common.h here.
         "//tensorflow/lite/c:common.h",
     ],
 )
@@ -260,6 +263,7 @@
     data = [
         "src/testdata/add.bin",
         "src/testdata/add_unknown_dimensions.bin",
+        "src/testdata/mul_add_signature_def.bin",
         "src/testdata/tile_with_bool_input.bin",
         "//tensorflow/lite:testdata/dynamic_shapes.bin",
         "//tensorflow/lite:testdata/multi_add.bin",
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index e14c38d..865cd84 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -19,6 +19,7 @@
 import java.nio.ByteBuffer;
 import java.nio.MappedByteBuffer;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -65,17 +66,20 @@
  * model with Toco, as are the default shapes of the inputs.
  *
  * <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will
- * be implicitly resized according to that array's shape. When inputs are provided as {@link Buffer}
- * types, no implicit resizing is done; the caller must ensure that the {@link Buffer} byte size
- * either matches that of the corresponding tensor, or that they first resize the tensor via {@link
- * #resizeInput()}. Tensor shape and type information can be obtained via the {@link Tensor} class,
- * available via {@link #getInputTensor(int)} and {@link #getOutputTensor(int)}.
+ * be implicitly resized according to that array's shape. When inputs are provided as {@link
+ * java.nio.Buffer} types, no implicit resizing is done; the caller must ensure that the {@link
+ * java.nio.Buffer} byte size either matches that of the corresponding tensor, or that they first
+ * resize the tensor via {@link #resizeInput(int, int[])}. Tensor shape and type information can be
+ * obtained via the {@link Tensor} class, available via {@link #getInputTensor(int)} and {@link
+ * #getOutputTensor(int)}.
  *
  * <p><b>WARNING:</b>Instances of a {@code Interpreter} is <b>not</b> thread-safe. A {@code
  * Interpreter} owns resources that <b>must</b> be explicitly freed by invoking {@link #close()}
  *
  * <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19,
  * but is not guaranteed.
+ *
+ * <p>Note: This class is not thread safe.
  */
 public final class Interpreter implements AutoCloseable {
 
@@ -221,6 +225,7 @@
    */
   public Interpreter(@NonNull File modelFile, Options options) {
     wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
+    signatureNameList = getSignatureDefNames();
   }
 
   /**
@@ -269,7 +274,7 @@
 
   /**
    * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom
-   * {@link #Options}.
+   * {@link Interpreter.Options}.
    *
    * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
    * {@code ByteBuffer} can be either a {@link MappedByteBuffer} that memory-maps a model file, or a
@@ -280,38 +285,41 @@
    */
   public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
     wrapper = new NativeInterpreterWrapper(byteBuffer, options);
+    signatureNameList = getSignatureDefNames();
   }
 
   /**
    * Runs model inference if the model takes only one input, and provides only one output.
    *
-   * <p>Warning: The API is more efficient if a {@link Buffer} (preferably direct, but not required)
-   * is used as the input/output data type. Please consider using {@link Buffer} to feed and fetch
-   * primitive data for better performance. The following concrete {@link Buffer} types are
-   * supported:
+   * <p>Warning: The API is more efficient if a {@link java.nio.Buffer} (preferably direct, but not
+   * required) is used as the input/output data type. Please consider using {@link java.nio.Buffer}
+   * to feed and fetch primitive data for better performance. The following concrete {@link
+   * java.nio.Buffer} types are supported:
    *
    * <ul>
    *   <li>{@link ByteBuffer} - compatible with any underlying primitive Tensor type.
-   *   <li>{@link FloatBuffer} - compatible with float Tensors.
-   *   <li>{@link IntBuffer} - compatible with int32 Tensors.
-   *   <li>{@link LongBuffer} - compatible with int64 Tensors.
+   *   <li>{@link java.nio.FloatBuffer} - compatible with float Tensors.
+   *   <li>{@link java.nio.IntBuffer} - compatible with int32 Tensors.
+   *   <li>{@link java.nio.LongBuffer} - compatible with int64 Tensors.
    * </ul>
    *
-   * Note that boolean types are only supported as arrays, not {@link Buffer}s, or as scalar inputs.
+   * Note that boolean types are only supported as arrays, not {@link java.nio.Buffer}s, or as
+   * scalar inputs.
    *
-   * @param input an array or multidimensional array, or a {@link Buffer} of primitive types
-   *     including int, float, long, and byte. {@link Buffer} is the preferred way to pass large
-   *     input data for primitive types, whereas string types require using the (multi-dimensional)
-   *     array input path. When a {@link Buffer} is used, its content should remain unchanged until
-   *     model inference is done, and the caller must ensure that the {@link Buffer} is at the
-   *     appropriate read position. A {@code null} value is allowed only if the caller is using a
-   *     {@link Delegate} that allows buffer handle interop, and such a buffer has been bound to the
-   *     input {@link Tensor}.
-   * @param output a multidimensional array of output data, or a {@link Buffer} of primitive types
-   *     including int, float, long, and byte. When a {@link Buffer} is used, the caller must ensure
-   *     that it is set the appropriate write position. A null value is allowed only if the caller
-   *     is using a {@link Delegate} that allows buffer handle interop, and such a buffer has been
-   *     bound to the output {@link Tensor}. See {@link Options#setAllowBufferHandleOutput()}.
+   * @param input an array or multidimensional array, or a {@link java.nio.Buffer} of primitive
+   *     types including int, float, long, and byte. {@link java.nio.Buffer} is the preferred way to
+   *     pass large input data for primitive types, whereas string types require using the
+   *     (multi-dimensional) array input path. When a {@link java.nio.Buffer} is used, its content
+   *     should remain unchanged until model inference is done, and the caller must ensure that the
+   *     {@link java.nio.Buffer} is at the appropriate read position. A {@code null} value is
+   *     allowed only if the caller is using a {@link Delegate} that allows buffer handle interop,
+   *     and such a buffer has been bound to the input {@link Tensor}.
+   * @param output a multidimensional array of output data, or a {@link java.nio.Buffer} of
+   *     primitive types including int, float, long, and byte. When a {@link java.nio.Buffer} is
+   *     used, the caller must ensure that it is set the appropriate write position. A null value is
+   *     allowed only if the caller is using a {@link Delegate} that allows buffer handle interop,
+   *     and such a buffer has been bound to the output {@link Tensor}. See {@link
+   *     Interpreter.Options#setAllowBufferHandleOutput(boolean)}.
    * @throws IllegalArgumentException if {@code input} or {@code output} is null or empty, or if
    *     error occurs when running the inference.
    * @throws IllegalArgumentException (EXPERIMENTAL, subject to change) if the inference is
@@ -327,35 +335,36 @@
   /**
    * Runs model inference if the model takes multiple inputs, or returns multiple outputs.
    *
-   * <p>Warning: The API is more efficient if {@link Buffer}s (preferably direct, but not required)
-   * are used as the input/output data types. Please consider using {@link Buffer} to feed and fetch
-   * primitive data for better performance. The following concrete {@link Buffer} types are
-   * supported:
+   * <p>Warning: The API is more efficient if {@link java.nio.Buffer}s (preferably direct, but not
+   * required) are used as the input/output data types. Please consider using {@link
+   * java.nio.Buffer} to feed and fetch primitive data for better performance. The following
+   * concrete {@link java.nio.Buffer} types are supported:
    *
    * <ul>
    *   <li>{@link ByteBuffer} - compatible with any underlying primitive Tensor type.
-   *   <li>{@link FloatBuffer} - compatible with float Tensors.
-   *   <li>{@link IntBuffer} - compatible with int32 Tensors.
-   *   <li>{@link LongBuffer} - compatible with int64 Tensors.
+   *   <li>{@link java.nio.FloatBuffer} - compatible with float Tensors.
+   *   <li>{@link java.nio.IntBuffer} - compatible with int32 Tensors.
+   *   <li>{@link java.nio.LongBuffer} - compatible with int64 Tensors.
    * </ul>
    *
-   * Note that boolean types are only supported as arrays, not {@link Buffer}s, or as scalar inputs.
+   * Note that boolean types are only supported as arrays, not {@link java.nio.Buffer}s, or as
+   * scalar inputs.
    *
    * <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is
    * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and
    * such a buffer has been bound to the corresponding input or output {@link Tensor}(s).
    *
    * @param inputs an array of input data. The inputs should be in the same order as inputs of the
-   *     model. Each input can be an array or multidimensional array, or a {@link Buffer} of
-   *     primitive types including int, float, long, and byte. {@link Buffer} is the preferred way
-   *     to pass large input data, whereas string types require using the (multi-dimensional) array
-   *     input path. When {@link Buffer} is used, its content should remain unchanged until model
-   *     inference is done, and the caller must ensure that the {@link Buffer} is at the appropriate
-   *     read position.
+   *     model. Each input can be an array or multidimensional array, or a {@link java.nio.Buffer}
+   *     of primitive types including int, float, long, and byte. {@link java.nio.Buffer} is the
+   *     preferred way to pass large input data, whereas string types require using the
+   *     (multi-dimensional) array input path. When {@link java.nio.Buffer} is used, its content
+   *     should remain unchanged until model inference is done, and the caller must ensure that the
+   *     {@link java.nio.Buffer} is at the appropriate read position.
    * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
-   *     Buffer}s of primitive types including int, float, long, and byte. It only needs to keep
-   *     entries for the outputs to be used. When a {@link Buffer} is used, the caller must ensure
-   *     that it is set the appropriate write position.
+   *     java.nio.Buffer}s of primitive types including int, float, long, and byte. It only needs to
+   *     keep entries for the outputs to be used. When a {@link java.nio.Buffer} is used, the caller
+   *     must ensure that it is set the appropriate write position.
    * @throws IllegalArgumentException if {@code inputs} or {@code outputs} is null or empty, or if
    *     error occurs when running the inference.
    */
@@ -366,6 +375,49 @@
   }
 
   /**
+   * Runs model inference based on SignatureDef provided through @code methodName.
+   *
+   * <p>See {@link Interpreter#run(Object, Object)} for more details on the allowed input and output
+   * data types.
+   *
+   * @param inputs A Map of inputs from input name in the signatureDef to an input object.
+   * @param outputs a map mapping from output name in SignatureDef to output data.
+   * @param methodName The exported method name identifying the SignatureDef.
+   * @throws IllegalArgumentException if {@code inputs} or {@code outputs} or {@code methodName}is
+   *     null or empty, or if error occurs when running the inference.
+   *
+   * <p>WARNING: This is an experimental API and subject to change.
+   */
+  public void runSignature(
+      @NonNull Map<String, Object> inputs,
+      @NonNull Map<String, Object> outputs,
+      String methodName) {
+    checkNotClosed();
+    if (methodName == null && signatureNameList.length == 1) {
+      methodName = signatureNameList[0];
+    }
+    if (methodName == null) {
+      throw new IllegalArgumentException(
+          "Input error: SignatureDef methodName should not be null. null is only allowed if the"
+              + " model has a single Signature. Available Signatures: "
+              +  Arrays.toString(signatureNameList));
+    }
+    wrapper.runSignature(inputs, outputs, methodName);
+  }
+
+  /* Same as {@link Interpreter#runSignature(Object, Object, String)} but doesn't require
+   * passing a methodName, assuming the model has one SignatureDef. If the model has more than
+   * one SignatureDef it will throw an exception.
+   *
+   * * <p>WARNING: This is an experimental API and subject to change.
+   * */
+  public void runSignature(
+      @NonNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs) {
+    checkNotClosed();
+    runSignature(inputs, outputs, null);
+  }
+
+  /**
    * Expicitly updates allocations for all tensors, if necessary.
    *
    * <p>This will propagate shapes and memory allocations for all dependent tensors using the input
@@ -446,6 +498,36 @@
     return wrapper.getInputTensor(inputIndex);
   }
 
+  /**
+   * Gets the list of SignatureDef exported method names available in the model.
+   *
+   * <p>WARNING: This is an experimental API and subject to change.
+   */
+  public String[] getSignatureDefNames() {
+    checkNotClosed();
+    return wrapper.getSignatureDefNames();
+  }
+
+  /**
+   * Gets the list of SignatureDefs inputs for method {@code methodName}
+   *
+   * <p>WARNING: This is an experimental API and subject to change.
+   */
+  public String[] getSignatureInputs(String methodName) {
+    checkNotClosed();
+    return wrapper.getSignatureInputs(methodName);
+  }
+
+  /**
+   * Gets the list of SignatureDefs outputs for method {@code methodName}
+   *
+   * <p>WARNING: This is an experimental API and subject to change.
+   */
+  public String[] getSignatureOutputs(String methodName) {
+    checkNotClosed();
+    return wrapper.getSignatureOutputs(methodName);
+  }
+
   /** Gets the number of output Tensors. */
   public int getOutputTensorCount() {
     checkNotClosed();
@@ -494,8 +576,8 @@
   /**
    * Sets the number of threads to be used for ops that support multi-threading.
    *
-   * @deprecated Prefer using {@link Options#setNumThreads(int)} directly for controlling thread
-   *     multi-threading. This method will be removed in a future release.
+   * @deprecated Prefer using {@link Interpreter.Options#setNumThreads(int)} directly for
+   *     controlling thread multi-threading. This method will be removed in a future release.
    */
   @Deprecated
   public void setNumThreads(int numThreads) {
@@ -507,8 +589,8 @@
    * Advanced: Modifies the graph with the provided {@link Delegate}.
    *
    * @throws IllegalArgumentException if error occurs when modifying graph with {@code delegate}.
-   * @deprecated Prefer using {@link Options#addDelegate} to provide delegates at creation time.
-   *     This method will be removed in a future release.
+   * @deprecated Prefer using {@link Interpreter.Options#addDelegate} to provide delegates at
+   *     creation time. This method will be removed in a future release.
    */
   @Deprecated
   public void modifyGraphWithDelegate(Delegate delegate) {
@@ -580,4 +662,5 @@
   }
 
   NativeInterpreterWrapper wrapper;
+  String[] signatureNameList;
 }
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 1eaaafd..a8006ce 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -22,6 +22,7 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.TreeMap;
 import org.tensorflow.lite.nnapi.NnApiDelegate;
 
 /**
@@ -30,6 +31,8 @@
  * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
  * explicitly freed by invoking the {@link #close()} method when the {@code
  * NativeInterpreterWrapper} object is no longer needed.
+ *
+ * Note: This class is not thread safe.
  */
 final class NativeInterpreterWrapper implements AutoCloseable {
 
@@ -136,6 +139,36 @@
     ownedDelegates.clear();
   }
 
+  public void runSignature(
+      Map<String, Object> inputs, Map<String, Object> outputs, String methodName) {
+    if (inputs == null || inputs.isEmpty()) {
+      throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
+    }
+    if (outputs == null || outputs.isEmpty()) {
+      throw new IllegalArgumentException("Input error: Outputs should not be null or empty.");
+    }
+    initTensorIndexesMaps();
+    // Map inputs/output to input indexes.
+    Map<Integer, Object> inputsWithInputIndex = new TreeMap<>();
+    Map<Integer, Object> outputsWithOutputIndex = new TreeMap<>();
+    for (Map.Entry<String, Object> input : inputs.entrySet()) {
+      int tensorIndex =
+          getInputTensorIndexFromSignature(interpreterHandle, input.getKey(), methodName);
+      inputsWithInputIndex.put(tensorToInputsIndexes.get(tensorIndex), input.getValue());
+    }
+    for (Map.Entry<String, Object> output : outputs.entrySet()) {
+      int tensorIndex =
+          getOutputTensorIndexFromSignature(interpreterHandle, output.getKey(), methodName);
+      outputsWithOutputIndex.put(tensorToOutputsIndexes.get(tensorIndex), output.getValue());
+    }
+    Object[] inputsList = new Object[inputs.size()];
+    int index = 0;
+    for (Map.Entry<Integer, Object> input : inputsWithInputIndex.entrySet()) {
+      inputsList[index++] = input.getValue();
+    }
+    run(inputsList, outputsWithOutputIndex);
+  }
+
   /** Sets inputs, runs model inference and returns outputs. */
   void run(Object[] inputs, Map<Integer, Object> outputs) {
     inferenceDurationNanoseconds = -1;
@@ -257,7 +290,26 @@
           String.format(
               "Input error: '%s' is not a valid name for any input. Names of inputs and their "
                   + "indexes are %s",
-              name, inputsIndexes.toString()));
+              name, inputsIndexes));
+    }
+  }
+
+  /** Initializes mapping from tensor index to input/output index. **/
+  private void initTensorIndexesMaps() {
+    if (tensorToInputsIndexes != null) {
+      return;
+    }
+    tensorToInputsIndexes = new HashMap<>();
+    tensorToOutputsIndexes = new HashMap<>();
+    int inputCount = getInputTensorCount();
+    for (int i = 0; i < inputCount; ++i) {
+      int tensorIndex = getInputTensorIndex(interpreterHandle, i);
+      tensorToInputsIndexes.put(tensorIndex, i);
+    }
+    int outputCount = getOutputTensorCount();
+    for (int i = 0; i < outputCount; ++i) {
+      int tensorIndex = getOutputTensorIndex(interpreterHandle, i);
+      tensorToOutputsIndexes.put(tensorIndex, i);
     }
   }
 
@@ -279,7 +331,7 @@
           String.format(
               "Input error: '%s' is not a valid name for any output. Names of outputs and their "
                   + "indexes are %s",
-              name, outputsIndexes.toString()));
+              name, outputsIndexes));
     }
   }
 
@@ -314,6 +366,27 @@
     return inputTensor;
   }
 
+  /** Gets the list of SignatureDefs available in the model, if any. */
+  public String[] getSignatureDefNames() {
+    return getSignatureDefNames(interpreterHandle);
+  }
+
+  private static native String[] getSignatureDefNames(long interpreterHandle);
+
+  /** Gets the list of SignatureDefs inputs for method {@code methodName} */
+  String[] getSignatureInputs(String methodName) {
+    return getSignatureInputs(interpreterHandle, methodName);
+  }
+
+  private static native String[] getSignatureInputs(long interpreterHandle, String methodName);
+
+  /** Gets the list of SignatureDefs outputs for method {@code methodName} */
+  String[] getSignatureOutputs(String methodName) {
+    return getSignatureOutputs(interpreterHandle, methodName);
+  }
+
+  private static native String[] getSignatureOutputs(long interpreterHandle, String methodName);
+
   /** Gets the number of output tensors. */
   int getOutputTensorCount() {
     return outputTensors.length;
@@ -430,6 +503,9 @@
   // Lazily constructed maps of input and output names to input and output Tensor indexes.
   private Map<String, Integer> inputsIndexes;
   private Map<String, Integer> outputsIndexes;
+  // Lazily constructed maps of tensor index to index in input and output indexes.
+  private Map<Integer, Integer> tensorToInputsIndexes;
+  private Map<Integer, Integer> tensorToOutputsIndexes;
 
   // Lazily constructed and populated arrays of input and output Tensor wrappers.
   private Tensor[] inputTensors;
@@ -448,6 +524,12 @@
 
   private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
 
+  private static native int getInputTensorIndexFromSignature(
+      long interpreterHandle, String signatureInputName, String methodName);
+
+  private static native int getOutputTensorIndexFromSignature(
+      long interpreterHandle, String signatureInputName, String methodName);
+
   private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx);
 
   private static native int getInputCount(long interpreterHandle);
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
index 3c2e7b4..d6e29f4 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
@@ -85,7 +85,7 @@
     }
   }
 
-  public static native String nativeRuntimeVersion();
+  private static native String nativeRuntimeVersion();
 
-  public static native String nativeSchemaVersion();
+  private static native String nativeSchemaVersion();
 }
diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD
index 8b6bad5..9dc0010 100644
--- a/tensorflow/lite/java/src/main/native/BUILD
+++ b/tensorflow/lite/java/src/main/native/BUILD
@@ -26,51 +26,36 @@
         "-ldl",
     ],
     deps = [
-        "//tensorflow/lite:framework",
+        "//tensorflow/lite:op_resolver",
         "//tensorflow/lite:schema_fbs_version",
         "//tensorflow/lite:string_util",
         "//tensorflow/lite:util",
-        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core/shims:common",
+        "//tensorflow/lite/core/shims:framework",
         "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
         "//tensorflow/lite/java/jni",
     ],
     alwayslink = 1,
 )
 
-# This includes all ops. If you want a smaller binary, you should copy and
-# modify builtin_ops_jni.cc.  You should then link your binary against both
-# ":native_framework_only" and your own version of ":native_builtin_ops".
+# This includes all ops. If you want a smaller binary, you should use
+# tflite_custom_cc_library or tflite_custom_android_library rules.
 cc_library(
     name = "native",
-    srcs = [
-        "builtin_ops_jni.cc",
-    ],
-    hdrs = ["op_resolver.h"],
     copts = tflite_copts(),
     deps = [
         ":native_framework_only",
-        "//tensorflow/lite:framework",
+        "//tensorflow/lite:create_op_resolver_with_builtin_ops",
         "//tensorflow/lite/core/api",
-        "//tensorflow/lite/kernels:builtin_ops",
+        "//tensorflow/lite/core/shims:builtin_ops",
+        "//tensorflow/lite/core/shims:framework",
     ],
     alwayslink = 1,
 )
 
-# TODO(b/153652701): Generate this target to give CreateOpResolver a custom namespace.
-cc_library(
-    name = "selected_ops_jni",
-    srcs = ["selected_ops_jni.cc"],
-    hdrs = ["op_resolver.h"],
-    copts = tflite_copts(),
-    deps = [
-        "//tensorflow/lite:framework",
-    ],
-)
-
 exports_files(
     [
         "exported_symbols.lds",
         "version_script.lds",
-        "op_resolver.h",
     ],
 )
diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 3551286..840985b 100644
--- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -21,40 +21,43 @@
 #include <atomic>
 #include <vector>
 
-#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/shims/c/common.h"
+#include "tensorflow/lite/core/shims/cc/interpreter.h"
+#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
+#include "tensorflow/lite/core/shims/cc/model_builder.h"
 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/interpreter_builder.h"
 #include "tensorflow/lite/java/src/main/native/jni_utils.h"
-#include "tensorflow/lite/model_builder.h"
 #include "tensorflow/lite/util.h"
 
 namespace tflite {
 // This is to be provided at link-time by a library.
-extern std::unique_ptr<OpResolver> CreateOpResolver();
+extern std::unique_ptr<MutableOpResolver> CreateOpResolver();
 }  // namespace tflite
 
 using tflite::jni::BufferErrorReporter;
 using tflite::jni::ThrowException;
+using tflite_shims::FlatBufferModel;
+using tflite_shims::Interpreter;
+using tflite_shims::InterpreterBuilder;
 
 namespace {
 
-tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
+Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
   if (handle == 0) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to Interpreter.");
     return nullptr;
   }
-  return reinterpret_cast<tflite::Interpreter*>(handle);
+  return reinterpret_cast<Interpreter*>(handle);
 }
 
-tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
+FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
   if (handle == 0) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to model.");
     return nullptr;
   }
-  return reinterpret_cast<tflite::FlatBufferModel*>(handle);
+  return reinterpret_cast<FlatBufferModel*>(handle);
 }
 
 BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
@@ -155,17 +158,57 @@
   return tflite::VerifyModelBuffer(verifier);
 }
 
+// Helper method that fetches the tensor index based on SignatureDef details
+// from either inputs or outputs.
+// Returns -1 if invalid names are passed.
+int GetTensorIndexForSignature(JNIEnv* env, jstring signature_tensor_name,
+                               jstring method_name,
+                               tflite::Interpreter* interpreter,
+                               bool is_input) {
+  // Fetch name strings.
+  const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
+  const char* signature_input_name_ptr =
+      env->GetStringUTFChars(signature_tensor_name, nullptr);
+  // Lookup if the input is valid.
+  const auto& signature_list =
+      (is_input ? interpreter->signature_inputs(method_name_ptr)
+                : interpreter->signature_outputs(method_name_ptr));
+  const auto& tensor = signature_list.find(signature_input_name_ptr);
+  // Release the memory before returning.
+  env->ReleaseStringUTFChars(method_name, method_name_ptr);
+  env->ReleaseStringUTFChars(signature_tensor_name, signature_input_name_ptr);
+  return tensor == signature_list.end() ? -1 : tensor->second;
+}
+
+jobjectArray GetSignatureInputsOutputsList(
+    const std::map<std::string, uint32_t>& input_output_list, JNIEnv* env) {
+  jclass string_class = env->FindClass("java/lang/String");
+  if (string_class == nullptr) {
+    ThrowException(env, tflite::jni::kUnsupportedOperationException,
+                   "Internal error: Can not find java/lang/String class to get "
+                   "SignatureDef names.");
+    return nullptr;
+  }
+
+  jobjectArray names = env->NewObjectArray(input_output_list.size(),
+                                           string_class, env->NewStringUTF(""));
+  int i = 0;
+  for (const auto& input : input_output_list) {
+    env->SetObjectArrayElement(names, i++,
+                               env->NewStringUTF(input.first.c_str()));
+  }
+  return names;
+}
+
 }  // namespace
 
-#ifdef __cplusplus
 extern "C" {
-#endif
 
 JNIEXPORT jobjectArray JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
                                                                 jclass clazz,
                                                                 jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return nullptr;
   jclass string_class = env->FindClass("java/lang/String");
   if (string_class == nullptr) {
@@ -187,7 +230,7 @@
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
     JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
@@ -205,7 +248,7 @@
 JNIEXPORT jboolean JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
     JNIEnv* env, jclass clazz, jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return JNI_FALSE;
 
   // TODO(b/132995737): Remove this logic by caching whether an unresolved
@@ -225,10 +268,78 @@
   return JNI_FALSE;
 }
 
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureDefNames(
+    JNIEnv* env, jclass clazz, jlong handle) {
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  if (interpreter == nullptr) return nullptr;
+  jclass string_class = env->FindClass("java/lang/String");
+  if (string_class == nullptr) {
+    ThrowException(env, tflite::jni::kUnsupportedOperationException,
+                   "Internal error: Can not find java/lang/String class to get "
+                   "SignatureDef names.");
+    return nullptr;
+  }
+  const auto& signature_defs = interpreter->signature_def_names();
+  jobjectArray names = static_cast<jobjectArray>(env->NewObjectArray(
+      signature_defs.size(), string_class, env->NewStringUTF("")));
+  for (int i = 0; i < signature_defs.size(); ++i) {
+    env->SetObjectArrayElement(names, i,
+                               env->NewStringUTF(signature_defs[i]->c_str()));
+  }
+  return names;
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureInputs(
+    JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  if (interpreter == nullptr) return nullptr;
+  const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
+  const jobjectArray signature_inputs = GetSignatureInputsOutputsList(
+      interpreter->signature_inputs(method_name_ptr), env);
+  // Release the memory before returning.
+  env->ReleaseStringUTFChars(method_name, method_name_ptr);
+  return signature_inputs;
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureOutputs(
+    JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  if (interpreter == nullptr) return nullptr;
+  const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
+  const jobjectArray signature_outputs = GetSignatureInputsOutputsList(
+      interpreter->signature_outputs(method_name_ptr), env);
+  // Release the memory before returning.
+  env->ReleaseStringUTFChars(method_name, method_name_ptr);
+  return signature_outputs;
+}
+
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndexFromSignature(
+    JNIEnv* env, jclass clazz, jlong handle, jstring signature_input_name,
+    jstring method_name) {
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  if (interpreter == nullptr) return -1;
+  return GetTensorIndexForSignature(env, signature_input_name, method_name,
+                                    interpreter, /*is_input=*/true);
+}
+
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndexFromSignature(
+    JNIEnv* env, jclass clazz, jlong handle, jstring signature_output_name,
+    jstring method_name) {
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  if (interpreter == nullptr) return -1;
+  return GetTensorIndexForSignature(env, signature_output_name, method_name,
+                                    interpreter, /*is_input=*/false);
+}
+
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
     JNIEnv* env, jclass clazz, jlong handle, jint input_index) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return interpreter->inputs()[input_index];
 }
@@ -236,7 +347,7 @@
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
     JNIEnv* env, jclass clazz, jlong handle, jint output_index) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return interpreter->outputs()[output_index];
 }
@@ -244,7 +355,7 @@
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength(
     JNIEnv* env, jclass clazz, jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return static_cast<jint>(interpreter->execution_plan().size());
 }
@@ -253,7 +364,7 @@
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
                                                                 jclass clazz,
                                                                 jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return static_cast<jint>(interpreter->inputs().size());
 }
@@ -262,7 +373,7 @@
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
                                                                  jclass clazz,
                                                                  jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return static_cast<jint>(interpreter->outputs().size());
 }
@@ -271,7 +382,7 @@
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
                                                                  jclass clazz,
                                                                  jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return nullptr;
   jclass string_class = env->FindClass("java/lang/String");
   if (string_class == nullptr) {
@@ -293,7 +404,7 @@
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
     JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow));
 }
@@ -301,7 +412,7 @@
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
     JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   interpreter->SetAllowBufferHandleOutput(allow);
 }
@@ -315,7 +426,7 @@
     return;
   }
 
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) {
     return;
   }
@@ -343,8 +454,8 @@
     if (num_threads > 0) {
       options.num_threads = num_threads;
     }
-    tflite::Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options),
-                                                    xnnpack_delete);
+    Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options),
+                                            xnnpack_delete);
     auto delegation_status =
         interpreter->ModifyGraphWithDelegate(std::move(delegate));
     // kTfLiteApplicationError occurs in cases where delegation fails but
@@ -376,7 +487,7 @@
                                                              jclass clazz,
                                                              jlong handle,
                                                              jint num_threads) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   interpreter->SetNumThreads(static_cast<int>(num_threads));
 }
@@ -413,8 +524,8 @@
   std::unique_ptr<tflite::TfLiteVerifier> verifier;
   verifier.reset(new JNIFlatBufferVerifier());
 
-  auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(
-      path, verifier.get(), error_reporter);
+  auto model = FlatBufferModel::VerifyAndBuildFromFile(path, verifier.get(),
+                                                       error_reporter);
   if (!model) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Contents of %s does not encode a valid "
@@ -442,7 +553,7 @@
     return 0;
   }
 
-  auto model = tflite::FlatBufferModel::BuildFromBuffer(
+  auto model = FlatBufferModel::BuildFromBuffer(
       buf, static_cast<size_t>(capacity), error_reporter);
   if (!model) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
@@ -457,14 +568,14 @@
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
     JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
     jint num_threads) {
-  tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
+  FlatBufferModel* model = convertLongToModel(env, model_handle);
   if (model == nullptr) return 0;
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
   if (error_reporter == nullptr) return 0;
   auto resolver = ::tflite::CreateOpResolver();
-  std::unique_ptr<tflite::Interpreter> interpreter;
-  TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))(
+  std::unique_ptr<Interpreter> interpreter;
+  TfLiteStatus status = InterpreterBuilder(*model, *(resolver.get()))(
       &interpreter, static_cast<int>(num_threads));
   if (status != kTfLiteOk) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
@@ -480,8 +591,7 @@
 // Sets inputs, runs inference, and returns outputs as long handles.
 JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
-  tflite::Interpreter* interpreter =
-      convertLongToInterpreter(env, interpreter_handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return;
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
@@ -499,7 +609,7 @@
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
     JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return -1;
   const int idx = static_cast<int>(output_idx);
   if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
@@ -520,8 +630,7 @@
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
   if (error_reporter == nullptr) return JNI_FALSE;
-  tflite::Interpreter* interpreter =
-      convertLongToInterpreter(env, interpreter_handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return JNI_FALSE;
   if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
@@ -557,8 +666,7 @@
 Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
     jlong delegate_handle) {
-  tflite::Interpreter* interpreter =
-      convertLongToInterpreter(env, interpreter_handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return;
 
   BufferErrorReporter* error_reporter =
@@ -579,8 +687,7 @@
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
-  tflite::Interpreter* interpreter =
-      convertLongToInterpreter(env, interpreter_handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return;
 
   BufferErrorReporter* error_reporter =
@@ -598,8 +705,7 @@
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag(
     JNIEnv* env, jclass clazz, jlong interpreter_handle) {
-  tflite::Interpreter* interpreter =
-      convertLongToInterpreter(env, interpreter_handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to interpreter.");
@@ -646,6 +752,4 @@
   }
 }
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif
diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc
index 24a13bb..00f2a69 100644
--- a/tensorflow/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc
@@ -19,12 +19,13 @@
 #include <memory>
 #include <string>
 
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/core/shims/c/common.h"
+#include "tensorflow/lite/core/shims/cc/interpreter.h"
 #include "tensorflow/lite/java/src/main/native/jni_utils.h"
 #include "tensorflow/lite/string_util.h"
 
 using tflite::jni::ThrowException;
+using tflite_shims::Interpreter;
 
 namespace {
 
@@ -39,14 +40,14 @@
 // invalidate all TfLiteTensor* handles during inference or allocation.
 class TensorHandle {
  public:
-  TensorHandle(tflite::Interpreter* interpreter, int tensor_index)
+  TensorHandle(Interpreter* interpreter, int tensor_index)
       : interpreter_(interpreter), tensor_index_(tensor_index) {}
 
   TfLiteTensor* tensor() const { return interpreter_->tensor(tensor_index_); }
   int index() const { return tensor_index_; }
 
  private:
-  tflite::Interpreter* const interpreter_;
+  Interpreter* const interpreter_;
   const int tensor_index_;
 };
 
@@ -392,14 +393,11 @@
 
 }  // namespace
 
-#ifdef __cplusplus
 extern "C" {
-#endif  // __cplusplus
 
 JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index) {
-  tflite::Interpreter* interpreter =
-      reinterpret_cast<tflite::Interpreter*>(interpreter_handle);
+  Interpreter* interpreter = reinterpret_cast<Interpreter*>(interpreter_handle);
   return reinterpret_cast<jlong>(new TensorHandle(interpreter, tensor_index));
 }
 
@@ -615,6 +613,4 @@
   return static_cast<jint>(tensor ? tensor->params.zero_point : 0);
 }
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif  // __cplusplus
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index f5b0217..9494d15 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -44,6 +44,8 @@
       "tensorflow/lite/testdata/dynamic_shapes.bin";
   private static final String BOOL_MODEL =
       "tensorflow/lite/java/src/testdata/tile_with_bool_input.bin";
+  private static final String MODEL_WITH_SIGNATURE_PATH =
+      "tensorflow/lite/java/src/testdata/mul_add_signature_def.bin";
 
   private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
   private static final ByteBuffer MULTIPLE_INPUTS_MODEL_BUFFER =
@@ -55,6 +57,8 @@
   private static final ByteBuffer DYNAMIC_SHAPES_MODEL_BUFFER =
       TestUtils.getTestFileAsBuffer(DYNAMIC_SHAPES_MODEL_PATH);
   private static final ByteBuffer BOOL_MODEL_BUFFER = TestUtils.getTestFileAsBuffer(BOOL_MODEL);
+  private static final ByteBuffer MODEL_WITH_SIGNATURE_BUFFER =
+      TestUtils.getTestFileAsBuffer(MODEL_WITH_SIGNATURE_PATH);
 
   @Test
   public void testInterpreter() throws Exception {
@@ -723,6 +727,91 @@
     }
   }
 
+  @Test
+  public void testModelWithSignatureDef() {
+    try (Interpreter interpreter = new Interpreter(MODEL_WITH_SIGNATURE_BUFFER)) {
+      String[] signatureNames = interpreter.getSignatureDefNames();
+      String[] expectedSignatureNames = {"mul_add"};
+      assertThat(signatureNames).isEqualTo(expectedSignatureNames);
+
+      String[] signatureInputs = interpreter.getSignatureInputs(expectedSignatureNames[0]);
+      String[] expectedSignatureInputs = {"x", "y"};
+      assertThat(signatureInputs).isEqualTo(expectedSignatureInputs);
+
+      String[] signatureOutputs = interpreter.getSignatureOutputs(expectedSignatureNames[0]);
+      String[] expectedSignatureOutputs = {"output_0"};
+      assertThat(signatureOutputs).isEqualTo(expectedSignatureOutputs);
+
+      FloatBuffer output = FloatBuffer.allocate(1);
+      float[] inputX = {2.0f};
+      float[] inputY = {4.0f};
+      Map<String, Object> inputs = new HashMap<>();
+      inputs.put("x", inputX);
+      inputs.put("y", inputY);
+      Map<String, Object> outputs = new HashMap<>();
+      outputs.put("output_0", output);
+      interpreter.runSignature(inputs, outputs, "mul_add");
+      // Result should be x * 3.0 + y
+      FloatBuffer expected = fill(FloatBuffer.allocate(1), 10.0f);
+      assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
+    }
+  }
+
+  @Test
+  public void testModelWithSignatureDefNullMethodName() {
+    try (Interpreter interpreter = new Interpreter(MODEL_WITH_SIGNATURE_BUFFER)) {
+      String[] signatureNames = interpreter.getSignatureDefNames();
+      String[] expectedSignatureNames = {"mul_add"};
+      assertThat(signatureNames).isEqualTo(expectedSignatureNames);
+
+      String[] signatureInputs = interpreter.getSignatureInputs(expectedSignatureNames[0]);
+      String[] expectedSignatureInputs = {"x", "y"};
+      assertThat(signatureInputs).isEqualTo(expectedSignatureInputs);
+
+      String[] signatureOutputs = interpreter.getSignatureOutputs(expectedSignatureNames[0]);
+      String[] expectedSignatureOutputs = {"output_0"};
+      assertThat(signatureOutputs).isEqualTo(expectedSignatureOutputs);
+
+      FloatBuffer output = FloatBuffer.allocate(1);
+      float[] inputX = {2.0f};
+      float[] inputY = {4.0f};
+      Map<String, Object> inputs = new HashMap<>();
+      inputs.put("x", inputX);
+      inputs.put("y", inputY);
+      Map<String, Object> outputs = new HashMap<>();
+      outputs.put("output_0", output);
+      interpreter.runSignature(inputs, outputs, null);
+      // Result should be x * 3.0 + y
+      FloatBuffer expected = fill(FloatBuffer.allocate(1), 10.0f);
+      assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
+      output = FloatBuffer.allocate(1);
+      outputs.put("output_0", output);
+      interpreter.runSignature(inputs, outputs);
+      assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
+    }
+  }
+
+  @Test
+  public void testModelWithSignatureDefNoSignatures() {
+    try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
+      String[] signatureNames = interpreter.getSignatureDefNames();
+      String[] expectedSignatureNames = {};
+      assertThat(signatureNames).isEqualTo(expectedSignatureNames);
+      Map<String, Object> inputs = new HashMap<>();
+      Map<String, Object> outputs = new HashMap<>();
+      try {
+        interpreter.runSignature(inputs, outputs);
+        fail();
+      } catch (IllegalArgumentException e) {
+        assertThat(e)
+            .hasMessageThat()
+            .contains(
+                "Input error: SignatureDef methodName should not be null. null is only allowed if"
+                    + " the model has a single Signature");
+      }
+    }
+  }
+
   private static native long getNativeHandleForDelegate();
 
   private static native long getNativeHandleForInvalidDelegate();
diff --git a/tensorflow/lite/java/src/test/native/interpreter_test_jni.cc b/tensorflow/lite/java/src/test/native/interpreter_test_jni.cc
index 7981a8b..de12df9 100644
--- a/tensorflow/lite/java/src/test/native/interpreter_test_jni.cc
+++ b/tensorflow/lite/java/src/test/native/interpreter_test_jni.cc
@@ -20,9 +20,7 @@
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 
-#ifdef __cplusplus
 extern "C" {
-#endif
 
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate(
@@ -97,6 +95,4 @@
   return reinterpret_cast<jlong>(&delegate);
 }
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif  // __cplusplus
diff --git a/tensorflow/lite/java/src/testdata/mul_add_signature_def.bin b/tensorflow/lite/java/src/testdata/mul_add_signature_def.bin
new file mode 100644
index 0000000..fe06d1d
--- /dev/null
+++ b/tensorflow/lite/java/src/testdata/mul_add_signature_def.bin
Binary files differ
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index 67fc33b..f730f7e 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -320,6 +320,18 @@
     }),
 )
 
+# Provide a library for clients to link to if they need to stay on deprecated
+# arithmetic backends. Include as a dependency of cpu_backend_gemm to start.
+# TODO(b/168923364): Move to dependent targets.
+cc_library(
+    name = "deprecated_backends",
+    srcs = [
+        "deprecated_backends.cc",
+    ],
+    compatible_with = get_compatible_with_portable(),
+    alwayslink = 1,
+)
+
 cc_library(
     name = "cpu_backend_context",
     srcs = [
@@ -337,6 +349,9 @@
         "//conditions:default": ["-DTFLITE_HAVE_CPUINFO"],
     }),
     deps = [
+        # TODO(b/168923364): Remove deprecated_backends after it is added to all
+        # necessary targets.
+        ":deprecated_backends",
         ":tflite_with_ruy",
         ":op_macros",
         # For now this unconditionally depends on both ruy and gemmlowp.
@@ -345,6 +360,7 @@
         "@ruy//ruy:context",
         "@gemmlowp",
         "//tensorflow/lite/c:common",
+        "//tensorflow/lite:macros",
         "//tensorflow/lite:external_cpu_backend_context",
         "//tensorflow/lite/kernels/internal:compatibility",
     ] + select({
@@ -1141,14 +1157,17 @@
     srcs = ["numeric_verify_test.cc"],
     tags = ["tflite_nnapi"],
     deps = [
+        ":kernel_util",
         ":test_main",
         ":test_util",
         "//tensorflow/lite:framework",
+        "//tensorflow/lite/kernels/internal:reference",
         "//tensorflow/lite/kernels/internal:types",
         "//tensorflow/lite/schema:schema_fbs",
         "//third_party/eigen3",
         "@com_google_absl//absl/memory",
         "@com_google_googletest//:gtest",
+        "@flatbuffers",
     ],
 )
 
@@ -2239,6 +2258,7 @@
         ":builtin_ops",
         ":kernel_util",
         ":variable_op_kernels",
+        "//tensorflow/lite:builtin_ops",
         "//tensorflow/lite:framework",
         "//tensorflow/lite/c:common",
         "@com_google_googletest//:gtest",
diff --git a/tensorflow/lite/kernels/arg_min_max.cc b/tensorflow/lite/kernels/arg_min_max.cc
index f782f94..03a6961 100644
--- a/tensorflow/lite/kernels/arg_min_max.cc
+++ b/tensorflow/lite/kernels/arg_min_max.cc
@@ -37,7 +37,13 @@
 
 TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* input,
                           const TfLiteTensor* axis, TfLiteTensor* output) {
-  int axis_value = *GetTensorData<int>(axis);
+  int axis_value;
+  // Retrive all 8 bytes when axis type is kTfLiteInt64 to avoid data loss.
+  if (axis->type == kTfLiteInt64) {
+    axis_value = static_cast<int>(*GetTensorData<int64_t>(axis));
+  } else {
+    axis_value = *GetTensorData<int>(axis);
+  }
   if (axis_value < 0) {
     axis_value += NumDimensions(input);
   }
diff --git a/tensorflow/lite/kernels/batch_matmul.cc b/tensorflow/lite/kernels/batch_matmul.cc
index 5f6afa3..23c2833 100644
--- a/tensorflow/lite/kernels/batch_matmul.cc
+++ b/tensorflow/lite/kernels/batch_matmul.cc
@@ -450,6 +450,8 @@
                         TfLiteTensor* scaling_factors,
                         TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
                         TfLiteTensor* input_offsets, TfLiteTensor* output) {
+  const auto* params =
+      reinterpret_cast<TfLiteBatchMatMulParams*>(node->builtin_data);
   const int32_t num_input_dims = input_shape.DimensionsCount();
 
   // Input row/cols have been swapped at this point, so dims are
@@ -465,18 +467,20 @@
   float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
   int32_t* input_offset_ptr = nullptr;
   int32_t* row_sums_ptr = nullptr;
-  // Only asymmetric quantization is supported.
   input_offset_ptr = GetTensorData<int32_t>(input_offsets);
   row_sums_ptr = GetTensorData<int32_t>(row_sums);
+  if (!params->asymmetric_quantize_inputs) {
+    memset(input_offset_ptr, 0, input_offsets->bytes);
+  }
   int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
   const int8_t* filter_data = GetTensorData<int8_t>(filter);
   const float* input_ptr = GetTensorData<float>(input);
   // Quantize each batch independently.
+  tensor_utils::BatchQuantizeFloats(input_ptr, num_batches_to_quantize,
+                                    input_size, quant_data, scaling_factors_ptr,
+                                    input_offset_ptr,
+                                    params->asymmetric_quantize_inputs);
   for (int b = 0; b < num_batches_to_quantize; ++b) {
-    const int offset = b * input_size;
-    tensor_utils::AsymmetricQuantizeFloats(
-        input_ptr + offset, input_size, quant_data + offset,
-        &scaling_factors_ptr[b], &input_offset_ptr[b]);
     // Incorporate scaling of the filter.
     scaling_factors_ptr[b] *= filter->params.scale;
   }
diff --git a/tensorflow/lite/kernels/batch_matmul_test.cc b/tensorflow/lite/kernels/batch_matmul_test.cc
index 7abef73..2975069 100644
--- a/tensorflow/lite/kernels/batch_matmul_test.cc
+++ b/tensorflow/lite/kernels/batch_matmul_test.cc
@@ -281,12 +281,12 @@
 
 // In the hybrid model the weights are quantized int8. But the input
 // and output are expected to be in float precision.
-class HybridAsymmetricBatchMatMulOpModel : public SingleOpModel {
+class HybridBatchMatMulOpModel : public SingleOpModel {
  public:
-  HybridAsymmetricBatchMatMulOpModel(
-      int units, int batches, const TensorData& lhs, const TensorData& rhs,
-      const TensorData& output = {TensorType_FLOAT32}, bool adj_x = false,
-      bool adj_y = false)
+  HybridBatchMatMulOpModel(int units, int batches, const TensorData& lhs,
+                           const TensorData& rhs,
+                           const TensorData& output = {TensorType_FLOAT32},
+                           bool asymmetric_quantize_inputs = true)
       : units_(units), batches_(batches) {
     int total_input_size = 1;
     for (size_t i = 0; i < lhs.shape.size(); ++i) {
@@ -299,9 +299,11 @@
 
     output_id_ = AddOutput(output);
 
-    SetBuiltinOp(BuiltinOperator_BATCH_MATMUL,
-                 BuiltinOptions_BatchMatMulOptions,
-                 CreateBatchMatMulOptions(builder_, adj_x, adj_y).Union());
+    SetBuiltinOp(
+        BuiltinOperator_BATCH_MATMUL, BuiltinOptions_BatchMatMulOptions,
+        CreateBatchMatMulOptions(builder_, /*adj_x=*/false, /*adj_y=*/false,
+                                 asymmetric_quantize_inputs)
+            .Union());
     BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)});
   }
   void SetWeights(const std::vector<float>& data) {
@@ -340,7 +342,7 @@
 };
 
 TEST_P(HybridAsymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) {
-  HybridAsymmetricBatchMatMulOpModel m(
+  HybridBatchMatMulOpModel m(
       /*units=*/3, /*batches=*/2,
       /*lhs=*/{TensorType_FLOAT32, {2, 10}},
       /*rhs=*/{TensorType_INT8, {10, 3}, 0, 0, 10.0 / 127.0, 0});
@@ -371,7 +373,7 @@
 }
 
 TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) {
-  HybridAsymmetricBatchMatMulOpModel m(
+  HybridBatchMatMulOpModel m(
       /*units=*/3, /*batches=*/2,
       /*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
       /*rhs=*/{TensorType_INT8, {10, 3}, 0, 0, 10.0 / 127.0, 0});
@@ -402,7 +404,7 @@
 }
 
 TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
-  HybridAsymmetricBatchMatMulOpModel m(
+  HybridBatchMatMulOpModel m(
       /*units=*/9, /*batches=*/2,
       /*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
       /*rhs=*/{TensorType_INT8, {10, 9}, 0, 0, 10.0 / 127.0, 0});
@@ -437,7 +439,7 @@
 }
 
 TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {
-  HybridAsymmetricBatchMatMulOpModel m(
+  HybridBatchMatMulOpModel m(
       /*units=*/3, /*batches=*/2,
       /*lhs=*/{TensorType_FLOAT32, {2, 10}},
       /*rhs=*/{TensorType_INT8, {2, 10, 3}, 0, 0, 10.0 / 127.0, 0});
@@ -470,6 +472,148 @@
     HybridAsymmetricBatchMatMulOpTest, HybridAsymmetricBatchMatMulOpTest,
     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
 
+class HybridSymmetricBatchMatMulOpTest : public SingleOpTest {
+ protected:
+  const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+    return *kKernelMap;
+  }
+};
+
+TEST_P(HybridSymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) {
+  HybridBatchMatMulOpModel m(
+      /*units=*/3, /*batches=*/2,
+      /*lhs=*/{TensorType_FLOAT32, {2, 10}},
+      /*rhs=*/{TensorType_INT8, {10, 3}, 0, 0, 10.0 / 127.0, 0},
+      /*output=*/{TensorType_FLOAT32}, /*asymmetric_quantize_inputs=*/false);
+
+  m.SetSignedWeights({
+      1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5,  5,  5,
+      6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
+  });
+
+  m.SetInput({
+      11, 12, 13, 14, 15, 16, 17, 18,  -19, -20,  // batch 1, 0
+      11, 12, 13, 14, 15, 16, 17, -18, 19,  -20,  // batch 1, 1
+  });
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+                                 {
+                                     194,
+                                     194,
+                                     194,
+                                     248,
+                                     248,
+                                     248,
+                                 },
+                                 /*max_abs_error=*/0.64f)));
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
+}
+
+TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) {
+  HybridBatchMatMulOpModel m(
+      /*units=*/3, /*batches=*/2,
+      /*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
+      /*rhs=*/{TensorType_INT8, {10, 3}, 0, 0, 10.0 / 127.0, 0},
+      /*output=*/{TensorType_FLOAT32}, /*asymmetric_quantize_inputs=*/false);
+
+  m.SetSignedWeights({
+      1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5,  5,  5,
+      6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
+  });
+
+  m.SetInput({
+      1,  2,  3,  4,  5,  6,  7,  8,   -9,  -10,  // batch 0, 0
+      1,  2,  3,  4,  5,  6,  7,  -8,  9,   -10,  // batch 0, 1
+      11, 12, 13, 14, 15, 16, 17, 18,  -19, -20,  // batch 1, 0
+      11, 12, 13, 14, 15, 16, 17, -18, 19,  -20,  // batch 1, 1
+  });
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+                                 {
+                                     24, 24, 24,     //
+                                     56, 56, 56,     //
+                                     194, 194, 194,  //
+                                     248, 248, 248,  //
+                                 },
+                                 /*max_abs_error=*/1.3f)));
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
+}
+
+TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
+  HybridBatchMatMulOpModel m(
+      /*units=*/9, /*batches=*/2,
+      /*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
+      /*rhs=*/{TensorType_INT8, {10, 9}, 0, 0, 10.0 / 127.0, 0},
+      {TensorType_FLOAT32}, false);
+
+  m.SetSignedWeights({
+      1, 1, 1, 17, 17, 17, 26, 26, 26, 2,  2,  2,  18, 18, 18, 27, 27, 27,
+      3, 3, 3, 19, 19, 19, 28, 28, 28, 4,  4,  4,  20, 20, 20, 29, 29, 29,
+      5, 5, 5, 21, 21, 21, 30, 30, 30, 6,  6,  6,  22, 22, 22, 31, 31, 31,
+      7, 7, 7, 23, 23, 23, 32, 32, 32, 8,  8,  8,  24, 24, 24, 33, 33, 33,
+      9, 9, 9, 25, 25, 25, 34, 34, 34, 10, 10, 10, 26, 26, 26, 35, 35, 35,
+  });
+
+  m.SetInput({
+      1,  2,  3,  4,  5,  6,  7,  8,   -9,  -10,  // batch 0, 0
+      1,  2,  3,  4,  5,  6,  7,  -8,  9,   -10,  // batch 0, 1
+      11, 12, 13, 14, 15, 16, 17, 18,  -19, -20,  // batch 1, 0
+      11, 12, 13, 14, 15, 16, 17, -18, 19,  -20,  // batch 1, 1
+  });
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(),
+              ElementsAreArray(ArrayFloatNear(
+                  {
+                      23,  23,  23,  296,  296,  296,  451,  451,  451,   //
+                      58,  58,  58,  362,  362,  362,  529,  529,  529,   //
+                      193, 193, 193, 1424, 1424, 1424, 2118, 2118, 2118,  //
+                      253, 253, 253, 1519, 1519, 1519, 2223, 2223, 2223   //
+                  },
+                  /*max_abs_error=*/1.3f)));
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 9}));
+}
+
+TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {
+  HybridBatchMatMulOpModel m(
+      /*units=*/3, /*batches=*/2,
+      /*lhs=*/{TensorType_FLOAT32, {2, 10}},
+      /*rhs=*/{TensorType_INT8, {2, 10, 3}, 0, 0, 10.0 / 127.0, 0},
+      {TensorType_FLOAT32}, false);
+
+  m.SetSignedWeights({
+      1, -3, 1, 2, -2, 2, 3, -1, 3, 4,  0, 4, 5, 1, 5, 6, 2, 6,  7,  3,
+      7, 8,  4, 8, 9,  5, 9, 10, 6, 10, 1, 1, 1, 2, 2, 2, 3, 3,  3,  4,
+      4, 4,  5, 5, 5,  6, 6, 6,  7, 7,  7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
+  });
+
+  m.SetInput({
+      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // batch 0, 0
+      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // batch 0, 1
+  });
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+                                 {
+                                     24, -45, 24,  //
+                                     56, -19, 56,  //
+                                     24, 24, 24,   //
+                                     56, 56, 56,   //
+                                 },
+                                 /*max_abs_error=*/0.64f)));
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    HybridSymmetricBatchMatMulOpTest, HybridSymmetricBatchMatMulOpTest,
+    ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
 class QuantizedBatchMatMulOpModel : public SingleOpModel {
  public:
   QuantizedBatchMatMulOpModel(int units, int batches, const TensorData& lhs,
diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc
index e683a2a..063cf7d 100644
--- a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -856,7 +856,7 @@
 
 // TODO(mirkov): add another test which directly compares to TF once TOCO
 // supports the conversion from dynamic_rnn with BasicRNNCell.
-TEST_P(BidirectionalRNNOpTest, BlackBoxTest) {
+TEST_P(BidirectionalRNNOpTest, ClosedBoxTest) {
   auto params = GetParam();
   const bool quantize_weights = std::get<0>(params);
   const bool asymmetric_quantize_inputs = std::get<1>(params);
@@ -903,8 +903,8 @@
                   bw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
 }
 
-// Same as BlackBox test, but input is reshuffled to time_major format.
-TEST_P(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
+// Same as ClosedBox test, but input is reshuffled to time_major format.
+TEST_P(BidirectionalRNNOpTest, ClosedBoxTestTimeMajor) {
   auto params = GetParam();
   const bool quantize_weights = std::get<0>(params);
   const bool asymmetric_quantize_inputs = std::get<1>(params);
@@ -950,8 +950,8 @@
           fw_expected, quantize_weights ? kHybridTolerance : kFloatTolerance)));
 }
 
-// Same as BlackBox test, yet with merged outputs.
-TEST_P(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
+// Same as ClosedBox test, yet with merged outputs.
+TEST_P(BidirectionalRNNOpTest, ClosedBoxTestMergeOutputs) {
   auto params = GetParam();
   const bool quantize_weights = std::get<0>(params);
   const bool asymmetric_quantize_inputs = std::get<1>(params);
@@ -995,8 +995,8 @@
                   merged_expected, quantize_weights ? 1.42e-2 : 1e-5)));
 }
 
-// Same as BlackBox test, but input is reshuffled to time_major format.
-TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajorMergeOutputs) {
+// Same as ClosedBox test, but input is reshuffled to time_major format.
+TEST(BidirectionalRNNOpTest, ClosedBoxTestTimeMajorMergeOutputs) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
                               /*input_size=*/8, /*aux_input_size=*/0,
@@ -1042,7 +1042,7 @@
 
 // Check that if the input sequence is reversed the outputs are the same just
 // forward and backward are swapped (and reversed).
-TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
+TEST(BidirectionalRNNOpTest, ClosedBoxTestReverseInputs) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
                               /*input_size=*/8, /*aux_input_size=*/0,
@@ -1163,11 +1163,11 @@
   }
 }
 
-// Same as BlackBox test, but has an auxiliary input. The layer has no
+// Same as ClosedBox test, but has an auxiliary input. The layer has no
 // cross-linking, i.e. the regular input is passed as an input to the forward
 // network only and the auxiliary input is passed as an input to the backward
 // network only.
-TEST(BidirectionalRNNOpTest, BlackBoxTestNoCrossLinkingRegularAndAuxInput) {
+TEST(BidirectionalRNNOpTest, ClosedBoxTestNoCrossLinkingRegularAndAuxInput) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
                               /*input_size=*/8, /*aux_input_size=*/8,
@@ -1216,7 +1216,7 @@
 
 // Same as above but the auxiliary input is set to zeroes. This test makes sure
 // that the forward network works as expected in a no-cross-linking mode.
-TEST(BidirectionalRNNOpTest, BlackBoxTestNoCrossLinkingRegularInputOnly) {
+TEST(BidirectionalRNNOpTest, ClosedBoxTestNoCrossLinkingRegularInputOnly) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
                               /*input_size=*/8, /*aux_input_size=*/8,
@@ -1264,7 +1264,7 @@
 // Same as above but the regular (i.e. not auxiliary) input is set to zeroes.
 // This test makes sure that the backward network works as expected in a
 // no-cross-linking mode.
-TEST(BidirectionalRNNOpTest, BlackBoxTestNoCrossLinkingAuxInputOnly) {
+TEST(BidirectionalRNNOpTest, ClosedBoxTestNoCrossLinkingAuxInputOnly) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
                               /*input_size=*/8, /*aux_input_size=*/8,
@@ -1309,9 +1309,9 @@
   EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
 }
 
-// Same as BlackBox test, but an input is passed to auxiliary input instead of
+// Same as ClosedBox test, but an input is passed to auxiliary input instead of
 // the regular one. Regular input and weights are set to zero.
-TEST(BidirectionalRNNOpTest, BlackBoxTestCrossLinkingAuxInputOnly) {
+TEST(BidirectionalRNNOpTest, ClosedBoxTestCrossLinkingAuxInputOnly) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
                               /*input_size=*/8, /*aux_input_size=*/8,
@@ -1358,10 +1358,10 @@
   EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
 }
 
-// Same as BlackBox test, but an input is passed to auxiliary input instead of
+// Same as ClosedBox test, but an input is passed to auxiliary input instead of
 // the regular one. Regular input and weights are set to zero. Time major inputs
 // and outputs.
-TEST(BidirectionalRNNOpTest, BlackBoxTestCrossLinkingAuxInputOnlyTimeMajor) {
+TEST(BidirectionalRNNOpTest, ClosedBoxTestCrossLinkingAuxInputOnlyTimeMajor) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
                               /*input_size=*/8, /*aux_input_size=*/8,
@@ -1408,7 +1408,7 @@
   EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
 }
 
-// Same as BlackBox test, but the input tensor and weights tensor are split
+// Same as ClosedBox test, but the input tensor and weights tensor are split
 // along the last dimension and passed to both regular and auxiliary inputs and
 // weights. The output in this case is the same. To understand this, let's
 // define W and V as regular input weights matrix and auxiliary input weights
@@ -1418,7 +1418,7 @@
 //   f(z) = Uz + b
 // is equivalent to:
 //   f((x^T|y^T)^T) = (Wx + Vy) + b.
-void run_blackbox_test_with_input_split(int input_size, int aux_input_size) {
+void run_closedbox_test_with_input_split(int input_size, int aux_input_size) {
   const int num_units = 16;
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/num_units, /*bw_units=*/num_units,
@@ -1498,14 +1498,14 @@
 }
 
 TEST(BidirectionalRNNOpTest,
-     BlackBoxTestCrossLinkingRegularAndAuxInputEvenSplit) {
-  run_blackbox_test_with_input_split(/*input_size=*/4, /*aux_input_size=*/4);
+     ClosedBoxTestCrossLinkingRegularAndAuxInputEvenSplit) {
+  run_closedbox_test_with_input_split(/*input_size=*/4, /*aux_input_size=*/4);
 }
 
 // Same as above but the input tensor and the weights tensor are split unevenly.
 TEST(BidirectionalRNNOpTest,
-     BlackBoxTestCrossLinkingRegularAndAuxInputUnevenSplit) {
-  run_blackbox_test_with_input_split(/*input_size=*/2, /*aux_input_size=*/6);
+     ClosedBoxTestCrossLinkingRegularAndAuxInputUnevenSplit) {
+  run_closedbox_test_with_input_split(/*input_size=*/2, /*aux_input_size=*/6);
 }
 
 }  // namespace
diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc
index 6eacb3d..c8d658e 100644
--- a/tensorflow/lite/kernels/cpu_backend_context.cc
+++ b/tensorflow/lite/kernels/cpu_backend_context.cc
@@ -24,6 +24,7 @@
 #include "public/gemmlowp.h"
 #include "ruy/context.h"  // from @ruy
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/macros.h"
 #include "tensorflow/lite/external_cpu_backend_context.h"
 #include "tensorflow/lite/kernels/internal/compatibility.h"
 #include "tensorflow/lite/kernels/op_macros.h"
@@ -35,7 +36,13 @@
 
 namespace tflite {
 
-#ifdef TFLITE_HAVE_CPUINFO
+// Use weak symbols if possible to dispatch to deprecated paths.
+#if TFLITE_HAS_ATTRIBUTE_WEAK && !defined(__APPLE__)
+extern TFLITE_ATTRIBUTE_WEAK bool UseGemmlowpOnX86();
+#endif  // defined(TFLITE_HAS_ATTRIBUTE_WEAK) && !(__APPLE__)
+
+// TODO(b/138922878) Enable when Ruy builds on Apple.
+#if defined(TFLITE_HAVE_CPUINFO) && !defined(__APPLE__)
 CpuBackendContext::CpuInfo::~CpuInfo() {
   if (init_status_ == InitStatus::kInitialized) {
     cpuinfo_deinitialize();
@@ -144,4 +151,15 @@
   return cpuinfo_.Avx() || cpuinfo_.Avx2Fma() || cpuinfo_.Avx512();
 }
 
+bool CpuBackendContext::PreferGemmlowpOnX86() {
+  bool use_gemmlowp_on_x86 = false;
+#if defined(TFLITE_X86_PLATFORM) && TFLITE_HAS_ATTRIBUTE_WEAK && \
+    !defined(__APPLE__)
+  if (::tflite::UseGemmlowpOnX86 != nullptr) {
+    use_gemmlowp_on_x86 = ::tflite::UseGemmlowpOnX86();
+  }
+#endif  // TFLITE_X86_PLATFORM && TFLITE_HAS_ATTRIBUTE_WEAK && !(__APPLE__)
+  return use_gemmlowp_on_x86 || !HasAvxOrAbove();
+}
+
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/cpu_backend_context.h b/tensorflow/lite/kernels/cpu_backend_context.h
index e020717..eda2712 100644
--- a/tensorflow/lite/kernels/cpu_backend_context.h
+++ b/tensorflow/lite/kernels/cpu_backend_context.h
@@ -16,6 +16,11 @@
 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
 
+#if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || \
+     defined(_M_X64))
+#define TFLITE_X86_PLATFORM
+#endif
+
 #include <memory>
 
 #include "public/gemmlowp.h"
@@ -52,6 +57,10 @@
 
   bool HasAvxOrAbove();
 
+  // Gemmlowp on x86 is a deprecated path but some clients may still use
+  // this path based on link time dependencies.
+  bool PreferGemmlowpOnX86();
+
  private:
   // Copy the wrapper class for cpuinfo from Ruy.
   class CpuInfo final {
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm.h b/tensorflow/lite/kernels/cpu_backend_gemm.h
index 6950e18..9c687f6 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm.h
+++ b/tensorflow/lite/kernels/cpu_backend_gemm.h
@@ -50,14 +50,7 @@
 //  ENABLED && (AVX
 //  or above available)
 
-
-#if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || \
-     defined(_M_X64))
-#define TFLITE_X86_PLATFORM
-#endif
-
-// TODO(b/168923364)  Set TFLITE_X86_RUY_ENABLED default 'on' when ready.
-#if defined(TFLITE_X86_PLATFORM) && defined(TFLITE_X86_RUY_ENABLED)
+#if !defined(TFLITE_WITH_RUY) && defined(TFLITE_X86_PLATFORM)
 /* GEMM dispatch implementation for x86.
  */
 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
@@ -72,12 +65,10 @@
           typename DstScalar, QuantizationFlavor quantization_flavor>
 struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
                                            DstScalar, quantization_flavor> {};
-#endif
 
-#if !defined(TFLITE_WITH_RUY) && !defined(TFLITE_X86_RUY_ENABLED)
+#if !defined(TFLITE_WITH_RUY)
 
 /* Specializations using gemmlowp */
-
 template <typename SrcScalar, typename DstScalar,
           QuantizationFlavor quantization_flavor>
 struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
@@ -114,7 +105,9 @@
 struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
     : detail::GemmImplUsingEigen {};
 
-#endif  // not TFLITE_WITH_RUY && not TFLITE_X86_RUY_ENABLED
+#endif  // not TFLITE_WITH_RUY
+
+#endif  // not TFLITE_WITH_RUY and TFLITE_X86_PLATFORM
 
 /* Public entry point */
 
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc
index 521e7bb..06bc7a0 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc
+++ b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc
@@ -297,7 +297,7 @@
 // done so far. Until that is done, the best that we can do is to search for
 // a good exponent value by trial-and-error. This is expensive, as each try
 // requires computing a whole GEMM. This is thus probably a major contribution
-// to the overall latency of this tesat. To partially mitigate that,
+// to the overall latency of this test. To partially mitigate that,
 // we use a bisection to reduce the required number of tries.
 //
 // This function is recursive. The bisect_min and bisect_max arguments
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_x86.h b/tensorflow/lite/kernels/cpu_backend_gemm_x86.h
index 20af953..39d37c7 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm_x86.h
+++ b/tensorflow/lite/kernels/cpu_backend_gemm_x86.h
@@ -41,25 +41,27 @@
       const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
       const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
       CpuBackendContext* context) {
-    // Run-time dispatch to Ruy for platforms with AVX or above.
-    if (context->HasAvxOrAbove()) {
-      detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
-                               quantization_flavor>::Run(lhs_params, lhs_data,
-                                                         rhs_params, rhs_data,
-                                                         dst_params, dst_data,
-                                                         params, context);
-    } else {
-      // Dispatch to gemmlowp for SSE.
+    // TODO(b/168923364) Ruy is preferred on x86, but check if the deprecated
+    // path is enabled.
+    if (context->PreferGemmlowpOnX86()) {
+      // Dispatch to gemmlowp.
       detail::GemmImplUsingGemmlowp<
           LhsScalar, RhsScalar, AccumScalar, DstScalar,
           quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
                                     dst_params, dst_data, params, context);
+
+      return;
     }
+    // Run-time dispatch to Ruy for platforms with AVX or above.
+    detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
+                             quantization_flavor>::Run(lhs_params, lhs_data,
+                                                       rhs_params, rhs_data,
+                                                       dst_params, dst_data,
+                                                       params, context);
   }
 };
 
-// For float, again prefer Ruy in all cases, but defer to eigen if no flavor of
-// AVX is present.
+// For float, defer to eigen for now.
 template <>
 struct GemmImplX86<float, float, float, float,
                    QuantizationFlavor::kFloatingPoint> {
@@ -69,19 +71,8 @@
                   const GemmParams<float, float,
                                    QuantizationFlavor::kFloatingPoint>& params,
                   CpuBackendContext* context) {
-    // Run-time dispatch to Ruy for platforms with AVX or above.
-    if (context->HasAvxOrAbove()) {
-      detail::GemmImplUsingRuy<
-          float, float, float, float,
-          QuantizationFlavor::kFloatingPoint>::Run(lhs_params, lhs_data,
-                                                   rhs_params, rhs_data,
-                                                   dst_params, dst_data, params,
-                                                   context);
-    } else {
-      // Dispatch to gemmlowp for SSE.
-      GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data,
-                              dst_params, dst_data, params, context);
-    }
+    GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data,
+                            dst_params, dst_data, params, context);
   }
 };
 
diff --git a/tensorflow/lite/kernels/densify.cc b/tensorflow/lite/kernels/densify.cc
index cd0c0a5..ca3eb77 100644
--- a/tensorflow/lite/kernels/densify.cc
+++ b/tensorflow/lite/kernels/densify.cc
@@ -80,21 +80,21 @@
                              GetTensorShape(op_context.input),
                              GetTensorData<float>(op_context.input),
                              GetTensorShape(op_context.output),
-                             GetTensorData<float>(op_context.output));
+                             GetTensorData<float>(op_context.output), context);
       break;
     case kTfLiteFloat16:
-      reference_ops::Densify(op_context.input->sparsity,
-                             GetTensorShape(op_context.input),
-                             GetTensorData<Eigen::half>(op_context.input),
-                             GetTensorShape(op_context.output),
-                             GetTensorData<Eigen::half>(op_context.output));
+      reference_ops::Densify(
+          op_context.input->sparsity, GetTensorShape(op_context.input),
+          GetTensorData<Eigen::half>(op_context.input),
+          GetTensorShape(op_context.output),
+          GetTensorData<Eigen::half>(op_context.output), context);
       break;
     case kTfLiteInt8:
       reference_ops::Densify(op_context.input->sparsity,
                              GetTensorShape(op_context.input),
                              GetTensorData<int8_t>(op_context.input),
                              GetTensorShape(op_context.output),
-                             GetTensorData<int8_t>(op_context.output));
+                             GetTensorData<int8_t>(op_context.output), context);
       break;
 
     default:
diff --git a/tensorflow/lite/java/src/main/native/op_resolver.h b/tensorflow/lite/kernels/deprecated_backends.cc
similarity index 67%
copy from tensorflow/lite/java/src/main/native/op_resolver.h
copy to tensorflow/lite/kernels/deprecated_backends.cc
index 08ff0ce..5688653 100644
--- a/tensorflow/lite/java/src/main/native/op_resolver.h
+++ b/tensorflow/lite/kernels/deprecated_backends.cc
@@ -12,17 +12,13 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
-#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
-
-#include <memory>
-
-#include "tensorflow/lite/op_resolver.h"
 
 namespace tflite {
 
-std::unique_ptr<OpResolver> CreateOpResolver();
+// Include this target as a dependency in order to define this function for
+// CpuBackendContext. Its use is to control execution of deprecated paths
+// by providing a symbol definition for otherwise "weak" symbol
+// declarations in CpuBackendContext.
+extern bool UseGemmlowpOnX86() { return true; }
 
-}
-
-#endif  // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/detection_postprocess.cc b/tensorflow/lite/kernels/detection_postprocess.cc
index 4f1040e..f746ad1 100644
--- a/tensorflow/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/lite/kernels/detection_postprocess.cc
@@ -382,9 +382,11 @@
 
 bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) {
   for (int i = 0; i < num_boxes; ++i) {
-    // ymax>=ymin, xmax>=xmin
     auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
-    if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
+    // Note: `ComputeIntersectionOverUnion` properly handles degenerated boxes
+    // (xmin == xmax and/or ymin == ymax) as it just returns 0 in case the box
+    // area is <= 0.
+    if (box.ymin > box.ymax || box.xmin > box.xmax) {
       return false;
     }
   }
diff --git a/tensorflow/lite/kernels/detection_postprocess_test.cc b/tensorflow/lite/kernels/detection_postprocess_test.cc
index b9c42e7..4f73098 100644
--- a/tensorflow/lite/kernels/detection_postprocess_test.cc
+++ b/tensorflow/lite/kernels/detection_postprocess_test.cc
@@ -187,6 +187,70 @@
               ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
 }
 
+// Tests the case when a box degenerates to a point (xmin==xmax, ymin==ymax).
+TEST(DetectionPostprocessOpTest, FloatTestWithDegeneratedBox) {
+  BaseDetectionPostprocessOpModel m(
+      {TensorType_FLOAT32, {1, 2, 4}}, {TensorType_FLOAT32, {1, 2, 3}},
+      {TensorType_FLOAT32, {2, 4}}, {TensorType_FLOAT32, {}},
+      {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
+      {TensorType_FLOAT32, {}});
+
+  // two boxes in center-size encoding
+  m.SetInput1<float>({
+      0.0, 0.0, 0.0, 0.0,  // box #1
+      0.0, 0.0, 0.0, 0.0,  // box #2
+  });
+  // class scores - two classes with background
+  m.SetInput2<float>({
+      /*background*/ 0., /*class 0*/ .9, /*class 1*/ .8,  // box #1
+      /*background*/ 0., /*class 0*/ .2, /*class 1*/ .7   // box #2
+  });
+  // two anchors in center-size encoding
+  m.SetInput3<float>({
+      0.5, 0.5, 1.0, 1.0,  // anchor #1
+      0.5, 0.5, 0.0, 0.0   // anchor #2 - DEGENERATED!
+  });
+  // Same boxes in box-corner encoding:
+  // { 0.0, 0.0, 1.0, 1.0,
+  //   0.5, 0.5, 0.5, 0.5} // DEGENERATED!
+  // NOTE: this is used instead of `m.Invoke()` to make sure the entire test
+  // gets aborted if an error occurs (which does not happen when e.g. ASSERT_EQ
+  // is used in such a helper function).
+  ASSERT_EQ(m.InvokeUnchecked(), kTfLiteOk);
+  // num_detections
+  std::vector<int> output_shape4 = m.GetOutputShape4();
+  EXPECT_THAT(output_shape4, ElementsAre(1));
+  const int num_detections = static_cast<int>(m.GetOutput4<float>()[0]);
+  EXPECT_EQ(num_detections, 2);
+  // detection_boxes
+  std::vector<int> output_shape1 = m.GetOutputShape1();
+  // NOTE: there are up to 3 detected boxes as per `max_detections` and
+  // `max_classes_per_detection` parameters. But since the actual number of
+  // detections is 2 (see above) only the top-2 results are tested
+  // here and below.
+  EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4));
+  std::vector<float> detection_boxes = m.GetOutput1<float>();
+  detection_boxes.resize(num_detections * 4);
+  EXPECT_THAT(detection_boxes,
+              ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 1.0,   // box #1
+                                               0.5, 0.5, 0.5, 0.5},  // box #2
+                                              1e-1)));
+  // detection_classes
+  std::vector<int> output_shape2 = m.GetOutputShape2();
+  EXPECT_THAT(output_shape2, ElementsAre(1, 3));
+  std::vector<float> detection_classes = m.GetOutput2<float>();
+  detection_classes.resize(num_detections);
+  EXPECT_THAT(detection_classes,
+              ElementsAreArray(ArrayFloatNear({0, 1}, 1e-4)));
+  // detection_scores
+  std::vector<int> output_shape3 = m.GetOutputShape3();
+  EXPECT_THAT(output_shape3, ElementsAre(1, 3));
+  std::vector<float> detection_scores = m.GetOutput3<float>();
+  detection_scores.resize(num_detections);
+  EXPECT_THAT(detection_scores,
+              ElementsAreArray(ArrayFloatNear({0.9, 0.7}, 1e-4)));
+}
+
 TEST(DetectionPostprocessOpTest, QuantizedTest) {
   BaseDetectionPostprocessOpModel m(
       {TensorType_UINT8, {1, 6, 4}, -1.0, 1.0},
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index b59dc0a..775fe8a 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -385,6 +385,10 @@
     hdrs = ["quantization_util.h"],
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts() + micro_copts(),
+    linkopts = select({
+        "//tensorflow:windows": [],
+        "//conditions:default": ["-lm"],
+    }),
     deps = [
         ":compatibility",
         ":cppmath",
@@ -457,6 +461,7 @@
         "reference/depthwiseconv_float.h",
         "reference/depthwiseconv_uint8.h",
         "reference/dequantize.h",
+        "reference/div.h",
         "reference/fill.h",
         "reference/floor.h",
         "reference/fully_connected.h",
@@ -504,6 +509,13 @@
     }),
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
+    # We are disabling parse_headers for the tf_lite_static_memory build to
+    # allow it to be consistent with the OSS bazel build. See b/175817116
+    # for more details.
+    features = select({
+        ":tf_lite_static_memory": ["-parse_headers"],
+        "//conditions:default": [],
+    }),
     deps = [
         ":common",
         ":compatibility",
@@ -551,6 +563,7 @@
         "reference/depthwiseconv_float.h",
         "reference/depthwiseconv_uint8.h",
         "reference/dequantize.h",
+        "reference/div.h",
         "reference/fill.h",
         "reference/floor.h",
         "reference/fully_connected.h",
diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h
index 1e241ca..c433fc8 100644
--- a/tensorflow/lite/kernels/internal/common.h
+++ b/tensorflow/lite/kernels/internal/common.h
@@ -178,8 +178,12 @@
   // - input x is in the range -(1<<47) <= x < (1<<47)
   assert(quantized_multiplier >= 0);
   assert(shift >= -31 && shift < 8);
+  assert(x >= -(static_cast<int64_t>(1) << 47) &&
+         x < (static_cast<int64_t>(1) << 47));
 
-  int32_t reduced_multiplier = (quantized_multiplier + (1 << 15)) >> 16;
+  int32_t reduced_multiplier = (quantized_multiplier < 0x7FFF0000)
+                                   ? ((quantized_multiplier + (1 << 15)) >> 16)
+                                   : 0x7FFF;
   int total_shift = 15 - shift;
   x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
   int32_t result = x >> total_shift;
@@ -297,10 +301,11 @@
         TfLiteRound(func(min + i * step + half_step) * 32768.0);
     double midpoint_err = midpoint_interp_val - midpoint_val;
     double bias = TfLiteRound(midpoint_err / 2.0);
-    table[i] = std::min(std::max(sample_val - bias, -32768.0), 32767.0);
+    table[i] = std::min<double>(std::max<double>(sample_val - bias, -32768.0),
+                                32767.0);
   }
-  table[num - 1] =
-      std::min(std::max(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
+  table[num - 1] = std::min<double>(
+      std::max<double>(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
 }
 
 // generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
@@ -325,10 +330,11 @@
         TfLiteRound(func(min + i * step + half_step) * 32768.0f);
     float midpoint_err = midpoint_interp_val - midpoint_val;
     float bias = TfLiteRound(midpoint_err / 2.0f);
-    table[i] = std::min(std::max(sample_val - bias, -32768.0f), 32767.0f);
+    table[i] = std::min<float>(std::max<float>(sample_val - bias, -32768.0f),
+                               32767.0f);
   }
-  table[num - 1] = std::min(
-      std::max(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f);
+  table[num - 1] = std::min<float>(
+      std::max<float>(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f);
 }
 
 // int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax
diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h
index f269650..1749513 100644
--- a/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h
+++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h
@@ -26,7 +26,6 @@
 #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
@@ -145,11 +144,10 @@
   }
 }
 
-inline void AveragePool16(const PoolParams& params,
-                          const RuntimeShape& input_shape,
-                          const int8* input_data,
-                          const RuntimeShape& output_shape, int8* output_data) {
-  ruy::profiler::ScopeLabel label("AveragePool/8bitWith16bitAccumulator");
+inline void AveragePool(const PoolParams& params,
+                        const RuntimeShape& input_shape, const int8* input_data,
+                        const RuntimeShape& output_shape, int8* output_data) {
+  ruy::profiler::ScopeLabel label("AveragePool/8bitWith32bitAccumulator");
 
   // Here, and in other pooling ops, in order to maintain locality of reference,
   // to minimize some recalculations, and to load into NEON vector registers, we
@@ -171,7 +169,7 @@
   const int stride_height = params.stride_height;
   const int stride_width = params.stride_width;
 
-  int16 acc[kPoolingAccTrancheSize];
+  int32 acc[kPoolingAccTrancheSize];
   for (int batch = 0; batch < batches; ++batch) {
     // We proceed through the depth in tranches (see comment above). The
     // depth_base is the depth at the beginning of the tranche. The
@@ -207,24 +205,30 @@
               int channel = 0;
 #ifdef USE_NEON
               for (; channel <= tranche_depth - 16; channel += 16) {
-                int16x8_t acc_reg[2];
-                for (int i = 0; i < 2; i++) {
-                  acc_reg[i] = vld1q_s16(acc + channel + 8 * i);
-                }
+                int16x4_t acc_reg[4];
                 int8x16_t input_reg = vld1q_s8(input_channel_ptr);
                 input_channel_ptr += 16;
-                acc_reg[0] = vaddw_s8(acc_reg[0], vget_low_s8(input_reg));
-                acc_reg[1] = vaddw_s8(acc_reg[1], vget_high_s8(input_reg));
-                for (int i = 0; i < 2; i++) {
-                  vst1q_s16(acc + channel + 8 * i, acc_reg[i]);
+                acc_reg[0] = vget_low_s16(vmovl_s8(vget_low_s8(input_reg)));
+                acc_reg[1] = vget_high_s16(vmovl_s8(vget_low_s8(input_reg)));
+                acc_reg[2] = vget_low_s16(vmovl_s8(vget_high_s8(input_reg)));
+                acc_reg[3] = vget_high_s16(vmovl_s8(vget_high_s8(input_reg)));
+                for (int i = 0; i < 4; i++) {
+                  vst1q_s32(
+                      acc + channel + 4 * i,
+                      vaddw_s16(vld1q_s32(acc + channel + 4 * i), acc_reg[i]));
                 }
               }
               for (; channel <= tranche_depth - 8; channel += 8) {
-                int16x8_t acc_reg = vld1q_s16(acc + channel);
-                int8x8_t input_reg = vld1_s8(input_channel_ptr);
+                int16x4_t acc_reg[2];
+                int16x8_t input_reg = vmovl_s8(vld1_s8(input_channel_ptr));
                 input_channel_ptr += 8;
-                acc_reg = vaddw_s8(acc_reg, input_reg);
-                vst1q_s16(acc + channel, acc_reg);
+                acc_reg[0] = vget_low_s16(input_reg);
+                acc_reg[1] = vget_high_s16(input_reg);
+                for (int i = 0; i < 2; i++) {
+                  vst1q_s32(
+                      acc + channel + 4 * i,
+                      vaddw_s16(vld1q_s32(acc + channel + 4 * i), acc_reg[i]));
+                }
               }
 #endif
               for (; channel < tranche_depth; ++channel) {
@@ -237,24 +241,6 @@
                                                   out_x, depth_base);
           int channel = 0;
 #ifdef USE_NEON
-#define AVGPOOL_DIVIDING_BY(FILTER_COUNT)                                    \
-  if (filter_count == FILTER_COUNT) {                                        \
-    for (; channel <= tranche_depth - 8; channel += 8) {                     \
-      int16 buf[8];                                                          \
-      for (int i = 0; i < 8; i++) {                                          \
-        buf[i] = acc[channel + i] > 0                                        \
-                     ? (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT  \
-                     : (acc[channel + i] - FILTER_COUNT / 2) / FILTER_COUNT; \
-      }                                                                      \
-      int8x8_t buf8 = vqmovn_s16(vld1q_s16(buf));                            \
-      buf8 = vmin_s8(buf8, vdup_n_s8(params.quantized_activation_max));      \
-      buf8 = vmax_s8(buf8, vdup_n_s8(params.quantized_activation_min));      \
-      vst1_s8(output_ptr + channel, buf8);                                   \
-    }                                                                        \
-  }
-          AVGPOOL_DIVIDING_BY(9)
-          AVGPOOL_DIVIDING_BY(15)
-#undef AVGPOOL_DIVIDING_BY
           for (; channel <= tranche_depth - 8; channel += 8) {
             int16 buf[8];
             for (int i = 0; i < 8; i++) {
@@ -283,17 +269,6 @@
   }
 }
 
-inline void AveragePool(const PoolParams& params,
-                        const RuntimeShape& input_shape, const int8* input_data,
-                        const RuntimeShape& output_shape, int8* output_data) {
-  if (params.filter_height * params.filter_width > 16 * 16) {
-    reference_integer_ops::AveragePool(params, input_shape, input_data,
-                                       output_shape, output_data);
-  } else {
-    AveragePool16(params, input_shape, input_data, output_shape, output_data);
-  }
-}
-
 }  // namespace optimized_integer_ops
 }  // namespace tflite
 
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index cbe6251..41cc2ee 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -1275,8 +1275,6 @@
     gemm_input_data = im2col_data;
     gemm_input_shape = &im2col_shape;
   } else {
-    // TODO(aselle): We need to make sure to not send im2col if it is not
-    // needed.
     TFLITE_DCHECK(!im2col_data);
     gemm_input_data = input_data;
     gemm_input_shape = &input_shape;
@@ -7830,7 +7828,7 @@
   }
 }
 
-// TODO(alanchiao): see if we can reduce the number
+// TODO(b/173718660): see if we can reduce the number
 // of lines of code in branching without affecting latency.
 template <typename T>
 inline void Transpose3D(const TransposeParams& params,
diff --git a/tensorflow/lite/kernels/internal/quantization_util.cc b/tensorflow/lite/kernels/internal/quantization_util.cc
index cf431cf..ed0fe43 100644
--- a/tensorflow/lite/kernels/internal/quantization_util.cc
+++ b/tensorflow/lite/kernels/internal/quantization_util.cc
@@ -289,7 +289,7 @@
     input_beta_real_multiplier = (1ll << 31) - 1.0;
   }
 #else   // TFLITE_EMULATE_FLOAT
-  const double input_beta_real_multiplier = std::min(
+  const double input_beta_real_multiplier = std::min<double>(
       beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
 #endif  // TFLITE_EMULATE_FLOAT
 
diff --git a/tensorflow/lite/kernels/internal/quantization_util_test.cc b/tensorflow/lite/kernels/internal/quantization_util_test.cc
index 053b311..b14b7e5 100644
--- a/tensorflow/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/lite/kernels/internal/quantization_util_test.cc
@@ -422,6 +422,72 @@
   EXPECT_THAT(inv_sqrt(kInt32Max), Pair(189812531, 12));
 }
 
+TEST(QuantizationUtilTest, MultiplyByQuantizedMultiplierInt32) {
+  auto quant_and_multiply = [](int32_t x, double multiplier) {
+    int32_t quantized_multiplier;
+    int shift;
+    QuantizeMultiplier(multiplier, &quantized_multiplier, &shift);
+    return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
+  };
+
+  EXPECT_EQ(quant_and_multiply(0, 0.1), 0);
+  EXPECT_EQ(quant_and_multiply(1, 0), 0);
+  EXPECT_EQ(quant_and_multiply(10000, 0.00097656), 10);
+  EXPECT_EQ(quant_and_multiply(10000, -0.00097656), -10);
+  EXPECT_EQ(quant_and_multiply(-10000, 0.00097656), -10);
+  EXPECT_EQ(quant_and_multiply(-10000, -0.00097656), 10);
+  EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::min(), 0.00001),
+            -21475);
+  EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::min(), -0.00001),
+            21475);
+  EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::max(), 0.00001),
+            21475);
+  EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::max(), -0.00001),
+            -21475);
+
+  // Test with maximum possible x and quantized_multiplier
+  const int32_t x = std::numeric_limits<int32_t>::max();
+  const int32_t quantized_multiplier = std::numeric_limits<int32_t>::max();
+  const int shift = -3;
+  const int32_t expected = static_cast<int32_t>(
+      TfLiteRound(static_cast<int64_t>(x) * quantized_multiplier /
+                  static_cast<double>(1ll << (31 - shift))));
+  EXPECT_EQ(MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift),
+            expected);
+  EXPECT_EQ(MultiplyByQuantizedMultiplier(-x, quantized_multiplier, shift),
+            -expected);
+}
+
+TEST(QuantizationUtilTest, MultiplyByQuantizedMultiplierInt64) {
+  auto quant_and_multiply = [](int64_t x, double multiplier) {
+    int32_t quantized_multiplier;
+    int shift;
+    QuantizeMultiplier(multiplier, &quantized_multiplier, &shift);
+    return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
+  };
+
+  // Negative multipliers are not supported by the 64-bit
+  // MultiplyByQuantizedMultiplier, only use >= 0 multipliers.
+  EXPECT_EQ(quant_and_multiply(0, 0.1), 0);
+  EXPECT_EQ(quant_and_multiply(1, 0), 0);
+  EXPECT_EQ(quant_and_multiply(10000, 0.00097656), 10);
+  EXPECT_EQ(quant_and_multiply(-10000, 0.00097656), -10);
+  EXPECT_EQ(quant_and_multiply(-(1ll << 47), 0.00001), -1407385600);
+  EXPECT_EQ(quant_and_multiply((1ll << 47) - 1, 0.00001), 1407385600);
+
+  // Test with maximum possible x and quantized_multiplier
+  const int64_t x = (1ll << 47) - 1;
+  const int32_t quantized_multiplier = std::numeric_limits<int32_t>::max();
+  const int shift = -31;
+  // Expected is around 'x * quantized_multiplier / 2**(31 - shift)' ~= 65536
+  // As there is some rounding error, expected is a bit smaller.
+  const int32_t expected = 65534;
+  EXPECT_EQ(MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift),
+            expected);
+  EXPECT_EQ(MultiplyByQuantizedMultiplier(-x, quantized_multiplier, shift),
+            -expected);
+}
+
 TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) {
   auto quantize = [](double beta, double scale, int integer_bits) {
     int32_t q;
diff --git a/tensorflow/lite/kernels/internal/reference/densify.h b/tensorflow/lite/kernels/internal/reference/densify.h
index 71a9a26..f5179ba 100644
--- a/tensorflow/lite/kernels/internal/reference/densify.h
+++ b/tensorflow/lite/kernels/internal/reference/densify.h
@@ -28,7 +28,8 @@
 template <typename T>
 inline void Densify(const TfLiteSparsity* sparsity,
                     const RuntimeShape& input_shape, const T* input_data,
-                    const RuntimeShape& output_shape, T* output_data) {
+                    const RuntimeShape& output_shape, T* output_data,
+                    TfLiteContext* context) {
   const int dims_count = output_shape.DimensionsCount();
   std::vector<int> vector_shape(dims_count);
   for (int i = 0; i < dims_count; i++) {
@@ -37,11 +38,8 @@
 
   tflite::optimize::sparsity::FormatConverter<T> converter(vector_shape,
                                                            *sparsity);
-  converter.SparseToDense(input_data);
-  const std::vector<T> out = converter.GetData();
-  for (int i = 0; i < out.size(); i++) {
-    output_data[i] = out[i];
-  }
+  converter.SparseToDense(input_data, output_shape.FlatSize(), output_data,
+                          context);
 }
 
 }  // namespace reference_ops
diff --git a/tensorflow/lite/kernels/internal/reference/div.h b/tensorflow/lite/kernels/internal/reference/div.h
new file mode 100644
index 0000000..bdd3ecc
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/div.h
@@ -0,0 +1,194 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
+
+#include <algorithm>
+
+#include "tensorflow/lite/kernels/internal/common.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+// Element-wise div that can often be used for inner loop of broadcast Div as
+// well as the non-broadcast Div.
+inline void DivElementwise(int size, const ArithmeticParams& params,
+                           const uint8* input1_data, const uint8* input2_data,
+                           uint8* output_data) {
+  TFLITE_DCHECK_GT(params.input1_offset, -256);
+  TFLITE_DCHECK_LT(params.input1_offset, 256);
+  TFLITE_DCHECK_GT(params.input2_offset, -256);
+  TFLITE_DCHECK_LT(params.input2_offset, 256);
+  TFLITE_DCHECK_GT(params.output_offset, -256);
+  TFLITE_DCHECK_LT(params.output_offset, 256);
+
+  for (int i = 0; i < size; ++i) {
+    const int32 input1_val = params.input1_offset + input1_data[i];
+    const int32 input2_val = params.input2_offset + input2_data[i];
+    TFLITE_DCHECK_NE(input2_val, 0);
+    int recip_shift;
+    const int32 input2_inv =
+        (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift)
+                         : -GetReciprocal(-input2_val, 31, &recip_shift);
+    const int headroom = CountLeadingSignBits(input1_val);
+    const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne(
+        input1_val, input2_inv, headroom);
+    const int total_shift = params.output_shift - recip_shift - headroom;
+    const int32 unclamped_result =
+        params.output_offset +
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            unscaled_quotient, params.output_multiplier, total_shift);
+    const int32 clamped_output =
+        std::min(params.quantized_activation_max,
+                 std::max(params.quantized_activation_min, unclamped_result));
+    output_data[i] = static_cast<uint8>(clamped_output);
+  }
+}
+
+inline void Div(const ArithmeticParams& params,
+                const RuntimeShape& input1_shape, const uint8* input1_data,
+                const RuntimeShape& input2_shape, const uint8* input2_data,
+                const RuntimeShape& output_shape, uint8* output_data) {
+  TFLITE_DCHECK_LE(params.quantized_activation_min,
+                   params.quantized_activation_max);
+  const int flat_size =
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
+
+  DivElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
+
+template <int N = 5>
+inline void BroadcastDivSlow(const ArithmeticParams& params,
+                             const RuntimeShape& unextended_input1_shape,
+                             const uint8* input1_data,
+                             const RuntimeShape& unextended_input2_shape,
+                             const uint8* input2_data,
+                             const RuntimeShape& unextended_output_shape,
+                             uint8* output_data) {
+  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
+  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
+
+  NdArrayDesc<N> desc1;
+  NdArrayDesc<N> desc2;
+  NdArrayDesc<N> output_desc;
+  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+                                      unextended_input2_shape, &desc1, &desc2);
+  CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
+                 &output_desc);
+
+  TFLITE_DCHECK_GT(params.input1_offset, -256);
+  TFLITE_DCHECK_LT(params.input1_offset, 256);
+  TFLITE_DCHECK_GT(params.input2_offset, -256);
+  TFLITE_DCHECK_LT(params.input2_offset, 256);
+  TFLITE_DCHECK_GT(params.output_offset, -256);
+  TFLITE_DCHECK_LT(params.output_offset, 256);
+
+  auto div_func = [&](int indexes[N]) {
+    const int32 input1_val =
+        params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
+    const int32 input2_val =
+        params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
+    TFLITE_DCHECK_NE(input2_val, 0);
+    int recip_shift;
+    const int32 input2_inv =
+        (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift)
+                         : -GetReciprocal(-input2_val, 31, &recip_shift);
+    const int headroom = CountLeadingSignBits(input1_val);
+    const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne(
+        input1_val, input2_inv, headroom);
+    const int total_shift = params.output_shift - recip_shift - headroom;
+    const int32 unclamped_result =
+        params.output_offset +
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            unscaled_quotient, params.output_multiplier, total_shift);
+    const int32 clamped_output =
+        std::min(params.quantized_activation_max,
+                 std::max(params.quantized_activation_min, unclamped_result));
+    output_data[SubscriptToIndex(output_desc, indexes)] =
+        static_cast<uint8>(clamped_output);
+  };
+  NDOpsHelper<N>(output_desc, div_func);
+}
+
+// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+template <typename T, int N = 5>
+void BroadcastDivSlow(const ArithmeticParams& params,
+                      const RuntimeShape& unextended_input1_shape,
+                      const T* input1_data,
+                      const RuntimeShape& unextended_input2_shape,
+                      const T* input2_data,
+                      const RuntimeShape& unextended_output_shape,
+                      T* output_data) {
+  T output_activation_min;
+  T output_activation_max;
+  GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
+  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
+
+  NdArrayDesc<N> desc1;
+  NdArrayDesc<N> desc2;
+  NdArrayDesc<N> output_desc;
+  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+                                      unextended_input2_shape, &desc1, &desc2);
+  CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
+                 &output_desc);
+
+  // In Tensorflow, the dimensions are canonically named (batch_number, row,
+  // col, channel), with extents (batches, height, width, depth), with the
+  // trailing dimension changing most rapidly (channels has the smallest
+  // stride, typically 1 element).
+  //
+  // In generated C code, we store arrays with the dimensions reversed. The
+  // first dimension has smallest stride.
+
+  auto div_func = [&](int indexes[N]) {
+    output_data[SubscriptToIndex(output_desc, indexes)] =
+        ActivationFunctionWithMinMax(
+            input1_data[SubscriptToIndex(desc1, indexes)] /
+                input2_data[SubscriptToIndex(desc2, indexes)],
+            output_activation_min, output_activation_max);
+  };
+  NDOpsHelper<N>(output_desc, div_func);
+}
+
+template <typename T>
+inline void Div(const ArithmeticParams& params,
+                const RuntimeShape& input1_shape, const T* input1_data,
+                const RuntimeShape& input2_shape, const T* input2_data,
+                const RuntimeShape& output_shape, T* output_data) {
+  T output_activation_min;
+  T output_activation_max;
+  GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+  const int flat_size =
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
+  for (int i = 0; i < flat_size; ++i) {
+    output_data[i] = ActivationFunctionWithMinMax(
+        input1_data[i] / input2_data[i], output_activation_min,
+        output_activation_max);
+  }
+}
+
+}  // namespace reference_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DIV_H_
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index afbb717..0b70440 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -40,6 +40,7 @@
 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
 #include "tensorflow/lite/kernels/internal/reference/conv.h"
 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
+#include "tensorflow/lite/kernels/internal/reference/div.h"
 #include "tensorflow/lite/kernels/internal/reference/fill.h"
 #include "tensorflow/lite/kernels/internal/reference/floor.h"
 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
@@ -420,172 +421,6 @@
   }
 }
 
-// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-template <typename T, int N = 5>
-void BroadcastDivSlow(const ArithmeticParams& params,
-                      const RuntimeShape& unextended_input1_shape,
-                      const T* input1_data,
-                      const RuntimeShape& unextended_input2_shape,
-                      const T* input2_data,
-                      const RuntimeShape& unextended_output_shape,
-                      T* output_data) {
-  T output_activation_min;
-  T output_activation_max;
-  GetActivationParams(params, &output_activation_min, &output_activation_max);
-
-  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
-  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
-  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
-
-  NdArrayDesc<N> desc1;
-  NdArrayDesc<N> desc2;
-  NdArrayDesc<N> output_desc;
-  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
-                                      unextended_input2_shape, &desc1, &desc2);
-  CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
-                 &output_desc);
-
-  // In Tensorflow, the dimensions are canonically named (batch_number, row,
-  // col, channel), with extents (batches, height, width, depth), with the
-  // trailing dimension changing most rapidly (channels has the smallest
-  // stride, typically 1 element).
-  //
-  // In generated C code, we store arrays with the dimensions reversed. The
-  // first dimension has smallest stride.
-
-  auto div_func = [&](int indexes[N]) {
-    output_data[SubscriptToIndex(output_desc, indexes)] =
-        ActivationFunctionWithMinMax(
-            input1_data[SubscriptToIndex(desc1, indexes)] /
-                input2_data[SubscriptToIndex(desc2, indexes)],
-            output_activation_min, output_activation_max);
-  };
-  NDOpsHelper<N>(output_desc, div_func);
-}
-
-template <typename T>
-inline void Div(const ArithmeticParams& params,
-                const RuntimeShape& input1_shape, const T* input1_data,
-                const RuntimeShape& input2_shape, const T* input2_data,
-                const RuntimeShape& output_shape, T* output_data) {
-  T output_activation_min;
-  T output_activation_max;
-  GetActivationParams(params, &output_activation_min, &output_activation_max);
-
-  const int flat_size =
-      MatchingElementsSize(input1_shape, input2_shape, output_shape);
-  for (int i = 0; i < flat_size; ++i) {
-    output_data[i] = ActivationFunctionWithMinMax(
-        input1_data[i] / input2_data[i], output_activation_min,
-        output_activation_max);
-  }
-}
-
-// Element-wise div that can often be used for inner loop of broadcast Div as
-// well as the non-broadcast Div.
-inline void DivElementwise(int size, const ArithmeticParams& params,
-                           const uint8* input1_data, const uint8* input2_data,
-                           uint8* output_data) {
-  TFLITE_DCHECK_GT(params.input1_offset, -256);
-  TFLITE_DCHECK_LT(params.input1_offset, 256);
-  TFLITE_DCHECK_GT(params.input2_offset, -256);
-  TFLITE_DCHECK_LT(params.input2_offset, 256);
-  TFLITE_DCHECK_GT(params.output_offset, -256);
-  TFLITE_DCHECK_LT(params.output_offset, 256);
-
-  for (int i = 0; i < size; ++i) {
-    const int32 input1_val = params.input1_offset + input1_data[i];
-    const int32 input2_val = params.input2_offset + input2_data[i];
-    TFLITE_DCHECK_NE(input2_val, 0);
-    int recip_shift;
-    const int32 input2_inv =
-        (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift)
-                         : -GetReciprocal(-input2_val, 31, &recip_shift);
-    const int headroom = CountLeadingSignBits(input1_val);
-    const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne(
-        input1_val, input2_inv, headroom);
-    const int total_shift = params.output_shift - recip_shift - headroom;
-    const int32 unclamped_result =
-        params.output_offset +
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            unscaled_quotient, params.output_multiplier, total_shift);
-    const int32 clamped_output =
-        std::min(params.quantized_activation_max,
-                 std::max(params.quantized_activation_min, unclamped_result));
-    output_data[i] = static_cast<uint8>(clamped_output);
-  }
-}
-
-inline void Div(const ArithmeticParams& params,
-                const RuntimeShape& input1_shape, const uint8* input1_data,
-                const RuntimeShape& input2_shape, const uint8* input2_data,
-                const RuntimeShape& output_shape, uint8* output_data) {
-  TFLITE_DCHECK_LE(params.quantized_activation_min,
-                   params.quantized_activation_max);
-  ruy::profiler::ScopeLabel label("Div/8bit");
-  const int flat_size =
-      MatchingElementsSize(input1_shape, input2_shape, output_shape);
-
-  DivElementwise(flat_size, params, input1_data, input2_data, output_data);
-}
-
-template <int N = 5>
-inline void BroadcastDivSlow(const ArithmeticParams& params,
-                             const RuntimeShape& unextended_input1_shape,
-                             const uint8* input1_data,
-                             const RuntimeShape& unextended_input2_shape,
-                             const uint8* input2_data,
-                             const RuntimeShape& unextended_output_shape,
-                             uint8* output_data) {
-  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
-  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
-  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
-
-  NdArrayDesc<N> desc1;
-  NdArrayDesc<N> desc2;
-  NdArrayDesc<N> output_desc;
-  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
-                                      unextended_input2_shape, &desc1, &desc2);
-  CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
-                 &output_desc);
-
-  TFLITE_DCHECK_GT(params.input1_offset, -256);
-  TFLITE_DCHECK_LT(params.input1_offset, 256);
-  TFLITE_DCHECK_GT(params.input2_offset, -256);
-  TFLITE_DCHECK_LT(params.input2_offset, 256);
-  TFLITE_DCHECK_GT(params.output_offset, -256);
-  TFLITE_DCHECK_LT(params.output_offset, 256);
-
-  auto div_func = [&](int indexes[N]) {
-    const int32 input1_val =
-        params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
-    const int32 input2_val =
-        params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
-    TFLITE_DCHECK_NE(input2_val, 0);
-    int recip_shift;
-    const int32 input2_inv =
-        (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift)
-                         : -GetReciprocal(-input2_val, 31, &recip_shift);
-    const int headroom = CountLeadingSignBits(input1_val);
-    const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne(
-        input1_val, input2_inv, headroom);
-    const int total_shift = params.output_shift - recip_shift - headroom;
-    const int32 unclamped_result =
-        params.output_offset +
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            unscaled_quotient, params.output_multiplier, total_shift);
-    const int32 clamped_output =
-        std::min(params.quantized_activation_max,
-                 std::max(params.quantized_activation_min, unclamped_result));
-    output_data[SubscriptToIndex(output_desc, indexes)] =
-        static_cast<uint8>(clamped_output);
-  };
-  NDOpsHelper<N>(output_desc, div_func);
-}
-
 inline void Sub16(const ArithmeticParams& params,
                   const RuntimeShape& input1_shape, const int16_t* input1_data,
                   const RuntimeShape& input2_shape, const int16_t* input2_data,
diff --git a/tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h b/tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h
index 0f8a248..81131dd 100644
--- a/tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h
+++ b/tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h
@@ -35,7 +35,7 @@
   tflite::optimize::sparsity::FormatConverter<float> converter(
       weights_shape_vector, sparsity);
   converter.SparseToDense(weights_data);
-  const std::vector<float> dense_weights_data = converter.GetData();
+  const std::vector<float>& dense_weights_data = converter.GetData();
   FullyConnected(params, input_shape, input_data, weights_shape,
                  dense_weights_data.data(), bias_shape, bias_data, output_shape,
                  output_data);
diff --git a/tensorflow/lite/kernels/lsh_projection_test.cc b/tensorflow/lite/kernels/lsh_projection_test.cc
index 008a5c4..a716d6c 100644
--- a/tensorflow/lite/kernels/lsh_projection_test.cc
+++ b/tensorflow/lite/kernels/lsh_projection_test.cc
@@ -87,7 +87,13 @@
 
   m.Invoke();
 
+#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
+    __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+  // Hash returns differently on machines with different endianness
+  EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 1, 1, 1, 0));
+#else
   EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0));
+#endif
 }
 
 TEST(LSHProjectionOpTest2, Sparse1DInputs) {
@@ -98,7 +104,13 @@
 
   m.Invoke();
 
+#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
+    __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+  // Hash returns differently on machines with different endianness
+  EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 3, 8 + 2));
+#else
   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0));
+#endif
 }
 
 TEST(LSHProjectionOpTest2, Sparse3DInputs) {
@@ -111,7 +123,13 @@
 
   m.Invoke();
 
+#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
+    __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+  // Hash returns differently on machines with different endianness
+  EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 3, 8 + 2));
+#else
   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1));
+#endif
 }
 
 }  // namespace
diff --git a/tensorflow/lite/kernels/numeric_verify.cc b/tensorflow/lite/kernels/numeric_verify.cc
index 5b4011f..ce1e491 100644
--- a/tensorflow/lite/kernels/numeric_verify.cc
+++ b/tensorflow/lite/kernels/numeric_verify.cc
@@ -21,6 +21,7 @@
 #include <numeric>
 #include <vector>
 
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/dequantize.h"
 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
@@ -36,13 +37,19 @@
 namespace custom {
 namespace numeric_verify {
 
+static constexpr const char kToleranceStr[] = "tolerance";
+static constexpr const char kDebugModeStr[] = "debug_mode";
+static constexpr const int kTemporaryDequantizedTensor = 0;
+
 struct OpContext {
   OpContext(TfLiteContext* context, TfLiteNode* node) {
     input = GetInput(context, node, 0);
     ref = GetInput(context, node, 1);
+    output = GetOutput(context, node, 0);
   }
   const TfLiteTensor* input;
   const TfLiteTensor* ref;
+  TfLiteTensor* output;
 };
 
 const int kTensorNotAllocated = -1;
@@ -50,21 +57,23 @@
 struct OpData {
   // The percentage of the tensor value range. Must be a number less than 1.0.
   float tolerance;
-  // The abstract value allowed for the floating-point value difference.
-  float max_diff;
   // This boolean value is only used when the input tensor is constant.
   bool float_input_initialized;
   int cache_tensor_id = kTensorNotAllocated;
+  // This boolean value is for controlling the behavior of numeric verify op.
+  bool debug_mode;
 };
 
 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   auto* op_data = new OpData();
   op_data->float_input_initialized = false;
 
-  // Get the tolerance parameter from the buffer. Use flexbuffers asMap if there
-  // multiple custom options.
-  const float* buffer_t = reinterpret_cast<const float*>(buffer);
-  op_data->tolerance = *buffer_t;
+  const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+  const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+  const float tolerance = m[kToleranceStr].AsFloat();
+  const bool debug_mode = m[kDebugModeStr].AsBool();
+  op_data->tolerance = tolerance;
+  op_data->debug_mode = debug_mode;
 
   return op_data;
 }
@@ -75,30 +84,19 @@
 
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
-  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
 
   OpContext op_context(context, node);
 
+  const int num_output = (op_data->debug_mode) ? 1 : 0;
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), num_output);
+
   TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
                               op_context.input->type == kTfLiteInt8 ||
                               op_context.input->type == kTfLiteInt16 ||
                               op_context.input->type == kTfLiteFloat16);
   TF_LITE_ENSURE(context, op_context.ref->type == kTfLiteFloat32);
 
-  op_data->max_diff = op_data->tolerance * op_context.input->params.scale;
-  switch (op_context.input->type) {
-    case kTfLiteUInt8:
-    case kTfLiteInt8:
-      op_data->max_diff *= (1 << 8);
-      break;
-    case kTfLiteInt16:
-      op_data->max_diff *= (1 << 16);
-      break;
-    default:
-      break;
-  }
-
   // Allocate tensor to store the dequantized inputs.
   if (op_data->cache_tensor_id == kTensorNotAllocated) {
     TF_LITE_ENSURE_OK(
@@ -111,7 +109,8 @@
 
   TfLiteTensor* dequantized;
   TF_LITE_ENSURE_OK(context,
-                    GetTemporarySafe(context, node, /*index=*/0, &dequantized));
+                    GetTemporarySafe(context, node, kTemporaryDequantizedTensor,
+                                     &dequantized));
   dequantized->type = op_context.ref->type;
   dequantized->allocation_type = kTfLiteDynamic;
 
@@ -119,6 +118,14 @@
                                  context, dequantized,
                                  TfLiteIntArrayCopy(op_context.input->dims)));
 
+  if (op_data->debug_mode) {
+    TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, num_output - 1,
+                                             &op_context.output));
+    op_context.output->type = kTfLiteFloat32;
+    op_context.output->allocation_type = kTfLiteArenaRwPersistent;
+    return context->ResizeTensor(context, op_context.output,
+                                 TfLiteIntArrayCopy(op_context.input->dims));
+  }
   return kTfLiteOk;
 }
 
@@ -146,7 +153,8 @@
   // Dequantize the input
   TfLiteTensor* dequantized;
   TF_LITE_ENSURE_OK(context,
-                    GetTemporarySafe(context, node, /*index=*/0, &dequantized));
+                    GetTemporarySafe(context, node, kTemporaryDequantizedTensor,
+                                     &dequantized));
   auto status = builtin::dequantize::DequantizeImpl<kernel_type>(
       context, node, op_context.input, dequantized);
   if (status != kTfLiteOk) {
@@ -157,15 +165,32 @@
     op_data->float_input_initialized = true;
   }
 
-  // If the tolerance is very small, we only display the stats of the diff.
-  if (op_data->tolerance < 0.1) {
+  // If the debug_mode is on, we don't throw any errors.
+  // We just calculate difference between float and quantized values, letting
+  // python debugger deal with the information.
+  if (op_data->debug_mode || op_data->tolerance < 0.1) {
+    const int num_output = (op_data->debug_mode) ? 1 : 0;
+    const int n = NumElements(dequantized);
+    if (op_data->debug_mode) {
+      TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, num_output - 1,
+                                               &op_context.output));
+      auto output_data = GetTensorData<float>(op_context.output);
+      for (int i = 0; i < n; ++i) {
+        float dequant = GetTensorData<float>(dequantized)[i];
+        float reference = GetTensorData<float>(op_context.ref)[i];
+        output_data[i] = dequant - reference;
+      }
+    }
+    // These statistics logging was added to identify some errors in practice.
     std::vector<double> diffs, temp;
-    diffs.reserve(NumElements(dequantized));
-    temp.reserve(NumElements(dequantized));
-    for (int i = 0; i < NumElements(op_context.ref); ++i) {
+    diffs.reserve(n);
+    temp.reserve(n);
+    diffs.resize(n);
+    temp.resize(n);
+    for (int i = 0; i < n; ++i) {
       float dequant = GetTensorData<float>(dequantized)[i];
       float reference = GetTensorData<float>(op_context.ref)[i];
-      diffs.push_back(dequant - reference);
+      diffs[i] = static_cast<double>(dequant - reference);
     }
     double mean =
         std::accumulate(diffs.begin(), diffs.end(), 0.0) / diffs.size();
@@ -184,24 +209,24 @@
         mean, max_diff, op_context.input->params.scale,
         op_context.input->params.zero_point);
     return kTfLiteOk;
-  }
-
-  // Verify the dequantized output.
-  auto max_diff = op_data->tolerance * op_context.input->params.scale;
-  for (int i = 0; i < NumElements(op_context.ref); ++i) {
-    int32_t value = GetQuantizedValue(op_context, i);
-    float dequant = GetTensorData<float>(dequantized)[i];
-    float reference = GetTensorData<float>(op_context.ref)[i];
-    float diff = std::abs(reference - dequant);
-    if (diff > max_diff) {
-      TF_LITE_KERNEL_LOG(
-          context,
-          "Mismatch: %f is quantized to %d with (%f, %d). "
-          "abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n",
-          reference, value, op_context.input->params.scale,
-          op_context.input->params.zero_point, reference, dequant, diff,
-          max_diff, op_data->tolerance);
-      return kTfLiteError;
+  } else {
+    // Verify the dequantized output.
+    auto max_diff = op_data->tolerance * op_context.input->params.scale;
+    for (int i = 0; i < NumElements(op_context.ref); ++i) {
+      int32_t value = GetQuantizedValue(op_context, i);
+      float dequant = GetTensorData<float>(dequantized)[i];
+      float reference = GetTensorData<float>(op_context.ref)[i];
+      float diff = std::abs(reference - dequant);
+      if (diff > max_diff) {
+        TF_LITE_KERNEL_LOG(
+            context,
+            "Mismatch: %f is quantized to %d with (%f, %d). "
+            "abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n",
+            reference, value, op_context.input->params.scale,
+            op_context.input->params.zero_point, reference, dequant, diff,
+            max_diff, op_data->tolerance);
+        return kTfLiteError;
+      }
     }
   }
   return kTfLiteOk;
diff --git a/tensorflow/lite/kernels/numeric_verify_test.cc b/tensorflow/lite/kernels/numeric_verify_test.cc
index 9fb2e55..e26f560 100644
--- a/tensorflow/lite/kernels/numeric_verify_test.cc
+++ b/tensorflow/lite/kernels/numeric_verify_test.cc
@@ -21,8 +21,11 @@
 #include <gtest/gtest.h>
 #include "absl/memory/memory.h"
 #include "third_party/eigen3/Eigen/Core"
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
 #include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/test_util.h"
 #include "tensorflow/lite/schema/schema_generated.h"
 
@@ -42,15 +45,25 @@
  public:
   NumericVerifyOpModel(TensorType type, std::initializer_list<int> shape,
                        float scale, int32_t zero_point, int version,
-                       float tolerance = 5.0) {
+                       float tolerance = 5.0, bool debug_mode = false) {
     const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point};
     input_ = AddInput(input_tensor_data);
     ref_ = AddInput({TensorType_FLOAT32, shape});
+    if (debug_mode) {
+      // The output tensor has the same shape with that of the input tensor.
+      output_ = AddOutput({TensorType_FLOAT32, shape});
+    }
 
     std::vector<uint8_t> custom_options(sizeof(float));
-    memcpy(custom_options.data(), &tolerance, sizeof(float));
 
-    SetCustomOp("NUMERIC_VERIFY", custom_options,
+    flexbuffers::Builder fbb;
+    fbb.Map([&]() {
+      fbb.Float("tolerance", tolerance);
+      fbb.Bool("debug_mode", debug_mode);
+    });
+    fbb.Finish();
+
+    SetCustomOp("NUMERIC_VERIFY", fbb.GetBuffer(),
                 ops::custom::Register_NUMERIC_VERIFY);
 
     BuildInterpreter({GetShape(input_), GetShape(ref_)});
@@ -63,9 +76,12 @@
     PopulateTensor(ref_, ref_data);
   }
 
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
  private:
   int input_;
   int ref_;
+  int output_;
 };
 
 TEST(NumericVerifyOpTest, Uint8) {
@@ -117,5 +133,18 @@
   EXPECT_EQ(m.InvokeUnchecked(), kTfLiteError);
 }
 
+TEST(NumericVerifyOpDebugModeTest, Int8) {
+  // [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
+  NumericVerifyOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2, 5.0, true);
+
+  // The 5th element is set to 0.
+  m.SetInputs<int8_t>({-128, -127, -126, -125, -124, 0, 124, 125, 126, 127},
+                      {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64});
+  EXPECT_EQ(m.InvokeUnchecked(), kTfLiteOk);
+  // The 5th element has discrepancy -61.5 (=dequantized - reference=0-(61.5)).
+  EXPECT_THAT(
+      m.GetOutput(),
+      ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, -61.5, 0, 0, 0, 0})));
+}
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/pad.cc b/tensorflow/lite/kernels/pad.cc
index e522ae0..bd68c46 100644
--- a/tensorflow/lite/kernels/pad.cc
+++ b/tensorflow/lite/kernels/pad.cc
@@ -118,7 +118,7 @@
                             op_context.constant_values->type);
   }
 
-  // TODO(nupurgarg): Current implementations rely on the inputs being <= 4D.
+  // Ensure we do not exceed maximum dimension count.
   TF_LITE_ENSURE(
       context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount());
 
diff --git a/tensorflow/lite/kernels/parse_example/BUILD b/tensorflow/lite/kernels/parse_example/BUILD
new file mode 100644
index 0000000..af21fe5
--- /dev/null
+++ b/tensorflow/lite/kernels/parse_example/BUILD
@@ -0,0 +1,76 @@
+# Kernel for custom parse_example
+package(
+    default_visibility = [
+        "//visibility:public",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "parse_example",
+    srcs = [
+        "example_proto_fast_parsing.cc",
+        "parse_example.cc",
+    ],
+    hdrs = [
+        "example_proto_fast_parsing.h",
+        "parse_example.h",
+    ],
+    deps = [
+        "@com_google_absl//absl/base",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@flatbuffers",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/kernels:kernel_util",
+        "//tensorflow/lite/kernels/internal:tensor",
+        "//tensorflow/lite:string_util",
+    ] + select({
+        "//tensorflow:android": [
+            "//tensorflow/core:portable_tensorflow_lib_lite",
+        ],
+        "//tensorflow:ios": [
+            "//tensorflow/core:portable_tensorflow_lib_lite",
+        ],
+        "//conditions:default": [
+            "//tensorflow/core:core_cpu",
+            "//tensorflow/core:feature_util",
+            "//tensorflow/core:framework",
+            "//tensorflow/core:framework_internal",
+            "//tensorflow/core:lib",
+            "//tensorflow/core:lib_internal",
+            "//tensorflow/core:protos_all_cc",
+        ],
+    }),
+)
+
+cc_test(
+    name = "parse_example_test",
+    srcs = ["parse_example_test.cc"],
+    tags = ["no_mac"],  # TODO(b/176113117): Fails to load shared object
+    deps = [
+        ":parse_example",
+        "@flatbuffers",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core/api:op_resolver",
+        "//tensorflow/lite/kernels:builtin_ops",
+        "//tensorflow/lite/kernels:test_main",
+        "//tensorflow/lite/kernels:test_util",
+        "//tensorflow/lite/schema:schema_fbs",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite:string_util",
+    ] + select({
+        "//tensorflow:android": [
+            "//tensorflow/core:portable_tensorflow_lib_lite",
+        ],
+        "//tensorflow:ios": [
+            "//tensorflow/core:portable_tensorflow_lib_lite",
+        ],
+        "//conditions:default": [
+            "//tensorflow/core:protos_all_cc",
+            "//tensorflow/core/example:feature_util",
+            "//tensorflow/core/platform:protobuf",
+            "//tensorflow/core/platform:tstring",
+        ],
+    }),
+)
diff --git a/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.cc b/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.cc
new file mode 100644
index 0000000..5490963
--- /dev/null
+++ b/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.cc
@@ -0,0 +1,170 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h"
+
+namespace tensorflow {
+namespace example {
+
+string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) {
+  return example_names.empty() ? "<unknown>" : example_names[n];
+}
+
+void CountSparseFeatures(
+    const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
+    size_t* total_num_features, size_t* max_num_features) {
+  for (auto& sparse_values_tmp : sparse_buffers) {
+    const std::vector<size_t>& end_indices =
+        sparse_values_tmp[d].example_end_indices;
+    *total_num_features += end_indices.back();
+    *max_num_features = std::max(*max_num_features, end_indices[0]);
+    for (size_t i = 1; i < end_indices.size(); ++i) {
+      size_t example_size = end_indices[i] - end_indices[i - 1];
+      *max_num_features = std::max(*max_num_features, example_size);
+    }
+  }
+}
+
+void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
+                              Tensor* dst) {
+  switch (dtype) {
+    case DT_INT64: {
+      std::copy(src->int64_list.begin(), src->int64_list.end(),
+                dst->flat<int64>().data() + offset);
+      break;
+    }
+    case DT_FLOAT: {
+      std::copy(src->float_list.begin(), src->float_list.end(),
+                dst->flat<float>().data() + offset);
+      break;
+    }
+    case DT_STRING: {
+      std::move(src->bytes_list.begin(), src->bytes_list.end(),
+                dst->flat<tstring>().data() + offset);
+      break;
+    }
+    default:
+      ReportUnexpectedDataType(dtype);
+  }
+}
+
+uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
+  DCHECK(stream != nullptr);
+  const void* ptr;
+  int size;
+  if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
+  return *static_cast<const uint8*>(ptr);
+}
+
+bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
+  DCHECK(stream != nullptr);
+  DCHECK(result != nullptr);
+  uint32 length;
+  if (!stream->ReadVarint32(&length)) return false;
+  if (length == 0) {
+    *result = StringPiece(nullptr, 0);
+    return true;
+  }
+  const void* stream_alias;
+  int stream_size;
+  if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
+    return false;
+  }
+  if (static_cast<uint32>(stream_size) < length) return false;
+  *result = StringPiece(static_cast<const char*>(stream_alias), length);
+  stream->Skip(length);
+  return true;
+}
+
+bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
+                          parsed::FeatureMapEntry* feature_map_entry) {
+  DCHECK(stream != nullptr);
+  DCHECK(feature_map_entry != nullptr);
+  uint32 length;
+  if (!stream->ReadVarint32(&length)) return false;
+  auto limit = stream->PushLimit(length);
+  if (!stream->ExpectTag(kDelimitedTag(1))) return false;
+  if (!ParseString(stream, &feature_map_entry->first)) return false;
+  if (!stream->ExpectTag(kDelimitedTag(2))) return false;
+  StringPiece feature_string_piece;
+  if (!ParseString(stream, &feature_string_piece)) return false;
+  feature_map_entry->second = parsed::Feature(feature_string_piece);
+  if (!stream->ExpectAtEnd()) return false;
+  stream->PopLimit(limit);
+  return true;
+}
+
+bool ParseFeatures(protobuf::io::CodedInputStream* stream,
+                   parsed::Example* example) {
+  DCHECK(stream != nullptr);
+  DCHECK(example != nullptr);
+  uint32 length;
+  if (!stream->ReadVarint32(&length)) return false;
+  auto limit = stream->PushLimit(length);
+  while (!stream->ExpectAtEnd()) {
+    parsed::FeatureMapEntry feature_map_entry;
+    if (!stream->ExpectTag(kDelimitedTag(1))) return false;
+    if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
+    example->push_back(std::move(feature_map_entry));
+  }
+  stream->PopLimit(limit);
+  return true;
+}
+
+bool ParseExample(protobuf::io::CodedInputStream* stream,
+                  parsed::Example* example) {
+  DCHECK(stream != nullptr);
+  DCHECK(example != nullptr);
+  // Loop over the input stream which may contain multiple serialized Example
+  // protos merged together as strings. This behavior is consistent with Proto's
+  // ParseFromString when string representations are concatenated.
+  while (!stream->ExpectAtEnd()) {
+    if (!stream->ExpectTag(kDelimitedTag(1))) {
+      if (!SkipExtraneousTag(stream)) return false;
+    } else {
+      if (!ParseFeatures(stream, example)) return false;
+    }
+  }
+  return true;
+}
+
+bool ParseExample(StringPiece serialized, parsed::Example* example) {
+  DCHECK(example != nullptr);
+  protobuf::io::CodedInputStream stream(
+      reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
+  EnableAliasing(&stream);
+  return ParseExample(&stream, example);
+}
+
+template <>
+void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
+  std::move(b, e, t);
+}
+
+template <>
+const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
+  return buffer.int64_list;
+}
+template <>
+const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
+  return buffer.float_list;
+}
+template <>
+const SmallVector<tstring>& GetListFromBuffer<tstring>(
+    const SparseBuffer& buffer) {
+  return buffer.bytes_list;
+}
+
+}  // namespace example
+}  // namespace tensorflow
diff --git a/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h b/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h
new file mode 100644
index 0000000..dc0252d
--- /dev/null
+++ b/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h
@@ -0,0 +1,688 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
+#define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
+#include "tensorflow/core/util/example_proto_fast_parsing.h"
+
+#include <vector>
+
+#include "absl/base/casts.h"
+#include "absl/container/flat_hash_map.h"
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/monitoring/counter.h"
+#include "tensorflow/core/platform/byte_order.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/presized_cuckoo_map.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+namespace example {
+
+template <typename T>
+using SmallVector = gtl::InlinedVector<T, 4>;
+
+template <typename T>
+class LimitedArraySlice {
+ public:
+  using value_type = T;
+
+  LimitedArraySlice(T* begin, size_t num_elements)
+      : current_(begin), begin_(begin), end_(begin + num_elements) {}
+
+  // May return negative if there were push_back calls after slice was filled.
+  int64 EndDistance() const { return end_ - current_; }
+
+  // Attempts to push value to the back of this. If the slice has
+  // already been filled, this method has no effect on the underlying data, but
+  // it changes the number returned by EndDistance into negative values.
+  void push_back(T&& value) {
+    if (EndDistance() > 0) *current_ = std::move(value);
+    ++current_;
+  }
+
+  // "Constructs" an element at the back of this by resizing the slice, and
+  // returns a mutable reference to the new last element.
+  // REQUIRES: EndDistance() > 0.
+  T& construct_at_end() {
+    DCHECK_GT(EndDistance(), 0);
+    return *(current_++);
+  }
+
+  // Returns a mutable reference to the last element in the slice.
+  // REQUIRES: size() > 0.
+  T& back() { return *(current_ - 1); }
+
+  // Returns the number of elements in the slice.
+  size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
+
+  // Attempts to resize the vector to the given size. It does so by advancing
+  // the pointer to the current element, possibly beyond the end of the slice.
+  // As a consequence, calling `size()` after `resize(x)` was called might
+  // return a value less than `x`.
+  void resize(size_t size) { current_ = begin_ + size; }
+
+  // Returns the pointer to the underlying data buffer.
+  T* data() { return begin_; }
+
+ private:
+  T* current_;
+  T* begin_;
+  T* end_;
+};
+
+template <typename A>
+auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
+  a->EnableAliasing(true);
+}
+
+template <typename A>
+void EnableAliasing(A&& a) {}
+
+uint8 PeekTag(protobuf::io::CodedInputStream* stream);
+
+constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
+constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
+constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
+
+namespace parsed {
+
+// ParseDataType has to be called first, then appropriate ParseZzzzList.
+class Feature {
+ public:
+  Feature() {}
+  explicit Feature(StringPiece serialized) : serialized_(serialized) {}
+
+  Status ParseDataType(DataType* dtype) {
+    DCHECK(dtype != nullptr);
+    if (serialized_.empty()) {
+      *dtype = DT_INVALID;
+      return Status::OK();
+    }
+    uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
+    serialized_.remove_prefix(1);
+    switch (oneof_tag) {
+      case kDelimitedTag(1):
+        *dtype = DT_STRING;
+        break;
+      case kDelimitedTag(2):
+        *dtype = DT_FLOAT;
+        break;
+      case kDelimitedTag(3):
+        *dtype = DT_INT64;
+        break;
+      default:
+        // Initialize variable to avoid compiler warning
+        *dtype = DT_INVALID;
+        return errors::InvalidArgument("Unsupported datatype.");
+    }
+    return Status::OK();
+  }
+
+  bool GetNumElementsInBytesList(int* num_elements) {
+    protobuf::io::CodedInputStream stream(
+        reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
+    EnableAliasing(&stream);
+    uint32 length = 0;
+    if (!stream.ReadVarint32(&length)) return false;
+    auto limit = stream.PushLimit(length);
+    *num_elements = 0;
+    while (!stream.ExpectAtEnd()) {
+      if (!stream.ExpectTag(kDelimitedTag(1))) return false;
+      uint32 bytes_length = 0;
+      if (!stream.ReadVarint32(&bytes_length)) return false;
+      if (!stream.Skip(bytes_length)) return false;
+      ++*num_elements;
+    }
+    stream.PopLimit(limit);
+    return true;
+  }
+
+  // Helper methods
+  tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
+    if (bytes_list->EndDistance() <= 0) {
+      return nullptr;
+    }
+    return &bytes_list->construct_at_end();
+  }
+  tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
+    return &bytes_list->emplace_back();
+  }
+
+  template <typename Result>
+  bool ParseBytesList(Result* bytes_list) {
+    DCHECK(bytes_list != nullptr);
+
+    protobuf::io::CodedInputStream stream(
+        reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
+
+    EnableAliasing(&stream);
+
+    uint32 length;
+    if (!stream.ReadVarint32(&length)) return false;
+    auto limit = stream.PushLimit(length);
+
+    while (!stream.ExpectAtEnd()) {
+      if (!stream.ExpectTag(kDelimitedTag(1))) return false;
+      // parse string
+      uint32 bytes_length;
+      if (!stream.ReadVarint32(&bytes_length)) return false;
+      tstring* bytes = construct_at_end(bytes_list);
+      if (bytes == nullptr) return false;
+      bytes->resize_uninitialized(bytes_length);
+      if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
+    }
+    stream.PopLimit(limit);
+    return true;
+  }
+
+  template <typename Result>
+  bool ParseFloatList(Result* float_list) {
+    DCHECK(float_list != nullptr);
+    protobuf::io::CodedInputStream stream(
+        reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
+    EnableAliasing(&stream);
+    uint32 length;
+    if (!stream.ReadVarint32(&length)) return false;
+    auto limit = stream.PushLimit(length);
+
+    if (!stream.ExpectAtEnd()) {
+      uint8 peek_tag = PeekTag(&stream);
+      if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
+        return false;
+      }
+
+      constexpr int32 kNumFloatBytes = 4;
+      if (peek_tag == kDelimitedTag(1)) {                       // packed
+        if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
+        uint32 packed_length;
+        if (!stream.ReadVarint32(&packed_length)) return false;
+        auto packed_limit = stream.PushLimit(packed_length);
+
+        // Store the initial size to know the offset we have to start writing
+        // data from before resizing the output "vector".
+        const size_t initial_size = float_list->size();
+        float_list->resize(initial_size + packed_length / kNumFloatBytes);
+
+        // If the result data type is float and we are on a little endian
+        // machine then we can simply memcpy the data from the proto into the
+        // result vector.
+        if (port::kLittleEndian &&
+            sizeof(typename Result::value_type) == kNumFloatBytes) {
+          // Calculate the length of the buffer available what can be less than
+          // what we requested in resize in case of a LimitedArraySlice.
+          const uint32 bytes_to_copy =
+              std::min(static_cast<uint32>((float_list->size() - initial_size) *
+                                           kNumFloatBytes),
+                       packed_length);
+          if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
+            return false;
+        } else {
+          int64 index = initial_size;
+          while (!stream.ExpectAtEnd()) {
+            uint32 buffer32;
+            if (!stream.ReadLittleEndian32(&buffer32)) return false;
+            if (index < float_list->size()) {
+              float_list->data()[index] = absl::bit_cast<float>(buffer32);
+              ++index;
+            }
+          }
+        }
+
+        stream.PopLimit(packed_limit);
+      } else {  // non-packed
+        const size_t initial_size = float_list->size();
+        // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
+        // the value.
+        const int64 num_elements =
+            stream.BytesUntilLimit() / (1 + kNumFloatBytes);
+        float_list->resize(initial_size + num_elements);
+        int64 index = initial_size;
+        while (!stream.ExpectAtEnd()) {
+          if (!stream.ExpectTag(kFixed32Tag(1))) return false;
+          uint32 buffer32;
+          if (!stream.ReadLittleEndian32(&buffer32)) return false;
+          float_list->data()[index] = absl::bit_cast<float>(buffer32);
+          ++index;
+        }
+      }
+    }
+
+    stream.PopLimit(limit);
+    return true;
+  }
+
+  template <typename Result>
+  bool ParseInt64List(Result* int64_list) {
+    DCHECK(int64_list != nullptr);
+    protobuf::io::CodedInputStream stream(
+        reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
+    EnableAliasing(&stream);
+    uint32 length;
+    if (!stream.ReadVarint32(&length)) return false;
+    auto limit = stream.PushLimit(length);
+
+    if (!stream.ExpectAtEnd()) {
+      uint8 peek_tag = PeekTag(&stream);
+      if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
+        return false;
+      }
+      if (peek_tag == kDelimitedTag(1)) {                       // packed
+        if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
+        uint32 packed_length;
+        if (!stream.ReadVarint32(&packed_length)) return false;
+        auto packed_limit = stream.PushLimit(packed_length);
+
+        while (!stream.ExpectAtEnd()) {
+          protobuf_uint64 n;  // There is no API for int64
+          if (!stream.ReadVarint64(&n)) return false;
+          int64_list->push_back(static_cast<int64>(n));
+        }
+
+        stream.PopLimit(packed_limit);
+      } else {  // non-packed
+        while (!stream.ExpectAtEnd()) {
+          if (!stream.ExpectTag(kVarintTag(1))) return false;
+          protobuf_uint64 n;  // There is no API for int64
+          if (!stream.ReadVarint64(&n)) return false;
+          int64_list->push_back(static_cast<int64>(n));
+        }
+      }
+    }
+    stream.PopLimit(limit);
+    return true;
+  }
+
+  StringPiece GetSerialized() const { return serialized_; }
+
+ private:
+  StringPiece serialized_;
+};
+
+using FeatureMapEntry = std::pair<StringPiece, Feature>;
+using Example = std::vector<FeatureMapEntry>;
+
+}  // namespace parsed
+
+inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
+  uint32 data;
+  protobuf_uint64 dummy;
+  switch (stream->ReadTag() & 0x7) {
+    case 0:  // varint
+      if (!stream->ReadVarint32(&data)) return false;
+      return true;
+    case 1:  // fixed64
+      if (!stream->ReadLittleEndian64(&dummy)) return false;
+      return true;
+    case 2:  // length delimited
+      if (!stream->ReadVarint32(&data)) return false;
+      stream->Skip(data);
+      return true;
+    case 3:          // group begin
+      return false;  // groups not supported.
+    case 4:          // group end
+      return false;  // groups not supported.
+    case 5:          // fixed32
+      if (!stream->ReadLittleEndian32(&data)) return false;
+      return true;
+  }
+  return false;  // unrecognized tag type
+}
+
+bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result);
+
+bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
+                          parsed::FeatureMapEntry* feature_map_entry);
+
+bool ParseFeatures(protobuf::io::CodedInputStream* stream,
+                   parsed::Example* example);
+
+bool ParseExample(protobuf::io::CodedInputStream* stream,
+                  parsed::Example* example);
+
+bool ParseExample(StringPiece serialized, parsed::Example* example);
+
+using Config = FastParseExampleConfig;
+
+// Enumeration for distinguishing feature types.
+// Note: FastParseSequenceExample constructs a map that includes Type values,
+// and relies on the fact that they are default-initialized to Dense.
+enum class Type { Dense, Sparse, Ragged };
+
+// Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
+struct SparseBuffer {
+  // Features are in one of the 3 vectors below depending on config's dtype.
+  // Other 2 vectors remain empty.
+  SmallVector<tstring> bytes_list;
+  SmallVector<float> float_list;
+  SmallVector<int64> int64_list;
+
+  // Features of example i are elements with indices
+  // from example_end_indices[i-1] to example_end_indices[i]-1 on the
+  // appropriate xxxxx_list
+  std::vector<size_t> example_end_indices;
+};
+
+struct SeededHasher {
+  uint64 operator()(StringPiece s) const {
+    return Hash64(s.data(), s.size(), seed);
+  }
+  uint64 seed{0xDECAFCAFFE};
+};
+
+// Use this in the "default" clause of switch statements when dispatching
+// on a dtype variable that was checked by CheckConfigDataType():
+inline void ReportUnexpectedDataType(DataType dtype) {
+  DCHECK(false)
+      << "Encountered unexpected DataType " << DataTypeString(dtype)
+      << "in variable that should have been checked by CheckConfigDataType().";
+}
+
+template <typename T>
+const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
+
+template <>
+const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer);
+
+template <>
+const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer);
+
+template <>
+const SmallVector<tstring>& GetListFromBuffer<tstring>(
+    const SparseBuffer& buffer);
+
+template <typename T>
+void CopyOrMoveBlock(const T* b, const T* e, T* t) {
+  std::copy(b, e, t);
+}
+template <>
+void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t);
+
+void CountSparseFeatures(
+    const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
+    size_t* total_num_features, size_t* max_num_features);
+
+void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
+                              Tensor* dst);
+
+// A struct used by FastParseSequenceExample to hold the serialized proto
+// substrings for a single feature, plus some auxiliary information derived
+// from those protos (such as the total value length).
+struct FeatureProtos {
+  // Proto substrings from each serialized SequenceExample that correspond
+  // with this feature.  `protos_present` records whether the proto had a
+  // value defined (even if that value is empty).
+  std::vector<StringPiece> protos;
+  std::vector<bool> protos_present;
+
+  // Information derived from protos:
+  size_t length;    // total length for ragged/sparse, max row length for dense.
+  size_t num_rows;  // only populated for ragged sequence features.
+
+  // Information from the config:
+  Type type;  // Whether this feature is sparse, ragged, or dense.
+  DataType dtype;
+};
+
+// Map from feature name to FeatureProtos for that feature.
+using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
+
+string ExampleName(const gtl::ArraySlice<tstring> example_names, int n);
+
+// Return the number of bytes elements parsed, or -1 on error. If out is null,
+// this method simply counts the number of elements without any copying.
+inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
+                             tstring* out) {
+  int num_elements = 0;
+  uint32 length;
+  if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
+    return -1;
+  }
+  if (length > 0) {
+    auto limit = stream->PushLimit(length);
+    while (!stream->ExpectAtEnd()) {
+      uint32 bytes_length;
+      if (!stream->ExpectTag(kDelimitedTag(1)) ||
+          !stream->ReadVarint32(&bytes_length)) {
+        return -1;
+      }
+      if (out == nullptr) {
+        stream->Skip(bytes_length);
+      } else {
+        out->resize_uninitialized(bytes_length);
+        if (!stream->ReadRaw(out->data(), bytes_length)) {
+          return -1;
+        }
+        out++;
+      }
+      num_elements++;
+    }
+    stream->PopLimit(limit);
+  }
+  return num_elements;
+}
+
+inline void PadFloatFeature(int num_to_pad, float* out) {
+  for (int i = 0; i < num_to_pad; i++) {
+    *out++ = 0.0;
+  }
+}
+
+inline void PadInt64Feature(int num_to_pad, int64* out) {
+  for (int i = 0; i < num_to_pad; i++) {
+    *out++ = 0;
+  }
+}
+
+// Return the number of float elements parsed, or -1 on error. If out is null,
+// this method simply counts the number of elements without any copying.
+inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
+                             float* out) {
+  int num_elements = 0;
+  uint32 length;
+  if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
+    return -1;
+  }
+  if (length > 0) {
+    auto limit = stream->PushLimit(length);
+    uint8 peek_tag = PeekTag(stream);
+    if (peek_tag == kDelimitedTag(1)) {  // packed
+      uint32 packed_length;
+      if (!stream->ExpectTag(kDelimitedTag(1)) ||
+          !stream->ReadVarint32(&packed_length)) {
+        return -1;
+      }
+      auto packed_limit = stream->PushLimit(packed_length);
+      while (!stream->ExpectAtEnd()) {
+        uint32 buffer32;
+        if (!stream->ReadLittleEndian32(&buffer32)) {
+          return -1;
+        }
+        if (out != nullptr) {
+          *out++ = absl::bit_cast<float>(buffer32);
+        }
+        num_elements++;
+      }
+      stream->PopLimit(packed_limit);
+    } else if (peek_tag == kFixed32Tag(1)) {
+      while (!stream->ExpectAtEnd()) {
+        uint32 buffer32;
+        if (!stream->ExpectTag(kFixed32Tag(1)) ||
+            !stream->ReadLittleEndian32(&buffer32)) {
+          return -1;
+        }
+        if (out != nullptr) {
+          *out++ = absl::bit_cast<float>(buffer32);
+        }
+        num_elements++;
+      }
+    } else {
+      // Unknown tag.
+      return -1;
+    }
+    stream->PopLimit(limit);
+  }
+  return num_elements;
+}
+
+// Return the number of int64 elements parsed, or -1 on error. If out is null,
+// this method simply counts the number of elements without any copying.
+inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
+                             int64* out) {
+  int num_elements = 0;
+  uint32 length;
+  if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
+    return -1;
+  }
+  if (length > 0) {
+    auto limit = stream->PushLimit(length);
+    uint8 peek_tag = PeekTag(stream);
+    if (peek_tag == kDelimitedTag(1)) {  // packed
+      uint32 packed_length;
+      if (!stream->ExpectTag(kDelimitedTag(1)) ||
+          !stream->ReadVarint32(&packed_length)) {
+        return -1;
+      }
+      auto packed_limit = stream->PushLimit(packed_length);
+      while (!stream->ExpectAtEnd()) {
+        protobuf_uint64 n;  // There is no API for int64
+        if (!stream->ReadVarint64(&n)) {
+          return -1;
+        }
+        if (out != nullptr) {
+          *out++ = n;
+        }
+        num_elements++;
+      }
+      stream->PopLimit(packed_limit);
+    } else if (peek_tag == kVarintTag(1)) {
+      while (!stream->ExpectAtEnd()) {
+        protobuf_uint64 n;  // There is no API for int64
+        if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
+          return -1;
+        }
+        if (out != nullptr) {
+          *out++ = n;
+        }
+        num_elements++;
+      }
+    } else {
+      // Unknown tag.
+      return -1;
+    }
+    stream->PopLimit(limit);
+  }
+  return num_elements;
+}
+
+// Parses the next feature on `stream` into `out` starting at `out_offset`.
+// Updates `out_offset`, and returns the number of values added.
+// Returns -1 if the next feature on `stream` doesn't match `dtype`.
+inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
+                        Tensor* out, size_t* out_offset) {
+  int delta;
+  switch (dtype) {
+    case DT_STRING:
+      delta =
+          ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
+      break;
+    case DT_FLOAT:
+      delta =
+          ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
+      break;
+    case DT_INT64:
+      delta =
+          ParseInt64Feature(stream, out->flat<int64>().data() + *out_offset);
+      break;
+    default:
+      ReportUnexpectedDataType(dtype);
+      delta = 0;
+  }
+  if (delta > 0) {
+    *out_offset += delta;
+  }
+  return delta;
+}
+
+// Returns the length of the next feature on `stream`.
+// Returns -1 if the next feature on `stream` doesn't match `dtype`.
+inline int GetFeatureLength(DataType dtype,
+                            protobuf::io::CodedInputStream* stream) {
+  switch (dtype) {
+    case DT_STRING:
+      return ParseBytesFeature(stream, nullptr);
+    case DT_FLOAT:
+      return ParseFloatFeature(stream, nullptr);
+    case DT_INT64:
+      return ParseInt64Feature(stream, nullptr);
+    default:
+      ReportUnexpectedDataType(dtype);
+      return -1;
+  }
+}
+
+inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
+  uint8 peek_tag = PeekTag(stream);
+  switch (peek_tag) {
+    case kDelimitedTag(1):
+      return DT_STRING;
+    case kDelimitedTag(2):
+      return DT_FLOAT;
+    case kDelimitedTag(3):
+      return DT_INT64;
+    default:
+      return DT_INVALID;
+  }
+}
+
+inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
+                             DataType dtype) {
+  switch (dtype) {
+    case DT_STRING:
+      if (!stream->ExpectTag(kDelimitedTag(1))) {
+        return false;
+      }
+      break;
+    case DT_FLOAT:
+      if (!stream->ExpectTag(kDelimitedTag(2))) {
+        return false;
+      }
+      break;
+    case DT_INT64:
+      if (!stream->ExpectTag(kDelimitedTag(3))) {
+        return false;
+      }
+      break;
+    default:
+      return false;
+  }
+  uint32 length;
+  return stream->ReadVarint32(&length) && length == 0;
+}
+
+}  // namespace example
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
diff --git a/tensorflow/lite/kernels/parse_example/parse_example.cc b/tensorflow/lite/kernels/parse_example/parse_example.cc
new file mode 100644
index 0000000..2f3c98f
--- /dev/null
+++ b/tensorflow/lite/kernels/parse_example/parse_example.cc
@@ -0,0 +1,1004 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/kernels/parse_example/parse_example.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+#include <unordered_map>
+
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
+#include "tensorflow/core/example/feature.pb.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/example_proto_fast_parsing.h"
+#include "tensorflow/core/util/presized_cuckoo_map.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace parse_example {
+namespace {
+
+namespace tf = ::tensorflow;
+using tf::Status;
+using tf::StringPiece;
+using tf::tstring;
+using tf::example::CopyOrMoveBlock;
+using tf::example::FastParseExampleConfig;
+using tf::example::GetListFromBuffer;
+using tf::example::LimitedArraySlice;
+using tf::example::ParseExample;
+using tf::example::SeededHasher;
+using tf::example::SmallVector;
+using tf::example::SparseBuffer;
+using tf::example::Type;
+using tf::example::parsed::Example;
+
+using ConfigIndex = tf::PresizedCuckooMap<std::pair<int32_t, Type>>;
+
+struct TfLiteResult {
+  std::vector<TfLiteTensor*> dense_values;
+  std::vector<TfLiteTensor*> sparse_values;
+  std::vector<TfLiteTensor*> sparse_indices;
+  std::vector<TfLiteTensor*> sparse_shapes;
+  std::map<int, tf::Tensor> dense_tensors;
+};
+
+template <typename T>
+void FillAndCopyVarLen(const int d, const size_t num_elements,
+                       const size_t num_elements_per_minibatch,
+                       const FastParseExampleConfig& config,
+                       std::vector<SparseBuffer>& varlen_dense_buffers,
+                       TfLiteTensor* values) {
+  const tf::Tensor& default_value = config.dense[d].default_value;
+
+  // Copy-fill the tensors (creating the zero/fill-padding)
+  std::fill(reinterpret_cast<T*>(values->data.raw),
+            reinterpret_cast<T*>(values->data.raw) + num_elements,
+            default_value.flat<T>()(0));
+
+  auto data = reinterpret_cast<T*>(values->data.raw);
+
+  const SparseBuffer& buffer = varlen_dense_buffers[d];
+  // Number of examples being stored in this buffer
+  const auto& end_indices = buffer.example_end_indices;
+  const size_t examples_in_buffer = end_indices.size();
+
+  const auto& list = GetListFromBuffer<T>(buffer);
+  auto list_ptr = list.begin();
+
+  size_t elements_tally = 0;
+  // Iterate through all the examples stored in this buffer.
+  for (size_t j = 0; j < examples_in_buffer; ++j) {
+    // Number of elements stored for this example.
+    const size_t num_elems = end_indices[j] - elements_tally;
+    CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
+    // Move forward this many elements in the varlen buffer.
+    list_ptr += num_elems;
+    // Move forward to the next minibatch entry in the values output.
+    data += num_elements_per_minibatch;
+    elements_tally = end_indices[j];
+  }
+  DCHECK(elements_tally == list.size());
+}
+
+bool ParseExample(StringRef serialized, Example* example) {
+  DCHECK(example != nullptr);
+  tf::protobuf::io::CodedInputStream stream(
+      reinterpret_cast<const uint8*>(serialized.str), serialized.len);
+  tensorflow::example::EnableAliasing(&stream);
+  return ParseExample(&stream, example);
+}
+
+Status FastParseSerializedExample(
+    StringRef serialized_example, const tstring& example_name,
+    const size_t example_index, const FastParseExampleConfig& config,
+    bool* quick_filter, int quick_filter_size,
+    const std::unique_ptr<ConfigIndex>& config_index, int config_index_size,
+    SeededHasher* hasher, std::vector<TfLiteTensor*>* output_dense,
+    std::vector<SparseBuffer>* output_varlen_dense,
+    std::vector<SparseBuffer>* output_sparse,
+    std::map<absl::string_view, int>& stats, TfLiteResult* result) {
+  DCHECK(output_dense != nullptr);
+  tensorflow::example::parsed::Example parsed_example;
+  if (!ParseExample(serialized_example, &parsed_example)) {
+    return tf::errors::Internal("Failed to parse example");
+  }
+  std::vector<tf::int64> dense_feature_last_example(config.dense.size(), -1);
+  std::vector<tf::int64> sparse_feature_last_example(config.sparse.size(), -1);
+  // Handle features present in the example.
+  const size_t parsed_example_size = parsed_example.size();
+  for (size_t i = 0; i < parsed_example_size; ++i) {
+    // This is a logic that standard protobuf parsing is implementing.
+    // I.e. last entry in the map overwrites all the previous ones.
+    tensorflow::example::parsed::FeatureMapEntry& name_and_feature =
+        parsed_example[parsed_example_size - i - 1];
+    const StringPiece feature_name = name_and_feature.first;
+    tensorflow::example::parsed::Feature& feature = name_and_feature.second;
+    if (feature_name.length() >= quick_filter_size ||
+        !quick_filter[feature_name.length()]) {
+      continue;
+    }
+    const uint64_t h = (*hasher)(feature_name);
+    std::pair<int32_t, Type> d_and_type;
+    if (!config_index->Find(h, &d_and_type)) {
+      continue;
+    }
+    size_t d = d_and_type.first;
+    bool is_dense = d_and_type.second == Type::Dense;
+
+    auto example_error = [&](StringPiece suffix) {
+      return tf::errors::Internal("Name: ", example_name,
+                                  ", Key: ", feature_name,
+                                  ", Index: ", example_index, ".  ", suffix);
+    };
+
+    auto parse_error = [&] {
+      return example_error("Can't parse serialized Example.");
+    };
+
+    tf::DataType example_dtype;
+    if (feature.ParseDataType(&example_dtype) != Status::OK()) {
+      return parse_error();
+    }
+    if (is_dense) {
+      if (example_dtype == tf::DT_INVALID) continue;
+
+      dense_feature_last_example[d] = example_index;
+
+      if (example_dtype != config.dense[d].dtype) {
+        return example_error(absl::StrCat(
+            "Data types don't match. Data type: ",
+            DataTypeString(example_dtype),
+            " but expected type: ", DataTypeString(config.dense[d].dtype)));
+      }
+      if (!config.dense[d].variable_length) {
+        TfLiteTensor* out = (*output_dense)[d];
+
+        const std::size_t num_elements = config.dense[d].elements_per_stride;
+        const std::size_t offset = example_index * num_elements;
+
+        auto shape_error = [&](size_t size, StringPiece type_str) {
+          return example_error(absl::StrCat(
+              "Number of ", type_str,
+              " values != expected.  "
+              "Values size:",
+              size,
+              " but output shape: ", config.dense[d].shape.DebugString()));
+        };
+
+        switch (config.dense[d].dtype) {
+          case tf::DT_INT64: {
+            auto out_p = reinterpret_cast<tf::int64*>(out->data.raw) + offset;
+            LimitedArraySlice<tf::int64> slice(out_p, num_elements);
+            if (!feature.ParseInt64List(&slice)) return parse_error();
+            if (slice.EndDistance() != 0) {
+              return shape_error(num_elements - slice.EndDistance(), "int64");
+            }
+            break;
+          }
+          case tf::DT_FLOAT: {
+            auto out_p = reinterpret_cast<float*>(out->data.raw) + offset;
+            LimitedArraySlice<float> slice(out_p, num_elements);
+            if (!feature.ParseFloatList(&slice)) return parse_error();
+            if (slice.EndDistance() != 0) {
+              return shape_error(num_elements - slice.EndDistance(), "float");
+            }
+            break;
+          }
+          case tf::DT_STRING: {
+            auto& out_tensor = result->dense_tensors[d];
+            auto out_p = out_tensor.flat<tstring>().data() + offset;
+            LimitedArraySlice<tstring> slice(out_p, num_elements);
+            if (!feature.ParseBytesList(&slice)) return parse_error();
+            if (slice.EndDistance() != 0) {
+              return shape_error(num_elements - slice.EndDistance(), "bytes");
+            }
+            break;
+          }
+          default:
+            return tf::errors::Internal("Unrecognized dense type: ",
+                                        config.dense[d].dtype);
+        }
+      } else {  // if dense variable length
+        SparseBuffer& out = (*output_varlen_dense)[d];
+
+        const std::size_t num_elements = config.dense[d].elements_per_stride;
+
+        if (example_dtype != tf::DT_INVALID &&
+            example_dtype != config.dense[d].dtype) {
+          return example_error(absl::StrCat(
+              "Data types don't match. ",
+              "Expected type: ", DataTypeString(config.dense[d].dtype)));
+        }
+
+        auto shape_error = [&](size_t size, StringPiece type_str) {
+          return example_error(
+              absl::StrCat("Number of ", type_str,
+                           " values is not a multiple of stride length. Saw ",
+                           size, " values but output shape is: ",
+                           config.dense[d].shape.DebugString()));
+        };
+
+        switch (config.dense[d].dtype) {
+          case tf::DT_INT64: {
+            if (example_dtype != tf::DT_INVALID) {
+              if (!feature.ParseInt64List(&out.int64_list)) {
+                return parse_error();
+              }
+              if (out.int64_list.size() % num_elements != 0) {
+                return shape_error(out.int64_list.size(), "int64");
+              }
+            }
+            out.example_end_indices.push_back(out.int64_list.size());
+            break;
+          }
+          case tf::DT_FLOAT: {
+            if (example_dtype != tf::DT_INVALID) {
+              if (!feature.ParseFloatList(&out.float_list)) {
+                return parse_error();
+              }
+              if (out.float_list.size() % num_elements != 0) {
+                return shape_error(out.float_list.size(), "float");
+              }
+            }
+            out.example_end_indices.push_back(out.float_list.size());
+            break;
+          }
+          case tf::DT_STRING: {
+            if (example_dtype != tf::DT_INVALID) {
+              if (!feature.ParseBytesList(&out.bytes_list)) {
+                return parse_error();
+              }
+              if (out.bytes_list.size() % num_elements != 0) {
+                return shape_error(out.bytes_list.size(), "byte");
+              }
+            }
+            out.example_end_indices.push_back(out.bytes_list.size());
+            break;
+          }
+          default:
+            return tf::errors::Internal("Should not happen: ",
+                                        config.dense[d].dtype);
+        }
+      }
+    } else {
+      // is sparse or ragged
+      auto& last_example = sparse_feature_last_example;
+      if (last_example[d] == example_index) {
+        continue;
+      }
+      last_example[d] = example_index;
+      SparseBuffer& out = (*output_sparse)[d];
+      tf::DataType feature_dtype = config.sparse[d].dtype;
+      if (example_dtype != tf::DT_INVALID && example_dtype != feature_dtype) {
+        return tf::errors::Internal("Data types don't match:", example_dtype,
+                                    " != ", feature_dtype);
+      }
+      switch (feature_dtype) {
+        case tf::DT_INT64: {
+          if (example_dtype != tf::DT_INVALID) {
+            if (!feature.ParseInt64List(&out.int64_list)) {
+              return parse_error();
+            }
+          }
+          out.example_end_indices.push_back(out.int64_list.size());
+          break;
+        }
+        case tf::DT_FLOAT: {
+          if (example_dtype != tf::DT_INVALID) {
+            if (!feature.ParseFloatList(&out.float_list)) {
+              return parse_error();
+            }
+          }
+          out.example_end_indices.push_back(out.float_list.size());
+          break;
+        }
+        case tf::DT_STRING: {
+          if (example_dtype != tf::DT_INVALID) {
+            if (!feature.ParseBytesList(&out.bytes_list)) {
+              return parse_error();
+            }
+          }
+          out.example_end_indices.push_back(out.bytes_list.size());
+          break;
+        }
+        default:
+          return tf::errors::Internal("Should not happen: ", feature_dtype);
+      }
+    }
+  }
+  // Handle missing dense features for fixed strides.
+  for (size_t d = 0; d < config.dense.size(); ++d) {
+    if (config.dense[d].variable_length) continue;
+    if (dense_feature_last_example[d] == example_index) continue;
+    if (config.dense[d].default_value.NumElements() == 0) {
+      return tf::errors::Internal(
+          "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
+          " (data type: ", DataTypeString(config.dense[d].dtype), ")",
+          " is required but could not be found.");
+    }
+    const tf::Tensor& in = config.dense[d].default_value;
+    TfLiteTensor* out = result->dense_values[d];
+    const std::size_t num_elements = in.shape().num_elements();
+    const std::size_t offset = example_index * num_elements;
+    switch (config.dense[d].dtype) {
+      case tf::DT_INT64: {
+        std::copy_n(in.flat<tf::int64>().data(), num_elements,
+                    out->data.i64 + offset);
+        break;
+      }
+      case tf::DT_FLOAT: {
+        std::copy_n(in.flat<float>().data(), num_elements,
+                    out->data.f + offset);
+        break;
+      }
+      case tf::DT_STRING: {
+        auto& out_tensor = result->dense_tensors[d];
+        std::copy_n(in.flat<tstring>().data(), num_elements,
+                    out_tensor.flat<tstring>().data() + offset);
+        break;
+      }
+      default:
+        return tf::errors::Internal("Should not happen: ",
+                                    config.dense[d].dtype);
+    }
+  }
+  for (size_t d = 0; d < config.dense.size(); ++d) {
+    if (!config.dense[d].variable_length) continue;
+    if (dense_feature_last_example[d] == example_index) continue;
+    SparseBuffer& out = (*output_varlen_dense)[d];
+    size_t prev_example_end_index =
+        out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
+    out.example_end_indices.push_back(prev_example_end_index);
+  }
+
+  for (size_t d = 0; d < config.sparse.size(); ++d) {
+    if (sparse_feature_last_example[d] == example_index) continue;
+    SparseBuffer& out = (*output_sparse)[d];
+    size_t prev_example_end_index =
+        out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
+    out.example_end_indices.push_back(prev_example_end_index);
+  }
+
+  return Status::OK();
+}
+
+void CountSparseFeatures(const SparseBuffer& sparse_buffer,
+                         size_t* total_num_features, size_t* max_num_features) {
+  const std::vector<size_t>& end_indices = sparse_buffer.example_end_indices;
+  *total_num_features += end_indices.back();
+  *max_num_features = std::max(*max_num_features, end_indices[0]);
+  for (size_t i = 1; i < end_indices.size(); ++i) {
+    size_t example_size = end_indices[i] - end_indices[i - 1];
+    *max_num_features = std::max(*max_num_features, example_size);
+  }
+}
+
+void CopySparseBufferToTensor(tf::DataType dtype, size_t offset,
+                              SparseBuffer* src, TfLiteTensor* dst) {
+  switch (dtype) {
+    case tf::DT_INT64: {
+      std::copy(src->int64_list.begin(), src->int64_list.end(),
+                reinterpret_cast<int64_t*>(dst->data.raw) + offset);
+      break;
+    }
+    case tf::DT_FLOAT: {
+      std::copy(src->float_list.begin(), src->float_list.end(),
+                reinterpret_cast<float*>(dst->data.raw) + offset);
+      break;
+    }
+    case tf::DT_STRING: {
+      DynamicBuffer buffer;
+      for (auto* begin = src->bytes_list.begin();
+           begin != src->bytes_list.end(); begin++) {
+        buffer.AddString(begin->c_str(), begin->size());
+      }
+      buffer.WriteToTensor(dst, nullptr);
+      break;
+    }
+    default:
+      DCHECK(false) << "Encountered unexpected DataType "
+                    << DataTypeString(dtype)
+                    << "in variable that should have been checked.";
+  }
+}
+
+inline void CopyToBuffer(tf::gtl::ArraySlice<tstring> vec, char* tensor_buffer,
+                         int num_examples, int batch_size,
+                         int elements_per_stride) {
+  int i = 0, k = 0;
+  int start = 0;
+  for (; i < num_examples; ++i) {
+    for (int j = 0; j < elements_per_stride; ++j) {
+      memcpy(tensor_buffer + start, vec[k].c_str(), vec[k].size());
+      start += vec[k].size();
+      k++;
+    }
+  }
+  // Will happen if the number of examples is less than the desired batch size.
+  for (; i < batch_size; ++i) {
+    for (int j = 0; j < elements_per_stride; ++j) {
+      memcpy(tensor_buffer + start, vec[k].c_str(), vec[k].size());
+      start += vec[k].size();
+      k++;
+    }
+  }
+}
+
+Status FastParseExampleLite(
+    const FastParseExampleConfig& config, const TfLiteTensor* serialized,
+    tf::gtl::ArraySlice<tstring> example_names, bool* quick_filter,
+    int quick_filter_size, const std::unique_ptr<ConfigIndex>& config_index,
+    int config_index_size, SeededHasher* hasher, TfLiteResult* result,
+    std::map<absl::string_view, int>& stats, TfLiteContext* context) {
+  if (result == nullptr) {
+    return tf::errors::Internal("Result is null");
+  }
+  const int count = GetStringCount(serialized);
+  std::vector<tf::Tensor> fixed_dense_values(config.dense.size());
+  std::vector<SparseBuffer> sparse_buffers(config.sparse.size());
+  std::vector<SparseBuffer> varlen_dense_buffers(config.dense.size());
+  Status status_of_minibatch;
+  for (size_t e = 0; e < count; ++e) {
+    Status status_of_minibatch = FastParseSerializedExample(
+        GetString(serialized, e),
+        (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
+        quick_filter, quick_filter_size, config_index, config_index_size,
+        hasher, &result->dense_values, &varlen_dense_buffers, &sparse_buffers,
+        /*arena,*/ stats, result);
+    if (!status_of_minibatch.ok()) break;
+  }
+  if (!status_of_minibatch.ok()) {
+    return status_of_minibatch;
+  }
+  // Merge SparseBuffers from all minibatches for every config.sparse.
+  // auto MergeSparseMinibatches = [&](size_t d) {
+  // Loop over minibatches
+  for (size_t d = 0; d < config.sparse.size(); ++d) {
+    size_t total_num_features = 0;
+    size_t max_num_features = 0;
+    CountSparseFeatures(sparse_buffers[d], &total_num_features,
+                        &max_num_features);
+    tf::TensorShape indices_shape;
+    TfLiteTensor* indices = result->sparse_indices[d];
+    TfLiteTensor* values = result->sparse_values[d];
+
+    TfLiteTensor* dense_shape = result->sparse_shapes[d];
+    auto* dense_shape_ptr = reinterpret_cast<int64_t*>(dense_shape->data.raw);
+    dense_shape_ptr[1] = max_num_features;
+
+    TfLiteIntArray* index_shape = TfLiteIntArrayCreate(2);
+    index_shape->data[0] = total_num_features;
+    index_shape->data[1] = 2;
+    context->ResizeTensor(context, indices, index_shape);
+
+    TfLiteIntArray* output_shape = TfLiteIntArrayCreate(1);
+    output_shape->data[0] = total_num_features;
+    context->ResizeTensor(context, values, output_shape);
+
+    SparseBuffer& buffer = sparse_buffers[d];
+
+    // Update indices.
+    auto* indices_p = reinterpret_cast<int64_t*>(indices->data.raw);
+    if (!indices_p) {
+      return tf::errors::Internal("Indices tensor not allocated!");
+    }
+
+    if (total_num_features > 0) {
+      int64_t* ix_p = indices_p;
+      size_t example_index = 0;
+      int idx0 = 0;
+      size_t delta = 0;
+      for (size_t example_end_index : buffer.example_end_indices) {
+        size_t feature_index = 0;
+        for (; delta < example_end_index; ++delta) {
+          // Column 0: example index
+          if (idx0 < total_num_features) {
+            *ix_p = example_index;
+            // Column 1: the feature index buffer example
+            *(ix_p + 1) = feature_index;
+            ix_p += 2;
+          }
+          ++feature_index;
+          ++idx0;
+        }
+        ++example_index;
+      }
+      CopySparseBufferToTensor(config.sparse[d].dtype, 0, &buffer, values);
+    }
+  }
+
+  // Merge SparseBuffers from all minibatches for every config.dense having
+  // variable_length.
+  for (size_t d = 0; d < config.dense.size(); ++d) {
+    if (!config.dense[d].variable_length) {
+      continue;
+    }
+    size_t max_num_features = 0;
+    std::vector<size_t>& end_indices =
+        varlen_dense_buffers[d].example_end_indices;
+    max_num_features = std::max(max_num_features, end_indices[0]);
+    for (size_t i = 1; i < end_indices.size(); ++i) {
+      size_t example_size = end_indices[i] - end_indices[i - 1];
+      max_num_features = std::max(max_num_features, example_size);
+    }
+
+    const size_t stride_size = config.dense[d].elements_per_stride;
+    const size_t max_num_elements = max_num_features / stride_size;
+    tf::TensorShape values_shape;
+    DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
+    const size_t batch_size = GetStringCount(serialized);
+    values_shape.AddDim(batch_size);
+    values_shape.AddDim(max_num_elements);
+    for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
+      values_shape.AddDim(config.dense[d].shape.dim_size(i));
+    }
+    TfLiteTensor* values = result->dense_values[d];
+    const size_t num_elements = GetTensorShape(values).FlatSize();
+
+    // Nothing to write, exit early.
+    if (num_elements == 0) {
+      continue;
+    }
+
+    const size_t num_elements_per_minibatch = num_elements / batch_size;
+    switch (config.dense[d].dtype) {
+      case tf::DT_INT64: {
+        FillAndCopyVarLen<tf::int64>(d, num_elements,
+                                     num_elements_per_minibatch, config,
+                                     varlen_dense_buffers, values);
+        break;
+      }
+      case tf::DT_FLOAT: {
+        FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
+                                 config, varlen_dense_buffers, values);
+        break;
+      }
+      default:
+        DCHECK(false) << "Encountered unexpected DataType "
+                      << config.dense[d].dtype
+                      << "in variable that should have been checked";
+    }
+  }
+
+  // Merge tflite string buffers if necessary.
+  for (size_t d = 0; d < config.dense.size(); ++d) {
+    if (config.dense[d].variable_length) {
+      continue;
+    }
+    if (result->dense_values[d]->type == kTfLiteString) {
+      auto& in = result->dense_tensors[d];
+      auto vec = in.vec<tstring>();
+      const int batch_size = result->dense_values[d]->dims->data[0];
+      const int elements_per_stride = config.dense[d].elements_per_stride;
+      int total_size = 0;
+      std::vector<int32_t> offsets;
+      offsets.reserve(vec.size() + 1);
+      offsets.push_back(0);
+      int k = 0;
+      for (int i = 0; i < batch_size; ++i) {
+        for (int j = 0; j < elements_per_stride; ++j) {
+          if (i < count) {
+            total_size += vec(k++).size();
+            offsets.push_back(total_size);
+          } else {
+            offsets.push_back(total_size);
+          }
+        }
+      }
+      const int32_t num_strings = offsets.size() - 1;
+      const size_t required_bytes = sizeof(int32_t) * (num_strings + 2) +
+          total_size;
+      char* tensor_buffer =
+          reinterpret_cast<char*>(result->dense_values[d]->data.raw);
+      if (result->dense_values[d]->bytes < required_bytes) {
+        if (result->dense_values[d]->data.raw) {
+          free(result->dense_values[d]->data.raw);
+        }
+        tensor_buffer = reinterpret_cast<char*>(malloc(required_bytes));
+        result->dense_values[d]->data.raw = tensor_buffer;
+        result->dense_values[d]->bytes = required_bytes;
+      }
+      const int32_t start = sizeof(int32_t) * (num_strings + 2);
+      memcpy(tensor_buffer, &num_strings, sizeof(int32_t));
+      for (size_t i = 0; i < offsets.size(); i++) {
+        int32_t offset_i = start + offsets[i];
+        memcpy(tensor_buffer + sizeof(int32_t) * (i + 1), &offset_i,
+               sizeof(int32_t));
+      }
+      tf::gtl::ArraySlice<tstring> slice(vec.data(), vec.size());
+      CopyToBuffer(slice, tensor_buffer + start, count, batch_size,
+                   elements_per_stride);
+    }
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
+enum InputTensor {
+  kExampleTensor = 0,
+  kNamesTensor = 1,
+  kSparseKeysTensor = 2,
+  kDenseKeysTensor = 3,
+  kRaggedKeysTensor = 4,
+};
+
+struct OpData {
+  FastParseExampleConfig config;
+  std::vector<tf::TensorShape> dense_shapes;
+  int dense_size = 0;
+  int sparse_size = 0;
+  std::unique_ptr<ConfigIndex> config_index;
+  int config_index_size;
+  SeededHasher hasher;
+  TfLiteResult got;
+  bool* quick_filter = nullptr;
+  int quick_filter_size;
+  bool created = false;
+  ~OpData() {
+    if (quick_filter) {
+      free(quick_filter);
+    }
+  }
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+  return new OpData;
+}
+
+template <typename T>
+tf::Tensor AsTensor(const std::vector<T>& val) {
+  tf::Tensor ret(tf::DataTypeToEnum<T>::value,
+                 {static_cast<tf::int64>(val.size())});
+  std::copy_n(val.begin(), val.size(), ret.flat<T>().data());
+  return ret;
+}
+
+enum Version {
+  V1,
+  V2,
+};
+
+tf::TensorShape TfLiteToTfShape(TfLiteIntArray* array) {
+  tf::TensorShape shape;
+  for (int i = 0; i < array->size; i++) {
+    shape.AddDim(array->data[i]);
+  }
+  return shape;
+}
+
+template <Version version>
+TfLiteStatus PrepareParseExample(TfLiteContext* context, TfLiteNode* node) {
+  OpData* data = reinterpret_cast<OpData*>(node->user_data);
+  TF_LITE_ENSURE(context, node->custom_initial_data);
+  data->config.dense.clear();
+  data->config.sparse.clear();
+  data->got.dense_values.clear();
+  const flexbuffers::Vector& v =
+      flexbuffers::GetRoot(
+          reinterpret_cast<const uint8_t*>(node->custom_initial_data),
+          node->custom_initial_data_size)
+          .AsVector();
+  if (v.size() == 2) {
+    tf::NodeDef nodedef;
+    TF_LITE_ENSURE_EQ(context, nodedef.ParseFromString(v[1].AsString().str()),
+                      true);
+    if (version == V1) {
+      data->dense_size = nodedef.attr().at("Ndense").i();
+      data->sparse_size = nodedef.attr().at("Nsparse").i();
+    } else if (version == V2) {
+      data->dense_size = nodedef.attr().at("Tdense").list().type_size();
+      data->sparse_size = nodedef.attr().at("num_sparse").i();
+    }
+    auto dense_shapes = nodedef.attr().at("dense_shapes").list();
+    for (int i = 0; i < dense_shapes.shape_size(); ++i) {
+      data->dense_shapes.push_back(dense_shapes.shape(i));
+    }
+  } else {
+    const flexbuffers::Map& m =
+        flexbuffers::GetRoot(
+            reinterpret_cast<const uint8_t*>(node->custom_initial_data),
+            node->custom_initial_data_size)
+            .AsMap();
+    const flexbuffers::TypedVector keys = m.Keys();
+    int num_sparse = 0;
+    int num_dense = 0;
+    for (int k = 0; k < keys.size(); ++k) {
+      const std::string key = keys[k].ToString();
+      const auto value = m[key];
+      if (key == "Nsparse" || key == "num_sparse") {
+        num_sparse = value.AsInt32();
+      }
+      if (key == "Ndense") {
+        num_dense = value.AsInt32();
+      }
+    }
+    data->sparse_size = num_sparse;
+    data->dense_size = num_dense;
+    if (version == V2) {
+      const TfLiteTensor* dense_key_tensor =
+          GetInput(context, node, kDenseKeysTensor);
+      data->dense_size = GetTensorShape(dense_key_tensor).FlatSize();
+    }
+  }
+
+  data->config.dense.reserve(data->dense_size);
+  data->config.sparse.reserve(data->sparse_size);
+  data->dense_shapes.reserve(data->dense_size);
+  const auto* serialized = GetInput(context, node, 0);
+  const int batch_size =
+      serialized->dims->size > 0 ? serialized->dims->data[0] : 1;
+
+  for (int i = 0; i < data->dense_size; i++) {
+    TfLiteTensor* dense_key_tensor =
+        GetOutput(context, node, data->sparse_size * 3 + i);
+    TfLiteIntArray* output_size = TfLiteIntArrayCopy(dense_key_tensor->dims);
+    if (data->dense_size > 0 && data->dense_shapes.empty()) {
+      RuntimeShape runtime_shape = GetTensorShape(dense_key_tensor);
+      data->dense_shapes.push_back(TfLiteToTfShape(output_size));
+    }
+    output_size->data[0] = batch_size * output_size->data[0];
+    context->ResizeTensor(context, dense_key_tensor, output_size);
+  }
+
+  size_t offset = 0;
+  for (int i = 0; i < data->sparse_size; i++) {
+    auto* parse_output = GetOutput(context, node, i + offset);
+    SetTensorToDynamic(parse_output);
+    TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(2);
+    sparse_size->data[0] = batch_size;
+    sparse_size->data[1] = 2;
+    context->ResizeTensor(context, parse_output, sparse_size);
+    data->got.sparse_indices.push_back(parse_output);
+  }
+  offset += data->sparse_size;
+  for (int i = 0; i < data->sparse_size; i++) {
+    auto* parse_output = GetOutput(context, node, i + offset);
+    SetTensorToDynamic(parse_output);
+    TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(1);
+    sparse_size->data[0] = 0;
+    context->ResizeTensor(context, parse_output, sparse_size);
+    data->got.sparse_values.push_back(parse_output);
+  }
+  offset += data->sparse_size;
+  for (int i = 0; i < data->sparse_size; i++) {
+    TfLiteTensor* parse_output = GetOutput(context, node, i + offset);
+    SetTensorToDynamic(parse_output);
+    TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(1);
+    sparse_size->data[0] = 2;
+    context->ResizeTensor(context, parse_output, sparse_size);
+    auto* shapes_shape_t = reinterpret_cast<int64_t*>(parse_output->data.i64);
+    shapes_shape_t[0] = batch_size;
+    shapes_shape_t[1] = 1;
+    data->got.sparse_shapes.push_back(parse_output);
+  }
+  data->created = false;
+  return kTfLiteOk;
+}
+
+template <Version version>
+TfLiteStatus EvalParseExample(TfLiteContext* context, TfLiteNode* node) {
+  OpData* data = reinterpret_cast<OpData*>(node->user_data);
+  if (!data->created) {
+    for (int i = 0; i < data->sparse_size; i++) {
+      int input_index =
+          version == V1 ? kSparseKeysTensor + i : kSparseKeysTensor;
+      int string_index = version == V1 ? 0 : i;
+      const TfLiteTensor* sparse_key_tensor =
+          GetInput(context, node, input_index);
+      const auto key = GetString(sparse_key_tensor, string_index);
+      const auto* sparse_output =
+          GetOutput(context, node, i + data->sparse_size);
+      std::string k(key.str, key.len);
+      switch (sparse_output->type) {
+        case kTfLiteInt64:
+          data->config.sparse.emplace_back(
+              k, tf::DataTypeToEnum<tf::int64>::value);
+          break;
+        case kTfLiteFloat32:
+          data->config.sparse.emplace_back(k, tf::DataTypeToEnum<float>::value);
+          break;
+        case kTfLiteString:
+          data->config.sparse.emplace_back(k,
+                                           tf::DataTypeToEnum<tstring>::value);
+          break;
+        default:
+          return kTfLiteError;
+      }
+    }
+
+    const auto& dense_shapes = data->dense_shapes;
+    for (int i = 0; i < data->dense_size; i++) {
+      const int input_index = version == V1
+                                  ? kSparseKeysTensor + data->sparse_size + i
+                                  : kDenseKeysTensor;
+      const int dense_defaults_index =
+          version == V1
+              ? kSparseKeysTensor + data->sparse_size + data->dense_size + i
+              : kRaggedKeysTensor + i + 1;
+      int string_index = version == V1 ? 0 : i;
+      const TfLiteTensor* dense_key_tensor =
+          GetInput(context, node, input_index);
+      const auto* dense_output =
+          GetOutput(context, node, i + data->sparse_size * 3);
+      const auto* dense_defaults =
+          GetInput(context, node, dense_defaults_index);
+      const auto key = GetString(dense_key_tensor, string_index);
+      std::string k(key.str, key.len);
+      const int elements_per_stride =
+          dense_shapes[i].dims() ? dense_shapes[i].num_elements() : 1;
+      switch (dense_output->type) {
+        case kTfLiteInt64:
+          data->config.dense.emplace_back(
+              k, tf::DataTypeToEnum<tf::int64>::value, dense_shapes[i],
+              AsTensor<tf::int64>(std::vector<tf::int64>(
+                  dense_defaults->data.i64,
+                  dense_defaults->data.i64 + elements_per_stride)),
+              false, elements_per_stride);
+          break;
+        case kTfLiteFloat32:
+          data->config.dense.emplace_back(
+              k, tf::DataTypeToEnum<float>::value, dense_shapes[i],
+              AsTensor<float>(std::vector<float>(
+                  dense_defaults->data.f,
+                  dense_defaults->data.f + elements_per_stride)),
+              false, elements_per_stride);
+          break;
+        case kTfLiteString: {
+          const int num_strings = GetStringCount(dense_defaults);
+          std::vector<tstring> values;
+          for (int i = 0; i < num_strings; ++i) {
+            auto ref = GetString(dense_defaults, i);
+            values.emplace_back(ref.str, ref.len);
+          }
+          data->config.dense.emplace_back(
+              k, tf::DataTypeToEnum<tstring>::value, dense_shapes[i],
+              AsTensor<tstring>(values), false, elements_per_stride);
+          break;
+        }
+        default:
+          return kTfLiteError;
+      }
+    }
+
+    int offset = 3 * data->sparse_size;
+    for (int i = 0; i < data->dense_size; i++) {
+      auto* parse_output = GetOutput(context, node, i + offset);
+      data->got.dense_values.push_back(parse_output);
+      if (parse_output->type == kTfLiteString) {
+        tf::TensorShape shape;
+        if (parse_output->dims->size == 1) {
+          shape.AddDim(parse_output->dims->data[0]);
+        } else {
+          shape.AddDim(GetTensorShape(parse_output).FlatSize());
+        }
+        data->got.dense_tensors[i] =
+            tf::Tensor(tf::DataTypeToEnum<tstring>::value, shape);
+      }
+    }
+
+    size_t config_size = data->config.dense.size();
+    config_size += data->config.sparse.size();
+    data->config_index_size = config_size;
+    auto config_index = std::make_unique<ConfigIndex>(config_size);
+    bool ok = true;
+    int max_length = 0;
+    for (size_t d = 0; d < data->config.dense.size(); ++d) {
+      auto s = data->config.dense[d].feature_name;
+      max_length = s.length() > max_length ? s.length() : max_length;
+    }
+    for (size_t d = 0; d < data->config.sparse.size(); ++d) {
+      auto s = data->config.sparse[d].feature_name;
+      max_length = s.length() > max_length ? s.length() : max_length;
+    }
+    if (data->quick_filter) {
+      free(data->quick_filter);
+    }
+    data->quick_filter =
+        static_cast<bool*>(malloc(++max_length * sizeof(bool)));
+    memset(data->quick_filter, 0, max_length * sizeof(bool));
+    data->quick_filter_size = max_length;
+    for (size_t d = 0; d < data->config.dense.size(); ++d) {
+      const auto& s = data->config.dense[d].feature_name;
+      data->quick_filter[s.length()] = true;
+    }
+    for (size_t d = 0; d < data->config.sparse.size(); ++d) {
+      const auto& s = data->config.sparse[d].feature_name;
+      data->quick_filter[s.length()] = true;
+    }
+
+    for (int i = 0; i < 1000; ++i) {
+      for (size_t d = 0; d < data->config.dense.size(); ++d) {
+        ok &= config_index->InsertUnique(
+            data->hasher(data->config.dense[d].feature_name), {d, Type::Dense});
+      }
+      for (size_t d = 0; d < data->config.sparse.size(); ++d) {
+        ok &= config_index->InsertUnique(
+            data->hasher(data->config.sparse[d].feature_name),
+            {d, Type::Sparse});
+      }
+      if (ok) {
+        break;
+      }
+      data->hasher.seed++;
+      config_index->Clear(config_size);
+      ok = true;
+    }
+    if (!ok) {
+      return kTfLiteError;
+    }
+    data->config_index = std::move(config_index);
+    data->created = true;
+  }
+
+  const TfLiteTensor* serialized = GetInput(context, node, kExampleTensor);
+
+  std::map<absl::string_view, int> stats;
+  const auto status = FastParseExampleLite(
+      data->config, serialized, {}, data->quick_filter, data->quick_filter_size,
+      data->config_index, data->config_index_size, &data->hasher, &data->got,
+      stats, context);
+  if (status != tf::Status::OK()) {
+    TF_LITE_KERNEL_LOG(context, status.ToString().c_str());
+    return kTfLiteError;
+  }
+  return kTfLiteOk;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+  auto* obj = reinterpret_cast<OpData*>(buffer);
+  delete obj;
+}
+
+}  // namespace parse_example
+
+TfLiteRegistration* Register_PARSE_EXAMPLE() {
+  static TfLiteRegistration r = {
+      parse_example::Init, parse_example::Free,
+      parse_example::PrepareParseExample<parse_example::V1>,
+      parse_example::EvalParseExample<parse_example::V1>};
+  return &r;
+}
+
+TfLiteRegistration* Register_PARSE_EXAMPLE_V2() {
+  static TfLiteRegistration r = {
+      parse_example::Init, parse_example::Free,
+      parse_example::PrepareParseExample<parse_example::V2>,
+      parse_example::EvalParseExample<parse_example::V2>};
+  return &r;
+}
+
+extern "C" void AddParseExampleOp(::tflite::MutableOpResolver* resolver) {
+  resolver->AddCustom("ParseExample", Register_PARSE_EXAMPLE());
+  resolver->AddCustom("ParseExampleV2", Register_PARSE_EXAMPLE_V2());
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/parse_example/parse_example.h b/tensorflow/lite/kernels/parse_example/parse_example.h
new file mode 100644
index 0000000..ccda857
--- /dev/null
+++ b/tensorflow/lite/kernels/parse_example/parse_example.h
@@ -0,0 +1,33 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
+#define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
+
+#include "tensorflow/lite/mutable_op_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_PARSE_EXAMPLE();
+TfLiteRegistration* Register_PARSE_EXAMPLE_V2();
+
+extern "C" void AddParseExampleOp(::tflite::MutableOpResolver* resolver);
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
diff --git a/tensorflow/lite/kernels/parse_example/parse_example_test.cc b/tensorflow/lite/kernels/parse_example/parse_example_test.cc
new file mode 100644
index 0000000..ca35da3
--- /dev/null
+++ b/tensorflow/lite/kernels/parse_example/parse_example_test.cc
@@ -0,0 +1,330 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/kernels/parse_example/parse_example.h"
+
+#include <initializer_list>
+
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
+#include "tensorflow/core/example/feature_util.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/tstring.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/interpreter_builder.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model_builder.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace tf = ::tensorflow;
+
+const char* kNodeDefTxt = R"pb(
+  name: "ParseExample/ParseExample"
+  op: "ParseExample"
+  input: "serialized"
+  input: "ParseExample/ParseExample/names"
+  input: "ParseExample/ParseExample/dense_keys_0"
+  input: "ParseExample/Const"
+  attr {
+    key: "Ndense"
+    value { i: 1 }
+  }
+  attr {
+    key: "Nsparse"
+    value { i: 0 }
+  }
+  attr {
+    key: "Tdense"
+    value { list { type: DT_FLOAT } }
+  }
+  attr {
+    key: "dense_shapes"
+    value { list { shape { dim { size: 2 } } } }
+  }
+  attr {
+    key: "sparse_types"
+    value { list { type: DT_FLOAT } }
+  }
+)pb";
+
+const char* kNodeDefTxt2 = R"pb(
+  name: "ParseExample/ParseExample"
+  op: "ParseExample"
+  input: "serialized"
+  input: "ParseExample/ParseExample/names"
+  input: "ParseExample/ParseExample/sparse_keys_0"
+  attr {
+    key: "Ndense"
+    value { i: 0 }
+  }
+  attr {
+    key: "Nsparse"
+    value { i: 1 }
+  }
+  attr {
+    key: "Tdense"
+    value {}
+  }
+  attr {
+    key: "dense_shapes"
+    value {}
+  }
+  attr {
+    key: "sparse_types"
+    value { list { type: DT_FLOAT } }
+  }
+)pb";
+
+const char* kNodeDefTxt3 = R"pb(
+  name: "ParseExample/ParseExample"
+  op: "ParseExample"
+  input: "serialized"
+  input: "ParseExample/ParseExample/names"
+  input: "ParseExample/ParseExample/sparse_keys_0"
+  attr {
+    key: "Ndense"
+    value { i: 1 }
+  }
+  attr {
+    key: "Nsparse"
+    value { i: 0 }
+  }
+  attr {
+    key: "Tdense"
+    value { list { type: DT_STRING } }
+  }
+  attr {
+    key: "dense_shapes"
+    value { list { shape { dim { size: 1 } } } }
+  }
+  attr {
+    key: "sparse_types"
+    value { list { type: DT_FLOAT } }
+  }
+)pb";
+
+const char* kNodeDefTxt4 = R"pb(
+  name: "ParseExample/ParseExample"
+  op: "ParseExample"
+  input: "serialized"
+  input: "ParseExample/ParseExample/names"
+  input: "ParseExample/ParseExample/sparse_keys_0"
+  attr {
+    key: "Ndense"
+    value { i: 0 }
+  }
+  attr {
+    key: "Nsparse"
+    value { i: 1 }
+  }
+  attr {
+    key: "Tdense"
+    value {}
+  }
+  attr {
+    key: "dense_shapes"
+    value {}
+  }
+  attr {
+    key: "sparse_types"
+    value { list { type: DT_STRING } }
+  }
+)pb";
+
+template <typename DefaultType>
+class ParseExampleOpModel : public SingleOpModel {
+ public:
+  ParseExampleOpModel(std::string serialized_example,
+                      std::vector<std::string> sparse_keys,
+                      std::vector<std::string> dense_keys,
+                      std::initializer_list<DefaultType> dense_defaults,
+                      std::vector<TensorType> dense_types,
+                      std::vector<TensorType> sparse_types,
+                      const char* text_def, int dense_size = 2) {
+    // Example
+    string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
+    // Names
+    string_indices_.push_back(
+        AddConstInput<std::string>(TensorData(TensorType_STRING, {0}), {""}));
+    std::for_each(sparse_keys.begin(), sparse_keys.end(), [&](auto&&) {
+      string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
+    });
+    std::for_each(dense_keys.begin(), dense_keys.end(), [&](auto&&) {
+      string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
+    });
+    if (dense_size > 0) {
+      dense_defaults_ = AddConstInput<DefaultType>(
+          TensorData(dense_types[0], {dense_size}), dense_defaults);
+    }
+    if (!sparse_keys.empty()) {
+      for (int i = 0; i < sparse_keys.size(); i++) {
+        sparse_indices_outputs_.push_back(AddOutput(TensorType_INT64));
+      }
+      for (int i = 0; i < sparse_keys.size(); i++) {
+        sparse_values_outputs_.push_back(AddOutput(sparse_types[i]));
+      }
+      for (int i = 0; i < sparse_keys.size(); i++) {
+        sparse_shapes_outputs_.push_back(AddOutput({TensorType_INT64, {2}}));
+      }
+    }
+    for (int i = 0; i < dense_keys.size(); i++) {
+      dense_outputs_.push_back(AddOutput({dense_types[i], {dense_size}}));
+    }
+
+    tf::NodeDef nodedef;
+    tf::protobuf::TextFormat::Parser parser;
+    tf::protobuf::io::ArrayInputStream input_stream(text_def, strlen(text_def));
+    if (!parser.Parse(&input_stream, &nodedef)) {
+      abort();
+    }
+    std::string serialized_nodedef;
+    nodedef.SerializeToString(&serialized_nodedef);
+    flexbuffers::Builder fbb;
+    fbb.Vector([&]() {
+      fbb.String(nodedef.op());
+      fbb.String(serialized_nodedef);
+    });
+    fbb.Finish();
+    const auto buffer = fbb.GetBuffer();
+    SetCustomOp("ParseExample", buffer, Register_PARSE_EXAMPLE);
+    BuildInterpreter({});
+    int idx = 0;
+    PopulateStringTensor(string_indices_[idx++], {serialized_example});
+    PopulateStringTensor(string_indices_[idx++], {""});
+    for (const auto& key : sparse_keys) {
+      PopulateStringTensor(string_indices_[idx++], {key});
+    }
+    for (const auto& key : dense_keys) {
+      PopulateStringTensor(string_indices_[idx++], {key});
+    }
+  }
+
+  template <typename T>
+  std::vector<T> GetSparseIndicesOutput(int i) {
+    return ExtractVector<T>(sparse_indices_outputs_[i]);
+  }
+
+  template <typename T>
+  std::vector<T> GetSparseValuesOutput(int i) {
+    return ExtractVector<T>(sparse_values_outputs_[i]);
+  }
+
+  template <typename T>
+  std::vector<T> GetSparseShapesOutput(int i) {
+    return ExtractVector<T>(sparse_shapes_outputs_[i]);
+  }
+
+  template <typename T>
+  std::vector<T> GetDenseOutput(int i) {
+    return ExtractVector<T>(dense_outputs_[i]);
+  }
+
+  std::vector<std::string> GetStringOutput(int i) {
+    auto* t = interpreter_->tensor(i);
+    int count = GetStringCount(t);
+    std::vector<std::string> v;
+    for (int i = 0; i < count; ++i) {
+      auto ref = GetString(t, i);
+      v.emplace_back(ref.str, ref.len);
+    }
+    return v;
+  }
+
+  int DenseDefaults() { return dense_defaults_; }
+
+  int SparseValuesOutputs(int i) { return sparse_values_outputs_[i]; }
+
+  int DenseOutputs(int i) { return dense_outputs_[i]; }
+
+  std::vector<int> dense_outputs_;
+  std::vector<int> sparse_indices_outputs_;
+  std::vector<int> sparse_shapes_outputs_;
+  std::vector<int> sparse_values_outputs_;
+  std::vector<int> string_indices_;
+  int dense_defaults_ = -1;
+};
+
+TEST(ParseExampleOpsTest, SimpleTest) {
+  tf::Example example;
+  tf::AppendFeatureValues<float>({1.5f, 1.5f}, "time", &example);
+  tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
+  ParseExampleOpModel<float> m(example.SerializeAsString(), {}, {"time"},
+                               {0.f, 0.f}, {TensorType_FLOAT32}, {},
+                               kNodeDefTxt);
+  m.Invoke();
+  EXPECT_THAT(m.GetDenseOutput<float>(0),
+              ElementsAreArray(ArrayFloatNear({1.5f, 1.5f})));
+}
+
+TEST(ParseExampleOpsTest, SparseTest) {
+  tf::Example example;
+  tf::AppendFeatureValues<float>({1.5f}, "time", &example);
+  ParseExampleOpModel<float> m(example.SerializeAsString(), {"time"}, {}, {},
+                               {}, {TensorType_FLOAT32}, kNodeDefTxt2, 0);
+  m.Invoke();
+  EXPECT_THAT(m.GetSparseIndicesOutput<int64_t>(0),
+              ElementsAreArray(ArrayFloatNear({0, 0})));
+  EXPECT_THAT(m.GetSparseValuesOutput<float>(0),
+              ElementsAreArray(ArrayFloatNear({1.5f})));
+  EXPECT_THAT(m.GetSparseShapesOutput<int64_t>(0),
+              ElementsAreArray(ArrayFloatNear({1, 1})));
+}
+
+TEST(ParseExampleOpsTest, SimpleBytesTest) {
+  tf::Example example;
+  const std::string test_data = "simpletest";
+  tf::AppendFeatureValues<tensorflow::tstring>({test_data}, "time", &example);
+  tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
+  std::string default_value = "missing";
+  ParseExampleOpModel<std::string> m(example.SerializeAsString(), {}, {"time"},
+                                     {default_value}, {TensorType_STRING}, {},
+                                     kNodeDefTxt3, 1);
+  m.PopulateStringTensor(m.DenseDefaults(), {default_value});
+  m.Invoke();
+  std::vector<string> c = m.GetStringOutput(m.DenseOutputs(0));
+  EXPECT_EQ(1, c.size());
+  EXPECT_EQ(test_data, c[0]);
+}
+
+TEST(ParseExampleOpsTest, SparseBytesTest) {
+  tf::Example example;
+  const std::string test_data = "simpletest";
+  tf::AppendFeatureValues<tensorflow::tstring>({test_data, test_data}, "time",
+                                               &example);
+  tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
+  ParseExampleOpModel<std::string> m(example.SerializeAsString(), {"time"}, {},
+                                     {}, {}, {TensorType_STRING}, kNodeDefTxt4,
+                                     0);
+  m.Invoke();
+  EXPECT_THAT(m.GetSparseIndicesOutput<int64_t>(0),
+              testing::ElementsAreArray({0, 0, 0, 1}));
+  auto values = m.GetStringOutput(m.SparseValuesOutputs(0));
+  EXPECT_EQ(2, values.size());
+  EXPECT_EQ(test_data, values[0]);
+  EXPECT_EQ(test_data, values[1]);
+  EXPECT_THAT(m.GetSparseShapesOutput<int64_t>(0),
+              testing::ElementsAreArray({1, 2}));
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/perception/BUILD b/tensorflow/lite/kernels/perception/BUILD
new file mode 100644
index 0000000..0cead40
--- /dev/null
+++ b/tensorflow/lite/kernels/perception/BUILD
@@ -0,0 +1,71 @@
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "pybind_extension")
+
+package(
+    default_visibility = [
+        "//visibility:public",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "perception_ops",
+    srcs = [
+        "max_pool_with_argmax.cc",
+        "max_unpooling_2d.cc",
+        "perception_ops.cc",
+    ],
+    hdrs = [
+        "perception_ops.h",
+    ],
+    compatible_with = get_compatible_with_portable(),
+    deps = [
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/kernels:kernel_util",
+        "//tensorflow/lite/kernels:padding",
+        "//tensorflow/lite/kernels/internal:common",
+        "//tensorflow/lite/kernels/internal:compatibility",
+        "//tensorflow/lite/kernels/internal:tensor",
+        "//tensorflow/lite/kernels/internal:tensor_utils",
+        "//tensorflow/lite/kernels/internal:types",
+        "@flatbuffers",
+    ],
+)
+
+cc_test(
+    name = "perception_ops_test",
+    size = "small",
+    srcs = [
+        "max_pool_with_argmax_test.cc",
+        "max_unpooling_2d_test.cc",
+    ],
+    deps = [
+        ":perception_ops",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite/kernels:test_main",
+        "//tensorflow/lite/kernels:test_util",
+        "//tensorflow/lite/testing:util",
+        "@flatbuffers",
+    ],
+)
+
+pybind_extension(
+    name = "pywrap_perception_ops",
+    srcs = [
+        "perception_ops_wrapper.cc",
+    ],
+    hdrs = ["perception_ops.h"],
+    additional_exported_symbols = ["PerceptionOpsRegisterer"],
+    link_in_framework = True,
+    module_name = "pywrap_perception_ops",
+    deps = [
+        ":perception_ops",
+        "//tensorflow/lite:mutable_op_resolver",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
diff --git a/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc b/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc
new file mode 100644
index 0000000..4e1aca9
--- /dev/null
+++ b/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc
@@ -0,0 +1,249 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/padding.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace max_pool_with_argmax {
+namespace {
+// TODO(b/175003241): Move this logic to lite/kernels/internal when promoting
+// this op to a builtin op.
+template <typename T>
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+                    const RuntimeShape& output_shape, const T* input_data,
+                    T* output_data, int32_t* indices_data) {
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+  const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
+  const int32_t input_height = input_shape.Dims(1);
+  const int32_t input_width = input_shape.Dims(2);
+  const int32_t output_height = output_shape.Dims(1);
+  const int32_t output_width = output_shape.Dims(2);
+  const int32_t stride_height = params.stride_height;
+  const int32_t stride_width = params.stride_width;
+  for (int32_t batch = 0; batch < batches; ++batch) {
+    for (int32_t out_y = 0; out_y < output_height; ++out_y) {
+      for (int32_t out_x = 0; out_x < output_width; ++out_x) {
+        for (int32_t channel = 0; channel < depth; ++channel) {
+          const int32_t in_x_origin =
+              (out_x * stride_width) - params.padding_values.width;
+          const int32_t in_y_origin =
+              (out_y * stride_height) - params.padding_values.height;
+          // Compute the boundaries of the filter region clamped so as to
+          // ensure that the filter window fits in the input array.
+          const int32_t filter_x_start = std::max(0, -in_x_origin);
+          const int32_t filter_x_end =
+              std::min(params.filter_width, input_width - in_x_origin);
+          const int32_t filter_y_start = std::max(0, -in_y_origin);
+          const int32_t filter_y_end =
+              std::min(params.filter_height, input_height - in_y_origin);
+          float max = std::numeric_limits<float>::lowest();
+          int32_t max_x = 0;
+          int32_t max_y = 0;
+
+          for (int32_t filter_y = filter_y_start; filter_y < filter_y_end;
+               ++filter_y) {
+            for (int32_t filter_x = filter_x_start; filter_x < filter_x_end;
+                 ++filter_x) {
+              const int32_t in_x = in_x_origin + filter_x;
+              const int32_t in_y = in_y_origin + filter_y;
+              float cur =
+                  input_data[Offset(input_shape, batch, in_y, in_x, channel)];
+              if (cur > max) {
+                max = cur;
+                max_x = in_x;
+                max_y = in_y;
+              }
+            }
+          }
+          int32_t output_idx =
+              Offset(output_shape, batch, out_y, out_x, channel);
+          output_data[output_idx] = ActivationFunctionWithMinMax(
+              max, params.float_activation_min, params.float_activation_max);
+          indices_data[output_idx] =
+              (max_y * input_width + max_x) * depth + channel;
+        }
+      }
+    }
+  }
+}
+
+}  // namespace
+
+constexpr int kDataInputTensor = 0;
+constexpr int kDataOutputTensor = 0;
+constexpr int kIndicesOutputTensor = 1;
+
+constexpr const char kIncludeBatchStr[] = "include_batch_in_index";
+constexpr const char kPoolSizeStr[] = "ksize";
+constexpr const char kStridesStr[] = "strides";
+constexpr const char kPaddingStr[] = "padding";
+constexpr const char kPaddingSameStr[] = "SAME";
+constexpr const char kPaddingValidStr[] = "VALID";
+
+struct OpData {
+  TfLitePoolParams params;
+  bool include_batch_in_index;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+  const flexbuffers::Map& m =
+      flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
+          .AsMap();
+
+  OpData* op_data = new OpData;
+  op_data->params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
+  op_data->include_batch_in_index = m[kIncludeBatchStr].AsBool();
+  op_data->params.activation = kTfLiteActNone;
+
+  const std::string padding = m[kPaddingStr].AsString().str();
+  if (padding == kPaddingValidStr) {
+    op_data->params.padding = kTfLitePaddingValid;
+  } else if (padding == kPaddingSameStr) {
+    op_data->params.padding = kTfLitePaddingSame;
+  } else {
+    op_data->params.padding = kTfLitePaddingUnknown;
+  }
+
+  // The first and last element of pool_size are always 1.
+  const auto pool_size = m[kPoolSizeStr].AsTypedVector();
+  TFLITE_CHECK_EQ(pool_size.size(), 4);
+  TFLITE_CHECK_EQ(pool_size[0].AsInt32(), 1);
+  TFLITE_CHECK_EQ(pool_size[3].AsInt32(), 1);
+  op_data->params.filter_height = pool_size[1].AsInt32();
+  op_data->params.filter_width = pool_size[2].AsInt32();
+
+  // The first and last element of strides are always 1.
+  const auto strides = m[kStridesStr].AsTypedVector();
+  TFLITE_CHECK_EQ(strides.size(), 4);
+  TFLITE_CHECK_EQ(strides[0].AsInt32(), 1);
+  TFLITE_CHECK_EQ(strides[3].AsInt32(), 1);
+  op_data->params.stride_height = strides[1].AsInt32();
+  op_data->params.stride_width = strides[2].AsInt32();
+
+  return op_data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+  delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
+  TfLiteTensor *output, *indices;
+  TF_LITE_ENSURE_OK(context,
+                    GetOutputSafe(context, node, kDataOutputTensor, &output));
+  TF_LITE_ENSURE_OK(
+      context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
+  const TfLiteTensor* input;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kDataInputTensor, &input));
+  TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+  TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+  TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
+  TF_LITE_ENSURE(context, indices->type == kTfLiteInt32);
+  TF_LITE_ENSURE(context, op_data->params.padding != kTfLitePaddingUnknown);
+  TF_LITE_ENSURE_MSG(
+      context, !op_data->include_batch_in_index,
+      "Include batch dimension in flattened index is not yet supported.");
+
+  int batches = input->dims->data[0];
+  int height = input->dims->data[1];
+  int width = input->dims->data[2];
+  int channels_out = input->dims->data[3];
+
+  // Matching GetWindowedOutputSize in TensorFlow.
+  int out_width, out_height;
+  op_data->params.computed.padding = ComputePaddingHeightWidth(
+      op_data->params.stride_height, op_data->params.stride_width, 1, 1, height,
+      width, op_data->params.filter_height, op_data->params.filter_width,
+      op_data->params.padding, &out_height, &out_width);
+
+  TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+  output_size->data[0] = batches;
+  output_size->data[1] = out_height;
+  output_size->data[2] = out_width;
+  output_size->data[3] = channels_out;
+  TfLiteIntArray* indices_size = TfLiteIntArrayCopy(output_size);
+
+  TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, indices, indices_size));
+  return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+  float activation_min, activation_max;
+  CalculateActivationRange(op_data->params.activation, &activation_min,
+                           &activation_max);
+
+  tflite::PoolParams op_params;
+  op_params.stride_height = op_data->params.stride_height;
+  op_params.stride_width = op_data->params.stride_width;
+  op_params.filter_height = op_data->params.filter_height;
+  op_params.filter_width = op_data->params.filter_width;
+  op_params.padding_values.height = op_data->params.computed.padding.height;
+  op_params.padding_values.width = op_data->params.computed.padding.width;
+  op_params.float_activation_min = activation_min;
+  op_params.float_activation_max = activation_max;
+
+  TfLiteTensor *output, *indices;
+  TF_LITE_ENSURE_OK(context,
+                    GetOutputSafe(context, node, kDataOutputTensor, &output));
+  TF_LITE_ENSURE_OK(
+      context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
+  const TfLiteTensor* input;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kDataInputTensor, &input));
+
+  switch (input->type) {
+    case kTfLiteFloat32:
+      MaxPool<float>(op_params, GetTensorShape(input), GetTensorShape(output),
+                     GetTensorData<float>(input), GetTensorData<float>(output),
+                     GetTensorData<int32_t>(indices));
+      break;
+    default:
+      TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
+                         TfLiteTypeGetName(input->type));
+      return kTfLiteError;
+  }
+  return kTfLiteOk;
+}
+}  // namespace max_pool_with_argmax
+
+TfLiteRegistration* RegisterMaxPoolWithArgmax() {
+  static TfLiteRegistration r = {
+      max_pool_with_argmax::Init, max_pool_with_argmax::Free,
+      max_pool_with_argmax::Prepare, max_pool_with_argmax::Eval};
+  return &r;
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/perception/max_pool_with_argmax_test.cc b/tensorflow/lite/kernels/perception/max_pool_with_argmax_test.cc
new file mode 100644
index 0000000..a0642df
--- /dev/null
+++ b/tensorflow/lite/kernels/perception/max_pool_with_argmax_test.cc
@@ -0,0 +1,298 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/perception/perception_ops.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/testing/util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace {
+
+using testing::ElementsAreArray;
+
+class MaxpoolingWithArgMaxOpModel : public SingleOpModel {
+ public:
+  MaxpoolingWithArgMaxOpModel(const TensorData& input, int stride_height,
+                              int stride_width, int filter_height,
+                              int filter_width, TfLitePadding padding,
+                              const TensorData& output,
+                              const TensorData& indices) {
+    input_ = AddInput(input);
+    output_ = AddOutput(output);
+    indices_ = AddOutput(indices);
+
+    std::vector<uint8_t> custom_option = CreateCustomOptions(
+        stride_height, stride_width, filter_height, filter_width, padding);
+    SetCustomOp("MaxPoolWithArgmax", custom_option, RegisterMaxPoolWithArgmax);
+    BuildInterpreter({GetShape(input_)});
+  }
+
+  void SetInput(const std::vector<float>& data) {
+    PopulateTensor(input_, data);
+  }
+
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+  std::vector<int32_t> GetIndices() { return ExtractVector<int32_t>(indices_); }
+
+  std::vector<int> GetIndicesShape() { return GetTensorShape(indices_); }
+
+ protected:
+  int input_;
+  int output_;
+  int indices_;
+
+ private:
+  std::vector<uint8_t> CreateCustomOptions(int stride_height, int stride_width,
+                                           int filter_height, int filter_width,
+                                           TfLitePadding padding) {
+    auto flex_builder = std::make_unique<flexbuffers::Builder>();
+    size_t map_start = flex_builder->StartMap();
+    flex_builder->Bool("include_batch_in_index", false);
+    if (padding == kTfLitePaddingValid) {
+      flex_builder->String("padding", "VALID");
+    } else {
+      flex_builder->String("padding", "SAME");
+    }
+
+    auto start = flex_builder->StartVector("ksize");
+    flex_builder->Add(1);
+    flex_builder->Add(filter_height);
+    flex_builder->Add(filter_width);
+    flex_builder->Add(1);
+    flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
+
+    auto strides_start = flex_builder->StartVector("strides");
+    flex_builder->Add(1);
+    flex_builder->Add(stride_height);
+    flex_builder->Add(stride_width);
+    flex_builder->Add(1);
+    flex_builder->EndVector(strides_start, /*typed=*/true, /*fixed=*/false);
+
+    flex_builder->EndMap(map_start);
+    flex_builder->Finish();
+    return flex_builder->GetBuffer();
+  }
+};
+
+TEST(MaxpoolWithArgMaxTest, UnsupportedInt64Test) {
+  EXPECT_DEATH_IF_SUPPORTED(MaxpoolingWithArgMaxOpModel model(
+                                /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+                                /*stride_height=*/2, /*stride_width=*/2,
+                                /*filter_height=*/2, /*filter_width=*/2,
+                                /*padding=*/kTfLitePaddingSame,
+                                /*output=*/{TensorType_FLOAT32, {}},
+                                /*indices=*/{TensorType_INT64, {}});
+                            , "indices->type == kTfLiteInt32 was not true.");
+}
+
+TEST(MaxpoolWithArgMaxTest, SimpleTest) {
+  MaxpoolingWithArgMaxOpModel model(
+      /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+      /*stride_height=*/2, /*stride_width=*/2,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingSame,
+      /*output=*/{TensorType_FLOAT32, {}},
+      /*indices=*/{TensorType_INT32, {}});
+  model.SetInput({0, 13, 2, 0, 0, 1, 4, 0});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({13, 4}));
+  EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 1, 2, 1}));
+  EXPECT_THAT(model.GetIndices(), ElementsAreArray({1, 6}));
+}
+
+TEST(MaxpoolWithArgMaxTest, Strides2x1Test) {
+  MaxpoolingWithArgMaxOpModel model(
+      /*input=*/{TensorType_FLOAT32, {1, 4, 2, 2}},
+      /*stride_height=*/2, /*stride_width=*/1,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingSame,
+      /*output=*/{TensorType_FLOAT32, {}},
+      /*indices=*/{TensorType_INT32, {}});
+
+  model.SetInput({1, 0, 0, 2, 3, 0, 0, 4, 5, 0, 0, 6, 7, 0, 0, 8});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 2}));
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 4, 0, 4, 7, 8, 0, 8}));
+  EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 2, 2}));
+  EXPECT_THAT(model.GetIndices(),
+              ElementsAreArray({4, 7, 2, 7, 12, 15, 10, 15}));
+}
+
+TEST(MaxpoolWithArgMaxTest, Strides2x2Test) {
+  MaxpoolingWithArgMaxOpModel model(
+      /*input=*/{TensorType_FLOAT32, {1, 4, 8, 1}},
+      /*stride_height=*/2, /*stride_width=*/2,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingSame,
+      /*output=*/{TensorType_FLOAT32, {}},
+      /*indices=*/{TensorType_INT32, {}});
+
+  model.SetInput({1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0,
+                  0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 8});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4, 1}));
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 4, 0, 0, 7, 6, 8}));
+  EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 4, 1}));
+  EXPECT_THAT(model.GetIndices(),
+              ElementsAreArray({0, 10, 13, 6, 16, 27, 20, 31}));
+}
+
+TEST(MaxpoolWithArgMaxTest, Strides2x2UnfitTest) {
+  MaxpoolingWithArgMaxOpModel model(
+      /*input=*/{TensorType_FLOAT32, {1, 4, 7, 1}},
+      /*stride_height=*/2, /*stride_width=*/2,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingSame,
+      /*output=*/{TensorType_FLOAT32, {}},
+      /*indices=*/{TensorType_INT32, {}});
+
+  model.SetInput({1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 4,
+                  0, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4, 1}));
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 2, 4, 0, 0, 5, 7}));
+  EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 4, 1}));
+  EXPECT_THAT(model.GetIndices(),
+              ElementsAreArray({0, 10, 5, 13, 14, 16, 19, 27}));
+}
+
+TEST(MaxpoolWithArgMaxTest, PaddingValidTest) {
+  MaxpoolingWithArgMaxOpModel model(
+      /*input=*/{TensorType_FLOAT32, {1, 4, 5, 1}},
+      /*stride_height=*/2, /*stride_width=*/2,
+      /*filter_height=*/2, /*filter_width=*/3,
+      /*padding=*/kTfLitePaddingValid,
+      /*output=*/{TensorType_FLOAT32, {}},
+      /*indices=*/{TensorType_INT32, {}});
+
+  model.SetInput(
+      {0, 0, 0, 0, 0, 0, 7, 0, 0, 10, 0, 0, 0, 0, 0, 0, 20, 0, 0, 19});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({7, 10, 20, 19}));
+  EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 2, 1}));
+  EXPECT_THAT(model.GetIndices(), ElementsAreArray({6, 9, 16, 19}));
+}
+
+TEST(MaxpoolWithArgMaxTest, PaddingValidUnfitTest) {
+  MaxpoolingWithArgMaxOpModel model(
+      /*input=*/{TensorType_FLOAT32, {1, 4, 6, 1}},
+      /*stride_height=*/2, /*stride_width=*/2,
+      /*filter_height=*/2, /*filter_width=*/3,
+      /*padding=*/kTfLitePaddingValid,
+      /*output=*/{TensorType_FLOAT32, {}},
+      /*indices=*/{TensorType_INT32, {}});
+
+  model.SetInput({0, 0, 0, 0, 0,  0, 7, 0,  0,  10, 0, 0,
+                  0, 0, 0, 0, 20, 0, 0, 19, 24, 1,  2, 44});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({7, 10, 24, 24}));
+  EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 2, 1}));
+  EXPECT_THAT(model.GetIndices(), ElementsAreArray({6, 9, 20, 20}));
+}
+
+TEST(MaxpoolWithArgMaxTest, InputWithBatchTest) {
+  MaxpoolingWithArgMaxOpModel model(
+      /*input=*/{TensorType_FLOAT32, {2, 4, 12, 2}},
+      /*stride_height=*/2, /*stride_width=*/3,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingSame,
+      /*output=*/{TensorType_FLOAT32, {}},
+      /*indices=*/{TensorType_INT32, {}});
+
+  model.SetInput({0,  0,  1,  0,  0,  0,  0,  0,  3,  4, 0,  0,  5, 0, 0,  6,
+                  0,  0,  0,  0,  0,  0,  0,  2,  0,  0, 0,  0,  0, 0, 0,  0,
+                  0,  0,  0,  0,  0,  0,  0,  0,  7,  0, 0,  8,  9, 0, 0,  10,
+                  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 0,  0,  0, 0, 15, 0,
+                  0,  16, 0,  0,  0,  0,  0,  0,  11, 0, 0,  12, 0, 0, 0,  14,
+                  13, 0,  0,  0,  0,  0,  0,  0,  0,  0, 0,  0,  0, 0, 0,  0,
+                  17, 18, 0,  0,  0,  30, 0,  20, 0,  0, 0,  0,  0, 0, 21, 0,
+                  0,  0,  0,  0,  0,  24, 0,  0,  0,  0, 0,  0,  0, 0, 19, 0,
+                  0,  0,  0,  22, 0,  0,  0,  0,  0,  0, 23, 0,  0, 0, 0,  0,
+                  0,  0,  27, 28, 0,  0,  0,  0,  29, 0, 0,  0,  0, 0, 0,  32,
+                  0,  0,  0,  0,  25, 26, 0,  0,  0,  0, 0,  0,  0, 0, 0,  0,
+                  0,  0,  0,  0,  0,  0,  31, 0,  0,  0, 0,  0,  0, 0, 0,  0});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4, 2}));
+  EXPECT_THAT(model.GetOutput(),
+              ElementsAreArray({1,  0,  3,  4,  5,  6,  9,  8,  11, 12, 13,
+                                14, 15, 0,  0,  0,  17, 18, 19, 20, 21, 0,
+                                23, 24, 27, 28, 29, 0,  31, 32, 25, 26}));
+  EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({2, 2, 4, 2}));
+  EXPECT_THAT(model.GetIndices(),
+              ElementsAreArray({2,  1,  8,  9,  12, 15, 44, 43, 72, 75, 80,
+                                79, 62, 61, 66, 67, 0,  1,  30, 7,  14, 13,
+                                42, 21, 50, 51, 56, 55, 86, 63, 68, 69}));
+}
+
+TEST(MaxpoolWithArgMaxTest, InputWithBatchAndPaddingValidTest) {
+  MaxpoolingWithArgMaxOpModel model(
+      /*input=*/{TensorType_FLOAT32, {2, 4, 11, 2}},
+      /*stride_height=*/2, /*stride_width=*/3,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingValid,
+      /*output=*/{TensorType_FLOAT32, {}},
+      /*indices=*/{TensorType_INT32, {}});
+
+  model.SetInput({0,  0,  1,  0, 0, 0, 0,  0,  3,  4,  0,  0,  5,  0,  0,  6,
+                  0,  0,  0,  0, 0, 0, 0,  2,  0,  0,  0,  0,  0,  0,  0,  0,
+                  0,  0,  0,  0, 0, 0, 0,  0,  7,  0,  0,  8,  9,  0,  0,  10,
+                  0,  0,  0,  0, 0, 0, 0,  0,  0,  0,  0,  0,  0,  0,  15, 0,
+                  0,  16, 0,  0, 0, 0, 0,  0,  11, 0,  0,  12, 0,  0,  0,  14,
+                  13, 0,  0,  0, 0, 0, 0,  0,  17, 18, 0,  0,  0,  30, 0,  20,
+                  0,  0,  0,  0, 0, 0, 21, 0,  0,  0,  0,  0,  0,  24, 0,  0,
+                  0,  0,  0,  0, 0, 0, 19, 0,  0,  0,  0,  22, 0,  0,  0,  0,
+                  0,  0,  23, 0, 0, 0, 0,  0,  0,  0,  27, 28, 0,  0,  0,  0,
+                  29, 0,  0,  0, 0, 0, 0,  32, 0,  0,  0,  0,  25, 26, 0,  0,
+                  0,  0,  0,  0, 0, 0, 0,  0,  0,  0,  0,  0,  0,  0,  31, 0});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4, 2}));
+  EXPECT_THAT(model.GetOutput(),
+              ElementsAreArray({1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+                                12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
+                                23, 24, 25, 26, 27, 28, 29, 0,  31, 32}));
+  EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({2, 2, 4, 2}));
+  EXPECT_THAT(model.GetIndices(),
+              ElementsAreArray({2,  23, 8,  9,  12, 15, 40, 43, 44, 47, 72,
+                                75, 80, 79, 62, 65, 0,  1,  30, 7,  14, 35,
+                                42, 21, 68, 69, 50, 51, 56, 57, 86, 63}));
+}
+
+}  // namespace
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/perception/max_unpooling_2d.cc b/tensorflow/lite/kernels/perception/max_unpooling_2d.cc
new file mode 100644
index 0000000..ce51b14
--- /dev/null
+++ b/tensorflow/lite/kernels/perception/max_unpooling_2d.cc
@@ -0,0 +1,132 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/padding.h"
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace max_unpooling_2d {
+
+constexpr int kDataInputTensor = 0;
+constexpr int kIndicesTensor = 1;
+constexpr int kOutputTensor = 0;
+
+// TODO(b/175003241): Move this logic to lite/kernels/internal when promoting
+// this op to a builtin op.
+inline void MaxUnpooling(const RuntimeShape& input_shape,
+                         const float* input_data, const int32_t* indices_data,
+                         const RuntimeShape& output_shape, float* output_data) {
+  std::memset(output_data, 0, output_shape.FlatSize() * sizeof(float));
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+  const int batch_stride =
+      output_shape.Dims(1) * output_shape.Dims(2) * output_shape.Dims(3);
+  for (int batch = 0; batch < batches; ++batch) {
+    for (int in_y = 0; in_y < input_shape.Dims(1); ++in_y) {
+      for (int in_x = 0; in_x < input_shape.Dims(2); ++in_x) {
+        for (int channel = 0; channel < depth; ++channel) {
+          const auto input_offset =
+              Offset(input_shape, batch, in_y, in_x, channel);
+          int idx = indices_data[input_offset];
+          output_data[batch * batch_stride + idx] = input_data[input_offset];
+        }
+      }
+    }
+  }
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  auto* params =
+      reinterpret_cast<const TfLitePoolParams*>(node->custom_initial_data);
+
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+  TF_LITE_ENSURE(context, output != nullptr);
+  const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
+  TF_LITE_ENSURE(context, input != nullptr);
+  const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
+  TF_LITE_ENSURE(context, indices != nullptr);
+  TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 4);
+  TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+  TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+  TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
+  TF_LITE_ENSURE(context, params->padding != kTfLitePaddingUnknown);
+
+  // Size of input and indices tensor must match.
+  const RuntimeShape input_shape = GetTensorShape(input);
+  const RuntimeShape indices_shape = GetTensorShape(indices);
+  TF_LITE_ENSURE_MSG(
+      context, input_shape.DimensionsCount() == indices_shape.DimensionsCount(),
+      "Input and indices must have the same shape.");
+  for (int i = 0; i < input_shape.DimensionsCount(); ++i) {
+    TF_LITE_ENSURE_MSG(context, input_shape.Dims(i) == indices_shape.Dims(i),
+                       "Input and indices must have the same shape.");
+  }
+
+  int batches = input->dims->data[0];
+  int height = input->dims->data[1];
+  int width = input->dims->data[2];
+  int channels_out = input->dims->data[3];
+
+  int out_width, out_height;
+  if (params->padding == kTfLitePaddingSame) {
+    out_width = width * params->stride_width;
+    out_height = height * params->stride_height;
+  } else {
+    out_width = (width - 1) * params->stride_width + params->filter_width;
+    out_height = (height - 1) * params->stride_height + params->filter_height;
+  }
+
+  TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+  output_size->data[0] = batches;
+  output_size->data[1] = out_height;
+  output_size->data[2] = out_width;
+  output_size->data[3] = channels_out;
+  return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+  TF_LITE_ENSURE(context, output != nullptr);
+  const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
+  TF_LITE_ENSURE(context, input != nullptr);
+  const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
+  TF_LITE_ENSURE(context, indices != nullptr);
+
+  MaxUnpooling(GetTensorShape(input), GetTensorData<float>(input),
+               GetTensorData<int32_t>(indices), GetTensorShape(output),
+               GetTensorData<float>(output));
+  return kTfLiteOk;
+}
+
+}  // namespace max_unpooling_2d
+
+TfLiteRegistration* RegisterMaxUnpooling2D() {
+  static TfLiteRegistration reg = {/*init=*/nullptr,
+                                   /*free=*/nullptr, max_unpooling_2d::Prepare,
+                                   max_unpooling_2d::Eval};
+  return &reg;
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/perception/max_unpooling_2d_test.cc b/tensorflow/lite/kernels/perception/max_unpooling_2d_test.cc
new file mode 100644
index 0000000..d052164
--- /dev/null
+++ b/tensorflow/lite/kernels/perception/max_unpooling_2d_test.cc
@@ -0,0 +1,258 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <vector>
+
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/perception/perception_ops.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/testing/util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace {
+
+using testing::ElementsAreArray;
+
+class MaxUnpoolingOpModel : public SingleOpModel {
+ public:
+  MaxUnpoolingOpModel(const TensorData& input, const TensorData& indices,
+                      int stride_height, int stride_width, int filter_height,
+                      int filter_width, TfLitePadding padding,
+                      const TensorData& output) {
+    input_ = AddInput(input);
+    indices_ = AddInput(indices);
+    output_ = AddOutput(output);
+
+    TfLitePoolParams params{padding,      stride_width,  stride_height,
+                            filter_width, filter_height, kTfLiteActNone};
+    uint8_t* params_ptr = reinterpret_cast<uint8_t*>(&params);
+    std::vector<uint8_t> custom_option;
+    custom_option.assign(params_ptr, params_ptr + sizeof(TfLitePoolParams));
+
+    SetCustomOp("MaxUnpooling2D", custom_option, RegisterMaxUnpooling2D);
+    BuildInterpreter({GetShape(input_), GetShape(indices_)});
+  }
+
+  void SetInput(const std::vector<float>& data) {
+    PopulateTensor(input_, data);
+  }
+  void SetIndices(const std::vector<int32_t>& data) {
+    PopulateTensor(indices_, data);
+  }
+
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+  int input_;
+  int indices_;
+  int output_;
+};
+
+TEST(MaxUnpoolingOpTest, DimensionMisMatchTest) {
+  EXPECT_DEATH(MaxUnpoolingOpModel model(
+                   /*input=*/{TensorType_FLOAT32, {1, 1, 2, 1}},
+                   /*indices=*/{TensorType_INT32, {1, 2, 2, 1}},
+                   /*stride_height=*/2, /*stride_width=*/2,
+                   /*filter_height=*/2, /*filter_width=*/2,
+                   /*padding=*/kTfLitePaddingSame,
+                   /*output=*/{TensorType_FLOAT32, {}}),
+               "Input and indices must have the same shape.");
+}
+
+TEST(MaxUnpoolingOpTest, SimpleTest) {
+  MaxUnpoolingOpModel model(
+      /*input=*/{TensorType_FLOAT32, {1, 1, 2, 1}},
+      /*indices=*/{TensorType_INT32, {1, 1, 2, 1}},
+      /*stride_height=*/2, /*stride_width=*/2,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingSame,
+      /*output=*/{TensorType_FLOAT32, {}});
+  model.SetInput({13, 4});
+  model.SetIndices({1, 6});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4, 1}));
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 13, 0, 0, 0, 0, 4, 0}));
+}
+
+TEST(MaxUnpoolingOpTest, Strides2x1Test) {
+  constexpr int kInputB = 1;
+  constexpr int kInputH = 2;
+  constexpr int kInputW = 2;
+  constexpr int kInputC = 2;
+  std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8};
+  std::vector<int32_t> indices_data{0, 3, 4, 7, 8, 11, 12, 15};
+
+  MaxUnpoolingOpModel model(
+      /*input=*/{TensorType_FLOAT32, {kInputB, kInputH, kInputW, kInputC}},
+      /*indices=*/{TensorType_INT32, {kInputB, kInputH, kInputW, kInputC}},
+      /*stride_height=*/2, /*stride_width=*/1,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingSame,
+      /*output=*/{TensorType_FLOAT32, {}});
+
+  model.SetInput(input_data);
+  model.SetIndices(indices_data);
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 2, 2}));
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 2, 3, 0, 0, 4, 5, 0,
+                                                   0, 6, 7, 0, 0, 8}));
+}
+
+TEST(MaxUnpoolingOpTest, Strides2x2Test) {
+  constexpr int kInputB = 1;
+  constexpr int kInputH = 2;
+  constexpr int kInputW = 4;
+  constexpr int kInputC = 1;
+  std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8};
+  std::vector<int32_t> indices_data{0, 5, 10, 13, 19, 20, 27, 31};
+
+  MaxUnpoolingOpModel model(
+      /*input=*/{TensorType_FLOAT32, {kInputB, kInputH, kInputW, kInputC}},
+      /*indices=*/{TensorType_INT32, {kInputB, kInputH, kInputW, kInputC}},
+      /*stride_height=*/2, /*stride_width=*/2,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingSame,
+      /*output=*/{TensorType_FLOAT32, {}});
+
+  model.SetInput(input_data);
+  model.SetIndices(indices_data);
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 8, 1}));
+  EXPECT_THAT(
+      model.GetOutput(),
+      ElementsAreArray({1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0,
+                        0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 8}));
+}
+
+TEST(MaxUnpoolingOpTest, PaddingValidTest) {
+  constexpr int kInputB = 1;
+  constexpr int kInputH = 2;
+  constexpr int kInputW = 2;
+  constexpr int kInputC = 1;
+  std::vector<float> input_data{7, 10, 20, 19};
+  std::vector<int32_t> indices_data{6, 9, 16, 19};
+
+  MaxUnpoolingOpModel model(
+      /*input=*/{TensorType_FLOAT32, {kInputB, kInputH, kInputW, kInputC}},
+      /*indices=*/{TensorType_INT32, {kInputB, kInputH, kInputW, kInputC}},
+      /*stride_height=*/2, /*stride_width=*/2,
+      /*filter_height=*/2, /*filter_width=*/3,
+      /*padding=*/kTfLitePaddingValid,
+      /*output=*/{TensorType_FLOAT32, {}});
+
+  model.SetInput(input_data);
+  model.SetIndices(indices_data);
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 5, 1}));
+  EXPECT_THAT(model.GetOutput(),
+              ElementsAreArray({0, 0, 0, 0, 0, 0, 7,  0, 0, 10,
+                                0, 0, 0, 0, 0, 0, 20, 0, 0, 19}));
+}
+
+TEST(MaxUnpoolingOpTest, InputWithBatchTest) {
+  constexpr int kInputB = 2;
+  constexpr int kInputH = 2;
+  constexpr int kInputW = 4;
+  constexpr int kInputC = 2;
+  std::vector<float> input_data{1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+                                12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
+                                23, 24, 25, 26, 27, 28, 29, 30, 31, 32};
+  std::vector<int32_t> indices_data{2,  23, 8,  9,  12, 15, 40, 43, 44, 47, 72,
+                                    75, 80, 79, 62, 65, 0,  1,  30, 7,  14, 35,
+                                    42, 21, 68, 69, 50, 51, 56, 5,  86, 63};
+
+  MaxUnpoolingOpModel model(
+      /*input=*/{TensorType_FLOAT32, {kInputB, kInputH, kInputW, kInputC}},
+      /*indices=*/{TensorType_INT32, {kInputB, kInputH, kInputW, kInputC}},
+      /*stride_height=*/2, /*stride_width=*/3,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingSame,
+      /*output=*/{TensorType_FLOAT32, {}});
+
+  model.SetInput(input_data);
+  model.SetIndices(indices_data);
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4, 12, 2}));
+  EXPECT_THAT(
+      model.GetOutput(),
+      ElementsAreArray(
+          {0,  0, 1,  0,  0, 0,  0,  0,  3,  4,  0, 0,  5,  0,  0, 6,  0, 0,
+           0,  0, 0,  0,  0, 2,  0,  0,  0,  0,  0, 0,  0,  0,  0, 0,  0, 0,
+           0,  0, 0,  0,  7, 0,  0,  8,  9,  0,  0, 10, 0,  0,  0, 0,  0, 0,
+           0,  0, 0,  0,  0, 0,  0,  0,  15, 0,  0, 16, 0,  0,  0, 0,  0, 0,
+           11, 0, 0,  12, 0, 0,  0,  14, 13, 0,  0, 0,  0,  0,  0, 0,  0, 0,
+           0,  0, 0,  0,  0, 0,  17, 18, 0,  0,  0, 30, 0,  20, 0, 0,  0, 0,
+           0,  0, 21, 0,  0, 0,  0,  0,  0,  24, 0, 0,  0,  0,  0, 0,  0, 0,
+           19, 0, 0,  0,  0, 22, 0,  0,  0,  0,  0, 0,  23, 0,  0, 0,  0, 0,
+           0,  0, 27, 28, 0, 0,  0,  0,  29, 0,  0, 0,  0,  0,  0, 32, 0, 0,
+           0,  0, 25, 26, 0, 0,  0,  0,  0,  0,  0, 0,  0,  0,  0, 0,  0, 0,
+           0,  0, 31, 0,  0, 0,  0,  0,  0,  0,  0, 0}));
+}
+
+TEST(MaxUnpoolingOpTest, InputWithBatchAndPaddingValidTest) {
+  constexpr int kInputB = 2;
+  constexpr int kInputH = 2;
+  constexpr int kInputW = 4;
+  constexpr int kInputC = 2;
+  std::vector<float> input_data{1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+                                12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
+                                23, 24, 25, 26, 27, 28, 29, 30, 31, 32};
+  std::vector<int32_t> indices_data{2,  23, 8,  9,  12, 15, 40, 43, 44, 47, 72,
+                                    75, 80, 79, 62, 65, 0,  1,  30, 7,  14, 35,
+                                    42, 21, 68, 69, 50, 51, 56, 5,  86, 63};
+
+  MaxUnpoolingOpModel model(
+      /*input=*/{TensorType_FLOAT32, {kInputB, kInputH, kInputW, kInputC}},
+      /*indices=*/{TensorType_INT32, {kInputB, kInputH, kInputW, kInputC}},
+      /*stride_height=*/2, /*stride_width=*/3,
+      /*filter_height=*/2, /*filter_width=*/2,
+      /*padding=*/kTfLitePaddingValid,
+      /*output=*/{TensorType_FLOAT32, {}});
+
+  model.SetInput(input_data);
+  model.SetIndices(indices_data);
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4, 11, 2}));
+  EXPECT_THAT(
+      model.GetOutput(),
+      ElementsAreArray(
+          {0,  0,  1, 0,  0,  0,  0, 0,  3,  4, 0,  0,  5,  0,  0, 6,  0,  0,
+           0,  0,  0, 0,  0,  2,  0, 0,  0,  0, 0,  0,  0,  0,  0, 0,  0,  0,
+           0,  0,  0, 0,  7,  0,  0, 8,  9,  0, 0,  10, 0,  0,  0, 0,  0,  0,
+           0,  0,  0, 0,  0,  0,  0, 0,  15, 0, 0,  16, 0,  0,  0, 0,  0,  0,
+           11, 0,  0, 12, 0,  0,  0, 14, 13, 0, 0,  0,  0,  0,  0, 0,  17, 18,
+           0,  0,  0, 30, 0,  20, 0, 0,  0,  0, 0,  0,  21, 0,  0, 0,  0,  0,
+           0,  24, 0, 0,  0,  0,  0, 0,  0,  0, 19, 0,  0,  0,  0, 22, 0,  0,
+           0,  0,  0, 0,  23, 0,  0, 0,  0,  0, 0,  0,  27, 28, 0, 0,  0,  0,
+           29, 0,  0, 0,  0,  0,  0, 32, 0,  0, 0,  0,  25, 26, 0, 0,  0,  0,
+           0,  0,  0, 0,  0,  0,  0, 0,  0,  0, 0,  0,  31, 0}));
+}
+
+}  // namespace
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/perception/perception_ops.cc b/tensorflow/lite/kernels/perception/perception_ops.cc
new file mode 100644
index 0000000..431530d
--- /dev/null
+++ b/tensorflow/lite/kernels/perception/perception_ops.cc
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/kernels/perception/perception_ops.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver) {
+  resolver->AddCustom("MaxUnpooling2D",
+                      tflite::ops::custom::RegisterMaxUnpooling2D());
+  resolver->AddCustom("MaxPoolWithArgmax",
+                      tflite::ops::custom::RegisterMaxPoolWithArgmax());
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/perception/perception_ops.h b/tensorflow/lite/kernels/perception/perception_ops.h
new file mode 100644
index 0000000..e5544b2
--- /dev/null
+++ b/tensorflow/lite/kernels/perception/perception_ops.h
@@ -0,0 +1,34 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_KERNELS_PERCEPTION_PERCEPTION_OPS_H_
+#define TENSORFLOW_LITE_KERNELS_PERCEPTION_PERCEPTION_OPS_H_
+
+#include "tensorflow/lite/mutable_op_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* RegisterMaxUnpooling2D();
+TfLiteRegistration* RegisterMaxPoolWithArgmax();
+
+extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver);
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_PERCEPTION_PERCEPTION_OPS_H_
diff --git a/tensorflow/lite/kernels/perception/perception_ops_wrapper.cc b/tensorflow/lite/kernels/perception/perception_ops_wrapper.cc
new file mode 100644
index 0000000..7fd1282
--- /dev/null
+++ b/tensorflow/lite/kernels/perception/perception_ops_wrapper.cc
@@ -0,0 +1,34 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "tensorflow/lite/kernels/perception/perception_ops.h"
+
+PYBIND11_MODULE(pywrap_perception_ops, m) {
+  m.doc() = R"pbdoc(
+    pywrap_perception_ops
+    -----
+  )pbdoc";
+  m.def(
+      "PerceptionOpsRegisterer",
+      [](uintptr_t resolver) {
+        tflite::ops::custom::AddPerceptionOps(
+            reinterpret_cast<tflite::MutableOpResolver*>(resolver));
+      },
+      R"pbdoc(
+        Perception op registerer function with the correct signature. Registers
+        Perception custom ops.
+      )pbdoc");
+}
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index 511c7cbf..a57f358 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -130,7 +130,9 @@
   AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(),
              /* min_version = */ 1,
              /* max_version = */ 2);
-  AddBuiltin(BuiltinOperator_DEPTH_TO_SPACE, Register_DEPTH_TO_SPACE());
+  AddBuiltin(BuiltinOperator_DEPTH_TO_SPACE, Register_DEPTH_TO_SPACE(),
+             /* min_version = */ 1,
+             /* max_version = */ 2);
   AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
              /* min_version = */ 1,
              /* max_version = */ 4);
@@ -299,7 +301,7 @@
   AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
   AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
              /* min_version = */ 1,
-             /* max_version = */ 3);
+             /* max_version = */ 4);
   AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
   // The version one of broadcast to op won't be not supported since the version
   // one was rollbacked and the builtin op code number has been changed because
diff --git a/tensorflow/lite/kernels/resize_nearest_neighbor.cc b/tensorflow/lite/kernels/resize_nearest_neighbor.cc
index bef3955..85e833c 100644
--- a/tensorflow/lite/kernels/resize_nearest_neighbor.cc
+++ b/tensorflow/lite/kernels/resize_nearest_neighbor.cc
@@ -68,7 +68,7 @@
   TF_LITE_ENSURE_OK(context,
                     GetOutputSafe(context, node, kOutputTensor, &output));
 
-  // TODO(ahentz): Our current implementations rely on the input being 4D,
+  // Our current implementations relies on the input being 4D,
   // and the size being 1D tensor with exactly 2 elements.
   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
   TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
diff --git a/tensorflow/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc
index 3f2fd58..2099109 100644
--- a/tensorflow/lite/kernels/strided_slice.cc
+++ b/tensorflow/lite/kernels/strided_slice.cc
@@ -37,7 +37,7 @@
 
 enum KernelType {
   kReference,
-  // TODO(soroosh): add kGenericOptimized
+  // TODO(b/175642009): add kGenericOptimized
 };
 
 constexpr int kInputTensor = 0;
@@ -154,7 +154,7 @@
   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
   // Only INT32 begin/end/strides are supported
-  // TODO(soroosh) add support for INT64
+  // TODO(b/175642009): add support for INT64
   TF_LITE_ENSURE_TYPES_EQ(context, op_context.begin->type, kTfLiteInt32);
   TF_LITE_ENSURE_TYPES_EQ(context, op_context.end->type, kTfLiteInt32);
   TF_LITE_ENSURE_TYPES_EQ(context, op_context.strides->type, kTfLiteInt32);
@@ -256,7 +256,6 @@
   return &r;
 }
 
-// TODO(soroosh): add optimized
 TfLiteRegistration* Register_STRIDED_SLICE() {
   return Register_STRIDED_SLICE_REF();
 }
diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc
index 6cf3e89..3f65b70 100644
--- a/tensorflow/lite/kernels/subgraph_test_util.cc
+++ b/tensorflow/lite/kernels/subgraph_test_util.cc
@@ -23,6 +23,7 @@
 #include <vector>
 
 #include <gtest/gtest.h>
+#include "tensorflow/lite/builtin_ops.h"
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/core/subgraph.h"
@@ -113,10 +114,11 @@
   TfLiteAddParams* params =
       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
   params->activation = kTfLiteActNone;
+  auto* add_reg = ops::builtin::Register_ADD();
+  add_reg->builtin_code = kTfLiteBuiltinAdd;
   int node_index;
-  subgraph->AddNodeWithParameters(
-      {kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params,
-      ::tflite::ops::builtin::Register_ADD(), &node_index);
+  subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0,
+                                  params, add_reg, &node_index);
 }
 
 // Build a subgraph with an mul op. Helper function for testing.
@@ -143,10 +145,11 @@
   TfLiteMulParams* params =
       reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
   params->activation = kTfLiteActNone;
+  auto* mul_reg = ops::builtin::Register_MUL();
+  mul_reg->builtin_code = kTfLiteBuiltinMul;
   int node_index;
-  subgraph->AddNodeWithParameters(
-      {kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params,
-      ::tflite::ops::builtin::Register_MUL(), &node_index);
+  subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0,
+                                  params, mul_reg, &node_index);
 }
 
 // Build a subgraph with a pad op. Helper function for testing.
@@ -172,10 +175,11 @@
 
   TfLitePadParams* params =
       reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLitePadParams)));
+  auto* pad_reg = ops::builtin::Register_PAD();
+  pad_reg->builtin_code = kTfLiteBuiltinPad;
   int node_index;
-  subgraph->AddNodeWithParameters(
-      {kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params,
-      ::tflite::ops::builtin::Register_PAD(), &node_index);
+  subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0,
+                                  params, pad_reg, &node_index);
 }
 
 void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) {
@@ -205,11 +209,12 @@
       reinterpret_cast<TfLiteIfParams*>(malloc(sizeof(TfLiteIfParams)));
   params->then_subgraph_index = 1;
   params->else_subgraph_index = 2;
+  auto* if_reg = ops::builtin::Register_IF();
+  if_reg->builtin_code = kTfLiteBuiltinIf;
 
   int node_index;
-  subgraph->AddNodeWithParameters(
-      {kCondInput, kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params,
-      ::tflite::ops::builtin::Register_IF(), &node_index);
+  subgraph->AddNodeWithParameters({kCondInput, kInput1, kInput2}, {kOutput}, {},
+                                  nullptr, 0, params, if_reg, &node_index);
 }
 
 void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) {
@@ -236,11 +241,13 @@
   SetupTensor(subgraph, kInput2, kTfLiteInt32);
   SetupTensor(subgraph, kOutput, kTfLiteBool);
 
+  auto* le_reg = ops::builtin::Register_LESS_EQUAL();
+  le_reg->builtin_code = kTfLiteBuiltinLessEqual;
+
   CreateConstantInt32Tensor(subgraph, kConstRhs, {1}, {rhs});
   int node_index;
-  subgraph->AddNodeWithParameters(
-      {kInput1, kConstRhs}, {kOutput}, {}, nullptr, 0, nullptr,
-      ::tflite::ops::builtin::Register_LESS_EQUAL(), &node_index);
+  subgraph->AddNodeWithParameters({kInput1, kConstRhs}, {kOutput}, {}, nullptr,
+                                  0, nullptr, le_reg, &node_index);
 }
 
 void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) {
@@ -277,13 +284,13 @@
   TfLiteAddParams* params =
       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
   params->activation = kTfLiteActNone;
-  subgraph->AddNodeWithParameters({0, 4}, {2}, {}, nullptr, 0, params,
-                                  ::tflite::ops::builtin::Register_ADD(),
+  auto* add_reg = ops::builtin::Register_ADD();
+  add_reg->builtin_code = kTfLiteBuiltinAdd;
+  subgraph->AddNodeWithParameters({0, 4}, {2}, {}, nullptr, 0, params, add_reg,
                                   &node_index);
   params = reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
   params->activation = kTfLiteActNone;
-  subgraph->AddNodeWithParameters({2, 1}, {3}, {}, nullptr, 0, params,
-                                  ::tflite::ops::builtin::Register_ADD(),
+  subgraph->AddNodeWithParameters({2, 1}, {3}, {}, nullptr, 0, params, add_reg,
                                   &node_index);
 }
 
@@ -327,14 +334,18 @@
   TfLiteAddParams* add_params =
       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
   add_params->activation = kTfLiteActNone;
-  subgraph->AddNodeWithParameters(
-      {kInputCounter, kConstStep}, {kOutputCounter}, {}, nullptr, 0, add_params,
-      ::tflite::ops::builtin::Register_ADD(), &node_index);
+  auto* add_reg = ops::builtin::Register_ADD();
+  add_reg->builtin_code = kTfLiteBuiltinAdd;
+  subgraph->AddNodeWithParameters({kInputCounter, kConstStep}, {kOutputCounter},
+                                  {}, nullptr, 0, add_params, add_reg,
+                                  &node_index);
   TfLitePadParams* pad_params =
       reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLiteAddParams)));
-  subgraph->AddNodeWithParameters(
-      {kInputValue, kConstPadding}, {kOutputValue}, {}, nullptr, 0, pad_params,
-      ::tflite::ops::builtin::Register_PAD(), &node_index);
+  auto* pad_reg = ops::builtin::Register_PAD();
+  pad_reg->builtin_code = kTfLiteBuiltinPad;
+  subgraph->AddNodeWithParameters({kInputValue, kConstPadding}, {kOutputValue},
+                                  {}, nullptr, 0, pad_params, pad_reg,
+                                  &node_index);
 }
 
 void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) {
@@ -364,11 +375,12 @@
       reinterpret_cast<TfLiteWhileParams*>(malloc(sizeof(TfLiteWhileParams)));
   params->cond_subgraph_index = 1;
   params->body_subgraph_index = 2;
+  auto* while_reg = ops::builtin::Register_WHILE();
+  while_reg->builtin_code = kTfLiteBuiltinWhile;
 
   int node_index;
   subgraph->AddNodeWithParameters({0, 1}, {2, 3}, {}, nullptr, 0, params,
-                                  ::tflite::ops::builtin::Register_WHILE(),
-                                  &node_index);
+                                  while_reg, &node_index);
 }
 
 void SubgraphBuilder::BuildAssignRandomValueToVariableSubgraph(
diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h
index 9cd272f..7cc986a 100644
--- a/tensorflow/lite/kernels/test_util.h
+++ b/tensorflow/lite/kernels/test_util.h
@@ -225,8 +225,8 @@
         t.shape, t.traversal_order, t.format, t.block_size, t.block_map);
     converter.DenseToSparse(dense_data.data());
 
-    const auto dim_metadata = converter.GetDimMetadata();
-    const auto sparse_data = converter.GetData();
+    const auto& dim_metadata = converter.GetDimMetadata();
+    const auto& sparse_data = converter.GetData();
 
     // Build sparsity parameter.
     std::vector<flatbuffers::Offset<DimensionMetadata>> fb_dim_metadata(
diff --git a/tensorflow/lite/micro/CONTRIBUTING.md b/tensorflow/lite/micro/CONTRIBUTING.md
index 78360b9..df76170 100644
--- a/tensorflow/lite/micro/CONTRIBUTING.md
+++ b/tensorflow/lite/micro/CONTRIBUTING.md
@@ -18,8 +18,9 @@
     *   [During the PR review](#during-the-pr-review)
     *   [Reviewer notes](#reviewer-notes)
     *   [Python notes](#python-notes)
+*   [Continuous Integration System](#continuous-integration-system)
 
-<!-- Added by: advaitjain, at: Mon 05 Oct 2020 02:38:02 PM PDT -->
+<!-- Added by: advaitjain, at: Tue 15 Dec 2020 03:06:29 PM PST -->
 
 <!--te-->
 
@@ -305,3 +306,32 @@
     ```
     yapf log_parser.py -i --style='{based_on_style: pep8, indent_width: 2}'
     ```
+
+    # Continuous Integration System
+
+    *   As a contributor, please make sure that the TfLite Micro build is green.
+        You can click on the details link to see what the errors are:
+
+    [![TfLite Micro Build](docs/images/tflm_continuous_integration_1.png)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/tflite-micro.html)
+
+    *   Most of the tests that are run as part of the CI are with the
+        [micro/tools/ci_build/test_all.sh](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/tools/ci_build/test_all.sh)
+        script.
+
+        *   There are a few additional tests that use bazel that are not
+            captured by the `test_all.sh` script.
+
+    *   If an error is not reproducible on your development machine, you can
+        recreate the docker container that is used on the CI servers with the
+        following commands (run from the root of the tensorflow github repo):
+
+    ```
+    mkdir /tmp/tflm-docker
+    docker build -f tensorflow/tools/ci_build/Dockerfile.micro -t tflm /tmp/tflm-docker
+    docker run -v `pwd`:/tensorflow -it tflm bash
+    ```
+
+    The `docker run` command is mounting the tensorflow repository on your
+    machine to the docker containter. As a result, any changes made within the
+    docker container will also be reflected in the directory in the host
+    machine.
diff --git a/tensorflow/lite/micro/all_ops_resolver.cc b/tensorflow/lite/micro/all_ops_resolver.cc
index f7bbcb9..8b87de8 100644
--- a/tensorflow/lite/micro/all_ops_resolver.cc
+++ b/tensorflow/lite/micro/all_ops_resolver.cc
@@ -1,8 +1,11 @@
 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
+
     http://www.apache.org/licenses/LICENSE-2.0
+
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/all_ops_resolver.h b/tensorflow/lite/micro/all_ops_resolver.h
index e8105b9..391b4f0 100644
--- a/tensorflow/lite/micro/all_ops_resolver.h
+++ b/tensorflow/lite/micro/all_ops_resolver.h
@@ -1,8 +1,11 @@
 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
+
     http://www.apache.org/licenses/LICENSE-2.0
+
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/apollo3evb/debug_log.cc b/tensorflow/lite/micro/apollo3evb/debug_log.cc
index 1523d4b..ea33a8e 100644
--- a/tensorflow/lite/micro/apollo3evb/debug_log.cc
+++ b/tensorflow/lite/micro/apollo3evb/debug_log.cc
@@ -37,8 +37,8 @@
 #include "tensorflow/lite/micro/debug_log.h"
 
 // These are headers from Ambiq's Apollo3 SDK.
-#include "am_bsp.h"         // NOLINT
-#include "am_util.h"        // NOLINT
+#include "am_bsp.h"   // NOLINT
+#include "am_util.h"  // NOLINT
 
 extern "C" void DebugLog(const char* s) {
 #ifndef TF_LITE_STRIP_ERROR_STRINGS
diff --git a/tensorflow/lite/micro/bluepill/debug_log.cc b/tensorflow/lite/micro/bluepill/debug_log.cc
index dd8a3b3..3fd2d52 100644
--- a/tensorflow/lite/micro/bluepill/debug_log.cc
+++ b/tensorflow/lite/micro/bluepill/debug_log.cc
@@ -22,6 +22,6 @@
       "mov r1, %[str]\n"
       "bkpt #0xAB\n"
       :
-      : [ str ] "r"(s)
+      : [str] "r"(s)
       : "r0", "r1");
 }
diff --git a/tensorflow/lite/micro/docs/images/tflm_continuous_integration_1.png b/tensorflow/lite/micro/docs/images/tflm_continuous_integration_1.png
new file mode 100644
index 0000000..acecc0e
--- /dev/null
+++ b/tensorflow/lite/micro/docs/images/tflm_continuous_integration_1.png
Binary files differ
diff --git a/tensorflow/lite/micro/docs/renode.md b/tensorflow/lite/micro/docs/renode.md
index 6e411bd..7132091 100644
--- a/tensorflow/lite/micro/docs/renode.md
+++ b/tensorflow/lite/micro/docs/renode.md
@@ -32,19 +32,14 @@
 [here](https://renode.readthedocs.io/en/latest/). For the purpose of Tensorflow
 Lite Micro, we make use of a portable version for Linux.
 
- 1. Download portable version of Renode for Linux:
+Portable renode wil be automatically installed when using the TfLite Micro
+Makefile to `tensorflow/lite/micro/tools/make/downloads/renode`.
 
-    ```
-    tensorflow/lite/micro/testing/download_renode.sh tensorflow/lite/micro/tools/make/downloads/renode
-    ```
+The Makefile internally calls the `renode_download.sh` script:
 
- 2. Install the Renode test dependencies
-
-    ```
-    pip3 install -r tensorflow/lite/micro/tools/make/downloads/renode/tests/requirements.txt
-    ```
-
-At this point in time you will be ready to run TFLM tests with Renode.
+```
+tensorflow/lite/micro/testing/renode_download.sh tensorflow/lite/micro/tools/make/downloads
+```
 
 # Running Unit Tests
 
@@ -56,6 +51,7 @@
 
  * This makes use of the robot framework from Renode.
  * Note that the tests can currently not be run in parallel.
+ * It takes about 25 second to complete all tests, including around 3 seconds for suite startup/teardown and average 0.38 second per test.
 
 ## Under the hood of the Testing Infrastructure
 
@@ -74,9 +70,32 @@
 
 # Running a non-test Binary with Renode
 
-It may be useful to run binaries on Renode that are not tests, independent of
-the robot framework. We will be adding some documentation for that in this
-section.
+Renode can also be used to run and debug binaries interactively. For example,
+to debug `kernel_addr_test` on Bluepill platform, run Renode:
+
+```
+tensorflow/lite/micro/tools/make/downloads/renode/renode
+```
+and issue following commands:
+```
+# Create platform
+include @tensorflow/lite/micro/testing/bluepill_nontest.resc
+# Load ELF file
+sysbus LoadELF @tensorflow/lite/micro/tools/make/gen/bluepill_cortex-m3/bin/kernel_add_test
+# Start simulation
+start
+```
+You can also connect GDB to the simulation.
+To do that, start the GDB server in Renode before issuing the `start` command:
+```
+machine StartGdbServer 3333
+```
+Than you can connect from GDB with:
+```
+target remote localhost:3333
+```
+
+For further reference please see the [Renode documentation](https://renode.readthedocs.io/en/latest/).
 
 # Useful External Links for Renode and Robot Documentation
 
@@ -102,4 +121,3 @@
 
        * [Remove File](http://robotframework.org/robotframework/latest/libraries/OperatingSystem.html#Remove%20File)
        * [List Files In Directory](https://robotframework.org/robotframework/latest/libraries/OperatingSystem.html#List%20Files%20In%20Directory)
-
diff --git a/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/prj.conf b/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/prj.conf
index f4d8a9f..4533ed5 100644
--- a/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/prj.conf
+++ b/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/prj.conf
@@ -1,10 +1,10 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
-#   http://www.apache.org/licenses/LICENSE-2.0
+#     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
diff --git a/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/src/assert.cc b/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/src/assert.cc
index 2141c09..2f709c6 100644
--- a/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/src/assert.cc
+++ b/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/src/assert.cc
@@ -15,5 +15,5 @@
 
 extern "C" {
 
-void __assert_func(const char *, int, const char *, const char *) {}
+void __assert_func(const char*, int, const char*, const char*) {}
 }
diff --git a/tensorflow/lite/micro/examples/magic_wand/zephyr_riscv/prj.conf b/tensorflow/lite/micro/examples/magic_wand/zephyr_riscv/prj.conf
index e415208..449a721 100644
--- a/tensorflow/lite/micro/examples/magic_wand/zephyr_riscv/prj.conf
+++ b/tensorflow/lite/micro/examples/magic_wand/zephyr_riscv/prj.conf
@@ -1,10 +1,10 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
-#    http://www.apache.org/licenses/LICENSE-2.0
+#     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
diff --git a/tensorflow/lite/micro/examples/magic_wand/zephyr_riscv/src/assert.cc b/tensorflow/lite/micro/examples/magic_wand/zephyr_riscv/src/assert.cc
index 2141c09..2f709c6 100644
--- a/tensorflow/lite/micro/examples/magic_wand/zephyr_riscv/src/assert.cc
+++ b/tensorflow/lite/micro/examples/magic_wand/zephyr_riscv/src/assert.cc
@@ -15,5 +15,5 @@
 
 extern "C" {
 
-void __assert_func(const char *, int, const char *, const char *) {}
+void __assert_func(const char*, int, const char*, const char*) {}
 }
diff --git a/tensorflow/lite/micro/examples/micro_speech/apollo3/_main.c b/tensorflow/lite/micro/examples/micro_speech/apollo3/_main.c
index b49d5c5..5ea6ac1 100644
--- a/tensorflow/lite/micro/examples/micro_speech/apollo3/_main.c
+++ b/tensorflow/lite/micro/examples/micro_speech/apollo3/_main.c
@@ -14,6 +14,7 @@
 ==============================================================================*/
 
 #include <stdint.h>
+
 #include "am_bsp.h"
 #include "am_mcu_apollo.h"  // Defines AM_CMSIS_REGS
 #include "am_util.h"
diff --git a/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_1k_cmsis_test.cmd b/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_1k_cmsis_test.cmd
index 6988057..47b78e5 100644
--- a/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_1k_cmsis_test.cmd
+++ b/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_1k_cmsis_test.cmd
@@ -1,3 +1,11 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_1k_micro_test.cmd b/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_1k_micro_test.cmd
index dc9cd4f..2b4cda8 100644
--- a/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_1k_micro_test.cmd
+++ b/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_1k_micro_test.cmd
@@ -1,3 +1,11 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_test.cmd b/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_test.cmd
index bd2048e..2f270d2 100644
--- a/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_test.cmd
+++ b/tensorflow/lite/micro/examples/micro_speech/apollo3/preprocessor_test.cmd
@@ -1,3 +1,11 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_cmsis_scores.cmd b/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_cmsis_scores.cmd
index ace278f..bf8521b 100644
--- a/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_cmsis_scores.cmd
+++ b/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_cmsis_scores.cmd
@@ -1,3 +1,11 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_cmsis_voice.cmd b/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_cmsis_voice.cmd
index 5dea48e..a77a96a 100644
--- a/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_cmsis_voice.cmd
+++ b/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_cmsis_voice.cmd
@@ -1,11 +1,11 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-# 
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_main.c b/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_main.c
index 4f70d47..74f2201 100644
--- a/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_main.c
+++ b/tensorflow/lite/micro/examples/micro_speech/apollo3/pushbutton_main.c
@@ -16,6 +16,7 @@
 /* This file is a modification of the Tensorflow Micro Lite file _main.c */
 
 #include <stdint.h>
+
 #include "am_bsp.h"
 #include "am_mcu_apollo.h"  // Defines AM_CMSIS_REGS
 #include "am_util.h"
diff --git a/tensorflow/lite/micro/examples/micro_speech/apollo3evb/micro_speech.cmd b/tensorflow/lite/micro/examples/micro_speech/apollo3evb/micro_speech.cmd
index 46d8dfa..a9d235d 100644
--- a/tensorflow/lite/micro/examples/micro_speech/apollo3evb/micro_speech.cmd
+++ b/tensorflow/lite/micro/examples/micro_speech/apollo3evb/micro_speech.cmd
@@ -1,3 +1,11 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/emsdp.lcf b/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/emsdp.lcf
index ae17db1..4b252ed 100644
--- a/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/emsdp.lcf
+++ b/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/emsdp.lcf
@@ -1,8 +1,11 @@
 # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/examples/micro_speech/ceva/main_functions.cc b/tensorflow/lite/micro/examples/micro_speech/ceva/main_functions.cc
index db19645..f4af2d8 100644
--- a/tensorflow/lite/micro/examples/micro_speech/ceva/main_functions.cc
+++ b/tensorflow/lite/micro/examples/micro_speech/ceva/main_functions.cc
@@ -1,8 +1,11 @@
 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
+
     http://www.apache.org/licenses/LICENSE-2.0
+
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/examples/micro_speech/disco_f746ng/command_responder.cc b/tensorflow/lite/micro/examples/micro_speech/disco_f746ng/command_responder.cc
index 1489d76..e5f962f 100644
--- a/tensorflow/lite/micro/examples/micro_speech/disco_f746ng/command_responder.cc
+++ b/tensorflow/lite/micro/examples/micro_speech/disco_f746ng/command_responder.cc
@@ -21,24 +21,24 @@
 
 // When a command is detected, write it to the display and log it to the
 // serial port.
-void RespondToCommand(tflite::ErrorReporter *error_reporter,
-                      int32_t current_time, const char *found_command,
+void RespondToCommand(tflite::ErrorReporter* error_reporter,
+                      int32_t current_time, const char* found_command,
                       uint8_t score, bool is_new_command) {
   if (is_new_command) {
     TF_LITE_REPORT_ERROR(error_reporter, "Heard %s (%d) @%dms", found_command,
                          score, current_time);
     if (*found_command == 'y') {
       lcd.Clear(0xFF0F9D58);
-      lcd.DisplayStringAt(0, LINE(5), (uint8_t *)"Heard yes!", CENTER_MODE);
+      lcd.DisplayStringAt(0, LINE(5), (uint8_t*)"Heard yes!", CENTER_MODE);
     } else if (*found_command == 'n') {
       lcd.Clear(0xFFDB4437);
-      lcd.DisplayStringAt(0, LINE(5), (uint8_t *)"Heard no :(", CENTER_MODE);
+      lcd.DisplayStringAt(0, LINE(5), (uint8_t*)"Heard no :(", CENTER_MODE);
     } else if (*found_command == 'u') {
       lcd.Clear(0xFFF4B400);
-      lcd.DisplayStringAt(0, LINE(5), (uint8_t *)"Heard unknown", CENTER_MODE);
+      lcd.DisplayStringAt(0, LINE(5), (uint8_t*)"Heard unknown", CENTER_MODE);
     } else {
       lcd.Clear(0xFF4285F4);
-      lcd.DisplayStringAt(0, LINE(5), (uint8_t *)"Heard silence", CENTER_MODE);
+      lcd.DisplayStringAt(0, LINE(5), (uint8_t*)"Heard silence", CENTER_MODE);
     }
   }
 }
diff --git a/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc b/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc
index 3596246..b2bb18b 100644
--- a/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc
+++ b/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc
@@ -35,9 +35,9 @@
 
 using namespace std;
 
-static const char *TAG = "TF_LITE_AUDIO_PROVIDER";
+static const char* TAG = "TF_LITE_AUDIO_PROVIDER";
 /* ringbuffer to hold the incoming audio data */
-ringbuf_t *g_audio_capture_buffer;
+ringbuf_t* g_audio_capture_buffer;
 volatile int32_t g_latest_audio_timestamp = 0;
 /* model requires 20ms new data from g_audio_capture_buffer and 10ms old data
  * each time , storing old data in the histrory buffer , {
@@ -96,13 +96,13 @@
   }
 }
 
-static void CaptureSamples(void *arg) {
+static void CaptureSamples(void* arg) {
   size_t bytes_read;
   uint8_t i2s_read_buffer[i2s_bytes_to_read] = {};
   i2s_init();
   while (1) {
     /* read 100ms data at once from i2s */
-    i2s_read((i2s_port_t)1, (void *)i2s_read_buffer, i2s_bytes_to_read,
+    i2s_read((i2s_port_t)1, (void*)i2s_read_buffer, i2s_bytes_to_read,
              &bytes_read, 10);
     if (bytes_read <= 0) {
       ESP_LOGE(TAG, "Error in I2S read : %d", bytes_read);
@@ -112,7 +112,7 @@
       }
       /* write bytes read by i2s into ring buffer */
       int bytes_written = rb_write(g_audio_capture_buffer,
-                                   (uint8_t *)i2s_read_buffer, bytes_read, 10);
+                                   (uint8_t*)i2s_read_buffer, bytes_read, 10);
       /* update the timestamp (in ms) to let the model know that new data has
        * arrived */
       g_latest_audio_timestamp +=
@@ -127,7 +127,7 @@
   vTaskDelete(NULL);
 }
 
-TfLiteStatus InitAudioRecording(tflite::ErrorReporter *error_reporter) {
+TfLiteStatus InitAudioRecording(tflite::ErrorReporter* error_reporter) {
   g_audio_capture_buffer = rb_init("tf_ringbuffer", kAudioCaptureBufferSize);
   if (!g_audio_capture_buffer) {
     ESP_LOGE(TAG, "Error creating ring buffer");
@@ -142,9 +142,9 @@
   return kTfLiteOk;
 }
 
-TfLiteStatus GetAudioSamples(tflite::ErrorReporter *error_reporter,
+TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter,
                              int start_ms, int duration_ms,
-                             int *audio_samples_size, int16_t **audio_samples) {
+                             int* audio_samples_size, int16_t** audio_samples) {
   if (!g_is_audio_initialized) {
     TfLiteStatus init_status = InitAudioRecording(error_reporter);
     if (init_status != kTfLiteOk) {
@@ -153,14 +153,14 @@
     g_is_audio_initialized = true;
   }
   /* copy 160 samples (320 bytes) into output_buff from history */
-  memcpy((void *)(g_audio_output_buffer), (void *)(g_history_buffer),
+  memcpy((void*)(g_audio_output_buffer), (void*)(g_history_buffer),
          history_samples_to_keep * sizeof(int16_t));
 
   /* copy 320 samples (640 bytes) from rb at ( int16_t*(g_audio_output_buffer) +
    * 160 ), first 160 samples (320 bytes) will be from history */
   int32_t bytes_read =
       rb_read(g_audio_capture_buffer,
-              ((uint8_t *)(g_audio_output_buffer + history_samples_to_keep)),
+              ((uint8_t*)(g_audio_output_buffer + history_samples_to_keep)),
               new_samples_to_get * sizeof(int16_t), 10);
   if (bytes_read < 0) {
     ESP_LOGE(TAG, " Model Could not read data from Ring Buffer");
@@ -173,8 +173,8 @@
   }
 
   /* copy 320 bytes from output_buff into history */
-  memcpy((void *)(g_history_buffer),
-         (void *)(g_audio_output_buffer + new_samples_to_get),
+  memcpy((void*)(g_history_buffer),
+         (void*)(g_audio_output_buffer + new_samples_to_get),
          history_samples_to_keep * sizeof(int16_t));
 
   *audio_samples_size = kMaxAudioSampleSize;
diff --git a/tensorflow/lite/micro/examples/micro_speech/esp/ringbuf.c b/tensorflow/lite/micro/examples/micro_speech/esp/ringbuf.c
index 6bf1585..e50abf7 100644
--- a/tensorflow/lite/micro/examples/micro_speech/esp/ringbuf.c
+++ b/tensorflow/lite/micro/examples/micro_speech/esp/ringbuf.c
@@ -32,9 +32,9 @@
 
 #define RB_TAG "RINGBUF"
 
-ringbuf_t *rb_init(const char *name, uint32_t size) {
-  ringbuf_t *r;
-  unsigned char *buf;
+ringbuf_t* rb_init(const char* name, uint32_t size) {
+  ringbuf_t* r;
+  unsigned char* buf;
 
   if (size < 2 || !name) {
     return NULL;
@@ -50,7 +50,7 @@
 #endif
   assert(buf);
 
-  r->name = (char *)name;
+  r->name = (char*)name;
   r->base = r->readptr = r->writeptr = buf;
   r->fill_cnt = 0;
   r->size = size;
@@ -70,7 +70,7 @@
   return r;
 }
 
-void rb_cleanup(ringbuf_t *rb) {
+void rb_cleanup(ringbuf_t* rb) {
   free(rb->base);
   rb->base = NULL;
   vSemaphoreDelete(rb->can_read);
@@ -85,17 +85,17 @@
 /*
  * @brief: get the number of filled bytes in the buffer
  */
-ssize_t rb_filled(ringbuf_t *rb) { return rb->fill_cnt; }
+ssize_t rb_filled(ringbuf_t* rb) { return rb->fill_cnt; }
 
 /*
  * @brief: get the number of empty bytes available in the buffer
  */
-ssize_t rb_available(ringbuf_t *rb) {
+ssize_t rb_available(ringbuf_t* rb) {
   ESP_LOGD(RB_TAG, "rb leftover %d bytes", rb->size - rb->fill_cnt);
   return (rb->size - rb->fill_cnt);
 }
 
-int rb_read(ringbuf_t *rb, uint8_t *buf, int buf_len, uint32_t ticks_to_wait) {
+int rb_read(ringbuf_t* rb, uint8_t* buf, int buf_len, uint32_t ticks_to_wait) {
   int read_size;
   int total_read_size = 0;
 
@@ -178,7 +178,7 @@
   return total_read_size;
 }
 
-int rb_write(ringbuf_t *rb, const uint8_t *buf, int buf_len,
+int rb_write(ringbuf_t* rb, const uint8_t* buf, int buf_len,
              uint32_t ticks_to_wait) {
   int write_size;
   int total_write_size = 0;
@@ -245,7 +245,7 @@
 /**
  * abort and set abort_read and abort_write to asked values.
  */
-static void _rb_reset(ringbuf_t *rb, int abort_read, int abort_write) {
+static void _rb_reset(ringbuf_t* rb, int abort_read, int abort_write) {
   if (rb == NULL) {
     return;
   }
@@ -259,9 +259,9 @@
   xSemaphoreGive(rb->lock);
 }
 
-void rb_reset(ringbuf_t *rb) { _rb_reset(rb, 0, 0); }
+void rb_reset(ringbuf_t* rb) { _rb_reset(rb, 0, 0); }
 
-void rb_abort_read(ringbuf_t *rb) {
+void rb_abort_read(ringbuf_t* rb) {
   if (rb == NULL) {
     return;
   }
@@ -270,7 +270,7 @@
   xSemaphoreGive(rb->lock);
 }
 
-void rb_abort_write(ringbuf_t *rb) {
+void rb_abort_write(ringbuf_t* rb) {
   if (rb == NULL) {
     return;
   }
@@ -279,7 +279,7 @@
   xSemaphoreGive(rb->lock);
 }
 
-void rb_abort(ringbuf_t *rb) {
+void rb_abort(ringbuf_t* rb) {
   if (rb == NULL) {
     return;
   }
@@ -296,12 +296,12 @@
  * This serves a special purpose to not allow this abort to be mixed with
  * rb_write.
  */
-void rb_reset_and_abort_write(ringbuf_t *rb) {
+void rb_reset_and_abort_write(ringbuf_t* rb) {
   _rb_reset(rb, 0, 1);
   xSemaphoreGive(rb->can_write);
 }
 
-void rb_signal_writer_finished(ringbuf_t *rb) {
+void rb_signal_writer_finished(ringbuf_t* rb) {
   if (rb == NULL) {
     return;
   }
@@ -309,14 +309,14 @@
   xSemaphoreGive(rb->can_read);
 }
 
-int rb_is_writer_finished(ringbuf_t *rb) {
+int rb_is_writer_finished(ringbuf_t* rb) {
   if (rb == NULL) {
     return RB_FAIL;
   }
   return (rb->writer_finished);
 }
 
-void rb_wakeup_reader(ringbuf_t *rb) {
+void rb_wakeup_reader(ringbuf_t* rb) {
   if (rb == NULL) {
     return;
   }
@@ -324,7 +324,7 @@
   xSemaphoreGive(rb->can_read);
 }
 
-void rb_stat(ringbuf_t *rb) {
+void rb_stat(ringbuf_t* rb) {
   xSemaphoreTake(rb->lock, portMAX_DELAY);
   ESP_LOGI(RB_TAG,
            "filled: %d, base: %p, read_ptr: %p, write_ptr: %p, size: %d\n",
diff --git a/tensorflow/lite/micro/examples/micro_speech/esp/ringbuf.h b/tensorflow/lite/micro/examples/micro_speech/esp/ringbuf.h
index 191afce..98b9b3b 100644
--- a/tensorflow/lite/micro/examples/micro_speech/esp/ringbuf.h
+++ b/tensorflow/lite/micro/examples/micro_speech/esp/ringbuf.h
@@ -30,11 +30,11 @@
 #define RB_READER_UNBLOCK -3
 
 typedef struct ringbuf {
-  char *name;
-  uint8_t *base; /**< Original pointer */
+  char* name;
+  uint8_t* base; /**< Original pointer */
   /* XXX: these need to be volatile? */
-  uint8_t *volatile readptr;  /**< Read pointer */
-  uint8_t *volatile writeptr; /**< Write pointer */
+  uint8_t* volatile readptr;  /**< Read pointer */
+  uint8_t* volatile writeptr; /**< Write pointer */
   volatile ssize_t fill_cnt;  /**< Number of filled slots */
   ssize_t size;               /**< Buffer size */
   xSemaphoreHandle can_read;
@@ -46,26 +46,26 @@
   int reader_unblock;
 } ringbuf_t;
 
-ringbuf_t *rb_init(const char *rb_name, uint32_t size);
-void rb_abort_read(ringbuf_t *rb);
-void rb_abort_write(ringbuf_t *rb);
-void rb_abort(ringbuf_t *rb);
-void rb_reset(ringbuf_t *rb);
+ringbuf_t* rb_init(const char* rb_name, uint32_t size);
+void rb_abort_read(ringbuf_t* rb);
+void rb_abort_write(ringbuf_t* rb);
+void rb_abort(ringbuf_t* rb);
+void rb_reset(ringbuf_t* rb);
 /**
  * @brief Special function to reset the buffer while keeping rb_write aborted.
  *        This rb needs to be reset again before being useful.
  */
-void rb_reset_and_abort_write(ringbuf_t *rb);
-void rb_stat(ringbuf_t *rb);
-ssize_t rb_filled(ringbuf_t *rb);
-ssize_t rb_available(ringbuf_t *rb);
-int rb_read(ringbuf_t *rb, uint8_t *buf, int len, uint32_t ticks_to_wait);
-int rb_write(ringbuf_t *rb, const uint8_t *buf, int len,
+void rb_reset_and_abort_write(ringbuf_t* rb);
+void rb_stat(ringbuf_t* rb);
+ssize_t rb_filled(ringbuf_t* rb);
+ssize_t rb_available(ringbuf_t* rb);
+int rb_read(ringbuf_t* rb, uint8_t* buf, int len, uint32_t ticks_to_wait);
+int rb_write(ringbuf_t* rb, const uint8_t* buf, int len,
              uint32_t ticks_to_wait);
-void rb_cleanup(ringbuf_t *rb);
-void rb_signal_writer_finished(ringbuf_t *rb);
-void rb_wakeup_reader(ringbuf_t *rb);
-int rb_is_writer_finished(ringbuf_t *rb);
+void rb_cleanup(ringbuf_t* rb);
+void rb_signal_writer_finished(ringbuf_t* rb);
+void rb_wakeup_reader(ringbuf_t* rb);
+int rb_is_writer_finished(ringbuf_t* rb);
 
 #ifdef __cplusplus
 }
diff --git a/tensorflow/lite/micro/examples/micro_speech/esp/sdkconfig.defaults b/tensorflow/lite/micro/examples/micro_speech/esp/sdkconfig.defaults
index 4c3f6b7..fb8c6d3 100644
--- a/tensorflow/lite/micro/examples/micro_speech/esp/sdkconfig.defaults
+++ b/tensorflow/lite/micro/examples/micro_speech/esp/sdkconfig.defaults
@@ -4,7 +4,7 @@
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
-#    http://www.apache.org/licenses/LICENSE-2.0
+#     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
diff --git a/tensorflow/lite/micro/examples/micro_speech/nxp_k66f/audio_provider.cc b/tensorflow/lite/micro/examples/micro_speech/nxp_k66f/audio_provider.cc
index fb7df6b..aa47dc4 100644
--- a/tensorflow/lite/micro/examples/micro_speech/nxp_k66f/audio_provider.cc
+++ b/tensorflow/lite/micro/examples/micro_speech/nxp_k66f/audio_provider.cc
@@ -1,8 +1,11 @@
 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
+
     http://www.apache.org/licenses/LICENSE-2.0
+
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -45,7 +48,7 @@
      defined(FSL_FEATURE_L1ICACHE_LINESIZE_BYTE))
 #define AT_NONCACHEABLE_SECTION_ALIGN(var, alignbytes) \
   __attribute__((section("NonCacheable"), zero_init))  \
-      __attribute__((aligned(alignbytes))) var
+  __attribute__((aligned(alignbytes))) var
 #else
 #define AT_NONCACHEABLE_SECTION_ALIGN(var, alignbytes) \
   __attribute__((aligned(alignbytes))) var
@@ -128,7 +131,7 @@
     {0x24, 0x00}, {0x25, 0x00}, {0x26, 0x20}, {0x20, 0x80}};
 
 // Save audio samples into intermediate buffer
-void CaptureSamples(const int16_t *sample_data) {
+void CaptureSamples(const int16_t* sample_data) {
   const int sample_size = kNoOfSamples;
   const int32_t time_in_ms =
       g_latest_audio_timestamp + (sample_size / (kAudioSampleFrequency / 1000));
@@ -145,17 +148,17 @@
 }
 
 // Callback function for SAI RX EDMA transfer complete
-static void SaiRxCallback(I2S_Type *base, sai_edma_handle_t *handle,
-                          status_t status, void *userData) {
+static void SaiRxCallback(I2S_Type* base, sai_edma_handle_t* handle,
+                          status_t status, void* userData) {
   if (kStatus_SAI_RxError == status) {
     // Handle the error
   } else {
     // Save audio data into intermediate buffer
     CaptureSamples(
-        reinterpret_cast<int16_t *>(g_rx_buffer + g_tx_index * kNoOfSamples));
+        reinterpret_cast<int16_t*>(g_rx_buffer + g_tx_index * kNoOfSamples));
 
     // Submit received audio buffer to SAI TX for audio loopback debug
-    g_sai_transfer.data = (uint8_t *)(g_rx_buffer + g_tx_index * kNoOfSamples);
+    g_sai_transfer.data = (uint8_t*)(g_rx_buffer + g_tx_index * kNoOfSamples);
     g_sai_transfer.dataSize = kBufferSize;
     if (kStatus_Success ==
         SAI_TransferSendEDMA(I2S0, &g_tx_sai_handle, &g_sai_transfer)) {
@@ -166,7 +169,7 @@
     }
 
     // Submit buffer to SAI RX to receive audio data
-    g_sai_transfer.data = (uint8_t *)(g_rx_buffer + g_rx_index * kNoOfSamples);
+    g_sai_transfer.data = (uint8_t*)(g_rx_buffer + g_rx_index * kNoOfSamples);
     g_sai_transfer.dataSize = kBufferSize;
     if (kStatus_Success ==
         SAI_TransferReceiveEDMA(I2S0, &g_rx_sai_handle, &g_sai_transfer)) {
@@ -179,8 +182,8 @@
 }
 
 // Callback function for TX Buffer transfer
-static void SaiTxCallback(I2S_Type *base, sai_edma_handle_t *handle,
-                          status_t status, void *userData) {
+static void SaiTxCallback(I2S_Type* base, sai_edma_handle_t* handle,
+                          status_t status, void* userData) {
   if (kStatus_SAI_TxError == status) {
     // Handle the error
   }
@@ -240,7 +243,7 @@
   i2c_data.direction = kI2C_Write;
   i2c_data.subaddress = register_address;
   i2c_data.subaddressSize = 1;
-  i2c_data.data = (uint8_t * volatile) data;
+  i2c_data.data = (uint8_t* volatile)data;
   i2c_data.dataSize = 1;
   i2c_data.flags = kI2C_TransferDefaultFlag;
   return I2C_MasterTransferBlocking(I2C1, &i2c_data);
@@ -255,7 +258,7 @@
 }
 
 // Initialization for receiving audio data
-TfLiteStatus InitAudioRecording(tflite::ErrorReporter *error_reporter) {
+TfLiteStatus InitAudioRecording(tflite::ErrorReporter* error_reporter) {
   edma_config_t dma_config = {0};
   sai_config_t sai_config;
   sai_transfer_format_t sai_format;
@@ -325,7 +328,7 @@
                               sai_format.masterClockHz);
 
   // Submit buffers to SAI RX to start receiving audio
-  g_sai_transfer.data = (uint8_t *)(g_rx_buffer + g_rx_index * kNoOfSamples);
+  g_sai_transfer.data = (uint8_t*)(g_rx_buffer + g_rx_index * kNoOfSamples);
   g_sai_transfer.dataSize = kBufferSize;
   if (kStatus_Success ==
       SAI_TransferReceiveEDMA(I2S0, &g_rx_sai_handle, &g_sai_transfer)) {
@@ -334,7 +337,7 @@
   if (g_rx_index == kNoOfBuffers) {
     g_rx_index = 0U;
   }
-  g_sai_transfer.data = (uint8_t *)(g_rx_buffer + g_rx_index * kNoOfSamples);
+  g_sai_transfer.data = (uint8_t*)(g_rx_buffer + g_rx_index * kNoOfSamples);
   g_sai_transfer.dataSize = kBufferSize;
   if (kStatus_Success ==
       SAI_TransferReceiveEDMA(I2S0, &g_rx_sai_handle, &g_sai_transfer)) {
@@ -349,9 +352,9 @@
 }  // namespace
 
 // Main entry point for getting audio data.
-TfLiteStatus GetAudioSamples(tflite::ErrorReporter *error_reporter,
+TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter,
                              int start_ms, int duration_ms,
-                             int *audio_samples_size, int16_t **audio_samples) {
+                             int* audio_samples_size, int16_t** audio_samples) {
   if (!g_is_audio_initialized) {
     TfLiteStatus init_status = InitAudioRecording(error_reporter);
     if (init_status != kTfLiteOk) {
diff --git a/tensorflow/lite/micro/examples/person_detection/Makefile.inc b/tensorflow/lite/micro/examples/person_detection/Makefile.inc
index 6b8be54..1b7ba8b 100644
--- a/tensorflow/lite/micro/examples/person_detection/Makefile.inc
+++ b/tensorflow/lite/micro/examples/person_detection/Makefile.inc
@@ -1,5 +1,3 @@
-$(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_INT8_MD5),person_model_int8,))
-
 person_detection_MODEL_SRCS := \
 tensorflow/lite/micro/examples/person_detection/model_settings.cc \
 $(MAKEFILE_DIR)/downloads/person_model_int8/person_detect_model_data.cc
diff --git a/tensorflow/lite/micro/examples/person_detection/arc_emsdp/emsdp.lcf b/tensorflow/lite/micro/examples/person_detection/arc_emsdp/emsdp.lcf
index c415093..9486ac6 100644
--- a/tensorflow/lite/micro/examples/person_detection/arc_emsdp/emsdp.lcf
+++ b/tensorflow/lite/micro/examples/person_detection/arc_emsdp/emsdp.lcf
@@ -1,8 +1,11 @@
 # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/examples/person_detection/esp/app_camera_esp.c b/tensorflow/lite/micro/examples/person_detection/esp/app_camera_esp.c
index f231be8..420f74b 100644
--- a/tensorflow/lite/micro/examples/person_detection/esp/app_camera_esp.c
+++ b/tensorflow/lite/micro/examples/person_detection/esp/app_camera_esp.c
@@ -15,7 +15,7 @@
 
 #include "app_camera_esp.h"
 
-static const char *TAG = "app_camera";
+static const char* TAG = "app_camera";
 
 int app_camera_init() {
 #if CONFIG_CAMERA_MODEL_ESP_EYE
diff --git a/tensorflow/lite/micro/examples/person_detection/esp/main/Kconfig.projbuild b/tensorflow/lite/micro/examples/person_detection/esp/main/Kconfig.projbuild
index ac769fb..c338ead 100755
--- a/tensorflow/lite/micro/examples/person_detection/esp/main/Kconfig.projbuild
+++ b/tensorflow/lite/micro/examples/person_detection/esp/main/Kconfig.projbuild
@@ -4,7 +4,7 @@
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
-#    http://www.apache.org/licenses/LICENSE-2.0
+#     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
diff --git a/tensorflow/lite/micro/examples/person_detection/esp/sdkconfig.defaults b/tensorflow/lite/micro/examples/person_detection/esp/sdkconfig.defaults
index 4365b36..021ea58 100644
--- a/tensorflow/lite/micro/examples/person_detection/esp/sdkconfig.defaults
+++ b/tensorflow/lite/micro/examples/person_detection/esp/sdkconfig.defaults
@@ -4,7 +4,7 @@
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
-#    http://www.apache.org/licenses/LICENSE-2.0
+#     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
diff --git a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.c b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.c
index 4fc673a..70adf66 100644
--- a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.c
+++ b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.c
@@ -673,7 +673,8 @@
 
   am_util_stdio_printf("[%s] +\n", __func__);
 #ifdef ENABLE_ASYNC
-  while (!s_bVsyncAsserted);
+  while (!s_bVsyncAsserted)
+    ;
 
   while (s_bVsyncAsserted) {
     // we don't check HSYNC here on the basis of assuming HM01B0 in the gated
@@ -687,18 +688,21 @@
         goto end;
       }
 
-      while (read_pclk());
+      while (read_pclk())
+        ;
     }
   }
 #else
   uint32_t ui32HsyncCnt = 0x00;
 
   while ((ui32HsyncCnt < HM01B0_PIXEL_Y_NUM)) {
-    while (0x00 == read_hsync());
+    while (0x00 == read_hsync())
+      ;
 
     // read one row
     while (read_hsync()) {
-      while (0x00 == read_pclk());
+      while (0x00 == read_pclk())
+        ;
 
       *(pui8Buffer + ui32Idx++) = read_byte();
 
@@ -706,7 +710,8 @@
         goto end;
       }
 
-      while (read_pclk());
+      while (read_pclk())
+        ;
     }
 
     ui32HsyncCnt++;
diff --git a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h
index c7ec4e6..8984d65 100644
--- a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h
+++ b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h
@@ -99,7 +99,7 @@
   am_hal_iom_mode_e eIOMMode;
   uint32_t ui32IOMModule;
   am_hal_iom_config_t sIOMCfg;
-  void *pIOMHandle;
+  void* pIOMHandle;
 
   uint32_t ui32CTimerModule;
   uint32_t ui32CTimerSegment;
@@ -138,8 +138,8 @@
 //! @return Error code.
 //
 //*****************************************************************************
-static uint32_t hm01b0_write_reg(hm01b0_cfg_t *psCfg, uint16_t ui16Reg,
-                                 uint8_t *pui8Value, uint32_t ui32NumBytes);
+static uint32_t hm01b0_write_reg(hm01b0_cfg_t* psCfg, uint16_t ui16Reg,
+                                 uint8_t* pui8Value, uint32_t ui32NumBytes);
 
 //*****************************************************************************
 //
@@ -156,8 +156,8 @@
 //! @return Error code.
 //
 //*****************************************************************************
-static uint32_t hm01b0_read_reg(hm01b0_cfg_t *psCfg, uint16_t ui16Reg,
-                                uint8_t *pui8Value, uint32_t ui32NumBytes);
+static uint32_t hm01b0_read_reg(hm01b0_cfg_t* psCfg, uint16_t ui16Reg,
+                                uint8_t* pui8Value, uint32_t ui32NumBytes);
 
 //*****************************************************************************
 //
@@ -172,7 +172,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-static uint32_t hm01b0_load_script(hm01b0_cfg_t *psCfg, hm_script_t *psScript,
+static uint32_t hm01b0_load_script(hm01b0_cfg_t* psCfg, hm_script_t* psScript,
                                    uint32_t ui32ScriptCmdNum);
 
 //*****************************************************************************
@@ -186,7 +186,7 @@
 //! @return none.
 //
 //*****************************************************************************
-void hm01b0_power_up(hm01b0_cfg_t *psCfg);
+void hm01b0_power_up(hm01b0_cfg_t* psCfg);
 
 //*****************************************************************************
 //
@@ -199,7 +199,7 @@
 //! @return none.
 //
 //*****************************************************************************
-void hm01b0_power_down(hm01b0_cfg_t *psCfg);
+void hm01b0_power_down(hm01b0_cfg_t* psCfg);
 
 //*****************************************************************************
 //
@@ -212,7 +212,7 @@
 //! @return none.
 //
 //*****************************************************************************
-void hm01b0_mclk_enable(hm01b0_cfg_t *psCfg);
+void hm01b0_mclk_enable(hm01b0_cfg_t* psCfg);
 
 //*****************************************************************************
 //
@@ -225,7 +225,7 @@
 //! @return none.
 //
 //*****************************************************************************
-void hm01b0_mclk_disable(hm01b0_cfg_t *psCfg);
+void hm01b0_mclk_disable(hm01b0_cfg_t* psCfg);
 
 //*****************************************************************************
 //
@@ -238,7 +238,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_init_if(hm01b0_cfg_t *psCfg);
+uint32_t hm01b0_init_if(hm01b0_cfg_t* psCfg);
 
 //*****************************************************************************
 //
@@ -251,7 +251,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_deinit_if(hm01b0_cfg_t *psCfg);
+uint32_t hm01b0_deinit_if(hm01b0_cfg_t* psCfg);
 
 //*****************************************************************************
 //
@@ -265,7 +265,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_get_modelid(hm01b0_cfg_t *psCfg, uint16_t *pui16MID);
+uint32_t hm01b0_get_modelid(hm01b0_cfg_t* psCfg, uint16_t* pui16MID);
 
 //*****************************************************************************
 //
@@ -281,7 +281,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_init_system(hm01b0_cfg_t *psCfg, hm_script_t *psScript,
+uint32_t hm01b0_init_system(hm01b0_cfg_t* psCfg, hm_script_t* psScript,
                             uint32_t ui32ScriptCmdNum);
 
 //*****************************************************************************
@@ -295,7 +295,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_test_walking1s(hm01b0_cfg_t *psCfg);
+uint32_t hm01b0_test_walking1s(hm01b0_cfg_t* psCfg);
 
 //*****************************************************************************
 //
@@ -308,7 +308,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_reset_sw(hm01b0_cfg_t *psCfg);
+uint32_t hm01b0_reset_sw(hm01b0_cfg_t* psCfg);
 
 //*****************************************************************************
 //
@@ -323,7 +323,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_get_mode(hm01b0_cfg_t *psCfg, uint8_t *pui8Mode);
+uint32_t hm01b0_get_mode(hm01b0_cfg_t* psCfg, uint8_t* pui8Mode);
 
 //*****************************************************************************
 //
@@ -344,7 +344,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_set_mode(hm01b0_cfg_t *psCfg, uint8_t ui8Mode,
+uint32_t hm01b0_set_mode(hm01b0_cfg_t* psCfg, uint8_t ui8Mode,
                          uint8_t framecnt);
 
 //*****************************************************************************
@@ -360,7 +360,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_hardware_trigger_streaming(hm01b0_cfg_t *psCfg, bool bTrigger);
+uint32_t hm01b0_hardware_trigger_streaming(hm01b0_cfg_t* psCfg, bool bTrigger);
 
 //*****************************************************************************
 //
@@ -375,7 +375,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_set_mirror(hm01b0_cfg_t *psCfg, bool bHmirror, bool bVmirror);
+uint32_t hm01b0_set_mirror(hm01b0_cfg_t* psCfg, bool bHmirror, bool bVmirror);
 
 //*****************************************************************************
 //
@@ -390,7 +390,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_blocking_read_oneframe(hm01b0_cfg_t *psCfg, uint8_t *pui8Buffer,
+uint32_t hm01b0_blocking_read_oneframe(hm01b0_cfg_t* psCfg, uint8_t* pui8Buffer,
                                        uint32_t ui32BufferLen);
 
 //*****************************************************************************
@@ -404,7 +404,7 @@
 //! @return Error code.
 //
 //*****************************************************************************
-uint32_t hm01b0_single_frame_capture(hm01b0_cfg_t *psCfg);
+uint32_t hm01b0_single_frame_capture(hm01b0_cfg_t* psCfg);
 
 #ifdef __cplusplus
 }
diff --git a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.c b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.c
index 3a64b70..9e83315 100644
--- a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.c
+++ b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.c
@@ -20,7 +20,8 @@
 #ifndef ARDUINO_EXCLUDE_CODE
 
 #include "HM01B0_debug.h"
-#include "am_util.h" // NOLINT
+
+#include "am_util.h"  // NOLINT
 
 void hm01b0_framebuffer_dump(uint8_t* frame, uint32_t length) {
   am_util_stdio_printf("+++ frame +++");
diff --git a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.c b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.c
index 7bc5b2b..e60d874 100644
--- a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.c
+++ b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.c
@@ -20,9 +20,9 @@
 #ifndef ARDUINO_EXCLUDE_CODE
 
 #include "HM01B0.h"
-#include "am_bsp.h" //NOLINT
-#include "am_mcu_apollo.h" //NOLINT
-#include "platform.h"      // TARGET specific implementation
+#include "am_bsp.h"         //NOLINT
+#include "am_mcu_apollo.h"  //NOLINT
+#include "platform.h"       // TARGET specific implementation
 
 // Image is down-sampled by applying a stride of 2 pixels in both the x and y
 // directions.
@@ -58,7 +58,8 @@
 
   while ((hsync_count < HM01B0_PIXEL_Y_NUM)) {
     // Wait for horizontal sync.
-    while (!read_hsync());
+    while (!read_hsync())
+      ;
 
     // Get resulting image position.  When hsync_count < offset_y, this will
     // underflow resulting in an index out of bounds which we check later,
@@ -69,14 +70,15 @@
     // Read one row. Hsync is held high for the duration of a row read.
     while (read_hsync()) {
       // Wait for pixel value to be ready.
-      while (!read_pclk());
+      while (!read_pclk())
+        ;
 
       // Read 8-bit value from camera.
       const uint8_t value = read_byte();
       const uint32_t output_x = (rowidx++ - offset_x) >> kStrideShift;
       if (output_x < w && output_y < h) {
         const int output_idx = (output_y * w + output_x) * channels;
-        for (int i=0; i<channels; i++) {
+        for (int i = 0; i < channels; i++) {
           // See the top of main_functions.cc for an explanation of and
           // rationale for our unsigned to signed input conversion.
           buffer[output_idx + i] = value - 128;
@@ -84,7 +86,8 @@
       }
 
       // Wait for next pixel clock.
-      while (read_pclk());
+      while (read_pclk())
+        ;
     }
 
     hsync_count++;
diff --git a/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/image_provider.cc b/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/image_provider.cc
index fc1c486..05db2c2 100644
--- a/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/image_provider.cc
+++ b/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/image_provider.cc
@@ -23,12 +23,11 @@
 
 #ifndef ARDUINO_EXCLUDE_CODE
 
-#include "tensorflow/lite/micro/examples/person_detection/image_provider.h"
-
 #include "tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h"
 #include "tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_RAW8_QVGA_8bits_lsb_5fps.h"
 #include "tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.h"
 #include "tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.h"
+#include "tensorflow/lite/micro/examples/person_detection/image_provider.h"
 
 // These are headers from Ambiq's Apollo3 SDK.
 #include "am_bsp.h"         // NOLINT
diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD
index f20ce0d..6dba309 100644
--- a/tensorflow/lite/micro/kernels/BUILD
+++ b/tensorflow/lite/micro/kernels/BUILD
@@ -34,7 +34,9 @@
         ],
     }),
     copts = micro_copts(),
-    deps = select({
+    deps = [
+        ":xtensa",
+    ] + select({
         "//conditions:default": [],
         ":xtensa_hifimini": [
             #"//third_party/xtensa/cstub64s:hifi_mini",
@@ -62,6 +64,7 @@
     deps = [
         ":fixedpoint_utils",
         ":kernel_util",
+        ":xtensa",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/kernels:kernel_util",
         "//tensorflow/lite/kernels/internal:common",
@@ -117,6 +120,7 @@
         "pad.cc",
         "pooling.cc",
         "prelu.cc",
+        "quantize_common.cc",
         "reduce.cc",
         "reshape.cc",
         "resize_nearest_neighbor.cc",
@@ -126,6 +130,7 @@
         "split_v.cc",
         "strided_slice.cc",
         "sub.cc",
+        "svdf_common.cc",
         "tanh.cc",
         "unpack.cc",
     ] + select({
@@ -144,7 +149,11 @@
             "xtensa/svdf.cc",
         ],
     }),
-    hdrs = ["micro_ops.h"],
+    hdrs = [
+        "micro_ops.h",
+        "quantize.h",
+        "svdf.h",
+    ],
     copts = micro_copts(),
     visibility = [
         # Needed for micro:op_resolvers but visibility can not be finer-grained
@@ -156,6 +165,7 @@
         ":kernel_util",
         ":fixedpoint_utils",
         ":micro_utils",
+        ":xtensa",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/kernels:kernel_util",
         "//tensorflow/lite/kernels:op_macros",
@@ -177,6 +187,24 @@
     }),
 )
 
+cc_library(
+    name = "xtensa",
+    hdrs = select({
+        "//conditions:default": [
+        ],
+        ":xtensa_hifimini": [
+            "xtensa/xtensa.h",
+        ],
+    }),
+    copts = micro_copts(),
+    deps = select({
+        "//conditions:default": [],
+        ":xtensa_hifimini": [
+            #"//third_party/xtensa/cstub64s:hifi_mini",
+        ],
+    }),
+)
+
 test_suite(
     name = "all_tests",
 )
@@ -678,6 +706,7 @@
     srcs = [
         "pad_test.cc",
     ],
+    tags = ["nomsan"],  # b/175133159
     deps = [
         ":kernel_runner",
         "//tensorflow/lite/c:common",
diff --git a/tensorflow/lite/micro/kernels/arc_mli/conv.cc b/tensorflow/lite/micro/kernels/arc_mli/conv.cc
index 2974531..bf5f024 100644
--- a/tensorflow/lite/micro/kernels/arc_mli/conv.cc
+++ b/tensorflow/lite/micro/kernels/arc_mli/conv.cc
@@ -1,4 +1,4 @@
-/* Copyright 2019-2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
diff --git a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc
index 252b022..1c973a4 100644
--- a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017-2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
diff --git a/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc b/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc
index 1f439ff..82e233f 100644
--- a/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc
+++ b/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017-2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
diff --git a/tensorflow/lite/micro/kernels/arc_mli/pooling.cc b/tensorflow/lite/micro/kernels/arc_mli/pooling.cc
index 76002c7..d1cd56f 100644
--- a/tensorflow/lite/micro/kernels/arc_mli/pooling.cc
+++ b/tensorflow/lite/micro/kernels/arc_mli/pooling.cc
@@ -1,4 +1,4 @@
-/* Copyright 2019-2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
diff --git a/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.cc b/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.cc
index 1e188fc..296b9b6 100644
--- a/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.cc
+++ b/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.cc
@@ -63,15 +63,15 @@
 #pragma Bss()
 }  // namespace
 
-static int8_t *scratch_mem[] = {scratch_mem_x, scratch_mem_y, scratch_mem_z};
+static int8_t* scratch_mem[] = {scratch_mem_x, scratch_mem_y, scratch_mem_z};
 static uint32_t scratch_sizes[] = {SCRATCH_MEM_X_SIZE, SCRATCH_MEM_Y_SIZE,
                                    SCRATCH_MEM_Z_SIZE};
 
-void *get_arc_scratch_buffer(int size) {
+void* get_arc_scratch_buffer(int size) {
   // Function to asign fast memory from one of 3 scratch buffers.
   // Best Fit strategy - memory is allocated from that memory bank that leaves
   // the least unused memory.
-  void *buf = NULL;
+  void* buf = NULL;
   int best_mem_idx = -1;
   int best_mem_delta = INT_MAX;
   const int num_mem = sizeof(scratch_mem) / sizeof(scratch_mem[0]);
@@ -85,14 +85,14 @@
     }
   }
   if (best_mem_idx >= 0) {
-    buf = static_cast<void *>(scratch_mem[best_mem_idx]);
+    buf = static_cast<void*>(scratch_mem[best_mem_idx]);
     scratch_mem[best_mem_idx] += size;
     scratch_sizes[best_mem_idx] -= size;
   }
   return buf;
 }
 
-void get_arc_scratch_buffer_max_size(int *size) {
+void get_arc_scratch_buffer_max_size(int* size) {
   int maxavailable = 0;
   const int num_mem = sizeof(scratch_mem) / sizeof(scratch_mem[0]);
   // find the largest available buffer.
@@ -104,7 +104,7 @@
   *size = maxavailable;
 }
 
-void get_arc_scratch_buffer_two_max_sizes(int *size1, int *size2) {
+void get_arc_scratch_buffer_two_max_sizes(int* size1, int* size2) {
   int maxavailable = 0;
   int secondavail = 0;
   const int num_mem = sizeof(scratch_mem) / sizeof(scratch_mem[0]);
diff --git a/tensorflow/lite/micro/kernels/detection_postprocess.cc b/tensorflow/lite/micro/kernels/detection_postprocess.cc
index cdf8634..db1d203 100644
--- a/tensorflow/lite/micro/kernels/detection_postprocess.cc
+++ b/tensorflow/lite/micro/kernels/detection_postprocess.cc
@@ -1,4 +1,5 @@
 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
@@ -357,8 +358,9 @@
   int counter = 0;
   for (int i = 0; i < size; i++) {
     if (values[i] >= threshold) {
-      keep_values[counter++] = values[i];
-      keep_indices[i] = i;
+      keep_values[counter] = values[i];
+      keep_indices[counter] = i;
+      counter++;
     }
   }
   return counter;
diff --git a/tensorflow/lite/micro/kernels/flexbuffers_generated_data.cc b/tensorflow/lite/micro/kernels/flexbuffers_generated_data.cc
index 3f01096..106deec 100644
--- a/tensorflow/lite/micro/kernels/flexbuffers_generated_data.cc
+++ b/tensorflow/lite/micro/kernels/flexbuffers_generated_data.cc
@@ -1,4 +1,5 @@
 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
diff --git a/tensorflow/lite/micro/kernels/flexbuffers_generated_data.h b/tensorflow/lite/micro/kernels/flexbuffers_generated_data.h
index 0eab0ae..f5b9eae 100644
--- a/tensorflow/lite/micro/kernels/flexbuffers_generated_data.h
+++ b/tensorflow/lite/micro/kernels/flexbuffers_generated_data.h
@@ -1,4 +1,5 @@
 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
diff --git a/tensorflow/lite/micro/kernels/micro_utils.h b/tensorflow/lite/micro/kernels/micro_utils.h
index 85db263..e406ac1 100644
--- a/tensorflow/lite/micro/kernels/micro_utils.h
+++ b/tensorflow/lite/micro/kernels/micro_utils.h
@@ -1,8 +1,11 @@
 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
+
     http://www.apache.org/licenses/LICENSE-2.0
+
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/kernels/quantize.cc b/tensorflow/lite/micro/kernels/quantize.cc
index 8b9bf7e..f62addb 100644
--- a/tensorflow/lite/micro/kernels/quantize.cc
+++ b/tensorflow/lite/micro/kernels/quantize.cc
@@ -12,11 +12,11 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/lite/kernels/internal/reference/quantize.h"
+
+#include "tensorflow/lite/micro/kernels/quantize.h"
 
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/requantize.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
@@ -25,24 +25,15 @@
 namespace tflite {
 namespace {
 
-struct OpData {
-  tflite::QuantizationParams quantization_params;
-  // The scaling factor from input to output (aka the 'real multiplier') can
-  // be represented as a fixed point multiplier plus a left shift.
-  int32_t output_multiplier;
-  int output_shift;
-
-  int32_t input_zero_point;
-};
-
 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
+  return context->AllocatePersistentBuffer(context,
+                                           sizeof(OpDataQuantizeReference));
 }
 
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   TFLITE_DCHECK(node->user_data != nullptr);
-  OpData* data = static_cast<OpData*>(node->user_data);
+  auto* data = static_cast<OpDataQuantizeReference*>(node->user_data);
 
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -77,8 +68,8 @@
     double effective_scale = static_cast<double>(input->params.scale) /
                              static_cast<double>(output->params.scale);
 
-    QuantizeMultiplier(effective_scale, &data->output_multiplier,
-                       &data->output_shift);
+    QuantizeMultiplier(effective_scale, &data->requantize_output_multiplier,
+                       &data->requantize_output_shift);
   }
 
   data->quantization_params.zero_point = output->params.zero_point;
@@ -88,107 +79,13 @@
   return kTfLiteOk;
 }
 
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  OpData* data = static_cast<OpData*>(node->user_data);
-
-  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
-  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
-
-  if (input->type == kTfLiteFloat32) {
-    switch (output->type) {
-      case kTfLiteInt8:
-        reference_ops::AffineQuantize(
-            data->quantization_params, tflite::micro::GetTensorShape(input),
-            tflite::micro::GetTensorData<float>(input),
-            tflite::micro::GetTensorShape(output),
-            tflite::micro::GetTensorData<int8_t>(output));
-        break;
-      case kTfLiteUInt8:
-        reference_ops::AffineQuantize(
-            data->quantization_params, tflite::micro::GetTensorShape(input),
-            tflite::micro::GetTensorData<float>(input),
-            tflite::micro::GetTensorShape(output),
-            tflite::micro::GetTensorData<uint8_t>(output));
-        break;
-      case kTfLiteInt16:
-        reference_ops::AffineQuantize(
-            data->quantization_params, tflite::micro::GetTensorShape(input),
-            tflite::micro::GetTensorData<float>(input),
-            tflite::micro::GetTensorShape(output),
-            tflite::micro::GetTensorData<int16_t>(output));
-        return kTfLiteOk;
-      default:
-        TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
-                           TfLiteTypeGetName(input->type),
-                           TfLiteTypeGetName(output->type));
-        return kTfLiteError;
-    }
-  } else if (input->type == kTfLiteInt16) {
-    size_t size = ElementCount(*input->dims);
-    switch (output->type) {
-      case kTfLiteInt8:
-        reference_ops::Requantize(tflite::micro::GetTensorData<int16_t>(input),
-                                  size, data->output_multiplier,
-                                  data->output_shift, data->input_zero_point,
-                                  data->quantization_params.zero_point,
-                                  tflite::micro::GetTensorData<int8_t>(output));
-        break;
-      case kTfLiteInt16:
-        reference_ops::Requantize(
-            tflite::micro::GetTensorData<int16_t>(input), size,
-            data->output_multiplier, data->output_shift, data->input_zero_point,
-            data->quantization_params.zero_point,
-            tflite::micro::GetTensorData<int16_t>(output));
-        return kTfLiteOk;
-      case kTfLiteInt32:
-        reference_ops::Requantize(
-            tflite::micro::GetTensorData<int16_t>(input), size,
-            data->output_multiplier, data->output_shift, data->input_zero_point,
-            data->quantization_params.zero_point,
-            tflite::micro::GetTensorData<int32_t>(output));
-        return kTfLiteOk;
-      default:
-        TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
-                           TfLiteTypeGetName(input->type),
-                           TfLiteTypeGetName(output->type));
-        return kTfLiteError;
-    }
-  } else if (input->type == kTfLiteInt8) {
-    // Int8 to Int8 requantization, required if the input and output tensors
-    // have different scales and/or zero points.
-    size_t size = ElementCount(*input->dims);
-    switch (output->type) {
-      case kTfLiteInt8:
-        reference_ops::Requantize(tflite::micro::GetTensorData<int8_t>(input),
-                                  size, data->output_multiplier,
-                                  data->output_shift, data->input_zero_point,
-                                  data->quantization_params.zero_point,
-                                  tflite::micro::GetTensorData<int8_t>(output));
-        break;
-      default:
-        TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
-                           TfLiteTypeGetName(input->type),
-                           TfLiteTypeGetName(output->type));
-        return kTfLiteError;
-    }
-  } else {
-    TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
-                       TfLiteTypeGetName(input->type),
-                       TfLiteTypeGetName(output->type));
-    return kTfLiteError;
-  }
-
-  return kTfLiteOk;
-}
-
 }  // namespace
 
 TfLiteRegistration Register_QUANTIZE() {
   return {/*init=*/Init,
           /*free=*/nullptr,
           /*prepare=*/Prepare,
-          /*invoke=*/Eval,
+          /*invoke=*/EvalQuantizeReference,
           /*profiling_string=*/nullptr,
           /*builtin_code=*/0,
           /*custom_name=*/nullptr,
diff --git a/tensorflow/lite/micro/kernels/quantize.h b/tensorflow/lite/micro/kernels/quantize.h
new file mode 100644
index 0000000..aefe624
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/quantize.h
@@ -0,0 +1,37 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_KERNELS_QUANTIZE_H_
+#define TENSORFLOW_LITE_MICRO_KERNELS_QUANTIZE_H_
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+struct OpDataQuantizeReference {
+  tflite::QuantizationParams quantization_params;
+  // The scaling factor from input to output (aka the 'real multiplier') can
+  // be represented as a fixed point multiplier plus a left shift.
+  int32_t requantize_output_multiplier;
+  int requantize_output_shift;
+
+  int32_t input_zero_point;
+};
+
+TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node);
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_MICRO_KERNELS_QUANTIZE_H_
diff --git a/tensorflow/lite/micro/kernels/quantize_common.cc b/tensorflow/lite/micro/kernels/quantize_common.cc
new file mode 100644
index 0000000..2c4a8d2
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/quantize_common.cc
@@ -0,0 +1,122 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/quantize.h"
+#include "tensorflow/lite/kernels/internal/reference/requantize.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/quantize.h"
+#include "tensorflow/lite/micro/micro_utils.h"
+
+namespace tflite {
+
+TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
+  TFLITE_DCHECK(node->user_data != nullptr);
+  auto* data = static_cast<OpDataQuantizeReference*>(node->user_data);
+
+  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
+  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
+
+  if (input->type == kTfLiteFloat32) {
+    switch (output->type) {
+      case kTfLiteInt8:
+        reference_ops::AffineQuantize(
+            data->quantization_params, tflite::micro::GetTensorShape(input),
+            tflite::micro::GetTensorData<float>(input),
+            tflite::micro::GetTensorShape(output),
+            tflite::micro::GetTensorData<int8_t>(output));
+        break;
+      case kTfLiteUInt8:
+        reference_ops::AffineQuantize(
+            data->quantization_params, tflite::micro::GetTensorShape(input),
+            tflite::micro::GetTensorData<float>(input),
+            tflite::micro::GetTensorShape(output),
+            tflite::micro::GetTensorData<uint8_t>(output));
+        break;
+      case kTfLiteInt16:
+        reference_ops::AffineQuantize(
+            data->quantization_params, tflite::micro::GetTensorShape(input),
+            tflite::micro::GetTensorData<float>(input),
+            tflite::micro::GetTensorShape(output),
+            tflite::micro::GetTensorData<int16_t>(output));
+        return kTfLiteOk;
+      default:
+        TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
+                           TfLiteTypeGetName(input->type),
+                           TfLiteTypeGetName(output->type));
+        return kTfLiteError;
+    }
+  } else if (input->type == kTfLiteInt16) {
+    size_t size = ElementCount(*input->dims);
+    switch (output->type) {
+      case kTfLiteInt8:
+        reference_ops::Requantize(
+            tflite::micro::GetTensorData<int16_t>(input), size,
+            data->requantize_output_multiplier, data->requantize_output_shift,
+            data->input_zero_point, data->quantization_params.zero_point,
+            tflite::micro::GetTensorData<int8_t>(output));
+        break;
+      case kTfLiteInt16:
+        reference_ops::Requantize(
+            tflite::micro::GetTensorData<int16_t>(input), size,
+            data->requantize_output_multiplier, data->requantize_output_shift,
+            data->input_zero_point, data->quantization_params.zero_point,
+            tflite::micro::GetTensorData<int16_t>(output));
+        return kTfLiteOk;
+      case kTfLiteInt32:
+        reference_ops::Requantize(
+            tflite::micro::GetTensorData<int16_t>(input), size,
+            data->requantize_output_multiplier, data->requantize_output_shift,
+            data->input_zero_point, data->quantization_params.zero_point,
+            tflite::micro::GetTensorData<int32_t>(output));
+        return kTfLiteOk;
+      default:
+        TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
+                           TfLiteTypeGetName(input->type),
+                           TfLiteTypeGetName(output->type));
+        return kTfLiteError;
+    }
+  } else if (input->type == kTfLiteInt8) {
+    // Int8 to Int8 requantization, required if the input and output tensors
+    // have different scales and/or zero points.
+    size_t size = ElementCount(*input->dims);
+    switch (output->type) {
+      case kTfLiteInt8:
+        reference_ops::Requantize(
+            tflite::micro::GetTensorData<int8_t>(input), size,
+            data->requantize_output_multiplier, data->requantize_output_shift,
+            data->input_zero_point, data->quantization_params.zero_point,
+            tflite::micro::GetTensorData<int8_t>(output));
+        break;
+      default:
+        TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
+                           TfLiteTypeGetName(input->type),
+                           TfLiteTypeGetName(output->type));
+        return kTfLiteError;
+    }
+  } else {
+    TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
+                       TfLiteTypeGetName(input->type),
+                       TfLiteTypeGetName(output->type));
+    return kTfLiteError;
+  }
+
+  return kTfLiteOk;
+}
+
+}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/quantize_test.cc b/tensorflow/lite/micro/kernels/quantize_test.cc
index 5803376..50d5021 100644
--- a/tensorflow/lite/micro/kernels/quantize_test.cc
+++ b/tensorflow/lite/micro/kernels/quantize_test.cc
@@ -49,7 +49,7 @@
   }
 }
 
-#if !defined(XTENSA)
+#if !defined(HIFIMINI)
 template <typename T>
 void TestQuantizeFloat(const int* input_dims_data, const float* input_data,
                        const int* output_dims_data, const float* golden,
@@ -79,7 +79,7 @@
   ValidateQuantizeGoldens(tensors, tensors_size, golden, golden_quantized,
                           scale, zero_point, output_dims_count, output_data);
 }
-#endif
+#endif  // defined(HIFIMINI)
 
 template <typename InputType, typename OutputType>
 void TestRequantize(const int* input_dims_data, const float* input_data,
@@ -121,7 +121,7 @@
 
 TF_LITE_MICRO_TESTS_BEGIN
 
-#if !defined(XTENSA)
+#if !defined(HIFIMINI)
 TF_LITE_MICRO_TEST(QuantizeOpTestUint8) {
   const int length = 10;
   const int dims[] = {2, 2, 5};
@@ -267,13 +267,11 @@
                                   values_quantized, output_scale,
                                   output_zero_point, output_quantized);
 }
-#endif
+#endif  // defined(HIFIMINI)
 
-#if !defined(XTENSA)
-// TODO(b/174603495): Since the hifimini optimized implementation does support
-// input==int16 and output==int8, it seems like this kernel test should pass. It
-// currently does not, but we are moving it to its own ifdef block to make it
-// more visible and hopefully fix this in the near future.
+#if !defined(HIFIMINI)
+// TODO(b/155682734): Hifimini optimized quantize requires input scale to be
+// smaller then output scale.
 TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) {
   const int length = 10;
   const int dims[] = {2, 2, 5};
@@ -290,7 +288,7 @@
                                   values_quantized, output_scale,
                                   output_zero_point, output_quantized);
 }
-#endif
+#endif  // defined(HIFIMINI)
 
 TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt32) {
   const int length = 10;
@@ -309,4 +307,23 @@
                                   output_zero_point, output_quantized);
 }
 
+TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) {
+  constexpr int length = 10;
+  const int dims[] = {2, 2, 5};
+  const float values[] = {-32, -31, -30, -29, -28, 27, 28, 29, 30, 31};
+  // TODO(b/155682734): Input scale must be smaller than output scale for
+  // xtensa.
+  const float input_scale = 0.4f;
+  const int input_zero_point = 0;
+  const float output_scale = 1.0f;
+  const int output_zero_point = 0;
+  int8_t output_quantized[length];
+  int8_t values_quantized[length];
+  int16_t input_quantized[length];
+  tflite::testing::TestRequantize(dims, values, input_quantized, input_scale,
+                                  input_zero_point, dims, values,
+                                  values_quantized, output_scale,
+                                  output_zero_point, output_quantized);
+}
+
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/softmax_test.cc b/tensorflow/lite/micro/kernels/softmax_test.cc
index 4ef8f4e..16dc6e2 100644
--- a/tensorflow/lite/micro/kernels/softmax_test.cc
+++ b/tensorflow/lite/micro/kernels/softmax_test.cc
@@ -20,8 +20,6 @@
 #include "tensorflow/lite/micro/test_helpers.h"
 #include "tensorflow/lite/micro/testing/micro_test.h"
 
-
-
 namespace tflite {
 namespace testing {
 namespace {
diff --git a/tensorflow/lite/micro/kernels/split_v_test.cc b/tensorflow/lite/micro/kernels/split_v_test.cc
index 06c90cb..73c4e1c 100755
--- a/tensorflow/lite/micro/kernels/split_v_test.cc
+++ b/tensorflow/lite/micro/kernels/split_v_test.cc
@@ -4,7 +4,7 @@
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
 
-http://www.apache.org/licenses/LICENSE-2.0
+    http://www.apache.org/licenses/LICENSE-2.0
 
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
diff --git a/tensorflow/lite/micro/kernels/svdf.cc b/tensorflow/lite/micro/kernels/svdf.cc
index 764fdc1..9ea43d4 100644
--- a/tensorflow/lite/micro/kernels/svdf.cc
+++ b/tensorflow/lite/micro/kernels/svdf.cc
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,6 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/lite/micro/kernels/svdf.h"
+
 #include <math.h>
 
 #include "tensorflow/lite/c/builtin_op_data.h"
@@ -29,21 +31,6 @@
 namespace tflite {
 namespace {
 
-struct OpData {
-  int32_t effective_scale_1_a;
-  int32_t effective_scale_2_a;
-  // b versions of each scale are kept at int since the numbers are just the
-  // shift value - typically between [-32, 32].
-  int effective_scale_1_b;
-  int effective_scale_2_b;
-  int scratch_tensor_index;
-  int scratch_output_tensor_index;
-
-  // Cached tensor zero point values for quantized operations.
-  int input_zero_point;
-  int output_zero_point;
-};
-
 // Input tensors.
 constexpr int kInputTensor = 0;
 constexpr int kWeightsFeatureTensor = 1;
@@ -200,150 +187,6 @@
       bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr);
 }
 
-void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
-                     const TfLiteEvalTensor* input_tensor,
-                     const TfLiteEvalTensor* weights_feature_tensor,
-                     const TfLiteEvalTensor* weights_time_tensor,
-                     const TfLiteEvalTensor* bias_tensor,
-                     const TfLiteSVDFParams* params,
-                     TfLiteEvalTensor* activation_state_tensor,
-                     TfLiteEvalTensor* output_tensor, const OpData& data) {
-  const int n_rank = params->rank;
-  const int n_batch = input_tensor->dims->data[0];
-  const int n_input = input_tensor->dims->data[1];
-  const int n_filter = weights_feature_tensor->dims->data[0];
-  const int n_unit = n_filter / n_rank;
-  const int n_memory = weights_time_tensor->dims->data[1];
-
-  TFLITE_DCHECK(context != nullptr);
-  TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
-
-  int32_t* scratch_tensor = static_cast<int32_t*>(
-      context->GetScratchBuffer(context, data.scratch_tensor_index));
-  int32_t* scratch_output_tensor = static_cast<int32_t*>(
-      context->GetScratchBuffer(context, data.scratch_output_tensor_index));
-
-  // Shift states.
-  int16_t* const state_ptr =
-      tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
-
-  // Left shift the activation_state.
-  {
-    int16_t* new_state_start = state_ptr;
-    const int16_t* old_state_start = state_ptr + 1;
-    const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory;
-    while (old_state_start != old_state_end) {
-      *new_state_start++ = *old_state_start++;
-    }
-  }
-
-  // Note: no need to clear the latest activation, matmul is not accumulative.
-
-  // Feature matmul.
-  {
-    int16_t* state =
-        tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
-    const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
-    const int8_t* weight_feature =
-        tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
-    const int32_t output_max = std::numeric_limits<int16_t>::max();
-    const int32_t output_min = std::numeric_limits<int16_t>::min();
-    int16_t* result_in_batch = state + (n_memory - 1);
-    for (int b = 0; b < n_batch; b++) {
-      const int8_t* matrix_ptr = weight_feature;
-      for (int r = 0; r < n_filter; r++) {
-        int32_t dot_prod = 0;
-        const int8_t* vector_in_batch = input + b * n_input;
-        for (int c = 0; c < n_input; c++) {
-          dot_prod +=
-              *matrix_ptr++ * (*vector_in_batch++ - data.input_zero_point);
-        }
-        dot_prod = MultiplyByQuantizedMultiplier(
-            dot_prod, data.effective_scale_1_a, data.effective_scale_1_b);
-        dot_prod = std::min(std::max(output_min, dot_prod), output_max);
-        // This assumes state is symmetrically quantized. Otherwise last bit of
-        // state should be initialized to its zero point and accumulate the
-        // dot_prod.
-        // Equivalent as the following:
-        //     result_in_batch = zero point, which happens to be zero.
-        //     result_in_batch += dot_prod_56.
-        *result_in_batch = dot_prod;
-        result_in_batch += n_memory;
-      }
-    }
-  }
-
-  // Time.
-  {
-    for (int b = 0; b < n_batch; ++b) {
-      int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
-
-      // Perform batched vector dot product:
-      const int16_t* vector1_ptr =
-          tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
-      const int16_t* vector2_ptr =
-          tflite::micro::GetTensorData<int16_t>(activation_state_tensor) +
-          b * n_memory * n_filter;
-
-      for (int i = 0; i < n_filter; i++) {
-        *scratch_ptr_batch = 0;
-        for (int j = 0; j < n_memory; j++) {
-          *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
-        }
-        scratch_ptr_batch++;
-      }
-    }
-  }
-
-  // Reduce, add bias, rescale, activation.
-  {
-    // Add bias.
-    if (bias_tensor) {
-      // Vector batch assign:
-      const int32_t* bias_data =
-          tflite::micro::GetTensorData<int32_t>(bias_tensor);
-      for (int i = 0; i < n_batch; ++i) {
-        int32_t* output_ptr = scratch_output_tensor + i * n_unit;
-        const int32_t* bias_ptr = bias_data;
-        for (int j = 0; j < n_unit; ++j) {
-          *output_ptr++ = *bias_ptr++;
-        }
-      }
-    } else {
-      int32_t* output_ptr = scratch_output_tensor;
-      for (int i = 0; i < n_batch * n_unit; ++i) {
-        *output_ptr++ = 0;
-      }
-    }
-
-    // Reduce.
-    for (int b = 0; b < n_batch; ++b) {
-      int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
-      int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
-
-      // Reduction sum vector
-      for (int i = 0; i < n_unit; ++i) {
-        for (int j = 0; j < n_rank; ++j) {
-          output_temp_ptr[i] += *scratch_ptr_batch++;
-        }
-      }
-    }
-
-    // Rescale.
-    const int32_t output_max = std::numeric_limits<int8_t>::max();
-    const int32_t output_min = std::numeric_limits<int8_t>::min();
-    for (int i = 0; i < n_batch * n_unit; ++i) {
-      int32_t x1 = scratch_output_tensor[i];
-      int32_t x2 = MultiplyByQuantizedMultiplier(x1, data.effective_scale_2_a,
-                                                 data.effective_scale_2_b);
-      int32_t x3 = x2 + data.output_zero_point;
-      int32_t x4 = std::min(std::max(output_min, x3), output_max);
-      tflite::micro::GetTensorData<int8_t>(output_tensor)[i] =
-          static_cast<int8_t>(x4);
-    }
-  }
-}
-
 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
   return context->AllocatePersistentBuffer(context, sizeof(OpData));
@@ -517,8 +360,9 @@
     }
 
     case kTfLiteInt8: {
-      EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
-                      params, activation_state, output, data);
+      EvalIntegerSvdfReference(context, node, input, weights_feature,
+                               weights_time, bias, params, activation_state,
+                               output, data);
       return kTfLiteOk;
       break;
     }
diff --git a/tensorflow/lite/micro/kernels/svdf.h b/tensorflow/lite/micro/kernels/svdf.h
new file mode 100644
index 0000000..b10ede6
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/svdf.h
@@ -0,0 +1,51 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
+#define TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+
+namespace tflite {
+
+struct OpData {
+  int32_t effective_scale_1_a;
+  int32_t effective_scale_2_a;
+  // b versions of each scale are kept at int since the numbers are just the
+  // shift value - typically between [-32, 32].
+  int effective_scale_1_b;
+  int effective_scale_2_b;
+  int scratch_tensor_index;
+  int scratch_output_tensor_index;
+
+  // Cached tensor zero point values for quantized operations.
+  int input_zero_point;
+  int output_zero_point;
+};
+
+// TensorflowLite Micro-specific reference implementation for Integer SVDF.
+void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
+                              const TfLiteEvalTensor* input_tensor,
+                              const TfLiteEvalTensor* weights_feature_tensor,
+                              const TfLiteEvalTensor* weights_time_tensor,
+                              const TfLiteEvalTensor* bias_tensor,
+                              const TfLiteSVDFParams* params,
+                              TfLiteEvalTensor* activation_state_tensor,
+                              TfLiteEvalTensor* output_tensor,
+                              const OpData& data);
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
diff --git a/tensorflow/lite/micro/kernels/svdf_common.cc b/tensorflow/lite/micro/kernels/svdf_common.cc
new file mode 100644
index 0000000..dcbe02d
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/svdf_common.cc
@@ -0,0 +1,177 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <math.h>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/micro/kernels/activation_utils.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/svdf.h"
+#include "tensorflow/lite/micro/micro_utils.h"
+
+namespace tflite {
+
+void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
+                              const TfLiteEvalTensor* input_tensor,
+                              const TfLiteEvalTensor* weights_feature_tensor,
+                              const TfLiteEvalTensor* weights_time_tensor,
+                              const TfLiteEvalTensor* bias_tensor,
+                              const TfLiteSVDFParams* params,
+                              TfLiteEvalTensor* activation_state_tensor,
+                              TfLiteEvalTensor* output_tensor,
+                              const OpData& data) {
+  const int n_rank = params->rank;
+  const int n_batch = input_tensor->dims->data[0];
+  const int n_input = input_tensor->dims->data[1];
+  const int n_filter = weights_feature_tensor->dims->data[0];
+  const int n_unit = n_filter / n_rank;
+  const int n_memory = weights_time_tensor->dims->data[1];
+
+  TFLITE_DCHECK(context != nullptr);
+  TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
+
+  int32_t* scratch_tensor = static_cast<int32_t*>(
+      context->GetScratchBuffer(context, data.scratch_tensor_index));
+  int32_t* scratch_output_tensor = static_cast<int32_t*>(
+      context->GetScratchBuffer(context, data.scratch_output_tensor_index));
+
+  // Shift states.
+  int16_t* const state_ptr =
+      tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
+
+  // Left shift the activation_state.
+  {
+    int16_t* new_state_start = state_ptr;
+    const int16_t* old_state_start = state_ptr + 1;
+    const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory;
+    while (old_state_start != old_state_end) {
+      *new_state_start++ = *old_state_start++;
+    }
+  }
+
+  // Note: no need to clear the latest activation, matmul is not accumulative.
+
+  // Feature matmul.
+  {
+    int16_t* state =
+        tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
+    const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
+    const int8_t* weight_feature =
+        tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
+    const int32_t output_max = std::numeric_limits<int16_t>::max();
+    const int32_t output_min = std::numeric_limits<int16_t>::min();
+    int16_t* result_in_batch = state + (n_memory - 1);
+    for (int b = 0; b < n_batch; b++) {
+      const int8_t* matrix_ptr = weight_feature;
+      for (int r = 0; r < n_filter; r++) {
+        int32_t dot_prod = 0;
+        const int8_t* vector_in_batch = input + b * n_input;
+        for (int c = 0; c < n_input; c++) {
+          dot_prod +=
+              *matrix_ptr++ * (*vector_in_batch++ - data.input_zero_point);
+        }
+        dot_prod = MultiplyByQuantizedMultiplier(
+            dot_prod, data.effective_scale_1_a, data.effective_scale_1_b);
+        dot_prod = std::min(std::max(output_min, dot_prod), output_max);
+        // This assumes state is symmetrically quantized. Otherwise last bit of
+        // state should be initialized to its zero point and accumulate the
+        // dot_prod.
+        // Equivalent as the following:
+        //     result_in_batch = zero point, which happens to be zero.
+        //     result_in_batch += dot_prod_56.
+        *result_in_batch = dot_prod;
+        result_in_batch += n_memory;
+      }
+    }
+  }
+
+  // Time.
+  {
+    for (int b = 0; b < n_batch; ++b) {
+      int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
+
+      // Perform batched vector dot product:
+      const int16_t* vector1_ptr =
+          tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
+      const int16_t* vector2_ptr =
+          tflite::micro::GetTensorData<int16_t>(activation_state_tensor) +
+          b * n_memory * n_filter;
+
+      for (int i = 0; i < n_filter; i++) {
+        *scratch_ptr_batch = 0;
+        for (int j = 0; j < n_memory; j++) {
+          *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
+        }
+        scratch_ptr_batch++;
+      }
+    }
+  }
+
+  // Reduce, add bias, rescale, activation.
+  {
+    // Add bias.
+    if (bias_tensor) {
+      // Vector batch assign:
+      const int32_t* bias_data =
+          tflite::micro::GetTensorData<int32_t>(bias_tensor);
+      for (int i = 0; i < n_batch; ++i) {
+        int32_t* output_ptr = scratch_output_tensor + i * n_unit;
+        const int32_t* bias_ptr = bias_data;
+        for (int j = 0; j < n_unit; ++j) {
+          *output_ptr++ = *bias_ptr++;
+        }
+      }
+    } else {
+      int32_t* output_ptr = scratch_output_tensor;
+      for (int i = 0; i < n_batch * n_unit; ++i) {
+        *output_ptr++ = 0;
+      }
+    }
+
+    // Reduce.
+    for (int b = 0; b < n_batch; ++b) {
+      int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
+      int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
+
+      // Reduction sum vector
+      for (int i = 0; i < n_unit; ++i) {
+        for (int j = 0; j < n_rank; ++j) {
+          output_temp_ptr[i] += *scratch_ptr_batch++;
+        }
+      }
+    }
+
+    // Rescale.
+    const int32_t output_max = std::numeric_limits<int8_t>::max();
+    const int32_t output_min = std::numeric_limits<int8_t>::min();
+    for (int i = 0; i < n_batch * n_unit; ++i) {
+      int32_t x1 = scratch_output_tensor[i];
+      int32_t x2 = MultiplyByQuantizedMultiplier(x1, data.effective_scale_2_a,
+                                                 data.effective_scale_2_b);
+      int32_t x3 = x2 + data.output_zero_point;
+      int32_t x4 = std::min(std::max(output_min, x3), output_max);
+      tflite::micro::GetTensorData<int8_t>(output_tensor)[i] =
+          static_cast<int8_t>(x4);
+    }
+  }
+}
+
+}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa/conv.cc b/tensorflow/lite/micro/kernels/xtensa/conv.cc
index 0af54c1..41a11a8 100644
--- a/tensorflow/lite/micro/kernels/xtensa/conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/conv.cc
@@ -15,17 +15,17 @@
 
 #include "tensorflow/lite/kernels/internal/reference/conv.h"
 
-#include <xtensa/tie/xt_hifi2.h>
-
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/common.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/padding.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
 
 namespace tflite {
 namespace {
@@ -60,6 +60,7 @@
   int32_t output_activation_max;
 };
 
+#if defined(HIFIMINI)
 void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier,
                     const int32_t* output_shift,
                     const RuntimeShape& input_shape, const int8_t* input_data,
@@ -260,6 +261,7 @@
     output_data[ch] = static_cast<int8_t>(AE_TRUNCA32Q48(acc_56));
   }
 }
+#endif
 
 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
                              TfLiteConvParams* params, int width, int height,
@@ -379,6 +381,7 @@
   op_params.quantized_activation_min = data->output_activation_min;
   op_params.quantized_activation_max = data->output_activation_max;
 
+#if defined(HIFIMINI)
   ConvPerChannel(op_params, data->per_channel_output_multiplier,
                  data->per_channel_output_shift,
                  tflite::micro::GetTensorShape(input),
@@ -389,6 +392,18 @@
                  tflite::micro::GetTensorData<int32_t>(bias),
                  tflite::micro::GetTensorShape(output),
                  tflite::micro::GetTensorData<int8_t>(output));
+#else
+  reference_integer_ops::ConvPerChannel(
+      op_params, data->per_channel_output_multiplier,
+      data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
+      tflite::micro::GetTensorData<int8_t>(input),
+      tflite::micro::GetTensorShape(filter),
+      tflite::micro::GetTensorData<int8_t>(filter),
+      tflite::micro::GetTensorShape(bias),
+      tflite::micro::GetTensorData<int32_t>(bias),
+      tflite::micro::GetTensorShape(output),
+      tflite::micro::GetTensorData<int8_t>(output));
+#endif
 }
 
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -408,6 +423,7 @@
           ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
           : nullptr;
 
+#if defined(HIFIMINI)
   int* input_dims = input->dims->data;
   int* filter_dims = filter->dims->data;
   if (input_dims[0] == 1 && input_dims[1] == 1 && input_dims[2] == 1 &&
@@ -427,6 +443,7 @@
         tflite::micro::GetTensorData<int8_t>(output));
     return kTfLiteOk;
   }
+#endif
 
   switch (input->type) {
     case kTfLiteInt8:
diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc
index b0ecedc..9cfaba7 100644
--- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include <xtensa/tie/xt_hifi2.h>
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
 
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
@@ -26,6 +26,7 @@
 #include "tensorflow/lite/kernels/padding.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
 
 namespace tflite {
 namespace {
@@ -61,6 +62,7 @@
   int32_t output_activation_max;
 };
 
+#if defined(HIFIMINI)
 inline void DepthwiseConvPerChannel(
     const DepthwiseParams& params, const int32_t* output_multiplier,
     const int32_t* output_shift, const RuntimeShape& input_shape,
@@ -304,6 +306,7 @@
     output_data[ch_1] = static_cast<int8_t>(AE_TRUNCA32Q48(block_1_acc));
   }
 }
+#endif
 
 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
                              TfLiteDepthwiseConvParams* params, int width,
@@ -331,7 +334,7 @@
     int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
 
     // TODO(b/148610881): Consider calculating quantized params at int24
-    // calculations:
+    // calculations for hifimini.
     TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
         context, input, filter, bias, output, params->activation,
         &data->output_multiplier, &data->output_shift,
@@ -424,6 +427,7 @@
   op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
   op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
 
+#if defined(HIFIMINI)
   DepthwiseConvPerChannel(op_params, data->per_channel_output_multiplier,
                           data->per_channel_output_shift,
                           tflite::micro::GetTensorShape(input),
@@ -434,6 +438,18 @@
                           tflite::micro::GetTensorData<int32_t>(bias),
                           tflite::micro::GetTensorShape(output),
                           tflite::micro::GetTensorData<int8_t>(output));
+#else
+  reference_integer_ops::DepthwiseConvPerChannel(
+      op_params, data->per_channel_output_multiplier,
+      data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
+      tflite::micro::GetTensorData<int8_t>(input),
+      tflite::micro::GetTensorShape(filter),
+      tflite::micro::GetTensorData<int8_t>(filter),
+      tflite::micro::GetTensorShape(bias),
+      tflite::micro::GetTensorData<int32_t>(bias),
+      tflite::micro::GetTensorShape(output),
+      tflite::micro::GetTensorData<int8_t>(output));
+#endif
 }
 
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -454,6 +470,7 @@
           ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
           : nullptr;
 
+#if defined(HIFIMINI)
   // Handle special case for streaming model.
   int* input_dims = input->dims->data;
   int* filter_dims = filter->dims->data;
@@ -474,6 +491,8 @@
         tflite::micro::GetTensorData<int8_t>(output));
     return kTfLiteOk;
   }
+#endif
+
   switch (input->type) {  // Already know in/out types are same.
     case kTfLiteInt8:
       EvalQuantizedPerChannel(context, node, params, op_data, input, filter,
diff --git a/tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h b/tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h
index a1d14df..2f8a4bd 100644
--- a/tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h
+++ b/tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h
@@ -16,16 +16,17 @@
 #ifndef TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_
 #define TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_
 
-#include <xtensa/tie/xt_hifi2.h>
-
 #include <algorithm>
 #include <cmath>
 #include <cstdint>
 
 #include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
 
 namespace tflite {
 
+#if defined(HIFIMINI)
+
 // INT24 MIN/MAX
 #define INT24_MIN -8388608
 #define INT24_MAX 8388607
@@ -132,6 +133,8 @@
   return static_cast<int>(raw);
 }
 
+#endif  // defined(HIFIMINI)
+
 }  // namespace tflite
 
 #endif  // TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_
diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc
index 165e243..a169343 100644
--- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc
@@ -15,8 +15,6 @@
 
 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
 
-#include <xtensa/tie/xt_hifi2.h>
-
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/common.h"
@@ -26,6 +24,7 @@
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
 
 namespace tflite {
 namespace {
@@ -54,6 +53,7 @@
 constexpr int kBiasTensor = 2;
 constexpr int kOutputTensor = 0;
 
+#if defined(HIFIMINI)
 void FullyConnected(const FullyConnectedParams& params,
                     const RuntimeShape& input_shape, const int8_t* input_data,
                     const RuntimeShape& filter_shape, const int8_t* filter_data,
@@ -137,6 +137,7 @@
     }
   }
 }
+#endif
 
 TfLiteStatus CalculateOpData(TfLiteContext* context,
                              TfLiteFusedActivation activation,
@@ -147,8 +148,13 @@
   double real_multiplier = 0.0;
   TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
       context, input, filter, bias, output, &real_multiplier));
+#if defined(HIFIMINI)
   QuantizeMultiplierForInt24(real_multiplier, &data->output_multiplier,
                              &data->output_shift);
+#else
+  QuantizeMultiplier(real_multiplier, &data->output_multiplier,
+                     &data->output_shift);
+#endif
   return CalculateActivationRangeQuantized(context, activation, output,
                                            &data->output_activation_min,
                                            &data->output_activation_max);
@@ -206,6 +212,7 @@
   op_params.quantized_activation_min = data.output_activation_min;
   op_params.quantized_activation_max = data.output_activation_max;
 
+#if defined(HIFIMINI)
   FullyConnected(op_params, tflite::micro::GetTensorShape(input),
                  tflite::micro::GetTensorData<int8_t>(input),
                  tflite::micro::GetTensorShape(filter),
@@ -214,6 +221,18 @@
                  tflite::micro::GetTensorData<int32_t>(bias),
                  tflite::micro::GetTensorShape(output),
                  tflite::micro::GetTensorData<int8_t>(output));
+#else
+  reference_integer_ops::FullyConnected(
+      op_params, tflite::micro::GetTensorShape(input),
+      tflite::micro::GetTensorData<int8_t>(input),
+      tflite::micro::GetTensorShape(filter),
+      tflite::micro::GetTensorData<int8_t>(filter),
+      tflite::micro::GetTensorShape(bias),
+      tflite::micro::GetTensorData<int32_t>(bias),
+      tflite::micro::GetTensorShape(output),
+      tflite::micro::GetTensorData<int8_t>(output));
+#endif
+
   return kTfLiteOk;
 }
 
diff --git a/tensorflow/lite/micro/kernels/xtensa/quantize.cc b/tensorflow/lite/micro/kernels/xtensa/quantize.cc
index 05646a3..3b84e06 100644
--- a/tensorflow/lite/micro/kernels/xtensa/quantize.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/quantize.cc
@@ -15,20 +15,22 @@
 
 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
 
-#include <xtensa/tie/xt_hifi2.h>
-
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/quantize.h"
 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/quantize.h"
 #include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
 #include "tensorflow/lite/micro/micro_utils.h"
 
 namespace tflite {
 namespace {
 
+#if defined(HIFIMINI)
 struct OpData {
   int32_t zero_point = 0;
   int scale_multiplier = 0;
@@ -107,34 +109,7 @@
   }
 }
 
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  auto* op_data = static_cast<OpData*>(node->user_data);
-
-  TfLiteTensor* output = GetOutput(context, node, 0);
-  const TfLiteTensor* input = GetInput(context, node, 0);
-
-  // TODO(b/155682734): Fix dangerous input/output scale ratio assumptions.
-  op_data->scale_multiplier =
-      CreateQConstantForInt24(0, input->params.scale / output->params.scale);
-
-  op_data->zero_point = output->params.zero_point;
-  op_data->input_zero_point = input->params.zero_point;
-
-  double effective_scale = static_cast<double>(input->params.scale) /
-                           static_cast<double>(output->params.scale);
-  QuantizeMultiplier(effective_scale, &op_data->requantize_output_multiplier,
-                     &op_data->requantize_output_shift);
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus EvalHifimini(TfLiteContext* context, TfLiteNode* node) {
   TFLITE_DCHECK(node->user_data != nullptr);
   auto* op_data = static_cast<OpData*>(node->user_data);
 
@@ -162,6 +137,54 @@
   }
   return kTfLiteOk;
 }
+#endif  // defined(HIFIMINI)
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+#if defined(HIFIMINI)
+  return context->AllocatePersistentBuffer(context, sizeof(OpData));
+#else
+  return context->AllocatePersistentBuffer(context,
+                                           sizeof(OpDataQuantizeReference));
+#endif
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TFLITE_DCHECK(node->user_data != nullptr);
+
+  TfLiteTensor* output = GetOutput(context, node, 0);
+  const TfLiteTensor* input = GetInput(context, node, 0);
+
+#if defined(HIFIMINI)
+  auto* op_data = static_cast<OpData*>(node->user_data);
+  // TODO(b/155682734): Fix dangerous input/output scale ratio assumptions.
+  op_data->scale_multiplier =
+      CreateQConstantForInt24(0, input->params.scale / output->params.scale);
+  op_data->zero_point = output->params.zero_point;
+#else
+  auto* op_data = static_cast<OpDataQuantizeReference*>(node->user_data);
+  op_data->quantization_params.zero_point = output->params.zero_point;
+  op_data->quantization_params.scale =
+      static_cast<double>(output->params.scale);
+#endif
+
+  op_data->input_zero_point = input->params.zero_point;
+
+  double effective_scale = static_cast<double>(input->params.scale) /
+                           static_cast<double>(output->params.scale);
+  QuantizeMultiplier(effective_scale, &op_data->requantize_output_multiplier,
+                     &op_data->requantize_output_shift);
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+#if defined(HIFIMINI)
+  return EvalHifimini(context, node);
+#else
+  return EvalQuantizeReference(context, node);
+#endif
+}
 
 }  // namespace
 
diff --git a/tensorflow/lite/micro/kernels/xtensa/svdf.cc b/tensorflow/lite/micro/kernels/xtensa/svdf.cc
index 5392e50..f9d6e18 100644
--- a/tensorflow/lite/micro/kernels/xtensa/svdf.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/svdf.cc
@@ -13,8 +13,9 @@
 limitations under the License.
 ==============================================================================*/
 
-#include <math.h>
-#include <xtensa/tie/xt_hifi2.h>
+#include "tensorflow/lite/micro/kernels/svdf.h"
+
+#include <cmath>
 
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
@@ -26,25 +27,11 @@
 #include "tensorflow/lite/micro/kernels/activation_utils.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
 
 namespace tflite {
 namespace {
 
-struct OpData {
-  int32_t effective_scale_1_a;
-  int32_t effective_scale_2_a;
-  // b versions of each scale are kept at int since the numbers are just the
-  // shift value - typically between [-32, 32].
-  int effective_scale_1_b;
-  int effective_scale_2_b;
-  int scratch_tensor_index;
-  int scratch_output_tensor_index;
-
-  // Cached tensor zero point values for quantized operations.
-  int input_zero_point;
-  int output_zero_point;
-};
-
 // Input tensors.
 constexpr int kInputTensor = 0;
 constexpr int kWeightsFeatureTensor = 1;
@@ -56,6 +43,7 @@
 // Output tensor.
 constexpr int kOutputTensor = 0;
 
+#if defined(HIFIMINI)
 /**
  * This version of SVDF is specific to TFLite Micro. It contains only a full
  * integer receipe with optimizations for the Xtensa HiFiMini platform.
@@ -255,6 +243,7 @@
     }
   }
 }
+#endif
 
 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   TFLITE_DCHECK(context != nullptr);
@@ -357,10 +346,17 @@
   TFLITE_DCHECK(node->user_data != nullptr);
   OpData* data = static_cast<OpData*>(node->user_data);
 
+#if defined(HIFIMINI)
   QuantizeMultiplierForInt24(effective_scale_1, &data->effective_scale_1_a,
                              &data->effective_scale_1_b);
   QuantizeMultiplierForInt24(effective_scale_2, &data->effective_scale_2_a,
                              &data->effective_scale_2_b);
+#else
+  QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
+                     &(data->effective_scale_1_b));
+  QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
+                     &(data->effective_scale_2_b));
+#endif
 
   data->input_zero_point = input->params.zero_point;
   data->output_zero_point = output->params.zero_point;
@@ -399,8 +395,13 @@
   TFLITE_DCHECK(node->user_data != nullptr);
   const OpData& data = *(static_cast<const OpData*>(node->user_data));
 
+#if defined(HIFIMINI)
   EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
                   params, activation_state, output, data);
+#else
+  EvalIntegerSvdfReference(context, node, input, weights_feature, weights_time,
+                           bias, params, activation_state, output, data);
+#endif
   return kTfLiteOk;
 }
 
diff --git a/tensorflow/lite/java/src/main/native/op_resolver.h b/tensorflow/lite/micro/kernels/xtensa/xtensa.h
similarity index 67%
copy from tensorflow/lite/java/src/main/native/op_resolver.h
copy to tensorflow/lite/micro/kernels/xtensa/xtensa.h
index 08ff0ce..7ada7f5 100644
--- a/tensorflow/lite/java/src/main/native/op_resolver.h
+++ b/tensorflow/lite/micro/kernels/xtensa/xtensa.h
@@ -12,17 +12,12 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
-#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
 
-#include <memory>
+#ifndef TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_XTENSA_H_
+#define TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_XTENSA_H_
 
-#include "tensorflow/lite/op_resolver.h"
+#if defined(HIFIMINI)
+#include <xtensa/tie/xt_hifi2.h>
+#endif
 
-namespace tflite {
-
-std::unique_ptr<OpResolver> CreateOpResolver();
-
-}
-
-#endif  // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
+#endif  // TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_XTENSA_H_
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc
index c501d8a..01a2f4e 100644
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc
+++ b/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc
@@ -1,24 +1,3 @@
-/*******************************************************************************
-* Copyright (c) 2019-2020 Cadence Design Systems, Inc.
-*
-* Permission is hereby granted, free of charge, to any person obtaining
-* a copy of this software and associated documentation files (the
-* "Software"), to use this Software with Cadence processor cores only and
-* not with any other processors and platforms, subject to
-* the following conditions:
-*
-* The above copyright notice and this permission notice shall be included
-* in all copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
-* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
-* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
-* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
-* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-******************************************************************************/
 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc
index 9d08709..68fe4f5 100755
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc
@@ -1,24 +1,3 @@
-/*******************************************************************************
-* Copyright (c) 2019-2020 Cadence Design Systems, Inc.
-*
-* Permission is hereby granted, free of charge, to any person obtaining
-* a copy of this software and associated documentation files (the
-* "Software"), to use this Software with Cadence processor cores only and
-* not with any other processors and platforms, subject to
-* the following conditions:
-*
-* The above copyright notice and this permission notice shall be included
-* in all copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
-* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
-* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
-* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
-* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-******************************************************************************/
 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc
index 8b1a8cf..dbebfc9 100755
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc
@@ -1,24 +1,3 @@
-/*******************************************************************************
-* Copyright (c) 2019-2020 Cadence Design Systems, Inc.
-*
-* Permission is hereby granted, free of charge, to any person obtaining
-* a copy of this software and associated documentation files (the
-* "Software"), to use this Software with Cadence processor cores only and
-* not with any other processors and platforms, subject to
-* the following conditions:
-*
-* The above copyright notice and this permission notice shall be included
-* in all copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
-* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
-* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
-* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
-* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-******************************************************************************/
 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc
index 44aac92..1f2b71e 100644
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc
+++ b/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc
@@ -1,24 +1,3 @@
-/*******************************************************************************
-* Copyright (c) 2019-2020 Cadence Design Systems, Inc.
-*
-* Permission is hereby granted, free of charge, to any person obtaining
-* a copy of this software and associated documentation files (the
-* "Software"), to use this Software with Cadence processor cores only and
-* not with any other processors and platforms, subject to
-* the following conditions:
-*
-* The above copyright notice and this permission notice shall be included
-* in all copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
-* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
-* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
-* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
-* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-******************************************************************************/
 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc
index 2cbea17..3347af9 100644
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc
+++ b/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc
@@ -1,24 +1,3 @@
-/*******************************************************************************
-* Copyright (c) 2019-2020 Cadence Design Systems, Inc.
-*
-* Permission is hereby granted, free of charge, to any person obtaining
-* a copy of this software and associated documentation files (the
-* "Software"), to use this Software with Cadence processor cores only and
-* not with any other processors and platforms, subject to
-* the following conditions:
-*
-* The above copyright notice and this permission notice shall be included
-* in all copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
-* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
-* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
-* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
-* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-******************************************************************************/
 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc
index 764bc88..3158a18 100644
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc
+++ b/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc
@@ -1,24 +1,3 @@
-/*******************************************************************************
-* Copyright (c) 2019-2020 Cadence Design Systems, Inc.
-*
-* Permission is hereby granted, free of charge, to any person obtaining
-* a copy of this software and associated documentation files (the
-* "Software"), to use this Software with Cadence processor cores only and
-* not with any other processors and platforms, subject to
-* the following conditions:
-*
-* The above copyright notice and this permission notice shall be included
-* in all copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
-* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
-* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
-* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
-* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-******************************************************************************/
 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc
old mode 100755
new mode 100644
index ccb3c11..7c32b9e
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc
+++ b/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc
@@ -1,24 +1,3 @@
-/*******************************************************************************
-* Copyright (c) 2019-2020 Cadence Design Systems, Inc.
-*
-* Permission is hereby granted, free of charge, to any person obtaining
-* a copy of this software and associated documentation files (the
-* "Software"), to use this Software with Cadence processor cores only and
-* not with any other processors and platforms, subject to
-* the following conditions:
-*
-* The above copyright notice and this permission notice shall be included
-* in all copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
-* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
-* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
-* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
-* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-******************************************************************************/
 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
@@ -522,7 +501,6 @@
 }
 }  // namespace
 
-
 TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
   auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
   OpData data;
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc
index 9d256b3..65ead0f 100755
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc
+++ b/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc
@@ -1,24 +1,3 @@
-/*******************************************************************************
-* Copyright (c) 2019-2020 Cadence Design Systems, Inc.
-*
-* Permission is hereby granted, free of charge, to any person obtaining
-* a copy of this software and associated documentation files (the
-* "Software"), to use this Software with Cadence processor cores only and
-* not with any other processors and platforms, subject to
-* the following conditions:
-*
-* The above copyright notice and this permission notice shall be included
-* in all copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
-* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
-* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
-* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
-* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-******************************************************************************/
 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc
index a208713..d8ee6b2 100644
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc
+++ b/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc
@@ -1,23 +1,3 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
@@ -488,77 +468,76 @@
   TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
                     memory_size * num_filters);
 
-    TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+  TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
 
-    if (input->type == kTfLiteInt8) {
-      TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
-      TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
-      TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
-      if (bias != nullptr) {
-        TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
-      }
-
-      TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
-
-      const auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
-          input->quantization.params);
-      const auto* weights_feature_params =
-          static_cast<const TfLiteAffineQuantization*>(
-              weights_feature->quantization.params);
-      const auto* state_params = static_cast<const TfLiteAffineQuantization*>(
-          activation_state->quantization.params);
-      const auto* weight_time_params =
-          static_cast<const TfLiteAffineQuantization*>(
-              weights_time->quantization.params);
-      const auto* output_params = static_cast<const TfLiteAffineQuantization*>(
-          output->quantization.params);
-      const double effective_scale_1 =
-          static_cast<double>(input_params->scale->data[0] *
-                              weights_feature_params->scale->data[0] /
-                              state_params->scale->data[0]);
-      const double effective_scale_2 = static_cast<double>(
-          state_params->scale->data[0] * weight_time_params->scale->data[0] /
-          output_params->scale->data[0]);
-
-      TFLITE_DCHECK(node->user_data != nullptr);
-      OpData* data = static_cast<OpData*>(node->user_data);
-
-      QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
-                         &(data->effective_scale_1_b));
-      QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
-                         &(data->effective_scale_2_b));
-
-      TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
-
-      const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
-          context, batch_size * num_filters * sizeof(int32_t),
-          &(data->scratch_tensor_index));
-      TF_LITE_ENSURE_OK(context, scratch_status);
-
-      const TfLiteStatus scratch_output_status =
-          context->RequestScratchBufferInArena(
-              context, batch_size * num_units * sizeof(int32_t),
-              &(data->scratch_output_tensor_index));
-      TF_LITE_ENSURE_OK(context, scratch_output_status);
-    } else {
-      TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
-      TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
-      TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
-      if (bias != nullptr) {
-        TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
-      }
-      TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
-
-      TFLITE_DCHECK(node->user_data != nullptr);
-      OpData* data = static_cast<OpData*>(node->user_data);
-
-      TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
-      const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
-          context, batch_size * num_filters * sizeof(float),
-          &(data->scratch_tensor_index));
-      TF_LITE_ENSURE_OK(context, scratch_status);
+  if (input->type == kTfLiteInt8) {
+    TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
+    TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
+    TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
+    if (bias != nullptr) {
+      TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
     }
 
+    TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
+
+    const auto* input_params =
+        reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
+    const auto* weights_feature_params =
+        static_cast<const TfLiteAffineQuantization*>(
+            weights_feature->quantization.params);
+    const auto* state_params = static_cast<const TfLiteAffineQuantization*>(
+        activation_state->quantization.params);
+    const auto* weight_time_params =
+        static_cast<const TfLiteAffineQuantization*>(
+            weights_time->quantization.params);
+    const auto* output_params = static_cast<const TfLiteAffineQuantization*>(
+        output->quantization.params);
+    const double effective_scale_1 = static_cast<double>(
+        input_params->scale->data[0] * weights_feature_params->scale->data[0] /
+        state_params->scale->data[0]);
+    const double effective_scale_2 = static_cast<double>(
+        state_params->scale->data[0] * weight_time_params->scale->data[0] /
+        output_params->scale->data[0]);
+
+    TFLITE_DCHECK(node->user_data != nullptr);
+    OpData* data = static_cast<OpData*>(node->user_data);
+
+    QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
+                       &(data->effective_scale_1_b));
+    QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
+                       &(data->effective_scale_2_b));
+
+    TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
+
+    const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
+        context, batch_size * num_filters * sizeof(int32_t),
+        &(data->scratch_tensor_index));
+    TF_LITE_ENSURE_OK(context, scratch_status);
+
+    const TfLiteStatus scratch_output_status =
+        context->RequestScratchBufferInArena(
+            context, batch_size * num_units * sizeof(int32_t),
+            &(data->scratch_output_tensor_index));
+    TF_LITE_ENSURE_OK(context, scratch_output_status);
+  } else {
+    TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
+    TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
+    TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
+    if (bias != nullptr) {
+      TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
+    }
+    TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
+
+    TFLITE_DCHECK(node->user_data != nullptr);
+    OpData* data = static_cast<OpData*>(node->user_data);
+
+    TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
+    const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
+        context, batch_size * num_filters * sizeof(float),
+        &(data->scratch_tensor_index));
+    TF_LITE_ENSURE_OK(context, scratch_status);
+  }
+
   return kTfLiteOk;
 }
 
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h b/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h
index cf74128..6fe6bae 100755
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h
+++ b/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h
@@ -1,24 +1,3 @@
-/******************************************************************************
- * Copyright (C) 2019 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h
index 7a1015c..790b93b 100644
--- a/tensorflow/lite/micro/micro_mutable_op_resolver.h
+++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h
@@ -36,6 +36,8 @@
 template <unsigned int tOpCount>
 class MicroMutableOpResolver : public MicroOpResolver {
  public:
+  TF_LITE_REMOVE_VIRTUAL_DELETE
+
   explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr)
       : error_reporter_(error_reporter) {}
 
@@ -421,8 +423,6 @@
   unsigned int GetRegistrationLength() { return registrations_len_; }
 
  private:
-  TF_LITE_REMOVE_VIRTUAL_DELETE
-
   TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
                           const TfLiteRegistration& registration,
                           MicroOpResolver::BuiltinParseFunction parser) {
diff --git a/tensorflow/lite/micro/riscv32_mcu/debug_log.cc b/tensorflow/lite/micro/riscv32_mcu/debug_log.cc
index e2a552e..f9459b8 100644
--- a/tensorflow/lite/micro/riscv32_mcu/debug_log.cc
+++ b/tensorflow/lite/micro/riscv32_mcu/debug_log.cc
@@ -1,8 +1,11 @@
 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
+
     http://www.apache.org/licenses/LICENSE-2.0
+
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/stm32f4/debug_log.cc b/tensorflow/lite/micro/stm32f4/debug_log.cc
index 311005f..7d61d10 100644
--- a/tensorflow/lite/micro/stm32f4/debug_log.cc
+++ b/tensorflow/lite/micro/stm32f4/debug_log.cc
@@ -20,6 +20,6 @@
       "mov r1, %[str]\n"
       "bkpt #0xAB\n"
       :
-      : [ str ] "r"(s)
+      : [str] "r"(s)
       : "r0", "r1");
 }
diff --git a/tensorflow/lite/micro/stm32f4HAL/debug_log.cc b/tensorflow/lite/micro/stm32f4HAL/debug_log.cc
index 6e1936a..117c101 100644
--- a/tensorflow/lite/micro/stm32f4HAL/debug_log.cc
+++ b/tensorflow/lite/micro/stm32f4HAL/debug_log.cc
@@ -28,19 +28,19 @@
 
 #ifdef __GNUC__
 int __io_putchar(int ch) {
-  HAL_UART_Transmit(&DEBUG_UART_HANDLE, (uint8_t *)&ch, 1, HAL_MAX_DELAY);
+  HAL_UART_Transmit(&DEBUG_UART_HANDLE, (uint8_t*)&ch, 1, HAL_MAX_DELAY);
 
   return ch;
 }
 #else
-int fputc(int ch, FILE *f) {
-  HAL_UART_Transmit(&DEBUG_UART_HANDLE, (uint8_t *)&ch, 1, HAL_MAX_DELAY);
+int fputc(int ch, FILE* f) {
+  HAL_UART_Transmit(&DEBUG_UART_HANDLE, (uint8_t*)&ch, 1, HAL_MAX_DELAY);
 
   return ch;
 }
 #endif /* __GNUC__ */
 
-void DebugLog(const char *s) { fprintf(stderr, "%s", s); }
+void DebugLog(const char* s) { fprintf(stderr, "%s", s); }
 
 #ifdef __cplusplus
 }
diff --git a/tensorflow/lite/micro/testing/bluepill.resc b/tensorflow/lite/micro/testing/bluepill.resc
index 5e0aa6e..78af665 100644
--- a/tensorflow/lite/micro/testing/bluepill.resc
+++ b/tensorflow/lite/micro/testing/bluepill.resc
@@ -21,13 +21,5 @@
 # These lines are needed to show the results of DebugLog calls in the output.
 machine LoadPlatformDescriptionFromString "uartSemihosting: UART.SemihostingUart @ cpu"
 showAnalyzer cpu.uartSemihosting Antmicro.Renode.Analyzers.LoggingUartAnalyzer
-
-logFile $logfile
-
-macro reset
-"""
-    sysbus LoadELF $bin
-"""
-
-runMacro $reset
+cpu.uartSemihosting CreateFileBackend $logfile true
 
diff --git a/tensorflow/lite/micro/testing/bluepill.robot b/tensorflow/lite/micro/testing/bluepill.robot
deleted file mode 100644
index 0a31f08..0000000
--- a/tensorflow/lite/micro/testing/bluepill.robot
+++ /dev/null
@@ -1,46 +0,0 @@
-*** Settings ***
-Suite Setup                   Prepare Tests
-Suite Teardown                Teardown
-Test Setup                    Reset Emulation
-Test Teardown                 Teardown With Custom Message
-Resource                      ${RENODEKEYWORDS}
-
-*** Variables ***
-${CREATE_SNAPSHOT_ON_FAIL}    False
-${UART}                       sysbus.cpu.uartSemihosting
-${RESC}                       undefined_RESC
-${RENODE_LOG}                 /tmp/renode.log
-${UART_LINE_ON_SUCCESS}       ~~~ALL TESTS PASSED~~~
-${DIR_WITH_TESTS}             undefined_DIR_WTH_TESTS
-
-*** Keywords ***
-Prepare Tests
-    [Documentation]           List all binaries with _test suffix and make available from test cases
-    Setup
-    @{tests} =                List Files In Directory  ${DIR_WITH_TESTS}  pattern=*_test  absolute=True
-    Set Suite Variable        @{tests}
-
-Teardown With Custom Message
-    [Documentation]           Replace robot fail message with shorter one to avoid duplicated UART output in log
-    Set Test Message          ${file} - FAILED
-    Test Teardown
-
-Test Binary
-    Remove File               ${RENODE_LOG}
-    Execute Command           $logfile = @${RENODE_LOG}
-    Execute Script            ${RESC}
-    Create Terminal Tester    ${UART}  timeout=2
-    Start Emulation
-    Wait For Line On Uart     ${UART_LINE_ON_SUCCESS}
-
-*** Test Cases ***
-Run All Bluepill Tests
-    [Documentation]           Runs Bluepill tests and waits for a specific string on the semihosting UART
-    FOR  ${TEST}  IN  @{tests}
-        Execute Command       Clear
-        Execute Command       $bin = @${TEST}
-        ${_}  ${file} =       Split Path  ${TEST}
-        Set Test Variable     ${file}
-        Test Binary
-        Log                   \t${file} - PASSED   console=True
-    END
diff --git a/tensorflow/lite/micro/testing/Dockerfile.stm32f4 b/tensorflow/lite/micro/testing/bluepill_nontest.resc
similarity index 60%
rename from tensorflow/lite/micro/testing/Dockerfile.stm32f4
rename to tensorflow/lite/micro/testing/bluepill_nontest.resc
index 75e6118..8a5cdd1 100644
--- a/tensorflow/lite/micro/testing/Dockerfile.stm32f4
+++ b/tensorflow/lite/micro/testing/bluepill_nontest.resc
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,9 +13,11 @@
 # limitations under the License.
 # ==============================================================================
 
-# This docker configuration file lets you emulate a stm32f4 board
-# on an x86 desktop or laptop, which can be useful for debugging and
-# automated testing.
-FROM antmicro/renode:latest
+mach create
+# Load platform specification
+machine LoadPlatformDescription @platforms/cpus/stm32f103.repl    
+# Create additional semihosting interface peripheral
+machine LoadPlatformDescriptionFromString "uartSemihosting: UART.SemihostingUart @ cpu"
+# Open separate window for semihosting UART output
+showAnalyzer sysbus.cpu.uartSemihosting
 
-LABEL maintainer="Pete Warden <petewarden@google.com>"
\ No newline at end of file
diff --git a/tensorflow/lite/micro/testing/download_renode.sh b/tensorflow/lite/micro/testing/download_renode.sh
deleted file mode 100755
index c74b1a4..0000000
--- a/tensorflow/lite/micro/testing/download_renode.sh
+++ /dev/null
@@ -1,63 +0,0 @@
-#!/bin/bash
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-# #
-# # Licensed under the Apache License, Version 2.0 (the "License");
-# # you may not use this file except in compliance with the License.
-# # You may obtain a copy of the License at
-# #
-# #     http://www.apache.org/licenses/LICENSE-2.0
-# #
-# # Unless required by applicable law or agreed to in writing, software
-# # distributed under the License is distributed on an "AS IS" BASIS,
-# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# # See the License for the specific language governing permissions and
-# # limitations under the License.
-# # ==============================================================================
-#
-# # Utility script that handles downloading and extracting portable version of Renode for testing purposes.
-# # Called with one argument:
-# # 1 - Path to new folder to unpack the package into.
-#
-
-if [ $# -ne 1 ]; then
-    echo "Usage: download_renode.sh PATH"
-    echo "    PATH is a path where Renode should be unpacked"
-    echo ""
-    echo "E.g: ./download_renode.sh /tmp/renode"
-    exit 1
-fi
-
-# Colours
-ORANGE="\033[33m"
-RED="\033[31m"
-NC="\033[0m"
-
-# Target version
-RENODE_VERSION='1.11.0'
-# Get target path
-TARGET_PATH=$1
-mkdir -p "${TARGET_PATH}" || exit 1
-
-echo "Downloading Renode portable in version ${RENODE_VERSION}"
-
-# Get link to requested version
-RELEASES_JSON=`curl https://api.github.com/repos/renode/renode/releases 2>/dev/null`
-LINUX_PORTABLE_URL=`echo "${RELEASES_JSON}" |grep 'browser_download_url'|\
-    grep --extended-regexp --only-matching "https://.*${RENODE_VERSION}.*linux-portable.*tar.gz"`
-if [ -z "${LINUX_PORTABLE_URL}" ]; then
-  echo -e "${RED}Portable version of release v${RENODE_VERSION} not found. Please make sure you use correct version format ('[0-9]+.[0-9]+.[0-9]+')${NC}"
-  exit 1
-fi
-
-# Check if newer version available
-LATEST_RENODE_VERSION=`echo "${RELEASES_JSON}" |grep 'tag_name' |\
-    head --lines 1 | grep --extended-regexp --only-matching '[0-9]+\.[0-9]+\.[0-9]+'`
-if [ "${RENODE_VERSION}" != "${LATEST_RENODE_VERSION}" ]; then
-  echo -e "${ORANGE}Latest available version is ${LATEST_RENODE_VERSION}, please consider using it.${NC}"
-fi
-echo "Downloading from url: ${LINUX_PORTABLE_URL}"
-
-# Get portable & unpack
-wget --quiet --output-document - "${LINUX_PORTABLE_URL}" |\
-    tar xz --strip-components=1 --directory "${TARGET_PATH}"
-echo "Unpacked to directory: ${TARGET_PATH}"
diff --git a/tensorflow/lite/micro/testing/robot.resource.txt b/tensorflow/lite/micro/testing/robot.resource.txt
new file mode 100644
index 0000000..e06720c
--- /dev/null
+++ b/tensorflow/lite/micro/testing/robot.resource.txt
@@ -0,0 +1,26 @@
+*** Variables ***
+${UART}                       sysbus.cpu.uartSemihosting
+
+*** Keywords ***
+Teardown With Custom Message
+    Test Teardown
+    [Documentation]           Replace robot fail message with whole UART output
+    ${UART_LOGS}              Get File    ${UART_LOG}
+    Set Test Message          UART OUTPUT:\n\n${UART_LOGS}
+    Remove File               ${UART_LOG}
+
+Create Platform
+    Execute Command           $logfile=@${UART_LOG}
+    Execute Script            ${RESC}
+    Provides                  ready-platform
+
+Test Binary
+    [Arguments]               ${BIN}
+    Requires                  ready-platform
+    Execute Command           sysbus LoadELF ${BIN}
+
+    Create Terminal Tester    ${UART}  timeout=2
+    Start Emulation
+
+    Wait For Line On Uart     ${UART_LINE_ON_SUCCESS}
+
diff --git a/tensorflow/lite/micro/testing/sifive_fe310.resc b/tensorflow/lite/micro/testing/sifive_fe310.resc
index b2bd20c..676197c 100644
--- a/tensorflow/lite/micro/testing/sifive_fe310.resc
+++ b/tensorflow/lite/micro/testing/sifive_fe310.resc
@@ -1,3 +1,18 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 :name: SiFive-FE310
 :description: This script runs Zephyr RTOS shell sample on SiFive-FE310 platform.
 
diff --git a/tensorflow/lite/micro/testing/stm32f4.resc b/tensorflow/lite/micro/testing/stm32f4.resc
index 45f213c..024c948 100644
--- a/tensorflow/lite/micro/testing/stm32f4.resc
+++ b/tensorflow/lite/micro/testing/stm32f4.resc
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -21,13 +21,5 @@
 # These lines are needed to show the results of DebugLog calls in the output.
 machine LoadPlatformDescriptionFromString "uartSemihosting: UART.SemihostingUart @ cpu"
 showAnalyzer cpu.uartSemihosting Antmicro.Renode.Analyzers.LoggingUartAnalyzer
-
-logFile @/tmp/renode_stm32f4_log.txt
-
-macro reset
-"""
-    sysbus LoadELF $bin
-"""
-
-runMacro $reset
+cpu.uartSemihosting CreateFileBackend $logfile true
 
diff --git a/tensorflow/lite/micro/testing/stm32f4.robot b/tensorflow/lite/micro/testing/stm32f4.robot
deleted file mode 100644
index 0833c0b..0000000
--- a/tensorflow/lite/micro/testing/stm32f4.robot
+++ /dev/null
@@ -1,23 +0,0 @@
-*** Settings ***
-Suite Setup                   Setup
-Suite Teardown                Teardown
-Test Setup                    Reset Emulation
-Resource                      /opt/renode/tests/renode-keywords.robot
-
-*** Variables ***
-${UART}                       sysbus.cpu.uartSemihosting
-
-*** Test Cases ***
-Should Run Stm32f4 Test
-    [Documentation]           Runs a Stm32f4 test and waits for a specific string on the semihosting UART
-    [Tags]                    stm32f4  uart  tensorflow  arm
-    ${BIN} =                  Get Environment Variable    BIN
-    ${SCRIPT} =               Get Environment Variable    SCRIPT
-    ${EXPECTED} =             Get Environment Variable    EXPECTED
-    Execute Command           $bin = @${BIN}
-    Execute Script            ${SCRIPT}
-
-    Create Terminal Tester    ${UART}  timeout=60
-    Start Emulation
-
-    Wait For Line On Uart     ${EXPECTED}
diff --git a/tensorflow/lite/micro/testing/test_bluepill_binary.sh b/tensorflow/lite/micro/testing/test_bluepill_binary.sh
deleted file mode 100755
index 45cdbaa..0000000
--- a/tensorflow/lite/micro/testing/test_bluepill_binary.sh
+++ /dev/null
@@ -1,76 +0,0 @@
-#!/bin/bash -e
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-TFLM_ROOT_DIR=${SCRIPT_DIR}/..
-
-# The renode script for the board being emulated.
-RESC_PATH=${TFLM_ROOT_DIR}/testing/bluepill.resc
-
-# Renode's entrypoint for using the Robot Framework.
-RENODE_TEST_SCRIPT=${TFLM_ROOT_DIR}/tools/make/downloads/renode/test.sh
-
-if [ ! -f "${RENODE_TEST_SCRIPT}" ]; then
-  echo "The renode test script: ${RENODE_TEST_SCRIPT} does not exist. Please " \
-       "make sure that you have correctly installed Renode for TFLM. See " \
-       "tensorflow/lite/micro/docs/renode.md for more details."
-  exit 1
-fi
-
-if ! ${RENODE_TEST_SCRIPT} &> /dev/null
-then
-  echo "The following command failed: ${RENODE_TEST_SCRIPT}. Please " \
-       "make sure that you have correctly installed Renode for TFLM. See " \
-       "tensorflow/lite/micro/docs/renode.md for more details."
-  exit 1
-fi
-
-exit_code=0
-
-# The logs from this script will go in the RESULTS_DIRECTORY. These include:
-#  1. RENODE_LOG: Output log from the renode process.
-#  2. html and xml files generated by the Robot Framework.
-#
-# Note that with the current approach (in bluepill.robot), multiple test
-# binaries are run in a loop and RENODE_LOG only has logs from the last test
-# binary since it is deleted prior to running each test binary.
-RESULTS_DIRECTORY=/tmp/renode_bluepill_logs
-mkdir -p ${RESULTS_DIRECTORY}
-RENODE_LOG=${RESULTS_DIRECTORY}/renode_log.txt
-
-ROBOT_COMMAND="${RENODE_TEST_SCRIPT} ${TFLM_ROOT_DIR}/testing/bluepill.robot \
-  -r ${RESULTS_DIRECTORY} \
-  --variable RESC:${RESC_PATH} \
-  --variable RENODE_LOG:${RENODE_LOG} \
-  --variable DIR_WITH_TESTS:${1}"
-
-echo "${ROBOT_COMMAND}"
-
-if ! ${ROBOT_COMMAND}
-then
-  exit_code=1
-fi
-
-if [ $exit_code -eq 0 ]
-then
-  echo "PASS"
-else
-  echo "UART LOGS:"
-  # Extract output from renode log
-  cat ${RENODE_LOG} |grep 'uartSemihosting' |sed 's/^.*from start] *//g'
-fi
-
-exit $exit_code
diff --git a/tensorflow/lite/micro/testing/test_stm32f4_binary.sh b/tensorflow/lite/micro/testing/test_stm32f4_binary.sh
deleted file mode 100755
index de7d749..0000000
--- a/tensorflow/lite/micro/testing/test_stm32f4_binary.sh
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/bin/bash -e
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-#
-# Tests a 'stm32f4' STM32F4 ELF by parsing the log output of Renode emulation.
-#
-# First argument is the ELF location.
-# Second argument is a regular expression that's required to be in the output logs
-# for the test to pass.
-#
-# This script must be run from the top-level folder of the tensorflow github
-# repository as it mounts `pwd` to the renode docker image (via docker run -v)
-# and paths in the docker run command assume the entire tensorflow repo is mounted.
-
-declare -r ROOT_DIR=`pwd`
-declare -r TEST_TMPDIR=/tmp/test_stm32f4_binary/
-declare -r MICRO_LOG_PATH=${TEST_TMPDIR}
-declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
-mkdir -p ${MICRO_LOG_PATH}
-
-docker build -t renode_stm32f4 \
-  -f ${ROOT_DIR}/tensorflow/lite/micro/testing/Dockerfile.stm32f4 \
-  ${ROOT_DIR}/tensorflow/lite/micro/testing/
-
-exit_code=0
-# running in `if` to avoid setting +e
-if ! docker run \
-  --log-driver=none -a stdout -a stderr \
-  -v ${ROOT_DIR}:/workspace \
-  -v /tmp:/tmp \
-  -e BIN=/workspace/$1 \
-  -e SCRIPT=/workspace/tensorflow/lite/micro/testing/stm32f4.resc \
-  -e EXPECTED="$2" \
-  -it renode_stm32f4 \
-  /bin/bash -c "/opt/renode/tests/test.sh /workspace/tensorflow/lite/micro/testing/stm32f4.robot 2>&1 >${MICRO_LOG_FILENAME}"
-then
-  exit_code=1
-fi
-
-echo "LOGS:"
-cat ${MICRO_LOG_FILENAME}
-if [ $exit_code -eq 0 ]
-then
-  echo "$1: PASS"
-else
-  echo "$1: FAIL - '$2' not found in logs."
-fi
-exit $exit_code
diff --git a/tensorflow/lite/micro/testing/test_with_renode.sh b/tensorflow/lite/micro/testing/test_with_renode.sh
new file mode 100755
index 0000000..1dea545
--- /dev/null
+++ b/tensorflow/lite/micro/testing/test_with_renode.sh
@@ -0,0 +1,108 @@
+#!/bin/bash -e
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+#
+# Parameters:
+#  ${1} - path to a binary to test or directory (all *_test will be run).
+#  ${2} - target (bluepill, stm32f4 etc.)
+
+set -e
+
+TARGET=${2}
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+TFLM_ROOT_DIR=${SCRIPT_DIR}/..
+
+# The renode script for the board being emulated.
+RESC_PATH=${TFLM_ROOT_DIR}/testing/${TARGET}.resc
+
+# Robot file with definition of custom keywords used in test suite.
+ROBOT_RESOURCE=${TFLM_ROOT_DIR}/testing/robot.resource.txt
+
+# Renode's entrypoint for using the Robot Framework.
+RENODE_TEST_SCRIPT=${TFLM_ROOT_DIR}/tools/make/downloads/renode/test.sh
+
+if [ ! -f "${RENODE_TEST_SCRIPT}" ]; then
+  echo "The renode test script: ${RENODE_TEST_SCRIPT} does not exist. Please " \
+       "make sure that you have correctly installed Renode for TFLM. See " \
+       "tensorflow/lite/micro/docs/renode.md for more details."
+  exit 1
+fi
+
+if ! ${RENODE_TEST_SCRIPT} &> /dev/null
+then
+  echo "The following command failed: ${RENODE_TEST_SCRIPT}. Please " \
+       "make sure that you have correctly installed Renode for TFLM. See " \
+       "tensorflow/lite/micro/docs/renode.md for more details."
+  exit 1
+fi
+
+# Files generated by this script will go in the RESULTS_DIRECTORY. These include:
+#  1. UART_LOG: Output log from the renode uart.
+#  2. html and xml files generated by the Robot Framework.
+#  3. ROBOT_SCRIPT: Generated test suite.
+#
+# Note that with the current approach (in generated ROBOT_SCRIPT), multiple test
+# binaries are run in a the same test suite and UART_LOG only has logs from the last test
+# binary since it is deleted prior to running each test binary. If some test fails
+# the UART_LOG will be printed to console log before being deleted.
+RESULTS_DIRECTORY=/tmp/renode_${TARGET}_logs
+mkdir -p ${RESULTS_DIRECTORY}
+
+UART_LOG=${RESULTS_DIRECTORY}/uart_log.txt
+
+ROBOT_SCRIPT=${RESULTS_DIRECTORY}/${TARGET}.robot
+
+echo -e "*** Settings ***\n" \
+        "Suite Setup                   Setup\n" \
+        "Suite Teardown                Teardown\n" \
+        "Test Setup                    Reset Emulation\n" \
+        "Test Teardown                 Teardown With Custom Message\n" \
+        "Resource                      \${RENODEKEYWORDS}\n" \
+        "Resource                      ${ROBOT_RESOURCE}\n" \
+        "Default Tags                  tensorflow\n" \
+        "\n" \
+        "*** Variables ***\n" \
+        "\${RESC}                      undefined_RESC\n" \
+        "\${UART_LOG}                  /tmp/uart.log\n" \
+        "\${UART_LINE_ON_SUCCESS}      ~~~ALL TESTS PASSED~~~\n" \
+        "\${CREATE_SNAPSHOT_ON_FAIL}   False\n" \
+        "\n" \
+        "*** Test Cases ***\n" \
+        "Should Create Platform\n" \
+        "    Create Platform\n" > $ROBOT_SCRIPT
+
+declare -a FILES
+if [[ -d ${1} ]]; then
+    FILES=`ls -1 ${1}/*_test`
+else
+    FILES=${1}
+fi
+
+for binary in ${FILES}
+do
+    echo -e "Should Run $(basename ${binary})\n"\
+            "    Test Binary    @$(realpath ${binary})\n" >> ${ROBOT_SCRIPT}
+done
+
+ROBOT_COMMAND="${RENODE_TEST_SCRIPT} ${ROBOT_SCRIPT} \
+  -r ${RESULTS_DIRECTORY} \
+  --variable RESC:${RESC_PATH} \
+  --variable UART_LOG:${UART_LOG}"
+
+echo "${ROBOT_COMMAND}"
+echo ""
+${ROBOT_COMMAND}
diff --git a/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh b/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh
index 50415e7..403b39f 100755
--- a/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh
+++ b/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh
@@ -1,25 +1,4 @@
 #!/bin/bash -e
-# ==============================================================================
-# Copyright (C) 2019 Cadence Design Systems, Inc.
-#
-# Permission is hereby granted, free of charge, to any person obtaining
-# a copy of this software and associated documentation files (the
-# "Software"), to use this Software with Cadence processor cores only and
-# not with any other processors and platforms, subject to
-# the following conditions:
-#
-# The above copyright notice and this permission notice shall be included
-# in all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
-# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
-# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
-# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
-# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-# ==============================================================================
-
 # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/lite/micro/tools/ci_build/test_all.sh b/tensorflow/lite/micro/tools/ci_build/test_all.sh
index c7a53cc..4bd73fe 100755
--- a/tensorflow/lite/micro/tools/ci_build/test_all.sh
+++ b/tensorflow/lite/micro/tools/ci_build/test_all.sh
@@ -24,16 +24,15 @@
 cd "${ROOT_DIR}"
 pwd
 
-make -f tensorflow/lite/micro/tools/make/Makefile \
-  clean clean_downloads
+make -f tensorflow/lite/micro/tools/make/Makefile clean_downloads DISABLE_DOWNLOADS=true
+make -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn clean DISABLE_DOWNLOADS=true
+if [ -d tensorflow/lite/micro/tools/make/downloads ]; then
+  echo "ERROR: Downloads directory should not exist, but it does."
+  exit 1
+fi
 
-# We are moving away from having the downloads and installations be part of the
-# Makefile. As a result, we need to manually add the downloads in this script.
-# Once we move more than the renode downloads out of the Makefile, we should
-# have common way to perform the downloads for a given target, tags ...
-echo "Starting renode download at `date`"
-tensorflow/lite/micro/testing/download_renode.sh tensorflow/lite/micro/tools/make/downloads/renode
-pip3 install -r tensorflow/lite/micro/tools/make/downloads/renode/tests/requirements.txt
+echo "Running code style checks at `date`"
+tensorflow/lite/micro/tools/ci_build/test_code_style.sh PRESUBMIT
 
 # Add all the test scripts for the various supported platforms here. This
 # enables running all the tests together has part of the continuous integration
diff --git a/tensorflow/lite/micro/tools/ci_build/test_bluepill.sh b/tensorflow/lite/micro/tools/ci_build/test_bluepill.sh
index f332e4c..dff3b0e 100755
--- a/tensorflow/lite/micro/tools/ci_build/test_bluepill.sh
+++ b/tensorflow/lite/micro/tools/ci_build/test_bluepill.sh
@@ -39,9 +39,4 @@
 # Next, build w/o release so that we can run the tests and get additional
 # debugging info on failures.
 readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
-readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} build
-
-# TODO(b/172939049): Using renode to run the tests is not currently integrated
-# with the Makefile.  So, we manually run the test script with the correct path
-# to the bluepill generated files.
-tensorflow/lite/micro/testing/test_bluepill_binary.sh tensorflow/lite/micro/tools/make/gen/bluepill_cortex-m3/bin/
+readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} test
diff --git a/tensorflow/lite/micro/tools/ci_build/test_code_style.sh b/tensorflow/lite/micro/tools/ci_build/test_code_style.sh
new file mode 100755
index 0000000..1165ba8
--- /dev/null
+++ b/tensorflow/lite/micro/tools/ci_build/test_code_style.sh
@@ -0,0 +1,88 @@
+#!/usr/bin/env bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+ROOT_DIR=${SCRIPT_DIR}/../../../../..
+cd "${ROOT_DIR}"
+
+source tensorflow/lite/micro/tools/ci_build/helper_functions.sh
+
+# explicitly call third_party_downloads since we need pigweed for the license
+# and clang-format checks.
+make -f tensorflow/lite/micro/tools/make/Makefile third_party_downloads
+
+# Explicitly disable exit on error so that we can properly clean up the
+# temporary git repository even when one of the scripts fail with an error code.
+set +e
+
+# The pigweed scripts only work from a git repository and the Tensorflow CI
+# infrastructure does not always guarantee that. As an ugly workaround, we
+# create our own git repo when running on the CI servers.
+pushd tensorflow/lite/micro/
+if [[ ${1} == "PRESUBMIT" ]]; then
+  git init .
+  git config user.email "tflm@google.com"
+  git config user.name "TensorflowLite Micro"
+  git add *
+  git commit -a -m "Commit for a temporary repository." > /dev/null
+fi
+
+# Check for license with the necessary exclusions.
+tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py \
+  . \
+  -p copyright_notice \
+  -e tools/make/downloads \
+  -e tools/make/targets/ecm3531 \
+  -e BUILD\
+  -e leon_commands \
+  -e "\.bzl" \
+  -e "\.h5" \
+  -e "\.ipynb" \
+  -e "\.inc" \
+  -e "\.patch" \
+  -e "\.properties" \
+  -e "\.txt" \
+  -e "\.tpl" \
+  --output-directory /tmp
+
+LICENSE_CHECK_RESULT=$?
+
+# Check that the TFLM-only code is clang-formatted We are currently ignoring
+# Python files (with yapf as the formatter) because that needs additional setup.
+# We are also ignoring the markdown files to allow for a more gradual rollout of
+# this presubmit check.
+tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/format_code.py \
+  . \
+  -e "\.inc" \
+  -e "\.md" \
+  -e "\.py"
+
+CLANG_FORMAT_RESULT=$?
+
+popd
+if [[ ${1} == "PRESUBMIT" ]]; then
+  rm -rf tensorflow/lite/micro/.git
+fi
+
+# Re-enable exit on error now that we are done with the temporary git repo.
+set -e
+
+if [[ ${LICENSE_CHECK_RESULT} != 0 || ${CLANG_FORMAT_RESULT} != 0 ]]
+then
+  exit 1
+fi
diff --git a/tensorflow/lite/micro/tools/ci_build/test_stm32f4.sh b/tensorflow/lite/micro/tools/ci_build/test_stm32f4.sh
index ba2dee3..2e7de8a 100755
--- a/tensorflow/lite/micro/tools/ci_build/test_stm32f4.sh
+++ b/tensorflow/lite/micro/tools/ci_build/test_stm32f4.sh
@@ -39,11 +39,4 @@
 # Next, build w/o release so that we can run the tests and get additional
 # debugging info on failures.
 readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
-readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TAGS=${TAGS} TARGET=${TARGET} build
-
-# TODO(b/149597202): Running tests via renode are disabled as part of the
-# continuous integration until we can get Docker running inside Docker. However,
-# if this script is run locally, the tests will still be run.
-if [[ ${1} != "PRESUBMIT" ]]; then
-readable_run make -f tensorflow/lite/micro/tools/make/Makefile TAGS=${TAGS} TARGET=${TARGET} test
-fi
+readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TAGS=${TAGS} TARGET=${TARGET} test
diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile
index dfd81e6..3e54526 100644
--- a/tensorflow/lite/micro/tools/make/Makefile
+++ b/tensorflow/lite/micro/tools/make/Makefile
@@ -1,3 +1,18 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 ifneq (3.82,$(firstword $(sort $(MAKE_VERSION) 3.82)))
   $(error "Requires make version 3.82 or later (current is $(MAKE_VERSION))")
 endif
@@ -325,6 +340,7 @@
 tensorflow/lite/micro/kernels/pooling.cc \
 tensorflow/lite/micro/kernels/prelu.cc \
 tensorflow/lite/micro/kernels/quantize.cc \
+tensorflow/lite/micro/kernels/quantize_common.cc \
 tensorflow/lite/micro/kernels/reduce.cc \
 tensorflow/lite/micro/kernels/reshape.cc \
 tensorflow/lite/micro/kernels/resize_nearest_neighbor.cc \
@@ -336,6 +352,7 @@
 tensorflow/lite/micro/kernels/strided_slice.cc \
 tensorflow/lite/micro/kernels/sub.cc \
 tensorflow/lite/micro/kernels/svdf.cc \
+tensorflow/lite/micro/kernels/svdf_common.cc \
 tensorflow/lite/micro/kernels/tanh.cc \
 tensorflow/lite/micro/kernels/unpack.cc
 
@@ -367,6 +384,7 @@
 LICENSE \
 tensorflow/core/public/version.h \
 tensorflow/lite/c/builtin_op_data.h \
+tensorflow/lite/c/c_api_types.h \
 tensorflow/lite/c/common.h \
 tensorflow/lite/core/api/error_reporter.h \
 tensorflow/lite/core/api/flatbuffer_conversions.h \
@@ -478,24 +496,38 @@
 ARDUINO_LIBRARY_TARGETS :=
 ARDUINO_LIBRARY_ZIPS :=
 
-# The download scripts require that the downloads directory already exist for
-# improved error checking. To accomodate that, we first create a downloads
-# directory.
-$(shell mkdir -p ${MAKEFILE_DIR}/downloads)
+# For some invocations of the makefile, it is useful to avoid downloads. This
+# can be achieved by explicitly passing in DISABLE_DOWNLOADS=true on the command
+# line. Note that for target-specific downloads (e.g. CMSIS) there will need to
+# be corresponding checking in the respecitve included makefiles (e.g.
+# ext_libs/cmsis_nn.inc)
+DISABLE_DOWNLOADS :=
 
-# Directly download the flatbuffers library.
-DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/flatbuffers_download.sh ${MAKEFILE_DIR}/downloads)
-ifneq ($(DOWNLOAD_RESULT), SUCCESS)
-  $(error Something went wrong with the flatbuffers download: $(DOWNLOAD_RESULT))
+ifneq ($(DISABLE_DOWNLOADS), true)
+  # The download scripts require that the downloads directory already exist for
+  # improved error checking. To accomodate that, we first create a downloads
+  # directory.
+  $(shell mkdir -p ${MAKEFILE_DIR}/downloads)
+
+  # Directly download the flatbuffers library.
+  DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/flatbuffers_download.sh ${MAKEFILE_DIR}/downloads)
+  ifneq ($(DOWNLOAD_RESULT), SUCCESS)
+    $(error Something went wrong with the flatbuffers download: $(DOWNLOAD_RESULT))
+  endif
+
+  DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/pigweed_download.sh ${MAKEFILE_DIR}/downloads)
+  ifneq ($(DOWNLOAD_RESULT), SUCCESS)
+    $(error Something went wrong with the pigweed download: $(DOWNLOAD_RESULT))
+  endif
+
+  include $(MAKEFILE_DIR)/third_party_downloads.inc
+  THIRD_PARTY_DOWNLOADS :=
+  $(eval $(call add_third_party_download,$(GEMMLOWP_URL),$(GEMMLOWP_MD5),gemmlowp,))
+  $(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,))
+  $(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,))
+  $(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_INT8_MD5),person_model_int8,))
 endif
 
-include $(MAKEFILE_DIR)/third_party_downloads.inc
-THIRD_PARTY_DOWNLOADS :=
-$(eval $(call add_third_party_download,$(GEMMLOWP_URL),$(GEMMLOWP_MD5),gemmlowp,))
-$(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,))
-$(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,))
-$(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_INT8_MD5),person_model_int8,))
-
 # The target-specific makefile must have a name that is exactly
 # TARGET_makefile.inc and is only needed for cross-compilation (i.e. when TARGET
 # is different from the HOST_OS).
@@ -505,8 +537,11 @@
 # will be separating the project generation from the Makefile in the future.
 TARGETS_WITHOUT_MAKEFILES := \
 $(HOST_OS) \
-arduino \
-chre
+arduino
+
+# ${TARGET}_makefile.inc can set this to true to allow it to defined a custom
+# implementation for `make test`. See bluepill_makefile as an example.
+TARGET_SPECIFIC_MAKE_TEST:=0
 
 ifeq ($(findstring $(TARGET),$(TARGETS_WITHOUT_MAKEFILES)),)
   include $(MAKEFILE_DIR)/targets/$(TARGET)_makefile.inc
@@ -627,7 +662,9 @@
 $(foreach TEST_TARGET,$(filter tensorflow/lite/micro/kernels/%,$(MICROLITE_TEST_SRCS)),\
 $(eval $(call microlite_test,kernel_$(notdir $(basename $(TEST_TARGET))),$(TEST_TARGET))))
 
+ifeq ($(TARGET_SPECIFIC_MAKE_TEST),0)
 test: $(MICROLITE_TEST_TARGETS)
+endif
 
 # Just build the test targets
 build: $(MICROLITE_BUILD_TARGETS)
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc
index a778d0f..f03bc8f 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc
+++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc
@@ -4,14 +4,16 @@
         # CMSIS-NN optimizations not supported
     endif
 
-    # Setup CMSIS-NN lib and add required header files to microlite lib INCLUDE.
-    # Unless an external path is provided we force a download during the first phase of make so
-    # that the files exist prior to the call to recursive_find below. add_third_party_download
-    # prevents the use of wildcards and recursive_find in selecting which files to add to THIRD_PARTY_SRCS.
-    CMSIS_DEFAULT_DOWNLOAD_PATH := $(MAKEFILE_DIR)/downloads/cmsis
-    CMSIS_PATH := $(CMSIS_DEFAULT_DOWNLOAD_PATH)
-    ifeq ($(CMSIS_PATH), $(CMSIS_DEFAULT_DOWNLOAD_PATH))
-      $(call $(or $(shell $(DOWNLOAD_SCRIPT) $(CMSIS_URL) $(CMSIS_MD5) $(CMSIS_PATH) >&2 && echo SUCCESS), $(error $(DOWNLOAD_SCRIPT) failed)))
+    ifneq ($(DISABLE_DOWNLOADS), true)
+      # Setup CMSIS-NN lib and add required header files to microlite lib INCLUDE.
+      # Unless an external path is provided we force a download during the first phase of make so
+      # that the files exist prior to the call to recursive_find below. add_third_party_download
+      # prevents the use of wildcards and recursive_find in selecting which files to add to THIRD_PARTY_SRCS.
+      CMSIS_DEFAULT_DOWNLOAD_PATH := $(MAKEFILE_DIR)/downloads/cmsis
+      CMSIS_PATH := $(CMSIS_DEFAULT_DOWNLOAD_PATH)
+      ifeq ($(CMSIS_PATH), $(CMSIS_DEFAULT_DOWNLOAD_PATH))
+        $(call $(or $(shell $(DOWNLOAD_SCRIPT) $(CMSIS_URL) $(CMSIS_MD5) $(CMSIS_PATH) >&2 && echo SUCCESS), $(error $(DOWNLOAD_SCRIPT) failed)))
+      endif
     endif
 
     THIRD_PARTY_CC_SRCS += \
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc
index 25b034f..00fbf45 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc
+++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc
@@ -1,6 +1,6 @@
 ifeq ($(TARGET_ARCH), hifi4)
 
-  DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/ext_libs/xtensa_download.sh ${MAKEFILE_DIR}/downloads)
+  DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/ext_libs/xtensa_download.sh ${MAKEFILE_DIR}/downloads hifi4)
   ifneq ($(DOWNLOAD_RESULT), SUCCESS)
     $(error Something went wrong with the xtensa download: $(DOWNLOAD_RESULT))
   endif
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
index 1630310..a427f63 100755
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
+++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
@@ -19,6 +19,7 @@
 # Called with four arguments:
 # 1 - Path to the downloads folder which is typically
 #     tensorflow/lite/micro/tools/make/downloads
+# 2 - Xtensa variant to download for (e.g. hifi4)
 #
 # This script is called from the Makefile and uses the following convention to
 # enable determination of sucess/failure:
@@ -39,28 +40,31 @@
   exit 1
 fi
 
-# Name of the xa_nnlib directory once it is unzipped.
-HIFI4_XA_NNLIB_DIRNAME="xa_nnlib_hifi4"
-
-HIFI4_PATH=${DOWNLOADS_DIR}/${HIFI4_XA_NNLIB_DIRNAME}
-if [ -d ${HIFI4_PATH} ]; then
-  echo >&2 "${HIFI4_PATH} already exists, skipping the download."
+if [[ ${2} == "hifi4" ]]; then
+  LIBRARY_URL="http://mirror.tensorflow.org/github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/xa_nnlib_06_27.zip"
+  LIBRARY_DIRNAME="xa_nnlib_hifi4"
+  LIBRARY_MD5="45fdc1209a8da62ab568aa6040f7eabf"
 else
+  echo "Attempting to download an unsupported xtensa variant: ${2}"
+  exit 1
+fi
 
-  ZIP_ARCHIVE_NAME="xa_nnlib_06_27.zip"
-  HIFI4_URL="http://mirror.tensorflow.org/github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/${ZIP_ARCHIVE_NAME}"
-  HIFI4_MD5="45fdc1209a8da62ab568aa6040f7eabf"
+LIBRARY_INSTALL_PATH=${DOWNLOADS_DIR}/${LIBRARY_DIRNAME}
 
-  wget ${HIFI4_URL} -O /tmp/${ZIP_ARCHIVE_NAME} >&2
-  MD5=`md5sum /tmp/${ZIP_ARCHIVE_NAME} | awk '{print $1}'`
+if [ -d ${LIBRARY_INSTALL_PATH} ]; then
+  echo >&2 "${LIBRARY_INSTALL_PATH} already exists, skipping the download."
+else
+  TMP_ZIP_ARCHIVE_NAME="${LIBRARY_DIRNAME}.zip"
+  wget ${LIBRARY_URL} -O /tmp/${TMP_ZIP_ARCHIVE_NAME} >&2
+  MD5=`md5sum /tmp/${TMP_ZIP_ARCHIVE_NAME} | awk '{print $1}'`
 
-  if [[ ${MD5} != ${HIFI4_MD5} ]]
+  if [[ ${MD5} != ${LIBRARY_MD5} ]]
   then
-    echo "Bad checksum. Expected: ${HIFI4_MD5}, Got: ${MD5}"
+    echo "Bad checksum. Expected: ${LIBRARY_MD5}, Got: ${MD5}"
     exit 1
   fi
 
-  unzip -qo /tmp/${ZIP_ARCHIVE_NAME} -d ${DOWNLOADS_DIR} >&2
+  unzip -qo /tmp/${TMP_ZIP_ARCHIVE_NAME} -d ${DOWNLOADS_DIR} >&2
 fi
 
 echo "SUCCESS"
diff --git a/tensorflow/lite/micro/tools/make/flatbuffers_download.sh b/tensorflow/lite/micro/tools/make/flatbuffers_download.sh
index 61f5f33..8ac0c4d 100755
--- a/tensorflow/lite/micro/tools/make/flatbuffers_download.sh
+++ b/tensorflow/lite/micro/tools/make/flatbuffers_download.sh
@@ -69,6 +69,15 @@
   mv ${temp_flexbuffers_path} ${input_flexbuffers_path}
 }
 
+# The BUILD files in the downloaded folder result in an error with:
+#  bazel build tensorflow/lite/micro/...
+#
+# Parameters:
+#   $1 - path to the downloaded flatbuffers code.
+function delete_build_files() {
+  rm -f `find ${1} -name BUILD`
+}
+
 DOWNLOADED_FLATBUFFERS_PATH=${DOWNLOADS_DIR}/flatbuffers
 
 if [ -d ${DOWNLOADED_FLATBUFFERS_PATH} ]; then
@@ -91,6 +100,8 @@
   mv /tmp/flatbuffers-${ZIP_PREFIX} ${DOWNLOADED_FLATBUFFERS_PATH}
 
   patch_to_avoid_strtod ${DOWNLOADED_FLATBUFFERS_PATH}/include/flatbuffers/flexbuffers.h
+  delete_build_files ${DOWNLOADED_FLATBUFFERS_PATH}
+
 fi
 
 echo "SUCCESS"
diff --git a/tensorflow/lite/micro/tools/make/pigweed.patch b/tensorflow/lite/micro/tools/make/pigweed.patch
new file mode 100644
index 0000000..0231b7e
--- /dev/null
+++ b/tensorflow/lite/micro/tools/make/pigweed.patch
@@ -0,0 +1,103 @@
+diff --git a/pw_presubmit/py/pw_presubmit/build.py b/pw_presubmit/py/pw_presubmit/build.py
+index 4a370e33..224ad9c6 100644
+--- a/pw_presubmit/py/pw_presubmit/build.py
++++ b/pw_presubmit/py/pw_presubmit/build.py
+@@ -20,7 +20,6 @@ from pathlib import Path
+ import re
+ from typing import Container, Dict, Iterable, List, Mapping, Set, Tuple
+ 
+-from pw_package import package_manager
+ from pw_presubmit import call, log_run, plural, PresubmitFailure, tools
+ 
+ _LOG = logging.getLogger(__name__)
+diff --git a/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py b/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py
+index 794967db..061db7ea 100755
+--- a/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py
++++ b/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py
+@@ -220,8 +220,8 @@ def clang_tidy(ctx: PresubmitContext):
+ 
+ 
+ # The first line must be regex because of the '20\d\d' date
+-COPYRIGHT_FIRST_LINE = r'Copyright 20\d\d The Pigweed Authors'
+-COPYRIGHT_COMMENTS = r'(#|//| \*|REM|::)'
++COPYRIGHT_FIRST_LINE = r'Copyright 20\d\d The TensorFlow Authors. All Rights Reserved.'
++COPYRIGHT_COMMENTS = r'(#|//|\*|REM|::|/\*)'
+ COPYRIGHT_BLOCK_COMMENTS = (
+     # HTML comments
+     (r'<!--', r'-->'), )
+@@ -232,21 +232,23 @@ COPYRIGHT_FIRST_LINE_EXCEPTIONS = (
+     '@echo off',
+     '# -*-',
+     ':',
++    '# Lint as',
++    '# coding=utf-8'
+ )
+ 
+ COPYRIGHT_LINES = tuple("""\
+ 
+-Licensed under the Apache License, Version 2.0 (the "License"); you may not
+-use this file except in compliance with the License. You may obtain a copy of
+-the License at
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
+ 
+-    https://www.apache.org/licenses/LICENSE-2.0
++    http://www.apache.org/licenses/LICENSE-2.0
+ 
+ Unless required by applicable law or agreed to in writing, software
+-distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+-WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+-License for the specific language governing permissions and limitations under
+-the License.
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
+ """.splitlines())
+ 
+ _EXCLUDE_FROM_COPYRIGHT_NOTICE: Sequence[str] = (
+@@ -344,6 +346,11 @@ def copyright_notice(ctx: PresubmitContext):
+                 errors.append(path)
+                 continue
+ 
++            # Special handling for TFLM style of copyright+license in the cc
++            # files.
++            if comment == '/*':
++              comment = ''
++
+             if end_block_comment:
+                 expected_lines = COPYRIGHT_LINES + (end_block_comment, )
+             else:
+@@ -354,6 +361,10 @@ def copyright_notice(ctx: PresubmitContext):
+                     expected_line = expected + '\n'
+                 elif comment:
+                     expected_line = (comment + ' ' + expected).rstrip() + '\n'
++                else:
++                    # Special handling for TFLM style of copyright+license in
++                    # the cc files.
++                    expected_line = (expected).rstrip() + '\n'
+ 
+                 if expected_line != actual:
+                     _LOG.warning('  bad line: %r', actual)
+@@ -475,6 +486,10 @@ BROKEN = (
+     gn_nanopb_build,
+ )
+ 
++COPYRIGHT_NOTICE = (
++    copyright_notice,
++)
++
+ QUICK = (
+     commit_message_format,
+     init_cipd,
+@@ -509,7 +524,8 @@ FULL = (
+     build_env_setup,
+ )
+ 
+-PROGRAMS = Programs(broken=BROKEN, quick=QUICK, full=FULL)
++PROGRAMS = Programs(broken=BROKEN, quick=QUICK, full=FULL,
++                    copyright_notice=COPYRIGHT_NOTICE)
+ 
+ 
+ def parse_args() -> argparse.Namespace:
diff --git a/tensorflow/lite/micro/tools/make/pigweed_download.sh b/tensorflow/lite/micro/tools/make/pigweed_download.sh
new file mode 100755
index 0000000..9991ee8
--- /dev/null
+++ b/tensorflow/lite/micro/tools/make/pigweed_download.sh
@@ -0,0 +1,65 @@
+#!/bin/bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Called with following arguments:
+# 1 - Path to the downloads folder which is typically
+#     tensorflow/lite/micro/tools/make/downloads
+#
+# This script is called from the Makefile and uses the following convention to
+# enable determination of sucess/failure:
+#
+#   - If the script is successful, the only output on stdout should be SUCCESS.
+#     The makefile checks for this particular string.
+#
+#   - Any string on stdout that is not SUCCESS will be shown in the makefile as
+#     the cause for the script to have failed.
+#
+#   - Any other informational prints should be on stderr.
+
+set -e
+
+DOWNLOADS_DIR=${1}
+if [ ! -d ${DOWNLOADS_DIR} ]; then
+  echo "The top-level downloads directory: ${DOWNLOADS_DIR} does not exist."
+  exit 1
+fi
+
+# The BUILD files in the downloaded folder result in an error with:
+#  bazel build tensorflow/lite/micro/...
+#
+# Parameters:
+#   $1 - path to the downloaded flatbuffers code.
+function delete_build_files() {
+  rm -f `find ${1} -name BUILD`
+}
+
+DOWNLOADED_PIGWEED_PATH=${DOWNLOADS_DIR}/pigweed
+
+if [ -d ${DOWNLOADED_PIGWEED_PATH} ]; then
+  echo >&2 "${DOWNLOADED_PIGWEED_PATH} already exists, skipping the download."
+else
+  git clone https://pigweed.googlesource.com/pigweed/pigweed ${DOWNLOADED_PIGWEED_PATH} >&2
+  pushd ${DOWNLOADED_PIGWEED_PATH} > /dev/null
+  git checkout 47268dff45019863e20438ca3746c6c62df6ef09 >&2
+
+  # Patch for TFLM specific changes that are not currently upstreamed.
+  git apply ../../pigweed.patch
+  popd > /dev/null
+
+  delete_build_files ${DOWNLOADED_PIGWEED_PATH}
+fi
+
+echo "SUCCESS"
diff --git a/tensorflow/lite/micro/tools/make/renode_download.sh b/tensorflow/lite/micro/tools/make/renode_download.sh
new file mode 100755
index 0000000..e3acc82
--- /dev/null
+++ b/tensorflow/lite/micro/tools/make/renode_download.sh
@@ -0,0 +1,90 @@
+#!/bin/bash
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Called with following arguments:
+# 1 - Path to the downloads folder which is typically
+#     tensorflow/lite/micro/tools/make/downloads
+#
+# This script is called from the Makefile and uses the following convention to
+# enable determination of sucess/failure:
+#
+#   - If the script is successful, the only output on stdout should be SUCCESS.
+#     The makefile checks for this particular string.
+#
+#   - Any string on stdout that is not SUCCESS will be shown in the makefile as
+#     the cause for the script to have failed.
+#
+#   - Any other informational prints should be on stderr.
+
+set -e
+
+DOWNLOADS_DIR=${1}
+if [ ! -d ${DOWNLOADS_DIR} ]; then
+  echo "The top-level downloads directory: ${DOWNLOADS_DIR} does not exist."
+  exit 1
+fi
+
+DOWNLOADED_RENODE_PATH=${DOWNLOADS_DIR}/renode
+
+if [ -d ${DOWNLOADED_RENODE_PATH} ]; then
+  echo >&2 "${DOWNLOADED_RENODE_PATH} already exists, skipping the download."
+else
+  # Colours
+  ORANGE="\033[33m"
+  RED="\033[31m"
+  NC="\033[0m"
+
+  # Target version
+  RENODE_VERSION='1.11.0'
+
+  echo >&2 "Downloading Renode portable in version ${RENODE_VERSION}"
+
+  # Get link to requested version
+  RELEASES_JSON=`curl https://api.github.com/repos/renode/renode/releases 2>/dev/null`
+  LINUX_PORTABLE_URL=`echo "${RELEASES_JSON}" |grep 'browser_download_url'|\
+      grep --extended-regexp --only-matching "https://.*${RENODE_VERSION}.*linux-portable.*tar.gz"`
+  if [ -z "${LINUX_PORTABLE_URL}" ]; then
+    echo -e "${RED}Portable version of release v${RENODE_VERSION} not found. Please make sure you use correct version format ('[0-9]+.[0-9]+.[0-9]+')${NC}"
+    exit 1
+  fi
+
+  # Check if newer version available
+  LATEST_RENODE_VERSION=`echo "${RELEASES_JSON}" |grep 'tag_name' |\
+      head --lines 1 | grep --extended-regexp --only-matching '[0-9]+\.[0-9]+\.[0-9]+'`
+  if [ "${RENODE_VERSION}" != "${LATEST_RENODE_VERSION}" ]; then
+    echo -e "${ORANGE}Latest available version is ${LATEST_RENODE_VERSION}, please consider using it.${NC}" &>2
+  fi
+  echo >&2 "Downloading from url: ${LINUX_PORTABLE_URL}"
+
+  TEMP_ARCHIVE="/tmp/renode.tar.gz"
+  wget ${LINUX_PORTABLE_URL} -O ${TEMP_ARCHIVE} >&2
+
+  EXPECTED_MD5="8415361f5caa843f1e31b59c50b2858f"
+  MD5=`md5sum ${TEMP_ARCHIVE} | awk '{print $1}'`
+  if [[ ${MD5} != ${EXPECTED_MD5} ]]
+  then
+    echo "Bad checksum. Expected: ${EXPECTED_MD5}, Got: ${MD5}"
+    exit 1
+  fi
+
+  mkdir ${DOWNLOADED_RENODE_PATH}
+  tar xzf ${TEMP_ARCHIVE} --strip-components=1 --directory "${DOWNLOADED_RENODE_PATH}" >&2
+  echo >&2 "Unpacked to directory: ${DOWNLOADED_RENODE_PATH}"
+
+  pip3 install -r ${DOWNLOADED_RENODE_PATH}/tests/requirements.txt >&2
+fi
+
+echo "SUCCESS"
diff --git a/tensorflow/lite/micro/tools/make/targets/apollo3evb/apollo3evb.ld b/tensorflow/lite/micro/tools/make/targets/apollo3evb/apollo3evb.ld
index cd1182f..6ae8f1f 100644
--- a/tensorflow/lite/micro/tools/make/targets/apollo3evb/apollo3evb.ld
+++ b/tensorflow/lite/micro/tools/make/targets/apollo3evb/apollo3evb.ld
@@ -1,3 +1,18 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
 /******************************************************************************
  *
  * apollo3evb.ld - Linker script for applications using startup_gcc.c
diff --git a/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp.lcf b/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp.lcf
index 5dc53cc..0655a4a 100644
--- a/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp.lcf
+++ b/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp.lcf
@@ -1,8 +1,11 @@
 # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp_v2.lcf b/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp_v2.lcf
index 63ef486..a15bce1 100644
--- a/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp_v2.lcf
+++ b/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp_v2.lcf
@@ -1,8 +1,11 @@
 # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -10,6 +13,7 @@
 # limitations under the License.
 # ==============================================================================
 #
+#
 # Difference with common EMSDP LCF file (to reduce data access time): 
 # - move data from external PSRAM to DCCM
 # - move text from SRAM to ICCM
diff --git a/tensorflow/lite/micro/tools/make/targets/bluepill/bluepill.lds b/tensorflow/lite/micro/tools/make/targets/bluepill/bluepill.lds
index 7497684..b5d823a 100644
--- a/tensorflow/lite/micro/tools/make/targets/bluepill/bluepill.lds
+++ b/tensorflow/lite/micro/tools/make/targets/bluepill/bluepill.lds
@@ -1,8 +1,11 @@
-/* Copyright 2018 Google Inc. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
+
     http://www.apache.org/licenses/LICENSE-2.0
+
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc
index b57b552..863e2e1 100644
--- a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc
+++ b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc
@@ -6,6 +6,11 @@
 $(eval $(call add_third_party_download,$(CMSIS_URL),$(CMSIS_MD5),cmsis,))
 $(eval $(call add_third_party_download,$(STM32_BARE_LIB_URL),$(STM32_BARE_LIB_MD5),stm32_bare_lib,))
 
+DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/renode_download.sh ${MAKEFILE_DIR}/downloads)
+ifneq ($(DOWNLOAD_RESULT), SUCCESS)
+  $(error Something went wrong with the renode download: $(DOWNLOAD_RESULT))
+endif
+
 PLATFORM_FLAGS = \
   -DTF_LITE_MCU_DEBUG_LOG \
   -mcpu=cortex-m3 \
@@ -57,5 +62,10 @@
   tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc
 MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS))
 
-TEST_SCRIPT := tensorflow/lite/micro/testing/test_bluepill_binary.sh
+TEST_SCRIPT := tensorflow/lite/micro/testing/test_with_renode.sh
 
+# We are setting this variable to non-zero to allow us to have a custom
+# implementation of `make test` for bluepill
+TARGET_SPECIFIC_MAKE_TEST := 1
+test: build
+	$(TEST_SCRIPT) $(BINDIR) $(TARGET)
diff --git a/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM.ld b/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM.ld
index 75652b2..666c59a 100755
--- a/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM.ld
+++ b/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM.ld
@@ -1,3 +1,18 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
 OUTPUT(a.elf)
 
 /* By default, program starts from reset address (the default location of the interrupt table) */
diff --git a/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM_18.0.2.ld b/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM_18.0.2.ld
index 0abbef4..dce5330 100755
--- a/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM_18.0.2.ld
+++ b/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM_18.0.2.ld
@@ -1,3 +1,18 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
 OUTPUT(a.elf)
 
 /* By default, program starts from reset address (the default location of the interrupt table) */
diff --git a/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM_18.0.3.ld b/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM_18.0.3.ld
index 75652b2..0fa2044 100755
--- a/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM_18.0.3.ld
+++ b/tensorflow/lite/micro/tools/make/targets/ceva/CEVA_BX1_TFLM_18.0.3.ld
@@ -1,3 +1,19 @@
+
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
 OUTPUT(a.elf)
 
 /* By default, program starts from reset address (the default location of the interrupt table) */
diff --git a/tensorflow/lite/micro/tools/make/targets/chre_makefile.inc b/tensorflow/lite/micro/tools/make/targets/chre_makefile.inc
new file mode 100644
index 0000000..538c4a9
--- /dev/null
+++ b/tensorflow/lite/micro/tools/make/targets/chre_makefile.inc
@@ -0,0 +1,33 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Remove flexbuffers library and detection postprocess kernel from chre build
+# due to string dependencies.
+EXCLUDED_CC_SRCS := \
+  tensorflow/lite/micro/kernels/detection_postprocess.cc \
+  tensorflow/lite/micro/kernels/flexbuffers_generated_data.cc
+
+EXCLUDED_TESTS := \
+  tensorflow/lite/micro/kernels/detection_postprocess_test.cc
+
+EXCLUDED_HDRS := \
+  third_party/flatbuffers/include/flatbuffers/flexbuffers.h
+
+EXCLUDED_KERNEL_HDRS := \
+  tensorflow/lite/micro/kernels/flexbuffers_generated_data.h
+
+MICROLITE_CC_KERNEL_SRCS := $(filter-out $(EXCLUDED_CC_SRCS),$(MICROLITE_CC_KERNEL_SRCS))
+MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS),$(MICROLITE_TEST_SRCS))
+THIRD_PARTY_CC_HDRS := $(filter-out $(EXCLUDED_HDRS),$(THIRD_PARTY_CC_HDRS))
+MICROLITE_CC_HDRS := $(filter-out $(EXCLUDED_KERNEL_HDRS),$(MICROLITE_CC_HDRS))
diff --git a/tensorflow/lite/micro/tools/make/targets/ecm3531/_main.c b/tensorflow/lite/micro/tools/make/targets/ecm3531/_main.c
index ead3709..e3d0b88 100644
--- a/tensorflow/lite/micro/tools/make/targets/ecm3531/_main.c
+++ b/tensorflow/lite/micro/tools/make/targets/ecm3531/_main.c
@@ -25,6 +25,7 @@
 #include <stdint.h>
 #include <stdio.h>
 #include <string.h>
+
 #include "eta_bsp.h"
 #include "eta_chip.h"
 #include "eta_csp.h"
diff --git a/tensorflow/lite/micro/tools/make/targets/ecm3531/ecm3531.lds b/tensorflow/lite/micro/tools/make/targets/ecm3531/ecm3531.lds
index 383b7f9..58cb5eb 100644
--- a/tensorflow/lite/micro/tools/make/targets/ecm3531/ecm3531.lds
+++ b/tensorflow/lite/micro/tools/make/targets/ecm3531/ecm3531.lds
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
 limitations under the License.
 ==============================================================================*/
 
-
 /*
  * linker script for use with ECM3531
  * All sections must map to 128KBytes of SRAM beginning at 0x10000000
diff --git a/tensorflow/lite/micro/tools/make/targets/ecm3531/ecm3531_flash.lds b/tensorflow/lite/micro/tools/make/targets/ecm3531/ecm3531_flash.lds
index 9cbbea3..7b95754 100644
--- a/tensorflow/lite/micro/tools/make/targets/ecm3531/ecm3531_flash.lds
+++ b/tensorflow/lite/micro/tools/make/targets/ecm3531/ecm3531_flash.lds
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
diff --git a/tensorflow/lite/micro/tools/make/targets/ecm3531/startup.c b/tensorflow/lite/micro/tools/make/targets/ecm3531/startup.c
index 32d817b..5a1af2b 100644
--- a/tensorflow/lite/micro/tools/make/targets/ecm3531/startup.c
+++ b/tensorflow/lite/micro/tools/make/targets/ecm3531/startup.c
@@ -17,6 +17,7 @@
 calls _main() which is the entry point into the application */
 
 #include <stdint.h>
+
 #include "eta_chip.h"
 #include "memio.h"
 
@@ -30,9 +31,9 @@
 //
 //*****************************************************************************
 
-int _main(int argc, char *argv[]);
+int _main(int argc, char* argv[]);
 void set_vtor(void);
-void *startup_get_my_pc(void);
+void* startup_get_my_pc(void);
 
 //*****************************************************************************
 // Forward DECLS for interrupt service routines (ISR)
@@ -94,7 +95,7 @@
 __attribute__((section(".vectors"), used)) void (*const gVectors[])(void) = {
     //(void (*)(void))((uint32_t)pui32Stack + sizeof(pui32Stack)), // Stack
     // pointer
-    (void *)STARTUP_STACK_TOP,
+    (void*)STARTUP_STACK_TOP,
     ResetISR,           // Reset handler
     NmiSR,              // The NMI handler
     FaultISR,           // The hard fault handler
@@ -402,8 +403,8 @@
 ////////////////////////////////////////////////////////////////////////////////
 // get my PC
 ////////////////////////////////////////////////////////////////////////////////
-void *startup_get_my_pc(void) {
-  void *pc;
+void* startup_get_my_pc(void) {
+  void* pc;
   asm("mov %0, pc" : "=r"(pc));
   return pc;
 }
@@ -411,8 +412,8 @@
 ////////////////////////////////////////////////////////////////////////////////
 // get my SP
 ////////////////////////////////////////////////////////////////////////////////
-void *startup_get_my_sp(void) {
-  void *sp;
+void* startup_get_my_sp(void) {
+  void* sp;
   asm("mov %0, sp" : "=r"(sp));
   return sp;
 }
diff --git a/tensorflow/lite/micro/tools/make/targets/stm32f4/stm32f4.lds b/tensorflow/lite/micro/tools/make/targets/stm32f4/stm32f4.lds
index 1856368..b7603d9 100644
--- a/tensorflow/lite/micro/tools/make/targets/stm32f4/stm32f4.lds
+++ b/tensorflow/lite/micro/tools/make/targets/stm32f4/stm32f4.lds
@@ -1,10 +1,9 @@
-/* Copyright 2020 Google Inc. All Rights Reserved.
-
-Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
+
     http://www.apache.org/licenses/LICENSE-2.0
 
 Unless required by applicable law or agreed to in writing, software
diff --git a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc
index 3bdf1e9..6ae0dda 100644
--- a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc
+++ b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc
@@ -9,6 +9,11 @@
 $(eval $(call add_third_party_download,$(CMSIS_URL),$(CMSIS_MD5),cmsis,))
 $(eval $(call add_third_party_download,$(STM32_BARE_LIB_URL),$(STM32_BARE_LIB_MD5),stm32_bare_lib,))
 
+DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/renode_download.sh ${MAKEFILE_DIR}/downloads)
+ifneq ($(DOWNLOAD_RESULT), SUCCESS)
+  $(error Something went wrong with the renode download: $(DOWNLOAD_RESULT))
+endif
+
 # TODO(b/161478030) : change - Wno - vla to - Wvla and remove - Wno-shadow once
 # we have a solution for fixing / avoiding being tripped up by these warnings.
 PLATFORM_FLAGS = \
@@ -59,7 +64,6 @@
   $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/debug_log.c
 THIRD_PARTY_CC_SRCS := $(filter-out $(EXCLUDED_SRCS), $(THIRD_PARTY_CC_SRCS))
 MICROLITE_CC_SRCS := $(filter-out $(EXCLUDED_SRCS), $(MICROLITE_CC_SRCS))
-TEST_SCRIPT := tensorflow/lite/micro/testing/test_stm32f4_binary.sh
 
 # TODO(b/158324045): Examine why some tests fail here.
 EXCLUDED_TESTS := \
@@ -78,11 +82,11 @@
   tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc
 MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS))
 
-# These are microcontroller-specific rules for converting the ELF output
-# of the linker into a binary image that can be loaded directly.
-OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy
+TEST_SCRIPT := tensorflow/lite/micro/testing/test_with_renode.sh
 
-$(BINDIR)/%.bin: $(BINDIR)/%
-	@mkdir -p $(dir $@)
-	$(OBJCOPY) $< $@ -O binary
+# We are setting this variable to non-zero to allow us to have a custom
+# implementation of `make test` for bluepill
+TARGET_SPECIFIC_MAKE_TEST := 1
+test: build
+	$(TEST_SCRIPT) $(BINDIR) $(TARGET)
 
diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc
index f59f929..bf7de98 100644
--- a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc
+++ b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc
@@ -22,14 +22,23 @@
   $(error XTENSA_CORE is undefined)
 endif
 
+ifeq ($(TARGET_ARCH), )
+  $(error TARGET_ARCH must be specified on the command line)
+endif
+
+# Create a cflag based on the specified TARGET_ARCH. For example:
+#   TARGET_ARCH=hifimini --> -DHIFIMINI
+#   TARGET_ARCH=fusion_f1 --> -DFUSION_F1
+TARGET_ARCH_DEFINES := -D$(shell echo $(TARGET_ARCH) | tr [a-z] [A-Z])
+
 PLATFORM_FLAGS = \
   -DTF_LITE_MCU_DEBUG_LOG \
   -DTF_LITE_USE_CTIME \
   --xtensa-core=$(XTENSA_CORE) \
   -mcoproc \
-  -DXTENSA \
   -DMAX_RFFT_PWR=9 \
-  -DMIN_RFFT_PWR=MAX_RFFT_PWR
+  -DMIN_RFFT_PWR=MAX_RFFT_PWR \
+  $(TARGET_ARCH_DEFINES)
 
 ifeq ($(BUILD_TYPE), release)
   PLATFORM_FLAGS += -Wno-unused-private-field
diff --git a/tensorflow/lite/micro/xcore/debug_log.cc b/tensorflow/lite/micro/xcore/debug_log.cc
index c206f05..b964706 100644
--- a/tensorflow/lite/micro/xcore/debug_log.cc
+++ b/tensorflow/lite/micro/xcore/debug_log.cc
@@ -1,8 +1,11 @@
 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
+
     http://www.apache.org/licenses/LICENSE-2.0
+
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
diff --git a/tensorflow/lite/objc/TensorFlowLiteObjC.podspec b/tensorflow/lite/objc/TensorFlowLiteObjC.podspec
index 145cf02..05f8a51 100644
--- a/tensorflow/lite/objc/TensorFlowLiteObjC.podspec
+++ b/tensorflow/lite/objc/TensorFlowLiteObjC.podspec
@@ -1,6 +1,6 @@
 Pod::Spec.new do |s|
   s.name             = 'TensorFlowLiteObjC'
-  s.version          = '2.3.0'
+  s.version          = '2.4.0'
   s.authors          = 'Google Inc.'
   s.license          = { :type => 'Apache' }
   s.homepage         = 'https://github.com/tensorflow/tensorflow'
@@ -21,15 +21,7 @@
 
   tfl_dir = 'tensorflow/lite/'
   objc_dir = tfl_dir + 'experimental/objc/'
-  s.public_header_files = objc_dir + 'apis/*.h'
-  s.source_files = [
-    objc_dir + '{apis,sources}/*.{h,m,mm}',
-    tfl_dir + 'c/c_api.h',
-    tfl_dir + 'c/common.h',
-    tfl_dir + 'delegates/xnnpack/xnnpack_delegate.h',
-  ]
-  s.module_map = objc_dir + 'apis/framework.modulemap'
-  s.dependency 'TensorFlowLiteC', "#{s.version}"
+
   s.pod_target_xcconfig = {
     'HEADER_SEARCH_PATHS' =>
       '"${PODS_TARGET_SRCROOT}" ' +
@@ -37,11 +29,60 @@
     'VALID_ARCHS' => 'i386 x86_64 armv7 arm64',
   }
 
-  s.test_spec 'Tests' do |ts|
-    ts.source_files = objc_dir + 'tests/*.m'
-    ts.resources = [
-      tfl_dir + 'testdata/add.bin',
-      tfl_dir + 'testdata/add_quantized.bin',
+  s.default_subspec = 'Core'
+
+  s.subspec 'Core' do |core|
+    core.public_header_files = objc_dir + 'apis/*.h'
+    core.source_files = [
+      objc_dir + '{apis,sources}/*.{h,m,mm}',
+      tfl_dir + 'c/c_api.h',
+      tfl_dir + 'c/c_api_types.h',
+      tfl_dir + 'c/common.h',
+      tfl_dir + 'delegates/xnnpack/xnnpack_delegate.h',
     ]
+    core.exclude_files = [
+      objc_dir + '{apis,sources}/TFL{Metal,CoreML}Delegate.{h,m}',
+    ]
+    core.dependency 'TensorFlowLiteC', "#{s.version}"
+
+    core.test_spec 'Tests' do |ts|
+      ts.source_files = objc_dir + 'tests/*.m'
+      ts.exclude_files = objc_dir + 'tests/TFL{Metal,CoreML}DelegateTests.m'
+      ts.resources = [
+        tfl_dir + 'testdata/add.bin',
+        tfl_dir + 'testdata/add_quantized.bin',
+      ]
+    end
+  end
+
+  s.subspec 'CoreML' do |coreml|
+    coreml.source_files = [
+      objc_dir + '{apis,sources}/TFLCoreMLDelegate.{h,m}',
+    ]
+    coreml.ios.deployment_target = '12.0'
+    coreml.dependency 'TensorFlowLiteC/CoreML', "#{s.version}"
+    coreml.dependency 'TensorFlowLiteObjC/Core', "#{s.version}"
+
+    coreml.test_spec 'Tests' do |ts|
+      ts.source_files = objc_dir + 'tests/TFLCoreMLDelegateTests.m'
+      ts.resources = [
+        tfl_dir + 'testdata/add.bin',
+      ]
+    end
+  end
+
+  s.subspec 'Metal' do |metal|
+    metal.source_files = [
+      objc_dir + '{apis,sources}/TFLMetalDelegate.{h,m}',
+    ]
+    metal.dependency 'TensorFlowLiteC/Metal', "#{s.version}"
+    metal.dependency 'TensorFlowLiteObjC/Core', "#{s.version}"
+
+    metal.test_spec 'Tests' do |ts|
+      ts.source_files = objc_dir + 'tests/TFLMetalDelegateTests.m'
+      ts.resources = [
+        tfl_dir + 'testdata/multi_add.bin',
+      ]
+    end
   end
 end
diff --git a/tensorflow/lite/objc/TensorFlowLiteObjC.podspec.template b/tensorflow/lite/objc/TensorFlowLiteObjC.podspec.template
index 5d0d6c3..4cdbfa1 100644
--- a/tensorflow/lite/objc/TensorFlowLiteObjC.podspec.template
+++ b/tensorflow/lite/objc/TensorFlowLiteObjC.podspec.template
@@ -36,6 +36,7 @@
     core.source_files = [
       objc_dir + '{apis,sources}/*.{h,m,mm}',
       tfl_dir + 'c/c_api.h',
+      tfl_dir + 'c/c_api_types.h',
       tfl_dir + 'c/common.h',
       tfl_dir + 'delegates/xnnpack/xnnpack_delegate.h',
     ]
diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD
index b1ddcdf..5bb0b03 100644
--- a/tensorflow/lite/python/BUILD
+++ b/tensorflow/lite/python/BUILD
@@ -62,6 +62,7 @@
     visibility = ["//visibility:public"],
     deps = [
         ":tflite_convert_main_lib",
+        "//tensorflow:tensorflow_py",
         "@six_archive//:six",
     ],
 )
@@ -73,6 +74,7 @@
     visibility = ["//visibility:public"],
     deps = [
         ":tflite_convert_lib",
+        "//tensorflow:tensorflow_py",
         "@six_archive//:six",
     ],
 )
@@ -186,7 +188,10 @@
 py_test(
     name = "lite_test",
     srcs = ["lite_test.py"],
-    data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
+    data = [
+        "//tensorflow/lite/python/testdata:control_flow_v1.pbtxt",
+        "@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb",
+    ],
     python_version = "PY3",
     shard_count = 4,
     srcs_version = "PY2AND3",
@@ -205,6 +210,9 @@
 py_test(
     name = "lite_v2_test",
     srcs = ["lite_v2_test.py"],
+    data = [
+        "//tensorflow/lite/python/testdata/control_flow_v1_saved_model:saved_model.pb",
+    ],
     python_version = "PY3",
     shard_count = 12,
     srcs_version = "PY2AND3",
diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc
index 8e6e9e8..c75a318 100644
--- a/tensorflow/lite/python/interpreter_wrapper/numpy.cc
+++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc
@@ -73,6 +73,8 @@
       return kTfLiteFloat32;
     case NPY_FLOAT16:
       return kTfLiteFloat16;
+    case NPY_FLOAT64:
+      return kTfLiteFloat64;
     case NPY_INT32:
       return kTfLiteInt32;
     case NPY_INT16:
@@ -83,6 +85,8 @@
       return kTfLiteInt8;
     case NPY_INT64:
       return kTfLiteInt64;
+    case NPY_UINT64:
+      return kTfLiteUInt64;
     case NPY_BOOL:
       return kTfLiteBool;
     case NPY_OBJECT:
@@ -91,7 +95,8 @@
       return kTfLiteString;
     case NPY_COMPLEX64:
       return kTfLiteComplex64;
-      // Avoid default so compiler errors created when new types are made.
+    case NPY_COMPLEX128:
+      return kTfLiteComplex128;
   }
   return kTfLiteNoType;
 }
diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py
index 61f48ed..d631867 100644
--- a/tensorflow/lite/python/lite.py
+++ b/tensorflow/lite/python/lite.py
@@ -473,7 +473,10 @@
 
     if self._experimental_calibrate_only:
       return calibrated
-    elif self._experimental_new_quantizer:
+    elif self._experimental_new_quantizer and (
+        activations_type != _dtypes.int16):
+      # TODO(b/175659372): remove the activations_type restriction and enable
+      # it for all the activation types.
       return _mlir_quantize(calibrated)
     else:
       return calibrate_quantize.calibrate_and_quantize(
@@ -1334,7 +1337,7 @@
     if calibrate_quantize:
       result = self._calibrate_quantize_model(result, **flags)
 
-    if self.experimental_new_converter:
+    if self.experimental_new_converter or self._experimental_new_quantizer:
       flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
           self.inference_input_type, self.inference_output_type)
       if flags_modify_model_io_type:
diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py
index 6f1bd19..ed8059a 100644
--- a/tensorflow/lite/python/lite_test.py
+++ b/tensorflow/lite/python/lite_test.py
@@ -312,10 +312,14 @@
       ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
       ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
       ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
-      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
-  def testIntegerQuantizationWithUnsupportedOps(self, is_int_only,
+      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
+      ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
+      ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
+  def testIntegerQuantizationWithUnsupportedOps(self,
+                                                is_int_only,
                                                 is_int16_quantize,
-                                                inference_input_output_type):
+                                                inference_input_output_type,
+                                                enable_mlir_quantizer=False):
     with ops.Graph().as_default():
       in_tensor_a = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
       in_tensor_b = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
@@ -363,23 +367,25 @@
 
     quantized_converter.inference_input_type = inference_input_output_type
     quantized_converter.inference_output_type = inference_input_output_type
+    quantized_converter._experimental_new_quantizer = enable_mlir_quantizer
     quantized_tflite_model = quantized_converter.convert()
     self.assertIsNotNone(quantized_tflite_model)
 
+    expected_dtype = inference_input_output_type.as_numpy_dtype
+    # Allow float32 for fallback on non-quantizable op.
+    expected_ceil_dtype = (
+        expected_dtype if enable_mlir_quantizer else dtypes.float32)
+
     interpreter = Interpreter(model_content=quantized_tflite_model)
     interpreter.allocate_tensors()
     input_details = interpreter.get_input_details()
     self.assertLen(input_details, 2)
-    # Allow float32 for fallback.
-    self.assertEqual(input_details[0]['dtype'], dtypes.float32)
-    self.assertEqual(input_details[1]['dtype'],
-                     inference_input_output_type.as_numpy_dtype)
+    self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype)
+    self.assertEqual(input_details[1]['dtype'], expected_dtype)
     output_details = interpreter.get_output_details()
     self.assertLen(output_details, 2)
-    # Allow float32 for fallback.
-    self.assertEqual(output_details[0]['dtype'], dtypes.float32)
-    self.assertEqual(output_details[1]['dtype'],
-                     inference_input_output_type.as_numpy_dtype)
+    self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
+    self.assertEqual(output_details[1]['dtype'], expected_dtype)
 
   @parameterized.named_parameters(
       ('EnableMlirConverter', True),  # enable mlir
@@ -1108,27 +1114,35 @@
 
   @parameterized.named_parameters(
       # Quantize to Float16 even if rep data provided.
-      ('UseRepresentativeData', True, False, True, False, False, False),
+      ('UseRepresentativeData', True, False, True, False, False, False, False),
       # Quantize to Float16 if no rep data provided.
-      ('NoRepresentativeData', False, False, True, False, False, False),
+      ('NoRepresentativeData', False, False, True, False, False, False, False),
       # Post training quantization if both rep data and int8 included.
-      ('UseSampleDataIncludeInt8', True, True, False, False, True, False),
-
+      ('UseSampleDataIncludeInt8', True, True, False, False, True, False, False
+      ),
       # Quantize to Float16 even if rep data provided with mlir.
-      ('UseRepresentativeDataMlir', True, False, True, False, False, True),
+      ('UseRepresentativeDataMlir', True, False, True, False, False, True, False
+      ),
       # Quantize to Float16 if no rep data provided with mlir.
-      ('NoRepresentativeDataMlir', False, False, True, False, False, True),
+      ('NoRepresentativeDataMlir', False, False, True, False, False, True, False
+      ),
       # Post training quantization if both rep data and int8 included with mlir.
-      ('SampleDataIncludeInt8Mlir', True, True, False, False, True, True))
+      ('SampleDataIncludeInt8Mlir', True, True, False, False, True, True, False
+      ),
+      # Same as above, but using MLIR quantizer
+      ('SampleDataIncludeInt8MlirQuant', True, True, False, False, True, True,
+       True))
   def testQuantizeFloat16(self, use_rep_data, include_int8,
                           is_float16_quantized, is_error,
-                          is_post_training_quantized, enable_mlir_converter):
+                          is_post_training_quantized, enable_mlir_converter,
+                          enable_mlir_quantizer):
     with ops.Graph().as_default():
       inp, output, calibration_gen = self._getIntegerQuantizeModel()
       sess = session.Session()
 
-    idx = 1 if enable_mlir_converter else 0
-    node_name = 'Conv2D' if enable_mlir_converter else 'Conv2D_bias'
+    bias_idx = 1 if enable_mlir_converter else 0
+    bias_name = 'Conv2D' if enable_mlir_converter else 'Conv2D_bias'
+
     # Convert float model.
     float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
     float_converter.experimental_new_converter = enable_mlir_converter
@@ -1136,13 +1150,20 @@
     self.assertIsNotNone(float_tflite_model)
     interpreter = Interpreter(model_content=float_tflite_model)
     interpreter.allocate_tensors()
-    self.assertEqual(interpreter.get_tensor_details()[idx]['name'], node_name)
-    self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
+    self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'],
+                     bias_name)
+    self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
                      dtypes.float32)
+
+    # MLIR quantizer has different bias index.
+    if enable_mlir_quantizer:
+      bias_idx = 2
+
     # Convert model to quantized version
     quantized_converter = lite.TFLiteConverter.from_session(
         sess, [inp], [output])
     quantized_converter.experimental_new_converter = enable_mlir_converter
+    quantized_converter._experimental_new_quantizer = enable_mlir_quantizer
     quantized_converter.optimizations = [lite.Optimize.DEFAULT]
     quantized_converter.target_spec.supported_types = [dtypes.float16]
     if include_int8:
@@ -1162,15 +1183,16 @@
       self.assertIsNotNone(quantized_tflite_model)
       interpreter = Interpreter(model_content=quantized_tflite_model)
       interpreter.allocate_tensors()
-      self.assertEqual(interpreter.get_tensor_details()[idx]['name'], node_name)
+      self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'],
+                       bias_name)
 
       if is_float16_quantized:
         # Verify that bias constant is float16 type.
-        self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
+        self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
                          dtypes.float16)
       elif is_post_training_quantized:
         # Verify that bias constants is int32 type.
-        self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
+        self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
                          dtypes.int32)
       else:
         raise ValueError('Invalid test options.')
@@ -2740,5 +2762,24 @@
     self.assertIsNone(converter.conversion_summary_dir)
 
 
+class ControlFlowV1OpsTest(LiteTest):
+
+  def testConverterErrorOnControlFlowV1Ops(self):
+    graph_def_file = resource_loader.get_path_to_datafile(
+        'testdata/control_flow_v1.pbtxt')
+    input_arrays = ['a', 'b', 'c', 'd']
+    output_arrays = ['Merge']
+
+    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
+                                                       input_arrays,
+                                                       output_arrays)
+    with self.assertRaises(ConverterError) as error:
+      converter.convert()
+    self.assertIn(
+        'Failed to functionalize Control Flow V1 ops. Consider using Control '
+        'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
+        'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index 97ba64f..b0cea7d 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -28,6 +28,7 @@
 import tensorflow as tf
 
 from tensorflow.lite.kernels.hashtable import pywrap_hashtable_ops as hashtable_ops_registerer
+from tensorflow.lite.python import convert
 from tensorflow.lite.python import lite
 from tensorflow.lite.python import lite_v2_test_util
 from tensorflow.lite.python.convert import mlir_quantize
@@ -38,6 +39,7 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.lib.io import file_io
+from tensorflow.python.platform import resource_loader
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import save_options
 from tensorflow.python.saved_model import saved_model
@@ -312,9 +314,7 @@
     converter = lite.TFLiteConverterV2.from_concrete_functions([func])
     # TODO(b/156309549): We should add INT16 to the builtin types.
     converter.optimizations = [lite.Optimize.DEFAULT]
-    converter.target_spec.supported_ops = [
-        lite.OpsSet.TFLITE_BUILTINS_INT8
-    ]
+    converter.target_spec.supported_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8]
     converter.representative_dataset = calibration_gen
     converter._experimental_calibrate_only = True
     calibrated_tflite = converter.convert()
@@ -608,11 +608,15 @@
       ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
       ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
       ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
-      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
+      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
+      ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
+      ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
   @test_util.run_v2_only
-  def testIntegerQuantizationWithUnsupportedOps(self, is_int_only,
+  def testIntegerQuantizationWithUnsupportedOps(self,
+                                                is_int_only,
                                                 is_int16_quantize,
-                                                inference_input_output_type):
+                                                inference_input_output_type,
+                                                enable_mlir_quantizer=False):
     func, calib_gen = self._getIntegerQuantizationModelWithUnsupportedOps()
 
     quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
@@ -644,23 +648,25 @@
 
     quantized_converter.inference_input_type = inference_input_output_type
     quantized_converter.inference_output_type = inference_input_output_type
+    quantized_converter._experimental_new_quantizer = enable_mlir_quantizer
     quantized_tflite_model = quantized_converter.convert()
     self.assertIsNotNone(quantized_tflite_model)
 
+    expected_dtype = inference_input_output_type.as_numpy_dtype
+    # Allow float32 for fallback on non-quantizable op.
+    expected_ceil_dtype = (
+        expected_dtype if enable_mlir_quantizer else dtypes.float32)
+
     interpreter = Interpreter(model_content=quantized_tflite_model)
     interpreter.allocate_tensors()
     input_details = interpreter.get_input_details()
     self.assertLen(input_details, 2)
-    # Allow float32 for fallback.
-    self.assertEqual(input_details[0]['dtype'], dtypes.float32)
-    self.assertEqual(input_details[1]['dtype'],
-                     inference_input_output_type.as_numpy_dtype)
+    self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype)
+    self.assertEqual(input_details[1]['dtype'], expected_dtype)
     output_details = interpreter.get_output_details()
     self.assertLen(output_details, 2)
-    # Allow float32 for fallback.
-    self.assertEqual(output_details[0]['dtype'], dtypes.float32)
-    self.assertEqual(output_details[1]['dtype'],
-                     inference_input_output_type.as_numpy_dtype)
+    self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
+    self.assertEqual(output_details[1]['dtype'], expected_dtype)
 
 
 class FromSavedModelTest(lite_v2_test_util.ModelTest):
@@ -919,8 +925,7 @@
     self.assertEqual(len(signature_defs.values()), 1)
     self.assertEqual(
         list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
-    self.assertEqual(
-        sorted(signature_defs['mul_add']['inputs']), ['x', 'y'])
+    self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
     self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
 
   @test_util.run_v2_only
@@ -960,8 +965,7 @@
     self.assertEqual(len(signature_defs.values()), 1)
     self.assertEqual(
         list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
-    self.assertEqual(
-        sorted(signature_defs['mul_add']['inputs']), ['x', 'y'])
+    self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
     self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
 
   @test_util.run_v2_only
@@ -1264,6 +1268,18 @@
     self.assertAllClose(expected_value, actual_value)
 
   @test_util.run_v2_only
+  def testConverterErrorOnControlFlowV1Ops(self):
+    filename = resource_loader.get_path_to_datafile(
+        'testdata/control_flow_v1_saved_model')
+    converter = lite.TFLiteConverterV2.from_saved_model(filename)
+    with self.assertRaises(convert.ConverterError) as error:
+      converter.convert()
+    self.assertIn(
+        'Failed to functionalize Control Flow V1 ops. Consider using Control '
+        'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
+        'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
+
+  @test_util.run_v2_only
   def testStaticRnn(self):
     input_data = tf.constant(
         np.array(np.random.random_sample((3, 10)), dtype=np.float32))
diff --git a/tensorflow/lite/python/testdata/BUILD b/tensorflow/lite/python/testdata/BUILD
index f86cfb6..12f8a6a 100644
--- a/tensorflow/lite/python/testdata/BUILD
+++ b/tensorflow/lite/python/testdata/BUILD
@@ -12,7 +12,10 @@
     licenses = ["notice"],  # Apache 2.0,
 )
 
-exports_files(glob(["*.pb"]))
+exports_files(glob([
+    "*.pb",
+    "*.pbtxt",
+]))
 
 tf_to_tflite(
     name = "permute_float",
diff --git a/tensorflow/lite/python/testdata/control_flow_v1.pbtxt b/tensorflow/lite/python/testdata/control_flow_v1.pbtxt
new file mode 100644
index 0000000..b481359
--- /dev/null
+++ b/tensorflow/lite/python/testdata/control_flow_v1.pbtxt
@@ -0,0 +1,64 @@
+node {
+  name: "a"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+node {
+  name: "b"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+node {
+  name: "c"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+node {
+  name: "d"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+node {
+  name: "Merge"
+  op: "Merge"
+  input: "a"
+  input: "b"
+  input: "c"
+  input: "d"
+  attr {
+    key: "N"
+    value {
+      i: 4
+    }
+  }
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+
+versions {
+  producer: 27
+}
diff --git a/tensorflow/lite/python/testdata/control_flow_v1_saved_model/BUILD b/tensorflow/lite/python/testdata/control_flow_v1_saved_model/BUILD
new file mode 100644
index 0000000..53005ff
--- /dev/null
+++ b/tensorflow/lite/python/testdata/control_flow_v1_saved_model/BUILD
@@ -0,0 +1,8 @@
+package(
+    default_visibility = ["//tensorflow:internal"],
+    licenses = ["notice"],  # Apache 2.0,
+)
+
+exports_files([
+    "saved_model.pb",
+])
diff --git a/tensorflow/lite/python/testdata/control_flow_v1_saved_model/saved_model.pb b/tensorflow/lite/python/testdata/control_flow_v1_saved_model/saved_model.pb
new file mode 100644
index 0000000..76d1b70
--- /dev/null
+++ b/tensorflow/lite/python/testdata/control_flow_v1_saved_model/saved_model.pb
Binary files differ
diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py
index 49275a7..0833e2b 100644
--- a/tensorflow/lite/python/tflite_convert.py
+++ b/tensorflow/lite/python/tflite_convert.py
@@ -26,6 +26,8 @@
 
 import six
 from six.moves import zip
+# Needed to enable TF2 by default.
+import tensorflow as tf  # pylint: disable=unused-import
 
 from tensorflow.lite.python import lite
 from tensorflow.lite.python.convert import register_custom_opdefs
diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py
index 7736c83..3bccf73 100644
--- a/tensorflow/lite/python/util.py
+++ b/tensorflow/lite/python/util.py
@@ -652,7 +652,12 @@
     if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
       quant_opcode_idxs.append(idx)
   if operators and not quant_opcode_idxs:
-    raise ValueError("Model input is not quantized.")
+    for input_idx in subgraph.inputs:
+      input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
+      if input_type == dtypes.float32:
+        raise ValueError("Model input is not dequantized.")
+    # None of the inputs have float32, then they must be int16, int8, or bool
+    return
 
   # Validate that the model input is quantized
   input_quant_ops = []
@@ -663,10 +668,13 @@
       # If found, validate that the operator's input type is float
       float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
       if float_type != dtypes.float32:
-        raise ValueError(
-            "Initial model input type must be tf.float32. Expected type for "
-            "tensor with name '{}' is tf.float32, instead type is {}".format(
-                float_tensor.name, _get_tf_type_name(float_type)))
+        if float_type == inference_input_type:
+          continue
+        else:
+          raise ValueError(
+              "Initial model input type must be tf.float32. Expected type for "
+              "tensor with name '{}' is tf.float32, instead type is {}".format(
+                  float_tensor.name, _get_tf_type_name(float_type)))
       # If found, validate that the operator output is quantized and compatible
       # with the final model input type
       quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
@@ -737,7 +745,12 @@
     if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
       dequant_opcode_idxs.append(idx)
   if operators and not dequant_opcode_idxs:
-    raise ValueError("Model output is not dequantized.")
+    for output in subgraph.outputs:
+      output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type)
+      if output_type == dtypes.float32:
+        raise ValueError("Model output is not dequantized.")
+    # None of the outputs have float32, then they must be int16, int8, or bool
+    return
 
   # Validate that the model output is dequantized
   output_dequant_ops = []
@@ -749,10 +762,13 @@
       quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
       float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
       if float_type != dtypes.float32:
-        raise ValueError(
-            "Initial model output type must be tf.float32. Expected type for "
-            "tensor with name '{}' is tf.float32, instead type is {}".format(
-                float_tensor.name, _get_tf_type_name(float_type)))
+        if float_type == inference_output_type:
+          continue
+        else:
+          raise ValueError(
+              "Initial model output type must be tf.float32. Expected type for "
+              "tensor with name '{}' is tf.float32, instead type is {}".format(
+                  float_tensor.name, _get_tf_type_name(float_type)))
       # If found, validate that the operator input is quantized and compatible
       # with the final model output type
       quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py
index e98b50d..cf71fb5 100644
--- a/tensorflow/lite/python/util_test.py
+++ b/tensorflow/lite/python/util_test.py
@@ -371,11 +371,18 @@
       model = None
     # Run model inference with float input output type
     output_data = _run_tflite_inference(model, tf.float32, tf.float32)
-    # Run model inference with modified integer input output type
+    # Modify the model io types to the target input/output types.
     model_io = util.modify_model_io_type(model, in_tftype, out_tftype)
+    # Run model inference with modified integer input output type
     output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
+    # Validate that both the outputs are the same
+    self.assertAllClose(output_data, output_io_data, atol=1.0)
 
-     # Validate that both the outputs are the same
+    # Modify the model with the target input/output types should be a no op.
+    model_io = util.modify_model_io_type(model_io, in_tftype, out_tftype)
+    # Run model inference with modified integer input output type
+    output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
+    # Validate that both the outputs are the same
     self.assertAllClose(output_data, output_io_data, atol=1.0)
 
 
diff --git a/tensorflow/lite/shared_library.h b/tensorflow/lite/shared_library.h
index a7bd91b..90b3dba 100644
--- a/tensorflow/lite/shared_library.h
+++ b/tensorflow/lite/shared_library.h
@@ -36,6 +36,8 @@
     return reinterpret_cast<void*>(
         GetProcAddress(static_cast<HMODULE>(handle), symbol));
   }
+  // Warning: Unlike dlsym(RTLD_DEFAULT), it doesn't search the symbol from
+  // dependent DLLs.
   static inline void* GetSymbol(const char* symbol) {
     return reinterpret_cast<void*>(GetProcAddress(nullptr, symbol));
   }
diff --git a/tensorflow/lite/swift/TensorFlowLiteSwift.podspec b/tensorflow/lite/swift/TensorFlowLiteSwift.podspec
index 8af52ef..8e9183a 100644
--- a/tensorflow/lite/swift/TensorFlowLiteSwift.podspec
+++ b/tensorflow/lite/swift/TensorFlowLiteSwift.podspec
@@ -1,6 +1,6 @@
 Pod::Spec.new do |s|
   s.name             = 'TensorFlowLiteSwift'
-  s.version          = '2.3.0'
+  s.version          = '2.4.0'
   s.authors          = 'Google Inc.'
   s.license          = { :type => 'Apache' }
   s.homepage         = 'https://github.com/tensorflow/tensorflow'
@@ -21,9 +21,6 @@
   tfl_dir = 'tensorflow/lite/'
   swift_dir = tfl_dir + 'experimental/swift/'
 
-  tfl_dir = 'tensorflow/lite/'
-  swift_dir = tfl_dir + 'experimental/swift/'
-
   s.default_subspec = 'Core'
 
   s.subspec 'Core' do |core|
@@ -57,6 +54,7 @@
       ts.resources = [
         tfl_dir + 'testdata/add.bin',
         tfl_dir + 'testdata/add_quantized.bin',
+        tfl_dir + 'testdata/multi_add.bin',
       ]
     end
   end
diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD
index dbbb69b..ba36ff3 100644
--- a/tensorflow/lite/testing/BUILD
+++ b/tensorflow/lite/testing/BUILD
@@ -12,7 +12,6 @@
 load("//tensorflow:tensorflow.bzl", "pybind_extension")
 load(
     "//tensorflow:tensorflow.bzl",
-    "py_test",  # @unused
     "tf_cc_binary",
     "tf_cc_test",
 )
@@ -232,6 +231,7 @@
         ":split",
         ":test_runner",
         "@com_google_absl//absl/strings",
+        "//tensorflow/lite/c:common",
         "//tensorflow/lite:builtin_op_data",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:string_util",
@@ -240,6 +240,8 @@
         "//tensorflow/lite/kernels:reference_ops",
         "//tensorflow/lite/kernels:test_delegate_providers_lib",
         "//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
+        "//tensorflow/lite/kernels/parse_example",
+        "//tensorflow/lite/kernels/perception:perception_ops",
         "//tensorflow/lite/tools/evaluation:utils",
     ] + select({
         "//tensorflow:ios": [],
diff --git a/tensorflow/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py
index bc87fde..ed08373 100644
--- a/tensorflow/lite/testing/generate_examples.py
+++ b/tensorflow/lite/testing/generate_examples.py
@@ -91,6 +91,10 @@
     help=("Comma-separated list of test set names to generate. "
           "If not specified, a test set is selected by parsing the name of "
           "'zip_to_output' file."))
+parser.add_argument(
+    "--mlir_quantizer",
+    action="store_true",
+    help=("Whether the new MLIR quantizer is being used."))
 
 
 # Toco binary path provided by the generate rule.
@@ -116,6 +120,7 @@
   options.tflite_convert_function = toco_convert.toco_convert
   options.no_tests_limit = FLAGS.no_tests_limit
   options.no_conversion_report = FLAGS.no_conversion_report
+  options.mlir_quantizer = FLAGS.mlir_quantizer
 
   if FLAGS.test_sets:
     test_sets = FLAGS.test_sets.split(",")
diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py
index 009a407..bb7e9db 100644
--- a/tensorflow/lite/testing/generate_examples_lib.py
+++ b/tensorflow/lite/testing/generate_examples_lib.py
@@ -238,6 +238,7 @@
     # TODO(juhoha): Separate the state from the options.
     self.multi_gen_state = None
     self.use_experimental_converter = False
+    self.mlir_quantizer = False
 
 
 def _prepare_dir(options):
@@ -273,7 +274,10 @@
   else:
     # Remove suffixes to extract the test name from the output name.
     test_name = re.sub(
-        r"(_(|toco-flex|forward-compat|edgetpu))?\.zip$", "", out, count=1)
+        r"(_(|toco-flex|forward-compat|edgetpu|mlir-quant))?\.zip$",
+        "",
+        out,
+        count=1)
 
   test_function_name = "make_%s_tests" % test_name
   test_function = get_test_function(test_function_name)
@@ -313,7 +317,10 @@
 
       # Remove suffix and set test_name to run proper test generation function.
       multi_gen_state.test_name = re.sub(
-          r"(_(|toco-flex|forward-compat))?$", "", test_name, count=1)
+          r"(_(|toco-flex|forward-compat|mlir-quant))?$",
+          "",
+          test_name,
+          count=1)
       # Set label base path to write test data files with proper path.
       multi_gen_state.label_base_path = os.path.join(
           os.path.dirname(zip_path), test_name + ".zip")
diff --git a/tensorflow/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc
index 9ba0cff..60eff92 100644
--- a/tensorflow/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/lite/testing/generated_examples_zip_test.cc
@@ -255,7 +255,7 @@
   size_t pos = 0;
   int added = 0;
   while (true) {
-    size_t end_pos = manifest.find("\n", pos);
+    size_t end_pos = manifest.find('\n', pos);
     if (end_pos == string::npos) break;
     string filename = manifest.substr(pos, end_pos - pos);
     test_paths->push_back(dir + "/" + filename);
@@ -294,13 +294,13 @@
   string test_path_and_label = GetParam();
   string test_path = test_path_and_label;
   string label = test_path_and_label;
-  size_t end_pos = test_path_and_label.find(" ");
+  size_t end_pos = test_path_and_label.find(' ');
   if (end_pos != string::npos) {
     test_path = test_path_and_label.substr(0, end_pos);
     label = test_path_and_label.substr(end_pos + 1);
   }
   string tflite_test_case = test_path + "_tests.txt";
-  string tflite_dir = test_path.substr(0, test_path.find_last_of("/"));
+  string tflite_dir = test_path.substr(0, test_path.find_last_of('/'));
   string test_name = label.substr(label.find_last_of('/'));
 
   std::ifstream tflite_stream(tflite_test_case);
diff --git a/tensorflow/lite/testing/model_coverage/README.md b/tensorflow/lite/testing/model_coverage/README.md
new file mode 100644
index 0000000..da4635d
--- /dev/null
+++ b/tensorflow/lite/testing/model_coverage/README.md
@@ -0,0 +1,21 @@
+# TensorFlow Lite model coverage tests
+
+Various conversion tests on popular mobile models.
+
+
+## Golden values
+
+Some tests rely on pre-computed golden values. The main goal is to detect
+changes affecting unintended parts of TFLite.
+
+Should a golden value test fail after an intended change, the golden values can
+be updated with the following command:
+
+```
+bazel run //third_party/tensorflow/lite/testing/model_coverage:<target> --test_output=all -- --update_goldens
+```
+
+Notice `bazel run` instead of `bazel test` and the addition of the
+`--update_golden` flag.
+
+The updated golden data files must then be included in the change list.
diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py
index 0235de9..580e47e 100644
--- a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py
+++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py
@@ -28,6 +28,7 @@
 from google.protobuf.message import DecodeError
 from tensorflow.core.framework import graph_pb2 as _graph_pb2
 from tensorflow.lite.python import convert_saved_model as _convert_saved_model
+from tensorflow.lite.python import interpreter as _interpreter
 from tensorflow.lite.python import lite as _lite
 from tensorflow.lite.python import util as _util
 from tensorflow.python.client import session as _session
@@ -37,12 +38,21 @@
 from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
 from tensorflow.python.lib.io import file_io as _file_io
 from tensorflow.python.platform import resource_loader as _resource_loader
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.saved_model import load as _load
 from tensorflow.python.saved_model import loader as _loader
 from tensorflow.python.saved_model import signature_constants as _signature_constants
 from tensorflow.python.saved_model import tag_constants as _tag_constants
 
 
+_GOLDENS_UPDATE_WARNING = """
+  Golden file update requested!
+  This test is now going to write new golden files.
+
+  Make sure to package the updates together with your CL.
+"""
+
+
 def get_filepath(filename, base_dir=None):
   """Returns the full path of the filename.
 
@@ -59,6 +69,17 @@
                       base_dir, filename)
 
 
+def get_golden_filepath(name):
+  """Returns the full path to a golden values file.
+
+  Args:
+    name: the name of the golden data, usually same as the test name.
+  """
+  goldens_directory = os.path.join(_resource_loader.get_data_files_path(),
+                                   "testdata", "golden")
+  return os.path.join(goldens_directory, "%s.npy.golden" % name)
+
+
 def get_image(size):
   """Returns an image loaded into an np.ndarray with dims [1, size, size, 3].
 
@@ -150,7 +171,9 @@
     raise ValueError("Could not find int16 activations.")
 
 
-def _get_tflite_interpreter(tflite_model, input_shapes_resize=None):
+def _get_tflite_interpreter(tflite_model,
+                            input_shapes_resize=None,
+                            custom_op_registerers=None):
   """Creates a TFLite interpreter with resized input tensors.
 
   Args:
@@ -158,11 +181,15 @@
     input_shapes_resize: A map where the key is the input tensor name and the
       value is the shape of the input tensor. This resize happens after model
       conversion, prior to calling allocate tensors. (default None)
+    custom_op_registerers: Op registerers for custom ops.
 
   Returns:
     lite.Interpreter
   """
-  interpreter = _lite.Interpreter(model_content=tflite_model)
+  if custom_op_registerers is None:
+    custom_op_registerers = []
+  interpreter = _interpreter.InterpreterWithCustomOps(
+      model_content=tflite_model, custom_op_registerers=custom_op_registerers)
   if input_shapes_resize:
     input_details = interpreter.get_input_details()
     input_details_map = {
@@ -174,17 +201,19 @@
   return interpreter
 
 
-def _get_input_data_map(tflite_model, input_data):
+def _get_input_data_map(tflite_model, input_data, custom_op_registerers=None):
   """Generates a map of input data based on the TFLite model.
 
   Args:
     tflite_model: Serialized TensorFlow Lite model.
     input_data: List of np.ndarray.
+    custom_op_registerers: Op registerers for custom ops.
 
   Returns:
     {str: [np.ndarray]}.
   """
-  interpreter = _get_tflite_interpreter(tflite_model)
+  interpreter = _get_tflite_interpreter(
+      tflite_model, custom_op_registerers=custom_op_registerers)
   interpreter.allocate_tensors()
   input_details = interpreter.get_input_details()
   return {
@@ -196,7 +225,8 @@
 def _generate_random_input_data(tflite_model,
                                 seed=None,
                                 input_data_range=None,
-                                input_shapes_resize=None):
+                                input_shapes_resize=None,
+                                custom_op_registerers=None):
   """Generates input data based on the input tensors in the TFLite model.
 
   Args:
@@ -210,11 +240,15 @@
     input_shapes_resize: A map where the key is the input tensor name and the
       value is the shape of the input tensor. This resize happens after model
       conversion, prior to calling allocate tensors. (default None)
+    custom_op_registerers: Op registerers for custom ops.
 
   Returns:
     ([np.ndarray], {str : [np.ndarray]}).
   """
-  interpreter = _get_tflite_interpreter(tflite_model, input_shapes_resize)
+  interpreter = _get_tflite_interpreter(
+      tflite_model,
+      input_shapes_resize,
+      custom_op_registerers=custom_op_registerers)
   interpreter.allocate_tensors()
   input_details = interpreter.get_input_details()
 
@@ -234,11 +268,15 @@
             ) * val + input_data_range[input_tensor["name"]][0]
     input_data.append(np.array(val, dtype=input_tensor["dtype"]))
 
-  input_data_map = _get_input_data_map(tflite_model, input_data)
+  input_data_map = _get_input_data_map(
+      tflite_model, input_data, custom_op_registerers=custom_op_registerers)
   return input_data, input_data_map
 
 
-def _evaluate_tflite_model(tflite_model, input_data, input_shapes_resize=None):
+def _evaluate_tflite_model(tflite_model,
+                           input_data,
+                           input_shapes_resize=None,
+                           custom_op_registerers=None):
   """Returns evaluation of input data on TFLite model.
 
   Args:
@@ -247,11 +285,15 @@
     input_shapes_resize: A map where the key is the input tensor name and the
       value is the shape of the input tensor. This resize happens after model
       conversion, prior to calling allocate tensors. (default None)
+    custom_op_registerers: Op registerers for custom ops.
 
   Returns:
     List of np.ndarray.
   """
-  interpreter = _get_tflite_interpreter(tflite_model, input_shapes_resize)
+  interpreter = _get_tflite_interpreter(
+      tflite_model,
+      input_shapes_resize,
+      custom_op_registerers=custom_op_registerers)
   interpreter.allocate_tensors()
 
   input_details = interpreter.get_input_details()
@@ -383,6 +425,31 @@
     np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
 
 
+def _compare_tf_tflite_results(tf_results,
+                               tflite_results,
+                               tflite_labels,
+                               tolerance=5):
+  """Compare the result of TF and TFLite model.
+
+  Args:
+    tf_results: results returned by the TF model.
+    tflite_results: results returned by the TFLite model.
+    tflite_labels: names of the output tensors in the TFlite model.
+    tolerance: Decimal place to check accuracy to. (default 5).
+  """
+  # Convert the output TensorFlow results into an ordered list.
+  if isinstance(tf_results, dict):
+    if len(tf_results) == 1:
+      tf_results = [tf_results[list(tf_results.keys())[0]]]
+    else:
+      tf_results = [tf_results[tflite_label] for tflite_label in tflite_labels]
+  else:
+    tf_results = [tf_results]
+
+  for tf_result, tflite_result in zip(tf_results, tflite_results):
+    np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
+
+
 def compare_models_v2(tflite_model,
                       tf_eval_func,
                       input_data=None,
@@ -424,17 +491,84 @@
   tflite_results, tflite_labels = _evaluate_tflite_model(
       tflite_model, input_data)
 
-  # Convert the output TensorFlow results into an ordered list.
-  if isinstance(tf_results, dict):
-    if len(tf_results) == 1:
-      tf_results = [tf_results[list(tf_results.keys())[0]]]
-    else:
-      tf_results = [tf_results[tflite_label] for tflite_label in tflite_labels]
-  else:
-    tf_results = [tf_results]
+  _compare_tf_tflite_results(tf_results, tflite_results, tflite_labels,
+                             tolerance)
 
-  for tf_result, tflite_result in zip(tf_results, tflite_results):
-    np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
+
+def compare_tflite_keras_models_v2(tflite_model,
+                                   keras_model,
+                                   input_data=None,
+                                   input_data_range=None,
+                                   tolerance=5,
+                                   custom_op_registerers=None):
+  """Similar to compare_models_v2 but accept Keras model.
+
+  Unless the input data is provided, the models are compared with random data.
+  Currently only 1 input and 1 output are supported by this function.
+
+  Args:
+    tflite_model: Serialized TensorFlow Lite model.
+    keras_model: Keras model to evaluate.
+    input_data: np.ndarray to pass into models during inference. (default None).
+    input_data_range: A map where the key is the input tensor name and the value
+      is a tuple (min_val, max_val) which specifies the value range of
+      the corresponding input tensor. For example, '{'input1': (1, 5)}' means to
+      generate a random value for tensor `input1` within range [1.0, 5.0)
+      (half-inclusive). (default None)
+    tolerance: Decimal place to check accuracy to. (default 5)
+    custom_op_registerers: Op registerers for custom ops.
+  """
+  # Generate random input data if not provided.
+  if input_data is None:
+    input_data, _ = _generate_random_input_data(
+        tflite_model=tflite_model,
+        input_data_range=input_data_range,
+        custom_op_registerers=custom_op_registerers)
+
+  if len(input_data) > 1:
+    tf_results = keras_model.predict(input_data)
+  else:
+    tf_results = keras_model.predict(input_data[0])
+  tflite_results, tflite_labels = _evaluate_tflite_model(
+      tflite_model, input_data, custom_op_registerers=custom_op_registerers)
+
+  _compare_tf_tflite_results(tf_results, tflite_results, tflite_labels,
+                             tolerance)
+
+
+def compare_model_golden(tflite_model,
+                         input_data,
+                         golden_name,
+                         update_golden=False,
+                         tolerance=5):
+  """Compares the output of a TFLite model against pre-existing golden values.
+
+  Args:
+    tflite_model: Serialized TensorFlow Lite model.
+    input_data: np.ndarray to pass into models during inference.
+    golden_name: Name of the file containing the (expected) golden values.
+    update_golden: Whether to update the golden values with the model output
+      instead of comparing againts them. This should only be done when a change
+      in TFLite warrants it.
+    tolerance: Decimal place to check accuracy to. (default 5).
+  """
+  tflite_results, _ = _evaluate_tflite_model(tflite_model, input_data)
+  golden_file = get_golden_filepath(golden_name)
+  if update_golden:
+    logging.warning(_GOLDENS_UPDATE_WARNING)
+    logging.warning("Updating golden values in file %s.", golden_file)
+    if not os.path.exists(golden_file):
+      golden_relative_path = os.path.relpath(
+          golden_file, _resource_loader.get_root_dir_with_all_resources())
+      logging.warning(
+          "Golden file not found. Manually create it first:\ntouch %r",
+          golden_relative_path)
+
+    with open(golden_file, "wb") as f:
+      np.save(f, tflite_results, allow_pickle=False)
+  else:
+    golden_data = np.load(golden_file, allow_pickle=False)
+    np.testing.assert_almost_equal(golden_data, tflite_results, tolerance)
 
 
 def test_frozen_graph_quant(filename,
@@ -643,10 +777,22 @@
       input_data_range=input_data_range)
 
 
-def test_saved_model_v2_quant_float16(directory, **kwargs):
-  """Validates the TensorFlow SavedModel converts to a TFLite model."""
+def _test_conversion_quant_float16(converter,
+                                   input_data,
+                                   golden_name=None,
+                                   update_golden=False,
+                                   **kwargs):
+  """Validates conversion with float16 quantization.
 
-  converter = _lite.TFLiteConverterV2.from_saved_model(directory)
+  Args:
+    converter: TFLite converter instance for the model to convert.
+    input_data: np.ndarray to pass into models during inference.
+    golden_name: Optional golden values to compare the output of the model
+      against.
+    update_golden: Whether to update the golden values with the model output
+      instead of comparing againts them.
+    **kwargs: Additional arguments to be passed into the converter.
+  """
   tflite_model_float = _convert(converter, version=2, **kwargs)
 
   interpreter_float = _get_tflite_interpreter(tflite_model_float)
@@ -678,6 +824,63 @@
     raise ValueError("--post_training_quantize flag was unable to quantize the "
                      "graph as expected.")
 
+  if golden_name:
+    compare_model_golden(tflite_model_quant, input_data, golden_name,
+                         update_golden)
+
+
+def test_saved_model_v2_quant_float16(directory,
+                                      input_data,
+                                      golden_name=None,
+                                      update_golden=False,
+                                      **kwargs):
+  """Validates conversion of a saved model to TFLite with float16 quantization.
+
+  Args:
+    directory: SavedModel directory to convert.
+    input_data: np.ndarray to pass into models during inference.
+    golden_name: Optional golden values to compare the output of the model
+      against.
+    update_golden: Whether to update the golden values with the model output
+      instead of comparing againts them.
+    **kwargs: Additional arguments to be passed into the converter.
+  """
+  converter = _lite.TFLiteConverterV2.from_saved_model(directory)
+  _test_conversion_quant_float16(converter, input_data, golden_name,
+                                 update_golden, **kwargs)
+
+
+def test_frozen_graph_quant_float16(filename,
+                                    input_arrays,
+                                    output_arrays,
+                                    input_data,
+                                    input_shapes=None,
+                                    golden_name=None,
+                                    update_golden=False,
+                                    **kwargs):
+  """Validates conversion of a frozen graph to TFLite with float16 quantization.
+
+  Args:
+    filename: Full filepath of file containing frozen GraphDef.
+    input_arrays: List of input tensors to freeze graph with.
+    output_arrays: List of output tensors to freeze graph with.
+    input_data: np.ndarray to pass into models during inference.
+    input_shapes: Dict of strings representing input tensor names to list of
+      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+      Automatically determined when input shapes is None (e.g., {"foo" : None}).
+        (default None)
+    golden_name: Optional golden values to compare the output of the model
+      against.
+    update_golden: Whether to update the golden values with the model output
+      instead of comparing againts them.
+    **kwargs: Additional arguments to be passed into the converter.
+  """
+  converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays,
+                                                      output_arrays,
+                                                      input_shapes)
+  _test_conversion_quant_float16(converter, input_data,
+                                 golden_name, update_golden, **kwargs)
+
 
 def test_keras_model(filename,
                      input_arrays=None,
diff --git a/tensorflow/lite/testing/model_coverage/testdata/golden/BUILD b/tensorflow/lite/testing/model_coverage/testdata/golden/BUILD
new file mode 100644
index 0000000..7f65171
--- /dev/null
+++ b/tensorflow/lite/testing/model_coverage/testdata/golden/BUILD
@@ -0,0 +1,11 @@
+# TensorFlow Lite quantization test goldens.
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "goldens",
+    srcs = glob(["*.npy.golden"]),
+)
diff --git a/tensorflow/lite/testing/model_coverage/testdata/golden/mobilenet_v1_quant_float16.npy.golden b/tensorflow/lite/testing/model_coverage/testdata/golden/mobilenet_v1_quant_float16.npy.golden
new file mode 100644
index 0000000..949387a
--- /dev/null
+++ b/tensorflow/lite/testing/model_coverage/testdata/golden/mobilenet_v1_quant_float16.npy.golden
Binary files differ
diff --git a/tensorflow/lite/testing/model_coverage/testdata/golden/tf2_mobilenet_v2_quant_float16.npy.golden b/tensorflow/lite/testing/model_coverage/testdata/golden/tf2_mobilenet_v2_quant_float16.npy.golden
new file mode 100644
index 0000000..6b07c31
--- /dev/null
+++ b/tensorflow/lite/testing/model_coverage/testdata/golden/tf2_mobilenet_v2_quant_float16.npy.golden
Binary files differ
diff --git a/tensorflow/lite/testing/nnapi_example.cc b/tensorflow/lite/testing/nnapi_example.cc
index a847ffa..a566074 100644
--- a/tensorflow/lite/testing/nnapi_example.cc
+++ b/tensorflow/lite/testing/nnapi_example.cc
@@ -31,7 +31,7 @@
 #include "tensorflow/lite/testing/tflite_driver.h"
 
 std::string dirname(const std::string& s) {
-  return s.substr(0, s.find_last_of("/"));
+  return s.substr(0, s.find_last_of('/'));
 }
 
 bool Interpret(const char* examples_filename, bool use_nnapi) {
diff --git a/tensorflow/lite/testing/op_tests/depth_to_space.py b/tensorflow/lite/testing/op_tests/depth_to_space.py
index 9693a66..c4647e1 100644
--- a/tensorflow/lite/testing/op_tests/depth_to_space.py
+++ b/tensorflow/lite/testing/op_tests/depth_to_space.py
@@ -28,9 +28,15 @@
   """Make a set of tests to do depth_to_space."""
 
   test_parameters = [{
-      "dtype": [tf.float32, tf.int32, tf.uint8, tf.int64],
+      "dtype": [tf.int32, tf.uint8, tf.int64],
       "input_shape": [[2, 3, 4, 16]],
       "block_size": [2, 4],
+      "fully_quantize": [False],
+  }, {
+      "dtype": [tf.float32],
+      "input_shape": [[2, 3, 4, 16]],
+      "block_size": [2, 4],
+      "fully_quantize": [True, False],
   }]
 
   def build_graph(parameters):
@@ -43,8 +49,15 @@
     return [input_tensor], [out]
 
   def build_inputs(parameters, sess, inputs, outputs):
-    input_values = create_tensor_data(parameters["dtype"],
-                                      parameters["input_shape"])
+    if not parameters["fully_quantize"]:
+      input_values = create_tensor_data(parameters["dtype"],
+                                        parameters["input_shape"])
+    else:
+      input_values = create_tensor_data(
+          parameters["dtype"],
+          parameters["input_shape"],
+          min_value=-1,
+          max_value=1)
     return [input_values], sess.run(
         outputs, feed_dict=dict(zip(inputs, [input_values])))
 
diff --git a/tensorflow/lite/testing/selective_build_test.cc b/tensorflow/lite/testing/selective_build_test.cc
index c3a0cf2..f614d2e 100644
--- a/tensorflow/lite/testing/selective_build_test.cc
+++ b/tensorflow/lite/testing/selective_build_test.cc
@@ -18,8 +18,8 @@
 #include <gtest/gtest.h>
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/create_op_resolver.h"
 #include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/java/src/main/native/op_resolver.h"
 #include "tensorflow/lite/model.h"
 #include "tensorflow/lite/model_builder.h"
 
@@ -29,7 +29,7 @@
       tflite::FlatBufferModel::BuildFromFile(filename.c_str());
 
   // Build the interpreter
-  std::unique_ptr<OpResolver> resolver = CreateOpResolver();
+  std::unique_ptr<MutableOpResolver> resolver = CreateOpResolver();
   std::unique_ptr<tflite::Interpreter> interpreter;
   if (tflite::InterpreterBuilder(*model, *resolver)(&interpreter) !=
       kTfLiteOk) {
diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc
index 329df93..3d1fac2 100644
--- a/tensorflow/lite/testing/tflite_driver.cc
+++ b/tensorflow/lite/testing/tflite_driver.cc
@@ -21,11 +21,14 @@
 
 #include "absl/strings/escaping.h"
 #include "tensorflow/lite/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
 #if !defined(__APPLE__)
 #include "tensorflow/lite/delegates/flex/delegate.h"
 #endif
 #include "tensorflow/lite/kernels/custom_ops_register.h"
 #include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
+#include "tensorflow/lite/kernels/parse_example/parse_example.h"
+#include "tensorflow/lite/kernels/perception/perception_ops.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/kernels/register_ref.h"
 #include "tensorflow/lite/kernels/test_delegate_providers.h"
@@ -79,7 +82,7 @@
 }
 
 bool IsQuantized(const TfLiteTensor& tensor) {
-  if (tensor.type != kTfLiteInt8 && tensor.type != kTfLiteInt16) return false;
+  if (tensor.quantization.type == kTfLiteNoQuantization) return false;
 
   if (tensor.quantization.params != nullptr) {
     auto* quantization =
@@ -370,6 +373,8 @@
     ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
         reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
     tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
+    tflite::ops::custom::AddParseExampleOp(buildinop_resolver_);
+    tflite::ops::custom::AddPerceptionOps(buildinop_resolver_);
   }
 
   switch (delegate_type) {
diff --git a/tensorflow/lite/testing/toco_convert.py b/tensorflow/lite/testing/toco_convert.py
index 48c19c4..5216c1f 100644
--- a/tensorflow/lite/testing/toco_convert.py
+++ b/tensorflow/lite/testing/toco_convert.py
@@ -115,6 +115,7 @@
           graphdef_file.name, input_arrays, output_tensors, input_shapes)
 
       converter.experimental_new_converter = options.use_experimental_converter
+      converter._experimental_new_quantizer = options.mlir_quantizer  # pylint: disable=protected-access
       converter.optimizations = [tf.lite.Optimize.DEFAULT]
 
       if fully_quantize:
diff --git a/tensorflow/lite/testing/zip_test_utils.py b/tensorflow/lite/testing/zip_test_utils.py
index 1b44604..e7ade88 100644
--- a/tensorflow/lite/testing/zip_test_utils.py
+++ b/tensorflow/lite/testing/zip_test_utils.py
@@ -368,6 +368,11 @@
           "fully_quantize", False) or param_dict.get("quant_16x8", False)):
         continue
 
+      # Skips the new quantizer tests when `fully_quantize` is set to false
+      # or it is not set.
+      if options.mlir_quantizer and not param_dict.get("fully_quantize", False):
+        continue
+
       def generate_inputs_outputs(tflite_model_binary,
                                   min_value=0,
                                   max_value=255):
diff --git a/tensorflow/lite/toco/logging/conversion_log_util.cc b/tensorflow/lite/toco/logging/conversion_log_util.cc
index 55afa13..75cb108b 100644
--- a/tensorflow/lite/toco/logging/conversion_log_util.cc
+++ b/tensorflow/lite/toco/logging/conversion_log_util.cc
@@ -214,13 +214,13 @@
   size_t pos = error_message.find(s1);
   if (pos != std::string::npos) {
     // Find the terminate point for flex op list.
-    auto end = error_message.find(".", pos);
+    auto end = error_message.find('.', pos);
     pruned_message.append(error_message.substr(pos, end - pos + 1));
   }
   pos = error_message.find(s2);
   if (pos != std::string::npos) {
     // Find the terminate point for custom op list.
-    auto end = error_message.find(".", pos);
+    auto end = error_message.find('.', pos);
     pruned_message.append(error_message.substr(pos, end - pos + 1));
   }
   return pruned_message;
diff --git a/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc
index 7d83a9d..0d8901f 100644
--- a/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc
+++ b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc
@@ -265,7 +265,7 @@
       // Assuming the node name has a pattern like:
       // "SOMESTRING1/CELLNAME/SEARCH_PATTERN/SOMESTRING2", we use
       // CELLNAME as the cluster name.
-      size_t cell_pos = node.name().rfind("/", weights_pos - 2) + 1;
+      size_t cell_pos = node.name().rfind('/', weights_pos - 2) + 1;
       std::string cell_name =
           node.name().substr(cell_pos, weights_pos - cell_pos - 1);
       cluster = std::unique_ptr<SvdfCluster>(new SvdfCluster);
diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc
index 54f6e33..eb19a7e 100644
--- a/tensorflow/lite/toco/tooling_util.cc
+++ b/tensorflow/lite/toco/tooling_util.cc
@@ -1074,7 +1074,7 @@
     // Check name.  Either "name_with_suffix_8", "name_with_port:3", but not
     // "name_with_both:3_8".
     const std::string& name = array_entry.first;
-    auto colon_pos = name.find_first_of(":");
+    auto colon_pos = name.find_first_of(':');
     if (colon_pos != std::string::npos) {
       CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
                std::string::npos)
diff --git a/tensorflow/lite/tools/benchmark/android/jni/benchmark_model_jni.cc b/tensorflow/lite/tools/benchmark/android/jni/benchmark_model_jni.cc
index 91bad6c..190130e 100644
--- a/tensorflow/lite/tools/benchmark/android/jni/benchmark_model_jni.cc
+++ b/tensorflow/lite/tools/benchmark/android/jni/benchmark_model_jni.cc
@@ -62,9 +62,7 @@
 }  // namespace benchmark
 }  // namespace tflite
 
-#ifdef __cplusplus
 extern "C" {
-#endif
 
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_benchmark_BenchmarkModel_nativeRun(JNIEnv* env,
@@ -90,6 +88,4 @@
   env->ReleaseStringUTFChars(args_obj, args_chars);
 }
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif  // __cplusplus
diff --git a/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc b/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc
index 354420d..b36eb93 100644
--- a/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc
+++ b/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc
@@ -18,9 +18,7 @@
 #include "tensorflow/core/util/stats_calculator.h"
 #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
 
-#ifdef __cplusplus
 extern "C" {
-#endif  // __cplusplus
 
 // -----------------------------------------------------------------------------
 // C APIs corresponding to tflite::benchmark::BenchmarkResults type.
@@ -179,6 +177,4 @@
   return benchmark_model->benchmark_model->AddListener(listener->adapter.get());
 }
 
-#ifdef __cplusplus
-}
-#endif  // __cplusplus
+}  // extern "C"
diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
index 923a0fa..59ad977 100644
--- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
+++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
@@ -40,26 +40,12 @@
 #include <stddef.h>
 #include <stdint.h>
 
+#include "tensorflow/lite/c/c_api_types.h"  // IWYU pragma: export
+
 #ifdef __cplusplus
 extern "C" {
 #endif  // __cplusplus
 
-typedef enum TfLiteStatus {
-  kTfLiteOk = 0,
-
-  // Generally referring to an error in the runtime (i.e. interpreter)
-  kTfLiteError = 1,
-
-  // Generally referring to an error from a TfLiteDelegate itself.
-  kTfLiteDelegateError = 2,
-
-  // Generally referring to an error in applying a delegate due to
-  // incompatibility between runtime and delegate, e.g., this error is returned
-  // when trying to apply a TfLite delegate onto a model graph that's already
-  // immutable.
-  kTfLiteApplicationError = 3
-} TfLiteStatus;
-
 // The list of external context types known to TF Lite. This list exists solely
 // to avoid conflicts and to ensure ops can share the external contexts they
 // need. Access to the external contexts is controlled by one of the
@@ -254,22 +240,6 @@
     }                                      \
   } while (0)
 
-// Define TFL_CAPI_EXPORT macro to export a function properly with a shared
-// library.
-#ifdef SWIG
-#define TFL_CAPI_EXPORT
-#else
-#if defined(_WIN32)
-#ifdef TFL_COMPILE_LIBRARY
-#define TFL_CAPI_EXPORT __declspec(dllexport)
-#else
-#define TFL_CAPI_EXPORT __declspec(dllimport)
-#endif  // TFL_COMPILE_LIBRARY
-#else
-#define TFL_CAPI_EXPORT __attribute__((visibility("default")))
-#endif  // _WIN32
-#endif  // SWIG
-
 // Single-precision complex data type compatible with the C99 definition.
 typedef struct TfLiteComplex64 {
   float re, im;  // real and imaginary parts, respectively.
@@ -285,24 +255,6 @@
   uint16_t data;
 } TfLiteFloat16;
 
-// Types supported by tensor
-typedef enum {
-  kTfLiteNoType = 0,
-  kTfLiteFloat32 = 1,
-  kTfLiteInt32 = 2,
-  kTfLiteUInt8 = 3,
-  kTfLiteInt64 = 4,
-  kTfLiteString = 5,
-  kTfLiteBool = 6,
-  kTfLiteInt16 = 7,
-  kTfLiteComplex64 = 8,
-  kTfLiteInt8 = 9,
-  kTfLiteFloat16 = 10,
-  kTfLiteFloat64 = 11,
-  kTfLiteComplex128 = 12,
-  kTfLiteUInt64 = 13,
-} TfLiteType;
-
 // Return the name of a given type, for error reporting purposes.
 const char* TfLiteTypeGetName(TfLiteType type);
 
@@ -319,22 +271,12 @@
 typedef struct TfLiteQuantization {
   // The type of quantization held by params.
   TfLiteQuantizationType type;
-  // Holds a reference to one of the quantization param structures specified
-  // below.
+  // Holds an optional reference to a quantization param structure. The actual
+  // type depends on the value of the `type` field (see the comment there for
+  // the values and corresponding types).
   void* params;
 } TfLiteQuantization;
 
-// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
-// If per-layer quantization is specified this field will still be populated in
-// addition to TfLiteAffineQuantization.
-// Parameters for asymmetric quantization. Quantized values can be converted
-// back to float using:
-//     real_value = scale * (quantized_value - zero_point)
-typedef struct TfLiteQuantizationParams {
-  float scale;
-  int32_t zero_point;
-} TfLiteQuantizationParams;
-
 // Parameters for asymmetric quantization across a dimension (i.e per output
 // channel quantization).
 // quantized_dimension specifies which dimension the scales and zero_points
@@ -536,7 +478,7 @@
   // WARNING: This is an experimental interface that is subject to change.
   struct TfLiteDelegate* delegate;
 } TfLiteNode;
-#else  // defined(TF_LITE_STATIC_MEMORY)?
+#else   // defined(TF_LITE_STATIC_MEMORY)?
 // NOTE: This flag is opt-in only at compile time.
 //
 // Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
diff --git a/tensorflow/lite/tools/benchmark/experimental/firebase/android/jni/benchmark_model_jni.cc b/tensorflow/lite/tools/benchmark/experimental/firebase/android/jni/benchmark_model_jni.cc
index 97cba27..26ce0ff 100644
--- a/tensorflow/lite/tools/benchmark/experimental/firebase/android/jni/benchmark_model_jni.cc
+++ b/tensorflow/lite/tools/benchmark/experimental/firebase/android/jni/benchmark_model_jni.cc
@@ -247,9 +247,7 @@
 }  // namespace benchmark
 }  // namespace tflite
 
-#ifdef __cplusplus
 extern "C" {
-#endif
 
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_benchmark_firebase_BenchmarkModel_nativeRun(
@@ -263,6 +261,4 @@
   env->ReleaseStringUTFChars(library_dir, lib_dir);
 }
 
-#ifdef __cplusplus
 }  // extern "C"
-#endif  // __cplusplus
diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake
index ba6714d..5defaa0 100644
--- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake
+++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake
@@ -22,7 +22,7 @@
 OverridableFetchContent_Declare(
   xnnpack
   GIT_REPOSITORY https://github.com/google/XNNPACK
-  GIT_TAG 6eaa1521288d268dd4cceca4ae5c018cf009179b
+  GIT_TAG 094e692629d57ddb932fcc993193626f60daa61b
   GIT_PROGRESS TRUE
   PREFIX "${CMAKE_BINARY_DIR}"
   SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack"
diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc
index 5656221..730a7d5 100644
--- a/tensorflow/lite/tools/evaluation/utils.cc
+++ b/tensorflow/lite/tools/evaluation/utils.cc
@@ -75,7 +75,7 @@
     while ((ent = readdir(dir)) != nullptr) {
       if (ent->d_type == DT_DIR) continue;
       std::string filename(std::string(ent->d_name));
-      size_t lastdot = filename.find_last_of(".");
+      size_t lastdot = filename.find_last_of('.');
       std::string ext = lastdot != std::string::npos ? filename.substr(lastdot)
                                                      : std::string();
       std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
diff --git a/tensorflow/lite/tools/flatbuffer_utils.py b/tensorflow/lite/tools/flatbuffer_utils.py
index 8a9031f..3fa9b1f 100644
--- a/tensorflow/lite/tools/flatbuffer_utils.py
+++ b/tensorflow/lite/tools/flatbuffer_utils.py
@@ -126,6 +126,8 @@
     subgraph.name = None
     for tensor in subgraph.tensors:
       tensor.name = None
+  # We clear all signature_def structure, since without names it is useless.
+  model.signatureDefs = None
 
 
 def randomize_weights(model, random_seed=0):
diff --git a/tensorflow/lite/tools/flatbuffer_utils_test.py b/tensorflow/lite/tools/flatbuffer_utils_test.py
index 129e027..cca1f09 100644
--- a/tensorflow/lite/tools/flatbuffer_utils_test.py
+++ b/tensorflow/lite/tools/flatbuffer_utils_test.py
@@ -89,6 +89,9 @@
     # Validate the description
     self.assertIsNotNone(initial_model.description)
     self.assertIsNone(final_model.description)
+    self.assertIsNotNone(initial_model.signatureDefs)
+    self.assertIsNone(final_model.signatureDefs)
+
     # Validate the main subgraph's name, inputs, outputs, operators and tensors
     initial_subgraph = initial_model.subgraphs[0]
     final_subgraph = final_model.subgraphs[0]
diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile
index 4debaae..99fe540 100644
--- a/tensorflow/lite/tools/make/Makefile
+++ b/tensorflow/lite/tools/make/Makefile
@@ -190,6 +190,7 @@
 $(wildcard tensorflow/lite/kernels/*test_main.cc) \
 $(wildcard tensorflow/lite/kernels/*test_util*.cc) \
 $(wildcard tensorflow/lite/tools/make/downloads/cpuinfo/src/*/mock*.c) \
+tensorflow/lite/create_op_resolver_with_selected_ops.cc \
 tensorflow/lite/tflite_with_xnnpack.cc \
 $(MINIMAL_SRCS)
 
@@ -345,14 +346,18 @@
 # The target that's compiled for micro-controllers
 micro: $(LIB_PATH)
 
-# Hack for generating schema file bypassing flatbuffer parsing
+# Hack for generating schema files bypassing flatbuffer parsing
 tensorflow/lite/schema/schema_generated.h:
 	@cp -u tensorflow/lite/schema/schema_generated.h.oss tensorflow/lite/schema/schema_generated.h
+	@cp -u tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h.oss tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h
 
 # Gathers together all the objects we've compiled into a single '.a' archive.
 $(LIB_PATH): tensorflow/lite/schema/schema_generated.h $(LIB_OBJS)
 	@mkdir -p $(dir $@)
 	$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
+$(LIB_PATH): tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h $(LIB_OBJS)
+	@mkdir -p $(dir $@)
+	$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
 
 lib: $(LIB_PATH)
 
diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD
index 88015d7..4eae02c 100644
--- a/tensorflow/lite/tools/optimize/BUILD
+++ b/tensorflow/lite/tools/optimize/BUILD
@@ -321,6 +321,7 @@
         "//tensorflow/lite/tools/optimize:testdata/argmax.bin",
         "//tensorflow/lite/tools/optimize:testdata/concat.bin",
         "//tensorflow/lite/tools/optimize:testdata/fc.bin",
+        "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin",
         "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin",
         "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin",
         "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin",
diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc
index 45dff78..d5ef945 100644
--- a/tensorflow/lite/tools/optimize/operator_property.cc
+++ b/tensorflow/lite/tools/optimize/operator_property.cc
@@ -106,6 +106,13 @@
       property.version = 2;
       property.quantizable_int16 = false;
       break;
+    case BuiltinOperator_DEPTH_TO_SPACE:
+      property.inputs = {{0, {}}};
+      property.outputs = {{0, {}}};
+      property.restrict_same_input_output_scale = true;
+      property.version = 2;
+      property.quantizable_int16 = false;
+      break;
     case BuiltinOperator_SPLIT:
       // We skip input 0 since it is the split dim which is not real valued.
       property.inputs = {{1, {}}};
diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc
index 713bafd..43e3d3b 100644
--- a/tensorflow/lite/tools/optimize/quantize_model.cc
+++ b/tensorflow/lite/tools/optimize/quantize_model.cc
@@ -218,6 +218,41 @@
   return (int8check || int16check);
 }
 
+// Check if input is consumed by quantize, which means we don't need to
+// requantize if the output scale is the same as the input tensor's.
+bool InputQuantizeRequired(const ModelT* model, const SubGraphT* subgraph,
+                           int32_t input_idx) {
+  std::vector<OperatorT*> quantize_ops;
+  for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) {
+    OperatorT* op = subgraph->operators[op_idx].get();
+    if (std::find(op->inputs.begin(), op->inputs.end(), input_idx) !=
+        op->inputs.end()) {
+      const BuiltinOperator op_code =
+          GetBuiltinCode(model->operator_codes[op->opcode_index].get());
+      if (op_code != BuiltinOperator_QUANTIZE) {
+        return true;
+      }
+      quantize_ops.push_back(op);
+    }
+  }
+  if (quantize_ops.size() == 1) {
+    const auto* tensor = subgraph->tensors[input_idx].get();
+    const auto* op = quantize_ops[0];
+    const int32_t output_idx = op->outputs[0];
+    const auto output_type = subgraph->tensors[output_idx]->type;
+    const float output_scale =
+        subgraph->tensors[output_idx]->quantization->scale[0];
+    const int64_t output_zero_point =
+        subgraph->tensors[output_idx]->quantization->zero_point[0];
+    if (output_type == tensor->type &&
+        output_scale == tensor->quantization->scale[0] &&
+        output_zero_point == tensor->quantization->zero_point[0]) {
+      return false;
+    }
+  }
+  return true;
+}
+
 // Sets the input type, adding a Leading Op node at the start of the model if
 // necessary.
 // Returns the new input tensor index.
@@ -258,6 +293,13 @@
           leading_op_name, tensor->shape, tensor->shape_signature, input_type,
           scale, zero_point + 128, &leading_op_input);
     }
+
+    // Check if quantize op already exists.
+    if (!InputQuantizeRequired(model, subgraph, tensor_idx)) {
+      subgraph->tensors[tensor_idx] = std::move(leading_op_input);
+      return tensor_idx;
+    }
+
     const int32_t leading_op_input_idx = subgraph->tensors.size();
     subgraph->tensors.push_back(std::move(leading_op_input));
 
@@ -963,6 +1005,11 @@
             EnumNameBuiltinOperator(op_code));
         quantization_not_supported = true;
       } else if (!property.quantizable && !allow_float) {
+        if (op_code == BuiltinOperator_DEQUANTIZE &&
+            std::find(subgraph->outputs.begin(), subgraph->outputs.end(),
+                      op->outputs[0]) != subgraph->outputs.end()) {
+          continue;
+        }
         TF_LITE_REPORT_ERROR(error_reporter,
                              "Quantization not yet supported for op: '%s'.\n",
                              EnumNameBuiltinOperator(op_code));
diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc
index 9afd163..92df971 100644
--- a/tensorflow/lite/tools/optimize/quantize_model_test.cc
+++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc
@@ -1639,6 +1639,64 @@
             transpose_output->quantization->zero_point[0]);
 }
 
+class QuantizeQatTest : public QuantizeModelTest {
+ protected:
+  QuantizeQatTest() {
+    input_model_ = ReadModel(internal::kQatModelWithFc);
+    readonly_model_ = input_model_->GetModel();
+    readonly_model_->UnPackTo(&model_);
+  }
+};
+
+TEST_F(QuantizeQatTest, VerifySingleQuantize) {
+  auto status = QuantizeModelAllOperators(
+      &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
+      TensorType_INT8, &error_reporter_);
+  ASSERT_EQ(kTfLiteOk, status);
+
+  const auto& subgraph = model_.subgraphs[0];
+  auto op = subgraph->operators[0].get();
+  ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
+            BuiltinOperator_QUANTIZE);
+  op = subgraph->operators[1].get();
+  ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
+            BuiltinOperator_RESHAPE);
+  op = subgraph->operators[2].get();
+  ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
+            BuiltinOperator_FULLY_CONNECTED);
+
+  ASSERT_EQ(op->inputs.size(), 3);
+  ASSERT_EQ(op->outputs.size(), 1);
+
+  auto qat_graph = readonly_model_->subgraphs()->Get(0);
+  // Verify FC input and weight is quantized.
+  ASSERT_EQ(qat_graph->tensors()->Get(op->inputs[0])->type(), TensorType_INT8);
+  EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
+  ASSERT_EQ(qat_graph->tensors()->Get(op->inputs[1])->type(), TensorType_INT8);
+  EXPECT_EQ(subgraph->tensors[op->inputs[1]].get()->type, TensorType_INT8);
+
+  // Verify FC bias should be int32 quantized.
+  ASSERT_EQ(qat_graph->tensors()->Get(op->inputs[2])->type(), TensorType_INT32);
+  EXPECT_EQ(subgraph->tensors[op->inputs[2]].get()->type, TensorType_INT32);
+
+  // The output of FC should be quantized.
+  ASSERT_EQ(qat_graph->tensors()->Get(op->outputs[0])->type(), TensorType_INT8);
+  EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
+
+  // check op and versioning.
+  EXPECT_EQ(model_.operator_codes.size(), 4);
+  EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
+            BuiltinOperator_QUANTIZE);
+  EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
+            BuiltinOperator_RESHAPE);
+  EXPECT_EQ(GetBuiltinCode(model_.operator_codes[2].get()),
+            BuiltinOperator_FULLY_CONNECTED);
+  EXPECT_EQ(GetBuiltinCode(model_.operator_codes[3].get()),
+            BuiltinOperator_DEQUANTIZE);
+  EXPECT_EQ(model_.operator_codes[1]->version, 1);
+  EXPECT_EQ(model_.operator_codes[2]->version, 4);
+}
+
 }  // namespace
 }  // namespace optimize
 }  // namespace tflite
diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc
index 29f5c1f..0dba966 100644
--- a/tensorflow/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/lite/tools/optimize/quantize_weights.cc
@@ -356,6 +356,8 @@
       model->operator_codes[i]->version = use_updated_hybrid_scheme ? 5 : 2;
     } else if (op_code == BuiltinOperator_FULLY_CONNECTED) {
       model->operator_codes[i]->version = use_updated_hybrid_scheme ? 9 : 3;
+    } else if (op_code == BuiltinOperator_BATCH_MATMUL) {
+      model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 1;
     } else if (op_code == BuiltinOperator_SVDF) {
       model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 2;
     } else if (op_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
diff --git a/tensorflow/lite/tools/optimize/sparsity/format_converter.cc b/tensorflow/lite/tools/optimize/sparsity/format_converter.cc
index d6a80f5..c5a7778 100644
--- a/tensorflow/lite/tools/optimize/sparsity/format_converter.cc
+++ b/tensorflow/lite/tools/optimize/sparsity/format_converter.cc
@@ -14,13 +14,9 @@
 ==============================================================================*/
 #include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
 
-#include <algorithm>
 #include <cstdint>
-#include <iostream>
 #include <vector>
 
-#include "tensorflow/lite/c/common.h"
-
 namespace tflite {
 namespace optimize {
 namespace sparsity {
@@ -261,7 +257,8 @@
 
 template <typename T>
 void FormatConverter<T>::Populate(const T* src_data, std::vector<int> indices,
-                                  int level, int prev_idx, int* src_data_ptr) {
+                                  int level, int prev_idx, int* src_data_ptr,
+                                  T* dest_data) {
   if (level == indices.size()) {
     int orig_rank = dense_shape_.size();
     std::vector<int> orig_idx;
@@ -279,7 +276,8 @@
           orig_idx[orig_dim] * block_size_[block_idx] + indices[i];
     }
 
-    data_[GetFlattenedIndex(orig_idx, dense_shape_)] = src_data[*src_data_ptr];
+    dest_data[GetFlattenedIndex(orig_idx, dense_shape_)] =
+        src_data[*src_data_ptr];
 
     *src_data_ptr = *src_data_ptr + 1;
     return;
@@ -291,7 +289,7 @@
     for (int i = 0; i < shape_of_level; i++) {
       indices[level] = i;
       Populate(src_data, indices, level + 1, prev_idx * shape_of_level + i,
-               src_data_ptr);
+               src_data_ptr, dest_data);
     }
   } else {
     const auto& array_segments = dim_metadata_[metadata_idx];
@@ -299,7 +297,7 @@
     for (int i = array_segments[prev_idx]; i < array_segments[prev_idx + 1];
          i++) {
       indices[level] = array_indices[i];
-      Populate(src_data, indices, level + 1, i, src_data_ptr);
+      Populate(src_data, indices, level + 1, i, src_data_ptr, dest_data);
     }
   }
 }
@@ -312,7 +310,32 @@
   int total_rank = traversal_order_.size();
   int src_data_ptr = 0;
   std::vector<int> indices(total_rank);
-  Populate(src_data, indices, 0, 0, &src_data_ptr);
+  Populate(src_data, indices, 0, 0, &src_data_ptr, data_.data());
+
+  return kTfLiteOk;
+}
+
+template <typename T>
+TfLiteStatus FormatConverter<T>::SparseToDense(const T* src_data,
+                                               const size_t dest_size,
+                                               T* dest_data,
+                                               TfLiteContext* context) {
+  if (dest_size != dense_size_) {
+    TF_LITE_MAYBE_KERNEL_LOG(
+        context, "unexpected buffer size for densified data, expected %lld.\n",
+        dense_size_);
+    return kTfLiteError;
+  }
+
+  // For types like Eigen::half, we cannot do a simple memset() with 0 values.
+  for (auto i = 0; i < dest_size; i++) {
+    dest_data[i] = T(0);
+  }
+
+  const int total_rank = traversal_order_.size();
+  int src_data_ptr = 0;
+  std::vector<int> indices(total_rank);
+  Populate(src_data, indices, 0, 0, &src_data_ptr, dest_data);
 
   return kTfLiteOk;
 }
diff --git a/tensorflow/lite/tools/optimize/sparsity/format_converter.h b/tensorflow/lite/tools/optimize/sparsity/format_converter.h
index 46e7d93..1ac324c 100644
--- a/tensorflow/lite/tools/optimize/sparsity/format_converter.h
+++ b/tensorflow/lite/tools/optimize/sparsity/format_converter.h
@@ -15,7 +15,6 @@
 #ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_SPARSITY_FORMAT_CONVERTER_H_
 #define TENSORFLOW_LITE_TOOLS_OPTIMIZE_SPARSITY_FORMAT_CONVERTER_H_
 
-#include <memory>
 #include <vector>
 
 #include "third_party/eigen3/Eigen/Core"
@@ -54,18 +53,28 @@
   FormatConverter(const std::vector<int>& shape,
                   const TfLiteSparsity& sparsity);
 
-  std::vector<T> GetData() { return data_; }
-  std::vector<std::vector<int>> GetDimMetadata() { return dim_metadata_; }
+  const std::vector<T>& GetData() { return data_; }
+  const std::vector<std::vector<int>>& GetDimMetadata() {
+    return dim_metadata_;
+  }
 
+  // Method for dense to sparse conversion. Need to call GetData() method to get
+  // the compressed data.
   TfLiteStatus DenseToSparse(const T* src_data);
 
+  // Method for sparse to dense conversion. Need to call GetData() method to get
+  // the decompressed data.
   TfLiteStatus SparseToDense(const T* src_data);
+  // Method for sparse to dense conversion with caller provided buffer. No need
+  // to call GetData() with this method.
+  TfLiteStatus SparseToDense(const T* src_data, const size_t dest_size,
+                             T* dest_data, TfLiteContext* context = nullptr);
 
  private:
   // A recursive function to fetch data from the compressed src_data buffer and
   // populate the dense buffer.
   void Populate(const T* src_data, std::vector<int> indices, int level,
-                int prev_idx, int* src_data_ptr);
+                int prev_idx, int* src_data_ptr, T* dest_data);
 
   // Check if val is equal to zero.
   bool IsZero(const T val);
@@ -76,7 +85,7 @@
   // tensor with (2, 2) block has blocked_shape (2, 2).
   std::vector<int> blocked_shape_;
   // Total number of elements in the dense tensor.
-  uint64_t dense_size_;
+  size_t dense_size_;
   // Has n(original dimension)+k(block_dimension) elements.
   std::vector<int> traversal_order_;
   // Format of each dimension in the traversal order.
diff --git a/tensorflow/lite/tools/optimize/sparsity/format_converter_test.cc b/tensorflow/lite/tools/optimize/sparsity/format_converter_test.cc
index 96919d2..ddf3477 100644
--- a/tensorflow/lite/tools/optimize/sparsity/format_converter_test.cc
+++ b/tensorflow/lite/tools/optimize/sparsity/format_converter_test.cc
@@ -31,19 +31,24 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0 = {3};
   const std::vector<int> dm1 = {4};
   EXPECT_EQ(dm0, dim_metadata[0]);
   EXPECT_EQ(dm1, dim_metadata[2]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, SimpleTestS0D1) {
@@ -55,7 +60,7 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0_0 = {0, 2};
   const std::vector<int> dm0_1 = {0, 2};
   const std::vector<int> dm1 = {4};
@@ -63,13 +68,18 @@
   EXPECT_EQ(dm0_1, dim_metadata[1]);
   EXPECT_EQ(dm1, dim_metadata[2]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {6, 0, 9, 8, 5, 0, 0, 7};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, SimpleTestD0S1) {
@@ -81,7 +91,7 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0 = {3};
   const std::vector<int> dm1_0 = {0, 3, 3, 5};
   const std::vector<int> dm1_1 = {0, 2, 3, 0, 3};
@@ -89,13 +99,18 @@
   EXPECT_EQ(dm1_0, dim_metadata[2]);
   EXPECT_EQ(dm1_1, dim_metadata[3]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {6, 9, 8, 5, 7};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, SimpleTestS0S1) {
@@ -107,7 +122,7 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0_0 = {0, 2};
   const std::vector<int> dm0_1 = {0, 2};
   const std::vector<int> dm1_0 = {0, 3, 5};
@@ -117,13 +132,18 @@
   EXPECT_EQ(dm1_0, dim_metadata[2]);
   EXPECT_EQ(dm1_1, dim_metadata[3]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {6, 9, 8, 5, 7};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, SimpleTestD1D0) {
@@ -135,19 +155,24 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0 = {4};
   const std::vector<int> dm1 = {3};
   EXPECT_EQ(dm0, dim_metadata[0]);
   EXPECT_EQ(dm1, dim_metadata[2]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {6, 0, 5, 0, 0, 0, 9, 0, 0, 8, 0, 7};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, SimpleTestS1D0) {
@@ -159,7 +184,7 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0_0 = {0, 3};
   const std::vector<int> dm0_1 = {0, 2, 3};
   const std::vector<int> dm1 = {3};
@@ -167,13 +192,18 @@
   EXPECT_EQ(dm0_1, dim_metadata[1]);
   EXPECT_EQ(dm1, dim_metadata[2]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {6, 0, 5, 9, 0, 0, 8, 0, 7};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, SimpleTestD1S0) {
@@ -185,7 +215,7 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0 = {4};
   const std::vector<int> dm1_0 = {0, 2, 2, 3, 5};
   const std::vector<int> dm1_1 = {0, 2, 0, 0, 2};
@@ -193,13 +223,18 @@
   EXPECT_EQ(dm1_0, dim_metadata[2]);
   EXPECT_EQ(dm1_1, dim_metadata[3]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {6, 5, 9, 8, 7};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, SimpleTestS1S0) {
@@ -211,7 +246,7 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0_0 = {0, 3};
   const std::vector<int> dm0_1 = {0, 2, 3};
   const std::vector<int> dm1_0 = {0, 2, 3, 5};
@@ -221,13 +256,18 @@
   EXPECT_EQ(dm1_0, dim_metadata[2]);
   EXPECT_EQ(dm1_1, dim_metadata[3]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {6, 5, 9, 8, 7};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, 3DTestS0D1S2) {
@@ -239,7 +279,7 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0_0 = {0, 2};
   const std::vector<int> dm0_1 = {0, 2};
   const std::vector<int> dm1 = {2};
@@ -252,13 +292,18 @@
   EXPECT_EQ(dm2_0, dim_metadata[4]);
   EXPECT_EQ(dm2_1, dim_metadata[5]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {6, 9, 8, 5, 7};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, 3DTestD0D1S2) {
@@ -270,7 +315,7 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0 = {3};
   const std::vector<int> dm1 = {2};
   const std::vector<int> dm2_0 = {0, 1, 3, 3, 3, 4, 5};
@@ -281,13 +326,18 @@
   EXPECT_EQ(dm2_0, dim_metadata[4]);
   EXPECT_EQ(dm2_1, dim_metadata[5]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {6, 9, 8, 5, 7};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, 3DTestS0S1S2) {
@@ -300,7 +350,7 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0_0 = {0, 2};
   const std::vector<int> dm0_1 = {0, 2};
   const std::vector<int> dm1_0 = {0, 2, 5};
@@ -314,13 +364,18 @@
   EXPECT_EQ(dm2_0, dim_metadata[4]);
   EXPECT_EQ(dm2_1, dim_metadata[5]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {1, 7, 5, 2, 4, 8, 3, 9};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, 3DTestS0S2S1) {
@@ -333,7 +388,7 @@
   FormatConverter<int> converter(dense_shape, traversal_order, format);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0_0 = {0, 2};
   const std::vector<int> dm0_1 = {0, 2};
   const std::vector<int> dm1_0 = {0, 2, 5};
@@ -347,13 +402,18 @@
   EXPECT_EQ(dm2_0, dim_metadata[4]);
   EXPECT_EQ(dm2_1, dim_metadata[5]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {1, 7, 5, 2, 4, 8, 3, 9};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, BlockTestD0D1) {
@@ -369,21 +429,26 @@
                                  block_size, block_map);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm = {2};
   EXPECT_EQ(dm, dim_metadata[0]);
   EXPECT_EQ(dm, dim_metadata[2]);
   EXPECT_EQ(dm, dim_metadata[4]);
   EXPECT_EQ(dm, dim_metadata[6]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {1, 0, 0, 4, 2, 3, 0, 0,
                                           0, 0, 0, 0, 5, 0, 0, 6};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 // BCSR
@@ -400,7 +465,7 @@
                                  block_size, block_map);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm0 = {4};
   const std::vector<int> dm2 = {2};
   const std::vector<int> dm1_0 = {0, 2, 3, 4, 5};
@@ -410,13 +475,18 @@
   EXPECT_EQ(dm1_1, dim_metadata[3]);
   EXPECT_EQ(dm2, dim_metadata[4]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {1, 0, 2, 3, 0, 4, 5, 0, 0, 6};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 // BCSR
@@ -433,7 +503,7 @@
                                  block_size, block_map);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm = {2};
   const std::vector<int> dm1_0 = {0, 2, 3};
   const std::vector<int> dm1_1 = {0, 1, 1};
@@ -443,13 +513,18 @@
   EXPECT_EQ(dm, dim_metadata[4]);
   EXPECT_EQ(dm, dim_metadata[6]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {1, 0, 0, 4, 2, 3, 0, 0, 5, 0, 0, 6};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 // BCSC
@@ -466,7 +541,7 @@
                                  block_size, block_map);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm = {2};
   const std::vector<int> dm1_0 = {0, 1, 3};
   const std::vector<int> dm1_1 = {0, 0, 1};
@@ -476,13 +551,18 @@
   EXPECT_EQ(dm, dim_metadata[4]);
   EXPECT_EQ(dm, dim_metadata[6]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {1, 0, 0, 4, 2, 0, 3, 0, 5, 0, 0, 6};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 // BCSR with last block being empty
@@ -499,7 +579,7 @@
                                  block_size, block_map);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm = {2};
   const std::vector<int> dm1_0 = {0, 2, 2};
   const std::vector<int> dm1_1 = {0, 1};
@@ -509,13 +589,18 @@
   EXPECT_EQ(dm, dim_metadata[4]);
   EXPECT_EQ(dm, dim_metadata[6]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {1, 0, 0, 4, 2, 3, 0, 0};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 
 TEST(FormatConverterTest, BlockTestD0S1ColMajorBlock) {
@@ -532,7 +617,7 @@
                                  block_size, block_map);
   converter.DenseToSparse(dense_values.data());
 
-  const auto dim_metadata = converter.GetDimMetadata();
+  const auto& dim_metadata = converter.GetDimMetadata();
   const std::vector<int> dm = {2};
   const std::vector<int> dm1_0 = {0, 3, 4};
   const std::vector<int> dm1_1 = {0, 1, 2, 1};
@@ -542,14 +627,19 @@
   EXPECT_EQ(dm, dim_metadata[4]);
   EXPECT_EQ(dm, dim_metadata[6]);
 
-  const auto data = converter.GetData();
+  const auto& data = converter.GetData();
   const std::vector<int> expected_data = {1, 1, 0, 0, 2, 2, 3, 3,
                                           0, 0, 4, 4, 5, 0, 0, 0};
   EXPECT_EQ(expected_data, data);
 
   converter.SparseToDense(expected_data.data());
-  const auto data_back = converter.GetData();
+  const auto& data_back = converter.GetData();
   EXPECT_EQ(data_back, dense_values);
+
+  std::vector<int> dense_data(dense_values.size());
+  converter.SparseToDense(expected_data.data(), dense_data.size(),
+                          dense_data.data(), nullptr);
+  EXPECT_EQ(dense_data, dense_values);
 }
 }  // namespace
 }  // namespace sparsity
diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/lite/tools/optimize/test_util.cc
index 5565fc4..433a10b 100644
--- a/tensorflow/lite/tools/optimize/test_util.cc
+++ b/tensorflow/lite/tools/optimize/test_util.cc
@@ -73,6 +73,7 @@
 const char* kSvdfQuantized = "svdf_quantized.bin";
 
 const char* kModelWithUnpack = "unpack.bin";
+const char* kQatModelWithFc = "fc_qat.bin";
 
 int FailOnErrorReporter::Report(const char* format, va_list args) {
   char buf[1024];
diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/lite/tools/optimize/test_util.h
index 4341a67..6229729 100644
--- a/tensorflow/lite/tools/optimize/test_util.h
+++ b/tensorflow/lite/tools/optimize/test_util.h
@@ -116,6 +116,9 @@
 // Test model with an unpack op.
 extern const char* kModelWithUnpack;
 
+// Test QAT model with fc op.
+extern const char* kQatModelWithFc;
+
 // An error reporter that fails on testing.
 class FailOnErrorReporter : public ErrorReporter {
  public:
diff --git a/tensorflow/lite/tools/optimize/testdata/fc_qat.bin b/tensorflow/lite/tools/optimize/testdata/fc_qat.bin
new file mode 100644
index 0000000..f121f7e
--- /dev/null
+++ b/tensorflow/lite/tools/optimize/testdata/fc_qat.bin
Binary files differ
diff --git a/tensorflow/lite/tools/serialization/BUILD b/tensorflow/lite/tools/serialization/BUILD
index ceb11e2..5472dbe 100644
--- a/tensorflow/lite/tools/serialization/BUILD
+++ b/tensorflow/lite/tools/serialization/BUILD
@@ -35,6 +35,7 @@
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/schema:schema_fbs_with_reflection",
         "//tensorflow/lite/schema:schema_utils",
+        "@com_google_absl//absl/container:flat_hash_map",
     ],
 )
 
@@ -67,6 +68,7 @@
         "//tensorflow/lite:framework",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/kernels:builtin_ops",
+        "//tensorflow/lite/kernels:subgraph_test_util",
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/testing:util",
         "@com_google_googletest//:gtest",
diff --git a/tensorflow/lite/tools/serialization/README.md b/tensorflow/lite/tools/serialization/README.md
new file mode 100644
index 0000000..bd6c91e
--- /dev/null
+++ b/tensorflow/lite/tools/serialization/README.md
@@ -0,0 +1,63 @@
+# TFLite Serialization Tool
+
+**NOTE:** This tool is intended for advanced users only, and should be used with
+care.
+
+The (C++) serialization library generates and writes a TFLite flatbuffer given
+an `Interpreter` or `Subgraph`. Example use-cases include authoring models with
+the `Interpreter` API, or updating models on-device (by modifying `tensor.data`
+for relevant tensors).
+
+## Serialization
+
+### Writing flatbuffer to file
+
+To write a TFLite model from an `Interpreter` (see `lite/interpreter.h`):
+`std::unique_ptr<tflite::Interpreter> interpreter; // ...build/modify
+interpreter... tflite::ModelWriter writer(interpreter.get()); std::string
+filename = "/tmp/model.tflite"; writer.Write(filename);`
+
+Note that the above API does not support custom I/O tensors or custom ops yet.
+However, it does support model with Control Flow.
+
+To generate/write a flatbuffer for a particular `Subgraph` (see
+`lite/core/subgraph.h`) you can use `SubgraphWriter`.
+
+```
+std::unique_ptr<tflite::Interpreter> interpreter;
+// ...build/modify interpreter...
+// The number of subgraphs can be obtained by:
+// const int num_subgraphs = interpreter_->subgraphs_size();
+// Note that 0 <= subgraph_index < num_subgraphs
+tflite::SubgraphWriter writer(&interpreter->subgraph(subgraph_index));
+std::string filename = "/tmp/model.tflite";
+writer.Write(filename);
+```
+
+`SubgraphWriter` supports custom ops and/or custom I/O tensors.
+
+### Generating flatbuffer in-memory
+
+Both `ModelWriter` and `SubgraphWriter` support a `GetBuffer` method to return
+the generated flatbuffer in-memory:
+
+```
+std::unique_ptr<uint8_t[]> output_buffer;
+size_t output_buffer_size;
+tflite::ModelWriter writer(interpreter.get());
+writer.GetBuffer(&output_buffer, &output_buffer_size);
+```
+
+## De-serialization
+
+The flatbuffers written as above can be de-serialized just like any other TFLite
+model, for eg:
+
+```
+std::unique_ptr<FlatBufferModel> model =
+    FlatBufferModel::BuildFromFile(filename);
+tflite::ops::builtin::BuiltinOpResolver resolver;
+InterpreterBuilder builder(*model, resolver);
+std::unique_ptr<Interpreter> new_interpreter;
+builder(&new_interpreter);
+```
diff --git a/tensorflow/lite/tools/serialization/writer.cc b/tensorflow/lite/tools/serialization/writer.cc
index fb81679..e52114b 100644
--- a/tensorflow/lite/tools/serialization/writer.cc
+++ b/tensorflow/lite/tools/serialization/writer.cc
@@ -34,7 +34,7 @@
   std::unique_ptr<tflite::Interpreter> interpreter;
   tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
   tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
-  tflite::SubgraphWriter writer(&interpreter->primary_subgraph());
+  tflite::ModelWriter writer(interpreter.get());
   writer.Write(argv[2]);
 
   return 0;
diff --git a/tensorflow/lite/tools/serialization/writer_lib.cc b/tensorflow/lite/tools/serialization/writer_lib.cc
index 0d831f5..7270da5 100644
--- a/tensorflow/lite/tools/serialization/writer_lib.cc
+++ b/tensorflow/lite/tools/serialization/writer_lib.cc
@@ -29,6 +29,41 @@
 #include "tensorflow/lite/version.h"
 
 namespace tflite {
+namespace {
+
+flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
+CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder* fbb,
+                      std::vector<OpCode>* opcodes) {
+  std::vector<flatbuffers::Offset<OperatorCode>> codes;
+  for (const auto& it : *opcodes) {
+    const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
+    codes.push_back(CreateOperatorCodeDirect(
+        *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
+  }
+  return fbb->template CreateVector<flatbuffers::Offset<OperatorCode>>(codes);
+}
+
+flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
+ExportBuffersImpl(flatbuffers::FlatBufferBuilder* fbb,
+                  std::vector<std::pair<const uint8_t*, size_t>>* buffers) {
+  std::vector<flatbuffers::Offset<Buffer>> buffer_vector;
+  for (auto buffer : *buffers) {
+    auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
+    buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
+  }
+  return fbb->template CreateVector<flatbuffers::Offset<Buffer>>(buffer_vector);
+}
+
+TfLiteStatus WriteImpl(const std::string& filename, void* data, size_t size) {
+  FILE* fp = fopen(filename.c_str(), "wb");
+  if (!fp) return kTfLiteError;
+
+  const int result_size = fwrite(data, 1, size, fp);
+  fclose(fp);
+  if (result_size != size) return kTfLiteError;
+
+  return kTfLiteOk;
+}
 
 std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
     flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
@@ -39,6 +74,8 @@
   return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset<void>());
 }
 
+}  // namespace
+
 template <class T_OUTPUT, class T_INPUT>
 flatbuffers::Offset<flatbuffers::Vector<T_OUTPUT>> SubgraphWriter::ExportVector(
     flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v) {
@@ -159,8 +196,8 @@
       // Allocate a buffer index
       int buffer_index = 0;  // This is null
       if (tensor->allocation_type == kTfLiteMmapRo) {
-        buffer_index = buffers_.size();
-        buffers_.push_back(std::make_pair(
+        buffer_index = buffers_->size();
+        buffers_->push_back(std::make_pair(
             reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
       }
       // Primitive type.
@@ -214,23 +251,12 @@
 
 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
 SubgraphWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
-  std::vector<flatbuffers::Offset<Buffer>> buffer_vector;
-  for (auto buffer : buffers_) {
-    auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
-    buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
-  }
-  return fbb->template CreateVector<flatbuffers::Offset<Buffer>>(buffer_vector);
+  return ExportBuffersImpl(fbb, buffers_);
 }
 
 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
 SubgraphWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
-  std::vector<flatbuffers::Offset<OperatorCode>> codes;
-  for (const auto& it : opcodes_) {
-    const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
-    codes.push_back(CreateOperatorCodeDirect(
-        *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
-  }
-  return fbb->template CreateVector<flatbuffers::Offset<OperatorCode>>(codes);
+  return CreateOpCodeTableImpl(fbb, opcodes_);
 }
 
 template <class T>
@@ -254,19 +280,9 @@
                                        size_t* size) {
   if (!out || !size) return kTfLiteError;
   flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
-
   std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
-  {  // subgraph specific stuff
-    auto tensors = ExportTensors(&builder);
-    std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
-    std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
-    auto inputs = ExportVector<int32_t>(&builder, written_inputs);
-    auto outputs = ExportVector<int32_t>(&builder, written_outputs);
+  subgraphs_as_vector.push_back(PopulateAndGetOffset(&builder));
 
-    auto ops = ExportOperators(&builder);
-    subgraphs_as_vector.push_back(
-        CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
-  }
   flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
       buffers = ExportBuffers(&builder);
 
@@ -284,21 +300,23 @@
   return kTfLiteOk;
 }
 
+flatbuffers::Offset<SubGraph> SubgraphWriter::PopulateAndGetOffset(
+    flatbuffers::FlatBufferBuilder* builder) {
+  auto tensors = ExportTensors(builder);
+  std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
+  std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
+  auto inputs = ExportVector<int32_t>(builder, written_inputs);
+  auto outputs = ExportVector<int32_t>(builder, written_outputs);
+
+  auto ops = ExportOperators(builder);
+  return CreateSubGraph(*builder, tensors, inputs, outputs, ops, /* name */ 0);
+}
+
 TfLiteStatus SubgraphWriter::Write(const std::string& filename) {
   std::unique_ptr<uint8_t[]> buffer;
   size_t size;
   TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
-
-  FILE* fp = fopen(filename.c_str(), "wb");
-  if (!fp) return kTfLiteError;
-
-  if (fwrite(buffer.get(), 1, size, fp) != size) {
-    fclose(fp);
-    return kTfLiteError;
-  }
-  if (fclose(fp)) return kTfLiteError;
-
-  return kTfLiteOk;
+  return WriteImpl(filename, buffer.get(), size);
 }
 
 TfLiteStatus SubgraphWriter::RegisterCustomWriter(
@@ -377,4 +395,50 @@
   return kTfLiteOk;
 }
 
+flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
+ModelWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
+  return ExportBuffersImpl(fbb, &buffers_);
+}
+
+flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
+ModelWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
+  return CreateOpCodeTableImpl(fbb, &opcodes_);
+}
+
+TfLiteStatus ModelWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
+                                    size_t* size) {
+  if (!out || !size) return kTfLiteError;
+  flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
+
+  std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
+  for (int i = 0; i < interpreter_->subgraphs_size(); ++i) {
+    SubgraphWriter writer(interpreter_->subgraph(i), &buffers_, &opcodes_,
+                          &builtin_op_to_opcode_);
+    subgraphs_as_vector.push_back(writer.PopulateAndGetOffset(&builder));
+  }
+
+  flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
+      buffers = ExportBuffers(&builder);
+
+  auto description = builder.CreateString("Exported from Subgraph.");
+
+  auto op_codes = CreateOpCodeTable(&builder);
+  auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
+                           builder.CreateVector(subgraphs_as_vector),
+                           description, buffers);
+  ::tflite::FinishModelBuffer(builder, model);
+  const uint8_t* buffer = builder.GetBufferPointer();
+  *size = builder.GetSize();
+  (*out).reset(new uint8_t[*size]);
+  memcpy(out->get(), buffer, *size);
+  return kTfLiteOk;
+}
+
+TfLiteStatus ModelWriter::Write(const std::string& filename) {
+  std::unique_ptr<uint8_t[]> buffer;
+  size_t size;
+  TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
+  return WriteImpl(filename, buffer.get(), size);
+}
+
 }  // namespace tflite
diff --git a/tensorflow/lite/tools/serialization/writer_lib.h b/tensorflow/lite/tools/serialization/writer_lib.h
index a18a3dd..3119278 100644
--- a/tensorflow/lite/tools/serialization/writer_lib.h
+++ b/tensorflow/lite/tools/serialization/writer_lib.h
@@ -12,37 +12,72 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-// Writes a flatbuffer of a currently loaded TensorFlow Lite subgraph.
-//
-// Usage:
-//  From command line:
-//   bazel run third_party/tensorflow/lite/experimental/writer:writer
-//     -- foo.tflite foo.out.tflite
-//
-// From C++
-//   std::unique_ptr<Interpreter> interpreter;
-//   // Build Interpreter however
-//   // ... <omitted>
-//   SubgraphWriter(&interpreter->primary_subgraph()).Write("output.tflite");
+// Library to write a flatbuffer of a currently loaded TFLite model/subgraph.
+
 #ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
 #define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
 #include <iostream>
 #include <unordered_map>
 
+#include "absl/container/flat_hash_map.h"
 #include "tensorflow/lite/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/context_util.h"
 #include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/schema/reflection/schema_generated.h"
 #include "tensorflow/lite/tools/serialization/enum_mapping.h"
 #include "tensorflow/lite/version.h"
 
 namespace tflite {
 
+struct OpCode {
+  int builtin;
+  std::string custom;
+};
+
+// Handles writing a full TFLite model (with 1 or more subgraphs) to a
+// serialized TF lite file format.
+// TODO(b/174708523): Support custom I/O or unused tensors later.
+class ModelWriter {
+ public:
+  // Construct a writer for the specified `interpreter`. Then, use
+  // .Write() or .GetBuffer(...) to extract the data.
+  explicit ModelWriter(Interpreter* interpreter) : interpreter_(interpreter) {
+    buffers_.push_back(std::make_pair(nullptr, 0));
+  }
+
+  // Get a buffer and size of a serialized flatbuffer.
+  TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
+  // Write the serialized flatbuffer to the prescribed `filename`.
+  TfLiteStatus Write(const std::string& filename);
+
+ private:
+  template <class T>
+  using Offset = flatbuffers::Offset<T>;
+  Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
+      flatbuffers::FlatBufferBuilder* fbb);
+  Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
+      flatbuffers::FlatBufferBuilder* fbb);
+
+  // ModelWriter does not take ownership of this object.
+  Interpreter* const interpreter_;
+
+  // This data corresponds to the overall model (rather than individual
+  // subgraphs), so we define common fields. Keep track of byte buffers
+  std::vector<std::pair<const uint8_t*, size_t>> buffers_;
+  // List of used opcodes
+  std::vector<OpCode> opcodes_;
+  absl::flat_hash_map<int, int> builtin_op_to_opcode_;
+};
+
 // Handles writing TensorFlow Lite running subgraph to a serialized TF lite
 // file format.
+// TODO(b/174708523): Reconcile into ModelWriter?
 class SubgraphWriter {
  public:
+  friend class ModelWriter;
+
   typedef flatbuffers::Offset<Operator> (*CustomWriter)(
       flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index,
       flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
@@ -55,7 +90,10 @@
         inputs_(subgraph->inputs()),
         outputs_(subgraph->outputs()),
         execution_plan_(subgraph->execution_plan()) {
-    buffers_.push_back(std::make_pair(nullptr, 0));
+    buffers_ = &buffers_data_;
+    opcodes_ = &opcodes_data_;
+    builtin_op_to_opcode_ = &builtin_op_to_opcode_data_;
+    buffers_->push_back(std::make_pair(nullptr, 0));
   }
 
   // Get a buffer and size of a serialized flatbuffer.
@@ -77,6 +115,28 @@
                                     const std::vector<int>& execution_plan);
 
  private:
+  // Used by ModelWriter.
+  explicit SubgraphWriter(
+      Subgraph* subgraph,
+      std::vector<std::pair<const uint8_t*, size_t>>* external_buffers,
+      std::vector<OpCode>* external_opcodes,
+      absl::flat_hash_map<int, int>* external_builtin_op_to_opcode)
+      : subgraph_(subgraph),
+        inputs_(subgraph->inputs()),
+        outputs_(subgraph->outputs()),
+        execution_plan_(subgraph->execution_plan()) {
+    buffers_ = external_buffers;
+    opcodes_ = external_opcodes;
+    builtin_op_to_opcode_ = external_builtin_op_to_opcode;
+    buffers_->push_back(std::make_pair(nullptr, 0));
+  }
+
+  // Used by ModelWriter to populate data specific to this subgraph.
+  // Global stuff (like opcodes & buffers) is populated into buffers_, opcodes_,
+  // etc. & populated in the Flatbuffer by ModelWriter.
+  flatbuffers::Offset<SubGraph> PopulateAndGetOffset(
+      flatbuffers::FlatBufferBuilder* builder);
+
   template <class T>
   using Offset = flatbuffers::Offset<T>;
   template <class T_OUTPUT, class T_INPUT>
@@ -102,11 +162,11 @@
 
   int GetOpCodeForBuiltin(int builtin_op_index) {
     // auto it = builtin_op_to_opcode_.find(builtin_op_index);
-    std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
-        builtin_op_to_opcode_.insert(
-            std::make_pair(builtin_op_index, opcodes_.size()));
+    std::pair<decltype(builtin_op_to_opcode_data_)::iterator, bool> result =
+        builtin_op_to_opcode_->insert(
+            std::make_pair(builtin_op_index, opcodes_->size()));
     if (result.second) {
-      opcodes_.push_back({builtin_op_index, ""});
+      opcodes_->push_back({builtin_op_index, ""});
     }
     return result.first->second;
   }
@@ -114,9 +174,9 @@
   int GetOpCodeForCustom(const std::string& custom_name) {
     std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
         custom_op_to_opcode_.insert(
-            std::make_pair(custom_name, opcodes_.size()));
+            std::make_pair(custom_name, opcodes_->size()));
     if (result.second) {
-      opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
+      opcodes_->push_back({BuiltinOperator_CUSTOM, custom_name});
     }
     return result.first->second;
   }
@@ -129,22 +189,26 @@
   std::vector<int> outputs_;
   // Order of nodes to be written.
   std::vector<int> execution_plan_;
-  // Keep track of byte buffers
-  std::vector<std::pair<const uint8_t*, size_t>> buffers_;
   // List of op codes and mappings from builtin or custom op to opcode
-  struct OpCode {
-    int builtin;
-    std::string custom;
-  };
   std::set<int> unused_tensors_;
   // For every tensor index in the subgraph, the index in the written.
   // This is different due to temporary and unused tensors not being written.
   std::vector<int> tensor_to_written_tensor_;
-  // List of used opcodes
-  std::vector<OpCode> opcodes_;
-  std::unordered_map<int, int> builtin_op_to_opcode_;
   std::unordered_map<std::string, int> custom_op_to_opcode_;
   std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
+
+  // We use pointers for these, since they may be provided by ModelWriter.
+  // Keep track of byte buffers
+  std::vector<std::pair<const uint8_t*, size_t>>* buffers_;
+  // List of used opcodes
+  std::vector<OpCode>* opcodes_;
+  absl::flat_hash_map<int, int>* builtin_op_to_opcode_;
+
+  // These are used if SubgraphWriter is being used directly.
+  std::vector<std::pair<const uint8_t*, size_t>> buffers_data_;
+  // List of used opcodes
+  std::vector<OpCode> opcodes_data_;
+  absl::flat_hash_map<int, int> builtin_op_to_opcode_data_;
 };
 
 }  // namespace tflite
diff --git a/tensorflow/lite/tools/serialization/writer_lib_test.cc b/tensorflow/lite/tools/serialization/writer_lib_test.cc
index 189b4bc..3f73f3c 100644
--- a/tensorflow/lite/tools/serialization/writer_lib_test.cc
+++ b/tensorflow/lite/tools/serialization/writer_lib_test.cc
@@ -15,21 +15,47 @@
 
 #include "tensorflow/lite/tools/serialization/writer_lib.h"
 
+#include <cstdlib>
 #include <numeric>
 #include <sstream>
+#include <string>
+#include <tuple>
 
 #include <gtest/gtest.h>
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/kernels/subgraph_test_util.h"
 #include "tensorflow/lite/model.h"
 #include "tensorflow/lite/schema/schema_generated.h"
 #include "tensorflow/lite/testing/util.h"
 
 namespace tflite {
-// Make an interpreter that has no tensors and no nodes
-// TODO(b/113731921): add more tests.
-TEST(Writer, FloatModelTest) {
+
+using subgraph_test_util::CheckIntTensor;
+using subgraph_test_util::FillIntTensor;
+
+std::string CreateFilePath(const std::string& file_name) {
+  return std::string(getenv("TEST_TMPDIR")) + file_name;
+}
+
+// The bool param indicates whether we use SubgraphWriter(true) or
+// ModelWriter(false) for the test
+class SingleSubgraphTest : public ::testing::TestWithParam<bool> {
+ protected:
+  void WriteToFile(Interpreter* interpreter, const std::string& filename,
+                   bool use_subgraph_writer) {
+    if (use_subgraph_writer) {
+      SubgraphWriter writer(&interpreter->primary_subgraph());
+      CHECK_EQ(writer.Write(filename), kTfLiteOk);
+    } else {
+      ModelWriter writer(interpreter);
+      CHECK_EQ(writer.Write(filename), kTfLiteOk);
+    }
+  }
+};
+
+TEST_P(SingleSubgraphTest, InvalidDestinations) {
   Interpreter interpreter;
   interpreter.AddTensors(3);
   float foo[] = {1, 2, 3};
@@ -52,10 +78,53 @@
   interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
                                     reinterpret_cast<void*>(builtin_data), reg);
 
-  SubgraphWriter writer(&interpreter.primary_subgraph());
-  writer.Write("/tmp/test_float.tflite");
+  // Check if invalid filename is handled gracefully.
+  if (GetParam()) {
+    SubgraphWriter writer(&interpreter.primary_subgraph());
+    CHECK_EQ(writer.Write(""), kTfLiteError);
+  } else {
+    ModelWriter writer(&interpreter);
+    CHECK_EQ(writer.Write(""), kTfLiteError);
+  }
+
+  // Check if invalid buffer is handled gracefully.
+  size_t size;
+  if (GetParam()) {
+    SubgraphWriter writer(&interpreter.primary_subgraph());
+    CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError);
+  } else {
+    ModelWriter writer(&interpreter);
+    CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError);
+  }
+}
+
+TEST_P(SingleSubgraphTest, FloatModelTest) {
+  Interpreter interpreter;
+  interpreter.AddTensors(3);
+  float foo[] = {1, 2, 3};
+  interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
+                                           TfLiteQuantization());
+  interpreter.SetTensorParametersReadOnly(
+      1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
+      reinterpret_cast<char*>(foo), sizeof(foo));
+  interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
+                                           TfLiteQuantization());
+  interpreter.SetInputs({0, 1});
+  interpreter.SetOutputs({2});
+  const char* initial_data = "";
+  tflite::ops::builtin::BuiltinOpResolver resolver;
+  TfLiteAddParams* builtin_data =
+      reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
+  builtin_data->activation = kTfLiteActNone;
+  builtin_data->pot_scale_int16 = false;
+  const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
+  interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
+                                    reinterpret_cast<void*>(builtin_data), reg);
+
+  const std::string test_file = CreateFilePath("test_float.tflite");
+  WriteToFile(&interpreter, test_file, GetParam());
   std::unique_ptr<FlatBufferModel> model =
-      FlatBufferModel::BuildFromFile("/tmp/test_float.tflite");
+      FlatBufferModel::BuildFromFile(test_file.c_str());
   InterpreterBuilder builder(*model, resolver);
   std::unique_ptr<Interpreter> new_interpreter;
   builder(&new_interpreter);
@@ -63,7 +132,7 @@
 }
 
 // Tests writing only a portion of the subgraph.
-TEST(Writer, CustomInputOutputTest) {
+TEST_P(SingleSubgraphTest, CustomInputOutputTest) {
   Interpreter interpreter;
   interpreter.AddTensors(4);
   constexpr float kFoo[] = {1, 2, 3};
@@ -94,22 +163,23 @@
   interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
 
   // Only write the second op.
+  const std::string test_file = CreateFilePath("test_custom.tflite");
   SubgraphWriter writer(&interpreter.primary_subgraph());
   EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
                                         /*execution_plan=*/{1}),
             kTfLiteOk);
   writer.SetUnusedTensors({0, 1});
-  writer.Write("/tmp/test_custom.tflite");
+  writer.Write(test_file);
 
   std::unique_ptr<FlatBufferModel> model =
-      FlatBufferModel::BuildFromFile("/tmp/test_custom.tflite");
+      FlatBufferModel::BuildFromFile(test_file.c_str());
   InterpreterBuilder builder(*model, resolver);
   std::unique_ptr<Interpreter> new_interpreter;
   builder(&new_interpreter);
   ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
 }
 
-TEST(Writer, CustomInputOutputErrorCasesTest) {
+TEST_P(SingleSubgraphTest, CustomInputOutputErrorCasesTest) {
   Interpreter interpreter;
   interpreter.AddTensors(5);
   constexpr float kFoo[] = {1, 2, 3};
@@ -160,7 +230,7 @@
             kTfLiteOk);
 }
 
-TEST(Writer, PerTensorQuantizedModelTest) {
+TEST_P(SingleSubgraphTest, PerTensorQuantizedModelTest) {
   Interpreter interpreter;
   interpreter.AddTensors(3);
   interpreter.SetTensorParametersReadWrite(
@@ -181,16 +251,18 @@
   interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
                                     reinterpret_cast<void*>(builtin_data), reg);
 
-  SubgraphWriter writer(&interpreter.primary_subgraph());
-  writer.Write("/tmp/test_uint8.tflite");
+  const std::string test_file = CreateFilePath("test_uint8.tflite");
+  WriteToFile(&interpreter, test_file, GetParam());
   std::unique_ptr<FlatBufferModel> model =
-      FlatBufferModel::BuildFromFile("/tmp/test_uint8.tflite");
+      FlatBufferModel::BuildFromFile(test_file.c_str());
   InterpreterBuilder builder(*model, resolver);
   std::unique_ptr<Interpreter> new_interpreter;
   builder(&new_interpreter);
   CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
 }
 
+INSTANTIATE_TEST_SUITE_P(Writer, SingleSubgraphTest, ::testing::Bool());
+
 struct ReshapeTestPattern {
   int num_inputs;
   bool is_param_valid;
@@ -241,8 +313,8 @@
 
   SubgraphWriter writer(&interpreter.primary_subgraph());
   std::stringstream ss;
-  ss << "/tmp/test_reshape_" << param.num_inputs << param.is_param_valid
-     << ".tflite";
+  ss << CreateFilePath("test_reshape_") << param.num_inputs
+     << param.is_param_valid << ".tflite";
   std::string filename = ss.str();
   writer.Write(filename);
   std::unique_ptr<FlatBufferModel> model =
@@ -268,6 +340,57 @@
       std::string name = ss.str();
       return name;
     });
+
+class WhileTest : public subgraph_test_util::ControlFlowOpTest {};
+
+// The test builds a model that produces the i-th number of
+// triangular number sequence: 1, 3, 6, 10, 15, 21, 28.
+TEST_F(WhileTest, TestTriangularNumberSequence) {
+  const int kSeqNumber = 4;
+  const int kExpectedValue = 15;
+
+  interpreter_.reset(new Interpreter);
+  interpreter_->AddSubgraphs(2);
+  builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), kSeqNumber);
+  builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2));
+  builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
+
+  interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
+  interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
+  ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
+  FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
+  FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1});
+
+  ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
+  TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
+  CheckIntTensor(output1, {1}, {kSeqNumber + 1});
+  TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
+  CheckIntTensor(output2, {1}, {kExpectedValue});
+
+  // Now serialize & deserialize model into a new Interpreter.
+  ModelWriter writer(interpreter_.get());
+  const std::string test_file = CreateFilePath("test_while.tflite");
+  writer.Write(test_file);
+  std::unique_ptr<FlatBufferModel> model =
+      FlatBufferModel::BuildFromFile(test_file.c_str());
+  tflite::ops::builtin::BuiltinOpResolver resolver;
+  InterpreterBuilder builder(*model, resolver);
+  std::unique_ptr<Interpreter> new_interpreter;
+  builder(&new_interpreter);
+
+  // Check deserialized model.
+  new_interpreter->ResizeInputTensor(interpreter_->inputs()[0], {1});
+  new_interpreter->ResizeInputTensor(interpreter_->inputs()[1], {1});
+  ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
+  FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[0]), {1});
+  FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[1]), {1});
+  ASSERT_EQ(new_interpreter->Invoke(), kTfLiteOk);
+  output1 = new_interpreter->tensor(interpreter_->outputs()[0]);
+  CheckIntTensor(output1, {1}, {kSeqNumber + 1});
+  output2 = new_interpreter->tensor(interpreter_->outputs()[1]);
+  CheckIntTensor(output2, {1}, {kExpectedValue});
+}
+
 }  // namespace tflite
 
 int main(int argc, char** argv) {
diff --git a/tensorflow/lite/tools/serialization/writer_test.cc b/tensorflow/lite/tools/serialization/writer_test.cc
index ccaab76..2ad77df 100644
--- a/tensorflow/lite/tools/serialization/writer_test.cc
+++ b/tensorflow/lite/tools/serialization/writer_test.cc
@@ -35,7 +35,7 @@
   std::unique_ptr<tflite::Interpreter> interpreter;
   tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
   tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
-  tflite::SubgraphWriter writer(&interpreter->primary_subgraph());
+  tflite::ModelWriter writer(interpreter.get());
   std::unique_ptr<uint8_t[]> output_buffer;
   size_t output_buffer_size;
   writer.GetBuffer(&output_buffer, &output_buffer_size);
diff --git a/tensorflow/lite/tools/test_utils.py b/tensorflow/lite/tools/test_utils.py
index 9840203..f2d0e29 100644
--- a/tensorflow/lite/tools/test_utils.py
+++ b/tensorflow/lite/tools/test_utils.py
@@ -196,6 +196,40 @@
   builder.PrependUOffsetTRelative(subgraph_offset)
   subgraphs_offset = builder.EndVector(1)
 
+  signature_method = builder.CreateString('my_method')
+  signature_key = builder.CreateString('my_key')
+  input_tensor_string = builder.CreateString('input_tensor')
+  output_tensor_string = builder.CreateString('output_tensor')
+
+  # Signature Inputs
+  schema_fb.TensorMapStart(builder)
+  schema_fb.TensorMapAddName(builder, input_tensor_string)
+  schema_fb.TensorMapAddTensorIndex(builder, 1)
+  input_tensor = schema_fb.TensorMapEnd(builder)
+
+  # Signature Outputs
+  schema_fb.TensorMapStart(builder)
+  schema_fb.TensorMapAddName(builder, output_tensor_string)
+  schema_fb.TensorMapAddTensorIndex(builder, 2)
+  output_tensor = schema_fb.TensorMapEnd(builder)
+
+  schema_fb.SignatureDefStartInputsVector(builder, 1)
+  builder.PrependUOffsetTRelative(input_tensor)
+  signature_inputs_offset = builder.EndVector(1)
+  schema_fb.SignatureDefStartOutputsVector(builder, 1)
+  builder.PrependUOffsetTRelative(output_tensor)
+  signature_outputs_offset = builder.EndVector(1)
+
+  schema_fb.SignatureDefStart(builder)
+  schema_fb.SignatureDefAddKey(builder, signature_key)
+  schema_fb.SignatureDefAddMethodName(builder, signature_method)
+  schema_fb.SignatureDefAddInputs(builder, signature_inputs_offset)
+  schema_fb.SignatureDefAddOutputs(builder, signature_outputs_offset)
+  signature_offset = schema_fb.SignatureDefEnd(builder)
+  schema_fb.ModelStartSignatureDefsVector(builder, 1)
+  builder.PrependUOffsetTRelative(signature_offset)
+  signature_defs_offset = builder.EndVector(1)
+
   string4_offset = builder.CreateString('model_description')
   schema_fb.ModelStart(builder)
   schema_fb.ModelAddVersion(builder, TFLITE_SCHEMA_VERSION)
@@ -203,6 +237,7 @@
   schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
   schema_fb.ModelAddDescription(builder, string4_offset)
   schema_fb.ModelAddBuffers(builder, buffers_offset)
+  schema_fb.ModelAddSignatureDefs(builder, signature_defs_offset)
   model_offset = schema_fb.ModelEnd(builder)
   builder.Finish(model_offset)
   model = builder.Output()
diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc
index 546e88b..dcb154a 100644
--- a/tensorflow/lite/tools/verifier.cc
+++ b/tensorflow/lite/tools/verifier.cc
@@ -50,8 +50,13 @@
   }
 }
 // Returns the int32_t value pointed by ptr.
-const uint32_t* GetIntPtr(const char* ptr) {
-  return reinterpret_cast<const uint32_t*>(ptr);
+const uint32_t GetIntPtr(const char* ptr) {
+#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
+    __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+  return flatbuffers::EndianScalar(*reinterpret_cast<const uint32_t*>(ptr));
+#else
+  return *reinterpret_cast<const uint32_t*>(ptr);
+#endif
 }
 
 // Verifies flatbuffer format of the model contents and returns the in-memory
@@ -79,7 +84,7 @@
   }
   const char* buffer_ptr = reinterpret_cast<const char*>(buffer.data()->data());
 
-  uint32_t num_strings = *GetIntPtr(buffer_ptr);
+  uint32_t num_strings = GetIntPtr(buffer_ptr);
   if (num_strings > kMaxNumString) {
     ReportError(error_reporter,
                 "String tensor %s has invalid num of string set: %d",
@@ -100,7 +105,7 @@
   uint32_t prev_ptr = header_offsets;
   uint32_t offset = sizeof(int32_t);
 
-  if (*GetIntPtr(buffer_ptr + offset) != header_offsets) {
+  if (GetIntPtr(buffer_ptr + offset) != header_offsets) {
     ReportError(error_reporter,
                 "String tensor %s buffer initial offset must be: %d",
                 NameOrEmptyString(tensor.name()), header_offsets);
@@ -108,7 +113,7 @@
   }
   offset += sizeof(int32_t);
   for (int i = 1, end = num_strings; i <= end; i++, offset += sizeof(int32_t)) {
-    int string_offset = *GetIntPtr(buffer_ptr + offset);
+    int string_offset = GetIntPtr(buffer_ptr + offset);
     if (string_offset < static_cast<int>(prev_ptr) ||
         string_offset > static_cast<int>(buffer_size)) {
       ReportError(error_reporter,
@@ -117,7 +122,7 @@
       return false;
     }
   }
-  if (*GetIntPtr(buffer_ptr + offset - sizeof(int32_t)) != buffer_size) {
+  if (GetIntPtr(buffer_ptr + offset - sizeof(int32_t)) != buffer_size) {
     ReportError(error_reporter,
                 "String tensor %s buffer last offset must be %d",
                 NameOrEmptyString(tensor.name()), buffer_size);
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index e657a66..c7c08c8 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -620,10 +620,7 @@
     case BuiltinOperator_SELECT:
     case BuiltinOperator_RSQRT:
     case BuiltinOperator_SQUARED_DIFFERENCE:
-      if (op_sig.input_types.at(0) == TensorType_INT8) {
-        return 2;
-      }
-      return 1;
+    case BuiltinOperator_DEPTH_TO_SPACE:
     case BuiltinOperator_MIRROR_PAD:
       if (op_sig.input_types.at(0) == TensorType_INT8) {
         return 2;
diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc
index bda02ec..3b418d9 100644
--- a/tensorflow/lite/tools/versioning/runtime_version.cc
+++ b/tensorflow/lite/tools/versioning/runtime_version.cc
@@ -61,6 +61,7 @@
               {{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"},
               {{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"},
               {{BuiltinOperator_BATCH_MATMUL, 3}, "2.4.0"},
+              {{BuiltinOperator_BATCH_MATMUL, 4}, kPendingReleaseVersion},
               // The version one of broadcast to op won't be not supported since
               // the version one was rollbacked and the builtin op code number
               // has been changed because of builtin op code shortage problem.
@@ -100,6 +101,7 @@
               {{BuiltinOperator_CONCATENATION, 2}, "1.14.0"},
               {{BuiltinOperator_CONCATENATION, 3}, "2.3.0"},
               {{BuiltinOperator_DEPTH_TO_SPACE, 1}, "2.1.0"},
+              {{BuiltinOperator_DEPTH_TO_SPACE, 2}, kPendingReleaseVersion},
               {{BuiltinOperator_EMBEDDING_LOOKUP, 1}, "1.13.0"},
               {{BuiltinOperator_EMBEDDING_LOOKUP, 2}, "1.14.0"},
               {{BuiltinOperator_EMBEDDING_LOOKUP, 3}, "1.14.0"},
diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files
index bf8bd7a..ad834a3 100644
--- a/tensorflow/opensource_only.files
+++ b/tensorflow/opensource_only.files
@@ -4,6 +4,7 @@
 tensorflow/compat_template.__init__.py
 tensorflow/compat_template_v1.__init__.py
 tensorflow/compiler/mlir/glob_lit_test.bzl
+tensorflow/compiler/mlir/hlo/WORKSPACE
 tensorflow/go/op/wrappers.go
 tensorflow/lite/core/shims/BUILD
 tensorflow/lite/core/shims/c/builtin_op_data.h
@@ -18,6 +19,7 @@
 tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h
 tensorflow/lite/delegates/gpu/cl/serialization_generated.h
 tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
+tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h
 tensorflow/lite/micro/build_def.bzl
 tensorflow/lite/schema/schema_generated.h
 tensorflow/python/autograph/core/config.py
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 9985f27..96e01db 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -49,6 +49,7 @@
     "//tensorflow_models:__subpackages__",
     "//tensorflow_model_optimization:__subpackages__",
     "//third_party/py/cleverhans:__subpackages__",
+    "//third_party/py/launchpad:__subpackages__",
     "//third_party/py/reverb:__subpackages__",
     "//third_party/py/neural_structured_learning:__subpackages__",
     "//third_party/py/tensorflow_examples:__subpackages__",
@@ -123,16 +124,10 @@
         "//third_party/py/tensorflow_core:__subpackages__",
     ],
     deps = [
-        ":_pywrap_checkpoint_reader",
         ":_pywrap_events_writer",
-        ":_pywrap_kernel_registry",
         ":_pywrap_py_exception_registry",
         ":_pywrap_python_op_gen",
         ":_pywrap_quantize_training",
-        ":_pywrap_stat_summarizer",
-        ":_pywrap_tfprof",
-        ":_pywrap_transform_graph",
-        ":_pywrap_util_port",
         ":_pywrap_utils",
         ":array_ops",
         ":audio_ops_gen",
@@ -197,7 +192,6 @@
         ":tf_item",
         ":tf_optimizer",
         ":training",
-        ":util",
         ":weights_broadcast_ops",
         ":while_v2",
         "//tensorflow/core:protos_all_py",
@@ -238,6 +232,13 @@
         "//tensorflow/python/tpu:tpu_noestimator",
         "//tensorflow/python/training:saver_test_utils",
         "//tensorflow/python/types",
+        "//tensorflow/python/util",
+        "//tensorflow/python/util:_pywrap_checkpoint_reader",
+        "//tensorflow/python/util:_pywrap_kernel_registry",
+        "//tensorflow/python/util:_pywrap_stat_summarizer",
+        "//tensorflow/python/util:_pywrap_tfprof",
+        "//tensorflow/python/util:_pywrap_transform_graph",
+        "//tensorflow/python/util:_pywrap_util_port",
         "//third_party/py/numpy",
     ],
 )
@@ -249,8 +250,8 @@
     ],
     deps = [
         ":tf_decorator",
-        ":tf_stack",
         "//tensorflow/python/types",
+        "//tensorflow/python/util:tf_stack",
     ],
 )
 
@@ -273,28 +274,6 @@
 
 # TODO(gunan): Investigate making this action hermetic so we do not need
 # to run it locally.
-
-py_library(
-    name = "platform",
-    visibility = ["//visibility:public"],
-    deps = ["//tensorflow/python/platform"],
-)
-
-py_library(
-    name = "platform_benchmark",
-    deps = ["//tensorflow/python/platform:benchmark"],
-)
-
-py_library(
-    name = "platform_analytics",
-    deps = ["//tensorflow/python/platform:analytics"],
-)
-
-py_library(
-    name = "platform_test",
-    deps = ["//tensorflow/python/platform:test"],
-)
-
 cc_library(
     name = "cost_analyzer_lib",
     srcs = ["grappler/cost_analyzer.cc"],
@@ -337,12 +316,12 @@
     module_name = "_pywrap_cost_analyzer",
     deps = [
         ":cost_analyzer_headers",
-        ":pybind11_status",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_headers_for_pybind",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/common_runtime:core_cpu_headers_lib",
         "//tensorflow/core/common_runtime/gpu:gpu_id",
+        "//tensorflow/python/lib/core:pybind11_status",
         "@pybind11",
     ],
 )
@@ -369,46 +348,33 @@
     ],
     module_name = "_pywrap_model_analyzer",
     deps = [
-        ":pybind11_status",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_headers_for_pybind",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python/lib/core:pybind11_status",
         "@pybind11",
     ],
 )
 
-cc_library(
-    name = "numpy_lib",
-    srcs = ["lib/core/numpy.cc"],
-    hdrs = ["lib/core/numpy.h"],
-    deps = [
-        "//third_party/py/numpy:headers",
-        "//third_party/python_runtime:headers",
+alias(
+    name = "util",
+    actual = "//tensorflow/python/util:util",
+    visibility = visibility + [
+        "//tensorflow:__pkg__",
+        "//third_party/py/tensorflow_core:__subpackages__",
+        "//third_party/py/tf_agents:__subpackages__",
+        "//third_party/py/tfx:__subpackages__",
     ],
 )
 
-cc_library(
+alias(
+    name = "tf_decorator",
+    actual = "//tensorflow/python/util:tf_decorator",
+)
+
+alias(
     name = "bfloat16_lib",
-    srcs = ["lib/core/bfloat16.cc"],
-    hdrs = ["lib/core/bfloat16.h"],
-    deps = [
-        ":numpy_lib",
-        ":safe_ptr",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//third_party/python_runtime:headers",
-    ],
-)
-
-tf_python_pybind_extension(
-    name = "_pywrap_bfloat16",
-    srcs = ["lib/core/bfloat16_wrapper.cc"],
-    hdrs = ["lib/core/bfloat16.h"],
-    module_name = "_pywrap_bfloat16",
-    deps = [
-        "//third_party/python_runtime:headers",
-        "@pybind11",
-    ],
+    actual = "//tensorflow/python/lib/core:bfloat16_lib",
 )
 
 # Necessary for the pywrap inclusion below.
@@ -429,125 +395,13 @@
         ":tfcompile_headers_lib",
         "@pybind11",
         "//third_party/python_runtime:headers",
-        ":pybind11_lib",
-        ":pybind11_status",
+        "//tensorflow/python/lib/core:pybind11_lib",
+        "//tensorflow/python/lib/core:pybind11_status",
         # The headers here cannot be brought in via cc_header_only_library
         "//tensorflow/compiler/aot:llvm_targets",
     ],
 )
 
-cc_library(
-    name = "ndarray_tensor_bridge",
-    srcs = ["lib/core/ndarray_tensor_bridge.cc"],
-    hdrs = ["lib/core/ndarray_tensor_bridge.h"],
-    visibility = tf_external_workspace_visible(
-        visibility + [
-            "//tensorflow:ndarray_tensor_allow_list",
-        ],
-    ),
-    deps = [
-        ":bfloat16_lib",
-        ":numpy_lib",
-        "//tensorflow/c:c_api_no_xla",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-    ],
-)
-
-cc_library(
-    name = "py_exception_registry",
-    srcs = ["lib/core/py_exception_registry.cc"],
-    hdrs = ["lib/core/py_exception_registry.h"],
-    deps = [
-        "//tensorflow/c:tf_status_headers",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-        "//third_party/python_runtime:headers",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
-    name = "pybind11_absl",
-    hdrs = ["lib/core/pybind11_absl.h"],
-    features = ["-parse_headers"],
-    visibility = tf_external_workspace_visible(visibility),
-    deps = [
-        "//tensorflow/core/platform:stringpiece",
-        "@pybind11",
-    ],
-)
-
-cc_library(
-    name = "pybind11_lib",
-    hdrs = ["lib/core/pybind11_lib.h"],
-    compatible_with = get_compatible_with_portable(),
-    features = ["-parse_headers"],
-    visibility = tf_external_workspace_visible(visibility),
-    deps = [
-        "@pybind11",
-    ],
-)
-
-cc_library(
-    name = "pybind11_status_headers",
-    hdrs = [
-        "lib/core/py_exception_registry.h",
-        "lib/core/pybind11_status.h",
-        "//tensorflow/c:headers",
-        "//tensorflow/c/eager:headers",
-    ],
-    features = [
-        "-parse_headers",
-    ],
-    visibility = tf_external_workspace_visible(visibility),
-    deps = [
-        "//tensorflow/c:tf_status_headers",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
-        "//third_party/python_runtime:headers",
-        "@pybind11",
-    ],
-)
-
-cc_library(
-    name = "pybind11_status",
-    hdrs = [
-        "lib/core/py_exception_registry.h",
-        "lib/core/pybind11_status.h",
-        "//tensorflow/c:headers",
-    ],
-    features = ["-parse_headers"],
-    visibility = tf_external_workspace_visible(visibility),
-    deps = [
-        ":pybind11_status_headers",
-        "//tensorflow/core:lib",
-    ],
-)
-
-cc_library(
-    name = "pybind11_proto",
-    hdrs = ["lib/core/pybind11_proto.h"],
-    features = ["-parse_headers"],
-    visibility = tf_external_workspace_visible(visibility),
-    deps = [
-        "@com_google_absl//absl/strings",
-        "@pybind11",
-    ],
-)
-
-cc_library(
-    name = "kernel_registry",
-    srcs = ["util/kernel_registry.cc"],
-    hdrs = ["util/kernel_registry.h"],
-    deps = [
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-    ],
-    alwayslink = 1,
-)
-
 py_library(
     name = "pywrap_tf_session",
     srcs = ["client/pywrap_tf_session.py"],
@@ -563,8 +417,6 @@
     srcs = ["client/tf_session_wrapper.cc"],
     hdrs = [
         "client/tf_session_helper.h",
-        "lib/core/numpy.h",
-        "lib/core/safe_ptr.h",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
         "//tensorflow/c/eager:pywrap_required_hdrs",
@@ -572,12 +424,14 @@
         "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
         "//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
         "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
+        "//tensorflow/python/lib/core:numpy_hdr",
+        "//tensorflow/python/lib/core:safe_ptr_hdr",
     ],
     module_name = "_pywrap_tf_session",
     deps = [
-        ":pybind11_lib",
-        ":pybind11_status",
-        ":safe_pyobject_ptr",
+        "//tensorflow/python/lib/core:pybind11_lib",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
         "//tensorflow/core/framework:pywrap_required_hdrs",
         "//third_party/py/numpy:headers",
         "//tensorflow/c:pywrap_required_hdrs",
@@ -604,46 +458,9 @@
     ),
 )
 
-tf_python_pybind_extension(
-    name = "_pywrap_tfprof",
-    srcs = ["util/tfprof_wrapper.cc"],
-    module_name = "_pywrap_tfprof",
-    deps = [
-        "//tensorflow/core:framework_headers_lib",
-        "//tensorflow/core:lib_headers_for_pybind",
-        "//tensorflow/core/profiler/internal:print_model_analysis_hdr",
-        "//third_party/eigen3",
-        "//third_party/python_runtime:headers",
-        "@com_google_absl//absl/strings",
-        "@pybind11",
-    ],
-)
-
-tf_python_pybind_extension(
+alias(
     name = "_pywrap_utils",
-    srcs = ["util/util_wrapper.cc"],
-    hdrs = ["util/util.h"],
-    module_name = "_pywrap_utils",
-    deps = [
-        ":pybind11_lib",
-        "//third_party/python_runtime:headers",
-        "@pybind11",
-    ],
-)
-
-tf_python_pybind_extension(
-    name = "_pywrap_kernel_registry",
-    srcs = ["util/kernel_registry_wrapper.cc"],
-    hdrs = ["util/kernel_registry.h"],
-    module_name = "_pywrap_kernel_registry",
-    deps = [
-        ":pybind11_lib",
-        "//tensorflow/core:framework_headers_lib",
-        "//tensorflow/core:lib_headers_for_pybind",
-        "//tensorflow/core:protos_all_cc",
-        "//third_party/python_runtime:headers",
-        "@pybind11",
-    ],
+    actual = "//tensorflow/python/util:_pywrap_utils",
 )
 
 tf_python_pybind_extension(
@@ -654,13 +471,13 @@
     hdrs = ["//tensorflow/core/common_runtime:quantize_training_hdrs"],
     module_name = "_pywrap_quantize_training",
     deps = [
-        ":pybind11_lib",
-        ":pybind11_proto",
-        ":pybind11_status",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_headers_for_pybind",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/common_runtime:core_cpu_headers_lib",
+        "//tensorflow/python/lib/core:pybind11_lib",
+        "//tensorflow/python/lib/core:pybind11_proto",
+        "//tensorflow/python/lib/core:pybind11_status",
         "//third_party/python_runtime:headers",
         "@com_google_absl//absl/strings",
         "@pybind11",
@@ -668,53 +485,16 @@
 )
 
 tf_python_pybind_extension(
-    name = "_pywrap_stat_summarizer",
-    srcs = ["util/stat_summarizer_wrapper.cc"],
-    module_name = "_pywrap_stat_summarizer",
-    deps = [
-        "//tensorflow/core:framework_headers_lib",
-        "//tensorflow/core:protos_all_cc",
-        "//third_party/eigen3",
-        "//third_party/python_runtime:headers",
-        "@com_google_absl//absl/memory",
-        "@pybind11",
-    ],
-)
-
-tf_python_pybind_extension(
-    name = "_pywrap_tensor_float_32_execution",
-    srcs = ["util/tensor_float_32.cc"],
-    hdrs = ["//tensorflow/core/platform:tensor_float_32_hdr"],
-    compatible_with = get_compatible_with_portable(),
-    module_name = "_pywrap_tensor_float_32_execution",
-    deps = [
-        "@pybind11",
-    ],
-)
-
-tf_python_pybind_extension(
-    name = "_pywrap_util_port",
-    srcs = ["util/port_wrapper.cc"],
-    hdrs = ["//tensorflow/core/util:port_hdrs"],
-    module_name = "_pywrap_util_port",
-    deps = [
-        "//tensorflow/core/util:port",
-        "//third_party/python_runtime:headers",
-        "@pybind11",
-    ],
-)
-
-tf_python_pybind_extension(
     name = "_pywrap_debug_events_writer",
     srcs = ["client/debug_events_writer_wrapper.cc"],
     module_name = "_pywrap_debug_events_writer",
     deps = [
-        ":pybind11_absl",
-        ":pybind11_proto",
-        ":pybind11_status",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_headers_for_pybind",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python/lib/core:pybind11_absl",
+        "//tensorflow/python/lib/core:pybind11_proto",
+        "//tensorflow/python/lib/core:pybind11_status",
         "//third_party/python_runtime:headers",
         "@com_google_absl//absl/strings",
         "@pybind11",
@@ -726,75 +506,26 @@
     srcs = ["client/events_writer_wrapper.cc"],
     module_name = "_pywrap_events_writer",
     deps = [
-        ":pybind11_absl",
-        ":pybind11_proto",
-        ":pybind11_status",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_headers_for_pybind",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python/lib/core:pybind11_absl",
+        "//tensorflow/python/lib/core:pybind11_proto",
+        "//tensorflow/python/lib/core:pybind11_status",
         "//third_party/python_runtime:headers",
         "@com_google_absl//absl/strings",
         "@pybind11",
     ],
 )
 
-tf_python_pybind_extension(
-    name = "_pywrap_transform_graph",
-    srcs = ["util/transform_graph_wrapper.cc"],
-    hdrs = ["//tensorflow/tools/graph_transforms:transform_graph_hdrs"],
-    module_name = "_pywrap_transform_graph",
-    deps = [
-        ":pybind11_status",
-        "//tensorflow/core:framework_headers_lib",
-        "//tensorflow/core:lib_headers_for_pybind",
-        "//tensorflow/core:protos_all_cc",
-        "//third_party/python_runtime:headers",
-        "@pybind11",
-    ],
-)
-
-tf_python_pybind_extension(
-    name = "_pywrap_checkpoint_reader",
-    srcs = ["util/py_checkpoint_reader_wrapper.cc"],
-    hdrs = [
-        "lib/core/ndarray_tensor.h",
-        "lib/core/safe_ptr.h",
-        ":py_exception_registry_hdr",
-        "//tensorflow/c:checkpoint_reader_hdrs",
-        "//tensorflow/c:headers",
-        "//tensorflow/c/eager:headers",
-    ],
-    module_name = "_pywrap_checkpoint_reader",
-    deps = [
-        ":pybind11_lib",
-        ":pybind11_status",
-        ":safe_pyobject_ptr",
-        "//tensorflow/core:lib_headers_for_pybind",
-        "//tensorflow/core:op_gen_lib",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/util/tensor_bundle:tensor_bundle_headers_lib",
-        "//third_party/py/numpy:headers",
-        "//third_party/python_runtime:headers",
-        "@com_google_absl//absl/strings",
-        "@pybind11",
-    ],
-)
-
-filegroup(
-    name = "py_exception_registry_hdr",
-    srcs = [
-        "lib/core/py_exception_registry.h",
-    ],
-    visibility = ["//visibility:public"],
-)
-
+# TODO(yanhuasun): Move this back and the source file back to lib/core directory.
 tf_python_pybind_extension(
     name = "_pywrap_py_exception_registry",
-    srcs = ["lib/core/py_exception_registry_wrapper.cc"],
+    srcs = ["py_exception_registry_wrapper.cc"],
     hdrs = [
-        ":py_exception_registry_hdr",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
+        "//tensorflow/python/lib/core:py_exception_registry_hdr",
     ],
     module_name = "_pywrap_py_exception_registry",
     deps = [
@@ -814,7 +545,7 @@
     hdrs = ["//tensorflow/lite/toco/python:toco_python_api_hdrs"],
     module_name = "_pywrap_toco_api",
     deps = [
-        ":pybind11_lib",
+        "//tensorflow/python/lib/core:pybind11_lib",
         "//third_party/python_runtime:headers",
         "@pybind11",
     ],
@@ -824,188 +555,12 @@
 # targets that depend are relying on cpp_python_util to pull in safe_ptr's
 # third_party/tensorflow/c:c_api_no_xla dependency, which registers
 # ops/gradients, rather than depending on it themselves.)
-cc_library(
-    name = "cpp_python_util",
-    srcs = ["util/util.cc"],
-    hdrs = ["util/util.h"],
-    deps = [
-        ":safe_ptr",
-        ":safe_pyobject_ptr",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
-        "//third_party/python_runtime:headers",
-        "@com_google_absl//absl/memory",
-    ],
-)
-
-cc_library(
-    name = "py_func_lib",
-    srcs = ["lib/core/py_func.cc"],
-    hdrs = ["lib/core/py_func.h"],
-    deps = [
-        ":ndarray_tensor",
-        ":ndarray_tensor_bridge",
-        ":numpy_lib",
-        ":py_util",
-        ":safe_ptr",
-        "//tensorflow/c:tf_status_helper",
-        "//tensorflow/c/eager:c_api",
-        "//tensorflow/c/eager:tfe_context_internal",
-        "//tensorflow/c/eager:tfe_tensorhandle_internal",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:script_ops_op_lib",
-        "//tensorflow/core/common_runtime/eager:context",
-        "//tensorflow/core/common_runtime/eager:tensor_handle",
-        "//tensorflow/python/eager:pywrap_tfe_lib",
-        "//third_party/py/numpy:headers",
-        "//third_party/python_runtime:headers",
-    ],
-    alwayslink = 1,
-)
-
 cc_header_only_library(
     name = "py_func_headers_lib",
     features = ["-parse_headers"],
     tags = ["no-ide"],
     deps = [
-        ":py_func_lib",
-    ],
-)
-
-tf_python_pybind_extension(
-    name = "_pywrap_py_func",
-    srcs = ["lib/core/py_func_wrapper.cc"],
-    module_name = "_pywrap_py_func",
-    deps = [
-        ":py_func_headers_lib",
-        "//third_party/python_runtime:headers",
-        "@pybind11",
-    ],
-)
-
-cc_library(
-    name = "safe_pyobject_ptr",
-    srcs = ["lib/core/safe_pyobject_ptr.cc"],
-    hdrs = ["lib/core/safe_pyobject_ptr.h"],
-    deps = [
-        "//third_party/python_runtime:headers",
-    ],
-)
-
-cc_library(
-    name = "safe_pyobject_ptr_required_hdrs",
-    textual_hdrs = ["lib/core/safe_pyobject_ptr.h"],
-)
-
-cc_library(
-    name = "safe_ptr",
-    srcs = [
-        "lib/core/safe_ptr.cc",
-        "//tensorflow/c/eager:headers",
-    ],
-    hdrs = ["lib/core/safe_ptr.h"],
-    deps = [
-        ":safe_pyobject_ptr",
-        "//tensorflow/c:c_api_no_xla",
-        "//third_party/python_runtime:headers",
-    ],
-)
-
-cc_library(
-    name = "ndarray_tensor_headers",
-    hdrs = [
-        "lib/core/bfloat16.h",
-        "lib/core/ndarray_tensor.h",
-        "lib/core/ndarray_tensor_bridge.h",
-        "lib/core/numpy.h",
-        "lib/core/safe_ptr.h",
-        "lib/core/safe_pyobject_ptr.h",
-        "//tensorflow/c:headers",
-        "//tensorflow/c/eager:headers",
-    ],
-    features = [
-        "-parse_headers",
-    ],
-    visibility = tf_external_workspace_visible(visibility + [
-        "//tensorflow:ndarray_tensor_allow_list",
-    ]),
-    deps = [
-        ":numpy_lib",
-        "//tensorflow/c:pywrap_required_hdrs",
-        "//tensorflow/c:tf_status_headers",
-        "//tensorflow/core:framework_internal_headers_lib",
-        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
-        "//third_party/py/numpy:headers",
-        "//third_party/python_runtime:headers",
-    ],
-)
-
-cc_library(
-    name = "ndarray_tensor",
-    srcs = ["lib/core/ndarray_tensor.cc"],
-    hdrs = ["lib/core/ndarray_tensor.h"],
-    visibility = tf_external_workspace_visible(visibility + [
-        "//tensorflow:ndarray_tensor_allow_list",
-    ]),
-    deps = [
-        ":bfloat16_lib",
-        ":ndarray_tensor_bridge",
-        ":numpy_lib",
-        ":safe_ptr",
-        "//tensorflow/c:c_api_internal",
-        "//tensorflow/c:tf_status_helper",
-        "//tensorflow/c:tf_tensor_internal",
-        "//tensorflow/c/eager:tfe_context_internal",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-    ],
-)
-
-cc_library(
-    name = "py_seq_tensor",
-    srcs = ["lib/core/py_seq_tensor.cc"],
-    hdrs = ["lib/core/py_seq_tensor.h"],
-    features = ["-parse_headers"],
-    deps = [
-        ":ndarray_tensor",
-        ":ndarray_tensor_bridge",
-        ":numpy_lib",
-        ":py_util",
-        ":safe_ptr",
-        "//tensorflow/c:tensor_interface",
-        "//tensorflow/c:tf_tensor_internal",
-        "//tensorflow/c/eager:c_api_internal",
-        "//tensorflow/c/eager:tfe_context_internal",
-        "//tensorflow/c/eager:tfe_tensorhandle_internal",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//third_party/python_runtime:headers",  # build_cleaner: keep; DNR: b/35864863
-    ],
-)
-
-cc_library(
-    name = "py_util",
-    srcs = ["lib/core/py_util.cc"],
-    hdrs = ["lib/core/py_util.h"],
-    deps = [
-        "//tensorflow/core:lib",
-        "//tensorflow/core:script_ops_op_lib",
-        "//tensorflow/core/platform:logging",
-        "//third_party/python_runtime:headers",
-    ],
-)
-
-cc_library(
-    name = "py_record_reader_lib",
-    srcs = ["lib/io/py_record_reader.cc"],
-    hdrs = ["lib/io/py_record_reader.h"],
-    deps = [
-        "//tensorflow/c:c_api",
-        "//tensorflow/c:tf_status_helper",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
+        "//tensorflow/python/lib/core:py_func_lib",
     ],
 )
 
@@ -1046,50 +601,7 @@
         ":framework_for_generated_wrappers",
         ":io_ops",
         ":platform",
-        ":util",
-    ],
-)
-
-tf_py_test(
-    name = "decorator_utils_test",
-    srcs = ["util/decorator_utils_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":platform",
-        ":util",
-    ],
-)
-
-tf_py_test(
-    name = "deprecation_test",
-    srcs = ["util/deprecation_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":platform",
-        ":util",
-    ],
-)
-
-tf_py_test(
-    name = "dispatch_test",
-    srcs = ["util/dispatch_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":platform",
-        ":util",
-    ],
-)
-
-tf_py_test(
-    name = "keyword_args_test",
-    srcs = ["util/keyword_args_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -1132,9 +644,9 @@
     srcs = ["framework/python_op_gen_wrapper.cc"],
     module_name = "_pywrap_python_op_gen",
     deps = [
-        ":pybind11_absl",
-        ":pybind11_lib",
         ":python_op_gen_headers_lib",
+        "//tensorflow/python/lib/core:pybind11_absl",
+        "//tensorflow/python/lib/core:pybind11_lib",
         "//third_party/python_runtime:headers",
         "@pybind11",
     ],
@@ -1235,22 +747,22 @@
     ],
     srcs_version = "PY2AND3",
     deps = [
-        ":_pywrap_checkpoint_reader",
         ":_pywrap_debug_events_writer",
         ":_pywrap_events_writer",
-        ":_pywrap_kernel_registry",
+        "//tensorflow/python/util:_pywrap_kernel_registry",
         ":_pywrap_py_exception_registry",
-        ":_pywrap_py_func",  # TODO(b/142001480): remove once the bug is fixed.
+        "//tensorflow/python/lib/core:_pywrap_py_func",  # TODO(b/142001480): remove once the bug is fixed.
         ":_pywrap_python_api_dispatcher",
         ":_pywrap_python_api_info",
         ":_pywrap_python_api_parameter_converter",
         ":_pywrap_python_op_gen",
         ":_pywrap_quantize_training",
         "//tensorflow/python/platform:_pywrap_stacktrace_handler",
-        ":_pywrap_stat_summarizer",
-        ":_pywrap_tfprof",
-        ":_pywrap_transform_graph",
-        ":_pywrap_util_port",
+        "//tensorflow/python/util:_pywrap_checkpoint_reader",
+        "//tensorflow/python/util:_pywrap_stat_summarizer",
+        "//tensorflow/python/util:_pywrap_tfprof",
+        "//tensorflow/python/util:_pywrap_transform_graph",
+        "//tensorflow/python/util:_pywrap_util_port",
         ":_pywrap_utils",
         ":composite_tensor",
         ":config",
@@ -1272,7 +784,7 @@
         ":tensor_spec",
         ":tensor_util",
         ":type_spec",
-        ":util",
+        "//tensorflow/python/util:util",
         "//third_party/py/numpy",
         "@six_archive//:six",
         "//tensorflow/python/eager:context",
@@ -1319,7 +831,7 @@
     srcs = ["framework/device_spec.py"],
     srcs_version = "PY2AND3",
     deps = [
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -1347,9 +859,9 @@
     srcs_version = "PY2AND3",
     deps = [
         ":_dtypes",
-        ":_pywrap_bfloat16",
         ":pywrap_tensorflow",
         "//tensorflow/core:protos_all_py",
+        "//tensorflow/python/lib/core:_pywrap_bfloat16",
     ],
 )
 
@@ -1365,7 +877,7 @@
         ":c_api_util",
         ":error_interpolation",
         ":pywrap_tf_session",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -1391,10 +903,10 @@
         ":graph_to_function_def",
         ":op_def_registry",
         ":pywrap_tf_session",
-        ":util",
         ":variable_scope",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -1485,8 +997,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":pywrap_tf_session",
-        ":util",
         "//tensorflow/core:protos_all_py",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -1501,8 +1013,8 @@
         ":op_def_registry",
         ":platform",
         ":tensor_shape",
-        ":util",
         "//tensorflow/core:protos_all_py",
+        "//tensorflow/python/util",
         "@six_archive//:six",
     ],
 )
@@ -1512,9 +1024,9 @@
     srcs = ["framework/op_def_registry.cc"],
     module_name = "_op_def_registry",
     deps = [
-        ":pybind11_status",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python/lib/core:pybind11_status",
         "@pybind11",
     ],
 )
@@ -1534,8 +1046,8 @@
     srcs = ["framework/py_context_manager.cc"],
     hdrs = ["framework/py_context_manager.h"],
     deps = [
-        ":safe_pyobject_ptr",
         "//tensorflow/core:lib",  # for core/platform/logging.h
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
         "//third_party/python_runtime:headers",
     ],
 )
@@ -1567,9 +1079,9 @@
     srcs = ["framework/op_def_util.cc"],
     hdrs = ["framework/op_def_util.h"],
     deps = [
-        ":cpp_python_util",
-        ":safe_pyobject_ptr",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
+        "//tensorflow/python/util:cpp_python_util",
         "//third_party/python_runtime:headers",  # buildcleaner: keep
         "@com_google_absl//absl/strings",
     ],
@@ -1590,17 +1102,17 @@
     ],
     hdrs = [
         "framework/op_def_util.h",
-        "lib/core/safe_ptr.h",
-        "util/util.h",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
+        "//tensorflow/python/lib/core:safe_ptr_hdr",
+        "//tensorflow/python/util:util_hdr",
     ],
     module_name = "_op_def_util",
     deps = [
-        ":pybind11_status",
-        ":safe_pyobject_ptr",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/platform:status",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
         "//third_party/python_runtime:headers",
         "@com_google_absl//absl/strings",
         "@pybind11",
@@ -1619,16 +1131,16 @@
     srcs = ["framework/python_api_parameter_converter.cc"],
     hdrs = ["framework/python_api_parameter_converter.h"],
     deps = [
-        ":cpp_python_util",
         ":op_def_util_cc",
         ":python_api_info",
         ":python_tensor_converter",
-        ":safe_pyobject_ptr",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/platform:status",
         "//tensorflow/python/eager:pywrap_tfe_lib",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
+        "//tensorflow/python/util:cpp_python_util",
         "//third_party/python_runtime:headers",  # buildcleaner: keep
         "@com_google_absl//absl/strings",
     ],
@@ -1643,7 +1155,6 @@
         "framework/python_api_info.h",
         "framework/python_api_parameter_converter.h",
         "framework/python_tensor_converter.h",
-        "lib/core/numpy.h",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:pywrap_required_hdrs",
         "//tensorflow/c/experimental/ops:pywrap_required_hdrs",
@@ -1651,10 +1162,11 @@
         "//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
         "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
         "//tensorflow/python/eager:pywrap_required_hdrs",
+        "//tensorflow/python/lib/core:numpy_hdr",
     ],
     module_name = "_pywrap_python_api_parameter_converter",
     deps = [
-        ":safe_pyobject_ptr_required_hdrs",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs",
         "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/hash",
         "@com_google_absl//absl/memory",
@@ -1701,15 +1213,15 @@
     srcs = ["framework/python_api_info.cc"],
     hdrs = ["framework/python_api_info.h"],
     deps = [
-        ":cpp_python_util",
         ":op_def_util_cc",
         ":python_tensor_converter",
-        ":safe_pyobject_ptr",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/platform:status",
         "//tensorflow/python/eager:pywrap_tfe_lib",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
+        "//tensorflow/python/util:cpp_python_util",
         "//third_party/python_runtime:headers",  # buildcleaner: keep
         "@com_google_absl//absl/strings",
     ],
@@ -1723,7 +1235,6 @@
         "framework/op_def_util.h",
         "framework/python_api_info.h",
         "framework/python_tensor_converter.h",
-        "lib/core/numpy.h",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:pywrap_required_hdrs",
         "//tensorflow/c/experimental/ops:pywrap_required_hdrs",
@@ -1731,10 +1242,11 @@
         "//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
         "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
         "//tensorflow/python/eager:pywrap_required_hdrs",
+        "//tensorflow/python/lib/core:numpy_hdr",
     ],
     module_name = "_pywrap_python_api_info",
     deps = [
-        ":safe_pyobject_ptr_required_hdrs",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs",
         "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/hash",
         "@com_google_absl//absl/memory",
@@ -1781,9 +1293,9 @@
     srcs = ["framework/python_api_dispatcher.cc"],
     hdrs = ["framework/python_api_dispatcher.h"],
     deps = [
-        ":cpp_python_util",
-        ":safe_pyobject_ptr",
         "//tensorflow/core/platform:logging",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
+        "//tensorflow/python/util:cpp_python_util",
         "//third_party/python_runtime:headers",  # buildcleaner: keep
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/strings",
@@ -1798,7 +1310,7 @@
     hdrs = ["framework/python_api_dispatcher.h"],
     module_name = "_pywrap_python_api_dispatcher",
     deps = [
-        ":safe_pyobject_ptr_required_hdrs",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs",
         "//third_party/python_runtime:headers",  # buildcleaner: keep
         "@pybind11",
     ],
@@ -1820,11 +1332,11 @@
     srcs = ["framework/python_tensor_converter.cc"],
     hdrs = ["framework/python_tensor_converter.h"],
     deps = [
-        ":cpp_python_util",
-        ":safe_pyobject_ptr",
         "//tensorflow/c/eager:c_api",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/python/eager:pywrap_tfe_lib",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
+        "//tensorflow/python/util:cpp_python_util",
         "//third_party/python_runtime:headers",  # buildcleaner: keep
         "@com_google_absl//absl/strings",
     ],
@@ -1836,7 +1348,6 @@
     srcs = ["framework/python_tensor_converter_wrapper.cc"],
     hdrs = [
         "framework/python_tensor_converter.h",
-        "lib/core/numpy.h",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:pywrap_required_hdrs",
         "//tensorflow/c/experimental/ops:pywrap_required_hdrs",
@@ -1844,10 +1355,11 @@
         "//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
         "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
         "//tensorflow/python/eager:pywrap_required_hdrs",
+        "//tensorflow/python/lib/core:numpy_hdr",
     ],
     module_name = "_pywrap_python_tensor_converter",
     deps = [
-        ":safe_pyobject_ptr_required_hdrs",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs",
         "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/hash",
         "@com_google_absl//absl/memory",
@@ -1908,7 +1420,6 @@
         ":tf2",
         ":traceable_stack",
         ":type_spec",
-        ":util",
         ":versions",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
@@ -1916,6 +1427,7 @@
         "//tensorflow/python/eager:monitoring",
         "//tensorflow/python/eager:tape",
         "//tensorflow/python/profiler:traceme",
+        "//tensorflow/python/util",
         "@six_archive//:six",
     ],
 )
@@ -1953,9 +1465,9 @@
         ":tensor_conversion_registry",
         ":tensor_shape",
         ":type_spec",
-        ":util",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/types",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -1980,9 +1492,9 @@
         ":sparse_tensor",
         ":tensor_array_ops",
         ":tensor_shape",
-        ":util",
         ":variable_scope",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -2012,7 +1524,7 @@
         ":framework_ops",
         ":sparse_tensor",
         ":tensor_array_ops",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -2043,8 +1555,7 @@
     srcs_version = "PY2AND3",
     deps = [
         ":framework_ops",
-        ":util",
-        "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -2060,7 +1571,7 @@
         ":constant_op",
         ":platform",
         ":test_ops",
-        ":util",
+        "//tensorflow/python/util:util",
     ] + tf_additional_xla_deps_py(),
 )
 
@@ -2079,8 +1590,9 @@
     srcs_version = "PY2AND3",
     deps = [
         ":platform",
-        ":tf_stack",
-        ":util",
+        "//tensorflow/python/util",
+        # TODO(mdan): Remove this once the transitive dependency is fixed.
+        "//tensorflow/python/util:tf_stack",
     ],
 )
 
@@ -2158,9 +1670,9 @@
     deps = [
         ":dtypes",
         ":tf2",
-        ":util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:monitoring",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -2172,7 +1684,7 @@
     deps = [
         ":dtypes",
         ":tensor_shape",
-        ":util",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
     ],
 )
@@ -2186,7 +1698,7 @@
         ":dtypes",
         ":tensor_shape",
         ":type_spec",
-        ":util",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
     ],
 )
@@ -2198,9 +1710,9 @@
     deps = [
         ":errors",
         ":tensor_shape",
-        ":util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/types",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -2210,7 +1722,7 @@
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -2260,7 +1772,6 @@
         ":session",
         ":tensor_array_ops",
         ":training",
-        ":util",
         ":variables",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:backprop",
@@ -2269,6 +1780,7 @@
         "//tensorflow/python/ops/ragged:ragged_tensor",
         "//tensorflow/python/ops/ragged:ragged_tensor_value",
         "//tensorflow/python/platform:_pywrap_stacktrace_handler",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
         "@absl_py//absl/testing:parameterized",
         "@six_archive//:six",
@@ -2330,8 +1842,8 @@
         ":framework_ops",
         ":framework_test_combinations_lib",
         ":tf2",
-        ":util",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -2340,7 +1852,7 @@
     srcs = ["framework/test_combinations.py"],
     srcs_version = "PY2AND3",
     deps = [
-        ":util",
+        "//tensorflow/python/util",
         "@absl_py//absl/testing:parameterized",
     ],
 )
@@ -2357,11 +1869,6 @@
 )
 
 py_library(
-    name = "client_testlib",
-    deps = ["//tensorflow/python/platform:client_testlib"],
-)
-
-py_library(
     name = "memory_checker",
     srcs = [
         "framework/memory_checker.py",
@@ -2483,10 +1990,10 @@
         ":sparse_tensor",
         ":tensor_array_ops",
         ":tensor_shape",
-        ":util",
         ":variable_scope",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -2512,6 +2019,7 @@
     python_version = "PY3",
     shard_count = 10,
     tags = [
+        "no_rocm",
         "noasan",
         "optonly",
     ],
@@ -2557,6 +2065,7 @@
     srcs = ["framework/importer_test.py"],
     main = "framework/importer_test.py",
     python_version = "PY3",
+    tags = ["no_rocm"],
     deps = [
         ":array_ops",
         ":client_testlib",
@@ -2621,7 +2130,7 @@
         ":platform_test",
         ":test_ops",
         ":traceable_stack",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -2699,13 +2208,13 @@
         ":resources",
         ":test_ops",
         ":test_ops_2",
-        ":util",
         ":variable_scope",
         ":variables",
         ":while_v2",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:function",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -3153,7 +2662,7 @@
     ],
 )
 
-py_test(
+cuda_py_test(
     name = "batch_ops_test",
     size = "small",
     srcs = ["ops/batch_ops_test.py"],
@@ -3390,7 +2899,7 @@
         ":sparse_tensor",
         ":tensor_shape",
         ":tensor_util",
-        ":util",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
         "@six_archive//:six",
     ],
@@ -3403,7 +2912,7 @@
     deps = [
         ":bitwise_ops_gen",
         ":framework",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -3453,7 +2962,7 @@
         ":framework",
         ":framework_for_generated_wrappers",
         ":set_ops_gen",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -3480,7 +2989,7 @@
         ":math_ops",
         ":sparse_tensor",
         ":tensor_util",
-        ":util",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
     ],
 )
@@ -3638,9 +3147,9 @@
         ":tensor_array_ops",
         ":tensor_shape",
         ":tf2",
-        ":tf_should_use",
-        ":util",
         "//tensorflow/core:protos_all_py",
+        "//tensorflow/python/util",
+        "//tensorflow/python/util:tf_should_use",
         "@six_archive//:six",
     ],
 )
@@ -3662,11 +3171,11 @@
         ":control_flow_util",
         ":control_flow_v2_func_graphs",
         ":framework_ops",
-        ":util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/distribute:distribute_lib",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:function",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -3687,7 +3196,7 @@
         ":control_flow_util",
         ":control_flow_util_v2",
         ":framework_ops",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -3758,10 +3267,10 @@
         ":graph_to_function_def",
         ":handle_data_util",
         ":pywrap_tensorflow",
-        ":util",
         "//tensorflow/python/compat",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/eager:function",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -3790,8 +3299,8 @@
         ":tensor_array_ops",
         ":tensor_shape",
         ":tensor_util",
-        ":util",
         "//tensorflow/python/eager:function",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -3901,7 +3410,7 @@
         ":framework_ops",
         ":protos_all_py",
         ":pywrap_tf_session",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -3954,8 +3463,8 @@
         ":random_grad",
         ":tensor_array_ops",
         ":unconnected_gradients",
-        ":util",
         "//tensorflow/python/ops/linalg/sparse",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -3980,11 +3489,11 @@
         ":resource_variable_ops",
         ":tensor_util",
         ":unconnected_gradients",
-        ":util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:backprop",
         "//tensorflow/python/eager:backprop_util",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
         "@six_archive//:six",
     ],
@@ -4026,7 +3535,7 @@
     srcs = ["ops/unconnected_gradients.py"],
     srcs_version = "PY2AND3",
     deps = [
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4074,8 +3583,8 @@
         ":nn_ops_gen",
         ":random_ops",
         ":string_ops",
-        ":util",
         ":variables",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
     ],
 )
@@ -4092,7 +3601,7 @@
         ":linalg_ops_impl",
         ":math_ops",
         ":random_ops",
-        ":util",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
     ],
 )
@@ -4110,7 +3619,7 @@
         ":math_ops",
         ":random_ops",
         ":stateless_random_ops",
-        ":util",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
     ],
 )
@@ -4211,8 +3720,8 @@
         ":logging_ops_gen",
         ":platform",
         ":string_ops",
-        ":util",
         "//tensorflow/python/compat",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4229,8 +3738,8 @@
         ":math_ops",
         ":sparse_tensor",
         ":string_ops",
-        ":util",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
         "@six_archive//:six",
     ],
 )
@@ -4280,9 +3789,9 @@
         ":state_ops",
         ":state_ops_gen",
         ":tensor_shape",
-        ":util",
         "//tensorflow/python/compat",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
     ],
 )
@@ -4296,7 +3805,7 @@
         ":control_flow_ops",
         ":framework_for_generated_wrappers",
         ":math_ops",
-        ":tf_should_use",
+        "//tensorflow/python/util:tf_should_use",
     ],
 )
 
@@ -4314,11 +3823,11 @@
         ":pywrap_tf_session",
         ":resource_variable_ops_gen",
         ":tensor_shape",
-        ":util",
         ":variables",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:tape",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4333,8 +3842,8 @@
         ":framework_ops",
         ":resource_variable_ops_gen",
         ":tensor_array_ops",
-        ":util",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4378,9 +3887,9 @@
         ":nn_ops",
         ":nn_ops_gen",
         ":sparse_ops",
-        ":util",
         ":variables",
         "//tensorflow/python/platform:device_context",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4558,9 +4067,9 @@
         ":math_ops",
         ":rnn_cell",
         ":tensor_array_ops",
-        ":util",
         ":variable_scope",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4581,11 +4090,11 @@
         ":nn_ops",
         ":partitioned_variables",
         ":random_ops",
-        ":util",
         ":variable_scope",
         ":variables",
         "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_impl",
         "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_wrapper_impl",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4594,10 +4103,10 @@
     srcs = ["ops/script_ops.py"],
     srcs_version = "PY2AND3",
     deps = [
-        ":_pywrap_py_func",
         ":array_ops",
         ":framework_for_generated_wrappers",
         ":script_ops_gen",
+        "//tensorflow/python/lib/core:_pywrap_py_func",
         "//third_party/py/numpy",
     ],
 )
@@ -4621,7 +4130,7 @@
         ":array_ops",
         ":data_flow_ops_gen",
         ":framework_for_generated_wrappers",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4651,7 +4160,7 @@
         ":framework_for_generated_wrappers",
         ":math_ops",
         ":sparse_ops_gen",
-        ":util",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
     ],
 )
@@ -4749,11 +4258,11 @@
         ":sets",
         ":sparse_ops",
         ":state_ops",
-        ":util",
         ":variable_scope",
         ":variables",
         ":weights_broadcast_ops",
         "//tensorflow/python/distribute:distribute_lib",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4866,7 +4375,6 @@
         ":template",
         ":tensor_array_grad",
         ":tensor_array_ops",
-        ":util",
         ":variable_scope",
         ":variables",
         "//tensorflow/python/compiler",
@@ -4877,6 +4385,7 @@
         "//tensorflow/python/ops/ragged",
         "//tensorflow/python/ops/structured",
         "//tensorflow/python/training/experimental:loss_scaling_gradient_tape",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4898,7 +4407,7 @@
         ":resource_variable_ops_gen",
         ":state_ops_gen",
         ":tensor_shape",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4910,7 +4419,7 @@
         ":framework",
         ":framework_for_generated_wrappers",
         ":string_ops_gen",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4932,10 +4441,10 @@
         ":summary_ops_gen",
         ":tensor_util",
         ":training_util",
-        ":util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:profiler",
+        "//tensorflow/python/util",
         "@six_archive//:six",
     ],
 )
@@ -4947,10 +4456,10 @@
     deps = [
         ":framework_for_generated_wrappers",
         ":platform",
-        ":util",
         ":variable_scope",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:function",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -4981,8 +4490,8 @@
         ":tensor_shape",
         ":tensor_util",
         ":tf2",
-        ":tf_should_use",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_should_use",
     ],
 )
 
@@ -4999,10 +4508,10 @@
         ":resource_variable_ops",
         ":tensor_shape",
         ":tf2",
-        ":util",
         ":variables",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:monitoring",
+        "//tensorflow/python/util",
         "@six_archive//:six",
     ],
 )
@@ -5019,11 +4528,11 @@
         ":math_ops",
         ":state_ops",
         ":tensor_shape",
-        ":tf_should_use",
-        ":util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/training/tracking:base",
+        "//tensorflow/python/util",
+        "//tensorflow/python/util:tf_should_use",
     ],
 )
 
@@ -5060,7 +4569,7 @@
     srcs_version = "PY2AND3",
     deps = [
         ":user_ops_gen",
-        ":util",
+        "//tensorflow/python/util",
         "@six_archive//:six",
     ],
 )
@@ -5107,11 +4616,11 @@
         ":tensor_array_grad",
         ":tensor_array_ops",
         ":training",
-        ":util",
         ":variable_scope",
         ":variables",
         ":while_v2",
         "//tensorflow/python/eager:def_function",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -5377,6 +4886,7 @@
     srcs = ["ops/nn_fused_batchnorm_test.py"],
     python_version = "PY3",
     shard_count = 24,
+    tags = ["no_rocm"],
     deps = [
         ":array_ops",
         ":client_testlib",
@@ -5526,171 +5036,14 @@
         ":platform",
         ":session",
         ":session_ops",
-        ":util",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
         "@six_archive//:six",
     ],
 )
 
-# Leaf library: may not depend on anything else inside TensorFlow.
-py_strict_library(
-    name = "tf_export",
-    srcs = ["util/tf_export.py"],
-    compatible_with = get_compatible_with_portable(),
-    srcs_version = "PY2AND3",
-    visibility = ["//tensorflow:__subpackages__"],
-    deps = [
-        ":tf_decorator",
-    ],
-)
-
-tf_py_test(
-    name = "tf_export_test",
-    srcs = ["util/tf_export_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":platform",
-        ":util",
-    ],
-)
-
-# Leaf library: may not depend on anything else inside TensorFlow.
-# TODO(mdan): Move this utility outside of TF.
-py_strict_library(
-    name = "tf_decorator",
-    srcs = [
-        "util/tf_contextlib.py",
-        "util/tf_decorator.py",
-        "util/tf_inspect.py",
-    ],
-    compatible_with = get_compatible_with_portable(),
-    srcs_version = "PY2AND3",
-    visibility = [
-        "//tensorflow:__subpackages__",
-        # TODO(mdan): Remove these dependencies.
-        "//third_party/py/tf_slim:__subpackages__",
-        "//learning/deepmind/research/language/translation/lm:__subpackages__",
-    ],
-    deps = [
-        "@six_archive//:six",
-    ],
-)
-
 # Note: this is a heavyweight library specialized for TensorFlow graphs. Do not use for
 # other purposes.
-py_strict_library(
-    name = "tf_stack",
-    srcs = ["util/tf_stack.py"],
-    srcs_version = "PY2AND3",
-    # TODO(mdan): Remove public visibility.
-    visibility = ["//visibility:public"],
-    deps = [
-        ":_tf_stack",
-        "@six_archive//:six",
-    ],
-)
-
-tf_python_pybind_extension(
-    name = "_tf_stack",
-    srcs = ["util/tf_stack.cc"],
-    hdrs = [
-        "//tensorflow/c:headers",
-        "//tensorflow/c/eager:headers",
-    ],
-    # TODO(b/138203821): change to "util._tf_stack" once the bug is fixed.
-    module_name = "_tf_stack",
-    deps = [
-        ":stack_trace",
-        "//tensorflow/c:pywrap_required_hdrs",
-        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
-        "//tensorflow/core/framework:pywrap_required_hdrs",
-        "//tensorflow/core/platform:path",
-        "//third_party/python_runtime:headers",  # buildcleaner: keep
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/hash",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/types:span",
-        "@pybind11",
-    ],
-)
-
-tf_py_test(
-    name = "tf_stack_test",
-    srcs = ["util/tf_stack_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":tf_export",
-        ":tf_stack",
-    ],
-)
-
-cc_library(
-    name = "stack_trace",
-    srcs = ["util/stack_trace.cc"],
-    hdrs = ["util/stack_trace.h"],
-    deps = [
-        "//tensorflow/core/platform:str_util",
-        "//tensorflow/core/platform:stringpiece",
-        "//tensorflow/core/util:abstract_stack_trace",
-        "//third_party/python_runtime:headers",  # buildcleaner: keep
-        "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/types:optional",
-    ],
-)
-
-cc_library(
-    name = "function_parameter_canonicalizer",
-    srcs = ["util/function_parameter_canonicalizer.cc"],
-    hdrs = ["util/function_parameter_canonicalizer.h"],
-    deps = [
-        ":py_util",
-        ":safe_pyobject_ptr",
-        "//tensorflow/core/platform:logging",
-        "//tensorflow/core/platform:macros",
-        "//third_party/python_runtime:headers",  # buildcleaner: keep
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/types:span",
-    ],
-)
-
-tf_python_pybind_extension(
-    name = "_function_parameter_canonicalizer_binding_for_test",
-    testonly = True,
-    srcs = ["util/function_parameter_canonicalizer_binding_for_test.cc"],
-    hdrs = [
-        "util/function_parameter_canonicalizer.h",
-    ],
-    module_name = "_function_parameter_canonicalizer_binding_for_test",
-    deps = [
-        ":safe_pyobject_ptr_required_hdrs",
-        "//tensorflow/core:lib",
-        "//third_party/python_runtime:headers",  # buildcleaner: keep
-        "@com_google_absl//absl/types:span",
-        "@pybind11",
-    ],
-)
-
-tf_py_test(
-    name = "function_parameter_canonicalizer_test",
-    srcs = ["util/function_parameter_canonicalizer_test.py"],
-    python_version = "PY3",
-    tags = [
-        "no_pip",  # b/168621686
-        "no_windows",  # b/169275019
-    ],
-    deps = [
-        ":_function_parameter_canonicalizer_binding_for_test",
-        ":client_testlib",
-    ],
-)
 
 py_library(
     name = "global_test_configuration",
@@ -5699,201 +5052,15 @@
            tf_enable_mlir_bridge(),
 )
 
-py_library(
-    name = "util",
-    srcs = glob(
-        ["util/**/*.py"],
-        exclude = [
-            "util/example_parser*",
-            "util/tf_contextlib.py",
-            "util/tf_should_use.py",
-            "util/tf_export.py",
-            "util/tf_stack.py",
-            "util/tf_decorator.py",
-            "util/**/*_test.py",
-        ],
-    ),
-    compatible_with = get_compatible_with_portable(),
-    srcs_version = "PY2AND3",
-    visibility = visibility + [
-        "//tensorflow:__pkg__",
-        "//third_party/py/tensorflow_core:__subpackages__",
-        "//third_party/py/tf_agents:__subpackages__",
-        "//third_party/py/tfx:__subpackages__",
-    ],
-    deps = [
-        ":_pywrap_tensor_float_32_execution",
-        # global_test_configuration is added here because all major tests depend on this
-        # library. It isn't possible to add these test dependencies via tensorflow.bzl's
-        # py_test because not all tensorflow tests use tensorflow.bzl's py_test.
-        ":global_test_configuration",
-        ":tf_decorator",
-        ":tf_export",
-        "@org_python_pypi_backports_weakref",
-        "@com_google_protobuf//:protobuf_python",
-        "//third_party/py/numpy",
-        "@six_archive//:six",
-        "@wrapt",
-        "//tensorflow/tools/docs:doc_controls",
-        "//tensorflow/tools/compatibility:all_renames_v2",
-    ],
-)
-
-tf_py_test(
-    name = "object_identity_test",
-    size = "small",
-    srcs = ["util/object_identity_test.py"],
-    python_version = "PY3",
-)
-
-# Placeholder for intenal nest_test comments.
-tf_py_test(
-    name = "util_nest_test",
-    size = "small",
-    srcs = ["util/nest_test.py"],
-    main = "util/nest_test.py",
-    python_version = "PY3",
-    deps = [":util_nest_test_main_lib"],
-)
-
-py_library(
-    name = "util_nest_test_main_lib",
-    testonly = True,
-    srcs = ["util/nest_test.py"],
-    deps = [
-        ":array_ops",
-        ":client_testlib",
-        ":framework",
-        ":framework_for_generated_wrappers",
-        ":math_ops",
-        ":util",
-        "//third_party/py/numpy",
-        "@absl_py//absl/testing:parameterized",
-    ],
-)
-
-tf_py_test(
-    name = "util_serialization_test",
-    size = "small",
-    srcs = ["util/serialization_test.py"],
-    main = "util/serialization_test.py",
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":util",
-    ],
-)
-
-tf_py_test(
-    name = "function_utils_test",
-    srcs = ["util/function_utils_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":util",
-    ],
-)
-
-tf_py_test(
-    name = "tf_contextlib_test",
-    size = "small",
-    srcs = ["util/tf_contextlib_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":util",
-    ],
-)
-
-tf_py_test(
-    name = "tf_decorator_test",
-    size = "small",
-    srcs = ["util/tf_decorator_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":util",
-    ],
-)
-
-py_library(
-    name = "tf_should_use",
-    srcs = ["util/tf_should_use.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":framework_ops",
-        ":util",
-        "//tensorflow/python/eager:context",
-        "@six_archive//:six",
-    ],
-)
-
-tf_py_test(
-    name = "tf_should_use_test",
-    size = "small",
-    srcs = ["util/tf_should_use_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":tf_should_use",
-    ],
-)
-
-tf_py_test(
-    name = "tf_inspect_test",
-    size = "small",
-    srcs = ["util/tf_inspect_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":util",
-    ],
-)
-
-py_library(
-    name = "util_example_parser_configuration",
-    srcs = ["util/example_parser_configuration.py"],
-    srcs_version = "PY2AND3",
-    visibility = ["//visibility:public"],
-    deps = [
-        ":framework",
-        ":framework_for_generated_wrappers",
-        "//tensorflow/core:protos_all_py",
-    ],
-)
-
-tf_py_test(
-    name = "lock_util_test",
-    size = "small",
-    srcs = ["util/lock_util_test.py"],
-    main = "util/lock_util_test.py",
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":util",
-        "@absl_py//absl/testing:parameterized",
-    ],
-)
-
-tf_py_test(
-    name = "module_wrapper_test",
-    size = "small",
-    srcs = ["util/module_wrapper_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":util",
-        "//tensorflow/tools/compatibility:all_renames_v2",
-        "@six_archive//:six",
-    ],
-)
+# `tree.compat` requires visibility exception to test against `nest_test`
+# to facilitate convergence between `tree.compat` and `nest`.
 
 tf_proto_library(
     name = "protos_all",
     srcs = glob(
         ["**/*.proto"],
         exclude = [
-            "util/protobuf/compare_test.proto",
+            "//tensorflow/python/util:compare_test_proto_src",
             "framework/cpp_shape_inference.proto",
         ],
     ),
@@ -5902,13 +5069,6 @@
 )
 
 tf_proto_library(
-    name = "compare_test_proto",
-    testonly = 1,
-    srcs = ["util/protobuf/compare_test.proto"],
-    cc_api_version = 2,
-)
-
-tf_proto_library(
     name = "cpp_shape_inference_proto",
     srcs = ["framework/cpp_shape_inference.proto"],
     cc_api_version = 2,
@@ -5918,37 +5078,6 @@
 )
 
 tf_py_test(
-    name = "protobuf_compare_test",
-    size = "small",
-    srcs = ["util/protobuf/compare_test.py"],
-    main = "util/protobuf/compare_test.py",
-    python_version = "PY3",
-    tags = ["no_pip"],  # compare_test_pb2 proto is not available in pip.
-    deps = [
-        ":compare_test_proto_py",
-        ":platform_test",
-        ":util",
-        "@six_archive//:six",
-    ],
-)
-
-tf_py_test(
-    name = "util_example_parser_configuration_test",
-    size = "small",
-    srcs = ["util/example_parser_configuration_test.py"],
-    main = "util/example_parser_configuration_test.py",
-    python_version = "PY3",
-    deps = [
-        ":array_ops",
-        ":client",
-        ":client_testlib",
-        ":framework_for_generated_wrappers",
-        ":parsing_ops",
-        ":util_example_parser_configuration",
-    ],
-)
-
-tf_py_test(
     name = "events_writer_test",
     size = "small",
     srcs = ["client/events_writer_test.py"],
@@ -5958,7 +5087,7 @@
         ":framework_test_lib",
         ":lib",
         ":platform_test",
-        ":util",
+        "//tensorflow/python/util",
     ],
 )
 
@@ -5977,11 +5106,11 @@
     srcs = ["client/device_lib_wrapper.cc"],
     module_name = "_pywrap_device_lib",
     deps = [
-        ":pybind11_proto",
-        ":pybind11_status",
         "//tensorflow/core:framework_internal_headers_lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/common_runtime:core_cpu_headers_lib",
+        "//tensorflow/python/lib/core:pybind11_proto",
+        "//tensorflow/python/lib/core:pybind11_status",
         "//third_party/python_runtime:headers",
         "@pybind11",
     ],
@@ -6071,24 +5200,27 @@
         ":bfloat16_lib",
         ":cost_analyzer_lib",
         ":model_analyzer_lib",
-        ":cpp_python_util",
-        ":function_parameter_canonicalizer",
-        ":kernel_registry",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_session",
+        "//tensorflow/python/util:cpp_python_util",
+        "//tensorflow/python/util:function_parameter_canonicalizer",
+        "//tensorflow/python/util:kernel_registry",
         ":numpy_lib",
         ":safe_ptr",
         ":py_exception_registry",
-        ":py_func_lib",
-        ":py_record_reader_lib",
-        ":pybind11_absl",
-        ":pybind11_lib",
-        ":pybind11_status",
-        ":pybind11_proto",
+        "//tensorflow/python/lib/core:py_func_lib",
+        "//tensorflow/python/lib/io:py_record_reader_lib",
+        "//tensorflow/python/lib/core:pybind11_absl",
+        "//tensorflow/python/lib/core:pybind11_lib",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "//tensorflow/python/lib/core:pybind11_proto",
         ":python_api_dispatcher",
         ":python_api_info",
         ":python_api_parameter_converter",
         ":python_op_gen",
         ":python_tensor_converter",
-        ":safe_pyobject_ptr",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
         ":tf_session_helper",
         "//third_party/python_runtime:headers",
         "//tensorflow/c:c_api",
@@ -6105,9 +5237,6 @@
         "//tensorflow/c/eager:mnist_gradients_testutil",
         "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
         "//tensorflow/core/data/service:server_lib",
-        "//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
-        "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
-        "//tensorflow/core/distributed_runtime/rpc:grpc_session",
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler:grappler_item_builder",
         "//tensorflow/core/grappler/clusters:cluster",
@@ -6148,13 +5277,13 @@
     srcs = [
         ":bfloat16_lib",  # bfloat16
         ":cost_analyzer_lib",  # cost_analyzer
-        ":cpp_python_util",  # util
-        ":kernel_registry",  # kernel_registry
+        "//tensorflow/python/util:cpp_python_util",
+        "//tensorflow/python/util:kernel_registry",
         ":model_analyzer_lib",  # model_analyzer
         ":ndarray_tensor",  # checkpoint_reader
         ":numpy_lib",  # checkpoint_reader
         ":py_exception_registry",  # py_exception_registry
-        ":py_func_lib",  # py_func
+        "//tensorflow/python/lib/core:py_func_lib",
         ":python_api_dispatcher",  # python_api_dispatcher
         ":python_api_info",  # python_api_info
         ":python_api_parameter_converter",  # python_api_parameter_converter
@@ -6313,52 +5442,6 @@
 
 # ** Targets for Windows build (end) **
 
-tf_python_pybind_extension(
-    name = "_pywrap_file_io",
-    srcs = ["lib/io/file_io_wrapper.cc"],
-    module_name = "_pywrap_file_io",
-    deps = [
-        ":pybind11_absl",
-        ":pybind11_status",
-        "//tensorflow/core:framework_headers_lib",
-        "//tensorflow/core:protos_all_cc",
-        "@pybind11",
-    ],
-)
-
-py_library(
-    name = "lib",
-    srcs = [
-        "lib/io/file_io.py",
-        "lib/io/python_io.py",
-        "lib/io/tf_record.py",
-    ],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":_pywrap_file_io",
-        ":_pywrap_record_io",
-        ":errors",
-        ":pywrap_tensorflow",
-        ":util",
-        "@six_archive//:six",
-    ],
-)
-
-tf_python_pybind_extension(
-    name = "_pywrap_record_io",
-    srcs = ["lib/io/record_io_wrapper.cc"],
-    module_name = "_pywrap_record_io",
-    deps = [
-        ":pybind11_absl",
-        ":pybind11_status",
-        "//tensorflow/core:framework_headers_lib",
-        "//tensorflow/core:lib_headers_for_pybind",
-        "//tensorflow/core/platform:types",
-        "@com_google_absl//absl/memory",
-        "@pybind11",
-    ],
-)
-
 py_library(
     name = "session",
     srcs = ["client/session.py"],
@@ -6373,7 +5456,7 @@
         ":platform",
         ":pywrap_tensorflow",
         ":session_ops",
-        ":util",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
         "@wrapt",
     ],
@@ -6427,8 +5510,8 @@
         ":platform_test",
         ":state_ops",
         ":training",
-        ":util",
         ":variables",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
         "@six_archive//:six",
     ],
@@ -6458,8 +5541,8 @@
         ":platform_test",
         ":state_ops",
         ":training",
-        ":util",
         ":variables",
+        "//tensorflow/python/util",
         "//third_party/py/numpy",
     ],
 )
@@ -6504,7 +5587,7 @@
         ":math_ops",
         ":platform_test",
         ":training",
-        ":util",
+        "//tensorflow/python/util",
         "@six_archive//:six",
     ],
 )
@@ -6592,47 +5675,6 @@
     ],
 )
 
-tf_py_test(
-    name = "bfloat16_test",
-    size = "small",
-    srcs = ["lib/core/bfloat16_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":lib",
-        ":pywrap_tensorflow",
-    ],
-)
-
-tf_py_test(
-    name = "file_io_test",
-    size = "small",
-    srcs = ["lib/io/file_io_test.py"],
-    python_version = "PY3",
-    tags = [
-        "no_rocm",
-        "no_windows",
-    ],
-    deps = [
-        ":client_testlib",
-        ":errors",
-        ":lib",
-    ],
-)
-
-tf_py_test(
-    name = "tf_record_test",
-    size = "small",
-    srcs = ["lib/io/tf_record_test.py"],
-    python_version = "PY3",
-    deps = [
-        ":client_testlib",
-        ":errors",
-        ":lib",
-        ":util",
-    ],
-)
-
 py_library(
     name = "summary_op_util",
     srcs = ["ops/summary_op_util.py"],
@@ -6669,9 +5711,9 @@
         ":summary_op_util",
         ":summary_ops_gen",
         ":summary_ops_v2",
-        ":util",
         "//tensorflow/python/distribute:summary_op_util",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util",
         "@six_archive//:six",
     ],
 )
@@ -7084,7 +6126,7 @@
     ],
     module_name = "_pywrap_tf_item",
     deps = [
-        ":pybind11_status",
+        "//tensorflow/python/lib/core:pybind11_status",
         "@pybind11",
         "//tensorflow/core/common_runtime:core_cpu_headers_lib",
         "//tensorflow/core:framework_headers_lib",
@@ -7160,12 +6202,12 @@
     ],
     module_name = "_pywrap_tf_cluster",
     deps = [
-        ":pybind11_status",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_headers_for_pybind",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/common_runtime:core_cpu_headers_lib",
         "//tensorflow/core/common_runtime/gpu:gpu_id",
+        "//tensorflow/python/lib/core:pybind11_status",
         "@com_google_absl//absl/types:span",
         "@pybind11",
     ],
@@ -7222,12 +6264,12 @@
     ],
     module_name = "_pywrap_tf_optimizer",
     deps = [
-        ":pybind11_status",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_headers_for_pybind",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/common_runtime:core_cpu_headers_lib",
         "//tensorflow/core/common_runtime/gpu:gpu_id",
+        "//tensorflow/python/lib/core:pybind11_status",
         "@pybind11",
     ],
 )
@@ -7556,18 +6598,18 @@
     name = "_pywrap_mlir",
     srcs = ["mlir_wrapper.cc"],
     hdrs = [
-        "lib/core/safe_ptr.h",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
         "//tensorflow/compiler/mlir/python:pywrap_mlir_hdrs",
+        "//tensorflow/python/lib/core:safe_ptr_hdr",
     ],
     module_name = "_pywrap_mlir",
     deps = [
-        ":pybind11_lib",
-        ":pybind11_status",
-        ":safe_pyobject_ptr",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/platform:status",
+        "//tensorflow/python/lib/core:pybind11_lib",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
         "//third_party/python_runtime:headers",
         "@com_google_absl//absl/container:fixed_array",
         "@pybind11",
@@ -7577,14 +6619,7 @@
 cc_library(
     name = "unified_api_pywrap_required_headers",
     textual_hdrs = [
-        "lib/core/numpy.h",
-        "lib/core/py_exception_registry.h",
-        "lib/core/pybind11_status.h",
-        "lib/core/bfloat16.h",
-        "lib/core/ndarray_tensor.h",
-        "lib/core/ndarray_tensor_bridge.h",
-        "lib/core/safe_ptr.h",
-        "lib/core/safe_pyobject_ptr.h",
+        "//tensorflow/python/lib/core:basic_hdrs",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
         "//tensorflow/c/eager:pywrap_required_hdrs",
@@ -7620,10 +6655,6 @@
     name = "_pywrap_tfe",
     srcs = ["tfe_wrapper.cc"],
     hdrs = [
-        "lib/core/numpy.h",
-        "lib/core/safe_ptr.h",
-        "util/util.h",
-        ":py_exception_registry_hdr",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
         "//tensorflow/c/eager:pywrap_required_hdrs",
@@ -7632,16 +6663,20 @@
         "//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
         "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
         "//tensorflow/python/eager:pywrap_required_hdrs",
+        "//tensorflow/python/lib/core:numpy_hdr",
+        "//tensorflow/python/lib/core:py_exception_registry_hdr",
+        "//tensorflow/python/lib/core:safe_ptr_hdr",
+        "//tensorflow/python/util:util_hdr",
     ],
     module_name = "_pywrap_tfe",
     # Only include TensorFlow header-only targets here.
     # If a cc_library needs to depend on TensorFlow .cc files through srcs or
     # deps, then you can use cc_header_only_library to keep only headers.
     deps = [
-        ":safe_pyobject_ptr",
-        ":pybind11_lib",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
+        "//tensorflow/python/lib/core:pybind11_lib",
         "//third_party/py/numpy:headers",
-        ":pybind11_status",
+        "//tensorflow/python/lib/core:pybind11_status",
         "//tensorflow/core/framework:pywrap_required_hdrs",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/hash",
@@ -7701,22 +6736,22 @@
 tf_python_pybind_extension(
     name = "_pywrap_parallel_device",
     srcs = [
-        "lib/core/safe_ptr.h",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
         "//tensorflow/c/eager/parallel_device:headers",
         "//tensorflow/c/eager/parallel_device:sources",
         "//tensorflow/python/distribute/parallel_device:pywrap_parallel_device.cc",
+        "//tensorflow/python/lib/core:safe_ptr_hdr",
     ],
     module_name = "_pywrap_parallel_device",
     visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
     deps = [
-        ":pybind11_lib",
-        ":pybind11_status",
-        ":safe_pyobject_ptr",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib_headers_for_pybind",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python/lib/core:pybind11_lib",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
         "//third_party/python_runtime:headers",
         "@pybind11",
     ],
@@ -7757,6 +6792,112 @@
 )
 
 alias(
+    name = "platform_benchmark",
+    actual = "//tensorflow/python/platform:benchmark",
+)
+
+alias(
+    name = "platform_analytics",
+    actual = "//tensorflow/python/platform:analytics",
+)
+
+py_library(
+    name = "platform_test",
+    deps = ["//tensorflow/python/platform:test"],
+)
+
+alias(
+    name = "platform",
+    actual = "//tensorflow/python/platform:platform",
+    visibility = ["//visibility:public"],
+)
+
+alias(
+    name = "client_testlib",
+    actual = "//tensorflow/python/platform:client_testlib",
+)
+
+alias(
+    name = "pybind11_absl",
+    actual = "//tensorflow/python/lib/core:pybind11_absl",
+)
+
+alias(
+    name = "pybind11_proto",
+    actual = "//tensorflow/python/lib/core:pybind11_proto",
+)
+
+alias(
+    name = "py_func_lib",
+    actual = "//tensorflow/python/lib/core:py_func_lib",
+)
+
+alias(
+    name = "py_seq_tensor",
+    actual = "//tensorflow/python/lib/core:py_seq_tensor",
+)
+
+alias(
+    name = "py_util",
+    actual = "//tensorflow/python/lib/core:py_util",
+)
+
+alias(
+    name = "py_record_reader_lib",
+    actual = "//tensorflow/python/lib/io:py_record_reader_lib",
+)
+
+alias(
+    name = "numpy_lib",
+    actual = "//tensorflow/python/lib/core:numpy_lib",
+)
+
+alias(
+    name = "py_exception_registry",
+    actual = "//tensorflow/python/lib/core:py_exception_registry",
+)
+
+alias(
+    name = "pybind11_lib",
+    actual = "//tensorflow/python/lib/core:pybind11_lib",
+)
+
+alias(
+    name = "pybind11_status_headers",
+    actual = "//tensorflow/python/lib/core:pybind11_status_headers",
+)
+
+alias(
+    name = "pybind11_status",
+    actual = "//tensorflow/python/lib/core:pybind11_status",
+)
+
+alias(
+    name = "lib",
+    actual = "//tensorflow/python/lib/io:lib",
+)
+
+alias(
+    name = "safe_ptr",
+    actual = "//tensorflow/python/lib/core:safe_ptr",
+)
+
+alias(
+    name = "ndarray_tensor",
+    actual = "//tensorflow/python/lib/core:ndarray_tensor",
+)
+
+alias(
+    name = "ndarray_tensor_bridge",
+    actual = "//tensorflow/python/lib/core:ndarray_tensor_bridge",
+)
+
+alias(
+    name = "ndarray_tensor_headers",
+    actual = "//tensorflow/python/lib/core:ndarray_tensor_headers",
+)
+
+alias(
     name = "basic_session_run_hooks",
     actual = "//tensorflow/python/training:basic_session_run_hooks",
 )
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 22b4884..6efba38 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -36,9 +36,9 @@
 
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
-from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
 
 from tensorflow.python.eager import context
+from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
 
 # pylint: enable=wildcard-import
 
diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index 089c484..a021480 100644
--- a/tensorflow/python/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -167,31 +167,35 @@
   """Remaps generated code to code it originated from."""
 
   def __init__(self, converted_fn):
+    super().__init__()
     self._source_map = converted_fn.ag_source_map
+    # This may be called repeatedly: once on entry, by the superclass, then by
+    # each child context manager.
+    self._cached_map = None
 
   def get_effective_source_map(self):
-    effective_source_map = self._effective_source_map
-    if effective_source_map is None:
-      if self.parent is not None:
-        parent_map = self.parent.get_effective_source_map()
+    if self._cached_map is not None:
+      return self._cached_map
+
+    parent_map = self.parent.get_effective_source_map()
+
+    effective_source_map = {}
+    for loc, origin in self._source_map.items():
+      effective_source_map[(loc.filename, loc.lineno)] = (origin.loc.filename,
+                                                          origin.loc.lineno,
+                                                          origin.function_name)
+
+    for key, value in parent_map.items():
+      filename, lineno, _ = value
+      value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
+      if value_loc in self._source_map:
+        origin = self._source_map[value_loc]
+        effective_source_map[key] = (origin.loc.filename, origin.loc.lineno,
+                                     origin.function_name)
       else:
-        parent_map = {}
+        effective_source_map[key] = value
 
-      effective_source_map = {}
-      for loc, origin in self._source_map.items():
-        effective_source_map[(loc.filename, loc.lineno)] = (
-            origin.loc.filename, origin.loc.lineno, origin.function_name)
-
-      for key, value in parent_map.items():
-        filename, lineno, _ = value
-        value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
-        if value_loc in self._source_map:
-          origin = self._source_map[value_loc]
-          effective_source_map[key] = (
-              origin.loc.filename, origin.loc.lineno, origin.function_name)
-        else:
-          effective_source_map[key] = value
-      self._effective_source_map = effective_source_map
+    self._cached_map = effective_source_map
     return effective_source_map
 
 
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index 2585683..2b1c0a3 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -153,9 +153,6 @@
   # The check for __code__ below is because isgeneratorfunction crashes
   # without one.
   if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o):
-    logging.warn(
-        'Entity %s appears to be a generator function. It will not be converted'
-        ' by AutoGraph.', o)
     logging.log(2, 'Allowlisted: %s: generator functions are not converted', o)
     return True
 
diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
index cbe4477..ed35b4e 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
@@ -60,9 +60,13 @@
 def islambda(f):
   if not tf_inspect.isfunction(f):
     return False
-  if not hasattr(f, '__name__'):
+  # TODO(mdan): Look into checking the only the code object.
+  if not (hasattr(f, '__name__') and hasattr(f, '__code__')):
     return False
-  return f.__name__ == '<lambda>'
+  # Some wrappers can rename the function, but changing the name of the
+  # code object is harder.
+  return (
+      (f.__name__ == '<lambda>') or (f.__code__.co_name == '<lambda>'))
 
 
 def isnamedtuple(f):
diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py
index 890f9e3..fbfa01f 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py
@@ -105,6 +105,11 @@
     self.assertTrue(inspect_utils.islambda(lambda x: x))
     self.assertFalse(inspect_utils.islambda(test_fn))
 
+  def test_islambda_renamed_lambda(self):
+    l = lambda x: 1
+    l.__name__ = 'f'
+    self.assertTrue(inspect_utils.islambda(l))
+
   def test_isnamedtuple(self):
     nt = collections.namedtuple('TestNamedTuple', ['a', 'b'])
 
diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
index b35b1d2..639e0dd 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
@@ -31,7 +31,9 @@
 from __future__ import division
 from __future__ import print_function
 
-from typing import Any, Callable, Tuple
+import itertools
+
+from typing import Any, Callable, Dict, Set
 
 import gast
 
@@ -39,6 +41,7 @@
 from tensorflow.python.autograph.pyct import cfg
 from tensorflow.python.autograph.pyct import qual_names
 from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
 from tensorflow.python.autograph.pyct.static_analysis import annos
 
 
@@ -118,12 +121,20 @@
     """Resolves the return type of a unary operation."""
     raise NotImplementedError('subclasses must implement')
 
-  def res_binop(self, ns, types_ns, node, left, right):
+  def res_unop(self, ns, types_ns, node, opnd):
     """Resolves the return type of a unary operation."""
     raise NotImplementedError('subclasses must implement')
 
+  def res_binop(self, ns, types_ns, node, left, right):
+    """Resolves the return type of a binary operation."""
+    raise NotImplementedError('subclasses must implement')
 
-class _SymbolTable(object):
+  def res_list_literal(self, ns, elt_types):
+    """Resolves the type of a list literal from its elements."""
+    raise NotImplementedError('subclasses must implement')
+
+
+class _TypeMap(object):
   """Abstraction for the state of the CFG walk for type inference.
 
   This is a value type. Only implements the strictly necessary operators.
@@ -135,7 +146,7 @@
 
   def __init__(self, init_from=None):
     if init_from:
-      assert isinstance(init_from, _SymbolTable)
+      assert isinstance(init_from, _TypeMap)
       self.types = {
           s: set(other_types) for s, other_types in init_from.types.items()
       }
@@ -152,8 +163,8 @@
     return not self.__eq__(other)
 
   def __or__(self, other):
-    assert isinstance(other, _SymbolTable)
-    result = _SymbolTable(self)
+    assert isinstance(other, _TypeMap)
+    result = _TypeMap(self)
     for s, other_types in other.types.items():
       if s not in result.types:
         self_types = set()
@@ -192,13 +203,22 @@
     print(a)  # a = int; side effect of f() accounted for
   """
 
-  def __init__(self, resolver, scope, namespace, closure_types, types_in):
+  def __init__(self,
+               resolver: Resolver,
+               scope: activity.Scope,
+               namespace: Dict[qual_names.QN, Any],
+               closure_types: Dict[qual_names.QN, Set[Any]],
+               types_in: _TypeMap):
     self.resolver = resolver
     self.scope = scope
     self.namespace = namespace
     self.closure_types = closure_types
     self.types_in = types_in
     self.new_symbols = {}
+
+    # rvalue type. This property is set when encountering an assign operation,
+    # so that visiting nodes with Store ctx (typically found on left side of
+    # assignments) can infer the type they should receive.
     self.rtype = None
 
   def visit(self, node):
@@ -221,36 +241,36 @@
       self._check_set(types)
     return types
 
-  def visit_Tuple(self, node):
-    if isinstance(node.ctx, gast.Load):
-      for elt in node.elts:
-        self.visit(elt)
-      # TODO(mdan): Parameterize it.
-      return {Tuple}
-
+  def _apply_unpacking(self, node):
     assert isinstance(node.ctx, gast.Store)
-
     if self.rtype is not None:
       original_stype = self.rtype
       # TODO(mdan): Find a better way to express unpacking.
       i_type = self.resolver.res_value(self.namespace, 0)
       for i, elt in enumerate(node.elts):
-        self.rtype = self.resolver.res_subscript(
+        self.rtype = self.resolver.res_slice(
             self.namespace, self.types_in.types, i, original_stype, i_type)
         self.visit(elt)
       self.rtype = original_stype
       return original_stype
-
     return None
 
+  def visit_Tuple(self, node):
+    if isinstance(node.ctx, gast.Load):
+      elt_types = ()
+      for elt in node.elts:
+        types_ = self.visit(elt)
+        if types_ is None:
+          return None
+        elt_types += (types_,)
+      return set(itertools.product(*elt_types))
+    return self._apply_unpacking(node)
+
   def visit_List(self, node):
     if isinstance(node.ctx, gast.Load):
-      el_types = []
-      for elt in node.elts:
-        el_types.append(self.visit(elt))
-      return {list}
-
-    raise NotImplementedError('list unpacking')
+      elt_types = tuple(self.visit(elt) for elt in node.elts)
+      return self.resolver.res_list_literal(self.namespace, elt_types)
+    return self._apply_unpacking(node)
 
   def visit_Set(self, node):
     raise NotImplementedError()
@@ -442,7 +462,7 @@
     if val_types is None or slice_types is None:
       return None
 
-    types = self.resolver.res_subscript(
+    types = self.resolver.res_slice(
         self.namespace, self.types_in.types, node, val_types, slice_types)
 
     if __debug__:
@@ -480,6 +500,20 @@
 
     return types
 
+  def visit_UnaryOp(self, node):
+    opnd_types = self.visit(node.operand)
+
+    if opnd_types is None:
+      return None
+
+    types = self.resolver.res_unop(
+        self.namespace, self.types_in.types, node, opnd_types)
+
+    if __debug__:
+      self._check_set(types)
+
+    return types
+
 
 class Analyzer(cfg.GraphVisitor):
   """CFG visitor that propagates type information across statements."""
@@ -504,13 +538,13 @@
         n: t for n, t in closure_types.items() if n not in scope.bound
     }
     if context_types:
-      self.context_types = _SymbolTable()
+      self.context_types = _TypeMap()
       self.context_types.types = context_types
     else:
       self.context_types = None
 
   def init_state(self, _):
-    return _SymbolTable()
+    return _TypeMap()
 
   def _update_closure_types(self, ast_node, types):
     existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None)
@@ -528,13 +562,13 @@
   def visit_node(self, node):
     prev_types_out = self.out[node]
 
-    types_in = _SymbolTable()
+    types_in = _TypeMap()
     for n in node.prev:
       types_in |= self.out[n]
     if (self.context_types is not None) and (node is self.graph.entry):
       types_in |= self.context_types
 
-    types_out = _SymbolTable(types_in)
+    types_out = _TypeMap(types_in)
     ast_node = node.ast_node
 
     inferrer = StmtInferrer(self.resolver, self.scope, self.namespace,
diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py
index 5648f8d..861e62b 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from typing import Any, Callable, Tuple
+from typing import Any, Callable, List
 
 from tensorflow.python.autograph.pyct import anno
 from tensorflow.python.autograph.pyct import cfg
@@ -171,7 +171,7 @@
     node, _ = tr.transform(test_fn, None)
     fn_body = node.body
 
-    self.assertTypes(fn_body[0].body[0].value, Tuple)
+    self.assertTypes(fn_body[0].body[0].value, (('x_type', 'y_type'),))
     self.assertTypes(fn_body[0].body[0].value.elts[0], 'x_type')
     self.assertTypes(fn_body[0].body[0].value.elts[1], 'y_type')
 
@@ -656,7 +656,7 @@
       def res_value(self, ns, value):
         return {int}
 
-      def res_subscript(self, ns, types_ns, node, value, slice_):
+      def res_slice(self, ns, types_ns, node, value, slice_):
         test_self.assertSetEqual(value, {list})
         test_self.assertSetEqual(slice_, {int})
         return {str}
@@ -683,7 +683,7 @@
       def res_value(self, ns, value):
         return {int}
 
-      def res_subscript(self, ns, types_ns, node_or_slice, value, slice_):
+      def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
         test_self.assertIn(node_or_slice, (0, 1))
         test_self.assertSetEqual(value, {list})
         test_self.assertSetEqual(slice_, {int})
@@ -699,7 +699,7 @@
     node, _ = TestTranspiler(Resolver).transform(test_fn, None)
     fn_body = node.body
 
-    self.assertTypes(fn_body[1].value, Tuple)
+    self.assertTypes(fn_body[1].value, ((float, str),))
     self.assertTypes(fn_body[1].value.elts[0], float)
     self.assertTypes(fn_body[1].value.elts[1], str)
 
@@ -751,6 +751,196 @@
     self.assertTypes(fn_body[0].value.left, list)
     self.assertTypes(fn_body[0].value.right, list)
 
+  def test_unop(self):
+
+    class Resolver(type_inference.Resolver):
+
+      def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
+        return {list}
+
+      def res_unop(self, ns, types_ns, node, opnd):
+        return {float}
+
+    def test_fn(a):
+      return -a
+
+    node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+    fn_body = node.body
+
+    self.assertTypes(fn_body[0].value, float)
+    self.assertTypes(fn_body[0].value.operand, list)
+
+  def test_tuple_literal(self):
+
+    class Resolver(type_inference.Resolver):
+
+      def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
+        return {int}
+
+    def test_fn(a, b):
+      return a, b
+
+    node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+    fn_body = node.body
+
+    self.assertTypes(fn_body[0].value, ((int, int),))
+    self.assertTypes(fn_body[0].value.elts[0], int)
+    self.assertTypes(fn_body[0].value.elts[1], int)
+
+  def test_list_literal(self):
+
+    class Resolver(type_inference.Resolver):
+
+      def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
+        return {int}
+
+      def res_list_literal(self, ns, elt_types):
+        all_types = set()
+        for s in elt_types:
+          all_types |= s
+        return {List[t] for t in all_types}
+
+    def test_fn(a, b):
+      return [a, b]
+
+    node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+    fn_body = node.body
+
+    self.assertTypes(fn_body[0].value, List[int])
+    self.assertTypes(fn_body[0].value.elts[0], int)
+    self.assertTypes(fn_body[0].value.elts[1], int)
+
+  def test_tuple_unpacking_syntactic(self):
+
+    test_self = self
+
+    class Resolver(type_inference.Resolver):
+
+      def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
+        if name == qual_names.QN('a'):
+          return {int}
+        else:
+          return {float}
+
+      def res_value(self, ns, value):
+        test_self.assertIn(value, (0, 1))
+        return int
+
+      def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
+        test_self.assertIn(node_or_slice, (0, 1))
+        test_self.assertSetEqual(value, {(int, float)})
+        test_self.assertEqual(slice_, int)
+        return {t[node_or_slice] for t in value}
+
+    def test_fn(a, b):
+      c, d = a, b
+      return c, d
+
+    node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+    fn_body = node.body
+
+    self.assertTypes(fn_body[1].value, ((int, float),))
+    self.assertTypes(fn_body[1].value.elts[0], int)
+    self.assertTypes(fn_body[1].value.elts[1], float)
+
+  def test_tuple_unpacking_operational(self):
+
+    test_self = self
+
+    class Resolver(type_inference.Resolver):
+
+      def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
+        return {(int, float)}
+
+      def res_value(self, ns, value):
+        test_self.assertIn(value, (0, 1))
+        return int
+
+      def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
+        test_self.assertIn(node_or_slice, (0, 1))
+        test_self.assertSetEqual(value, {(int, float)})
+        test_self.assertEqual(slice_, int)
+        return {t[node_or_slice] for t in value}
+
+    def test_fn(a):
+      c, d = a
+      return c, d
+
+    node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+    fn_body = node.body
+
+    self.assertTypes(fn_body[1].value, ((int, float),))
+    self.assertTypes(fn_body[1].value.elts[0], int)
+    self.assertTypes(fn_body[1].value.elts[1], float)
+
+  def test_list_expansion_syntactic(self):
+
+    test_self = self
+
+    class Resolver(type_inference.Resolver):
+
+      def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
+        if name == qual_names.QN('a'):
+          return {int}
+        else:
+          return {float}
+
+      def res_value(self, ns, value):
+        test_self.assertIn(value, (0, 1))
+        return int
+
+      def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
+        test_self.assertIn(node_or_slice, (0, 1))
+        test_self.assertSetEqual(value, {(int, float)})
+        test_self.assertEqual(slice_, int)
+        return {t[node_or_slice] for t in value}
+
+    def test_fn(a, b):
+      [c, d] = a, b
+      return c, d
+
+    node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+    fn_body = node.body
+
+    # TODO(mdan): Whether it's List or Tuple might be open for interpretation.
+    self.assertTypes(fn_body[1].value, ((int, float),))
+    self.assertTypes(fn_body[1].value.elts[0], int)
+    self.assertTypes(fn_body[1].value.elts[1], float)
+
+  def test_list_expansion_operational(self):
+
+    test_self = self
+
+    class Resolver(type_inference.Resolver):
+
+      def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
+        if name == qual_names.QN('a'):
+          return {int}
+        else:
+          return {float}
+
+      def res_value(self, ns, value):
+        test_self.assertIn(value, (0, 1))
+        return int
+
+      def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
+        test_self.assertIn(node_or_slice, (0, 1))
+        test_self.assertSetEqual(value, {(int, float)})
+        test_self.assertEqual(slice_, int)
+        return {t[node_or_slice] for t in value}
+
+    def test_fn(a, b):
+      [c, d] = a, b
+      return c, d
+
+    node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+    fn_body = node.body
+
+    # TODO(mdan): Whether it's List or Tuple might be open for interpretation.
+    self.assertTypes(fn_body[1].value, ((int, float),))
+    self.assertTypes(fn_body[1].value.elts[0], int)
+    self.assertTypes(fn_body[1].value.elts[1], float)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc
index d8399a4..3063813 100644
--- a/tensorflow/python/client/tf_session_wrapper.cc
+++ b/tensorflow/python/client/tf_session_wrapper.cc
@@ -711,6 +711,18 @@
       },
       py::return_value_policy::reference);
 
+  m.def(
+      "TF_LoadPluggableDeviceLibrary",
+      [](const char* library_filename) {
+        tensorflow::Safe_TF_StatusPtr status =
+            tensorflow::make_safe(TF_NewStatus());
+        auto output =
+            TF_LoadPluggableDeviceLibrary(library_filename, status.get());
+        tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+        return output;
+      },
+      py::return_value_policy::reference);
+
   m.def("TF_GetOpList", [](TF_Library* lib_handle) {
     TF_Buffer output_buffer = TF_GetOpList(lib_handle);
     return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize(
@@ -720,6 +732,11 @@
 
   m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle,
         py::call_guard<py::gil_scoped_release>());
+
+  m.def("TF_PluggableDeviceLibraryHandle",
+        TF_DeletePluggableDeviceLibraryHandle,
+        py::call_guard<py::gil_scoped_release>());
+
   m.def("TF_AddControlInput", TF_AddControlInput);
   m.def(
       "TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) {
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index bc0ae54..51e6e9f 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -33,7 +33,7 @@
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 12, 7)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 12, 22)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 
diff --git a/tensorflow/python/compiler/mlir/BUILD b/tensorflow/python/compiler/mlir/BUILD
index 7e19379..1e4316d 100644
--- a/tensorflow/python/compiler/mlir/BUILD
+++ b/tensorflow/python/compiler/mlir/BUILD
@@ -11,7 +11,7 @@
     srcs_version = "PY2AND3",
     deps = [
         "//tensorflow/python:pywrap_mlir",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD
index 6b3f32c..9237ad1 100644
--- a/tensorflow/python/compiler/tensorrt/BUILD
+++ b/tensorflow/python/compiler/tensorrt/BUILD
@@ -97,7 +97,6 @@
     tags = [
         "no_cuda_on_cpu_tap",
         "no_pip",
-        "no_rocm",
         "no_windows",
         "nomac",
     ],
diff --git a/tensorflow/python/compiler/xla/jit_compile_test.py b/tensorflow/python/compiler/xla/jit_compile_test.py
index 7b71573..9ec0ffe 100644
--- a/tensorflow/python/compiler/xla/jit_compile_test.py
+++ b/tensorflow/python/compiler/xla/jit_compile_test.py
@@ -37,20 +37,14 @@
 
       xla_func = def_function.function(fn, jit_compile=True)
       inputs = array_ops.placeholder(dtypes.float32, [5])
-      # XLA support is not yet enabled for TF ROCm
-      if not test.is_built_with_rocm():
-        x = xla_func(inputs, 1)
-        with session.Session(graph=g) as sess:
-          y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
-          self.assertTrue(x.graph.as_graph_def().library.function[0]
-                          .attr["_XlaMustCompile"].b)
-          self.assertAllClose([2, 3, 3, 4, 4], y)
+      x = xla_func(inputs, 1)
+      with session.Session(graph=g) as sess:
+        y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
+        self.assertTrue(x.graph.as_graph_def().library.function[0]
+                        .attr["_XlaMustCompile"].b)
+        self.assertAllClose([2, 3, 3, 4, 4], y)
 
   def testDerivative(self):
-    # XLA support is not yet enabled for TF ROCm
-    if test.is_built_with_rocm():
-      return
-
     def fn(x, a):
       return 2 * x + a
 
@@ -81,14 +75,12 @@
 
       xla_func = def_function.function(fn, jit_compile=True)
       inputs = array_ops.placeholder(dtypes.int32, [5])
-      # XLA support is not yet enabled for TF ROCm
-      if not test.is_built_with_rocm():
-        x = xla_func(inputs, 1)
-        with session.Session(graph=g) as sess:
-          y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
-          self.assertTrue(x.graph.as_graph_def().library.function[0]
-                          .attr["_XlaMustCompile"].b)
-          self.assertAllClose([2, 3, 3, 4, 4], y)
+      x = xla_func(inputs, 1)
+      with session.Session(graph=g) as sess:
+        y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
+        self.assertTrue(x.graph.as_graph_def().library.function[0]
+                        .attr["_XlaMustCompile"].b)
+        self.assertAllClose([2, 3, 3, 4, 4], y)
 
   # Checking that we crash on an unsupported operation lets us test that the XLA
   # compiler was actually invoked.
@@ -101,12 +93,10 @@
       xla_func = def_function.function(fn, jit_compile=True)
       inputs = array_ops.placeholder(dtypes.float32, [5])
       x = xla_func(inputs)
-      # XLA support is not yet enabled for TF ROCm
-      if not test.is_built_with_rocm():
-        with self.assertRaisesRegex(errors.InvalidArgumentError,
-                                    "not compilable"):
-          with session.Session(graph=g) as sess:
-            sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
+      with self.assertRaisesRegex(errors.InvalidArgumentError,
+                                  "not compilable"):
+        with session.Session(graph=g) as sess:
+          sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index d69bf6a..fae8051 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -56,9 +56,9 @@
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:tensor_spec",
         "//tensorflow/python:tf2",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python/data/util:convert",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 0448dcc..46af724 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1784,6 +1784,12 @@
     ...     num_parallel_calls=tf.data.AUTOTUNE,
     ...     deterministic=False)
 
+    The order of elements yielded by this transformation is deterministic if
+    `deterministic=True`. If `map_func` contains stateful operations and
+    `num_parallel_calls > 1`, the order in which that state is accessed is
+    undefined, so the values of output elements may not be deterministic
+    regardless of the `deterministic` flag value.
+
     Args:
       map_func: A function mapping a dataset element to another dataset element.
       num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
@@ -1792,11 +1798,10 @@
         `tf.data.AUTOTUNE` is used, then the number of parallel
         calls is set dynamically based on available CPU.
       deterministic: (Optional.) A boolean controlling whether determinism
-        should be traded for performance by allowing elements to be produced out
+        should be traded for performance by allowing elements to be yielded out
         of order.  If `deterministic` is `None`, the
         `tf.data.Options.experimental_deterministic` dataset option (`True` by
-        default) is used to decide whether to produce elements
-        deterministically.
+        default) is used to decide whether to run deterministically.
 
     Returns:
       Dataset: A `Dataset`.
@@ -1925,8 +1930,7 @@
         should be traded for performance by allowing elements to be produced out
         of order.  If `deterministic` is `None`, the
         `tf.data.Options.experimental_deterministic` dataset option (`True` by
-        default) is used to decide whether to produce elements
-        deterministically.
+        default) is used to decide whether to run deterministically.
 
     Returns:
       Dataset: A `Dataset`.
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index 9f9bcde..69f4da5 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -37,8 +37,8 @@
 
 import six as _six
 
-from tensorflow.python import _pywrap_utils
 from tensorflow.python.framework import sparse_tensor as _sparse_tensor
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util.compat import collections_abc as _collections_abc
 
 
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 36fdb20..b4a6ce6 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -734,7 +734,10 @@
     size = "medium",
     srcs = ["lib/check_numerics_callback_test.py"],
     python_version = "PY3",
-    tags = ["no_windows"],
+    tags = [
+        "no_mac",  # TODO(b/175322370): Detected Infinity or NaN in output 0 of graph op "RealDiv"
+        "no_windows",
+    ],
     deps = [
         ":check_numerics_callback",
         "//tensorflow/python:framework_test_lib",
@@ -1116,7 +1119,6 @@
     srcs = ["cli/debugger_cli_common_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
-    tags = ["no_rocm"],
     deps = [
         ":debugger_cli_common",
         "//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 6fb015e..0df1dce 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -48,10 +48,10 @@
         "//tensorflow/python:platform",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:tensor_util",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:executor",
+        "//tensorflow/python/util:tf_export",
         "//tensorflow/tools/docs:doc_controls",
         "@enum34_archive//:enum",
         "@six_archive//:six",
@@ -162,7 +162,6 @@
     srcs = ["distribute_lib_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
-    tags = ["no_rocm"],
     deps = [
         ":combinations",
         ":distribute_lib",
@@ -255,7 +254,6 @@
         "//tensorflow/python:pywrap_tfe",
         "//tensorflow/python:summary_ops_v2",
         "//tensorflow/python:tensor_util",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:training",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
@@ -263,6 +261,7 @@
         "//tensorflow/python/autograph/impl",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -273,6 +272,7 @@
         ":device_util",
         ":distribute_lib",
         ":reduce_util",
+        ":sharded_variable",
         ":shared_variable_creator",
         ":tpu_values",
         ":values",
@@ -286,7 +286,6 @@
         "//tensorflow/python:pywrap_tfe",
         "//tensorflow/python:summary_ops_v2",
         "//tensorflow/python:tensor_util",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:training",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
@@ -294,6 +293,7 @@
         "//tensorflow/python/autograph/impl",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -345,12 +345,12 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:resource_variable_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:training",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -408,11 +408,11 @@
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:platform",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:training",
         "//tensorflow/python:util",
         "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -715,7 +715,6 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:resource_variable_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:type_spec",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
@@ -725,6 +724,7 @@
         "//tensorflow/python/training/saving:saveable_object_util",
         "//tensorflow/python/training/tracking:base",
         "//tensorflow/python/types",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -836,12 +836,12 @@
         ":tpu_strategy",
         "//tensorflow/python:platform",
         "//tensorflow/python:tf2",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
         "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:remote",
         "//tensorflow/python/tpu:tpu_lib",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -1087,13 +1087,13 @@
         "//tensorflow/python:partitioned_variables",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:tensor_shape",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:type_spec",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
         "//tensorflow/python/saved_model:save_context",
         "//tensorflow/python/training/saving:saveable_object_util",
         "//tensorflow/python/training/tracking:base",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -1119,6 +1119,7 @@
         "//tensorflow/python/compat:v2_compat",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/module",
+        "//tensorflow/python/saved_model:load",
         "//tensorflow/python/saved_model:loader",
         "//tensorflow/python/saved_model:save",
         "//tensorflow/python/saved_model:signature_constants",
@@ -1168,6 +1169,7 @@
         "multi_and_single_gpu",
         "no_cuda_asan",  # times out
         "no_rocm",
+        "noasan",  # b/175816710
         "notsan",  # b/168645872
     ],
     tpu_tags = [
@@ -1658,6 +1660,11 @@
     srcs = ["multi_process_runner_test.py"],
     python_version = "PY3",
     shard_count = 12,
+    tags = [
+        "noasan",
+        "nomsan",
+        "notsan",
+    ],  # b/175904958
     deps = [
         ":multi_process_runner",
         ":multi_worker_test_base",
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
index c0e67a4..159fc33 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
@@ -994,7 +994,7 @@
           "currently.")
     self._strategy = strategy
     self._strategy.extended._used_with_coordinator = True
-    self.cluster = Cluster(strategy)
+    self._cluster = Cluster(strategy)
 
   @property
   def strategy(self):
@@ -1067,7 +1067,7 @@
     # Slot variables are usually created during function tracing time; thus
     # `schedule` needs to be called within the `strategy.scope()`.
     with self.strategy.scope():
-      return self.cluster.schedule(fn, args=args, kwargs=kwargs)
+      return self._cluster.schedule(fn, args=args, kwargs=kwargs)
 
   def join(self):
     """Blocks until all the scheduled functions have finished execution.
@@ -1088,7 +1088,7 @@
         previously scheduled function since the last time an error was thrown or
         since the beginning of the program.
     """
-    self.cluster.join()
+    self._cluster.join()
 
   def done(self):
     """Returns whether all the scheduled functions have finished execution.
@@ -1106,7 +1106,7 @@
         previously scheduled function since the last time an error was thrown or
         since the beginning of the program.
     """
-    return self.cluster.done()
+    return self._cluster.done()
 
   def create_per_worker_dataset(self, dataset_fn):
     """Create dataset on workers by calling `dataset_fn` on worker devices.
@@ -1168,7 +1168,7 @@
       iterators (that are on the workers).
     """
     input_workers = input_lib.InputWorkers([
-        (w.device_name, [w.device_name]) for w in self.cluster.workers
+        (w.device_name, [w.device_name]) for w in self._cluster.workers
     ])
 
     return _PerWorkerDistributedDataset(dataset_fn, input_workers, self)
@@ -1191,7 +1191,7 @@
       objects.
     """
     results = []
-    for w in self.cluster.workers:
+    for w in self._cluster.workers:
       results.append(w._create_resource(fn, args=args, kwargs=kwargs))  # pylint: disable=protected-access
     return PerWorkerValues(tuple(results))
 
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py b/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py
index 4c5a8e2..ea98181 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py
@@ -162,7 +162,7 @@
         if test_join:
           ps_coordinator.join()
         if test_schedule:
-          while ps_coordinator.cluster._closure_queue._error is None:
+          while ps_coordinator._cluster._closure_queue._error is None:
             time.sleep(1)
           ps_coordinator.schedule(worker_fn)
       except errors.UnavailableError:
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py b/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py
index 0439863..819aec8 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py
@@ -575,7 +575,7 @@
 
   def testDatasetsShuffledDifferently(self):
     # This test requires at least two workers in the cluster.
-    self.assertGreaterEqual(len(self.coordinator.cluster.workers), 2)
+    self.assertGreaterEqual(len(self.coordinator._cluster.workers), 2)
 
     random_seed.set_random_seed(None)
 
@@ -587,12 +587,12 @@
 
     # Get elements from the first two iterators.
     iterator_1 = distributed_iterator._values[0]
-    iterator_1._rebuild_on(self.coordinator.cluster.workers[0])
+    iterator_1._rebuild_on(self.coordinator._cluster.workers[0])
     iterator_1 = iterator_1.fetch()
     elements_in_iterator_1 = [e.numpy() for e in iterator_1]
 
     iterator_2 = distributed_iterator._values[1]
-    iterator_2._rebuild_on(self.coordinator.cluster.workers[1])
+    iterator_2._rebuild_on(self.coordinator._cluster.workers[1])
     iterator_2 = iterator_2.fetch()
     elements_in_iterator_2 = [e.numpy() for e in iterator_2]
 
diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
index 17472e0..0089184 100644
--- a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
+++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
@@ -175,7 +175,8 @@
     model.schedule_training_functions(4)
     # Model does infinite training step, so at this moment, we expect to have 2
     # infinite closures inflight, and 2 closures in the queue.
-    while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
+    while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
+           < 2):
       time.sleep(0.1)
     self.assertFalse(self.cluster_coord.done())
     self._restart(downtime_secs=2, job="worker")
@@ -356,7 +357,8 @@
 
     # Model does infinite training step, so at this moment, we expect to have 2
     # infinite closures inflight, and 8 closures in the queue.
-    while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
+    while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
+           < 2):
       time.sleep(0.1)
     self.assertFalse(self.cluster_coord.done())
     self._cluster.kill_task("worker", 0)
@@ -380,7 +382,8 @@
 
     # Model does infinite training step, so at this moment, we expect to have 2
     # infinite closures inflight, and 8 closures in the queue.
-    while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
+    while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
+           < 2):
       time.sleep(0.1)
     self.assertFalse(self.cluster_coord.done())
     self._cluster.kill_task("worker", 0)
diff --git a/tensorflow/python/distribute/integration_test/saved_model_test.py b/tensorflow/python/distribute/integration_test/saved_model_test.py
index 8496f3c..147fc81 100644
--- a/tensorflow/python/distribute/integration_test/saved_model_test.py
+++ b/tensorflow/python/distribute/integration_test/saved_model_test.py
@@ -611,6 +611,26 @@
     # ShardedVariable loading only works in v1.
     self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6])
 
+    with self.assertRaisesWithLiteralMatch(
+        ValueError, "Loading `ShardedVariable` is not supported"):
+      with strategy.scope():
+        tf.saved_model.load(model_dir)
+
+    with self.assertRaisesWithLiteralMatch(
+        ValueError, "Loading `ShardedVariable` is not supported"):
+      tf.saved_model.load(model_dir)
+
+  def test_load_with_partitioner_raises_error(self):
+    model = self.Model()
+    model_dir = self.get_temp_dir()
+    tf.saved_model.save(model, model_dir)
+
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, tf1.fixed_size_partitioner(2))
+    with self.assertRaisesRegex(ValueError, "`variable_partitioner`"):
+      with strategy.scope():
+        tf.saved_model.load(model_dir)
+
 
 if __name__ == "__main__":
   # TODO(b/172304955): enable logical devices.
diff --git a/tensorflow/python/distribute/multi_process_lib.py b/tensorflow/python/distribute/multi_process_lib.py
index 14fe8a4..12084fe 100644
--- a/tensorflow/python/distribute/multi_process_lib.py
+++ b/tensorflow/python/distribute/multi_process_lib.py
@@ -98,23 +98,27 @@
   """
   # TODO(b/150264776): This does not work with Windows. Find a solution.
   if sys.argv[0].endswith('.py'):
-    path = None
-    # If all we have is a python module path, we'll need to make a guess for
-    # the actual executable path.
-    if 'bazel-out' in sys.argv[0]:
-      # Guess the binary path under bazel. For target
-      # //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
-      # argv[0] is in the form of
-      # /.../tensorflow/python/distribute/input_lib_test.py
-      # and the binary is
-      # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
-      org_tensorflow_base = sys.argv[0][:sys.argv[0].rfind('/org_tensorflow')]
-      binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
-      possible_path = os.path.join(org_tensorflow_base, 'org_tensorflow',
-                                   binary)
-      logging.info('Guessed test binary path: %s', possible_path)
-      if os.access(possible_path, os.X_OK):
-        path = possible_path
+    def guess_path(package_root):
+      # If all we have is a python module path, we'll need to make a guess for
+      # the actual executable path.
+      if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]:
+        # Guess the binary path under bazel. For target
+        # //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
+        # argv[0] is in the form of
+        # /.../tensorflow/python/distribute/input_lib_test.py
+        # and the binary is
+        # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
+        package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)]
+        binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
+        possible_path = os.path.join(package_root_base, package_root,
+                                     binary)
+        logging.info('Guessed test binary path: %s', possible_path)
+        if os.access(possible_path, os.X_OK):
+          return possible_path
+        return None
+    path = guess_path('org_tensorflow')
+    if not path:
+      path = guess_path('org_keras')
     if path is None:
       logging.error(
           'Cannot determine binary path. sys.argv[0]=%s os.environ=%s',
diff --git a/tensorflow/python/distribute/packed_distributed_variable.py b/tensorflow/python/distribute/packed_distributed_variable.py
index 4c9433d..a158411 100644
--- a/tensorflow/python/distribute/packed_distributed_variable.py
+++ b/tensorflow/python/distribute/packed_distributed_variable.py
@@ -282,6 +282,10 @@
     with ops.device(self._device):
       return self._var.handle
 
+  def on_device_handle(self):
+    with ops.device(self._device):
+      return self._var.get_var_on_current_device().handle
+
   @property
   def op(self):
     with ops.device(self._device):
diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py
index 01a7c30..c3e1d3f 100644
--- a/tensorflow/python/distribute/parameter_server_strategy_v2.py
+++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py
@@ -560,7 +560,13 @@
     name = kwargs.get("name", None)
     initial_value = kwargs.get("initial_value", None)
     if initial_value is None:
-      raise ValueError("initial_value must be specified.")
+      raise ValueError(
+          "It looks like you are using `ParameterServerStrategy` with a "
+          "`variable_partitioner`, and trying to create a variable without "
+          "specifying `initial_value`. This is not allowed. Please specify the "
+          "`initial_value`. This can also happen if you are trying to load a "
+          "saved_model within a `ParameterServerStrategy` scope. Loading a "
+          "saved_model with `variable_partitioner` is not supported.")
 
     # Two cases where initial_value can be a callable:
     #   1. initial_value is passed as a callable, e.g, an `initializer` class.
diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py
index 553d82e..5b56af7 100644
--- a/tensorflow/python/distribute/sharded_variable.py
+++ b/tensorflow/python/distribute/sharded_variable.py
@@ -28,6 +28,7 @@
 from tensorflow.python.ops import partitioned_variables
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.saved_model import revived_types
 from tensorflow.python.saved_model import save_context
 from tensorflow.python.training.saving import saveable_object_util
 from tensorflow.python.training.tracking import base as trackable
@@ -500,3 +501,21 @@
   return embedding_ops.embedding_lookup(params.variables, ids,
                                         partition_strategy, name,
                                         validate_indices, max_norm)
+
+
+def _raise_when_load(_):
+  # We don't have serialization and deserialization mechanisms for
+  # `ShardedVariable` in 2.x style save/load yet.
+  raise ValueError('Loading `ShardedVariable` is not supported')
+
+
+revived_types.register_revived_type(
+    '_tf_distribute_sharded_variable',
+    lambda obj: isinstance(obj, ShardedVariable),
+    versions=[
+        revived_types.VersionedTypeRegistration(
+            object_factory=_raise_when_load,
+            version=0,
+            min_producer_version=0,
+            min_consumer_version=0)
+    ])
diff --git a/tensorflow/python/distribute/sharded_variable_test.py b/tensorflow/python/distribute/sharded_variable_test.py
index a020a85..822012b 100644
--- a/tensorflow/python/distribute/sharded_variable_test.py
+++ b/tensorflow/python/distribute/sharded_variable_test.py
@@ -35,6 +35,7 @@
 from tensorflow.python.ops import embedding_ops
 from tensorflow.python.ops import variables as variables_lib
 from tensorflow.python.platform import test
+from tensorflow.python.saved_model import load
 from tensorflow.python.saved_model import loader
 from tensorflow.python.saved_model import save
 from tensorflow.python.saved_model import signature_constants
@@ -300,6 +301,19 @@
     # Continue using root.train for training
     self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
 
+  def test_load_raises_error(self):
+    root = tracking.AutoTrackable()
+    v1 = variables_lib.Variable([3.])
+    v2 = variables_lib.Variable([2.])
+    root.v = sharded_variable.ShardedVariable([v1, v2])
+
+    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
+    save.save(root, save_dir)
+
+    with self.assertRaisesWithLiteralMatch(
+        ValueError, 'Loading `ShardedVariable` is not supported'):
+      load.load(save_dir)
+
   def test_validation_errors(self):
     with self.assertRaisesRegex(ValueError, 'Expected a list of '):
       sharded_variable.ShardedVariable(
diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py
index 239882c..9f5fdb0 100644
--- a/tensorflow/python/distribute/tpu_strategy_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_test.py
@@ -167,6 +167,19 @@
 @parameterized.named_parameters([("PackedVar", True), ("", False)])
 class TPUStrategyTest(test.TestCase, parameterized.TestCase):
 
+  def test_handle_in_cross_replica_context(self, enable_packed_var):
+    strategy = get_tpu_strategy(enable_packed_var)
+    with strategy.scope():
+      v = variables.Variable(1.0)
+
+    @def_function.function
+    def func():
+      self.assertEndsWith(v.handle.device, "device:TPU:0")
+      return v + 1.0
+
+    ret = func()
+    self.assertAllEqual(ret, 2.0)
+
   def test_function_compile_with_xla(self, enable_packed_var):
     strategy = get_tpu_strategy(enable_packed_var)
     with strategy.scope():
diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py
index 3094f74..dbe1e1f 100644
--- a/tensorflow/python/distribute/tpu_values.py
+++ b/tensorflow/python/distribute/tpu_values.py
@@ -115,7 +115,11 @@
     # If we're in a tpu.rewrite(), return the replicated handle.
     tpu_context = enclosing_tpu_context()
     if tpu_context is None or context.executing_eagerly():
-      return self._get_on_device_or_primary().handle
+      var = self._get_on_device_or_primary()
+      if isinstance(var, packed.PackedVarAndDevice):
+        return var.on_device_handle()
+      else:
+        return var.handle
     else:
       is_packed = self._packed_var is not None
       val = self._values
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index e48a02a..ff77ab4 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -64,16 +64,16 @@
         "//tensorflow/core/platform:types",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/core/util:abstract_stack_trace",
-        "//tensorflow/python:cpp_python_util",
         "//tensorflow/python:ndarray_tensor",
         "//tensorflow/python:ndarray_tensor_bridge",
         "//tensorflow/python:numpy_lib",
         "//tensorflow/python:py_exception_registry",
-        "//tensorflow/python:py_seq_tensor",
-        "//tensorflow/python:py_util",
-        "//tensorflow/python:safe_ptr",
-        "//tensorflow/python:safe_pyobject_ptr",
-        "//tensorflow/python:stack_trace",
+        "//tensorflow/python/lib/core:py_seq_tensor",
+        "//tensorflow/python/lib/core:py_util",
+        "//tensorflow/python/lib/core:safe_ptr",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
+        "//tensorflow/python/util:cpp_python_util",
+        "//tensorflow/python/util:stack_trace",
         "//third_party/py/numpy:headers",
         "//third_party/python_runtime:headers",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -232,10 +232,10 @@
         "//tensorflow/c/eager:c_api",
         "//tensorflow/c/eager:c_api_experimental",
         "//tensorflow/c/eager:custom_device_testutil",
-        "//tensorflow/python:cpp_python_util",
         "//tensorflow/python:pybind11_lib",
-        "//tensorflow/python:pybind11_status",
-        "//tensorflow/python:safe_ptr",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "//tensorflow/python/lib/core:safe_ptr",
+        "//tensorflow/python/util:cpp_python_util",
         "//third_party/python_runtime:headers",
         "@pybind11",
     ],
@@ -1022,6 +1022,7 @@
     shard_count = 2,
     tags = [
         "no_oss",  # This test launches local server.
+        "notap",  # b/175813762
         "optonly",  # times out
     ],
     deps = [
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index be121bf..3d6faf2 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -28,7 +28,6 @@
 import six
 
 from tensorflow.python import pywrap_tfe
-from tensorflow.python import _pywrap_utils
 from tensorflow.python.eager import backprop_util
 from tensorflow.python.eager import context
 from tensorflow.python.eager import execute
@@ -49,6 +48,7 @@
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_inspect
@@ -1344,7 +1344,10 @@
                                  parallel_iterations=parallel_iterations)
     new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
     if output is None:
-      output = array_ops.zeros(new_shape)
+      # Note that this block is returning zeros when it could use `None` to
+      # represent unconnected gradients. This is to maintain compatibility with
+      # the previous behavior, which ignored `unconnected_gradients`.
+      output = array_ops.zeros(new_shape, target.dtype)
       if rewrap_as_ndarray:
         output = np_arrays.tensor_to_ndarray(output)
       return output
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 0063b7f..bdc2bae 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1961,6 +1961,31 @@
       f = def_function.function(f)
     self.assertAllEqual([1, 0, 0], array_ops.shape(f(array_ops.zeros([1, 0]))))
 
+  @parameterized.parameters((True,), (False))
+  def test_zeros_type_correct(self, use_pfor):
+    for dtype in [dtypes.float32, dtypes.float64]:
+      @def_function.function
+      def f(x):
+        del x
+        return constant_op.constant([[1.]], dtype=dtype)  # pylint: disable=cell-var-from-loop
+
+      with backprop.GradientTape(persistent=True) as tape:
+        x = constant_op.constant([[2.]], dtype=dtype)
+        tape.watch(x)
+        y = f(x)
+      jac = tape.batch_jacobian(y, x, experimental_use_pfor=use_pfor)
+      self.assertEqual(dtype, jac.dtype)
+      self.assertAllClose([[[0.]]], jac)
+
+      with backprop.GradientTape(persistent=True) as tape:
+        x = constant_op.constant([[2.]], dtype=dtype)
+        tape.watch(x)
+        y = f(x)
+      jac = tape.batch_jacobian(y, x, unconnected_gradients='zero',
+                                experimental_use_pfor=use_pfor)
+      self.assertEqual(dtype, jac.dtype)
+      self.assertAllClose([[[0.]]], jac)
+
 
 class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
 
diff --git a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py
index 573c8bc2..0a47209 100644
--- a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py
+++ b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py
@@ -266,8 +266,8 @@
 
   def _report(self, label, start, num_iters, device, batch_size, data_format,
               num_replicas=1):
-    resnet50_test_util.report(self, label, start, num_iters, device,
-                              batch_size, data_format, num_replicas=1)
+    resnet50_test_util.report(self, label, start, num_iters, device, batch_size,
+                              data_format, num_replicas)
 
   def _train_batch_sizes(self):
     """Choose batch sizes based on GPU capability."""
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 046a09f..5e004ad 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -1245,12 +1245,17 @@
   def invoking_op_callbacks(self, value):
     self._thread_local_data.invoking_op_callbacks = value
 
-  def _initialize_physical_devices(self):
-    """Get local devices visible to the system."""
+  def _initialize_physical_devices(self, reinitialize=False):
+    """Gets local devices visible to the system.
+
+    Args:
+      reinitialize: If True, reinitializes self._physical_devices  so that
+        dynamic registered devices will also be visible to the python front-end.
+    """
     # We lazy initialize self._physical_devices since we do not want to do this
     # the constructor since the backend may not be initialized yet.
     with self._device_lock:
-      if self._physical_devices is not None:
+      if not reinitialize and self._physical_devices is not None:
         return
 
       devs = pywrap_tfe.TF_ListPhysicalDevices()
@@ -1269,6 +1274,12 @@
     # Import device settings that may have been passed into the constructor
     self._import_config()
 
+  def reinitialize_physical_devices(self):
+    """Gets local devices visible to the system."""
+    # Reinitialize the physical device list after registering
+    # the pluggable device.
+    self._initialize_physical_devices(True)
+
   def list_physical_devices(self, device_type=None):
     """List local devices visible to the system.
 
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index ec33d31..44336a7 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -51,84 +51,94 @@
 
 FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
 FREQUENT_TRACING_WARNING_THRESHOLD = 5
+FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2
 
 
-class _CallCounter(object):
+class _FrequentTracingDetector(object):
   """Class keeping track of how many recent calls triggered tracing."""
 
-  __slots__ = ["_max_call_history", "_calls_per_tracings", "call_count"]
+  __slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"]
 
-  def __init__(self, max_call_history):
-    self._max_call_history = max_call_history
+  def __init__(self):
     self._calls_per_tracings = []
-    self.call_count = 0
+    self._total_warning_count = 0
+    self._call_count = 0
 
-  def called_with_tracing(self):
-    self.call_count += 1
+  def called_with_tracing(self, function_name, omit_warning):
+    """Updates the list of most recent calls' tracing information.
+
+    Warns the user when recent calls caused retracing too often.
+
+    Args:
+      function_name: the python function being traced.
+      omit_warning: If 'True', this call will not warn the user even if
+        retracing happens too often.
+    """
+    self._call_count += 1
     self._calls_per_tracings.append(1)
 
     while self._calls_per_tracings:
-      if self.call_count - self._calls_per_tracings[0] > self._max_call_history:
-        self.call_count -= self._calls_per_tracings.pop(0)
+      if (self._call_count - self._calls_per_tracings[0] >
+          FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY):
+        self._call_count -= self._calls_per_tracings.pop(0)
       else:
         break
 
+    if (omit_warning or self._total_warning_count >=
+        FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR):
+      return
+    if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD:
+      self._total_warning_count += 1
+      logging.warning(
+          "{} out of the last {} calls to {} triggered tf.function "
+          "retracing. Tracing is expensive and the excessive number of "
+          "tracings could be due to (1) creating @tf.function repeatedly in "
+          "a loop, (2) passing tensors with different shapes, (3) passing "
+          "Python objects instead of tensors. For (1), please define your "
+          "@tf.function outside of the loop. For (2), @tf.function has "
+          "experimental_relax_shapes=True option that relaxes argument "
+          "shapes that can avoid unnecessary retracing. For (3), please "
+          "refer to "
+          "https://www.tensorflow.org/guide/function#controlling_retracing"
+          " and https://www.tensorflow.org/api_docs/python/tf/function for "
+          " more details.".format(
+              len(self._calls_per_tracings), self._call_count, function_name))
+
   def called_without_tracing(self):
     # We don't count tracing when users load a concrete function directly or
     # call get_concrete_function, so the first call can be not a tracing call.
     if not self._calls_per_tracings:
       self._calls_per_tracings = [0]
     self._calls_per_tracings[-1] += 1
-    self.call_count += 1
-
-  def get_tracing_count(self):
-    return len(self._calls_per_tracings)
+    self._call_count += 1
 
 
-class _FrequentTracingDetector(object):
-  """Class for frequent retracing detection and warning."""
+class _FrequentTracingDetectorManager(object):
+  """Class for the management of all _FrequentTracingDetector objects."""
 
-  __slots__ = ["_counters", "_lock"]
+  __slots__ = ["_detectors", "_lock"]
 
   def __init__(self):
-    self._counters = weakref.WeakKeyDictionary()  # GUARDED_BY(self._lock)
+    self._detectors = weakref.WeakKeyDictionary()  # GUARDED_BY(self._lock)
     self._lock = threading.Lock()
 
-  def _get_counter(self, key):
-    if key not in self._counters:
-      self._counters[key] = _CallCounter(
-          FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
-    return self._counters[key]
+  def _get_detector(self, key):
+    if key not in self._detectors:
+      self._detectors[key] = _FrequentTracingDetector()
+    return self._detectors[key]
 
   def called_without_tracing(self, key):
     with self._lock:
-      counter = self._get_counter(key)
-      counter.called_without_tracing()
+      detector = self._get_detector(key)
+      detector.called_without_tracing()
 
   def called_with_tracing(self, key, function_name, omit_warning):
     with self._lock:
-      counter = self._get_counter(key)
-      counter.called_with_tracing()
-      if omit_warning:
-        return
-      if counter.get_tracing_count() >= FREQUENT_TRACING_WARNING_THRESHOLD:
-        logging.warning(
-            "{} out of the last {} calls to {} triggered tf.function "
-            "retracing. Tracing is expensive and the excessive number of "
-            "tracings could be due to (1) creating @tf.function repeatedly in "
-            "a loop, (2) passing tensors with different shapes, (3) passing "
-            "Python objects instead of tensors. For (1), please define your "
-            "@tf.function outside of the loop. For (2), @tf.function has "
-            "experimental_relax_shapes=True option that relaxes argument "
-            "shapes that can avoid unnecessary retracing. For (3), please "
-            "refer to "
-            "https://www.tensorflow.org/guide/function#controlling_retracing"
-            " and https://www.tensorflow.org/api_docs/python/tf/function for "
-            " more details.".format(counter.get_tracing_count(),
-                                    counter.call_count, function_name))
+      detector = self._get_detector(key)
+      detector.called_with_tracing(function_name, omit_warning)
 
 
-_frequent_tracing_detector = _FrequentTracingDetector()
+_frequent_tracing_detector_manager = _FrequentTracingDetectorManager()
 
 
 class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
@@ -794,10 +804,10 @@
 
     if context.executing_eagerly():
       if without_tracing:
-        _frequent_tracing_detector.called_without_tracing(
+        _frequent_tracing_detector_manager.called_without_tracing(
             self._key_for_call_stats)
       else:
-        _frequent_tracing_detector.called_with_tracing(
+        _frequent_tracing_detector_manager.called_with_tracing(
             self._key_for_call_stats, self._python_function,
             self._omit_frequent_tracing_warning)
 
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
index 03970d8..2b1dad4 100644
--- a/tensorflow/python/eager/def_function_test.py
+++ b/tensorflow/python/eager/def_function_test.py
@@ -956,6 +956,18 @@
       self.assertLen(logs.output, 1)
       self.assertIn('Tracing is expensive', logs.output[0])
 
+  def test_retracing_warning_limits(self):
+
+    @def_function.function
+    def my_func(x):
+      return x
+
+    with self.assertLogs(level='WARN') as logs:
+      for i in range(10):
+        my_func(i)
+
+      self.assertLen(logs.output, 2)
+
   def test_experimental_get_tracing_count_function(self):
 
     @def_function.function
diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py
index 281ff14..ead508c 100644
--- a/tensorflow/python/eager/def_function_xla_jit_test.py
+++ b/tensorflow/python/eager/def_function_xla_jit_test.py
@@ -153,6 +153,9 @@
   @test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns'
                                  ' wrong status type')
   def testUnsupportedOps(self):
+    if 'tpu' in self.device.lower():
+      self.skipTest('XLA TPU supports tf.unique')
+
     with ops.device('device:{}:0'.format(self.device)):
 
       def fn(x):
@@ -167,6 +170,136 @@
                                   'not compilable'):
         xla_func(inputs)
 
+  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
+                                 'support stack traces')
+  def testPythonLocationInMetadata(self):
+    with ops.device('device:{}:0'.format(self.device)):
+
+      @def_function.function(jit_compile=True)
+      def fn(x, y):
+        return x + y
+
+      inputs = constant_op.constant([1, 2, 2, 3, 3])
+      self.assertIn('def_function_xla_jit_test',
+                    fn.experimental_get_compiler_ir(inputs, inputs)())
+
+  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
+                                 'support stack traces')
+  def testPythonLocationNestedInMetadata(self):
+    with ops.device('device:{}:0'.format(self.device)):
+
+      @def_function.function(jit_compile=True)
+      def f(x, y):
+        return x + y
+
+      @def_function.function(jit_compile=True)
+      def g(x, y):
+        return f(x, y)
+
+      inputs = constant_op.constant([1, 2, 2, 3, 3])
+      self.assertIn('def_function_xla_jit_test',
+                    g.experimental_get_compiler_ir(inputs, inputs)())
+
+  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
+                                 'support stack traces')
+  def testPythonStackTrace(self):
+    if 'tpu' in self.device.lower():
+      self.skipTest('XLA TPU supports tf.unique')
+
+    with ops.device('device:{}:0'.format(self.device)):
+
+      @def_function.function(jit_compile=True)
+      def fn(x):
+        return array_ops.unique(x).y  # COMMENT2
+
+      inputs = constant_op.constant([1, 2, 2, 3, 3])
+      with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT2'):
+        fn(inputs)
+
+  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
+                                 'support stack traces')
+  def testPythonStackTraceControlFlow(self):
+    if 'tpu' in self.device.lower():
+      self.skipTest('XLA TPU supports tf.unique')
+
+    with ops.device('device:{}:0'.format(self.device)):
+
+      @def_function.function(jit_compile=True)
+      def f(x):
+        x = ops.convert_to_tensor(x)
+
+        def body(i, a):
+          return i + 1 + array_ops.unique([i]).y[0], \
+              control_flow_ops.cond(i > 2, lambda: a + (x**2), lambda: a + 3)
+
+        return control_flow_ops.while_loop(
+            lambda i, *_: i < 10,
+            body, (constant_op.constant(0), constant_op.constant(3.)),
+            maximum_iterations=10)[1]
+
+      with self.assertRaisesRegex(errors.InvalidArgumentError, r'\.y\[0\]'):
+        f(constant_op.constant(100.0))
+
+  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
+                                 'support stack traces')
+  def testPythonStackTraceUncompiledWithinCompiled(self):
+    if 'tpu' in self.device.lower():
+      self.skipTest('XLA TPU supports tf.unique')
+
+    with ops.device('device:{}:0'.format(self.device)):
+
+      @def_function.function
+      def fn(x):
+        return array_ops.unique(x).y  # COMMENT3
+
+      @def_function.function(jit_compile=True)
+      def outer(x):
+        return fn(x)
+
+      inputs = constant_op.constant([1, 2, 2, 3, 3])
+      with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT3'):
+        outer(inputs)
+
+  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
+                                 'support stack traces')
+  def testPythonStackTraceCompiledWithinUncompiled(self):
+    if 'tpu' in self.device.lower():
+      self.skipTest('XLA TPU supports tf.unique')
+
+    with ops.device('device:{}:0'.format(self.device)):
+
+      @def_function.function(jit_compile=True)
+      def fn(x):
+        return array_ops.unique(x).y  # COMMENT1
+
+      @def_function.function
+      def outer(x):
+        return fn(x)
+
+      inputs = constant_op.constant([1, 2, 2, 3, 3])
+      with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT1'):
+        outer(inputs)
+
+  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
+                                 'support stack traces')
+  def testPythonStackTraceCompiledWithinCompiled(self):
+    if 'tpu' in self.device.lower():
+      self.skipTest('XLA TPU supports tf.unique')
+
+    with ops.device('device:{}:0'.format(self.device)):
+
+      @def_function.function(jit_compile=True)
+      def fn(x):
+        return array_ops.unique(x).y  # COMMENT4
+
+      @def_function.function
+      def outer(x):
+        return fn(x)
+
+      inputs = constant_op.constant([1, 2, 2, 3, 3])
+      with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT4'):
+        outer(inputs)
+
   def testFunctionGradient(self):
     with ops.device('device:{}:0'.format(self.device)):
       v = resource_variable_ops.ResourceVariable(2.0)
@@ -243,6 +376,8 @@
   @test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns '
                                  ' wrong status type')
   def testMethodCompilationUnsupportedFunc(self):
+    if 'tpu' in self.device.lower():
+      self.skipTest('XLA TPU supports tf.unique')
 
     with ops.device('device:{}:0'.format(self.device)):
 
@@ -727,6 +862,15 @@
       self.assertEqual(out.shape[0], 50)
       self.assertEqual(out.shape[1], 2)
 
+  def testTfAssert(self):
+    with ops.device('device:{}:0'.format(self.device)):
+
+      @def_function.function(jit_compile=True)
+      def f(x):
+        control_flow_ops.Assert(x == 1, ['Wrong value'])
+
+      f(constant_op.constant(1))
+
 
 if __name__ == '__main__':
   ops.enable_eager_execution()
diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py
index 2660472..47c6f6b 100644
--- a/tensorflow/python/eager/forwardprop_test.py
+++ b/tensorflow/python/eager/forwardprop_test.py
@@ -50,12 +50,9 @@
 from tensorflow.python.platform import test
 from tensorflow.python.util import nest
 
-
 _X11_35_DERIVATIVES = [
-    1.1 ** 3.5,
-    3.5 * 1.1 ** 2.5,
-    3.5 * 2.5 * 1.1 ** 1.5,
-    3.5 * 2.5 * 1.5 * 1.1 ** 0.5]
+    1.1**3.5, 3.5 * 1.1**2.5, 3.5 * 2.5 * 1.1**1.5, 3.5 * 2.5 * 1.5 * 1.1**0.5
+]
 
 
 # TODO(allenl): Move this somewhere useful once forward gradients are stable.
@@ -83,8 +80,8 @@
       jac_columns.append(
           nest.map_structure(
               functools.partial(array_ops.reshape, shape=[-1]),
-              _jvp(f, primals,
-                   nest.pack_sequence_as(primals, tangent_mask))[1]))
+              _jvp(f, primals, nest.pack_sequence_as(primals,
+                                                     tangent_mask))[1]))
     jac_flat.append(array_ops.stack(jac_columns, axis=1))
     tangent_mask[primal_index] = array_ops.zeros_like(primal)
   return nest.pack_sequence_as(primals, jac_flat)
@@ -129,15 +126,18 @@
   """Return a function which computes the gradient of `f` in forward mode."""
 
   def _f(*params):
+
     def _single_jvp(param_mask):
-      with forwardprop.ForwardAccumulator(primals=[params[argnums]],
-                                          tangents=param_mask) as acc:
+      with forwardprop.ForwardAccumulator(
+          primals=[params[argnums]], tangents=param_mask) as acc:
         primals_out = f(*params)
       return acc.jvp(primals_out)
+
     # Building up a function to run with pfor takes a bit too long since we're
     # only running it a handful of times.
-    return _vectorize_parameters(_single_jvp, [params[argnums]],
-                                 use_pfor=False, dtype=f_out_dtypes)
+    return _vectorize_parameters(
+        _single_jvp, [params[argnums]], use_pfor=False, dtype=f_out_dtypes)
+
   return _f
 
 
@@ -159,8 +159,10 @@
   def _wrapper(index):
     full_onehot = array_ops.one_hot(index, total_size)
     split_onehot = array_ops.split(full_onehot, parameter_sizes)
-    tangents = [array_ops.reshape(v, array_ops.shape(param))
-                for param, v in zip(params, split_onehot)]
+    tangents = [
+        array_ops.reshape(v, array_ops.shape(param))
+        for param, v in zip(params, split_onehot)
+    ]
     return f(tangents)
 
   if use_pfor:
@@ -188,7 +190,9 @@
   """
   return _vectorize_parameters(
       functools.partial(_hvp, f, params),
-      params, use_pfor=use_pfor, dtype=dtype)
+      params,
+      use_pfor=use_pfor,
+      dtype=dtype)
 
 
 def _test_gradients(testcase,
@@ -335,8 +339,7 @@
       execution_count = getattr(self, "_execution_count", 0)
       self._execution_count = execution_count + 1
       x = array_ops.zeros([execution_count])
-      with forwardprop.ForwardAccumulator(
-          x, array_ops.ones_like(x)) as acc:
+      with forwardprop.ForwardAccumulator(x, array_ops.ones_like(x)) as acc:
         y = x + x
       self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))
 
@@ -353,11 +356,9 @@
   def testMultipleWatchesAdd(self):
     x = constant_op.constant(-2.)
     with self.assertRaisesRegex(ValueError, "multiple times"):
-      with forwardprop.ForwardAccumulator(
-          [x, x], [1., 2.]):
+      with forwardprop.ForwardAccumulator([x, x], [1., 2.]):
         pass
-    with forwardprop.ForwardAccumulator(
-        [x], [3.]) as acc:
+    with forwardprop.ForwardAccumulator([x], [3.]) as acc:
       self.assertAllClose(3., acc.jvp(x))
       acc._watch(x, constant_op.constant(10.))
       self.assertAllClose(13., acc.jvp(x))
@@ -452,8 +453,10 @@
 
     @custom_gradient.custom_gradient
     def f(unused_x):
+
       def grad(unused_dy):
         raise ValueError("test_error_string")
+
       return 1., grad
 
     c = constant_op.constant(1.)
@@ -462,22 +465,15 @@
       with self.assertRaisesRegex(ValueError, "test_error_string"):
         f(c)
 
-  @parameterized.named_parameters(
-      [("EluM5", -0.5, nn_ops.elu),
-       ("EluP5", [0.5], nn_ops.elu),
-       ("SwishP5", 0.5, nn_impl.swish),
-       ("SwishM5", [-0.5], nn_impl.swish)])
+  @parameterized.named_parameters([("EluM5", -0.5, nn_ops.elu),
+                                   ("EluP5", [0.5], nn_ops.elu),
+                                   ("SwishP5", 0.5, nn_impl.swish),
+                                   ("SwishM5", [-0.5], nn_impl.swish)])
   def testElementwiseNNOps(self, value, op_fn):
     _test_gradients(self, op_fn, [constant_op.constant(value)], order=3)
 
   def testFusedBatchNormGradsInference(self):
 
-    if test.is_built_with_rocm():
-      # This test was added recently and has been failing on the ROCm
-      # platform, since it was added.
-      # TODO(rocm): do root cause analysis of test failure and fix it.
-      self.skipTest("Test fails on ROCm platform, needs further analysis")
-
     x_shape = [4, 10, 10, 2]
     increment = 3. / math_ops.reduce_prod(
         constant_op.constant(x_shape, dtype=dtypes.float32))
@@ -489,11 +485,16 @@
     epsilon = 0.001
 
     def _bn_fused(x_arg, scale_arg, offset_arg):
-      return nn_impl.fused_batch_norm(x_arg, scale_arg, offset_arg,
-                                      mean, variance,
-                                      epsilon=epsilon, is_training=False)[0]
-    _test_gradients(self, _bn_fused, [x, scale, offset],
-                    order=2, atol=1e-2)
+      return nn_impl.fused_batch_norm(
+          x_arg,
+          scale_arg,
+          offset_arg,
+          mean,
+          variance,
+          epsilon=epsilon,
+          is_training=False)[0]
+
+    _test_gradients(self, _bn_fused, [x, scale, offset], order=2, atol=1e-2)
 
   def testPushPopAccumulatorState(self):
     # Note that this example is somewhat contrived. push_forwardprop_state is
@@ -519,22 +520,25 @@
       output = f(c)
       self.assertAllClose(d * math_ops.cos(c), acc.jvp(output))
 
-  @parameterized.named_parameters(
-      [("Order{}".format(order), order, expected)
-       for order, expected in enumerate(_X11_35_DERIVATIVES)])
+  @parameterized.named_parameters([
+      ("Order{}".format(order), order, expected)
+      for order, expected in enumerate(_X11_35_DERIVATIVES)
+  ])
   @test_util.assert_no_new_pyobjects_executing_eagerly
   def testHigherOrderPureForward(self, order, expected):
 
     def _forwardgrad(f):
+
       def _compute_forwardgrad(primal):
         tangent = constant_op.constant(1.)
         with forwardprop.ForwardAccumulator(primal, tangent) as acc:
           primal_out = f(primal)
         return acc.jvp(primal_out)
+
       return _compute_forwardgrad
 
     def _forward(x):
-      return x ** 3.5
+      return x**3.5
 
     f = _forward
     primal = constant_op.constant(1.1)
@@ -542,26 +546,25 @@
       f = _forwardgrad(f)
     self.assertAllClose(expected, f(primal))
 
-  @parameterized.named_parameters(
-      [("Function", def_function.function),
-       ("NoFunction", lambda f: f)])
+  @parameterized.named_parameters([("Function", def_function.function),
+                                   ("NoFunction", lambda f: f)])
   def testGradPureForward(self, decorator):
 
     @decorator
     def f(x):
-      return x ** 3.5
+      return x**3.5
 
     primal = constant_op.constant(1.1)
-    with forwardprop.ForwardAccumulator(
-        primal, constant_op.constant(1.)) as outer_acc:
-      with forwardprop.ForwardAccumulator(
-          primal, constant_op.constant(1.)) as acc:
+    with forwardprop.ForwardAccumulator(primal,
+                                        constant_op.constant(1.)) as outer_acc:
+      with forwardprop.ForwardAccumulator(primal,
+                                          constant_op.constant(1.)) as acc:
         primal_out = f(primal)
     inner_jvp = acc.jvp(primal_out)
     outer_jvp = outer_acc.jvp(inner_jvp)
-    self.assertAllClose(1.1 ** 3.5, primal_out)
-    self.assertAllClose(3.5 * 1.1 ** 2.5, inner_jvp)
-    self.assertAllClose(3.5 * 2.5 * 1.1 ** 1.5, outer_jvp)
+    self.assertAllClose(1.1**3.5, primal_out)
+    self.assertAllClose(3.5 * 1.1**2.5, inner_jvp)
+    self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp)
     self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out)))
 
   @test_util.assert_no_new_pyobjects_executing_eagerly
@@ -571,18 +574,18 @@
     inner_jvp = constant_op.constant(3.)
     with forwardprop.ForwardAccumulator(
         [primal_in, inner_jvp],
-        [constant_op.constant(2.), constant_op.constant(4.)]) as outer_acc:
-      with forwardprop.ForwardAccumulator(
-          primal_in, inner_jvp) as inner_acc:
+        [constant_op.constant(2.),
+         constant_op.constant(4.)]) as outer_acc:
+      with forwardprop.ForwardAccumulator(primal_in, inner_jvp) as inner_acc:
         packed_input_indices, packed_input_tangents = (
             forwardprop_util.pack_tangents([primal_in]))
         self.assertAllClose([3., 2., 4.], packed_input_tangents)
         expected_indices = (
             # inner_acc watches primal_in
-            ((0, 1),),
+            (
+                (0, 1),),
             # outer_acc watches primal_in and inner_jvp
-            ((0, 2),
-             (1, 3)))
+            ((0, 2), (1, 3)))
         self.assertAllEqual(expected_indices, packed_input_indices)
         primal_out = primal_in * two
         self.assertAllClose(6., inner_acc.jvp(primal_out))
@@ -597,15 +600,16 @@
 
     @def_function.function
     def take_gradients():
+
       @def_function.function
       def f(x):
-        return x ** 3.5
+        return x**3.5
 
       primal = constant_op.constant(1.1)
       with forwardprop.ForwardAccumulator(
           primal, constant_op.constant(1.)) as outer_acc:
-        with forwardprop.ForwardAccumulator(
-            primal, constant_op.constant(1.)) as acc:
+        with forwardprop.ForwardAccumulator(primal,
+                                            constant_op.constant(1.)) as acc:
           primal_out = f(primal)
       inner_jvp = acc.jvp(primal_out)
       outer_jvp = outer_acc.jvp(inner_jvp)
@@ -613,9 +617,9 @@
       return primal_out, inner_jvp, outer_jvp
 
     primal_out, inner_jvp, outer_jvp = take_gradients()
-    self.assertAllClose(1.1 ** 3.5, primal_out)
-    self.assertAllClose(3.5 * 1.1 ** 2.5, inner_jvp)
-    self.assertAllClose(3.5 * 2.5 * 1.1 ** 1.5, outer_jvp)
+    self.assertAllClose(1.1**3.5, primal_out)
+    self.assertAllClose(3.5 * 1.1**2.5, inner_jvp)
+    self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp)
 
   def testFunctionGrad(self):
 
@@ -623,11 +627,7 @@
     def f(x):
       return math_ops.reduce_prod(math_ops.tanh(x)**2)
 
-    _test_gradients(
-        self,
-        f,
-        [constant_op.constant([1., 2.])],
-        order=3)
+    _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3)
 
   def testReusingJVP(self):
     m1 = random_ops.random_uniform((256, 2096))
@@ -642,8 +642,8 @@
       result2 = matmul(m2, m2, transpose_b=True)
 
     def _expected(mat, tangent):
-      return (math_ops.matmul(tangent, mat, transpose_b=True)
-              + math_ops.matmul(mat, tangent, transpose_b=True))
+      return (math_ops.matmul(tangent, mat, transpose_b=True) +
+              math_ops.matmul(mat, tangent, transpose_b=True))
 
     self.assertAllClose(result1, result2)
     self.assertAllClose(_expected(m1, tangent1), acc.jvp(result1))
@@ -693,19 +693,16 @@
     with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
       with backprop.GradientTape() as tape:
         self.assertFalse(tape_lib.should_record_backprop([c]))
-        self.assertEqual(1,
-                         pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
+        self.assertEqual(1, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
         tape.watch(c)
-        self.assertEqual(2,
-                         pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
+        self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
         self.assertTrue(tape_lib.should_record_backprop([c]))
         with tape_lib.stop_recording():
           self.assertEqual(0,
                            pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
           self.assertFalse(tape_lib.should_record_backprop([c]))
           d = c * 2.
-        self.assertEqual(2,
-                         pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
+        self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
         self.assertTrue(tape_lib.should_record_backprop([c]))
         self.assertFalse(tape_lib.should_record_backprop([d]))
         self.assertIsNone(acc.jvp(d))
@@ -728,11 +725,11 @@
         self.assertIsNone(tape.gradient(d, c))
         self.assertIsNone(tape.gradient(e, c))
         tape_lib.record_operation_forwardprop_only(
-            "CustomForwardMul", [d], [c, two],
-            lambda dd: (two * dd, c * dd), None)
-        tape_lib.record_operation_backprop_only(
-            "CustomBackwardMul", [e], [c, three],
-            lambda de: (three * de, c * de))
+            "CustomForwardMul", [d], [c, two], lambda dd: (two * dd, c * dd),
+            None)
+        tape_lib.record_operation_backprop_only("CustomBackwardMul", [e],
+                                                [c, three], lambda de:
+                                                (three * de, c * de))
         self.assertAllClose(4., acc.jvp(d))
         self.assertIsNone(acc.jvp(e))
         self.assertIsNone(tape.gradient(d, c))
@@ -749,16 +746,17 @@
   def testVariableReadInFunction(self):
     v = variables.Variable(1.)
     with forwardprop.ForwardAccumulator(v, 11.) as acc:
+
       @def_function.function
       def f():
         return v.read_value(), 2. * v.read_value()
+
       result = f()
       self.assertAllClose((1.0, 2.), result)
       self.assertAllClose((11., 22.), acc.jvp(result))
 
-  @parameterized.named_parameters(
-      [("ForwardPropFirst", True),
-       ("TapeFirst", False)])
+  @parameterized.named_parameters([("ForwardPropFirst", True),
+                                   ("TapeFirst", False)])
   def testForwardOverBackwardMemoryEfficiency(self, forward_prop_first):
     # Watching depends on nesting, not creation order
     c = constant_op.constant(1.)
@@ -788,9 +786,8 @@
     finally:
       gc.enable()
 
-  @parameterized.named_parameters(
-      [("ForwardPropFirst", True),
-       ("TapeFirst", False)])
+  @parameterized.named_parameters([("ForwardPropFirst", True),
+                                   ("TapeFirst", False)])
   def testBackwardOverForward(self, forward_prop_first):
     c = constant_op.constant(1.)
     # Watching depends on nesting, not creation order
@@ -805,8 +802,7 @@
         tape.watch(c)
         d = math_ops.cos(c)
         self.assertTrue(tape_lib.should_record_backprop((acc.jvp(d),)))
-      self.assertAllClose(-.1 * math_ops.cos(1.),
-                          tape.gradient(acc.jvp(d), c))
+      self.assertAllClose(-.1 * math_ops.cos(1.), tape.gradient(acc.jvp(d), c))
 
   @test_util.assert_no_new_pyobjects_executing_eagerly
   def testRecordingWithJVPIndices(self):
@@ -816,11 +812,10 @@
       self.assertAllClose([10.], packed_input_tangents)
       d = constant_op.constant(2.)
       d_tangent = constant_op.constant(3.)
-      tape_lib.record_operation_forwardprop_only(
-          "FunctionWithInlineJVPs",
-          [d] + [d_tangent],
-          [c] + packed_input_tangents,
-          None, (((0, 1),),))
+      tape_lib.record_operation_forwardprop_only("FunctionWithInlineJVPs",
+                                                 [d] + [d_tangent],
+                                                 [c] + packed_input_tangents,
+                                                 None, (((0, 1),),))
       self.assertAllClose(3., acc.jvp(d))
 
   @test_util.assert_no_new_pyobjects_executing_eagerly
@@ -829,26 +824,19 @@
     d = constant_op.constant(2.)
     e = constant_op.constant(3.)
     with forwardprop.ForwardAccumulator(c, 10.) as acc:
-      tape_lib.record_operation(
-          "ForwardIsSpecial",
-          [d], [c],
-          None, lambda jvp: [-2. * jvp])
+      tape_lib.record_operation("ForwardIsSpecial", [d], [c], None,
+                                lambda jvp: [-2. * jvp])
       self.assertAllClose(-20., acc.jvp(d))
-      tape_lib.record_operation(
-          "ForwardIsSpecial2",
-          [], [],
-          None, lambda: [])
-      tape_lib.record_operation(
-          "ForwardIsSpecial3",
-          [e], [d],
-          None, lambda x: [x])
+      tape_lib.record_operation("ForwardIsSpecial2", [], [], None, lambda: [])
+      tape_lib.record_operation("ForwardIsSpecial3", [e], [d], None,
+                                lambda x: [x])
       self.assertAllClose(-20., acc.jvp(e))
 
   @test_util.assert_no_new_pyobjects_executing_eagerly
   def testVariableWatched(self):
     v = variables.Variable([1., 2., 3.])
-    with forwardprop.ForwardAccumulator(
-        v, constant_op.constant([.1, -.2, .3])) as acc:
+    with forwardprop.ForwardAccumulator(v, constant_op.constant([.1, -.2,
+                                                                 .3])) as acc:
       self.assertAllClose([.1, -.2, .3], acc.jvp(v))
       x = v * 2.
       self.assertAllClose([.2, -.4, .6], acc.jvp(x))
@@ -878,8 +866,9 @@
       def compute_jvps(self):
         if self._v is None:
           self._v = variables.Variable([1., 2., 3.])
-        with forwardprop.ForwardAccumulator(
-            self._v, constant_op.constant([.1, -.2, .3])) as acc:
+        with forwardprop.ForwardAccumulator(self._v,
+                                            constant_op.constant([.1, -.2,
+                                                                  .3])) as acc:
           x = self._v * 2.
           x2 = self._v + .1
         return acc.jvp((self._v, x, x2))
@@ -898,6 +887,7 @@
     self.assertAllClose(3., acc.jvp(y))
 
   def testIndexSlicesGradInFunction(self):
+
     @def_function.function
     def f(a):
       return array_ops.gather(a, 0)
@@ -983,17 +973,14 @@
   def testOfFunctionWhile(self):
     y = constant_op.constant(1.)
     with forwardprop.ForwardAccumulator(y, 1.) as acc:
-      self.assertAllClose(
-          10., acc.jvp(_has_loop(constant_op.constant(5), y)))
+      self.assertAllClose(10., acc.jvp(_has_loop(constant_op.constant(5), y)))
 
   @test_util.assert_no_new_pyobjects_executing_eagerly
   def testOfFunctionCond(self):
     y = constant_op.constant(1.)
     with forwardprop.ForwardAccumulator(y, 1.) as acc:
-      self.assertAllClose(
-          3., acc.jvp(_has_cond(constant_op.constant(5), y)))
-      self.assertAllClose(
-          0., acc.jvp(_has_cond(constant_op.constant(0), y)))
+      self.assertAllClose(3., acc.jvp(_has_cond(constant_op.constant(5), y)))
+      self.assertAllClose(0., acc.jvp(_has_cond(constant_op.constant(0), y)))
 
   @test_util.assert_no_new_pyobjects_executing_eagerly
   def testInFunctionWhile(self):
@@ -1024,15 +1011,18 @@
 
     hessian_eager, = _forward_over_back_hessian(
         _f, [constant_op.constant(x_value)],
-        use_pfor=False, dtype=[dtypes.float32])
+        use_pfor=False,
+        dtype=[dtypes.float32])
     self.assertAllClose(hess_value, hessian_eager)
     hessian_function, = def_function.function(_forward_over_back_hessian)(
         _f, [constant_op.constant(x_value)],
-        use_pfor=False, dtype=[dtypes.float32])
+        use_pfor=False,
+        dtype=[dtypes.float32])
     self.assertAllClose(hess_value, hessian_function)
     hessian_pfor, = def_function.function(_forward_over_back_hessian)(
         _f, [constant_op.constant(x_value)],
-        use_pfor=True, dtype=[dtypes.float32])
+        use_pfor=True,
+        dtype=[dtypes.float32])
     self.assertAllClose(hess_value, hessian_pfor)
 
 
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index a41dc8e..828af8f 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -33,7 +33,6 @@
 
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import function_pb2
-from tensorflow.python import _pywrap_utils
 from tensorflow.python import pywrap_tfe
 from tensorflow.python.client import pywrap_tf_session
 from tensorflow.python.eager import backprop
@@ -67,6 +66,7 @@
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.profiler import trace
 from tensorflow.python.saved_model import save_context
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util import compat
 from tensorflow.python.util import function_utils
 from tensorflow.python.util import lazy_loader
diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py
index 0fb78cb..712aae9 100644
--- a/tensorflow/python/eager/remote_test.py
+++ b/tensorflow/python/eager/remote_test.py
@@ -103,12 +103,12 @@
       return i + variable_b, c
 
     rets = remote_output(constant_op.constant([1]))
+    self.assertAllEqual(rets[0].numpy(), [2])
+    self.assertAllEqual(rets[1].numpy(), 2)
     self.assertEqual(rets[0].backing_device,
                      '/job:localhost/replica:0/task:0/device:CPU:0')
     self.assertEqual(rets[1].backing_device,
                      '/job:worker/replica:0/task:0/device:CPU:0')
-    self.assertAllEqual(rets[0].numpy(), [2])
-    self.assertAllEqual(rets[1].numpy(), 2)
 
   def testMultiDeviceFunctionAmbiguousDevice(self):
 
@@ -390,6 +390,30 @@
 
     self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
 
+  def testMultiDeviceFunctionExecutionOrderingWithPackedInput(self):
+    shape = [2]
+    with ops.device('/job:worker/replica:0/task:2/device:CPU:0'):
+      # Send 20 remote requests to simulate heavy load on worker:2.
+      unused_values = []
+      for _ in range(20):
+        unused_values.append(array_ops.zeros(shape))
+      func_input = array_ops.zeros(shape)
+
+    packed_input = ops.pack_eager_tensors([func_input])
+
+    @def_function.function
+    def func(packed_input):
+      # When worker:2 receives the component function request, packed_input
+      # should be ready on worker:2.
+      with ops.device('/job:worker/replica:0/task:2/device:CPU:0'):
+        ret = packed_input + constant_op.constant(1.0)
+      return ret + constant_op.constant(1.0)
+
+    # Run the function on a worker:1
+    with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
+      self.assertAllEqual(func(packed_input).numpy(),
+                          array_ops.ones(shape).numpy() * 2)
+
   def testMultiDeviceFunctionWithPackedVariable(self):
     with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
       var0 = resource_variable_ops.ResourceVariable(1.0)
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index b3c6e06..e1293a6 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -84,7 +84,6 @@
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:string_ops",
         "//tensorflow/python:tensor_shape",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:training",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
@@ -92,6 +91,7 @@
         "//tensorflow/python/eager:context",
         "//tensorflow/python/training/tracking",
         "//tensorflow/python/training/tracking:data_structures",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
         "@six_archive//:six",
     ],
diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py
index 3f6ab98..fe78637 100644
--- a/tensorflow/python/framework/auto_control_deps.py
+++ b/tensorflow/python/framework/auto_control_deps.py
@@ -45,7 +45,9 @@
     "CollectiveReduce",
     "CollectiveReduceV2",
     "CollectiveBcastSend",
+    "CollectiveBcastSendV2",
     "CollectiveBcastRecv",
+    "CollectiveBcastRecvV2",
     "NcclAllReduce",
     # We do not add "Send" here since we want it to be added as a control output
     # in order to avoid being pruned.
diff --git a/tensorflow/python/framework/composite_tensor.py b/tensorflow/python/framework/composite_tensor.py
index 22dbe7c..531a8a5 100644
--- a/tensorflow/python/framework/composite_tensor.py
+++ b/tensorflow/python/framework/composite_tensor.py
@@ -22,8 +22,8 @@
 
 import six
 
-from tensorflow.python import _pywrap_utils
 from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py
index 2691665..a9bb996 100644
--- a/tensorflow/python/framework/config.py
+++ b/tensorflow/python/framework/config.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python import _pywrap_tensor_float_32_execution
 from tensorflow.python.eager import context
+from tensorflow.python.util import _pywrap_tensor_float_32_execution
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index 343856b..6bed1b5 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -169,9 +169,9 @@
 
   Note: All eager `tf.Tensor` values are immutable (in contrast to
   `tf.Variable`). There is nothing especially _constant_ about the value
-  returned from `tf.constant`. This function it is not fundamentally different
-  from `tf.convert_to_tensor`. The name `tf.constant` comes from the `value`
-  being embeded in a `Const` node in the `tf.Graph`. `tf.constant` is useful
+  returned from `tf.constant`. This function is not fundamentally different from
+  `tf.convert_to_tensor`. The name `tf.constant` comes from the `value` being
+  embedded in a `Const` node in the `tf.Graph`. `tf.constant` is useful
   for asserting that the value can be embedded that way.
 
   If the argument `dtype` is not specified, then the type is inferred from
@@ -188,7 +188,7 @@
     array([[1, 2, 3],
            [4, 5, 6]])>
 
-  If `dtype` is specified the resulting tensor values are cast to the requested
+  If `dtype` is specified, the resulting tensor values are cast to the requested
   `dtype`.
 
   >>> tf.constant([1, 2, 3, 4, 5, 6], dtype=tf.float64)
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 9eeae83..51eca8e 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -25,8 +25,8 @@
 # protobuf errors where a file is defined twice on MacOS.
 # pylint: disable=invalid-import-order,g-bad-import-order
 from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
-from tensorflow.python import _pywrap_bfloat16
 from tensorflow.python import _dtypes
+from tensorflow.python.lib.core import _pywrap_bfloat16
 from tensorflow.python.util.tf_export import tf_export
 
 _np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
diff --git a/tensorflow/python/framework/errors_test.py b/tensorflow/python/framework/errors_test.py
index 1847b09..c0a5e8e 100644
--- a/tensorflow/python/framework/errors_test.py
+++ b/tensorflow/python/framework/errors_test.py
@@ -23,10 +23,10 @@
 import warnings
 
 from tensorflow.core.lib.core import error_codes_pb2
-from tensorflow.python import _pywrap_file_io
 from tensorflow.python.framework import c_api_util
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import errors_impl
+from tensorflow.python.lib.io import _pywrap_file_io
 from tensorflow.python.platform import test
 from tensorflow.python.util import compat
 
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 570e4af..3ac7025 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -389,13 +389,9 @@
     variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
     variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access
 
-    collections_ref = {}
-    parent_collections_ref = ops.get_default_graph()._collections  # pylint: disable=protected-access
-    for key in variable_keys:
-      if key not in parent_collections_ref:
-        parent_collections_ref[key] = collections_ref[key] = []
-      else:
-        collections_ref[key] = parent_collections_ref[key]
+    parent_graph = ops.get_default_graph()
+    collections_ref = {
+        key: parent_graph.get_collection_ref(key) for key in variable_keys}
 
     temp_graph = func_graph_from_py_func(
         self._func,
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index 69aa38d..243d33a 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -18,16 +18,21 @@
 from __future__ import division
 from __future__ import print_function
 
+import itertools
+
+
 from tensorflow.core.framework import function_pb2
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.framework import tensor_shape_pb2
 from tensorflow.core.framework import types_pb2
 from tensorflow.core.framework import versions_pb2
 from tensorflow.python.eager import context
+from tensorflow.python.framework import cpp_shape_inference_pb2
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import versions
 from tensorflow.python.framework.func_graph import FuncGraph
+from tensorflow.python.ops import resource_variable_ops
 
 
 def function_def_to_graph(fdef, input_shapes=None):
@@ -84,6 +89,9 @@
         func_graph.get_operation_by_name(fdef.control_ret[ret_name])
         for ret_name in fdef.signature.control_output
     ]
+
+    _set_handle_data(func_graph, fdef)
+
     for node in graph_def.node:
       output_shapes = node.attr.get("_output_shapes", None)
       if output_shapes is not None:
@@ -264,3 +272,19 @@
     return 1
   else:
     raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
+
+
+def _set_handle_data(func_graph, fdef):
+  """Adds handle data for resource type inputs and outputs."""
+  for tensor, arg_def in itertools.chain(
+      zip(func_graph.inputs, fdef.signature.input_arg),
+      zip(func_graph.outputs, fdef.signature.output_arg)):
+    if arg_def.handle_data:
+      shape_and_dtype = arg_def.handle_data[0]
+      handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
+      handle_data.is_set = True
+      handle_data.shape_and_type.append(
+          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
+              shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype))
+      resource_variable_ops._set_handle_shapes_and_types(  # pylint: disable=protected-access
+          tensor, handle_data, True)
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index ea37607..4c9b612 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1168,6 +1168,12 @@
 
 class FunctionsFromProtos(test.TestCase):
 
+  def stripInternalFunctionDefAnnotations(self, f_def):
+    result = function_pb2.FunctionDef()
+    result.CopyFrom(f_def)
+    result.attr.pop("_construction_context", None)
+    return result
+
   def expectFunctionsEqual(self, func, grad_func=None, new_func=None):
     if new_func is None:
       # Make a copy of func.definition to avoid any bugs masked by using the
@@ -1177,7 +1183,9 @@
       fdef = function_pb2.FunctionDef.FromString(serialized_fdef)
       new_func = function._from_definition(fdef, grad_func=grad_func)
     self.assertEqual(func.name, new_func.name)
-    self.assertEqual(func.definition, new_func.definition)
+    self.assertEqual(
+        self.stripInternalFunctionDefAnnotations(func.definition),
+        self.stripInternalFunctionDefAnnotations(new_func.definition))
     self.assertEqual(func.grad_func_name, new_func.grad_func_name)
     self.assertEqual(func.declared_input_types, new_func.declared_input_types)
     self.assertEqual(func.captured_inputs, new_func.captured_inputs)
@@ -1213,7 +1221,9 @@
     new_func = function._from_definition(Foo.definition)
 
     self.assertEqual(Foo.name, new_func.name)
-    self.assertEqual(Foo.definition, new_func.definition)
+    self.assertEqual(
+        self.stripInternalFunctionDefAnnotations(Foo.definition),
+        self.stripInternalFunctionDefAnnotations(new_func.definition))
     self.assertEqual(Foo.grad_func_name, new_func.grad_func_name)
 
     # Captured inputs are added as regular inputs to the function definition
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index d1a0c26..73ef3f7 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -27,6 +27,7 @@
 
 from tensorflow.python import _pywrap_python_op_gen
 from tensorflow.python.client import pywrap_tf_session as py_tf
+from tensorflow.python.eager import context
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 
@@ -159,6 +160,45 @@
         library_location)
 
 
+def load_pluggable_device_library(library_location):
+  """Loads a TensorFlow PluggableDevice plugin.
+
+  "library_location" can be a path to a specific shared object, or a folder.
+  If it is a folder, all shared objects will be loaded. when the library is
+  loaded, devices/kernels registered in the library via StreamExecutor C API
+  and Kernel/Op Registration C API are made available in TensorFlow process.
+
+  Args:
+    library_location: Path to the plugin or folder of plugins. Relative or
+      absolute filesystem path to a dynamic library file or folder.
+
+  Raises:
+    OSError: When the file to be loaded is not found.
+    RuntimeError: when unable to load the library.
+  """
+  if os.path.exists(library_location):
+    if os.path.isdir(library_location):
+      directory_contents = os.listdir(library_location)
+
+      pluggable_device_libraries = [
+          os.path.join(library_location, f)
+          for f in directory_contents
+          if _is_shared_object(f)
+      ]
+    else:
+      pluggable_device_libraries = [library_location]
+
+    for lib in pluggable_device_libraries:
+      py_tf.TF_LoadPluggableDeviceLibrary(lib)
+    # Reinitialized physical devices list after plugin registration.
+    context.context().reinitialize_physical_devices()
+  else:
+    raise OSError(
+        errno.ENOENT,
+        'The file or folder to load pluggable device libraries from does not '
+        'exist.', library_location)
+
+
 @tf_export('experimental.register_filesystem_plugin')
 def register_filesystem_plugin(plugin_location):
   """Loads a TensorFlow FileSystem plugin.
diff --git a/tensorflow/python/framework/memory_checker_test.py b/tensorflow/python/framework/memory_checker_test.py
index bed6aac..86311e7 100644
--- a/tensorflow/python/framework/memory_checker_test.py
+++ b/tensorflow/python/framework/memory_checker_test.py
@@ -108,6 +108,18 @@
     with self.assertRaises(AssertionError):
       memory_checker.assert_no_leak_if_all_possibly_except_one()
 
+  def testLeak4(self):
+    helper = _memory_checker_test_helper.MemoryCheckerTestHelper()
+
+    with MemoryChecker() as memory_checker:
+      for i in range(10):
+        helper.list_push_back(i)
+        memory_checker.record_snapshot()
+
+    memory_checker.report()
+    with self.assertRaises(AssertionError):
+      memory_checker.assert_no_leak_if_all_possibly_except_one()
+
   def testNoNewPythonObjectsEmpty(self):
     self.skipTest('TODO(b/150324603): Flaky test.')
     with MemoryChecker() as memory_checker:
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index 016af65..f6d93e2 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -26,13 +26,13 @@
 from tensorflow.core.framework import tensor_pb2
 from tensorflow.core.framework import tensor_shape_pb2
 from tensorflow.core.framework import types_pb2
-from tensorflow.python import _pywrap_utils
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import op_callbacks
 from tensorflow.python.framework import op_def_registry
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util import compat
 from tensorflow.python.util import tf_contextlib
 
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 376122b..8f0d18a 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -50,6 +50,7 @@
 from tensorflow.python.eager import tape
 from tensorflow.python.framework import c_api_util
 from tensorflow.python.framework import composite_tensor
+from tensorflow.python.framework import cpp_shape_inference_pb2
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -3292,18 +3293,18 @@
             continue
           # TODO(b/141471245): Fix the inconsistency when inputs of func graph
           # are appended during gradient computation of while/cond.
-          for input_tensor, _ in zip(func_graph_inputs,
-                                     function_def.signature.input_arg):
+          for input_tensor, arg_def in zip(func_graph_inputs,
+                                           function_def.signature.input_arg):
+            input_shapes.list.shape.add().CopyFrom(
+                input_tensor.get_shape().as_proto())
             if input_tensor.dtype == dtypes.resource:
-              # TODO(allenl): Save and restore handle data, then save the
-              # resource placeholder's shape. Right now some shape functions get
-              # confused if we set the shape of the resource placeholder (to a
-              # scalar of course) and there isn't any handle data.
-              input_shapes.list.shape.add().CopyFrom(
-                  tensor_shape.TensorShape(None).as_proto())
-            else:
-              input_shapes.list.shape.add().CopyFrom(
-                  input_tensor.get_shape().as_proto())
+              _copy_handle_data_to_arg_def(input_tensor, arg_def)
+
+          for output_tensor, arg_def in zip(func_graph.outputs,
+                                            function_def.signature.output_arg):
+            if output_tensor.dtype == dtypes.resource:
+              _copy_handle_data_to_arg_def(output_tensor, arg_def)
+
           for node in function_def.node_def:
             try:
               op = func_graph.get_operation_by_name(node.name)
@@ -5332,13 +5333,12 @@
 def control_dependencies(control_inputs):
   """Wrapper for `Graph.control_dependencies()` using the default graph.
 
-  See `tf.Graph.control_dependencies`
-  for more details.
+  See `tf.Graph.control_dependencies` for more details.
 
   Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
-  this method, as code executes in the expected order.* Only use
-  `tf.control_dependencies` when working with v1-style code or in a graph
-  context such as inside `Dataset.map`.
+  this method, as ops execute in the expected order thanks to automatic control
+  dependencies.* Only use `tf.control_dependencies` when working with v1
+  `tf.Graph` code.
 
   When eager execution is enabled, any callable object in the `control_inputs`
   list will be called.
@@ -6979,3 +6979,22 @@
 
   if graph.building_function and hasattr(graph, "outer_graph"):
     return _get_enclosing_context(graph.outer_graph)
+
+
+def get_resource_handle_data(graph_op):
+  assert type(graph_op) == Tensor  # pylint: disable=unidiomatic-typecheck
+
+  handle_data = pywrap_tf_session.GetHandleShapeAndType(
+      graph_op.graph._c_graph, graph_op._as_tf_output())  # pylint: disable=protected-access
+
+  return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
+      compat.as_bytes(handle_data))
+
+
+def _copy_handle_data_to_arg_def(tensor, arg_def):
+  handle_data = get_resource_handle_data(tensor)
+  if handle_data.shape_and_type:
+    shape_and_type = handle_data.shape_and_type[0]
+    proto = arg_def.handle_data.add()
+    proto.dtype = shape_and_type.dtype
+    proto.shape.CopyFrom(handle_data.shape_and_type[0].shape)
diff --git a/tensorflow/python/framework/registry.py b/tensorflow/python/framework/registry.py
index 83569cd..6041a98 100644
--- a/tensorflow/python/framework/registry.py
+++ b/tensorflow/python/framework/registry.py
@@ -23,9 +23,10 @@
 from __future__ import division
 from __future__ import print_function
 
+import traceback
+
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
-from tensorflow.python.util import tf_stack
 
 
 # Registry mechanism below is based on mapreduce.python.mrpython.Register.
@@ -65,8 +66,8 @@
     logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name)
     # stack trace is [this_function, Register(), user_function,...]
     # so the user function is #2.
-    stack = tf_stack.extract_stack(limit=3)
-    stack_index = min(2, len(stack)-1)
+    stack = traceback.extract_stack(limit=3)
+    stack_index = min(2, len(stack) - 1)
     if stack_index >= 0:
       location_tag = stack[stack_index]
     else:
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index 28e21a8..d96a137 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -23,7 +23,6 @@
 import numpy as np
 
 from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
-from tensorflow.python import _pywrap_utils
 from tensorflow.python import tf2
 from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import constant_op
@@ -35,6 +34,7 @@
 from tensorflow.python.framework import type_spec
 from tensorflow.python.ops import gen_sparse_ops
 from tensorflow.python.types import internal
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util.tf_export import tf_export
 
 # pylint: disable=protected-access
diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py
index 20b776d..68c2320 100644
--- a/tensorflow/python/framework/tensor_spec.py
+++ b/tensorflow/python/framework/tensor_spec.py
@@ -20,12 +20,12 @@
 
 import numpy as np
 
-from tensorflow.python import _pywrap_utils
 from tensorflow.python.framework import common_shapes
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import type_spec
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util.tf_export import tf_export
 
 
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 6c4c985..ee8db09 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -809,6 +809,41 @@
   This function attempts to partially evaluate the given tensor, and
   returns its value as a numpy ndarray if this succeeds.
 
+  Example usage:
+
+  >>> a = tf.constant(10)
+  >>> tf.get_static_value(a)
+  10
+  >>> b = tf.constant(20)
+  >>> tf.get_static_value(tf.add(a, b))
+  30
+
+  >>> # `tf.Variable` is not supported.
+  >>> c = tf.Variable(30)
+  >>> print(tf.get_static_value(c))
+  None
+
+  Using `partial` option is most relevant when calling `get_static_value` inside
+  a `tf.function`. Setting it to `True` will return the results but for the
+  values that cannot be evaluated will be `None`. For example:
+
+  ```python
+  class Foo(object):
+    def __init__(self):
+      self.a = tf.Variable(1)
+      self.b = tf.constant(2)
+
+    @tf.function
+    def bar(self, partial):
+      packed = tf.raw_ops.Pack(values=[self.a, self.b])
+      static_val = tf.get_static_value(packed, partial=partial)
+      tf.print(static_val)
+
+  f = Foo()
+  f.bar(partial=True)  # `array([None, array(2, dtype=int32)], dtype=object)`
+  f.bar(partial=False)  # `None`
+  ```
+
   Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it
   will no longer be possible to feed a different value for `tensor`. This allows
   the result of this function to influence the graph that is constructed, and
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4ba6798..6b54f03 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -43,7 +43,6 @@
 
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.protobuf import rewriter_config_pb2
-from tensorflow.python import _pywrap_util_port
 from tensorflow.python import tf2
 from tensorflow.python.client import device_lib
 from tensorflow.python.client import pywrap_tf_session
@@ -82,6 +81,7 @@
 from tensorflow.python.platform import googletest
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import server_lib
+from tensorflow.python.util import _pywrap_util_port
 from tensorflow.python.util import compat
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import nest
@@ -616,12 +616,12 @@
     The wrapped function
   """
 
-  def wrapper(self, *args, **kwargs):
+  def wrapper(*args, **kwargs):
     output_all_intermediates_old = \
         control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
     control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
     try:
-      return fn(self, *args, **kwargs)
+      return fn(*args, **kwargs)
     finally:
       control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
           output_all_intermediates_old
@@ -2579,6 +2579,12 @@
     self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
 
     msgs = [msg]
+    # np.allclose does not always work for our custom bfloat16 extension type
+    # when type promotions are involved, so we first cast any bfloat16 arrays
+    # to float32.
+    a_dtype = a.dtype
+    a = a.astype(np.float32) if a.dtype == dtypes.bfloat16.as_numpy_dtype else a
+    b = b.astype(np.float32) if b.dtype == dtypes.bfloat16.as_numpy_dtype else b
     if not np.allclose(a, b, rtol=rtol, atol=atol):
       # Adds more details to np.testing.assert_allclose.
       #
@@ -2602,7 +2608,7 @@
       msgs.append("not close rhs = {}".format(y))
       msgs.append("not close dif = {}".format(np.abs(x - y)))
       msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
-      msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape))
+      msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape))
       # TODO(xpan): There seems to be a bug:
       # tensorflow/compiler/tests:binary_ops_test pass with float32
       # nan even though the equal_nan is False by default internally.
diff --git a/tensorflow/python/framework/traceable_stack.py b/tensorflow/python/framework/traceable_stack.py
index 857d021..2dccc34 100644
--- a/tensorflow/python/framework/traceable_stack.py
+++ b/tensorflow/python/framework/traceable_stack.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.util import tf_stack
+import inspect
 
 
 class TraceableObject(object):
@@ -51,26 +51,20 @@
       TraceableObject.HEURISTIC_USED if the offset was larger than the stack,
       and TraceableObject.FAILURE if the stack was empty.
     """
-    # Offset is defined in "Args" as relative to the caller.  We are one frame
+    retcode = self.SUCCESS
+    frame = inspect.currentframe()
+    # Offset is defined in "Args" as relative to the caller. We are one frame
     # beyond the caller.
-    local_offset = offset + 1
-
-    frame_records = tf_stack.extract_stack(
-        limit=local_offset + 1)
-    if not frame_records:
-      return self.FAILURE
-    if len(frame_records) > local_offset:
-      frame = frame_records[len(frame_records) - (local_offset + 1)]
-      self.filename = frame.filename
-      self.lineno = frame.lineno
-      return self.SUCCESS
-    else:
-      # If the offset is too large then we use the largest offset possible,
-      # meaning we use the outermost stack frame at index 0.
-      frame = frame_records[0]
-      self.filename = frame.filename
-      self.lineno = frame.lineno
-      return self.HEURISTIC_USED
+    for _ in range(offset + 1):
+      parent = frame.f_back
+      if parent is None:
+        # If the offset is too large then we use the largest offset possible.
+        retcode = self.HEURISTIC_USED
+        break
+      frame = parent
+    self.filename = frame.f_code.co_filename
+    self.lineno = frame.f_lineno
+    return retcode
 
   def copy_metadata(self):
     """Return a TraceableObject like this one, but without the object."""
diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py
index fa48ca8..80b6933 100644
--- a/tensorflow/python/framework/type_spec.py
+++ b/tensorflow/python/framework/type_spec.py
@@ -20,16 +20,16 @@
 
 import abc
 import collections
-
 import re
+
 import numpy as np
 import six
 
-from tensorflow.python import _pywrap_utils
 from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_decorator
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 612a385..f58a875 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -236,6 +236,7 @@
     ],
     srcs_version = "PY2AND3",
     deps = [
+        ":activations",
         ":backend",
         ":losses",
         "//tensorflow/python:array_ops",
@@ -581,7 +582,6 @@
     python_version = "PY3",
     shard_count = 8,
     tags = [
-        "no_rocm",
         "notsan",  # b/67509773
     ],
     deps = [
diff --git a/tensorflow/python/keras/applications/BUILD b/tensorflow/python/keras/applications/BUILD
index 4d23718..fbd9de7 100644
--- a/tensorflow/python/keras/applications/BUILD
+++ b/tensorflow/python/keras/applications/BUILD
@@ -41,7 +41,6 @@
     deps = [
         "//tensorflow/python:lib",
         "//tensorflow/python:platform",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/keras:activations",
         "//tensorflow/python/keras:backend",
         "//tensorflow/python/keras:models",
@@ -49,6 +48,7 @@
         "//tensorflow/python/keras/layers",
         "//tensorflow/python/keras/utils:data_utils",
         "//tensorflow/python/keras/utils:layer_utils",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
diff --git a/tensorflow/python/keras/applications/mobilenet_v3.py b/tensorflow/python/keras/applications/mobilenet_v3.py
index ab396a2..055d277 100644
--- a/tensorflow/python/keras/applications/mobilenet_v3.py
+++ b/tensorflow/python/keras/applications/mobilenet_v3.py
@@ -61,7 +61,7 @@
   The following table describes the performance of MobileNets:
   ------------------------------------------------------------------------
   MACs stands for Multiply Adds
-  
+
   |Classification Checkpoint|MACs(M)|Parameters(M)|Top1 Accuracy|Pixel1 CPU(ms)|
   |---|---|---|---|---|
   | mobilenet_v3_large_1.0_224              | 217 | 5.4 |   75.6   |   51.2  |
@@ -77,11 +77,6 @@
 
   Optionally loads weights pre-trained on ImageNet.
 
-  Note: each Keras Application expects a specific kind of input preprocessing.
-  For MobileNetV3, call
-  `tf.keras.applications.mobilenet_v3.preprocess_input` on your
-  inputs before passing them to the model.
-
   Arguments:
     input_shape: Optional shape tuple, to be specified if you would
       like to use a model with an input image resolution that is not
@@ -136,6 +131,10 @@
       on the "top" layer. Ignored unless `include_top=True`. Set
       `classifier_activation=None` to return the logits of the "top" layer.
 
+  Call arguments:
+    inputs: A floating point `numpy.array` or a `tf.Tensor`, 4D with 3 color
+      channels, with values in the range [0, 255].
+
   Returns:
     A `keras.Model` instance.
 
@@ -555,6 +554,24 @@
 
 @keras_export('keras.applications.mobilenet_v3.preprocess_input')
 def preprocess_input(x, data_format=None):  # pylint: disable=unused-argument
+  """A placeholder method for backward compatibility.
+
+  The preprocessing logic has been included in the mobilenet_v3 model
+  implementation. Users are no longer required to call this method to normalize
+  the input data. This method does nothing and only kept as a placeholder to
+  align the API surface between old and new version of model.
+
+  Args:
+    x: A floating point `numpy.array` or a `tf.Tensor`.
+    data_format: Optional data format of the image tensor/array. Defaults to
+      None, in which case the global setting
+      `tf.keras.backend.image_data_format()` is used (unless you changed it,
+      it defaults to "channels_last").{mode}
+
+  Returns:
+    Unchanged `numpy.array` or `tf.Tensor`.
+  """
+
   return x
 
 
@@ -563,8 +580,4 @@
   return imagenet_utils.decode_predictions(preds, top=top)
 
 
-preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
-    mode='',
-    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
-    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
 decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 8513105..a4e40a9 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -93,8 +93,8 @@
       backend.variable(input_a, dtype=dtype),
       backend.variable(input_b, dtype=dtype), *keras_args, **keras_kwargs)
   keras_output = backend.eval(keras_output)
-  np_output = np_op(input_a.astype(dtype), input_b.astype(dtype),
-                    *np_args, **np_kwargs)
+  np_output = np_op(
+      input_a.astype(dtype), input_b.astype(dtype), *np_args, **np_kwargs)
   try:
     np.testing.assert_allclose(keras_output, np_output, atol=1e-4)
   except AssertionError:
@@ -425,19 +425,31 @@
         (backend.argmax, np.argmax),
     ]
     for keras_op, np_op in ops_to_test:
-      compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
-                                       keras_kwargs={'axis': 1},
-                                       np_kwargs={'axis': 1})
-      compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
-                                       keras_kwargs={'axis': -1},
-                                       np_kwargs={'axis': -1})
+      compare_single_input_op_to_numpy(
+          keras_op,
+          np_op,
+          input_shape=(4, 7, 5),
+          keras_kwargs={'axis': 1},
+          np_kwargs={'axis': 1})
+      compare_single_input_op_to_numpy(
+          keras_op,
+          np_op,
+          input_shape=(4, 7, 5),
+          keras_kwargs={'axis': -1},
+          np_kwargs={'axis': -1})
       if 'keepdims' in tf_inspect.getargspec(keras_op).args:
-        compare_single_input_op_to_numpy(keras_op, np_op,
-                                         input_shape=(4, 7, 5),
-                                         keras_kwargs={'axis': 1,
-                                                       'keepdims': True},
-                                         np_kwargs={'axis': 1,
-                                                    'keepdims': True})
+        compare_single_input_op_to_numpy(
+            keras_op,
+            np_op,
+            input_shape=(4, 7, 5),
+            keras_kwargs={
+                'axis': 1,
+                'keepdims': True
+            },
+            np_kwargs={
+                'axis': 1,
+                'keepdims': True
+            })
 
   def test_elementwise_ops(self):
     ops_to_test = [
@@ -457,9 +469,8 @@
         (backend.log, np.log),
     ]
     for keras_op, np_op in ops_to_test:
-      compare_single_input_op_to_numpy(keras_op, np_op,
-                                       input_shape=(4, 7),
-                                       negative_values=False)
+      compare_single_input_op_to_numpy(
+          keras_op, np_op, input_shape=(4, 7), negative_values=False)
 
     compare_single_input_op_to_numpy(
         backend.clip,
@@ -489,9 +500,8 @@
         (backend.minimum, np.minimum),
     ]
     for keras_op, np_op in ops_to_test:
-      compare_two_inputs_op_to_numpy(keras_op, np_op,
-                                     input_shape_a=(4, 7),
-                                     input_shape_b=(4, 7))
+      compare_two_inputs_op_to_numpy(
+          keras_op, np_op, input_shape_a=(4, 7), input_shape_b=(4, 7))
 
   def test_relu(self):
     x = ops.convert_to_tensor_v2_with_dispatch([[-4, 0], [2, 7]], 'float32')
@@ -713,19 +723,14 @@
         shape[2] += padding[1][0] + padding[1][1]
         shape[3] += padding[2][0] + padding[2][1]
         y = np.zeros(tuple(shape))
-        y[:,
-          padding[0][0]:-padding[0][1],
-          padding[1][0]:-padding[1][1],
-          padding[2][0]:-padding[2][1],
-          :] = x
+        y[:, padding[0][0]:-padding[0][1], padding[1][0]:-padding[1][1],
+          padding[2][0]:-padding[2][1], :] = x
       else:
         shape[2] += padding[0][0] + padding[0][1]
         shape[3] += padding[1][0] + padding[1][1]
         shape[4] += padding[2][0] + padding[2][1]
         y = np.zeros(tuple(shape))
-        y[:, :,
-          padding[0][0]:-padding[0][1],
-          padding[1][0]:-padding[1][1],
+        y[:, :, padding[0][0]:-padding[0][1], padding[1][0]:-padding[1][1],
           padding[2][0]:-padding[2][1]] = x
       return y
 
@@ -753,18 +758,14 @@
   def test_bias_add(self):
     keras_op = backend.bias_add
     np_op = np.add
-    compare_two_inputs_op_to_numpy(keras_op, np_op,
-                                   input_shape_a=(4, 7),
-                                   input_shape_b=(7,))
-    compare_two_inputs_op_to_numpy(keras_op, np_op,
-                                   input_shape_a=(4, 3, 7),
-                                   input_shape_b=(7,))
-    compare_two_inputs_op_to_numpy(keras_op, np_op,
-                                   input_shape_a=(4, 3, 5, 7),
-                                   input_shape_b=(7,))
-    compare_two_inputs_op_to_numpy(keras_op, np_op,
-                                   input_shape_a=(4, 3, 5, 2, 7),
-                                   input_shape_b=(7,))
+    compare_two_inputs_op_to_numpy(
+        keras_op, np_op, input_shape_a=(4, 7), input_shape_b=(7,))
+    compare_two_inputs_op_to_numpy(
+        keras_op, np_op, input_shape_a=(4, 3, 7), input_shape_b=(7,))
+    compare_two_inputs_op_to_numpy(
+        keras_op, np_op, input_shape_a=(4, 3, 5, 7), input_shape_b=(7,))
+    compare_two_inputs_op_to_numpy(
+        keras_op, np_op, input_shape_a=(4, 3, 5, 2, 7), input_shape_b=(7,))
 
     with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
       x = backend.variable((3, 4))
@@ -787,12 +788,10 @@
         b = b.reshape((1, b.shape[0], 1, 1))
       return x + b
 
-    compare_two_inputs_op_to_numpy(keras_op, np_op,
-                                   input_shape_a=(4, 3, 7),
-                                   input_shape_b=(3,))
-    compare_two_inputs_op_to_numpy(keras_op, np_op,
-                                   input_shape_a=(4, 3, 5, 7),
-                                   input_shape_b=(3,))
+    compare_two_inputs_op_to_numpy(
+        keras_op, np_op, input_shape_a=(4, 3, 7), input_shape_b=(3,))
+    compare_two_inputs_op_to_numpy(
+        keras_op, np_op, input_shape_a=(4, 3, 5, 7), input_shape_b=(3,))
 
   def test_pool2d(self):
     val = np.random.random((10, 3, 10, 10))
@@ -847,8 +846,6 @@
       y = backend.pool2d(x, (2, 2), strides=(2, 2), pool_mode='other')
 
   def test_pool3d(self):
-    if test.is_built_with_rocm():
-      self.skipTest('Pooling with 3D tensors is not supported in ROCm')
     val = np.random.random((10, 3, 10, 10, 10))
     x = backend.variable(val)
     y = backend.pool3d(
@@ -938,18 +935,16 @@
           kernel_sizes = (kernel_size,) * dim
           strides = (stride,) * dim
 
-          output_shape = tuple([(i - kernel_size + stride) // stride
-                                for i in input_spatial_shape])
+          output_shape = tuple([
+              (i - kernel_size + stride) // stride for i in input_spatial_shape
+          ])
 
           kernel_shape = (np.prod(output_shape),
-                          np.prod(kernel_sizes) * channels_in,
-                          filters)
+                          np.prod(kernel_sizes) * channels_in, filters)
 
           kernel = np.random.normal(
-              0,
-              1,
-              output_shape + (channels_in, np.prod(kernel_sizes), filters)
-          )
+              0, 1,
+              output_shape + (channels_in, np.prod(kernel_sizes), filters))
 
           kernel_cf = np.reshape(kernel, kernel_shape)
           kernel_cf = backend.variable(kernel_cf)
@@ -957,14 +952,14 @@
           conv_cf = backend.local_conv(inputs_cf, kernel_cf, kernel_sizes,
                                        strides, output_shape, 'channels_first')
 
-          inputs_cl = np.transpose(inputs, [0, 2] + list(range(3, dim + 2)) +
-                                   [1])
+          inputs_cl = np.transpose(inputs,
+                                   [0, 2] + list(range(3, dim + 2)) + [1])
           inputs_cl = backend.variable(inputs_cl)
 
           kernel_cl = np.reshape(
-              np.transpose(kernel, list(range(dim)) + [dim + 1, dim, dim + 2]),
-              kernel_shape
-          )
+              np.transpose(kernel,
+                           list(range(dim)) + [dim + 1, dim, dim + 2]),
+              kernel_shape)
           kernel_cl = backend.variable(kernel_cl)
 
           conv_cl = backend.local_conv(inputs_cl, kernel_cl, kernel_sizes,
@@ -975,18 +970,13 @@
 
           self.assertAllCloseAccordingToType(
               conv_cf,
-              np.transpose(conv_cl,
-                           [0, dim + 1] + list(range(1, dim + 1))),
-              atol=1e-5
-          )
+              np.transpose(conv_cl, [0, dim + 1] + list(range(1, dim + 1))),
+              atol=1e-5)
 
   @parameterized.named_parameters(
       ('local_conv1d', (5, 6), (3,), (1,), (3,)),
       ('local_conv2d', (4, 5, 6), (3, 3), (1, 1), (2, 3)))
-  def test_local_conv_1d_and_2d(self,
-                                input_shape,
-                                kernel_sizes,
-                                strides,
+  def test_local_conv_1d_and_2d(self, input_shape, kernel_sizes, strides,
                                 output_shape):
     filters = 3
     batch_size = 2
@@ -994,9 +984,9 @@
     inputs = np.random.normal(0, 1, (batch_size,) + input_shape)
     inputs = backend.variable(inputs)
 
-    kernel = np.random.normal(0, 1, (np.prod(output_shape),
-                                     np.prod(kernel_sizes) * input_shape[-1],
-                                     filters))
+    kernel = np.random.normal(0, 1,
+                              (np.prod(output_shape), np.prod(kernel_sizes) *
+                               input_shape[-1], filters))
     kernel = backend.variable(kernel)
 
     local_conv = backend.local_conv(inputs, kernel, kernel_sizes, strides,
@@ -1225,12 +1215,33 @@
     mask = backend.variable(np_mask)
 
     kwargs_list = [
-        {'go_backwards': False, 'mask': None},
-        {'go_backwards': False, 'mask': None, 'unroll': True},
-        {'go_backwards': True, 'mask': None},
-        {'go_backwards': True, 'mask': None, 'unroll': True},
-        {'go_backwards': False, 'mask': mask},
-        {'go_backwards': False, 'mask': mask, 'unroll': True},
+        {
+            'go_backwards': False,
+            'mask': None
+        },
+        {
+            'go_backwards': False,
+            'mask': None,
+            'unroll': True
+        },
+        {
+            'go_backwards': True,
+            'mask': None
+        },
+        {
+            'go_backwards': True,
+            'mask': None,
+            'unroll': True
+        },
+        {
+            'go_backwards': False,
+            'mask': mask
+        },
+        {
+            'go_backwards': False,
+            'mask': mask,
+            'unroll': True
+        },
     ]
     for i, kwargs in enumerate(kwargs_list):
       last_output, outputs, new_states = backend.rnn(rnn_fn, inputs,
@@ -1319,12 +1330,33 @@
     mask = backend.variable(np_mask)
 
     kwargs_list = [
-        {'go_backwards': False, 'mask': None},
-        {'go_backwards': False, 'mask': None, 'unroll': True},
-        {'go_backwards': True, 'mask': None},
-        {'go_backwards': True, 'mask': None, 'unroll': True},
-        {'go_backwards': False, 'mask': mask},
-        {'go_backwards': False, 'mask': mask, 'unroll': True},
+        {
+            'go_backwards': False,
+            'mask': None
+        },
+        {
+            'go_backwards': False,
+            'mask': None,
+            'unroll': True
+        },
+        {
+            'go_backwards': True,
+            'mask': None
+        },
+        {
+            'go_backwards': True,
+            'mask': None,
+            'unroll': True
+        },
+        {
+            'go_backwards': False,
+            'mask': mask
+        },
+        {
+            'go_backwards': False,
+            'mask': mask,
+            'unroll': True
+        },
     ]
     for i, kwargs in enumerate(kwargs_list):
       last_output, outputs, new_states = backend.rnn(rnn_fn, inputs,
@@ -1394,8 +1426,8 @@
     def step_function(inputs, states):
       return inputs, [s + 1 for s in states]
 
-    inputs_vals = np.random.random((num_samples, num_timesteps,
-                                    state_and_io_size))
+    inputs_vals = np.random.random(
+        (num_samples, num_timesteps, state_and_io_size))
     initial_state_vals = np.random.random((num_samples, state_and_io_size))
     # masking of two last timesteps for second sample only
     mask_vals = np.ones((num_samples, num_timesteps))
@@ -1785,29 +1817,34 @@
     depth = 6
     seq_len_0 = 5
     input_prob_matrix_0 = np.asarray(
-        [[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
-         [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
-         [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
-         [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
-         [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
-         # Random entry added in at time=5
-         [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]],
+        [
+            [0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
+            [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
+            [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
+            [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
+            [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
+            # Random entry added in at time=5
+            [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]
+        ],
         dtype=np.float32)
 
     # len max_time_steps array of batch_size x depth matrices
-    inputs = ([input_prob_matrix_0[t, :][np.newaxis, :]
-               for t in range(seq_len_0)] +  # Pad to max_time_steps = 8
-              2 * [np.zeros((1, depth), dtype=np.float32)])
+    inputs = (
+        [input_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
+        ] +  # Pad to max_time_steps = 8
+        2 * [np.zeros((1, depth), dtype=np.float32)])
 
     inputs = backend.variable(np.asarray(inputs).transpose((1, 0, 2)))
 
     # batch_size length vector of sequence_lengths
     input_length = backend.variable(np.array([seq_len_0], dtype=np.int32))
     # batch_size length vector of negative log probabilities
-    log_prob_truth = np.array([
-        -3.5821197,  # output beam 0
-        -3.777835    # output beam 1
-    ], np.float32)[np.newaxis, :]
+    log_prob_truth = np.array(
+        [
+            -3.5821197,  # output beam 0
+            -3.777835  # output beam 1
+        ],
+        np.float32)[np.newaxis, :]
 
     decode_truth = [
         np.array([1, 0, -1, -1, -1, -1, -1]),
@@ -1866,9 +1903,9 @@
 
       labels = np.asarray([[0, 1, 2, 1, 0]])
       inputs = np.asarray(
-          [[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], [
-              0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436
-          ], [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
+          [[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
+            [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
+            [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
             [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
             [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]]
           ],
@@ -1975,12 +2012,12 @@
 
     x_ph = backend.placeholder(ndim=2)
     v = backend.variable(np.ones((4, 2)))
-    output = x_ph ** 2 + v
+    output = x_ph**2 + v
     new_v = v + x_ph
     f = backend.function(x_ph, output, updates=[(v, new_v)])
     input_val = np.random.random((4, 2))
     result = f(input_val)
-    self.assertAllClose(result, input_val ** 2 + 1)
+    self.assertAllClose(result, input_val**2 + 1)
     self.assertAllClose(backend.get_value(v), np.ones((4, 2)) + input_val)
 
 
diff --git a/tensorflow/python/keras/benchmarks/BUILD b/tensorflow/python/keras/benchmarks/BUILD
index 1e249d3..3087db3 100644
--- a/tensorflow/python/keras/benchmarks/BUILD
+++ b/tensorflow/python/keras/benchmarks/BUILD
@@ -100,6 +100,17 @@
     ],
 )
 
+py_test(
+    name = "benchmark_util_test",
+    srcs = ["benchmark_util_test.py"],
+    python_version = "PY3",
+    tags = COMMON_TAGS,
+    deps = [
+        ":benchmark_util",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
 cuda_py_test(
     name = "bidirectional_lstm_benchmark_test",
     srcs = ["keras_examples_benchmarks/bidirectional_lstm_benchmark_test.py"],
@@ -210,8 +221,23 @@
     ],
 )
 
+py_test(
+    name = "optimizer_benchmarks_test",
+    srcs = ["optimizer_benchmarks_test.py"],
+    python_version = "PY3",
+    tags = COMMON_TAGS + [
+        "no_oss_py38",  # TODO(b/162044699)
+    ],
+    deps = [
+        ":benchmark_util",
+        ":profiler_lib",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/optimizer_v2",
+    ],
+)
+
 # Run memory profiler on Keras model.
-# Please make sure `meomry_profiler` is installed.
+# Please make sure `memory_profiler` is installed.
 # To run the memory profiler:
 # With CPU:
 #   bazel run -c opt model_memory_profile -- --model=YOUR_MODEL_NAME
diff --git a/tensorflow/python/keras/benchmarks/benchmark_util.py b/tensorflow/python/keras/benchmarks/benchmark_util.py
index 93aea7c..7c9b5ee 100644
--- a/tensorflow/python/keras/benchmarks/benchmark_util.py
+++ b/tensorflow/python/keras/benchmarks/benchmark_util.py
@@ -25,6 +25,45 @@
 from tensorflow.python.keras.benchmarks import distribution_util
 
 
+def get_benchmark_name(name):
+  """Split the suffix of the benchmark name.
+
+  For example, for the name = 'benchmark_layer_call__Conv2D_small_shape',
+  the return value is ['Conv2D', 'small', 'shape'].
+
+  This is to generate the metadata of the benchmark test.
+
+  Arguments:
+    name: A string, the benchmark name.
+
+  Returns:
+    A list of strings of the suffix in the benchmark name.
+  """
+  if '__' not in name or '_' not in name:
+    raise ValueError('The format of the benchmark name is wrong.')
+  return name.split('__')[-1].split('_')
+
+
+def generate_benchmark_params_cpu_gpu(*params_list):
+  """Extend the benchmark names with CPU and GPU suffix.
+
+  Arguments:
+    *params_list: A list of tuples represents the benchmark parameters.
+
+  Returns:
+    A list of strings with the benchmark name extended with CPU and GPU suffix.
+  """
+  benchmark_params = []
+  for params in params_list:
+    benchmark_params.extend([
+        ((param[0] + '_CPU',) + param[1:]) for param in params
+    ])
+    benchmark_params.extend([
+        ((param[0] + '_GPU',) + param[1:]) for param in params
+    ])
+  return benchmark_params
+
+
 class TimerCallBack(tf.keras.callbacks.Callback):
   """Callback for logging time in each epoch or batch."""
 
diff --git a/tensorflow/python/keras/benchmarks/benchmark_util_test.py b/tensorflow/python/keras/benchmarks/benchmark_util_test.py
new file mode 100644
index 0000000..7e959a1
--- /dev/null
+++ b/tensorflow/python/keras/benchmarks/benchmark_util_test.py
@@ -0,0 +1,53 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for benchmark utitilies."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras.benchmarks import benchmark_util
+
+
+class BenchmarkUtilTest(tf.test.TestCase):
+
+  def test_get_benchmark_name(self):
+    name = "benchmark_layer_call__Conv2D_small_shape"
+    expected = ["Conv2D", "small", "shape"]
+    out = benchmark_util.get_benchmark_name(name)
+    self.assertAllEqual(out, expected)
+
+  def test_generate_benchmark_params_cpu_gpu(self):
+    adam_opt = tf.keras.optimizers.Adam()
+    sgd_opt = tf.keras.optimizers.SGD()
+    params = [
+        ("Adam", adam_opt, 10),
+        ("SGD", sgd_opt, 10),
+    ]
+    expected = [
+        ("Adam_CPU", adam_opt, 10),
+        ("SGD_CPU", sgd_opt, 10),
+        ("Adam_GPU", adam_opt, 10),
+        ("SGD_GPU", sgd_opt, 10),
+    ]
+
+    out = benchmark_util.generate_benchmark_params_cpu_gpu(params)
+    self.assertAllEqual(out, expected)
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD b/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
index 4161081..a3f90de 100644
--- a/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
+++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
@@ -62,5 +62,6 @@
     deps = [
         ":layer_benchmarks_test_base",
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:benchmark_util",
     ],
 )
diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
index 0fc9015..850e7ef 100644
--- a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
+++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
@@ -19,17 +19,15 @@
 from __future__ import print_function
 
 import functools
+import numpy as np
 import six
 
 import tensorflow as tf
+from tensorflow.python.keras.benchmarks import benchmark_util
 from tensorflow.python.keras.benchmarks.layer_benchmarks import layer_benchmarks_test_base
 from tensorflow.python.platform import benchmark
 
 
-def _get_benchmark_name(name):
-  return name.split("__")[-1].split("_")
-
-
 def _get_metadata(name):
   return {
       "model_name": "ideal_layers",
@@ -37,14 +35,14 @@
   }
 
 
-def _generate_benchmark_params(*params_list):
-  benchmark_params = []
-  for params in params_list:
-    benchmark_params.extend(
-        [((param[0] + "_CPU",) + param[1:]) for param in params])
-    benchmark_params.extend(
-        [((param[0] + "_GPU",) + param[1:]) for param in params])
-  return benchmark_params
+def _get_input_data(inputs):
+  if "input_shape" in inputs:
+    return tf.ones(inputs["input_shape"])
+  elif "input" in inputs:
+    return inputs["input"]
+  else:
+    raise ValueError("Please specificy either `input_shape` or `input`"
+                     "for the benchmark test")
 
 
 def _layer_call_backward(layer, x):
@@ -63,73 +61,104 @@
   # the benchmark name. It must follow the convention of
   # "{layer_name}_{small|normal|large}_shape" to make it compatible with
   # `self.report_benchmark()` method.
-  _benchmark_parameters = _generate_benchmark_params([
-      ("Conv2D_small_shape", tf.keras.layers.Conv2D,
-       {"filters": 1, "kernel_size": 1, "activation": "relu"},
-       (1, 1, 1, 1), 10000),
-      ("Conv2D_normal_shape", tf.keras.layers.Conv2D,
-       {"filters": 1, "kernel_size": 1, "activation": "relu"},
-       (64, 28, 28, 3), 10000),
-      ("LSTM_small_shape", tf.keras.layers.LSTM,
-       {"units": 1}, (1, 1, 1), 10000),
-      ("LSTM_normal_shape", tf.keras.layers.LSTM,
-       {"units": 4}, (32, 10, 8), 10000),
+  _benchmark_parameters = benchmark_util.generate_benchmark_params_cpu_gpu([
+      ("Conv2D_small_shape", tf.keras.layers.Conv2D, {
+          "filters": 1,
+          "kernel_size": 1,
+          "activation": "relu"
+      }, {
+          "input_shape": (1, 1, 1, 1)
+      }, 10),
+      ("Conv2D_normal_shape", tf.keras.layers.Conv2D, {
+          "filters": 1,
+          "kernel_size": 1,
+          "activation": "relu"
+      }, {
+          "input_shape": (64, 28, 28, 3)
+      }, 10),
+      ("LSTM_small_shape", tf.keras.layers.LSTM, {
+          "units": 1
+      }, {
+          "input_shape": (1, 1, 1)
+      }, 10),
+      ("LSTM_normal_shape", tf.keras.layers.LSTM, {
+          "units": 4
+      }, {
+          "input_shape": (32, 10, 8)
+      }, 10),
+      ("Embedding_small_shape", tf.keras.layers.Embedding, {
+          "input_dim": 1,
+          "output_dim": 1,
+          "input_length": 1
+      }, {
+          "input": np.random.randint(1, size=(1, 1))
+      }, 10),
+      ("Embedding_normal_shape", tf.keras.layers.Embedding, {
+          "input_dim": 1000,
+          "output_dim": 64,
+          "input_length": 10
+      }, {
+          "input": np.random.randint(1000, size=(32, 10))
+      }, 10),
   ])
 
-  def benchmark_layer_call(self, layer_cls, layer_args, input_shape, num_iters):
+  def benchmark_layer_call(self, layer_cls, layer_args, inputs, num_iters):
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
 
     fn = functools.partial(layer, x)
-    name = _get_benchmark_name(self._get_name())
+    name = benchmark_util.get_benchmark_name(self._get_name())
     metadata = {"implementation": name[0] + ".layer.call"}
     metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_with_function(
-      self, layer_cls, layer_args, input_shape, num_iters):
+      self, layer_cls, layer_args, inputs, num_iters):
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
     layer.call = tf.function(layer.call)
 
     fn = functools.partial(layer, x)
-    name = _get_benchmark_name(self._get_name())
+    name = benchmark_util.get_benchmark_name(self._get_name())
     metadata = {"implementation": name[0] + ".layer.call.function"}
     metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_with_xla(
-      self, layer_cls, layer_args, input_shape, num_iters):
+      self, layer_cls, layer_args, inputs, num_iters):
+    name = benchmark_util.get_benchmark_name(self._get_name())
+    # TODO(b/173461426)
+    if layer_cls is tf.keras.layers.Embedding and name[-1] == "GPU":
+      return
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
     layer.call = tf.function(
         layer.call, jit_compile=True)
 
     fn = functools.partial(layer, x)
-    name = _get_benchmark_name(self._get_name())
     metadata = {"implementation": name[0] + ".layer.call.xla"}
     metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_backward(
-      self, layer_cls, layer_args, input_shape, num_iters):
+      self, layer_cls, layer_args, inputs, num_iters):
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
 
     fn = functools.partial(_layer_call_backward, layer, x)
-    name = _get_benchmark_name(self._get_name())
+    name = benchmark_util.get_benchmark_name(self._get_name())
     metadata = {"implementation": name[0] + ".layer.call.backward"}
     metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_backward_with_function(
-      self, layer_cls, layer_args, input_shape, num_iters):
+      self, layer_cls, layer_args, inputs, num_iters):
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
     layer.call = tf.function(layer.call)
 
     fn = functools.partial(_layer_call_backward, layer, x)
-    name = _get_benchmark_name(self._get_name())
+    name = benchmark_util.get_benchmark_name(self._get_name())
     metadata = {"implementation": name[0] + ".layer.call.backward.function"}
     metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
@@ -139,29 +168,54 @@
     benchmark.ParameterizedBenchmark,
     layer_benchmarks_test_base.LayerBenchmarksBase)):
 
-  _benchmark_parameters = _generate_benchmark_params([
-      ("Conv2D_small_shape", tf.keras.layers.Conv2D,
-       {"filters": 1, "kernel_size": 1, "activation": "relu"},
-       (1, 1, 1, 1), 10000),
-      ("Conv2D_normal_shape", tf.keras.layers.Conv2D,
-       {"filters": 1, "kernel_size": 1, "activation": "relu"},
-       (64, 28, 28, 3), 10000),
+  _benchmark_parameters = benchmark_util.generate_benchmark_params_cpu_gpu([
+      ("Conv2D_small_shape", tf.keras.layers.Conv2D, {
+          "filters": 1,
+          "kernel_size": 1,
+          "activation": "relu"
+      }, {
+          "input_shape": (1, 1, 1, 1)
+      }, 10000),
+      ("Conv2D_normal_shape", tf.keras.layers.Conv2D, {
+          "filters": 1,
+          "kernel_size": 1,
+          "activation": "relu"
+      }, {
+          "input_shape": (64, 28, 28, 3)
+      }, 10000),
       # TODO(b/153480400)
       # ("LSTM_small_shape", tf.keras.layers.LSTM,
-      #  {"units": 1}, (1, 1, 1), 10000),
+      #  {"units": 1}, {"input_shape": (1, 1, 1)}, 10000),
       # ("LSTM_normal_shape", tf.keras.layers.LSTM,
-      #  {"units": 4}, (32, 10, 8), 10000),
+      #  {"units": 4}, {"input_shape": (32, 10, 8)}, 10000),
+      ("Embedding_small_shape", tf.keras.layers.Embedding, {
+          "input_dim": 1,
+          "output_dim": 1,
+          "input_length": 1
+      }, {
+          "input": np.random.randint(1, size=(1, 1))
+      }, 10),
+      ("Embedding_normal_shape", tf.keras.layers.Embedding, {
+          "input_dim": 1000,
+          "output_dim": 64,
+          "input_length": 10
+      }, {
+          "input": np.random.randint(1000, size=(32, 10))
+      }, 10),
   ])
 
   def benchmark_layer_call_backward_with_xla(
-      self, layer_cls, layer_args, input_shape, num_iters):
+      self, layer_cls, layer_args, inputs, num_iters):
+    name = benchmark_util.get_benchmark_name(self._get_name())
+    # TODO(b/173461426)
+    if layer_cls is tf.keras.layers.Embedding and name[-1] == "GPU":
+      return
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
     layer.call = tf.function(
         layer.call, jit_compile=True)
 
     fn = functools.partial(_layer_call_backward, layer, x)
-    name = _get_benchmark_name(self._get_name())
     metadata = {"implementation": name[0] + ".layer.call.backward.xla"}
     metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
diff --git a/tensorflow/python/keras/benchmarks/optimizer_benchmarks_test.py b/tensorflow/python/keras/benchmarks/optimizer_benchmarks_test.py
new file mode 100644
index 0000000..a5ba771
--- /dev/null
+++ b/tensorflow/python/keras/benchmarks/optimizer_benchmarks_test.py
@@ -0,0 +1,86 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark tests for Keras optimizers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras.benchmarks import benchmark_util
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.platform.benchmark import ParameterizedBenchmark
+
+
+def bidirect_imdb_lstm_config():
+  """Bidirectional LSTM model and IMDB data."""
+
+  def model_fn():
+    inputs = tf.keras.Input(shape=(None,), dtype="int32")
+    x = tf.keras.layers.Embedding(20000, 128)(inputs)
+    x = tf.keras.layers.Bidirectional(
+        tf.keras.layers.LSTM(64, return_sequences=True))(
+            x)
+    x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64))(x)
+    outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)
+    model = tf.keras.Model(inputs, outputs)
+    return model
+
+  (x_train, y_train), _ = tf.keras.datasets.imdb.load_data(num_words=20000)
+  x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=200)
+
+  return model_fn, x_train, y_train
+
+
+class KerasOptimizerBenchmark(
+    tf.test.Benchmark, metaclass=ParameterizedBenchmark):
+  """Keras optimizer benchmarks."""
+
+  # The parameter of each benchmark test is a tuple, and the first one is
+  # the optimizer name.
+  _benchmark_parameters = benchmark_util.generate_benchmark_params_cpu_gpu([
+      ("Adam", tf.keras.optimizers.Adam(), 10),
+      ("NonFusedAdam", adam.NonFusedAdam(), 10),
+  ])
+
+  def benchmark_optimizer(self, optimizer, num_iters):
+    """Optimizer benchmark with Bidirectional LSTM model on IMDB data.
+
+    Arguments:
+      optimizer: The optimizer instance to be benchmarked.
+      num_iters: The number of iterations to run for performance measurement.
+    """
+    model, train_x, train_y = bidirect_imdb_lstm_config()
+    metrics, wall_time, extras = benchmark_util.measure_performance(
+        model,
+        x=train_x,
+        y=train_y,
+        batch_size=512,
+        optimizer=optimizer,
+        loss="binary_crossentropy",
+        metrics=["accuracy"])
+    name = benchmark_util.get_benchmark_name(self._get_name())
+    metadata = {
+        "implementation": name[0],
+        "model_name": "optimizers",
+        "parameters": "lstm.512",
+    }
+    extras.update(metadata)
+    self.report_benchmark(
+        iters=num_iters, wall_time=wall_time, metrics=metrics, extras=extras)
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index bfe9169..96d9031 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -72,6 +72,7 @@
   requests = None
 
 
+# Note: `configure_callbacks` is only used in TF1.
 def configure_callbacks(callbacks,
                         model,
                         do_validation=False,
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index 98f30e7..4b563db 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -808,7 +808,6 @@
     tags = [
         "multi_and_single_gpu",
         "no_cuda_asan",  # times out
-        "no_rocm",
     ],
     xla_tags = [
         "no_cuda_asan",  # times out
@@ -828,7 +827,6 @@
     shard_count = 7,
     tags = [
         "multi_and_single_gpu",
-        "no_rocm",
     ],
     xla_tags = [
         "no_cuda_asan",  # times out
@@ -897,11 +895,11 @@
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
         "//tensorflow/python:summary_ops_v2",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:training_lib",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
         "//tensorflow/python/training/tracking:util",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/keras/distribute/sidecar_evaluator.py b/tensorflow/python/keras/distribute/sidecar_evaluator.py
index 663cfef..ce576b8 100644
--- a/tensorflow/python/keras/distribute/sidecar_evaluator.py
+++ b/tensorflow/python/keras/distribute/sidecar_evaluator.py
@@ -19,6 +19,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import re
+
 # pylint: disable=g-direct-tensorflow-import
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
@@ -32,6 +34,26 @@
 _ITERATIONS_UNINITIALIZED = -1
 
 
+def list_checkpoint_attributes(ckpt_dir_or_file):
+  """Lists all the attributes in a checkpoint.
+
+  Checkpoint keys are paths in a checkpoint graph, and attribute is the first
+  element in the path. e.g. with a checkpoint key
+  "optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE", optimizer is the attribute. The
+  attribute is also used to save/restore a variable in a checkpoint,
+  e.g. tf.train.Checkpoint(optimizer=optimizer, model=model).
+
+  Args:
+    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
+
+  Returns:
+    Set of attributes in a checkpoint.
+  """
+  reader = checkpoint_utils.load_checkpoint(ckpt_dir_or_file)
+  variable_map = reader.get_variable_to_shape_map()
+  return {name.split('/')[0] for name in variable_map.keys()}
+
+
 class SidecarEvaluator(object):
   """A class designed for a dedicated evaluator task.
 
@@ -148,6 +170,21 @@
         # `expect_partial` because the checkpoint can have other `Trackable`s
         # such as `optimizer`.
         checkpoint.restore(latest_checkpoint).expect_partial()
+        checkpoint_attributes = list_checkpoint_attributes(latest_checkpoint)
+        # The checkpoint should contain model and optimizer for SidecarEvaluator
+        # to work. But the model weights saved by ModelCheckpoint callback does
+        # not contain model as an attribute. To make SidecarEvaluator compatibly
+        # work in this case, if model attribute is not found but
+        # layer_with_weights attribute is found, use model.load_weights to load
+        # the model's weights, while self._iterations is still restored by
+        # checkpoint variable.
+        if 'model' not in checkpoint_attributes:
+          for attribute in checkpoint_attributes:
+            # check whether the checkpoint has the required attributes for
+            # model.load_weights to work.
+            if re.match(r'^layer_with_weights-[\d+]', attribute) is not None:
+              self.model.load_weights(latest_checkpoint)
+              break
       except (errors_impl.OpError,) as e:
         # A couple errors can happen here with the coordinator racing to write
         # checkpoint:
diff --git a/tensorflow/python/keras/distribute/sidecar_evaluator_test.py b/tensorflow/python/keras/distribute/sidecar_evaluator_test.py
index af35c92..9e75f1c 100644
--- a/tensorflow/python/keras/distribute/sidecar_evaluator_test.py
+++ b/tensorflow/python/keras/distribute/sidecar_evaluator_test.py
@@ -20,7 +20,6 @@
 from __future__ import print_function
 
 import os
-import unittest
 
 from absl import logging
 import numpy as np
@@ -36,6 +35,8 @@
 from tensorflow.python.training import checkpoint_management
 from tensorflow.python.training.tracking import util as tracking_util
 
+_BATCH_SIZE = 32
+
 
 class SidecarEvaluatorTest(test.TestCase):
 
@@ -130,7 +131,6 @@
 
     self.assertSummaryEventsWritten(log_dir)
 
-  @unittest.skip('b/172976255')
   def testSidecarEvaluatorOutputsSummarySavedWithCallback(self):
     checkpoint_dir = os.path.join(self.get_temp_dir(), 'checkpoints')
     log_dir = os.path.join(self.get_temp_dir(), 'summary')
@@ -139,7 +139,7 @@
     data = np.random.random((1000, 32))
     labels = np.random.random((1000, 10))
     dataset = dataset_ops.Dataset.from_tensor_slices((data, labels))
-    dataset = dataset.batch(32)
+    dataset = dataset.batch(_BATCH_SIZE)
     save_callback = keras.callbacks.ModelCheckpoint(
         filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
         save_weights_only=True)
@@ -152,17 +152,22 @@
     # Create a new model used for evaluation.
     eval_model = self.createTestModel(compile_model=True)
     # Have an sidecar_evaluator evaluate once.
-    sidecar_evaluator_lib.SidecarEvaluator(
+    sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
         eval_model,
         data=dataset,
         checkpoint_dir=checkpoint_dir,
         log_dir=log_dir,
-        max_evaluations=1).start()
+        max_evaluations=1)
+    sidecar_evaluator.start()
+
     # Eval model has been restored to the same state as the original model, so
     # their weights should match. If not, restoration of the model didn't
     # work.
     self.assertModelsSameVariables(model, eval_model)
 
+    # check the iterations is restored.
+    self.assertEqual(sidecar_evaluator._iterations.numpy(), _BATCH_SIZE)
+
     self.assertSummaryEventsWritten(log_dir)
 
 
diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD
index ef26a14..185f15a 100644
--- a/tensorflow/python/keras/engine/BUILD
+++ b/tensorflow/python/keras/engine/BUILD
@@ -1,7 +1,10 @@
 # Description:
 #   Contains the Keras engine API (internal TensorFlow version).
 
+# buildifier: disable=same-origin-load
 load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+# buildifier: disable=same-origin-load
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 
 package(
@@ -80,6 +83,8 @@
         "//tensorflow/python/ops/ragged:ragged_tensor",
         "//tensorflow/python/ops/ragged:ragged_util",
         "//tensorflow/python/profiler:trace",
+        "//tensorflow/python/saved_model:constants",
+        "//tensorflow/python/saved_model:loader",
         "//tensorflow/python/tpu:tpu_lib",
         "//tensorflow/python/training/tracking:data_structures",
         "//tensorflow/tools/docs:doc_controls",
@@ -135,7 +140,7 @@
         "//tensorflow/python:tensor_spec",
         "//tensorflow/python:tensor_util",
         "//tensorflow/python:tf2",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
         "//tensorflow/python/autograph/core",
@@ -389,7 +394,6 @@
     shard_count = 20,
     tags = [
         "manual",
-        "no_rocm",
         "nomac",  # TODO(mihaimaruseac): b/127695564
         "notsan",
     ],
@@ -514,7 +518,6 @@
     python_version = "PY3",
     shard_count = 30,
     tags = [
-        "no_rocm",
         "nomac",  # TODO(mihaimaruseac): b/127695564
     ],
     deps = [
@@ -604,6 +607,7 @@
     shard_count = 8,
     tags = [
         "no-internal-py3",
+        "no_rocm",
         "nomac",  # TODO(mihaimaruseac): b/127695564
     ],
     deps = [
@@ -658,7 +662,6 @@
     python_version = "PY3",
     shard_count = 8,
     tags = [
-        "no_rocm",
         "nomac",  # TODO(mihaimaruseac): b/127695564
     ],
     deps = [
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 3f5f4ff..894751c 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -924,6 +924,7 @@
         for `inputs` by the previous layer (if `input` did come from
         a layer that generated a corresponding mask, i.e. if it came from
         a Keras layer with masking support.
+      - If the layer is not built, the method will call `build`.
 
     Raises:
       ValueError: if the layer's `call` method returns None (an invalid value).
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index eff921e..474dce8 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -178,6 +178,10 @@
       origin_layer = layer._keras_history[0]
       if isinstance(origin_layer, input_layer.InputLayer):
         layer = origin_layer
+        logging.warning(
+            'Please add `keras.layers.InputLayer` instead of `keras.Input` to '
+            'Sequential model. `keras.Input` is intended to be used by '
+            'Functional model.')
 
     if isinstance(layer, module.Module):
       if not isinstance(layer, base_layer.Layer):
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 9dcac8c..91c1182 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -55,6 +55,7 @@
 from tensorflow.python.keras.mixed_precision import policy
 from tensorflow.python.keras.saving import hdf5_format
 from tensorflow.python.keras.saving import save
+from tensorflow.python.keras.saving import saving_utils
 from tensorflow.python.keras.saving.saved_model import json_utils
 from tensorflow.python.keras.saving.saved_model import model_serialization
 from tensorflow.python.keras.utils import generic_utils
@@ -72,6 +73,8 @@
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.profiler import trace
+from tensorflow.python.saved_model import constants as sm_constants
+from tensorflow.python.saved_model import loader_impl as sm_loader
 from tensorflow.python.training import checkpoint_management
 from tensorflow.python.training import py_checkpoint_reader
 from tensorflow.python.training.tracking import base as trackable
@@ -2114,7 +2117,7 @@
     """
     self._assert_weights_created()
     filepath = path_to_string(filepath)
-    filepath_is_h5 = _is_hdf5_filepath(filepath)
+    filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
     if save_format is None:
       if filepath_is_h5:
         save_format = 'h5'
@@ -2202,7 +2205,8 @@
     Arguments:
         filepath: String, path to the weights file to load. For weight files in
             TensorFlow format, this is the file prefix (the same as was passed
-            to `save_weights`).
+            to `save_weights`). This can also be a path to a SavedModel
+            saved from `model.save`.
         by_name: Boolean, whether to load weights by name or by topological
             order. Only topological loading is supported for weight files in
             TensorFlow format.
@@ -2229,7 +2233,7 @@
     """
     if dist_utils.is_tpu_strategy(self._distribution_strategy):
       if (self._distribution_strategy.extended.steps_per_run > 1 and
-          (not _is_hdf5_filepath(filepath))):
+          (not saving_utils.is_hdf5_filepath(filepath))):
         raise ValueError('Load weights is not yet supported with TPUStrategy '
                          'with steps_per_run greater than 1.')
     if skip_mismatch and not by_name:
@@ -2237,16 +2241,7 @@
           'When calling model.load_weights, skip_mismatch can only be set to '
           'True when by_name is True.')
 
-    filepath = path_to_string(filepath)
-    if _is_hdf5_filepath(filepath):
-      save_format = 'h5'
-    else:
-      try:
-        py_checkpoint_reader.NewCheckpointReader(filepath)
-        save_format = 'tf'
-      except errors_impl.DataLossError:
-        # The checkpoint is not readable in TensorFlow format. Try HDF5.
-        save_format = 'h5'
+    filepath, save_format = _detect_save_format(filepath)
     if save_format == 'tf':
       status = self._trackable_saver.restore(filepath, options)
       if by_name:
@@ -2851,6 +2846,40 @@
     raise RuntimeError(error_msg)
 
 
-def _is_hdf5_filepath(filepath):
-  return (filepath.endswith('.h5') or filepath.endswith('.keras') or
-          filepath.endswith('.hdf5'))
+def _detect_save_format(filepath):
+  """Returns path to weights file and save format."""
+
+  filepath = path_to_string(filepath)
+  if saving_utils.is_hdf5_filepath(filepath):
+    return filepath, 'h5'
+
+  # Filepath could be a TensorFlow checkpoint file prefix or SavedModel
+  # directory. It's possible for filepath to be both a prefix and directory.
+  # Prioritize checkpoint over SavedModel.
+  if _is_readable_tf_checkpoint(filepath):
+    save_format = 'tf'
+  elif sm_loader.contains_saved_model(filepath):
+    ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY,
+                             sm_constants.VARIABLES_FILENAME)
+    if _is_readable_tf_checkpoint(ckpt_path):
+      filepath = ckpt_path
+      save_format = 'tf'
+    else:
+      raise ValueError('Unable to load weights. filepath {} appears to be a '
+                       'SavedModel directory, but checkpoint either doesn\'t '
+                       'exist, or is incorrectly formatted.'.format(filepath))
+  else:
+    # Not a TensorFlow checkpoint. This filepath is likely an H5 file that
+    # doesn't have the hdf5/keras extensions.
+    save_format = 'h5'
+  return filepath, save_format
+
+
+def _is_readable_tf_checkpoint(filepath):
+  try:
+    py_checkpoint_reader.NewCheckpointReader(filepath)
+    return True
+  except errors_impl.DataLossError:
+    # The checkpoint is not readable in TensorFlow format.
+    return False
+
diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py
index 0faafa6..576e8c8 100644
--- a/tensorflow/python/keras/engine/training_v1.py
+++ b/tensorflow/python/keras/engine/training_v1.py
@@ -55,6 +55,7 @@
 from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
 from tensorflow.python.keras.mixed_precision import policy
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.keras.saving import saving_utils
 from tensorflow.python.keras.saving.saved_model import model_serialization
 from tensorflow.python.keras.utils import data_utils
 from tensorflow.python.keras.utils import layer_utils
@@ -229,7 +230,7 @@
     """
     if distributed_training_utils.is_tpu_strategy(self._distribution_strategy):
       if (self._distribution_strategy.extended.steps_per_run > 1 and
-          (not training_lib._is_hdf5_filepath(filepath))):  # pylint: disable=protected-access
+          (not saving_utils.is_hdf5_filepath(filepath))):  # pylint: disable=protected-access
         raise ValueError('Load weights is not yet supported with TPUStrategy '
                          'with steps_per_run greater than 1.')
     return super(Model, self).load_weights(filepath, by_name, skip_mismatch)
diff --git a/tensorflow/python/keras/feature_column/BUILD b/tensorflow/python/keras/feature_column/BUILD
index 8266173..e736913 100644
--- a/tensorflow/python/keras/feature_column/BUILD
+++ b/tensorflow/python/keras/feature_column/BUILD
@@ -45,10 +45,10 @@
     deps = [
         ":base_feature_layer",
         "//tensorflow/python:framework_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python/feature_column:feature_column_v2",
         "//tensorflow/python/keras:backend",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -61,8 +61,8 @@
         ":base_feature_layer",
         ":dense_features",
         "//tensorflow/python:framework_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/feature_column:feature_column_v2",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -125,9 +125,9 @@
         "//tensorflow/python:array_ops",
         "//tensorflow/python:check_ops",
         "//tensorflow/python:framework_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/feature_column:feature_column_v2",
         "//tensorflow/python/keras:backend",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/keras/initializers/initializers_v2.py b/tensorflow/python/keras/initializers/initializers_v2.py
index 0e4fd66..1eaf0af 100644
--- a/tensorflow/python/keras/initializers/initializers_v2.py
+++ b/tensorflow/python/keras/initializers/initializers_v2.py
@@ -19,12 +19,23 @@
 from __future__ import division
 from __future__ import print_function
 
+import math
+
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.keras import backend
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_linalg_ops
 from tensorflow.python.ops import init_ops_v2
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import stateless_random_ops
 from tensorflow.python.util.tf_export import keras_export
 
+_PARTITION_SHAPE = 'partition_shape'
+_PARTITION_OFFSET = 'partition_offset'
+
 
 @keras_export('keras.initializers.Initializer')
 class Initializer(object):
@@ -308,7 +319,7 @@
 @keras_export('keras.initializers.TruncatedNormal',
               'keras.initializers.truncated_normal',
               v1=[])
-class TruncatedNormal(init_ops_v2.TruncatedNormal, Initializer):
+class TruncatedNormal(Initializer):
   """Initializer that generates a truncated normal distribution.
 
   Also available via the shortcut function
@@ -338,6 +349,12 @@
       always produce the same random tensor for a given shape and dtype.
   """
 
+  def __init__(self, mean=0.0, stddev=0.05, seed=None):
+    self.mean = mean
+    self.stddev = stddev
+    self.seed = seed
+    self._random_generator = _RandomGenerator(seed)
+
   def __call__(self, shape, dtype=None, **kwargs):
     """Returns a tensor object initialized to random normal values (truncated).
 
@@ -349,14 +366,25 @@
         `tf.keras.backend.set_floatx(float_dtype)`)
       **kwargs: Additional keyword arguments.
     """
-    return super(TruncatedNormal, self).__call__(
-        shape, dtype=_get_dtype(dtype), **kwargs)
+    _validate_kwargs(self.__class__.__name__, kwargs)
+    dtype = _assert_float_dtype(_get_dtype(dtype))
+    if _PARTITION_SHAPE in kwargs:
+      shape = kwargs[_PARTITION_SHAPE]
+    return self._random_generator.truncated_normal(shape, self.mean,
+                                                   self.stddev, dtype)
+
+  def get_config(self):
+    return {
+        'mean': self.mean,
+        'stddev': self.stddev,
+        'seed': self.seed
+    }
 
 
 @keras_export('keras.initializers.VarianceScaling',
               'keras.initializers.variance_scaling',
               v1=[])
-class VarianceScaling(init_ops_v2.VarianceScaling, Initializer):
+class VarianceScaling(Initializer):
   """Initializer capable of adapting its scale to the shape of weights tensors.
 
   Also available via the shortcut function
@@ -395,6 +423,28 @@
       always produce the same random tensor for a given shape and dtype.
   """
 
+  def __init__(self,
+               scale=1.0,
+               mode='fan_in',
+               distribution='truncated_normal',
+               seed=None):
+    if scale <= 0.:
+      raise ValueError('`scale` must be positive float.')
+    if mode not in {'fan_in', 'fan_out', 'fan_avg'}:
+      raise ValueError('Invalid `mode` argument:', mode)
+    distribution = distribution.lower()
+    # Compatibility with keras-team/keras.
+    if distribution == 'normal':
+      distribution = 'truncated_normal'
+    if distribution not in {'uniform', 'truncated_normal',
+                            'untruncated_normal'}:
+      raise ValueError('Invalid `distribution` argument:', distribution)
+    self.scale = scale
+    self.mode = mode
+    self.distribution = distribution
+    self.seed = seed
+    self._random_generator = _RandomGenerator(seed)
+
   def __call__(self, shape, dtype=None, **kwargs):
     """Returns a tensor object initialized as specified by the initializer.
 
@@ -406,14 +456,42 @@
         `tf.keras.backend.set_floatx(float_dtype)`)
       **kwargs: Additional keyword arguments.
     """
-    return super(VarianceScaling, self).__call__(
-        shape, dtype=_get_dtype(dtype), **kwargs)
+    _validate_kwargs(self.__class__.__name__, kwargs)
+    dtype = _assert_float_dtype(_get_dtype(dtype))
+    scale = self.scale
+    fan_in, fan_out = _compute_fans(shape)
+    if _PARTITION_SHAPE in kwargs:
+      shape = kwargs[_PARTITION_SHAPE]
+    if self.mode == 'fan_in':
+      scale /= max(1., fan_in)
+    elif self.mode == 'fan_out':
+      scale /= max(1., fan_out)
+    else:
+      scale /= max(1., (fan_in + fan_out) / 2.)
+    if self.distribution == 'truncated_normal':
+      # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
+      stddev = math.sqrt(scale) / .87962566103423978
+      return self._random_generator.truncated_normal(shape, 0.0, stddev, dtype)
+    elif self.distribution == 'untruncated_normal':
+      stddev = math.sqrt(scale)
+      return self._random_generator.random_normal(shape, 0.0, stddev, dtype)
+    else:
+      limit = math.sqrt(3.0 * scale)
+      return self._random_generator.random_uniform(shape, -limit, limit, dtype)
+
+  def get_config(self):
+    return {
+        'scale': self.scale,
+        'mode': self.mode,
+        'distribution': self.distribution,
+        'seed': self.seed
+    }
 
 
 @keras_export('keras.initializers.Orthogonal',
               'keras.initializers.orthogonal',
               v1=[])
-class Orthogonal(init_ops_v2.Orthogonal, Initializer):
+class Orthogonal(Initializer):
   """Initializer that generates an orthogonal matrix.
 
   Also available via the shortcut function `tf.keras.initializers.orthogonal`.
@@ -449,6 +527,11 @@
       ([pdf](https://arxiv.org/pdf/1312.6120.pdf))
   """
 
+  def __init__(self, gain=1.0, seed=None):
+    self.gain = gain
+    self.seed = seed
+    self._random_generator = _RandomGenerator(seed)
+
   def __call__(self, shape, dtype=None, **kwargs):
     """Returns a tensor object initialized to an orthogonal matrix.
 
@@ -460,14 +543,39 @@
        (via `tf.keras.backend.set_floatx(float_dtype)`)
       **kwargs: Additional keyword arguments.
     """
-    return super(Orthogonal, self).__call__(
-        shape, dtype=_get_dtype(dtype), **kwargs)
+    _validate_kwargs(self.__class__.__name__, kwargs, support_partition=False)
+    dtype = _assert_float_dtype(_get_dtype(dtype))
+    # Check the shape
+    if len(shape) < 2:
+      raise ValueError('The tensor to initialize must be '
+                       'at least two-dimensional')
+    # Flatten the input shape with the last dimension remaining
+    # its original shape so it works for conv2d
+    num_rows = 1
+    for dim in shape[:-1]:
+      num_rows *= dim
+    num_cols = shape[-1]
+    flat_shape = (max(num_cols, num_rows), min(num_cols, num_rows))
+
+    # Generate a random matrix
+    a = self._random_generator.random_normal(flat_shape, dtype=dtype)
+    # Compute the qr factorization
+    q, r = gen_linalg_ops.qr(a, full_matrices=False)
+    # Make Q uniform
+    d = array_ops.tensor_diag_part(r)
+    q *= math_ops.sign(d)
+    if num_rows < num_cols:
+      q = array_ops.matrix_transpose(q)
+    return self.gain * array_ops.reshape(q, shape)
+
+  def get_config(self):
+    return {'gain': self.gain, 'seed': self.seed}
 
 
 @keras_export('keras.initializers.Identity',
               'keras.initializers.identity',
               v1=[])
-class Identity(init_ops_v2.Identity, Initializer):
+class Identity(Initializer):
   """Initializer that generates the identity matrix.
 
   Also available via the shortcut function `tf.keras.initializers.identity`.
@@ -488,6 +596,9 @@
     gain: Multiplicative factor to apply to the identity matrix.
   """
 
+  def __init__(self, gain=1.0):
+    self.gain = gain
+
   def __call__(self, shape, dtype=None, **kwargs):
     """Returns a tensor object initialized to a 2D identity matrix.
 
@@ -499,8 +610,16 @@
        (via `tf.keras.backend.set_floatx(float_dtype)`)
       **kwargs: Additional keyword arguments.
     """
-    return super(Identity, self).__call__(
-        shape, dtype=_get_dtype(dtype), **kwargs)
+    _validate_kwargs(self.__class__.__name__, kwargs, support_partition=False)
+    dtype = _assert_float_dtype(_get_dtype(dtype))
+    if len(shape) != 2:
+      raise ValueError(
+          'Identity matrix initializer can only be used for 2D matrices.')
+    initializer = linalg_ops.eye(*shape, dtype=dtype)
+    return self.gain * initializer
+
+  def get_config(self):
+    return {'gain': self.gain}
 
 
 @keras_export('keras.initializers.GlorotUniform',
@@ -765,3 +884,98 @@
   if dtype is None:
     dtype = backend.floatx()
   return dtypes.as_dtype(dtype)
+
+
+def _assert_float_dtype(dtype):
+  """Validate and return floating point type based on `dtype`.
+
+  `dtype` must be a floating point type.
+
+  Args:
+    dtype: The data type to validate.
+
+  Returns:
+    Validated type.
+
+  Raises:
+    ValueError: if `dtype` is not a floating point type.
+  """
+  dtype = dtypes.as_dtype(dtype)
+  if not dtype.is_floating:
+    raise ValueError('Expected floating point type, got %s.' % dtype)
+  return dtype
+
+
+class _RandomGenerator(object):
+  """Random generator that selects appropriate random ops."""
+
+  def __init__(self, seed=None):
+    super(_RandomGenerator, self).__init__()
+    if seed is not None:
+      # Stateless random ops requires 2-int seed.
+      self.seed = [seed, 0]
+    else:
+      self.seed = None
+
+  def random_normal(self, shape, mean=0.0, stddev=1, dtype=dtypes.float32):
+    """A deterministic random normal if seed is passed."""
+    if self.seed:
+      op = stateless_random_ops.stateless_random_normal
+    else:
+      op = random_ops.random_normal
+    return op(
+        shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=self.seed)
+
+  def random_uniform(self, shape, minval, maxval, dtype):
+    """A deterministic random uniform if seed is passed."""
+    if self.seed:
+      op = stateless_random_ops.stateless_random_uniform
+    else:
+      op = random_ops.random_uniform
+    return op(
+        shape=shape, minval=minval, maxval=maxval, dtype=dtype, seed=self.seed)
+
+  def truncated_normal(self, shape, mean, stddev, dtype):
+    """A deterministic truncated normal if seed is passed."""
+    if self.seed:
+      op = stateless_random_ops.stateless_truncated_normal
+    else:
+      op = random_ops.truncated_normal
+    return op(
+        shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=self.seed)
+
+
+def _compute_fans(shape):
+  """Computes the number of input and output units for a weight shape.
+
+  Args:
+    shape: Integer shape tuple or TF tensor shape.
+
+  Returns:
+    A tuple of integer scalars (fan_in, fan_out).
+  """
+  if len(shape) < 1:  # Just to avoid errors for constants.
+    fan_in = fan_out = 1
+  elif len(shape) == 1:
+    fan_in = fan_out = shape[0]
+  elif len(shape) == 2:
+    fan_in = shape[0]
+    fan_out = shape[1]
+  else:
+    # Assuming convolution kernels (2D, 3D, or more).
+    # kernel shape: (..., input_depth, depth)
+    receptive_field_size = 1
+    for dim in shape[:-2]:
+      receptive_field_size *= dim
+    fan_in = shape[-2] * receptive_field_size
+    fan_out = shape[-1] * receptive_field_size
+  return int(fan_in), int(fan_out)
+
+
+def _validate_kwargs(cls_name, kwargs, support_partition=True):
+  for kwarg in kwargs:
+    if kwarg not in [_PARTITION_SHAPE, _PARTITION_OFFSET]:
+      raise TypeError('Unknown keyword arguments: %s' % kwarg)
+    elif not support_partition:
+      raise ValueError('%s initializer doesn\'t support partition-related '
+                       'arguments' % cls_name)
diff --git a/tensorflow/python/keras/integration_test/forwardprop_test.py b/tensorflow/python/keras/integration_test/forwardprop_test.py
index a93dd23..af9b34e 100644
--- a/tensorflow/python/keras/integration_test/forwardprop_test.py
+++ b/tensorflow/python/keras/integration_test/forwardprop_test.py
@@ -118,7 +118,9 @@
   """
   return _vectorize_parameters(
       functools.partial(_hvp, f, params),
-      params, use_pfor=use_pfor, dtype=dtype)
+      params,
+      use_pfor=use_pfor,
+      dtype=dtype)
 
 
 def _test_gradients(testcase,
@@ -173,7 +175,10 @@
                   2. / tf.size(v, out_type=tf.float32),
                   dtype=tf.float32), v.shape))
     _test_gradients(
-        self, layer, [input_value], atol=atol,
+        self,
+        layer,
+        [input_value],
+        atol=atol,
         # These are linear, so second-order is pretty boring.
         order=2)
 
@@ -189,8 +194,10 @@
       input_value = tf.constant(value, dtype=tf.float32)
       layer.build(input_value.shape)
       _test_gradients(
-          self, functools.partial(layer, training=training), [input_value],
-          order=2, atol=1e-3)
+          self,
+          functools.partial(layer, training=training), [input_value],
+          order=2,
+          atol=1e-3)
 
   @parameterized.named_parameters([
       ("NonFused", [[0.1], [0.2], [-0.3]],
@@ -205,8 +212,8 @@
         input_value = tf.constant(value, dtype=tf.float32)
         tape.watch(input_value)
         output = layer(input_value, training=training)
-      jac_back = tape.jacobian(
-          output, [input_value] + layer.trainable_variables)
+      jac_back = tape.jacobian(output,
+                               [input_value] + layer.trainable_variables)
       jac_forward = _jacfwd(
           lambda *args: layer(args[0], training=training),  # pylint:disable=cell-var-from-loop
           [input_value] + layer.trainable_variables)
@@ -218,12 +225,6 @@
                                    ("NoFunction", lambda f: f)])
   def testVariablesHVP(self, decorator):
 
-    if tf.test.is_built_with_rocm():
-      # TODO(rocm)
-      # This test was recently added and has never passed on the
-      # ROCm platform. Remove this skip once the test is passing again
-      self.skipTest("NoFunction decorator test fails on the ROCm platform")
-
     class _Model(tf.Module):
 
       def __init__(self):
@@ -240,6 +241,7 @@
         return self._second_dense(x)
 
     model = _Model()
+
     def _loss():
       input_value = tf.constant([[-0.5, 1.], [0.5, -1.]])
       target = tf.constant([[-1.], [2.]])
@@ -251,8 +253,8 @@
         loss = _loss()
       vector = tape.gradient(loss, model.trainable_variables)
       variable_input_fn = lambda unused_variables: _loss()
-      forward_over_back_hvp, = _hvp(
-          variable_input_fn, [model.trainable_variables], [vector])
+      forward_over_back_hvp, = _hvp(variable_input_fn,
+                                    [model.trainable_variables], [vector])
       with tf.GradientTape(persistent=True) as tape:
         tape.watch(model.trainable_variables)
         loss = _loss()
@@ -260,6 +262,7 @@
       back_over_back_hvp = tape.gradient(
           first_grads, model.trainable_variables, output_gradients=vector)
       return forward_over_back_hvp, back_over_back_hvp
+
     self.assertAllClose(*_compute_hvps(), rtol=1e-5, atol=1e-5)
 
   def testEmbeddingLayerInFunction(self):
@@ -288,9 +291,7 @@
 
 class HessianTests(tf.test.TestCase, parameterized.TestCase):
 
-  @parameterized.named_parameters(
-      [("PFor", True),
-       ("MapFn", False)])
+  @parameterized.named_parameters([("PFor", True), ("MapFn", False)])
   def testHessianOfVariables(self, use_pfor):
     model = tf.keras.layers.Dense(1)
     model.build([None, 2])
diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD
index 54e815d..1caff66 100644
--- a/tensorflow/python/keras/layers/BUILD
+++ b/tensorflow/python/keras/layers/BUILD
@@ -750,7 +750,6 @@
     python_version = "PY3",
     shard_count = 4,
     tags = [
-        "no_rocm",
         "notsan",  # http://b/62136390
     ],
     deps = [
@@ -769,7 +768,6 @@
     python_version = "PY3",
     shard_count = 4,
     tags = [
-        "no_rocm",
         "noasan",  # times out b/63678675
         "notsan",  # http://b/62189182
     ],
diff --git a/tensorflow/python/keras/layers/gru_v2_test.py b/tensorflow/python/keras/layers/gru_v2_test.py
index 0422ce1..80776fa 100644
--- a/tensorflow/python/keras/layers/gru_v2_test.py
+++ b/tensorflow/python/keras/layers/gru_v2_test.py
@@ -34,6 +34,7 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util as tf_test_util
 from tensorflow.python.keras import combinations
 from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import testing_utils
@@ -594,6 +595,7 @@
       outputs_trimmed = lstm(inputs[:, :masksteps])
     self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
 
+  @tf_test_util.enable_output_all_intermediates
   def test_v1_session_behavior(self):
     with ops.get_default_graph().as_default():
       # See b/139132348 for more details.
diff --git a/tensorflow/python/keras/layers/legacy_rnn/BUILD b/tensorflow/python/keras/layers/legacy_rnn/BUILD
index 2da1445..99a8d11 100644
--- a/tensorflow/python/keras/layers/legacy_rnn/BUILD
+++ b/tensorflow/python/keras/layers/legacy_rnn/BUILD
@@ -41,6 +41,7 @@
         "//tensorflow/python/keras:initializers",
         "//tensorflow/python/keras/engine:input_spec",
         "//tensorflow/python/keras/legacy_tf_layers:layers_base",
+        "//tensorflow/python/keras/saving",
         "//tensorflow/python/keras/utils:tf_utils",
         "//tensorflow/python/training/tracking:base",
     ],
diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py
index 91f3aed..8b44761 100644
--- a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py
+++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py
@@ -346,6 +346,12 @@
   def get_config(self):  # pylint: disable=useless-super-delegation
     return super(RNNCell, self).get_config()
 
+  @property
+  def _use_input_spec_as_call_signature(self):
+    # We do not store the shape information for the state argument in the call
+    # function for legacy RNN cells, so do not generate an input signature.
+    return False
+
 
 class LayerRNNCell(RNNCell):
   """Subclass of RNNCells that act like proper `tf.Layer` objects.
diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py
index b0d287e..69ec0af 100644
--- a/tensorflow/python/keras/layers/local.py
+++ b/tensorflow/python/keras/layers/local.py
@@ -12,8 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Locally-connected layers.
-"""
+"""Locally-connected layers."""
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -60,79 +59,61 @@
   ```
 
   Arguments:
-      filters: Integer, the dimensionality of the output space
-          (i.e. the number of output filters in the convolution).
-      kernel_size: An integer or tuple/list of a single integer,
-          specifying the length of the 1D convolution window.
-      strides: An integer or tuple/list of a single integer,
-          specifying the stride length of the convolution.
-          Specifying any stride value != 1 is incompatible with specifying
-          any `dilation_rate` value != 1.
-      padding: Currently only supports `"valid"` (case-insensitive).
-          `"same"` may be supported in the future.
-          `"valid"` means no padding.
-      data_format: A string,
-          one of `channels_last` (default) or `channels_first`.
-          The ordering of the dimensions in the inputs.
-          `channels_last` corresponds to inputs with shape
-          `(batch, length, channels)` while `channels_first`
-          corresponds to inputs with shape
-          `(batch, channels, length)`.
-          It defaults to the `image_data_format` value found in your
-          Keras config file at `~/.keras/keras.json`.
-          If you never set it, then it will be "channels_last".
-      activation: Activation function to use.
-          If you don't specify anything, no activation is applied
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of output filters in the convolution).
+      kernel_size: An integer or tuple/list of a single integer, specifying the
+        length of the 1D convolution window.
+      strides: An integer or tuple/list of a single integer, specifying the
+        stride length of the convolution.
+      padding: Currently only supports `"valid"` (case-insensitive). `"same"`
+        may be supported in the future. `"valid"` means no padding.
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, length,
+        channels)` while `channels_first` corresponds to inputs with shape
+        `(batch, channels, length)`. It defaults to the `image_data_format`
+        value found in your Keras config file at `~/.keras/keras.json`. If you
+        never set it, then it will be "channels_last".
+      activation: Activation function to use. If you don't specify anything, no
+        activation is applied
           (ie. "linear" activation: `a(x) = x`).
       use_bias: Boolean, whether the layer uses a bias vector.
       kernel_initializer: Initializer for the `kernel` weights matrix.
       bias_initializer: Initializer for the bias vector.
-      kernel_regularizer: Regularizer function applied to
-          the `kernel` weights matrix.
+      kernel_regularizer: Regularizer function applied to the `kernel` weights
+        matrix.
       bias_regularizer: Regularizer function applied to the bias vector.
-      activity_regularizer: Regularizer function applied to
-          the output of the layer (its "activation")..
+      activity_regularizer: Regularizer function applied to the output of the
+        layer (its "activation")..
       kernel_constraint: Constraint function applied to the kernel matrix.
       bias_constraint: Constraint function applied to the bias vector.
-      implementation: implementation mode, either `1`, `2`, or `3`.
-          `1` loops over input spatial locations to perform the forward pass.
-          It is memory-efficient but performs a lot of (small) ops.
-
-          `2` stores layer weights in a dense but sparsely-populated 2D matrix
-          and implements the forward pass as a single matrix-multiply. It uses
-          a lot of RAM but performs few (large) ops.
-
-          `3` stores layer weights in a sparse tensor and implements the forward
-          pass as a single sparse matrix-multiply.
-
+      implementation: implementation mode, either `1`, `2`, or `3`. `1` loops
+        over input spatial locations to perform the forward pass. It is
+        memory-efficient but performs a lot of (small) ops.  `2` stores layer
+        weights in a dense but sparsely-populated 2D matrix and implements the
+        forward pass as a single matrix-multiply. It uses a lot of RAM but
+        performs few (large) ops.  `3` stores layer weights in a sparse tensor
+        and implements the forward pass as a single sparse matrix-multiply.
           How to choose:
-
           `1`: large, dense models,
           `2`: small models,
-          `3`: large, sparse models,
-
-          where "large" stands for large input/output activations
-          (i.e. many `filters`, `input_filters`, large `input_size`,
-          `output_size`), and "sparse" stands for few connections between inputs
-          and outputs, i.e. small ratio
-          `filters * input_filters * kernel_size / (input_size * strides)`,
-          where inputs to and outputs of the layer are assumed to have shapes
-          `(input_size, input_filters)`, `(output_size, filters)`
-          respectively.
-
-          It is recommended to benchmark each in the setting of interest to pick
-          the most efficient one (in terms of speed and memory usage). Correct
-          choice of implementation can lead to dramatic speed improvements (e.g.
-          50X), potentially at the expense of RAM.
-
-          Also, only `padding="valid"` is supported by `implementation=1`.
-
+          `3`: large, sparse models,  where "large" stands for large
+            input/output activations (i.e. many `filters`, `input_filters`,
+            large `input_size`, `output_size`), and "sparse" stands for few
+            connections between inputs and outputs, i.e. small ratio `filters *
+            input_filters * kernel_size / (input_size * strides)`, where inputs
+            to and outputs of the layer are assumed to have shapes `(input_size,
+            input_filters)`, `(output_size, filters)` respectively.  It is
+            recommended to benchmark each in the setting of interest to pick the
+            most efficient one (in terms of speed and memory usage). Correct
+            choice of implementation can lead to dramatic speed improvements
+            (e.g. 50X), potentially at the expense of RAM.  Also, only
+            `padding="valid"` is supported by `implementation=1`.
   Input shape:
       3D tensor with shape: `(batch_size, steps, input_dim)`
-
   Output shape:
-      3D tensor with shape: `(batch_size, new_steps, filters)`
-      `steps` value might have changed due to padding or strides.
+      3D tensor with shape: `(batch_size, new_steps, filters)` `steps` value
+        might have changed due to padding or strides.
   """
 
   def __init__(self,
@@ -159,8 +140,8 @@
     self.padding = conv_utils.normalize_padding(padding)
     if self.padding != 'valid' and implementation == 1:
       raise ValueError('Invalid border mode for LocallyConnected1D '
-                       '(only "valid" is supported if implementation is 1): '
-                       + padding)
+                       '(only "valid" is supported if implementation is 1): ' +
+                       padding)
     self.data_format = conv_utils.normalize_data_format(data_format)
     self.activation = activations.get(activation)
     self.use_bias = use_bias
@@ -182,10 +163,13 @@
       input_dim, input_length = input_shape[2], input_shape[1]
 
     if input_dim is None:
-      raise ValueError('Axis 2 of input should be fully-defined. '
-                       'Found shape:', input_shape)
-    self.output_length = conv_utils.conv_output_length(
-        input_length, self.kernel_size[0], self.padding, self.strides[0])
+      raise ValueError(
+          'Axis 2 of input should be fully-defined. '
+          'Found shape:', input_shape)
+    self.output_length = conv_utils.conv_output_length(input_length,
+                                                       self.kernel_size[0],
+                                                       self.padding,
+                                                       self.strides[0])
 
     if self.implementation == 1:
       self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim,
@@ -200,17 +184,18 @@
 
     elif self.implementation == 2:
       if self.data_format == 'channels_first':
-        self.kernel_shape = (input_dim, input_length,
-                             self.filters, self.output_length)
+        self.kernel_shape = (input_dim, input_length, self.filters,
+                             self.output_length)
       else:
-        self.kernel_shape = (input_length, input_dim,
-                             self.output_length, self.filters)
+        self.kernel_shape = (input_length, input_dim, self.output_length,
+                             self.filters)
 
-      self.kernel = self.add_weight(shape=self.kernel_shape,
-                                    initializer=self.kernel_initializer,
-                                    name='kernel',
-                                    regularizer=self.kernel_regularizer,
-                                    constraint=self.kernel_constraint)
+      self.kernel = self.add_weight(
+          shape=self.kernel_shape,
+          initializer=self.kernel_initializer,
+          name='kernel',
+          regularizer=self.kernel_regularizer,
+          constraint=self.kernel_constraint)
 
       self.kernel_mask = get_locallyconnected_mask(
           input_shape=(input_length,),
@@ -232,8 +217,7 @@
               padding=self.padding,
               filters_in=input_dim,
               filters_out=self.filters,
-              data_format=self.data_format)
-      )
+              data_format=self.data_format))
 
       self.kernel = self.add_weight(
           shape=(len(self.kernel_idxs),),
@@ -243,8 +227,8 @@
           constraint=self.kernel_constraint)
 
     else:
-      raise ValueError('Unrecognized implementation mode: %d.'
-                       % self.implementation)
+      raise ValueError('Unrecognized implementation mode: %d.' %
+                       self.implementation)
 
     if self.use_bias:
       self.bias = self.add_weight(
@@ -292,8 +276,8 @@
                                         self.compute_output_shape(inputs.shape))
 
     else:
-      raise ValueError('Unrecognized implementation mode: %d.'
-                       % self.implementation)
+      raise ValueError('Unrecognized implementation mode: %d.' %
+                       self.implementation)
 
     if self.use_bias:
       output = K.bias_add(output, self.bias, data_format=self.data_format)
@@ -367,87 +351,71 @@
   ```
 
   Arguments:
-      filters: Integer, the dimensionality of the output space
-          (i.e. the number of output filters in the convolution).
-      kernel_size: An integer or tuple/list of 2 integers, specifying the
-          width and height of the 2D convolution window.
-          Can be a single integer to specify the same value for
-          all spatial dimensions.
-      strides: An integer or tuple/list of 2 integers,
-          specifying the strides of the convolution along the width and height.
-          Can be a single integer to specify the same value for
-          all spatial dimensions.
-      padding: Currently only support `"valid"` (case-insensitive).
-          `"same"` will be supported in future.
-          `"valid"` means no padding.
-      data_format: A string,
-          one of `channels_last` (default) or `channels_first`.
-          The ordering of the dimensions in the inputs.
-          `channels_last` corresponds to inputs with shape
-          `(batch, height, width, channels)` while `channels_first`
-          corresponds to inputs with shape
-          `(batch, channels, height, width)`.
-          It defaults to the `image_data_format` value found in your
-          Keras config file at `~/.keras/keras.json`.
-          If you never set it, then it will be "channels_last".
-      activation: Activation function to use.
-          If you don't specify anything, no activation is applied
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of output filters in the convolution).
+      kernel_size: An integer or tuple/list of 2 integers, specifying the width
+        and height of the 2D convolution window. Can be a single integer to
+        specify the same value for all spatial dimensions.
+      strides: An integer or tuple/list of 2 integers, specifying the strides of
+        the convolution along the width and height. Can be a single integer to
+        specify the same value for all spatial dimensions.
+      padding: Currently only support `"valid"` (case-insensitive). `"same"`
+        will be supported in future. `"valid"` means no padding.
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, height, width,
+        channels)` while `channels_first` corresponds to inputs with shape
+        `(batch, channels, height, width)`. It defaults to the
+        `image_data_format` value found in your Keras config file at
+        `~/.keras/keras.json`. If you never set it, then it will be
+        "channels_last".
+      activation: Activation function to use. If you don't specify anything, no
+        activation is applied
           (ie. "linear" activation: `a(x) = x`).
       use_bias: Boolean, whether the layer uses a bias vector.
       kernel_initializer: Initializer for the `kernel` weights matrix.
       bias_initializer: Initializer for the bias vector.
-      kernel_regularizer: Regularizer function applied to
-          the `kernel` weights matrix.
+      kernel_regularizer: Regularizer function applied to the `kernel` weights
+        matrix.
       bias_regularizer: Regularizer function applied to the bias vector.
-      activity_regularizer: Regularizer function applied to
-          the output of the layer (its "activation").
+      activity_regularizer: Regularizer function applied to the output of the
+        layer (its "activation").
       kernel_constraint: Constraint function applied to the kernel matrix.
       bias_constraint: Constraint function applied to the bias vector.
-      implementation: implementation mode, either `1`, `2`, or `3`.
-          `1` loops over input spatial locations to perform the forward pass.
-          It is memory-efficient but performs a lot of (small) ops.
-
-          `2` stores layer weights in a dense but sparsely-populated 2D matrix
-          and implements the forward pass as a single matrix-multiply. It uses
-          a lot of RAM but performs few (large) ops.
-
-          `3` stores layer weights in a sparse tensor and implements the forward
-          pass as a single sparse matrix-multiply.
-
+      implementation: implementation mode, either `1`, `2`, or `3`. `1` loops
+        over input spatial locations to perform the forward pass. It is
+        memory-efficient but performs a lot of (small) ops.  `2` stores layer
+        weights in a dense but sparsely-populated 2D matrix and implements the
+        forward pass as a single matrix-multiply. It uses a lot of RAM but
+        performs few (large) ops.  `3` stores layer weights in a sparse tensor
+        and implements the forward pass as a single sparse matrix-multiply.
           How to choose:
-
           `1`: large, dense models,
           `2`: small models,
-          `3`: large, sparse models,
-
-          where "large" stands for large input/output activations
-          (i.e. many `filters`, `input_filters`, large `np.prod(input_size)`,
-          `np.prod(output_size)`), and "sparse" stands for few connections
-          between inputs and outputs, i.e. small ratio
-          `filters * input_filters * np.prod(kernel_size) / (np.prod(input_size)
-          * np.prod(strides))`, where inputs to and outputs of the layer are
-          assumed to have shapes `input_size + (input_filters,)`,
-          `output_size + (filters,)` respectively.
-
-          It is recommended to benchmark each in the setting of interest to pick
-          the most efficient one (in terms of speed and memory usage). Correct
-          choice of implementation can lead to dramatic speed improvements (e.g.
-          50X), potentially at the expense of RAM.
-
-          Also, only `padding="valid"` is supported by `implementation=1`.
-
+          `3`: large, sparse models,  where "large" stands for large
+            input/output activations (i.e. many `filters`, `input_filters`,
+            large `np.prod(input_size)`, `np.prod(output_size)`), and "sparse"
+            stands for few connections between inputs and outputs, i.e. small
+            ratio `filters * input_filters * np.prod(kernel_size) /
+            (np.prod(input_size) * np.prod(strides))`, where inputs to and
+            outputs of the layer are assumed to have shapes `input_size +
+            (input_filters,)`, `output_size + (filters,)` respectively.  It is
+            recommended to benchmark each in the setting of interest to pick the
+            most efficient one (in terms of speed and memory usage). Correct
+            choice of implementation can lead to dramatic speed improvements
+            (e.g. 50X), potentially at the expense of RAM.  Also, only
+            `padding="valid"` is supported by `implementation=1`.
   Input shape:
-      4D tensor with shape:
-      `(samples, channels, rows, cols)` if data_format='channels_first'
-      or 4D tensor with shape:
-      `(samples, rows, cols, channels)` if data_format='channels_last'.
-
+      4D tensor with shape: `(samples, channels, rows, cols)` if
+        data_format='channels_first'
+      or 4D tensor with shape: `(samples, rows, cols, channels)` if
+        data_format='channels_last'.
   Output shape:
-      4D tensor with shape:
-      `(samples, filters, new_rows, new_cols)` if data_format='channels_first'
-      or 4D tensor with shape:
-      `(samples, new_rows, new_cols, filters)` if data_format='channels_last'.
-      `rows` and `cols` values might have changed due to padding.
+      4D tensor with shape: `(samples, filters, new_rows, new_cols)` if
+        data_format='channels_first'
+      or 4D tensor with shape: `(samples, new_rows, new_cols, filters)` if
+        data_format='channels_last'. `rows` and `cols` values might have changed
+        due to padding.
   """
 
   def __init__(self,
@@ -474,8 +442,8 @@
     self.padding = conv_utils.normalize_padding(padding)
     if self.padding != 'valid' and implementation == 1:
       raise ValueError('Invalid border mode for LocallyConnected2D '
-                       '(only "valid" is supported if implementation is 1): '
-                       + padding)
+                       '(only "valid" is supported if implementation is 1): ' +
+                       padding)
     self.data_format = conv_utils.normalize_data_format(data_format)
     self.activation = activations.get(activation)
     self.use_bias = use_bias
@@ -510,10 +478,8 @@
     self.output_col = output_col
 
     if self.implementation == 1:
-      self.kernel_shape = (
-          output_row * output_col,
-          self.kernel_size[0] * self.kernel_size[1] * input_filter,
-          self.filters)
+      self.kernel_shape = (output_row * output_col, self.kernel_size[0] *
+                           self.kernel_size[1] * input_filter, self.filters)
 
       self.kernel = self.add_weight(
           shape=self.kernel_shape,
@@ -524,17 +490,18 @@
 
     elif self.implementation == 2:
       if self.data_format == 'channels_first':
-        self.kernel_shape = (input_filter, input_row, input_col,
-                             self.filters, self.output_row, self.output_col)
+        self.kernel_shape = (input_filter, input_row, input_col, self.filters,
+                             self.output_row, self.output_col)
       else:
         self.kernel_shape = (input_row, input_col, input_filter,
                              self.output_row, self.output_col, self.filters)
 
-      self.kernel = self.add_weight(shape=self.kernel_shape,
-                                    initializer=self.kernel_initializer,
-                                    name='kernel',
-                                    regularizer=self.kernel_regularizer,
-                                    constraint=self.kernel_constraint)
+      self.kernel = self.add_weight(
+          shape=self.kernel_shape,
+          initializer=self.kernel_initializer,
+          name='kernel',
+          regularizer=self.kernel_regularizer,
+          constraint=self.kernel_constraint)
 
       self.kernel_mask = get_locallyconnected_mask(
           input_shape=(input_row, input_col),
@@ -556,8 +523,7 @@
               padding=self.padding,
               filters_in=input_filter,
               filters_out=self.filters,
-              data_format=self.data_format)
-      )
+              data_format=self.data_format))
 
       self.kernel = self.add_weight(
           shape=(len(self.kernel_idxs),),
@@ -567,8 +533,8 @@
           constraint=self.kernel_constraint)
 
     else:
-      raise ValueError('Unrecognized implementation mode: %d.'
-                       % self.implementation)
+      raise ValueError('Unrecognized implementation mode: %d.' %
+                       self.implementation)
 
     if self.use_bias:
       self.bias = self.add_weight(
@@ -620,8 +586,8 @@
                                         self.compute_output_shape(inputs.shape))
 
     else:
-      raise ValueError('Unrecognized implementation mode: %d.'
-                       % self.implementation)
+      raise ValueError('Unrecognized implementation mode: %d.' %
+                       self.implementation)
 
     if self.use_bias:
       output = K.bias_add(output, self.bias, data_format=self.data_format)
@@ -687,10 +653,10 @@
   `strides`, `padding` and `data_format`.
 
   Arguments:
-    input_shape: tuple of size N: `(d_in1, ..., d_inN)`
-                 spatial shape of the input.
-    kernel_shape: tuple of size N, spatial shape of the convolutional kernel
-                  / receptive field.
+    input_shape: tuple of size N: `(d_in1, ..., d_inN)` spatial shape of the
+      input.
+    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
+      receptive field.
     strides: tuple of size N, strides along each spatial dimension.
     padding: type of padding, string `"same"` or `"valid"`.
     data_format: a string, `"channels_first"` or `"channels_last"`.
@@ -710,8 +676,7 @@
       input_shape=input_shape,
       kernel_shape=kernel_shape,
       strides=strides,
-      padding=padding
-  )
+      padding=padding)
 
   ndims = int(mask.ndim / 2)
 
@@ -740,34 +705,26 @@
   reshapes to make `inputs` and `kernel` 2-D and `output` (N+2)-D.
 
   Arguments:
-      inputs: (N+2)-D tensor with shape
-          `(batch_size, channels_in, d_in1, ..., d_inN)`
-          or
-          `(batch_size, d_in1, ..., d_inN, channels_in)`.
+      inputs: (N+2)-D tensor with shape `(batch_size, channels_in, d_in1, ...,
+        d_inN)` or `(batch_size, d_in1, ..., d_inN, channels_in)`.
       kernel: the unshared weights for N-D convolution,
-          an (N+2)-D tensor of shape:
-          `(d_in1, ..., d_inN, channels_in, d_out2, ..., d_outN, channels_out)`
-          or
-          `(channels_in, d_in1, ..., d_inN, channels_out, d_out2, ..., d_outN)`,
-          with the ordering of channels and spatial dimensions matching
-          that of the input.
-          Each entry is the weight between a particular input and
-          output location, similarly to a fully-connected weight matrix.
-      kernel_mask: a float 0/1 mask tensor of shape:
-           `(d_in1, ..., d_inN, 1, d_out2, ..., d_outN, 1)`
-           or
-           `(1, d_in1, ..., d_inN, 1, d_out2, ..., d_outN)`,
-           with the ordering of singleton and spatial dimensions
-           matching that of the input.
-           Mask represents the connectivity pattern of the layer and is
-           precomputed elsewhere based on layer parameters: stride,
-           padding, and the receptive field shape.
+          an (N+2)-D tensor of shape: `(d_in1, ..., d_inN, channels_in, d_out2,
+            ..., d_outN, channels_out)` or `(channels_in, d_in1, ..., d_inN,
+            channels_out, d_out2, ..., d_outN)`, with the ordering of channels
+            and spatial dimensions matching that of the input. Each entry is the
+            weight between a particular input and output location, similarly to
+            a fully-connected weight matrix.
+      kernel_mask: a float 0/1 mask tensor of shape: `(d_in1, ..., d_inN, 1,
+        d_out2, ..., d_outN, 1)` or `(1, d_in1, ..., d_inN, 1, d_out2, ...,
+        d_outN)`, with the ordering of singleton and spatial dimensions matching
+        that of the input. Mask represents the connectivity pattern of the layer
+        and is
+           precomputed elsewhere based on layer parameters: stride, padding, and
+             the receptive field shape.
       output_shape: a tuple of (N+2) elements representing the output shape:
-          `(batch_size, channels_out, d_out1, ..., d_outN)`
-          or
-          `(batch_size, d_out1, ..., d_outN, channels_out)`,
-          with the ordering of channels and spatial dimensions matching that of
-          the input.
+        `(batch_size, channels_out, d_out1, ..., d_outN)` or `(batch_size,
+        d_out1, ..., d_outN, channels_out)`, with the ordering of channels and
+        spatial dimensions matching that of the input.
 
   Returns:
       Output (N+2)-D tensor with shape `output_shape`.
@@ -778,8 +735,9 @@
   kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2)
 
   output_flat = math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True)
-  output = K.reshape(output_flat,
-                     [K.shape(output_flat)[0],] + output_shape.as_list()[1:])
+  output = K.reshape(output_flat, [
+      K.shape(output_flat)[0],
+  ] + output_shape.as_list()[1:])
   return output
 
 
@@ -811,14 +769,16 @@
   """
   inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1))
   output_flat = gen_sparse_ops.SparseTensorDenseMatMul(
-      a_indices=kernel_idxs, a_values=kernel, a_shape=kernel_shape,
-      b=inputs_flat, adjoint_b=True)
+      a_indices=kernel_idxs,
+      a_values=kernel,
+      a_shape=kernel_shape,
+      b=inputs_flat,
+      adjoint_b=True)
   output_flat_transpose = K.transpose(output_flat)
 
-  output_reshaped = K.reshape(
-      output_flat_transpose,
-      [K.shape(output_flat_transpose)[0],] + output_shape.as_list()[1:]
-  )
+  output_reshaped = K.reshape(output_flat_transpose, [
+      K.shape(output_flat_transpose)[0],
+  ] + output_shape.as_list()[1:])
   return output_reshaped
 
 
@@ -831,7 +791,7 @@
   Arguments:
     tensor: a tensor of shape `(d0, ..., d(N-1))`.
     split_dim: an integer from 1 to N-1, index of the dimension to group
-        dimensions before (excluding) and after (including).
+      dimensions before (excluding) and after (including).
 
   Returns:
     Tensor of shape
diff --git a/tensorflow/python/keras/layers/lstm_v2_test.py b/tensorflow/python/keras/layers/lstm_v2_test.py
index 4c2bbad..1b71bd0 100644
--- a/tensorflow/python/keras/layers/lstm_v2_test.py
+++ b/tensorflow/python/keras/layers/lstm_v2_test.py
@@ -35,6 +35,7 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util as tf_test_util
 from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.layers import recurrent as rnn_v1
@@ -795,6 +796,7 @@
       outputs_trimmed = lstm(inputs[:, :masksteps])
     self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
 
+  @tf_test_util.enable_output_all_intermediates
   def test_v1_session_behavior(self):
     with ops.get_default_graph().as_default():
       # See b/139132348 for more details.
diff --git a/tensorflow/python/keras/layers/pooling_test.py b/tensorflow/python/keras/layers/pooling_test.py
index 10d520c..ebfc9c0 100644
--- a/tensorflow/python/keras/layers/pooling_test.py
+++ b/tensorflow/python/keras/layers/pooling_test.py
@@ -34,16 +34,18 @@
 class GlobalPoolingTest(test.TestCase, parameterized.TestCase):
 
   def test_globalpooling_1d(self):
-    testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
-                             input_shape=(3, 4, 5))
-    testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
-                             kwargs={'data_format': 'channels_first'},
-                             input_shape=(3, 4, 5))
+    testing_utils.layer_test(
+        keras.layers.pooling.GlobalMaxPooling1D, input_shape=(3, 4, 5))
+    testing_utils.layer_test(
+        keras.layers.pooling.GlobalMaxPooling1D,
+        kwargs={'data_format': 'channels_first'},
+        input_shape=(3, 4, 5))
     testing_utils.layer_test(
         keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5))
-    testing_utils.layer_test(keras.layers.pooling.GlobalAveragePooling1D,
-                             kwargs={'data_format': 'channels_first'},
-                             input_shape=(3, 4, 5))
+    testing_utils.layer_test(
+        keras.layers.pooling.GlobalAveragePooling1D,
+        kwargs={'data_format': 'channels_first'},
+        input_shape=(3, 4, 5))
 
   def test_globalpooling_1d_masking_support(self):
     model = keras.Sequential()
@@ -57,9 +59,9 @@
     self.assertAllClose(output[0], model_input[0, 0, :])
 
   def test_globalpooling_1d_with_ragged(self):
-    ragged_data = ragged_factory_ops.constant([
-        [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]],
-        [[1.0, 1.0], [2.0, 2.0]]], ragged_rank=1)
+    ragged_data = ragged_factory_ops.constant(
+        [[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], [[1.0, 1.0], [2.0, 2.0]]],
+        ragged_rank=1)
     dense_data = ragged_data.to_tensor()
 
     inputs = keras.Input(shape=(None, 2), dtype='float32', ragged=True)
@@ -76,9 +78,10 @@
     self.assertAllEqual(output_ragged, output_dense)
 
   def test_globalpooling_2d_with_ragged(self):
-    ragged_data = ragged_factory_ops.constant([
-        [[[1.0], [1.0]], [[2.0], [2.0]], [[3.0], [3.0]]],
-        [[[1.0], [1.0]], [[2.0], [2.0]]]], ragged_rank=1)
+    ragged_data = ragged_factory_ops.constant(
+        [[[[1.0], [1.0]], [[2.0], [2.0]], [[3.0], [3.0]]],
+         [[[1.0], [1.0]], [[2.0], [2.0]]]],
+        ragged_rank=1)
     dense_data = ragged_data.to_tensor()
 
     inputs = keras.Input(shape=(None, 2, 1), dtype='float32', ragged=True)
@@ -94,9 +97,10 @@
     self.assertAllEqual(output_ragged, output_dense)
 
   def test_globalpooling_3d_with_ragged(self):
-    ragged_data = ragged_factory_ops.constant([
-        [[[[1.0]], [[1.0]]], [[[2.0]], [[2.0]]], [[[3.0]], [[3.0]]]],
-        [[[[1.0]], [[1.0]]], [[[2.0]], [[2.0]]]]], ragged_rank=1)
+    ragged_data = ragged_factory_ops.constant(
+        [[[[[1.0]], [[1.0]]], [[[2.0]], [[2.0]]], [[[3.0]], [[3.0]]]],
+         [[[[1.0]], [[1.0]]], [[[2.0]], [[2.0]]]]],
+        ragged_rank=1)
 
     inputs = keras.Input(shape=(None, 2, 1, 1), dtype='float32', ragged=True)
     out = keras.layers.GlobalAveragePooling3D()(inputs)
@@ -162,15 +166,19 @@
   def test_averagepooling_2d(self):
     testing_utils.layer_test(
         keras.layers.AveragePooling2D,
-        kwargs={'strides': (2, 2),
-                'padding': 'same',
-                'pool_size': (2, 2)},
+        kwargs={
+            'strides': (2, 2),
+            'padding': 'same',
+            'pool_size': (2, 2)
+        },
         input_shape=(3, 5, 6, 4))
     testing_utils.layer_test(
         keras.layers.AveragePooling2D,
-        kwargs={'strides': (2, 2),
-                'padding': 'valid',
-                'pool_size': (3, 3)},
+        kwargs={
+            'strides': (2, 2),
+            'padding': 'valid',
+            'pool_size': (3, 3)
+        },
         input_shape=(3, 5, 6, 4))
 
     # This part of the test can only run on GPU but doesn't appear
@@ -194,14 +202,14 @@
 class Pooling3DTest(test.TestCase, parameterized.TestCase):
 
   def test_maxpooling_3d(self):
-    if test.is_built_with_rocm():
-      self.skipTest('Pooling with 3D tensors is not supported in ROCm')
     pool_size = (3, 3, 3)
     testing_utils.layer_test(
         keras.layers.MaxPooling3D,
-        kwargs={'strides': 2,
-                'padding': 'valid',
-                'pool_size': pool_size},
+        kwargs={
+            'strides': 2,
+            'padding': 'valid',
+            'pool_size': pool_size
+        },
         input_shape=(3, 11, 12, 10, 4))
     testing_utils.layer_test(
         keras.layers.MaxPooling3D,
@@ -214,14 +222,14 @@
         input_shape=(3, 4, 11, 12, 10))
 
   def test_averagepooling_3d(self):
-    if test.is_built_with_rocm():
-      self.skipTest('Pooling with 3D tensors is not supported in ROCm')
     pool_size = (3, 3, 3)
     testing_utils.layer_test(
         keras.layers.AveragePooling3D,
-        kwargs={'strides': 2,
-                'padding': 'valid',
-                'pool_size': pool_size},
+        kwargs={
+            'strides': 2,
+            'padding': 'valid',
+            'pool_size': pool_size
+        },
         input_shape=(3, 11, 12, 10, 4))
     testing_utils.layer_test(
         keras.layers.AveragePooling3D,
@@ -242,8 +250,10 @@
       for stride in [1, 2]:
         testing_utils.layer_test(
             keras.layers.MaxPooling1D,
-            kwargs={'strides': stride,
-                    'padding': padding},
+            kwargs={
+                'strides': stride,
+                'padding': padding
+            },
             input_shape=(3, 5, 4))
     testing_utils.layer_test(
         keras.layers.MaxPooling1D,
@@ -255,8 +265,10 @@
       for stride in [1, 2]:
         testing_utils.layer_test(
             keras.layers.AveragePooling1D,
-            kwargs={'strides': stride,
-                    'padding': padding},
+            kwargs={
+                'strides': stride,
+                'padding': padding
+            },
             input_shape=(3, 5, 4))
 
     testing_utils.layer_test(
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index 33d201a..d13473a 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -58,11 +58,11 @@
         "//tensorflow/python:resources",
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:tensor_spec",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/keras/engine",
         "//tensorflow/python/keras/utils:tf_utils",
         "//tensorflow/python/ops/parallel_for:control_flow_ops",
         "//tensorflow/python/ops/ragged:ragged_functional_ops",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
@@ -81,11 +81,11 @@
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:tensor_spec",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/keras/engine",
         "//tensorflow/python/keras/utils:tf_utils",
         "//tensorflow/python/ops/ragged:ragged_array_ops",
         "//tensorflow/python/ops/ragged:ragged_tensor",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
@@ -105,11 +105,11 @@
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:tensor_spec",
         "//tensorflow/python:tensor_util",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/keras/engine",
         "//tensorflow/python/keras/utils:tf_utils",
         "//tensorflow/python/ops/ragged:ragged_functional_ops",
         "//tensorflow/python/ops/ragged:ragged_tensor",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
@@ -132,7 +132,6 @@
         "//tensorflow/python:stateless_random_ops",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:tensor_util",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:variables",
         "//tensorflow/python/compat",
         "//tensorflow/python/eager:context",
@@ -140,6 +139,7 @@
         "//tensorflow/python/keras/engine",
         "//tensorflow/python/keras/engine:input_spec",
         "//tensorflow/python/keras/utils:control_flow_util",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
@@ -179,10 +179,10 @@
         "//tensorflow/python:init_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:tensor_shape",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python/keras:backend",
         "//tensorflow/python/keras/engine",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
@@ -198,8 +198,8 @@
         ":index_lookup",
         ":table_utils",
         "//tensorflow/python:dtypes",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/keras/engine",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -242,7 +242,6 @@
         "//tensorflow/python:string_ops",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:tensor_spec",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/keras:backend",
         "//tensorflow/python/keras/engine",
@@ -250,6 +249,7 @@
         "//tensorflow/python/keras/utils:tf_utils",
         "//tensorflow/python/ops/ragged:ragged_functional_ops",
         "//tensorflow/python/ops/ragged:ragged_string_ops",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
@@ -272,13 +272,13 @@
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:tensor_spec",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python/keras:backend",
         "//tensorflow/python/keras/engine",
         "//tensorflow/python/keras/engine:input_spec",
         "//tensorflow/python/keras/utils:layer_utils",
         "//tensorflow/python/ops/ragged:ragged_tensor",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
@@ -308,8 +308,8 @@
         ":index_lookup",
         ":table_utils",
         "//tensorflow/python:dtypes",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/keras/engine",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
index 2525848..3283374 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
@@ -56,24 +56,36 @@
           expected_output_shape=(None, expected_height, expected_width,
                                  channels))
 
-  @parameterized.named_parameters(
-      ('down_sample_bilinear_2_by_2', {'interpolation': 'bilinear'}, 2, 2),
-      ('down_sample_bilinear_3_by_2', {'interpolation': 'bilinear'}, 3, 2),
-      ('down_sample_nearest_2_by_2', {'interpolation': 'nearest'}, 2, 2),
-      ('down_sample_nearest_3_by_2', {'interpolation': 'nearest'}, 3, 2),
-      ('down_sample_area_2_by_2', {'interpolation': 'area'}, 2, 2),
-      ('down_sample_area_3_by_2', {'interpolation': 'area'}, 3, 2))
+  @parameterized.named_parameters(('down_sample_bilinear_2_by_2', {
+      'interpolation': 'bilinear'
+  }, 2, 2), ('down_sample_bilinear_3_by_2', {
+      'interpolation': 'bilinear'
+  }, 3, 2), ('down_sample_nearest_2_by_2', {
+      'interpolation': 'nearest'
+  }, 2, 2), ('down_sample_nearest_3_by_2', {
+      'interpolation': 'nearest'
+  }, 3, 2), ('down_sample_area_2_by_2', {
+      'interpolation': 'area'
+  }, 2, 2), ('down_sample_area_3_by_2', {
+      'interpolation': 'area'
+  }, 3, 2))
   def test_down_sampling(self, kwargs, expected_height, expected_width):
     with CustomObjectScope({'Resizing': image_preprocessing.Resizing}):
       self._run_test(kwargs, expected_height, expected_width)
 
-  @parameterized.named_parameters(
-      ('up_sample_bilinear_10_by_12', {'interpolation': 'bilinear'}, 10, 12),
-      ('up_sample_bilinear_12_by_12', {'interpolation': 'bilinear'}, 12, 12),
-      ('up_sample_nearest_10_by_12', {'interpolation': 'nearest'}, 10, 12),
-      ('up_sample_nearest_12_by_12', {'interpolation': 'nearest'}, 12, 12),
-      ('up_sample_area_10_by_12', {'interpolation': 'area'}, 10, 12),
-      ('up_sample_area_12_by_12', {'interpolation': 'area'}, 12, 12))
+  @parameterized.named_parameters(('up_sample_bilinear_10_by_12', {
+      'interpolation': 'bilinear'
+  }, 10, 12), ('up_sample_bilinear_12_by_12', {
+      'interpolation': 'bilinear'
+  }, 12, 12), ('up_sample_nearest_10_by_12', {
+      'interpolation': 'nearest'
+  }, 10, 12), ('up_sample_nearest_12_by_12', {
+      'interpolation': 'nearest'
+  }, 12, 12), ('up_sample_area_10_by_12', {
+      'interpolation': 'area'
+  }, 10, 12), ('up_sample_area_12_by_12', {
+      'interpolation': 'area'
+  }, 12, 12))
   def test_up_sampling(self, kwargs, expected_height, expected_width):
     with CustomObjectScope({'Resizing': image_preprocessing.Resizing}):
       self._run_test(kwargs, expected_height, expected_width)
@@ -112,8 +124,9 @@
         expected_output = np.reshape(expected_output, (1, 4, 4, 1))
         self.assertAllEqual(expected_output, output_image)
 
-  @parameterized.named_parameters(
-      ('reshape_bilinear_10_by_4', {'interpolation': 'bilinear'}, 10, 4))
+  @parameterized.named_parameters(('reshape_bilinear_10_by_4', {
+      'interpolation': 'bilinear'
+  }, 10, 4))
   def test_reshaping(self, kwargs, expected_height, expected_width):
     with CustomObjectScope({'Resizing': image_preprocessing.Resizing}):
       self._run_test(kwargs, expected_height, expected_width)
@@ -151,8 +164,8 @@
     kwargs = {'height': expected_height, 'width': expected_width}
     input_images = np.random.random(
         (num_samples, orig_height, orig_width, channels)).astype(np.float32)
-    expected_output = get_numpy_center_crop(
-        input_images, expected_height, expected_width)
+    expected_output = get_numpy_center_crop(input_images, expected_height,
+                                            expected_width)
     with testing_utils.use_gpu():
       testing_utils.layer_test(
           image_preprocessing.CenterCrop,
@@ -163,31 +176,27 @@
           expected_output_shape=(None, expected_height, expected_width,
                                  channels))
 
-  @parameterized.named_parameters(
-      ('center_crop_3_by_4', 3, 4),
-      ('center_crop_3_by_2', 3, 2))
+  @parameterized.named_parameters(('center_crop_3_by_4', 3, 4),
+                                  ('center_crop_3_by_2', 3, 2))
   def test_center_crop_aligned(self, expected_height, expected_width):
     with CustomObjectScope({'CenterCrop': image_preprocessing.CenterCrop}):
       self._run_test(expected_height, expected_width)
 
-  @parameterized.named_parameters(
-      ('center_crop_4_by_5', 4, 5),
-      ('center_crop_4_by_3', 4, 3))
+  @parameterized.named_parameters(('center_crop_4_by_5', 4, 5),
+                                  ('center_crop_4_by_3', 4, 3))
   def test_center_crop_mis_aligned(self, expected_height, expected_width):
     with CustomObjectScope({'CenterCrop': image_preprocessing.CenterCrop}):
       self._run_test(expected_height, expected_width)
 
-  @parameterized.named_parameters(
-      ('center_crop_4_by_6', 4, 6),
-      ('center_crop_3_by_2', 3, 2))
+  @parameterized.named_parameters(('center_crop_4_by_6', 4, 6),
+                                  ('center_crop_3_by_2', 3, 2))
   def test_center_crop_half_mis_aligned(self, expected_height, expected_width):
     with CustomObjectScope({'CenterCrop': image_preprocessing.CenterCrop}):
       self._run_test(expected_height, expected_width)
 
-  @parameterized.named_parameters(
-      ('center_crop_5_by_12', 5, 12),
-      ('center_crop_10_by_8', 10, 8),
-      ('center_crop_10_by_12', 10, 12))
+  @parameterized.named_parameters(('center_crop_5_by_12', 5, 12),
+                                  ('center_crop_10_by_8', 10, 8),
+                                  ('center_crop_10_by_12', 10, 12))
   def test_invalid_center_crop(self, expected_height, expected_width):
     with self.assertRaisesRegex(errors.InvalidArgumentError,
                                 r'assertion failed'):
@@ -218,28 +227,23 @@
           expected_output_shape=(None, expected_height, expected_width,
                                  channels))
 
-  @parameterized.named_parameters(
-      ('random_crop_5_by_12', 5, 12),
-      ('random_crop_10_by_8', 10, 8),
-      ('random_crop_10_by_12', 10, 12))
+  @parameterized.named_parameters(('random_crop_5_by_12', 5, 12),
+                                  ('random_crop_10_by_8', 10, 8),
+                                  ('random_crop_10_by_12', 10, 12))
   def test_invalid_random_crop(self, expected_height, expected_width):
     with self.assertRaises(errors.InvalidArgumentError):
       with CustomObjectScope({'RandomCrop': image_preprocessing.RandomCrop}):
         self._run_test(expected_height, expected_width)
 
   def test_training_with_mock(self):
-    if test.is_built_with_rocm():
-      # TODO(rocm):
-      # re-enable this test once ROCm adds support for
-      # the StatefulUniformFullInt Op (on the GPU)
-      self.skipTest('Feature not supported on ROCm')
     np.random.seed(1337)
     height, width = 3, 4
     height_offset = np.random.randint(low=0, high=3)
     width_offset = np.random.randint(low=0, high=5)
     mock_offset = [0, height_offset, width_offset, 0]
     with test.mock.patch.object(
-        stateless_random_ops, 'stateless_random_uniform',
+        stateless_random_ops,
+        'stateless_random_uniform',
         return_value=mock_offset):
       with testing_utils.use_gpu():
         layer = image_preprocessing.RandomCrop(height, width)
@@ -249,15 +253,9 @@
                               width_offset:(width_offset + width), :]
         self.assertAllClose(expected_output, actual_output)
 
-  @parameterized.named_parameters(
-      ('random_crop_4_by_6', 4, 6),
-      ('random_crop_3_by_2', 3, 2))
+  @parameterized.named_parameters(('random_crop_4_by_6', 4, 6),
+                                  ('random_crop_3_by_2', 3, 2))
   def test_random_crop_output_shape(self, expected_height, expected_width):
-    if test.is_built_with_rocm():
-      # TODO(rocm):
-      # re-enable this test once ROCm adds support for
-      # the StatefulUniformFullInt Op (on the GPU)
-      self.skipTest('Feature not supported on ROCm')
     with CustomObjectScope({'RandomCrop': image_preprocessing.RandomCrop}):
       self._run_test(expected_height, expected_width)
 
@@ -283,8 +281,7 @@
     with testing_utils.use_gpu():
       layer = image_preprocessing.RandomCrop(height, width)
       actual_output = layer(inp, training=0)
-      resized_inp = image_ops.resize_images_v2(
-          inp, size=[5, 3])
+      resized_inp = image_ops.resize_images_v2(inp, size=[5, 3])
       expected_output = resized_inp[:, 1:4, :, :]
       self.assertAllClose(expected_output, actual_output)
 
@@ -310,7 +307,7 @@
 
   @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
   def test_rescaling_base(self):
-    kwargs = {'scale': 1./127.5, 'offset': -1.}
+    kwargs = {'scale': 1. / 127.5, 'offset': -1.}
     testing_utils.layer_test(
         image_preprocessing.Rescaling,
         kwargs=kwargs,
@@ -319,18 +316,18 @@
 
   @testing_utils.run_v2_only
   def test_rescaling_correctness_float(self):
-    layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1.)
+    layer = image_preprocessing.Rescaling(scale=1. / 127.5, offset=-1.)
     inputs = random_ops.random_uniform((2, 4, 5, 3))
     outputs = layer(inputs)
-    self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1)
+    self.assertAllClose(outputs.numpy(), inputs.numpy() * (1. / 127.5) - 1)
 
   @testing_utils.run_v2_only
   def test_rescaling_correctness_int(self):
-    layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1)
+    layer = image_preprocessing.Rescaling(scale=1. / 127.5, offset=-1)
     inputs = random_ops.random_uniform((2, 4, 5, 3), 0, 100, dtype='int32')
     outputs = layer(inputs)
     self.assertEqual(outputs.dtype.name, 'float32')
-    self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1)
+    self.assertAllClose(outputs.numpy(), inputs.numpy() * (1. / 127.5) - 1)
 
   def test_config_with_custom_name(self):
     layer = image_preprocessing.Rescaling(0.5, name='rescaling')
@@ -426,11 +423,7 @@
 @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
 class RandomContrastTest(keras_parameterized.TestCase):
 
-  def _run_test(self,
-                lower,
-                upper,
-                expected_output=None,
-                mock_random=None):
+  def _run_test(self, lower, upper, expected_output=None, mock_random=None):
     np.random.seed(1337)
     num_samples = 2
     orig_height = 5
@@ -452,18 +445,16 @@
         actual_output = layer(inp, training=True)
         self.assertAllClose(expected_output, actual_output)
 
-  @parameterized.named_parameters(
-      ('random_contrast_2_by_5', 0.2, 0.5),
-      ('random_contrast_2_by_13', 0.2, 1.3),
-      ('random_contrast_5_by_2', 0.5, 0.2))
+  @parameterized.named_parameters(('random_contrast_2_by_5', 0.2, 0.5),
+                                  ('random_contrast_2_by_13', 0.2, 1.3),
+                                  ('random_contrast_5_by_2', 0.5, 0.2))
   def test_random_contrast(self, lower, upper):
     with CustomObjectScope(
         {'RandomContrast': image_preprocessing.RandomContrast}):
       self._run_test(lower, upper)
 
-  @parameterized.named_parameters(
-      ('random_contrast_amplitude_2', 0.2),
-      ('random_contrast_amplitude_5', 0.5))
+  @parameterized.named_parameters(('random_contrast_amplitude_2', 0.2),
+                                  ('random_contrast_amplitude_5', 0.5))
   def test_random_contrast_amplitude(self, amplitude):
     with CustomObjectScope(
         {'RandomContrast': image_preprocessing.RandomContrast}):
@@ -1002,8 +993,10 @@
     # pyformat: enable
     transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
     self._run_random_transform_with_mock(
-        transform_matrix, expected_output,
-        mode='constant', interpolation='nearest')
+        transform_matrix,
+        expected_output,
+        mode='constant',
+        interpolation='nearest')
 
     # Test up shift by 1.
     # pyformat: disable
@@ -1016,8 +1009,10 @@
     # pyformat: enable
     transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
     self._run_random_transform_with_mock(
-        transform_matrix, expected_output,
-        mode='constant', interpolation='nearest')
+        transform_matrix,
+        expected_output,
+        mode='constant',
+        interpolation='nearest')
 
     # Test left shift by 1.
     # pyformat: disable
@@ -1030,8 +1025,10 @@
     # pyformat: enable
     transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
     self._run_random_transform_with_mock(
-        transform_matrix, expected_output,
-        mode='constant', interpolation='nearest')
+        transform_matrix,
+        expected_output,
+        mode='constant',
+        interpolation='nearest')
 
     # Test right shift by 1.
     # pyformat: disable
@@ -1044,8 +1041,10 @@
     # pyformat: enable
     transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
     self._run_random_transform_with_mock(
-        transform_matrix, expected_output,
-        mode='constant', interpolation='nearest')
+        transform_matrix,
+        expected_output,
+        mode='constant',
+        interpolation='nearest')
 
 
 @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@@ -1193,8 +1192,7 @@
         self.assertAllEqual(expected_output, output_image)
 
   def test_random_zoom_inference(self):
-    with CustomObjectScope(
-        {'RandomZoom': image_preprocessing.RandomZoom}):
+    with CustomObjectScope({'RandomZoom': image_preprocessing.RandomZoom}):
       input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
       expected_output = input_images
       with testing_utils.use_gpu():
@@ -1239,7 +1237,8 @@
     with test.mock.patch.object(
         gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
       with test.mock.patch.object(
-          gen_stateless_random_ops_v2, 'stateless_random_uniform_v2',
+          gen_stateless_random_ops_v2,
+          'stateless_random_uniform_v2',
           return_value=mock_factor):
         with testing_utils.use_gpu():
           img = np.random.random((12, 5, 8, 3))
@@ -1254,8 +1253,8 @@
         layer = image_preprocessing.RandomHeight(factor=(1., 1.))
         # Return type of RandomHeight() is float32 if `interpolation` is not
         # set to `ResizeMethod.NEAREST_NEIGHBOR`; cast `layer` to desired dtype.
-        output_image = math_ops.cast(layer(np.expand_dims(input_image, axis=0)),
-                                     dtype=dtype)
+        output_image = math_ops.cast(
+            layer(np.expand_dims(input_image, axis=0)), dtype=dtype)
         # pyformat: disable
         expected_output = np.asarray([
             [0, 1, 2],
@@ -1333,7 +1332,8 @@
     with test.mock.patch.object(
         gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
       with test.mock.patch.object(
-          gen_stateless_random_ops_v2, 'stateless_random_uniform_v2',
+          gen_stateless_random_ops_v2,
+          'stateless_random_uniform_v2',
           return_value=mock_factor):
         with testing_utils.use_gpu():
           img = np.random.random((12, 8, 5, 3))
@@ -1348,8 +1348,8 @@
         layer = image_preprocessing.RandomWidth(factor=(1., 1.))
         # Return type of RandomWidth() is float32 if `interpolation` is not
         # set to `ResizeMethod.NEAREST_NEIGHBOR`; cast `layer` to desired dtype.
-        output_image = math_ops.cast(layer(np.expand_dims(input_image, axis=0)),
-                                     dtype=dtype)
+        output_image = math_ops.cast(
+            layer(np.expand_dims(input_image, axis=0)), dtype=dtype)
         # pyformat: disable
         expected_output = np.asarray([
             [0, 0.25, 0.75, 1],
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index b6b6ad0..a19219a 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -28,6 +28,7 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util as tf_test_util
 from tensorflow.python.keras import combinations
 from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import testing_utils
@@ -629,33 +630,39 @@
 
   def test_bidirectional_statefulness(self):
     # Bidirectional and stateful
-    rnn = keras.layers.SimpleRNN
-    samples = 2
-    dim = 2
-    timesteps = 2
-    output_dim = 2
-    mode = 'sum'
+    def run_test():
+      rnn = keras.layers.SimpleRNN
+      samples = 2
+      dim = 2
+      timesteps = 2
+      output_dim = 2
+      mode = 'sum'
 
-    with self.cached_session():
-      x = np.random.random((samples, timesteps, dim))
-      target_dim = 2 * output_dim if mode == 'concat' else output_dim
-      y = np.random.random((samples, target_dim))
+      with self.cached_session():
+        x = np.random.random((samples, timesteps, dim))
+        target_dim = 2 * output_dim if mode == 'concat' else output_dim
+        y = np.random.random((samples, target_dim))
 
-      inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
-      bidi_rnn = keras.layers.Bidirectional(
-          rnn(output_dim, stateful=True), merge_mode=mode)
-      self.assertTrue(bidi_rnn.stateful)
-      output = bidi_rnn(inputs)
-      model = keras.models.Model(inputs, output)
+        inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
+        bidi_rnn = keras.layers.Bidirectional(
+            rnn(output_dim, stateful=True), merge_mode=mode)
+        self.assertTrue(bidi_rnn.stateful)
+        output = bidi_rnn(inputs)
+        model = keras.models.Model(inputs, output)
 
-      y_1 = model.predict(x, batch_size=1)
-      model.reset_states()
-      y_2 = model.predict(x, batch_size=1)
+        y_1 = model.predict(x, batch_size=1)
+        model.reset_states()
+        y_2 = model.predict(x, batch_size=1)
 
-      self.assertAllClose(y_1, y_2)
+        self.assertAllClose(y_1, y_2)
 
-      model.compile(loss='mse', optimizer='sgd')
-      model.fit(x, y, epochs=1, batch_size=1)
+        model.compile(loss='mse', optimizer='sgd')
+        model.fit(x, y, epochs=1, batch_size=1)
+
+    if context.executing_eagerly():
+      run_test()
+    else:
+      tf_test_util.enable_output_all_intermediates(run_test)()
 
   @parameterized.parameters(['sum', 'mul', 'ave', 'concat', None])
   def test_Bidirectional_merged_value(self, merge_mode):
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index d739c16..f7ea1e5 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -12,8 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Built-in loss functions.
-"""
+"""Built-in loss functions."""
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -92,8 +91,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op.
     """
     losses_utils.ReductionV2.validate(reduction)
@@ -122,15 +121,15 @@
         sparse loss functions such as sparse categorical crossentropy where
         shape = `[batch_size, d0, .. dN-1]`
       y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`
-      sample_weight: Optional `sample_weight` acts as a
-        coefficient for the loss. If a scalar is provided, then the loss is
-        simply scaled by the given value. If `sample_weight` is a tensor of size
-        `[batch_size]`, then the total loss for each sample of the batch is
-        rescaled by the corresponding element in the `sample_weight` vector. If
-        the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be
-        broadcasted to this shape), then each loss element of `y_pred` is scaled
+      sample_weight: Optional `sample_weight` acts as a coefficient for the
+        loss. If a scalar is provided, then the loss is simply scaled by the
+        given value. If `sample_weight` is a tensor of size `[batch_size]`, then
+        the total loss for each sample of the batch is rescaled by the
+        corresponding element in the `sample_weight` vector. If the shape of
+        `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted to
+        this shape), then each loss element of `y_pred` is scaled
         by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss
-        functions reduce by 1 dimension, usually axis=-1.)
+          functions reduce by 1 dimension, usually axis=-1.)
 
     Returns:
       Weighted loss float `Tensor`. If `reduction` is `NONE`, this has
@@ -230,8 +229,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: (Optional) name for the loss.
       **kwargs: The keyword arguments that are passed on to `fn`.
     """
@@ -250,8 +249,7 @@
       Loss values per sample.
     """
     if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true):
-      y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
-          y_pred, y_true)
+      y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true)
     ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx())
     return ag_fn(y_true, y_pred, **self._fn_kwargs)
 
@@ -314,8 +312,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to 'mean_squared_error'.
     """
     super(MeanSquaredError, self).__init__(
@@ -373,8 +371,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to 'mean_absolute_error'.
     """
     super(MeanAbsoluteError, self).__init__(
@@ -433,8 +431,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to
         'mean_absolute_percentage_error'.
     """
@@ -494,8 +492,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to
         'mean_squared_logarithmic_error'.
     """
@@ -507,44 +505,64 @@
 class BinaryCrossentropy(LossFunctionWrapper):
   """Computes the cross-entropy loss between true labels and predicted labels.
 
-  Use this cross-entropy loss when there are only two label classes (assumed to
-  be 0 and 1). For each example, there should be a single floating-point value
-  per prediction.
+  Use this cross-entropy loss for binary (0 or 1) classification applications.
+  The loss function requires the following inputs:
 
-  In the snippet below, each of the four examples has only a single
-  floating-pointing value, and both `y_pred` and `y_true` have the shape
-  `[batch_size]`.
+  - `y_true` (true label): This is either 0 or 1.
+  - `y_pred` (predicted value): This is the model's prediction, i.e, a single
+    floating-point value which either represents a
+    [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
+    when `from_logits=True`) or a probability (i.e, value in [0., 1.] when
+    `from_logits=False`).
 
-  Standalone usage:
+  **Recommended Usage:** (set `from_logits=True`)
 
-  >>> y_true = [[0., 1.], [0., 0.]]
-  >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
-  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
-  >>> bce = tf.keras.losses.BinaryCrossentropy()
-  >>> bce(y_true, y_pred).numpy()
-  0.815
-
-  >>> # Calling with 'sample_weight'.
-  >>> bce(y_true, y_pred, sample_weight=[1, 0]).numpy()
-  0.458
-
-   >>> # Using 'sum' reduction type.
-  >>> bce = tf.keras.losses.BinaryCrossentropy(
-  ...     reduction=tf.keras.losses.Reduction.SUM)
-  >>> bce(y_true, y_pred).numpy()
-  1.630
-
-  >>> # Using 'none' reduction type.
-  >>> bce = tf.keras.losses.BinaryCrossentropy(
-  ...     reduction=tf.keras.losses.Reduction.NONE)
-  >>> bce(y_true, y_pred).numpy()
-  array([0.916 , 0.714], dtype=float32)
-
-  Usage with the `tf.keras` API:
+  With `tf.keras` API:
 
   ```python
-  model.compile(optimizer='sgd', loss=tf.keras.losses.BinaryCrossentropy())
+  model.compile(
+    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
+    ....
+  )
   ```
+
+  As a standalone function:
+
+  >>> # Example 1: (batch_size = 1, number of samples = 4)
+  >>> y_true = [0, 1, 0, 0]
+  >>> y_pred = [-18.6, 0.51, 2.94, -12.8]
+  >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
+  >>> bce(y_true, y_pred).numpy()
+  0.865
+
+  >>> # Example 2: (batch_size = 2, number of samples = 4)
+  >>> y_true = [[0, 1], [0, 0]]
+  >>> y_pred = [[-18.6, 0.51], [2.94, -12.8]]
+  >>> # Using default 'auto'/'sum_over_batch_size' reduction type.
+  >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
+  >>> bce(y_true, y_pred).numpy()
+  0.865
+  >>> # Using 'sample_weight' attribute
+  >>> bce(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
+  0.243
+  >>> # Using 'sum' reduction` type.
+  >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,
+  ...     reduction=tf.keras.losses.Reduction.SUM)
+  >>> bce(y_true, y_pred).numpy()
+  1.730
+  >>> # Using 'none' reduction type.
+  >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,
+  ...     reduction=tf.keras.losses.Reduction.NONE)
+  >>> bce(y_true, y_pred).numpy()
+  array([0.235, 1.496], dtype=float32)
+
+  **Default Usage:** (set `from_logits=False`)
+
+  >>> # Make the following updates to the above "Recommended Usage" section
+  >>> # 1. Set `from_logits=False`
+  >>> tf.keras.losses.BinaryCrossentropy() # OR ...('from_logits=False')
+  >>> # 2. Update `y_pred` to use probabilities instead of logits
+  >>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]]
   """
 
   def __init__(self,
@@ -570,8 +588,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: (Optional) Name for the op. Defaults to 'binary_crossentropy'.
     """
     super(BinaryCrossentropy, self).__init__(
@@ -640,9 +658,9 @@
         default, we assume that `y_pred` encodes a probability distribution.
         **Note - Using from_logits=True is more numerically stable.**
       label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
-        meaning the confidence on label values are relaxed. e.g.
-        `label_smoothing=0.2` means that we will use a value of `0.1` for label
-        `0` and `0.9` for label `1`"
+        meaning the confidence on label values are relaxed. For example, if
+        `0.1`, use `0.1 / num_classes` for non-target labels and 
+        `0.9 + 0.1 / num_classes` for target labels.
       reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
         loss. Default value is `AUTO`. `AUTO` indicates that the reduction
         option will be determined by the usage context. For almost all cases
@@ -650,8 +668,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to 'categorical_crossentropy'.
     """
     super(CategoricalCrossentropy, self).__init__(
@@ -727,8 +745,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to
         'sparse_categorical_crossentropy'.
     """
@@ -791,8 +809,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to 'hinge'.
     """
     super(Hinge, self).__init__(hinge, name=name, reduction=reduction)
@@ -852,8 +870,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to 'squared_hinge'.
     """
     super(SquaredHinge, self).__init__(
@@ -912,8 +930,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to 'categorical_hinge'.
     """
     super(CategoricalHinge, self).__init__(
@@ -969,8 +987,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to 'poisson'.
     """
     super(Poisson, self).__init__(poisson, name=name, reduction=reduction)
@@ -1026,8 +1044,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to 'log_cosh'.
     """
     super(LogCosh, self).__init__(log_cosh, name=name, reduction=reduction)
@@ -1086,8 +1104,8 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to 'kl_divergence'.
     """
     super(KLDivergence, self).__init__(
@@ -1154,20 +1172,17 @@
         `tf.distribute.Strategy`, outside of built-in training loops such as
         `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
         will raise an error. Please see this custom training [tutorial](
-          https://www.tensorflow.org/tutorials/distribute/custom_training)
-        for more details.
+          https://www.tensorflow.org/tutorials/distribute/custom_training) for
+            more details.
       name: Optional name for the op. Defaults to 'huber_loss'.
     """
     super(Huber, self).__init__(
         huber, name=name, reduction=reduction, delta=delta)
 
 
-@keras_export('keras.metrics.mean_squared_error',
-              'keras.metrics.mse',
-              'keras.metrics.MSE',
-              'keras.losses.mean_squared_error',
-              'keras.losses.mse',
-              'keras.losses.MSE')
+@keras_export('keras.metrics.mean_squared_error', 'keras.metrics.mse',
+              'keras.metrics.MSE', 'keras.losses.mean_squared_error',
+              'keras.losses.mse', 'keras.losses.MSE')
 @dispatch.add_dispatch_support
 def mean_squared_error(y_true, y_pred):
   """Computes the mean squared error between labels and predictions.
@@ -1198,12 +1213,9 @@
   return K.mean(math_ops.squared_difference(y_pred, y_true), axis=-1)
 
 
-@keras_export('keras.metrics.mean_absolute_error',
-              'keras.metrics.mae',
-              'keras.metrics.MAE',
-              'keras.losses.mean_absolute_error',
-              'keras.losses.mae',
-              'keras.losses.MAE')
+@keras_export('keras.metrics.mean_absolute_error', 'keras.metrics.mae',
+              'keras.metrics.MAE', 'keras.losses.mean_absolute_error',
+              'keras.losses.mae', 'keras.losses.MAE')
 @dispatch.add_dispatch_support
 def mean_absolute_error(y_true, y_pred):
   """Computes the mean absolute error between labels and predictions.
@@ -1232,11 +1244,9 @@
 
 
 @keras_export('keras.metrics.mean_absolute_percentage_error',
-              'keras.metrics.mape',
-              'keras.metrics.MAPE',
+              'keras.metrics.mape', 'keras.metrics.MAPE',
               'keras.losses.mean_absolute_percentage_error',
-              'keras.losses.mape',
-              'keras.losses.MAPE')
+              'keras.losses.mape', 'keras.losses.MAPE')
 @dispatch.add_dispatch_support
 def mean_absolute_percentage_error(y_true, y_pred):
   """Computes the mean absolute percentage error between `y_true` and `y_pred`.
@@ -1269,11 +1279,9 @@
 
 
 @keras_export('keras.metrics.mean_squared_logarithmic_error',
-              'keras.metrics.msle',
-              'keras.metrics.MSLE',
+              'keras.metrics.msle', 'keras.metrics.MSLE',
               'keras.losses.mean_squared_logarithmic_error',
-              'keras.losses.msle',
-              'keras.losses.MSLE')
+              'keras.losses.msle', 'keras.losses.MSLE')
 @dispatch.add_dispatch_support
 def mean_squared_logarithmic_error(y_true, y_pred):
   """Computes the mean squared logarithmic error between `y_true` and `y_pred`.
@@ -1518,7 +1526,9 @@
     y_pred: Tensor of predicted targets.
     from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
       we assume that `y_pred` encodes a probability distribution.
-    label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.
+    label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For
+      example, if `0.1`, use `0.1 / num_classes` for non-target labels
+      and `0.9 + 0.1 / num_classes` for target labels.
 
   Returns:
     Categorical crossentropy loss value.
@@ -1589,7 +1599,9 @@
     y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
     from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
       we assume that `y_pred` encodes a probability distribution.
-    label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.
+    label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by 
+      squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing`
+      for the target class and `0.5 * label_smoothing` for the non-target class.
 
   Returns:
     Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`.
@@ -1609,12 +1621,9 @@
 
 
 @keras_export('keras.metrics.kl_divergence',
-              'keras.metrics.kullback_leibler_divergence',
-              'keras.metrics.kld',
-              'keras.metrics.KLD',
-              'keras.losses.kl_divergence',
-              'keras.losses.kullback_leibler_divergence',
-              'keras.losses.kld',
+              'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld',
+              'keras.metrics.KLD', 'keras.losses.kl_divergence',
+              'keras.losses.kullback_leibler_divergence', 'keras.losses.kld',
               'keras.losses.KLD')
 @dispatch.add_dispatch_support
 def kl_divergence(y_true, y_pred):
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index f05fb91..5b70197 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -37,8 +37,8 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras import activations
 from tensorflow.python.keras import backend as K
-from tensorflow.python.keras.activations import sigmoid
 from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.engine import keras_tensor
@@ -2134,7 +2134,7 @@
     label_weights = None if self.multi_label else self.label_weights
 
     if self._from_logits:
-      y_pred = sigmoid(y_pred)
+      y_pred = activations.sigmoid(y_pred)
 
     with ops.control_dependencies(deps):
       return metrics_utils.update_confusion_matrix_variables(
diff --git a/tensorflow/python/keras/mixed_precision/autocast_variable.py b/tensorflow/python/keras/mixed_precision/autocast_variable.py
index 6882a05..834e10d 100644
--- a/tensorflow/python/keras/mixed_precision/autocast_variable.py
+++ b/tensorflow/python/keras/mixed_precision/autocast_variable.py
@@ -70,12 +70,11 @@
   called.
   """
 
-  def __init__(self, variable, op=None):
+  def __init__(self, variable):
     """Creates an AutoCastVariable instance.
 
     Args:
       variable: A floating-point resource variable to wrap.
-      op: Optional operation of this variable.
 
     Raises:
       ValueError: If `variable` is not a floating-point resource variable
@@ -87,7 +86,11 @@
       raise ValueError('variable must be a floating point variable but has '
                        'type: %s' % variable.dtype.name)
     self._variable = variable
-    self._op = op
+    # 'delegate' means AutoCastVariable.op return self._variable.op, which will
+    # raise an AttributeError in Eager (as intended). If set to any other value,
+    # AutoCastVariable.op returns that value instead, which is used to set the
+    # op attribute in AutoCastVariable.assign().
+    self._op = 'delegate'
 
   def _should_cast(self):
     """Returns True if this variable should be casted when accessed."""
@@ -212,10 +215,18 @@
                            use_locking=None,
                            name=None,
                            read_value=True):
+    # TODO(b/146181571): This logic can be simplified once
+    # DistributedVariable.assign returns a DistributedVariable. Currently for
+    # MirroredStrategy, it returns a Mirrored value.
     if ops.executing_eagerly_outside_functions():
       assign_op = update_fn(value, use_locking, name, False)
       if read_value:
-        return create_autocast_variable(self._variable, op=assign_op)
+        # We create a new AutoCastVariable with the same underlying tf.Variable.
+        # The new AutoCastVariable is identical except the 'op' attribute is
+        # defined. This matches the behavior of tf.Variable.assign.
+        var = create_autocast_variable(self._variable)
+        var._op = assign_op  # pylint:disable=protected-access
+        return var
       return assign_op
 
     # Fallback to wrapping the returned variable in graph mode if possible
@@ -311,9 +322,9 @@
 
   @property
   def op(self):
-    if self._op is not None:
-      return self._op
-    return self._variable.op
+    if self._op == 'delegate':
+      return self._variable.op
+    return self._op
 
   def _as_graph_element(self):
     graph_element = self._variable._as_graph_element()  # pylint:disable=protected-access
@@ -482,7 +493,7 @@
                                         AutoCastVariable._dense_var_to_tensor)  # pylint:disable=protected-access
 
 
-def create_autocast_variable(variable, op=None):
+def create_autocast_variable(variable):
   """Creates an AutoCastVariable that wraps another variable.
 
   This typically just returns `AutoCastVariable(variable)`. But, if the variable
@@ -494,14 +505,13 @@
 
   Args:
     variable: A floating-point resource variable to wrap.
-    op: Optional operation of this variable.
 
   Returns:
     An AutoCastVariable that wraps the variable.
   """
   if not isinstance(variable, (distribute_values.DistributedVariable,
                                ps_distribute_values.AggregatingVariable)):
-    return AutoCastVariable(variable, op=op)
+    return AutoCastVariable(variable)
 
   class AutoCastDistributedVariable(AutoCastVariable, variable.__class__):
     """An AutoCastVariable that also subclasses from variable.__class__.
@@ -524,7 +534,7 @@
              ).format(v=self)
       # pylint: enable=missing-format-attribute
 
-  return AutoCastDistributedVariable(variable, op=op)
+  return AutoCastDistributedVariable(variable)
 
 
 class enable_auto_cast_variables(object):  # pylint:disable=invalid-name
diff --git a/tensorflow/python/keras/mixed_precision/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/autocast_variable_test.py
index c21ff86..6d70aaa 100644
--- a/tensorflow/python/keras/mixed_precision/autocast_variable_test.py
+++ b/tensorflow/python/keras/mixed_precision/autocast_variable_test.py
@@ -37,7 +37,14 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras.mixed_precision import autocast_variable
+from tensorflow.python.keras.optimizer_v2 import adadelta
+from tensorflow.python.keras.optimizer_v2 import adagrad
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.keras.optimizer_v2 import adamax
+from tensorflow.python.keras.optimizer_v2 import ftrl
 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
+from tensorflow.python.keras.optimizer_v2 import nadam
+from tensorflow.python.keras.optimizer_v2 import rmsprop
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variables
@@ -352,11 +359,28 @@
         self.assertAllClose(5., self.evaluate(run_assign()))
 
   @ds_combinations.generate(maybe_distribute)
-  def test_assign_op(self, distribution):
+  def test_op_attribute(self, distribution):
     with distribution.scope():
       x = get_var(0., dtypes.float32)
       x = autocast_variable.create_autocast_variable(x)
 
+      # Variable.op raises an AttributeError in Eager mode and is an op in graph
+      # mode. Variable.assign(...).op is None in Eager mode and an op in Graph
+      # mode or a tf.function. We test this is also true of AutoCastVariable.
+      if context.executing_eagerly():
+        with self.assertRaisesRegex(
+            AttributeError,
+            'Tensor.op is meaningless when eager execution is enabled'):
+          x.op  # pylint: disable=pointless-statement
+        self.assertIsNone(x.assign(1.0).op)
+        self.assertIsNone(x.assign_add(1.0).op)
+        self.assertIsNone(x.assign_sub(1.0).op)
+      else:
+        self.assertIsNotNone(x.op)
+        self.assertIsNotNone(x.assign(1.0).op)
+        self.assertIsNotNone(x.assign_add(1.0).op)
+        self.assertIsNotNone(x.assign_sub(1.0).op)
+
       @def_function.function
       def func():
         self.assertIsNotNone(x.assign(1.0).op)
@@ -503,25 +527,51 @@
             'dtype_to_cast_to=float32 '
             'inner_variable=MirroredVariable.*>')
 
-  @parameterized.named_parameters(
-      ('v1', gradient_descent_v1.GradientDescentOptimizer),
-      ('v2', gradient_descent_v2.SGD))
-  def test_optimizer(self, optimizer_class):
+  @ds_combinations.generate(combinations.combine(
+      optimizer_class=[
+          adadelta.Adadelta,
+          adagrad.Adagrad,
+          adam.Adam,
+          adamax.Adamax,
+          ftrl.Ftrl,
+          gradient_descent_v2.SGD,
+          nadam.Nadam,
+          rmsprop.RMSprop,
+          gradient_descent_v1.GradientDescentOptimizer
+      ],
+      use_tf_function=[False, True]))
+  def test_optimizer(self, optimizer_class, use_tf_function):
+    if use_tf_function and not context.executing_eagerly():
+      self.skipTest('Test does not support graph mode with tf.function')
     x = get_var(1., dtypes.float32)
     x = autocast_variable.create_autocast_variable(x)
-    opt = optimizer_class(1.)
+    y = get_var(1., dtypes.float32)
+    opt = optimizer_class(learning_rate=1.)
 
-    @def_function.function
     def f():
-      opt.minimize(lambda: x + 1., var_list=[x])
+      # Minimize both the AutoCastVariable and the normal tf.Variable. Both
+      # variables should be updated to the same value.
+      op = opt.minimize(lambda: x + y, var_list=[x, y])
+      return None if ops.executing_eagerly_outside_functions() else op
+
+    if use_tf_function:
+      f = def_function.function(f)
 
     if context.executing_eagerly():
       f()
     else:
-      op = f()  # pylint: disable=assignment-from-no-return
+      op = f()
       self.evaluate(variables.global_variables_initializer())
       self.evaluate(op)
-    self.assertEqual(self.evaluate(x), 0)
+    # Assert the AutoCastVariable has changed from its initial value
+    self.assertNotEqual(self.evaluate(x), 1.)
+    # Assert AutoCastVariable is updated correctly by comparing it to the normal
+    # variable
+    self.assertAlmostEqual(self.evaluate(x), self.evaluate(y))
+    if optimizer_class in (gradient_descent_v2.SGD,
+                           gradient_descent_v1.GradientDescentOptimizer):
+      # With SGD, the variables decreases by exactly 1
+      self.assertEqual(self.evaluate(x), 0)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD
index 2c5660c..138741d 100644
--- a/tensorflow/python/keras/optimizer_v2/BUILD
+++ b/tensorflow/python/keras/optimizer_v2/BUILD
@@ -83,8 +83,8 @@
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py
index 404b3f8..7922e42 100644
--- a/tensorflow/python/keras/optimizer_v2/adadelta.py
+++ b/tensorflow/python/keras/optimizer_v2/adadelta.py
@@ -153,7 +153,7 @@
     config = super(Adadelta, self).get_config()
     config.update({
         'learning_rate': self._serialize_hyperparameter('learning_rate'),
-        'decay': self._serialize_hyperparameter('decay'),
+        'decay': self._initial_decay,
         'rho': self._serialize_hyperparameter('rho'),
         'epsilon': self.epsilon,
     })
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py
index 4d3294a..f18b02b 100644
--- a/tensorflow/python/keras/optimizer_v2/adagrad.py
+++ b/tensorflow/python/keras/optimizer_v2/adagrad.py
@@ -157,7 +157,7 @@
     config = super(Adagrad, self).get_config()
     config.update({
         'learning_rate': self._serialize_hyperparameter('learning_rate'),
-        'decay': self._serialize_hyperparameter('decay'),
+        'decay': self._initial_decay,
         'initial_accumulator_value': self._initial_accumulator_value,
         'epsilon': self.epsilon,
     })
diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py
index a1d8b70..02e14b2 100644
--- a/tensorflow/python/keras/optimizer_v2/adam.py
+++ b/tensorflow/python/keras/optimizer_v2/adam.py
@@ -244,7 +244,7 @@
     config = super(Adam, self).get_config()
     config.update({
         'learning_rate': self._serialize_hyperparameter('learning_rate'),
-        'decay': self._serialize_hyperparameter('decay'),
+        'decay': self._initial_decay,
         'beta_1': self._serialize_hyperparameter('beta_1'),
         'beta_2': self._serialize_hyperparameter('beta_2'),
         'epsilon': self.epsilon,
@@ -468,7 +468,7 @@
     config = super(NonFusedAdam, self).get_config()
     config.update({
         'learning_rate': self._serialize_hyperparameter('learning_rate'),
-        'decay': self._serialize_hyperparameter('decay'),
+        'decay': self._initial_decay,
         'beta_1': self._serialize_hyperparameter('beta_1'),
         'beta_2': self._serialize_hyperparameter('beta_2'),
         'epsilon': self.epsilon,
diff --git a/tensorflow/python/keras/optimizer_v2/adamax.py b/tensorflow/python/keras/optimizer_v2/adamax.py
index 26cc59b..a205c3c 100644
--- a/tensorflow/python/keras/optimizer_v2/adamax.py
+++ b/tensorflow/python/keras/optimizer_v2/adamax.py
@@ -180,7 +180,7 @@
     config = super(Adamax, self).get_config()
     config.update({
         'learning_rate': self._serialize_hyperparameter('learning_rate'),
-        'decay': self._serialize_hyperparameter('decay'),
+        'decay': self._initial_decay,
         'beta_1': self._serialize_hyperparameter('beta_1'),
         'beta_2': self._serialize_hyperparameter('beta_2'),
         'epsilon': self.epsilon,
diff --git a/tensorflow/python/keras/optimizer_v2/ftrl.py b/tensorflow/python/keras/optimizer_v2/ftrl.py
index 0bbba96..6525f16 100644
--- a/tensorflow/python/keras/optimizer_v2/ftrl.py
+++ b/tensorflow/python/keras/optimizer_v2/ftrl.py
@@ -234,7 +234,7 @@
         'learning_rate':
             self._serialize_hyperparameter('learning_rate'),
         'decay':
-            self._serialize_hyperparameter('decay'),
+            self._initial_decay,
         'initial_accumulator_value':
             self._initial_accumulator_value,
         'learning_rate_power':
diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent.py b/tensorflow/python/keras/optimizer_v2/gradient_descent.py
index ee7de98..29ddb7f 100644
--- a/tensorflow/python/keras/optimizer_v2/gradient_descent.py
+++ b/tensorflow/python/keras/optimizer_v2/gradient_descent.py
@@ -187,7 +187,7 @@
     config = super(SGD, self).get_config()
     config.update({
         "learning_rate": self._serialize_hyperparameter("learning_rate"),
-        "decay": self._serialize_hyperparameter("decay"),
+        "decay": self._initial_decay,
         "momentum": self._serialize_hyperparameter("momentum"),
         "nesterov": self.nesterov,
     })
diff --git a/tensorflow/python/keras/optimizer_v2/nadam.py b/tensorflow/python/keras/optimizer_v2/nadam.py
index 6efaa90..bca744c 100644
--- a/tensorflow/python/keras/optimizer_v2/nadam.py
+++ b/tensorflow/python/keras/optimizer_v2/nadam.py
@@ -214,7 +214,7 @@
     config = super(Nadam, self).get_config()
     config.update({
         'learning_rate': self._serialize_hyperparameter('learning_rate'),
-        'decay': self._serialize_hyperparameter('decay'),
+        'decay': self._initial_decay,
         'beta_1': self._serialize_hyperparameter('beta_1'),
         'beta_2': self._serialize_hyperparameter('beta_2'),
         'epsilon': self.epsilon,
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
index 49da484..06af303 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -1004,7 +1004,7 @@
       lr_t = math_ops.cast(lr_t(local_step), var_dtype)
     if self._initial_decay > 0.:
       local_step = math_ops.cast(self.iterations, var_dtype)
-      decay_t = self._get_hyper("decay", var_dtype)
+      decay_t = math_ops.cast(self._initial_decay, var_dtype)
       lr_t = lr_t / (1. + decay_t * local_step)
     return lr_t
 
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
index d56ec49..e0acf96 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
@@ -908,6 +908,19 @@
     self.assertAllClose([0., 1.], fn(), atol=1e-4)
     self.assertAllClose([-1, 0.], fn(), atol=1e-4)
 
+  def testBasicWithConstantDecay(self):
+    var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
+    loss = lambda: 3 * var
+    opt = adam.Adam(learning_rate=1.0)
+
+    @def_function.function
+    def fn():
+      opt.minimize(loss, [var])
+      return var
+
+    self.assertAllClose([0., 1.], fn(), atol=1e-4)
+    self.assertAllClose([-1, 0.], fn(), atol=1e-4)
+
   def testVarKeyWithVarCreatedInEager(self):
     a = variables.Variable([1., 2.], name='var')
     b = variables.Variable([1.], name='var')
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py
index 315b0a6..6493b28 100644
--- a/tensorflow/python/keras/optimizer_v2/rmsprop.py
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py
@@ -290,7 +290,7 @@
     config = super(RMSprop, self).get_config()
     config.update({
         "learning_rate": self._serialize_hyperparameter("learning_rate"),
-        "decay": self._serialize_hyperparameter("decay"),
+        "decay": self._initial_decay,
         "rho": self._serialize_hyperparameter("rho"),
         "momentum": self._serialize_hyperparameter("momentum"),
         "epsilon": self.epsilon,
diff --git a/tensorflow/python/keras/preprocessing/text.py b/tensorflow/python/keras/preprocessing/text.py
index 372dc18..2d49fc1 100644
--- a/tensorflow/python/keras/preprocessing/text.py
+++ b/tensorflow/python/keras/preprocessing/text.py
@@ -45,7 +45,7 @@
   Arguments:
       input_text: Input text (string).
       filters: list (or concatenation) of characters to filter out, such as
-          punctuation. Default: `'!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\\t\\n'`,
+          punctuation. Default: ``'!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\\t\\n'``,
             includes basic punctuation, tabs, and newlines.
       lower: boolean. Whether to convert the input to lowercase.
       split: str. Separator for word splitting.
diff --git a/tensorflow/python/keras/preprocessing/timeseries.py b/tensorflow/python/keras/preprocessing/timeseries.py
index 7121c0f..4c77655 100644
--- a/tensorflow/python/keras/preprocessing/timeseries.py
+++ b/tensorflow/python/keras/preprocessing/timeseries.py
@@ -87,34 +87,60 @@
     `shuffle=False`, the dataset will yield batches of sequences
     composed of the following indices:
 
-  ```
-  First sequence:  [0  2  4  6  8 10 12 14 16 18]
-  Second sequence: [3  5  7  9 11 13 15 17 19 21]
-  Third sequence:  [6  8 10 12 14 16 18 20 22 24]
-  ...
-  Last sequence:   [78 80 82 84 86 88 90 92 94 96]
-  ```
+    ```
+    First sequence:  [0  2  4  6  8 10 12 14 16 18]
+    Second sequence: [3  5  7  9 11 13 15 17 19 21]
+    Third sequence:  [6  8 10 12 14 16 18 20 22 24]
+    ...
+    Last sequence:   [78 80 82 84 86 88 90 92 94 96]
+    ```
 
-  In this case the last 3 data points are discarded since no full sequence
-  can be generated to include them (the next sequence would have started
-  at index 81, and thus its last step would have gone over 99).
+    In this case the last 3 data points are discarded since no full sequence
+    can be generated to include them (the next sequence would have started
+    at index 81, and thus its last step would have gone over 99).
 
-  Example 2: temporal regression. Consider an array `data` of scalar
-  values, of shape `(steps,)`. To generate a dataset that uses the past 10
-  timesteps to predict the next timestep, you would use:
+  Example 2: temporal regression. 
+    Consider an array `data` of scalar values, of shape `(steps,)`. 
+    To generate a dataset that uses the past 10
+    timesteps to predict the next timestep, you would use:
 
-  ```python
-  input_data = data
-  offset = 10
-  targets = data[offset:]
-  dataset = tf.keras.preprocessing.timeseries_dataset_from_array(
-      input_data, targets, sequence_length=offset)
-  for batch in dataset:
-    inputs, targets = batch
-    assert np.array_equal(inputs[0], data[:10])  # First sequence: steps [0-9]
-    assert np.array_equal(targets[0], data[10])  # Corresponding target: step 10
-    break
-  ```
+    ```python
+    input_data = data[:-10]
+    targets = data[10:]
+    dataset = tf.keras.preprocessing.timeseries_dataset_from_array(
+        input_data, targets, sequence_length=10)
+    for batch in dataset:
+      inputs, targets = batch
+      assert np.array_equal(inputs[0], data[:10])  # First sequence: steps [0-9]
+      assert np.array_equal(targets[0], data[10])  # Corresponding target: step 10
+      break
+    ```
+
+  Example 3: temporal regression for many-to-many architectures.
+    Consider two arrays of scalar values `X` and `Y`,
+    both of shape `(100,)`. The resulting dataset should consist samples with 
+    20 timestamps each. The samples should not overlap.
+    To generate a dataset that uses the current timestamp 
+    to predict the corresponding target timestep, you would use:
+
+    ```python
+    X = np.arange(100)
+    Y = X*2
+
+    sample_length = 20
+    input_dataset = tf.keras.preprocessing.timeseries_dataset_from_array(
+      X, None, sequence_length=sample_length, sequence_stride=sample_length)
+    target_dataset = tf.keras.preprocessing.timeseries_dataset_from_array(
+      Y, None, sequence_length=sample_length, sequence_stride=sample_length)
+
+    for batch in zip(input_dataset, target_dataset):
+      inputs, targets = batch
+      assert np.array_equal(inputs[0], X[:sample_length])
+
+      # second sample equals output timestamps 20-40
+      assert np.array_equal(targets[1], Y[sample_length:2*sample_length])
+      break
+    ```
   """
   if start_index and (start_index < 0 or start_index >= len(data)):
     raise ValueError('start_index must be higher than 0 and lower than the '
diff --git a/tensorflow/python/keras/saving/hdf5_format.py b/tensorflow/python/keras/saving/hdf5_format.py
index d3bb10c..400f830 100644
--- a/tensorflow/python/keras/saving/hdf5_format.py
+++ b/tensorflow/python/keras/saving/hdf5_format.py
@@ -179,7 +179,9 @@
     model_config = f.attrs.get('model_config')
     if model_config is None:
       raise ValueError('No model found in config file.')
-    model_config = json_utils.decode(model_config.decode('utf-8'))
+    if hasattr(model_config, 'decode'):
+      model_config = model_config.decode('utf-8')
+    model_config = json_utils.decode(model_config)
     model = model_config_lib.model_from_config(model_config,
                                                custom_objects=custom_objects)
 
@@ -189,11 +191,13 @@
     if compile:
       # instantiate optimizer
       training_config = f.attrs.get('training_config')
+      if hasattr(training_config, 'decode'):
+        training_config = training_config.decode('utf-8')
       if training_config is None:
         logging.warning('No training configuration found in the save file, so '
                         'the model was *not* compiled. Compile it manually.')
         return model
-      training_config = json_utils.decode(training_config.decode('utf-8'))
+      training_config = json_utils.decode(training_config)
 
       # Compile model.
       model.compile(**saving_utils.compile_args_from_training_config(
@@ -659,11 +663,15 @@
           and weights file.
   """
   if 'keras_version' in f.attrs:
-    original_keras_version = f.attrs['keras_version'].decode('utf8')
+    original_keras_version = f.attrs['keras_version']
+    if hasattr(original_keras_version, 'decode'):
+      original_keras_version = original_keras_version.decode('utf8')
   else:
     original_keras_version = '1'
   if 'backend' in f.attrs:
-    original_backend = f.attrs['backend'].decode('utf8')
+    original_backend = f.attrs['backend']
+    if hasattr(original_backend, 'decode'):
+      original_backend = original_backend.decode('utf8')
   else:
     original_backend = None
 
@@ -730,11 +738,15 @@
           and weights file and skip_match=False.
   """
   if 'keras_version' in f.attrs:
-    original_keras_version = f.attrs['keras_version'].decode('utf8')
+    original_keras_version = f.attrs['keras_version']
+    if hasattr(original_keras_version, 'decode'):
+      original_keras_version = original_keras_version.decode('utf8')
   else:
     original_keras_version = '1'
   if 'backend' in f.attrs:
-    original_backend = f.attrs['backend'].decode('utf8')
+    original_backend = f.attrs['backend']
+    if hasattr(original_backend, 'decode'):
+      original_backend = original_backend.decode('utf8')
   else:
     original_backend = None
 
@@ -849,13 +861,18 @@
       data: Attributes data.
   """
   if name in group.attrs:
-    data = [n.decode('utf8') for n in group.attrs[name]]
+    data = [
+        n.decode('utf8') if hasattr(n, 'decode') else n
+        for n in group.attrs[name]
+    ]
   else:
     data = []
     chunk_id = 0
     while '%s%d' % (name, chunk_id) in group.attrs:
-      data.extend(
-          [n.decode('utf8') for n in group.attrs['%s%d' % (name, chunk_id)]])
+      data.extend([
+          n.decode('utf8') if hasattr(n, 'decode') else n
+          for n in group.attrs['%s%d' % (name, chunk_id)]
+      ])
       chunk_id += 1
   return data
 
diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py
index 2b1d7b5..4a4c345 100644
--- a/tensorflow/python/keras/saving/save.py
+++ b/tensorflow/python/keras/saving/save.py
@@ -18,11 +18,11 @@
 from __future__ import division
 from __future__ import print_function
 
-import os
 import six
 
 from tensorflow.python import tf2
 from tensorflow.python.keras.saving import hdf5_format
+from tensorflow.python.keras.saving import saving_utils
 from tensorflow.python.keras.saving.saved_model import load as saved_model_load
 from tensorflow.python.keras.saving.saved_model import load_context
 from tensorflow.python.keras.saving.saved_model import save as saved_model_save
@@ -39,12 +39,6 @@
   h5py = None
 # pylint: enable=g-import-not-at-top
 
-_HDF5_EXTENSIONS = ['.h5', '.hdf5', '.keras']
-
-
-# TODO(kathywu): Remove this when Keras SavedModel is not experimental.
-_KERAS_SAVED_MODEL_STILL_EXPERIMENTAL = True
-
 
 @keras_export('keras.models.save_model')
 def save_model(model,
@@ -140,7 +134,7 @@
 
   if (save_format == 'h5' or
       (h5py is not None and isinstance(filepath, h5py.File)) or
-      os.path.splitext(filepath)[1] in _HDF5_EXTENSIONS):
+      saving_utils.is_hdf5_filepath(filepath)):
     # TODO(b/130258301): add utility method for detecting model type.
     if (not model._is_graph_network and  # pylint:disable=protected-access
         not isinstance(model, sequential.Sequential)):
diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py
index 3ebad7c..00c7bb2 100644
--- a/tensorflow/python/keras/saving/save_test.py
+++ b/tensorflow/python/keras/saving/save_test.py
@@ -677,8 +677,8 @@
       self.assertAllClose(out, out2, atol=1e-05)
 
       # Test non-default options in h5
-      with h5py.File('_', driver='core',
-                     backing_store=False) as h5file:
+      with h5py.File(
+          '_', driver='core', mode='w', backing_store=False) as h5file:
         keras.models.save_model(model, h5file)
         loaded_model = keras.models.load_model(h5file)
         out2 = loaded_model.predict(x)
diff --git a/tensorflow/python/keras/saving/save_weights_test.py b/tensorflow/python/keras/saving/save_weights_test.py
index 229a891..1f5fbb4 100644
--- a/tensorflow/python/keras/saving/save_weights_test.py
+++ b/tensorflow/python/keras/saving/save_weights_test.py
@@ -54,11 +54,14 @@
 @combinations.generate(combinations.combine(mode=['graph', 'eager']))
 class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
 
+  def _save_model_dir(self, dirname='saved_model'):
+    temp_dir = self.get_temp_dir()
+    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
+    return os.path.join(temp_dir, dirname)
+
   @keras_parameterized.run_with_all_weight_formats
   def test_weight_loading(self):
-    temp_dir = self.get_temp_dir()
-    self.addCleanup(shutil.rmtree, temp_dir)
-    saved_model_dir = os.path.join(temp_dir, 'saved_model')
+    saved_model_dir = self._save_model_dir()
     save_format = testing_utils.get_save_format()
     with self.cached_session():
       a = keras.layers.Input(shape=(2,))
@@ -213,9 +216,7 @@
     if h5py is None:
       return
 
-    temp_dir = self.get_temp_dir()
-    self.addCleanup(shutil.rmtree, temp_dir)
-    h5_path = os.path.join(temp_dir, 'test.h5')
+    h5_path = self._save_model_dir('test.h5')
 
     num_hidden = 5
     input_dim = 3
@@ -244,9 +245,7 @@
       exclude_formats=['tf_no_traces'])
   def test_nested_model_weight_loading(self):
     save_format = testing_utils.get_save_format()
-    temp_dir = self.get_temp_dir()
-    self.addCleanup(shutil.rmtree, temp_dir)
-    saved_model_dir = os.path.join(temp_dir, 'saved_model')
+    saved_model_dir = self._save_model_dir()
 
     batch_size = 5
     shape = (None, None, 3)
@@ -284,9 +283,7 @@
     if h5py is None:
       return
 
-    temp_dir = self.get_temp_dir()
-    self.addCleanup(shutil.rmtree, temp_dir)
-    h5_path = os.path.join(temp_dir, 'test.h5')
+    h5_path = self._save_model_dir('test.h5')
 
     num_hidden = 5
     input_dim = 3
@@ -326,9 +323,7 @@
     if h5py is None:
       return
 
-    temp_dir = self.get_temp_dir()
-    self.addCleanup(shutil.rmtree, temp_dir)
-    h5_path = os.path.join(temp_dir, 'test.h5')
+    h5_path = self._save_model_dir('test.h5')
 
     num_hidden = 5
     input_dim = 3
@@ -367,6 +362,32 @@
       self.assertAllClose([3.5] * num_classes,
                           keras.backend.get_value(model.layers[1].bias))
 
+  @keras_parameterized.run_with_all_saved_model_formats(
+      exclude_formats=['tf_no_traces'])
+  @keras_parameterized.run_with_all_model_types
+  def test_load_weights_from_saved_model(self):
+    save_path = self._save_model_dir()
+    save_format = testing_utils.get_save_format()
+
+    if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
+      # TODO(b/173646281): HDF5 format currently does not allow saving
+      # subclassed models.
+      return
+
+    with self.cached_session():
+      model = testing_utils.get_small_mlp(1, 4, input_dim=3)
+      data = np.random.random((1, 3))
+      labels = np.random.random((1, 4))
+      model.compile(loss='mse', optimizer='rmsprop')
+      model.fit(data, labels)
+      model.save(save_path, save_format=save_format)
+      new_model = testing_utils.get_small_mlp(1, 4, input_dim=3)
+      if testing_utils.get_model_type() == 'subclass':
+        # Call on test data to build the model.
+        new_model.predict(data)
+      new_model.load_weights(save_path)
+      self.assertAllClose(model.weights, new_model.weights)
+
 
 class SubclassedModel(training.Model):
 
diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py
index 3754568..a2bc7ed 100644
--- a/tensorflow/python/keras/saving/saved_model/load.py
+++ b/tensorflow/python/keras/saving/saved_model/load.py
@@ -146,7 +146,7 @@
 
   # Recreate layers and metrics using the info stored in the metadata.
   keras_loader = KerasObjectLoader(metadata, object_graph_def)
-  keras_loader.load_layers()
+  keras_loader.load_layers(compile=compile)
 
   # Generate a dictionary of all loaded nodes.
   nodes_to_load = {'root': None}
@@ -371,7 +371,7 @@
           obj_child, child_proto, child_id)
       self.loaded_nodes[child_id] = obj_child, setter
 
-  def load_layers(self):
+  def load_layers(self, compile=True):  # pylint: disable=redefined-builtin
     """Load all layer nodes from the metadata."""
     # Load metrics after models and layers, since it's likely that models
     # and layers will create the metric when initialized (this avoids wasting
@@ -387,9 +387,20 @@
           node_metadata.metadata)
 
     for node_metadata in metric_list:
-      self.loaded_nodes[node_metadata.node_id] = self._load_layer(
-          node_metadata.node_id, node_metadata.identifier,
-          node_metadata.metadata)
+      try:
+        self.loaded_nodes[node_metadata.node_id] = self._load_layer(
+            node_metadata.node_id, node_metadata.identifier,
+            node_metadata.metadata)
+      except ValueError:
+        # Metrics are only needed when the model is compiled later. We ignore
+        # errors when trying to load custom metrics when `compile=False` until
+        # custom metrics are serialized properly (b/135550038).
+        if compile:
+          raise
+        logging.warning('Unable to restore custom metric. Please ensure that '
+                        'the layer implements `get_config` and `from_config` '
+                        'when saving. In addition, please use the '
+                        '`custom_objects` arg when calling `load_model()`.')
 
   def _load_layer(self, node_id, identifier, metadata):
     """Load a single layer from a SavedUserObject proto."""
diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
index 9bbb31e..6456a47 100644
--- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py
+++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
@@ -1159,6 +1159,22 @@
       self._test_metric_save_and_load(
           metric, self._save_model_dir(), 1, test_sample_weight=False)
 
+  @keras_parameterized.run_with_all_model_types
+  def test_custom_metric_model(self):
+
+    class CustomMetric(keras.metrics.MeanSquaredError):
+      pass
+
+    model = testing_utils.get_small_mlp(1, 4, input_dim=3)
+    model.compile(loss='mse', optimizer='rmsprop', metrics=[CustomMetric()])
+
+    saved_model_dir = self._save_model_dir()
+    tf_save.save(model, saved_model_dir)
+    with self.assertRaisesRegex(ValueError, 'custom_objects'):
+      keras_load.load(saved_model_dir)
+
+    keras_load.load(saved_model_dir, compile=False)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py
index e459d17..fc092df 100644
--- a/tensorflow/python/keras/saving/saving_utils.py
+++ b/tensorflow/python/keras/saving/saving_utils.py
@@ -321,3 +321,8 @@
           'Compiled the loaded model, but the compiled metrics have yet to '
           'be built. `model.compile_metrics` will be empty until you train '
           'or evaluate the model.')
+
+
+def is_hdf5_filepath(filepath):
+  return (filepath.endswith('.h5') or filepath.endswith('.keras') or
+          filepath.endswith('.hdf5'))
diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD
index 53ac326..7718884 100644
--- a/tensorflow/python/keras/tests/BUILD
+++ b/tensorflow/python/keras/tests/BUILD
@@ -147,7 +147,6 @@
     python_version = "PY3",
     shard_count = 16,
     tags = [
-        "no_rocm",
         "notsan",
     ],
     deps = [
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 3dac22c..dc00408 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -241,7 +241,6 @@
     srcs = ["cholesky_op_test.py"],
     shard_count = 5,
     tags = [
-        "no_rocm",  # TODO(rocm): feature not supported on ROCm platform
         "nomsan",  # TODO(b/131773093): Re-enable.
     ],
     deps = [
@@ -1720,6 +1719,7 @@
     name = "betainc_op_test",
     size = "small",
     srcs = ["betainc_op_test.py"],
+    tags = ["no_rocm"],  # ROCm 3.9 regression
     xla_tags = [
         "no_cuda_asan",  # times out
     ],
@@ -3128,6 +3128,10 @@
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/eager:def_function",
+        "//tensorflow/python/saved_model:load",
+        "//tensorflow/python/saved_model:save",
+        "//tensorflow/python/training/tracking",
         "//third_party/py/numpy",
         "@absl_py//absl/testing:parameterized",
     ],
@@ -3847,7 +3851,6 @@
     size = "medium",
     srcs = ["tridiagonal_matmul_op_test.py"],
     shard_count = 10,
-    tags = ["no_rocm"],
     deps = [
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index c4fc23b..006737f 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1151,15 +1151,11 @@
 
 class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 
-  @test_util.run_deprecated_v1
   def testInvalidSlice(self):
-    with self.cached_session() as sess:
-      foo = constant_op.constant([1, 2, 3])
-      with self.assertRaisesRegex(
-          ValueError, "Sliced assignment"
-          " is only supported for variables"):
-        bar = foo[:2].assign(constant_op.constant([1, 2]))
-        sess.run(bar)
+    foo = constant_op.constant([1, 2, 3])
+    with self.assertRaisesRegex(AttributeError, "no attribute 'assign'"):
+      bar = foo[:2].assign(constant_op.constant([1, 2]))
+      self.evaluate(bar)
 
   def doTestSliceAssign(self, use_resource):
     for dtype in STRIDED_SLICE_TYPES:
diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py
index fdaf321..916b234 100644
--- a/tensorflow/python/kernel_tests/collective_ops_test.py
+++ b/tensorflow/python/kernel_tests/collective_ops_test.py
@@ -43,6 +43,8 @@
 class CollectiveOpsV1(object):
   all_reduce = _collective_ops.all_reduce
   all_gather = _collective_ops.all_gather
+  broadcast_send = _collective_ops.broadcast_send
+  broadcast_recv = _collective_ops.broadcast_recv
 
 
 class CollectiveOpsV2(object):
@@ -63,6 +65,25 @@
     return _collective_ops.all_gather_v2(t, group_size, group_key, instance_key,
                                          *args, **kwargs)
 
+  @staticmethod
+  def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
+                     *args, **kwargs):
+    group_size = array_ops.identity(group_size)
+    group_key = array_ops.identity(group_key)
+    instance_key = array_ops.identity(instance_key)
+    return _collective_ops.broadcast_send_v2(t, group_size, group_key,
+                                             instance_key, *args, **kwargs)
+
+  @staticmethod
+  def broadcast_recv(shape, dtype, group_size, group_key, instance_key, *args,
+                     **kwargs):
+    group_size = array_ops.identity(group_size)
+    group_key = array_ops.identity(group_key)
+    instance_key = array_ops.identity(instance_key)
+    shape = array_ops.identity(shape)
+    return _collective_ops.broadcast_recv_v2(
+        shape, dtype, group_size, group_key, instance_key, *args, **kwargs)
+
 
 device_combination = (
     combinations.combine(device='CPU', communication='RING', required_gpus=0) +
@@ -191,6 +212,42 @@
     for result in run_all_gather_2devices():
       self.assertAllClose(result, [1., 1.], rtol=1e-5, atol=1e-5)
 
+  def testBroadcast(self, collective_ops, device, communication):
+    dev0 = '/device:%s:0' % device
+    dev1 = '/device:%s:1' % device
+
+    @def_function.function
+    def run_broadcast_2devices():
+      shape = [3]
+      in_value = constant_op.constant([1., 2., 3.], shape=shape)
+      group_size = 2
+      group_key = 2
+      instance_key = 2
+      collectives = []
+      with ops.device(dev0):
+        collectives.append(
+            collective_ops.broadcast_send(
+                in_value,
+                shape,
+                in_value.dtype,
+                group_size,
+                group_key,
+                instance_key,
+                communication_hint=communication))
+      with ops.device(dev1):
+        collectives.append(
+            collective_ops.broadcast_recv(
+                shape,
+                in_value.dtype,
+                group_size,
+                group_key,
+                instance_key,
+                communication_hint=communication))
+      return collectives
+
+    for result in run_broadcast_2devices():
+      self.assertAllClose(result, [1., 2., 3.], rtol=1e-5, atol=1e-5)
+
   def testInstanceKeyScopedUnderGroupKey(self, collective_ops, device,
                                          communication):
     if device == 'GPU' and context.num_gpus() < 4:
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 7ccf4b8..3289771 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -19,11 +19,14 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import function
+from tensorflow.python.eager import remote
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -288,23 +291,24 @@
 
   @test_util.run_v1_only("b/120545219")
   def testDefunInCond(self):
-    x = constant_op.constant(1.0, name="x")
-    y = constant_op.constant(2.0, name="y")
+    with ops.Graph().as_default():
+      x = constant_op.constant(1.0, name="x")
+      y = constant_op.constant(2.0, name="y")
 
-    def true_fn():
+      def true_fn():
 
-      @function.defun
-      def fn():
-        return x * y * 2.0
+        @function.defun
+        def fn():
+          return x * y * 2.0
 
-      return fn()
+        return fn()
 
-    def false_fn():
-      return 2.0
+      def false_fn():
+        return 2.0
 
-    self._testCond(true_fn, false_fn, [x])
-    self._testCond(true_fn, false_fn, [x, y])
-    self._testCond(true_fn, false_fn, [y])
+      self._testCond(true_fn, false_fn, [x])
+      self._testCond(true_fn, false_fn, [x, y])
+      self._testCond(true_fn, false_fn, [y])
 
   @test_util.run_deprecated_v1
   def testNestedDefunInCond(self):
@@ -942,24 +946,23 @@
     self.assertAllEqual(self.evaluate(fn_output), [2.0, 4.0])
 
   def testGradientTapeOfCondWithResourceVariableInFunction(self):
-    with context.eager_mode():
-      v = variables.Variable(2.)
+    v = variables.Variable(2.)
 
-      @def_function.function
-      def fn_with_cond():
-        with backprop.GradientTape() as tape:
-          pred = constant_op.constant(True, dtype=dtypes.bool)
+    @def_function.function
+    def fn_with_cond():
+      with backprop.GradientTape() as tape:
+        pred = constant_op.constant(True, dtype=dtypes.bool)
 
-          def true_fn():
-            return math_ops.pow(v, 3)
+        def true_fn():
+          return math_ops.pow(v, 3)
 
-          def false_fn():
-            return v
+        def false_fn():
+          return v
 
-          cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
-        return tape.gradient(cond, v)
+        cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
+      return tape.gradient(cond, v)
 
-      self.assertAllEqual(fn_with_cond(), 12.0)
+    self.assertAllEqual(fn_with_cond(), 12.0)
 
   def _CheckIteratedCosGradients(self, func):
 
@@ -1458,9 +1461,10 @@
       self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
 
 
-class CondV2ColocationGroupAndDeviceTest(test.TestCase):
+class CondV2ColocationGroupAndDeviceTest(test.TestCase, parameterized.TestCase):
 
   def setUp(self):
+    context._reset_context()
     super(CondV2ColocationGroupAndDeviceTest, self).setUp()
     cpus = context.context().list_physical_devices("CPU")
     context.context().set_logical_device_configuration(
@@ -1468,6 +1472,8 @@
             context.LogicalDeviceConfiguration(),
             context.LogicalDeviceConfiguration()
         ])
+    workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
+    remote.connect_to_remote_host(workers[0].target)
 
   def testColocateWithBeforeCond(self):
     with ops.Graph().as_default() as g:
@@ -1544,64 +1550,113 @@
         self.assertTrue(len(run_metadata.partition_graphs) >= 2)
 
   def testDeviceBeforeCond(self):
-    with context.eager_mode():
-      def fn():
-        cpu_zero_op = test_ops.device_placement_op()
-        self.assertEqual("/device:CPU:0", cpu_zero_op.device)
-        with ops.device("CPU:1"):
-          cpu_one_op = test_ops.device_placement_op()
-          self.assertEqual("/device:CPU:1", cpu_one_op.device)
-        return cpu_zero_op, cpu_one_op
 
-      @def_function.function
-      def _cond_wrapper():
-        with ops.device("/device:CPU:0"):
-          return cond_v2.cond_v2(constant_op.constant(True), fn, fn)
+    def fn():
+      cpu_zero_op = test_ops.device_placement_op()
+      self.assertEqual("/job:localhost/device:CPU:0", cpu_zero_op.device)
+      with ops.device("CPU:1"):
+        cpu_one_op = test_ops.device_placement_op()
+        self.assertEqual("/job:localhost/device:CPU:1", cpu_one_op.device)
+      return cpu_zero_op, cpu_one_op
 
-      zero_expected, one_expected = self.evaluate(_cond_wrapper())
-      self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
-      self.assertIn(compat.as_bytes("CPU:1"), one_expected)
+    @def_function.function
+    def _cond_wrapper():
+      with ops.device("/job:localhost/device:CPU:0"):
+        return cond_v2.cond_v2(constant_op.constant(True), fn, fn)
 
-      def fn2():
-        self.assertEqual("/device:GPU:0", constant_op.constant(3.0).op.device)
-        return test_ops.device_placement_op()
+    zero_expected, one_expected = self.evaluate(_cond_wrapper())
+    self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
+    self.assertIn(compat.as_bytes("CPU:1"), one_expected)
+    self.assertIn(compat.as_bytes("job:localhost"), zero_expected)
+    self.assertIn(compat.as_bytes("job:localhost"), one_expected)
 
-      @def_function.function
-      def _cond_wrapper2():
-        with ops.device("/device:GPU:0"):
-          return cond_v2.cond_v2(constant_op.constant(True), fn2, fn2)
+    def fn2():
+      self.assertEqual("/job:localhost/device:GPU:0",
+                       constant_op.constant(3.0).op.device)
+      return test_ops.device_placement_op()
 
-      if test_util.is_gpu_available():
-        self.assertIn(compat.as_bytes("GPU:0"),
-                      self.evaluate(_cond_wrapper2()))
-      else:
-        self.skipTest("Test requires a GPU to check GPU device placement.")
+    @def_function.function
+    def _cond_wrapper2():
+      with ops.device("/job:localhost/device:GPU:0"):
+        return cond_v2.cond_v2(constant_op.constant(True), fn2, fn2)
+
+    if test_util.is_gpu_available():
+      self.assertIn(compat.as_bytes("GPU:0"), self.evaluate(_cond_wrapper2()))
+      self.assertIn(
+          compat.as_bytes("job:localhost"), self.evaluate(_cond_wrapper2()))
+    else:
+      self.skipTest("Test requires a GPU to check GPU device placement.")
+
+  @parameterized.named_parameters([
+      dict(
+          testcase_name="Function",
+          functional_op_to_test=lambda fn: def_function.function(fn)()),
+      dict(
+          testcase_name="Cond",
+          functional_op_to_test=
+          lambda fn: cond_v2.cond_v2(constant_op.constant(True), fn, fn))
+  ])
+  def testDeviceBeforeRemote(self, functional_op_to_test):
+    context.context().log_device_placement = True
+
+    def _fn():
+      local_op = test_ops.device_placement_op()
+      with ops.device("/job:worker/CPU:0"):
+        worker_op = test_ops.device_placement_op()
+      return local_op, worker_op
+
+    @def_function.function
+    def _wrapper():
+      with ops.device("/job:localhost"):
+        return functional_op_to_test(_fn)
+
+    local_expected, worker_expected = self.evaluate(_wrapper())
+    self.assertIn(compat.as_bytes("job:localhost"), local_expected)
+    self.assertIn(compat.as_bytes("job:worker"), worker_expected)
+
+    del _fn, _wrapper
+
+    # There's nothing special about localhost; if we swap roles (functional op
+    # on worker, op on localhost) the inner placement still wins.
+    def _fn2():
+      local_op = test_ops.device_placement_op()
+      with ops.device("/job:localhost/CPU:0"):
+        worker_op = test_ops.device_placement_op()
+      return local_op, worker_op
+
+    @def_function.function
+    def _wrapper2():
+      with ops.device("/job:worker"):
+        return functional_op_to_test(_fn2)
+
+    worker_expected, local_expected = self.evaluate(_wrapper2())
+    self.assertIn(compat.as_bytes("job:worker"), worker_expected)
+    self.assertIn(compat.as_bytes("job:localhost"), local_expected)
 
   def testColocationBeforeCond(self):
-    with context.eager_mode():
 
-      def _fn():
-        result = test_ops.device_placement_op()
-        self.assertIn("colocation_test_op",
-                      result.op.colocation_groups()[0].decode())
-        return result
+    def _fn():
+      result = test_ops.device_placement_op()
+      self.assertIn("colocation_test_op",
+                    result.op.colocation_groups()[0].decode())
+      return result
 
-      @def_function.function(autograph=False)
-      def _cond_wrapper():
-        with ops.device("/device:CPU:0"):
-          op_on_cpu_0 = test_ops.device_placement_op(name="colocation_test_op")
-        with ops.device("/device:CPU:1"):
-          op_on_cpu_1 = test_ops.device_placement_op(
-              name="colocation_test_op_1")
-        condition = constant_op.constant(True)
-        with ops.colocate_with(op_on_cpu_0.op):
-          zero_expected = cond_v2.cond_v2(condition, _fn, _fn)
-        with ops.colocate_with(op_on_cpu_1.op):
-          one_expected = cond_v2.cond_v2(condition, _fn, _fn)
-        return zero_expected, one_expected
-      zero_expected, one_expected = self.evaluate(_cond_wrapper())
-      self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
-      self.assertIn(compat.as_bytes("CPU:1"), one_expected)
+    @def_function.function(autograph=False)
+    def _cond_wrapper():
+      with ops.device("/device:CPU:0"):
+        op_on_cpu_0 = test_ops.device_placement_op(name="colocation_test_op")
+      with ops.device("/device:CPU:1"):
+        op_on_cpu_1 = test_ops.device_placement_op(name="colocation_test_op_1")
+      condition = constant_op.constant(True)
+      with ops.colocate_with(op_on_cpu_0.op):
+        zero_expected = cond_v2.cond_v2(condition, _fn, _fn)
+      with ops.colocate_with(op_on_cpu_1.op):
+        one_expected = cond_v2.cond_v2(condition, _fn, _fn)
+      return zero_expected, one_expected
+
+    zero_expected, one_expected = self.evaluate(_cond_wrapper())
+    self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
+    self.assertIn(compat.as_bytes("CPU:1"), one_expected)
 
   def testDeviceInAndOutOfCond(self):
     with ops.Graph().as_default() as g:
@@ -1702,4 +1757,5 @@
 
 
 if __name__ == "__main__":
+  ops.enable_eager_execution()
   test.main()
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 532dac1..54bbd2b 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -730,8 +730,6 @@
         g for g in run_metadata.partition_graphs
         if device_str in g.node[0].device
     ]
-    if not device_graphs:
-      return 0
     self.assertLen(device_graphs, 1)
     switch_nodes = [
         n for n in device_graphs[0].node
@@ -761,6 +759,7 @@
       options = config_pb2.RunOptions(output_partition_graphs=True)
       sess.run(
           r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
+      self.assertLen(run_metadata.partition_graphs, 2)
       # Check that the Switch for `arg` gets placed on CPU.
       self.assertEqual(
           self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD
index d4f29d4..7bdcdef 100644
--- a/tensorflow/python/kernel_tests/distributions/BUILD
+++ b/tensorflow/python/kernel_tests/distributions/BUILD
@@ -61,6 +61,7 @@
     size = "small",
     srcs = ["beta_test.py"],
     tags = [
+        "no_rocm",  # ROCm 3.9 regression
         "notsan",  # b/173653918
     ],
     xla_tags = [
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index 097183d..0da8a0e 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -149,7 +149,6 @@
     srcs = ["linear_operator_circulant_test.py"],
     shard_count = 10,
     tags = [
-        "no_rocm",  # calls BLAS ops for complex types
         "noasan",  # times out, b/63678675
         "optonly",  # times out, b/79171797
     ],
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_test.py
index 0100eb4..b2bc189 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_test.py
@@ -54,7 +54,6 @@
     self._stored_shape = shape
     super(LinearOperatorShape, self).__init__(
         dtype=dtypes.float32,
-        graph_parents=None,
         is_non_singular=is_non_singular,
         is_self_adjoint=is_self_adjoint,
         is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/kernel_tests/linalg/sparse/BUILD b/tensorflow/python/kernel_tests/linalg/sparse/BUILD
index 96ebc38..0352ae7 100644
--- a/tensorflow/python/kernel_tests/linalg/sparse/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/sparse/BUILD
@@ -40,7 +40,10 @@
     srcs = ["csr_sparse_matrix_ops_test.py"],
     main = "csr_sparse_matrix_ops_test.py",
     shard_count = 10,
-    tags = ["notsan"],  # b/149115441
+    tags = [
+        "no_rocm",  # ROCm 3.8 regression
+        "notsan",  # b/149115441
+    ],
     deps = [
         "//tensorflow/python/ops/linalg/sparse",
         "//tensorflow/python/ops/linalg/sparse:gen_sparse_csr_matrix_ops",
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index 411087d..796caa4 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -192,30 +192,25 @@
       for extra in [(), (2,), (3,)] + [(3, 2)] * (size < 10):
         for adjoint in False, True:
           shape = extra + (size, size)
-          name = '%s_%s_adj_%s' % (dtype.__name__, '_'.join(map(str, shape)),
-                                   str(adjoint))
-          _AddTest(MatrixBinaryFunctorGradientTest, 'MatrixSolveGradient', name,
-                   _GetMatrixBinaryFunctorGradientTest(
-                       linalg_ops.matrix_solve, dtype, shape, adjoint=adjoint))
+          name = '%s_%s_adj_%s' % (dtype.__name__, '_'.join(map(
+              str, shape)), str(adjoint))
+          _AddTest(
+              MatrixBinaryFunctorGradientTest, 'MatrixSolveGradient', name,
+              _GetMatrixBinaryFunctorGradientTest(
+                  linalg_ops.matrix_solve, dtype, shape, adjoint=adjoint))
 
           for lower in True, False:
             name = '%s_low_%s' % (name, lower)
-            if (name == 'float32_10_10_adj_False_low_True') and \
-               test_lib.is_built_with_rocm():
-              # Skip this one particular subtest on the ROCm platform
-              # It will fail because of 1 element in 10,000 mismatch,
-              # and the mismatch is minor (tolerance is 0.20, mismatch is 0,22)
-              # TODO(rocm) : investigate cause of mismatch and fix
-              continue
-            _AddTest(MatrixBinaryFunctorGradientTest,
-                     'MatrixTriangularSolveGradient', name,
-                     _GetMatrixBinaryFunctorGradientTest(
-                         linalg_ops.matrix_triangular_solve,
-                         dtype,
-                         shape,
-                         float32_tol_fudge=4.0,
-                         adjoint=adjoint,
-                         lower=lower))
+            _AddTest(
+                MatrixBinaryFunctorGradientTest,
+                'MatrixTriangularSolveGradient', name,
+                _GetMatrixBinaryFunctorGradientTest(
+                    linalg_ops.matrix_triangular_solve,
+                    dtype,
+                    shape,
+                    float32_tol_fudge=4.0,
+                    adjoint=adjoint,
+                    lower=lower))
 
             band_shape = extra + (size // 2 + 1, size)
             name = '%s_%s_adj_%s_low_%s' % (dtype.__name__, '_'.join(
@@ -239,9 +234,10 @@
       for extra in [(), (2,), (3,)] + [(3, 2)] * (size < 10):
         shape = extra + (size, size)
         name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
-        _AddTest(MatrixUnaryFunctorGradientTest, 'MatrixInverseGradient', name,
-                 _GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_inverse,
-                                                    dtype, shape))
+        _AddTest(
+            MatrixUnaryFunctorGradientTest, 'MatrixInverseGradient', name,
+            _GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_inverse, dtype,
+                                               shape))
         if not test_lib.is_built_with_rocm():
           # TODO(rocm) :
           # re-enable this test when upstream issues are resolved
@@ -258,8 +254,8 @@
             MatrixUnaryFunctorGradientTest, 'LogMatrixDeterminantGradient',
             name,
             _GetMatrixUnaryFunctorGradientTest(
-                lambda x: linalg_ops.log_matrix_determinant(x)[1],
-                dtype, shape))
+                lambda x: linalg_ops.log_matrix_determinant(x)[1], dtype,
+                shape))
 
         # The numerical Jacobian is consistently invalid for these four shapes
         # because the matrix square root of the perturbed input doesn't exist
@@ -278,8 +274,8 @@
       for cols in 2, 5, 10:
         for l2_regularization in 1e-6, 0.001, 1.0:
           shape = (rows, cols)
-          name = '%s_%s_%s' % (dtype.__name__, '_'.join(map(str, shape)),
-                               l2_regularization)
+          name = '%s_%s_%s' % (dtype.__name__, '_'.join(map(
+              str, shape)), l2_regularization)
           float32_tol_fudge = 5.1 if l2_regularization == 1e-6 else 4.0
           _AddTest(
               MatrixBinaryFunctorGradientTest,
@@ -287,10 +283,7 @@
               name,
               # pylint: disable=long-lambda,g-long-lambda
               _GetMatrixBinaryFunctorGradientTest(
-                  (lambda a, b, l=l2_regularization:
-                   linalg_ops.matrix_solve_ls(a, b, l)),
-                  dtype,
-                  shape,
-                  float32_tol_fudge))
+                  (lambda a, b, l=l2_regularization: linalg_ops.matrix_solve_ls(
+                      a, b, l)), dtype, shape, float32_tol_fudge))
 
   test_lib.main()
diff --git a/tensorflow/python/kernel_tests/map_fn_test.py b/tensorflow/python/kernel_tests/map_fn_test.py
index af0f8e9..68c2694 100644
--- a/tensorflow/python/kernel_tests/map_fn_test.py
+++ b/tensorflow/python/kernel_tests/map_fn_test.py
@@ -138,10 +138,11 @@
       elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
       y = map_fn.map_fn(
           lambda x: math_ops.multiply(math_ops.square(x), param), elems)
-      r = gradients_impl.gradients(y, param)[0]
-      self.assertAllEqual(91.0, self.evaluate(r))
-      r = gradients_impl.gradients(y, elems)[0]
-      self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r))
+      r_param = gradients_impl.gradients(y, param)[0]
+      r_elems = gradients_impl.gradients(y, elems)[0]
+      self.assertAllEqual(91.0, self.evaluate(r_param))
+      self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0],
+                          self.evaluate(r_elems))
 
   @test_util.run_in_graph_and_eager_modes
   def testMap_SimpleNotTensor(self):
diff --git a/tensorflow/python/kernel_tests/pool_test.py b/tensorflow/python/kernel_tests/pool_test.py
index b554413..0e6bbeb 100644
--- a/tensorflow/python/kernel_tests/pool_test.py
+++ b/tensorflow/python/kernel_tests/pool_test.py
@@ -274,9 +274,6 @@
               strides=[1, 2],
               dilation_rate=[1, 1],
               data_format="NCHW")
-          if test.is_built_with_rocm():
-            # Pooling with 3D tensors is not supported in ROCm
-            continue
           self._test(
               input_shape=[2, 2, 7, 5, 3],
               window_shape=[2, 2, 2],
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 20699f5..98e043e 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -1004,12 +1004,14 @@
     ]
 
     Config = collections.namedtuple(
-        "Config", ["use_gpu", "include_batch_in_index", "argmax"])
+        "Config", ["use_gpu", "include_batch_in_index", "argmax", "Targmax"])
     configs = [
-        Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8]),
-        Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17]),
-        Config(True, False, [0, 1, 3, 5, 0, 2, 6, 8]),
-        Config(True, True, [0, 1, 3, 5, 9, 11, 15, 17])
+        Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int64),
+        Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int64),
+        Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int32),
+        Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int32),
+        Config(True, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int64),
+        Config(True, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int64),
     ]
 
     for config in configs:
@@ -1019,7 +1021,7 @@
             t,
             ksize=[1, 2, 2, 1],
             strides=[1, 1, 1, 1],
-            Targmax=dtypes.int64,
+            Targmax=config.Targmax,
             padding="VALID",
             include_batch_in_index=config.include_batch_in_index)
         out, argmax = self.evaluate([out_op, argmax_op])
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index 7fa31d1..bb47d60 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -26,14 +26,16 @@
 
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import   array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import init_ops
@@ -47,6 +49,9 @@
 from tensorflow.python.ops import variables as variables_lib
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging
+from tensorflow.python.saved_model import load
+from tensorflow.python.saved_model import save
+from tensorflow.python.training.tracking import tracking
 from tensorflow.python.training.tracking import util as trackable_utils
 from tensorflow.python.util import nest
 
@@ -3060,6 +3065,29 @@
     reconstructed_wrapper = wrapper_cls.from_config(config_copy)
     self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
 
+  def testSavedModel(self):
+    if test_util.is_gpu_available():
+      self.skipTest("b/175887901")
+
+    with self.cached_session():
+      root = tracking.AutoTrackable()
+      root.cell = rnn_cell_impl.LSTMCell(8)
+      @def_function.function(input_signature=[tensor_spec.TensorSpec([3, 8])])
+      def call(x):
+        state = root.cell.zero_state(3, dtype=x.dtype)
+        y, _ = root.cell(x, state)
+        return y
+      root.call = call
+      expected = root.call(array_ops.zeros((3, 8)))
+      self.evaluate(variables_lib.global_variables_initializer())
+
+      save_dir = os.path.join(self.get_temp_dir(), "saved_model")
+      save.save(root, save_dir)
+      loaded = load.load(save_dir)
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.assertAllClose(
+          expected, loaded.call(array_ops.zeros((3, 8))))
+
 
 @test_util.run_all_in_graph_and_eager_modes
 @test_util.run_all_without_tensor_float_32(
diff --git a/tensorflow/python/kernel_tests/signal/BUILD b/tensorflow/python/kernel_tests/signal/BUILD
index bd89318..d2d6296 100644
--- a/tensorflow/python/kernel_tests/signal/BUILD
+++ b/tensorflow/python/kernel_tests/signal/BUILD
@@ -125,6 +125,7 @@
     srcs = ["spectral_ops_test.py"],
     python_version = "PY3",
     tags = [
+        "no_rocm",
         "nomac",
     ],
     deps = [
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index 99f70c1..c53f196 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -182,23 +182,6 @@
           np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
           np.array([0, 3]).astype(label_dtype))
 
-  def testBfloat16(self):
-    for label_dtype in np.int32, np.int64:
-      np_features = np.array([[1., 1., 1., 1.], [1., 2., 3.,
-                                                 4.]]).astype(np.float32)
-      np_labels = np.array([0, 3]).astype(label_dtype)
-      np_loss, np_backprop = self._npXent(np_features, np_labels)
-
-      np_features_bf16 = math_ops.cast(np_features, dtypes.bfloat16)
-      np_loss_bf16 = math_ops.cast(np_loss, dtypes.bfloat16)
-      np_backprop_bf16 = math_ops.cast(np_backprop, dtypes.bfloat16)
-      with self.cached_session(use_gpu=False):
-        loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
-            np_features_bf16, np_labels)
-        tf_loss, tf_backprop = self.evaluate([loss, backprop])
-      self.assertAllCloseAccordingToType(np_loss_bf16, tf_loss)
-      self.assertAllCloseAccordingToType(np_backprop_bf16, tf_backprop)
-
   def testHalf(self):
     for label_dtype in np.int32, np.int64:
       self._testXent(
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py
index 436fef8..072a901 100644
--- a/tensorflow/python/kernel_tests/unique_op_test.py
+++ b/tensorflow/python/kernel_tests/unique_op_test.py
@@ -21,6 +21,7 @@
 import numpy as np
 
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.platform import test
@@ -106,6 +107,51 @@
     for i in range(len(x)):
       self.assertEqual(x[i], tf_y[tf_idx[i]])
 
+  @test_util.run_deprecated_v1
+  def testShapeInferenceV2(self):
+    """Test shape inference."""
+    x = np.arange(6).reshape(3, 2, 1)
+    _, idx = gen_array_ops.unique_v2(x, axis=[0])
+    self.assertEqual(idx.shape.as_list(), [3])
+    _, idx = gen_array_ops.unique_v2(x, axis=[1])
+    self.assertEqual(idx.shape.as_list(), [2])
+    _, idx = gen_array_ops.unique_v2(x, axis=[2])
+    self.assertEqual(idx.shape.as_list(), [1])
+    _, idx = gen_array_ops.unique_v2(x, axis=[-1])
+    self.assertEqual(idx.shape.as_list(), [1])
+    _, idx = gen_array_ops.unique_v2(x, axis=[-2])
+    self.assertEqual(idx.shape.as_list(), [2])
+    _, idx = gen_array_ops.unique_v2(x, axis=[-3])
+    self.assertEqual(idx.shape.as_list(), [3])
+    _, idx = gen_array_ops.unique_v2([0, 1, 2], axis=[])
+    self.assertEqual(idx.shape.as_list(), [3])
+
+    with self.assertRaisesRegexp(ValueError, "axis expects a 1D vector"):
+      gen_array_ops.unique_v2(x, axis=[[0]])
+
+    with self.assertRaisesRegexp(ValueError, "x expects a 1D vector"):
+      gen_array_ops.unique_v2(x, axis=[])
+
+    with self.assertRaisesRegexp(
+        ValueError, "axis does not support input tensors larger than"):
+      gen_array_ops.unique_v2(x, axis=[1, 2])
+
+    with self.assertRaisesRegexp(ValueError,
+                                 r"axis expects to be in the range \[-3, 3\)"):
+      gen_array_ops.unique_v2(x, axis=[3])
+
+    with self.assertRaisesRegexp(ValueError,
+                                 r"axis expects to be in the range \[-3, 3\)"):
+      gen_array_ops.unique_v2(x, axis=[-4])
+
+    x_t = array_ops.placeholder(dtypes.int32, shape=None)
+    _, idx = gen_array_ops.unique_v2(x_t, axis=[0])
+    self.assertEqual(idx.shape.as_list(), [None])
+
+    axis_t = array_ops.placeholder(dtypes.int32, shape=None)
+    _, idx = gen_array_ops.unique_v2(x, axis=axis_t)
+    self.assertEqual(idx.shape.as_list(), [None])
+
 
 class UniqueWithCountsTest(test.TestCase):
 
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
index a161b4b..feba76a 100644
--- a/tensorflow/python/kernel_tests/while_v2_test.py
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -33,6 +33,7 @@
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import meta_graph
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
 from tensorflow.python.grappler import tf_optimizer
 from tensorflow.python.ops import array_ops
@@ -42,6 +43,7 @@
 from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import custom_gradient
 from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_list_ops
 from tensorflow.python.ops import gradient_checker_v2
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import list_ops
@@ -1316,6 +1318,50 @@
 
     Fn()
 
+  def testDoNotAccumulateForwardTensorsForTensorListReductionOps(self):
+
+    @def_function.function
+    def Fn():
+      with backprop.GradientTape() as tape:
+        e = constant_op.constant(2.)
+        x = list_ops.empty_tensor_list(
+            element_dtype=dtypes.float32, element_shape=e.shape)
+        x = list_ops.tensor_list_push_back(x, e)
+        tape.watch(x)
+
+        def Body(i, x):
+          forward_graph = ops.get_default_graph()
+
+          @custom_gradient.custom_gradient
+          def IdentityWithZeroGrad(x):
+
+            def Grad(unused_g, variables=None):  # pylint: disable=redefined-outer-name
+              del variables
+              gradient_graph = ops.get_default_graph()
+              shape = gen_list_ops.tensor_list_element_shape(
+                  x, shape_type=dtypes.int32)
+              assert shape.graph is forward_graph
+              size = gen_list_ops.tensor_list_length(x)
+              assert size.graph is forward_graph
+              zeros = gen_list_ops.tensor_list_reserve(shape, size,
+                                                       dtypes.float32)
+              assert zeros.graph is gradient_graph
+              return zeros
+
+            return x, Grad
+
+          return i + 1, IdentityWithZeroGrad(x)
+
+        _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x])
+      ones_like = list_ops.tensor_list_from_tensor(
+          array_ops.ones_like(
+              list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)),
+          element_shape=tensor_shape.TensorShape([]))
+      grad = tape.gradient(result, x, output_gradients=[ones_like])
+      return grad
+
+    Fn()
+
   @test_util.run_v2_only
   def testInheritParentNameScope(self):
 
diff --git a/tensorflow/python/lib/core/BUILD b/tensorflow/python/lib/core/BUILD
new file mode 100644
index 0000000..d3a28b2
--- /dev/null
+++ b/tensorflow/python/lib/core/BUILD
@@ -0,0 +1,370 @@
+# python/lib/core package
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
+
+visibility = [
+    "//engedu/ml/tf_from_scratch:__pkg__",
+    "//third_party/cloud_tpu/convergence_tools:__subpackages__",
+    "//third_party/mlperf:__subpackages__",
+    "//tensorflow:internal",
+    "//tensorflow/lite/toco/python:__pkg__",
+    "//tensorflow_models:__subpackages__",
+    "//tensorflow_model_optimization:__subpackages__",
+    "//third_party/py/cleverhans:__subpackages__",
+    "//third_party/py/launchpad:__subpackages__",
+    "//third_party/py/reverb:__subpackages__",
+    "//third_party/py/neural_structured_learning:__subpackages__",
+    "//third_party/py/tensorflow_examples:__subpackages__",
+    "//third_party/py/tf_agents:__subpackages__",  # For benchmarks.
+    "//third_party/py/tf_slim:__subpackages__",
+    "//third_party/py/tensorflow_docs:__subpackages__",
+    "//third_party/py/keras:__subpackages__",
+]
+
+package(
+    default_visibility = visibility,
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "numpy_lib",
+    srcs = ["numpy.cc"],
+    hdrs = ["numpy.h"],
+    deps = [
+        "//third_party/py/numpy:headers",
+        "//third_party/python_runtime:headers",
+    ],
+)
+
+cc_library(
+    name = "bfloat16_lib",
+    srcs = ["bfloat16.cc"],
+    hdrs = ["bfloat16.h"],
+    deps = [
+        ":numpy_lib",
+        "//tensorflow/core/platform:logging",
+        "//third_party/eigen3",
+        "//third_party/python_runtime:headers",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_bfloat16",
+    srcs = ["bfloat16_wrapper.cc"],
+    hdrs = ["bfloat16.h"],
+    module_name = "_pywrap_bfloat16",
+    deps = [
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
+cc_library(
+    name = "ndarray_tensor_bridge",
+    srcs = ["ndarray_tensor_bridge.cc"],
+    hdrs = ["ndarray_tensor_bridge.h"],
+    visibility = tf_external_workspace_visible(
+        visibility + [
+            "//tensorflow:ndarray_tensor_allow_list",
+        ],
+    ),
+    deps = [
+        ":bfloat16_lib",
+        ":numpy_lib",
+        "//tensorflow/c:c_api_no_xla",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+    ],
+)
+
+cc_library(
+    name = "py_exception_registry",
+    srcs = ["py_exception_registry.cc"],
+    hdrs = ["py_exception_registry.h"],
+    deps = [
+        "//tensorflow/c:tf_status_headers",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//third_party/python_runtime:headers",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "pybind11_absl",
+    hdrs = ["pybind11_absl.h"],
+    features = ["-parse_headers"],
+    visibility = tf_external_workspace_visible(visibility),
+    deps = [
+        "//tensorflow/core/platform:stringpiece",
+        "@pybind11",
+    ],
+)
+
+cc_library(
+    name = "pybind11_lib",
+    hdrs = ["pybind11_lib.h"],
+    compatible_with = get_compatible_with_portable(),
+    features = ["-parse_headers"],
+    visibility = tf_external_workspace_visible(visibility),
+    deps = [
+        "@pybind11",
+    ],
+)
+
+cc_library(
+    name = "pybind11_status_headers",
+    hdrs = [
+        "py_exception_registry.h",
+        "pybind11_status.h",
+        "//tensorflow/c:headers",
+        "//tensorflow/c/eager:headers",
+    ],
+    features = [
+        "-parse_headers",
+    ],
+    visibility = tf_external_workspace_visible(visibility),
+    deps = [
+        "//tensorflow/c:tf_status_headers",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
+cc_library(
+    name = "pybind11_status",
+    hdrs = [
+        "py_exception_registry.h",
+        "pybind11_status.h",
+        "//tensorflow/c:headers",
+    ],
+    features = ["-parse_headers"],
+    visibility = tf_external_workspace_visible(visibility),
+    deps = [
+        ":pybind11_status_headers",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
+    name = "pybind11_proto",
+    hdrs = ["pybind11_proto.h"],
+    features = ["-parse_headers"],
+    visibility = tf_external_workspace_visible(visibility),
+    deps = [
+        "@com_google_absl//absl/strings",
+        "@pybind11",
+    ],
+)
+
+filegroup(
+    name = "py_exception_registry_hdr",
+    srcs = [
+        "py_exception_registry.h",
+    ],
+    visibility = ["//visibility:public"],
+)
+
+filegroup(
+    name = "numpy_hdr",
+    srcs = ["numpy.h"],
+)
+
+filegroup(
+    name = "safe_ptr_hdr",
+    srcs = ["safe_ptr.h"],
+)
+
+filegroup(
+    name = "ndarray_tensor_hdr",
+    srcs = ["ndarray_tensor.h"],
+)
+
+filegroup(
+    name = "basic_hdrs",
+    srcs = [
+        "bfloat16.h",
+        "ndarray_tensor.h",
+        "ndarray_tensor_bridge.h",
+        "numpy.h",
+        "py_exception_registry.h",
+        "pybind11_status.h",
+        "safe_ptr.h",
+        "safe_pyobject_ptr.h",
+    ],
+)
+
+cc_library(
+    name = "py_func_lib",
+    srcs = ["py_func.cc"],
+    hdrs = ["py_func.h"],
+    deps = [
+        ":ndarray_tensor",
+        ":ndarray_tensor_bridge",
+        ":numpy_lib",
+        ":py_util",
+        ":safe_ptr",
+        "//tensorflow/c:tf_status_helper",
+        "//tensorflow/c/eager:c_api",
+        "//tensorflow/c/eager:tfe_context_internal",
+        "//tensorflow/c/eager:tfe_tensorhandle_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:script_ops_op_lib",
+        "//tensorflow/core/common_runtime/eager:context",
+        "//tensorflow/core/common_runtime/eager:tensor_handle",
+        "//tensorflow/python/eager:pywrap_tfe_lib",
+        "//third_party/py/numpy:headers",
+        "//third_party/python_runtime:headers",
+    ],
+    alwayslink = 1,
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_py_func",
+    srcs = ["py_func_wrapper.cc"],
+    module_name = "_pywrap_py_func",
+    deps = [
+        "//tensorflow/python:py_func_headers_lib",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
+cc_library(
+    name = "safe_pyobject_ptr",
+    srcs = ["safe_pyobject_ptr.cc"],
+    hdrs = ["safe_pyobject_ptr.h"],
+    deps = [
+        "//third_party/python_runtime:headers",
+    ],
+)
+
+cc_library(
+    name = "safe_pyobject_ptr_required_hdrs",
+    textual_hdrs = ["safe_pyobject_ptr.h"],
+)
+
+cc_library(
+    name = "safe_ptr",
+    srcs = [
+        "safe_ptr.cc",
+        "//tensorflow/c/eager:headers",
+    ],
+    hdrs = ["safe_ptr.h"],
+    deps = [
+        ":safe_pyobject_ptr",
+        "//tensorflow/c:c_api_no_xla",
+        "//third_party/python_runtime:headers",
+    ],
+)
+
+cc_library(
+    name = "ndarray_tensor_headers",
+    hdrs = [
+        "bfloat16.h",
+        "ndarray_tensor.h",
+        "ndarray_tensor_bridge.h",
+        "numpy.h",
+        "safe_ptr.h",
+        "safe_pyobject_ptr.h",
+        "//tensorflow/c:headers",
+        "//tensorflow/c/eager:headers",
+    ],
+    features = [
+        "-parse_headers",
+    ],
+    visibility = tf_external_workspace_visible(visibility + [
+        "//tensorflow:ndarray_tensor_allow_list",
+    ]),
+    deps = [
+        ":numpy_lib",
+        "//tensorflow/c:pywrap_required_hdrs",
+        "//tensorflow/c:tf_status_headers",
+        "//tensorflow/core:framework_internal_headers_lib",
+        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
+        "//third_party/py/numpy:headers",
+        "//third_party/python_runtime:headers",
+    ],
+)
+
+cc_library(
+    name = "ndarray_tensor",
+    srcs = ["ndarray_tensor.cc"],
+    hdrs = ["ndarray_tensor.h"],
+    visibility = tf_external_workspace_visible(visibility + [
+        "//tensorflow:ndarray_tensor_allow_list",
+    ]),
+    deps = [
+        ":bfloat16_lib",
+        ":ndarray_tensor_bridge",
+        ":numpy_lib",
+        ":safe_ptr",
+        "//tensorflow/c:c_api_internal",
+        "//tensorflow/c:tf_status_helper",
+        "//tensorflow/c:tf_tensor_internal",
+        "//tensorflow/c/eager:tfe_context_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
+    name = "py_seq_tensor",
+    srcs = ["py_seq_tensor.cc"],
+    hdrs = ["py_seq_tensor.h"],
+    features = ["-parse_headers"],
+    deps = [
+        ":ndarray_tensor",
+        ":ndarray_tensor_bridge",
+        ":numpy_lib",
+        ":py_util",
+        ":safe_ptr",
+        "//tensorflow/c:tensor_interface",
+        "//tensorflow/c:tf_tensor_internal",
+        "//tensorflow/c/eager:c_api_internal",
+        "//tensorflow/c/eager:tfe_context_internal",
+        "//tensorflow/c/eager:tfe_tensorhandle_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//third_party/python_runtime:headers",  # build_cleaner: keep; DNR: b/35864863
+    ],
+)
+
+cc_library(
+    name = "py_util",
+    srcs = ["py_util.cc"],
+    hdrs = ["py_util.h"],
+    deps = [
+        "//tensorflow/core:lib",
+        "//tensorflow/core:script_ops_op_lib",
+        "//tensorflow/core/platform:logging",
+        "//third_party/python_runtime:headers",
+    ],
+)
+
+tf_py_test(
+    name = "bfloat16_test",
+    size = "small",
+    srcs = ["bfloat16_test.py"],
+    python_version = "PY3",
+    deps = [
+        "//tensorflow/python:pywrap_tensorflow",
+        "//tensorflow/python/lib/io:lib",
+        "//tensorflow/python/platform:client_testlib",
+    ],
+)
diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc
index 31def39..8d35186 100644
--- a/tensorflow/python/lib/core/bfloat16.cc
+++ b/tensorflow/python/lib/core/bfloat16.cc
@@ -13,67 +13,54 @@
 limitations under the License.
 ==============================================================================*/
 
-#include <array>
-
 #include "tensorflow/python/lib/core/bfloat16.h"
 
-#include "tensorflow/core/framework/numeric_types.h"
-#include "tensorflow/core/lib/strings/strcat.h"
+#include <array>
+#include <locale>
+// Place `<locale>` before <Python.h> to avoid a build failure in macOS.
+#include <Python.h>
+
+#include "absl/strings/str_cat.h"
+#include "third_party/eigen3/Eigen/Core"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/python/lib/core/numpy.h"
-#include "tensorflow/python/lib/core/safe_ptr.h"
 
 namespace tensorflow {
 namespace {
 
-// Workarounds for Python 2 vs 3 API differences.
-#if PY_MAJOR_VERSION < 3
+using bfloat16 = Eigen::bfloat16;
 
-PyObject* MakePyString(const string& s) {
-  return PyString_FromString(s.c_str());
+struct PyDecrefDeleter {
+  void operator()(PyObject* p) const { Py_DECREF(p); }
+};
+
+// Safe container for an owned PyObject. On destruction, the reference count of
+// the contained object will be decremented.
+using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
+Safe_PyObjectPtr make_safe(PyObject* object) {
+  return Safe_PyObjectPtr(object);
 }
 
-typedef long HashType;  // NOLINT
-
-bool TfPyInt_Check(PyObject* object) { return PyInt_Check(object); }
-
-PyObject* TfPyInt_FromLong(long x) {  // NOLINT
-  return PyInt_FromLong(x);
-}
-
-long TfPyInt_AsLong(PyObject* x) {  // NOLINT
-  return PyInt_AsLong(x);
-}
-
-#else  // PY_MAJOR_VERSION < 3
-
-PyObject* MakePyString(const string& s) {
-  return PyUnicode_FromString(s.c_str());
-}
-
-bool TfPyInt_Check(PyObject* object) {
+bool PyLong_CheckNoOverflow(PyObject* object) {
   if (!PyLong_Check(object)) {
-    return 0;
+    return false;
   }
   int overflow = 0;
   PyLong_AsLongAndOverflow(object, &overflow);
   return (overflow == 0);
 }
 
-PyObject* TfPyInt_FromLong(long x) {  // NOLINT
-  return PyLong_FromLong(x);
-}
-
-long TfPyInt_AsLong(PyObject* x) {  // NOLINT
-  return PyLong_AsLong(x);
-}
-
-typedef Py_hash_t HashType;
-
-#endif  // PY_MAJOR_VERSION < 3
+// Registered numpy type ID. Global variable populated by the registration code.
+// Protected by the GIL.
+int npy_bfloat16 = NPY_NOTYPE;
 
 // Forward declaration.
-extern PyTypeObject PyBfloat16_Type;
+extern PyTypeObject bfloat16_type;
+
+// Pointer to the bfloat16 type object we are using. This is either a pointer
+// to bfloat16_type, if we choose to register it, or to the bfloat16 type
+// registered by another system into NumPy.
+PyTypeObject* bfloat16_type_ptr = nullptr;
 
 // Representation of a Python bfloat16 object.
 struct PyBfloat16 {
@@ -84,7 +71,7 @@
 // Returns true if 'object' is a PyBfloat16.
 bool PyBfloat16_Check(PyObject* object) {
   return PyObject_IsInstance(object,
-                             reinterpret_cast<PyObject*>(&PyBfloat16_Type));
+                             reinterpret_cast<PyObject*>(&bfloat16_type));
 }
 
 // Extracts the value of a PyBfloat16 object.
@@ -94,8 +81,7 @@
 
 // Constructs a PyBfloat16 object from a bfloat16.
 Safe_PyObjectPtr PyBfloat16_FromBfloat16(bfloat16 x) {
-  Safe_PyObjectPtr ref =
-      make_safe(PyBfloat16_Type.tp_alloc(&PyBfloat16_Type, 0));
+  Safe_PyObjectPtr ref = make_safe(bfloat16_type.tp_alloc(&bfloat16_type, 0));
   PyBfloat16* p = reinterpret_cast<PyBfloat16*>(ref.get());
   if (p) {
     p->value = x;
@@ -105,7 +91,7 @@
 
 // Converts a Python object to a bfloat16 value. Returns true on success,
 // returns false and reports a Python error on failure.
-bool AsBfloat16(PyObject* arg, bfloat16* output) {
+bool CastToBfloat16(PyObject* arg, bfloat16* output) {
   if (PyBfloat16_Check(arg)) {
     *output = PyBfloat16_Bfloat16(arg);
     return true;
@@ -119,8 +105,8 @@
     *output = bfloat16(d);
     return true;
   }
-  if (TfPyInt_Check(arg)) {
-    long l = TfPyInt_AsLong(arg);  // NOLINT
+  if (PyLong_CheckNoOverflow(arg)) {
+    long l = PyLong_AsLong(arg);  // NOLINT
     if (PyErr_Occurred()) {
       return false;
     }
@@ -128,14 +114,46 @@
     *output = bfloat16(static_cast<float>(l));
     return true;
   }
+  if (PyArray_IsScalar(arg, Half)) {
+    Eigen::half f;
+    PyArray_ScalarAsCtype(arg, &f);
+    *output = bfloat16(f);
+    return true;
+  }
   if (PyArray_IsScalar(arg, Float)) {
     float f;
     PyArray_ScalarAsCtype(arg, &f);
     *output = bfloat16(f);
     return true;
   }
-  PyErr_Format(PyExc_TypeError, "expected number, got %s",
-               arg->ob_type->tp_name);
+  if (PyArray_IsScalar(arg, Double)) {
+    double f;
+    PyArray_ScalarAsCtype(arg, &f);
+    *output = bfloat16(f);
+    return true;
+  }
+  if (PyArray_IsZeroDim(arg)) {
+    Safe_PyObjectPtr ref;
+    PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
+    if (PyArray_TYPE(arr) != npy_bfloat16) {
+      ref = make_safe(PyArray_Cast(arr, npy_bfloat16));
+      if (PyErr_Occurred()) {
+        return false;
+      }
+      arg = ref.get();
+      arr = reinterpret_cast<PyArrayObject*>(arg);
+    }
+    *output = *reinterpret_cast<bfloat16*>(PyArray_DATA(arr));
+    return true;
+  }
+  return false;
+}
+
+bool SafeCastToBfloat16(PyObject* arg, bfloat16* output) {
+  if (PyBfloat16_Check(arg)) {
+    *output = PyBfloat16_Bfloat16(arg);
+    return true;
+  }
   return false;
 }
 
@@ -149,7 +167,7 @@
 PyObject* PyBfloat16_Int(PyObject* self) {
   bfloat16 x = PyBfloat16_Bfloat16(self);
   long y = static_cast<long>(x);  // NOLINT
-  return TfPyInt_FromLong(y);
+  return PyLong_FromLong(y);
 }
 
 // Negates a PyBfloat16.
@@ -158,28 +176,43 @@
   return PyBfloat16_FromBfloat16(-x).release();
 }
 
-// Binary arithmetic operators on PyBfloat16 values.
-#define BFLOAT16_BINOP(name, op)                                  \
-  PyObject* PyBfloat16_##name(PyObject* a, PyObject* b) {         \
-    bfloat16 x, y;                                                \
-    if (!AsBfloat16(a, &x) || !AsBfloat16(b, &y)) return nullptr; \
-    bfloat16 z = x op y;                                          \
-    return PyBfloat16_FromBfloat16(z).release();                  \
+PyObject* PyBfloat16_Add(PyObject* a, PyObject* b) {
+  bfloat16 x, y;
+  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
+    return PyBfloat16_FromBfloat16(x + y).release();
   }
-BFLOAT16_BINOP(Add, +)
-BFLOAT16_BINOP(Subtract, -)
-BFLOAT16_BINOP(Multiply, *)
-BFLOAT16_BINOP(Divide, /)
-#undef BFLOAT16_BINOP
+  return PyArray_Type.tp_as_number->nb_add(a, b);
+}
+
+PyObject* PyBfloat16_Subtract(PyObject* a, PyObject* b) {
+  bfloat16 x, y;
+  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
+    return PyBfloat16_FromBfloat16(x - y).release();
+  }
+  return PyArray_Type.tp_as_number->nb_subtract(a, b);
+}
+
+PyObject* PyBfloat16_Multiply(PyObject* a, PyObject* b) {
+  bfloat16 x, y;
+  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
+    return PyBfloat16_FromBfloat16(x * y).release();
+  }
+  return PyArray_Type.tp_as_number->nb_multiply(a, b);
+}
+
+PyObject* PyBfloat16_TrueDivide(PyObject* a, PyObject* b) {
+  bfloat16 x, y;
+  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
+    return PyBfloat16_FromBfloat16(x / y).release();
+  }
+  return PyArray_Type.tp_as_number->nb_true_divide(a, b);
+}
 
 // Python number methods for PyBfloat16 objects.
 PyNumberMethods PyBfloat16_AsNumber = {
     PyBfloat16_Add,       // nb_add
     PyBfloat16_Subtract,  // nb_subtract
     PyBfloat16_Multiply,  // nb_multiply
-#if PY_MAJOR_VERSION < 3
-    PyBfloat16_Divide,  // nb_divide
-#endif
     nullptr,              // nb_remainder
     nullptr,              // nb_divmod
     nullptr,              // nb_power
@@ -193,27 +226,13 @@
     nullptr,              // nb_and
     nullptr,              // nb_xor
     nullptr,              // nb_or
-#if PY_MAJOR_VERSION < 3
-    nullptr,  // nb_coerce
-#endif
-    PyBfloat16_Int,  // nb_int
-#if PY_MAJOR_VERSION < 3
-    PyBfloat16_Int,  // nb_long
-#else
-    nullptr,  // reserved
-#endif
-    PyBfloat16_Float,  // nb_float
-#if PY_MAJOR_VERSION < 3
-    nullptr,  // nb_oct
-    nullptr,  // nb_hex
-#endif
+    PyBfloat16_Int,       // nb_int
+    nullptr,              // reserved
+    PyBfloat16_Float,     // nb_float
 
     nullptr,  // nb_inplace_add
     nullptr,  // nb_inplace_subtract
     nullptr,  // nb_inplace_multiply
-#if PY_MAJOR_VERSION < 3
-    nullptr,  // nb_inplace_divide
-#endif
     nullptr,  // nb_inplace_remainder
     nullptr,  // nb_inplace_power
     nullptr,  // nb_inplace_lshift
@@ -222,11 +241,11 @@
     nullptr,  // nb_inplace_xor
     nullptr,  // nb_inplace_or
 
-    nullptr,            // nb_floor_divide
-    PyBfloat16_Divide,  // nb_true_divide
-    nullptr,            // nb_inplace_floor_divide
-    nullptr,            // nb_inplace_true_divide
-    nullptr,            // nb_index
+    nullptr,                // nb_floor_divide
+    PyBfloat16_TrueDivide,  // nb_true_divide
+    nullptr,                // nb_inplace_floor_divide
+    nullptr,                // nb_inplace_true_divide
+    nullptr,                // nb_index
 };
 
 // Constructs a new PyBfloat16.
@@ -243,22 +262,32 @@
   }
   PyObject* arg = PyTuple_GetItem(args, 0);
 
+  bfloat16 value;
   if (PyBfloat16_Check(arg)) {
     Py_INCREF(arg);
     return arg;
-  } else {
-    bfloat16 value;
-    if (!AsBfloat16(arg, &value)) {
-      return nullptr;
-    }
+  } else if (CastToBfloat16(arg, &value)) {
     return PyBfloat16_FromBfloat16(value).release();
+  } else if (PyArray_Check(arg)) {
+    PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
+    if (PyArray_TYPE(arr) != npy_bfloat16) {
+      return PyArray_Cast(arr, npy_bfloat16);
+    } else {
+      Py_INCREF(arg);
+      return arg;
+    }
   }
+  PyErr_Format(PyExc_TypeError, "expected number, got %s",
+               arg->ob_type->tp_name);
+  return nullptr;
 }
 
 // Comparisons on PyBfloat16s.
 PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) {
   bfloat16 x, y;
-  if (!AsBfloat16(a, &x) || !AsBfloat16(b, &y)) return nullptr;
+  if (!SafeCastToBfloat16(a, &x) || !SafeCastToBfloat16(b, &y)) {
+    return PyGenericArrType_Type.tp_richcompare(a, b, op);
+  }
   bool result;
   switch (op) {
     case Py_LT:
@@ -288,81 +317,77 @@
 // Implementation of repr() for PyBfloat16.
 PyObject* PyBfloat16_Repr(PyObject* self) {
   bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
-  string v = strings::StrCat("bfloat16(", static_cast<float>(x), ")");
-  return MakePyString(v);
+  std::string v = absl::StrCat(static_cast<float>(x));
+  return PyUnicode_FromString(v.c_str());
 }
 
 // Implementation of str() for PyBfloat16.
 PyObject* PyBfloat16_Str(PyObject* self) {
   bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
-  string v = strings::StrCat(static_cast<float>(x));
-  return MakePyString(v);
+  std::string v = absl::StrCat(static_cast<float>(x));
+  return PyUnicode_FromString(v.c_str());
 }
 
 // Hash function for PyBfloat16. We use the identity function, which is a weak
 // hash function.
-HashType PyBfloat16_Hash(PyObject* self) {
+Py_hash_t PyBfloat16_Hash(PyObject* self) {
   bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
   return x.value;
 }
 
 // Python type for PyBfloat16 objects.
-PyTypeObject PyBfloat16_Type = {
-#if PY_MAJOR_VERSION < 3
-    PyObject_HEAD_INIT(nullptr) 0,  // ob_size
-#else
-    PyVarObject_HEAD_INIT(nullptr, 0)
-#endif
-    "bfloat16",          // tp_name
-    sizeof(PyBfloat16),  // tp_basicsize
-    0,                   // tp_itemsize
-    nullptr,             // tp_dealloc
+PyTypeObject bfloat16_type = {
+    PyVarObject_HEAD_INIT(nullptr, 0) "bfloat16",  // tp_name
+    sizeof(PyBfloat16),                            // tp_basicsize
+    0,                                             // tp_itemsize
+    nullptr,                                       // tp_dealloc
 #if PY_VERSION_HEX < 0x03080000
     nullptr,  // tp_print
 #else
     0,  // tp_vectorcall_offset
 #endif
-    nullptr,                                   // tp_getattr
-    nullptr,                                   // tp_setattr
-    nullptr,                                   // tp_compare / tp_reserved
-    PyBfloat16_Repr,                           // tp_repr
-    &PyBfloat16_AsNumber,                      // tp_as_number
-    nullptr,                                   // tp_as_sequence
-    nullptr,                                   // tp_as_mapping
-    PyBfloat16_Hash,                           // tp_hash
-    nullptr,                                   // tp_call
-    PyBfloat16_Str,                            // tp_str
-    nullptr,                                   // tp_getattro
-    nullptr,                                   // tp_setattro
-    nullptr,                                   // tp_as_buffer
-    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,  // tp_flags
-    "bfloat16 floating-point values",          // tp_doc
-    nullptr,                                   // tp_traverse
-    nullptr,                                   // tp_clear
-    PyBfloat16_RichCompare,                    // tp_richcompare
-    0,                                         // tp_weaklistoffset
-    nullptr,                                   // tp_iter
-    nullptr,                                   // tp_iternext
-    nullptr,                                   // tp_methods
-    nullptr,                                   // tp_members
-    nullptr,                                   // tp_getset
-    nullptr,                                   // tp_base
-    nullptr,                                   // tp_dict
-    nullptr,                                   // tp_descr_get
-    nullptr,                                   // tp_descr_set
-    0,                                         // tp_dictoffset
-    nullptr,                                   // tp_init
-    nullptr,                                   // tp_alloc
-    PyBfloat16_New,                            // tp_new
-    nullptr,                                   // tp_free
-    nullptr,                                   // tp_is_gc
-    nullptr,                                   // tp_bases
-    nullptr,                                   // tp_mro
-    nullptr,                                   // tp_cache
-    nullptr,                                   // tp_subclasses
-    nullptr,                                   // tp_weaklist
-    nullptr,                                   // tp_del
-    0,                                         // tp_version_tag
+    nullptr,               // tp_getattr
+    nullptr,               // tp_setattr
+    nullptr,               // tp_compare / tp_reserved
+    PyBfloat16_Repr,       // tp_repr
+    &PyBfloat16_AsNumber,  // tp_as_number
+    nullptr,               // tp_as_sequence
+    nullptr,               // tp_as_mapping
+    PyBfloat16_Hash,       // tp_hash
+    nullptr,               // tp_call
+    PyBfloat16_Str,        // tp_str
+    nullptr,               // tp_getattro
+    nullptr,               // tp_setattro
+    nullptr,               // tp_as_buffer
+                           // tp_flags
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
+    "bfloat16 floating-point values",  // tp_doc
+    nullptr,                           // tp_traverse
+    nullptr,                           // tp_clear
+    PyBfloat16_RichCompare,            // tp_richcompare
+    0,                                 // tp_weaklistoffset
+    nullptr,                           // tp_iter
+    nullptr,                           // tp_iternext
+    nullptr,                           // tp_methods
+    nullptr,                           // tp_members
+    nullptr,                           // tp_getset
+    nullptr,                           // tp_base
+    nullptr,                           // tp_dict
+    nullptr,                           // tp_descr_get
+    nullptr,                           // tp_descr_set
+    0,                                 // tp_dictoffset
+    nullptr,                           // tp_init
+    nullptr,                           // tp_alloc
+    PyBfloat16_New,                    // tp_new
+    nullptr,                           // tp_free
+    nullptr,                           // tp_is_gc
+    nullptr,                           // tp_bases
+    nullptr,                           // tp_mro
+    nullptr,                           // tp_cache
+    nullptr,                           // tp_subclasses
+    nullptr,                           // tp_weaklist
+    nullptr,                           // tp_del
+    0,                                 // tp_version_tag
 };
 
 // Numpy support
@@ -370,31 +395,32 @@
 PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
 
 PyArray_Descr NPyBfloat16_Descr = {
-    PyObject_HEAD_INIT(nullptr) & PyBfloat16_Type,  // typeobj
+    PyObject_HEAD_INIT(nullptr)  //
+                                 /*typeobj=*/
+    (&bfloat16_type),
     // We must register bfloat16 with a kind other than "f", because numpy
     // considers two types with the same kind and size to be equal, but
     // float16 != bfloat16.
-    'V',  // kind
+    // The downside of this is that NumPy scalar promotion does not work with
+    // bfloat16 values.
+    /*kind=*/'V',
     // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
     // character is unique.
-    'E',                                                  // type
-    '=',                                                  // byteorder
-    NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,  // hasobject
-    0,                                                    // type_num
-    sizeof(bfloat16),                                     // elsize
-    alignof(bfloat16),                                    // alignment
-    nullptr,                                              // subarray
-    nullptr,                                              // fields
-    nullptr,                                              // names
-    &NPyBfloat16_ArrFuncs,                                // f
-    nullptr,                                              // metadata
-    nullptr,                                              // c_metadata
-    -1,                                                   // hash
+    /*type=*/'E',
+    /*byteorder=*/'=',
+    /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
+    /*type_num=*/0,
+    /*elsize=*/sizeof(bfloat16),
+    /*alignment=*/alignof(bfloat16),
+    /*subarray=*/nullptr,
+    /*fields=*/nullptr,
+    /*names=*/nullptr,
+    /*f=*/&NPyBfloat16_ArrFuncs,
+    /*metadata=*/nullptr,
+    /*c_metadata=*/nullptr,
+    /*hash=*/-1,  // -1 means "not computed yet".
 };
 
-// Registered numpy type ID. Global variable populated by the registration code.
-int npy_bfloat16_ = -1;
-
 // Implementations of NumPy array methods.
 
 PyObject* NPyBfloat16_GetItem(void* data, void* arr) {
@@ -405,7 +431,11 @@
 
 int NPyBfloat16_SetItem(PyObject* item, void* data, void* arr) {
   bfloat16 x;
-  if (!AsBfloat16(item, &x)) return -1;
+  if (!CastToBfloat16(item, &x)) {
+    PyErr_Format(PyExc_TypeError, "expected number, got %s",
+                 item->ob_type->tp_name);
+    return -1;
+  }
   memcpy(data, &x, sizeof(bfloat16));
   return 0;
 }
@@ -486,16 +516,183 @@
   return 0;
 }
 
+void NPyBfloat16_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
+                         void* op, npy_intp n, void* arr) {
+  char* c1 = reinterpret_cast<char*>(ip1);
+  char* c2 = reinterpret_cast<char*>(ip2);
+  float acc = 0.0f;
+  for (npy_intp i = 0; i < n; ++i) {
+    bfloat16* const b1 = reinterpret_cast<bfloat16*>(c1);
+    bfloat16* const b2 = reinterpret_cast<bfloat16*>(c2);
+    acc += static_cast<float>(*b1) * static_cast<float>(*b2);
+    c1 += is1;
+    c2 += is2;
+  }
+  bfloat16* out = reinterpret_cast<bfloat16*>(op);
+  *out = static_cast<bfloat16>(acc);
+}
+
+int NPyBfloat16_CompareFunc(const void* v1, const void* v2, void* arr) {
+  bfloat16 b1 = *reinterpret_cast<const bfloat16*>(v1);
+  bfloat16 b2 = *reinterpret_cast<const bfloat16*>(v2);
+  if (b1 < b2) {
+    return -1;
+  }
+  if (b1 > b2) {
+    return 1;
+  }
+  return 0;
+}
+
+int NPyBfloat16_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind,
+                           void* arr) {
+  const bfloat16* bdata = reinterpret_cast<const bfloat16*>(data);
+  float max_val = -std::numeric_limits<float>::infinity();
+  for (npy_intp i = 0; i < n; ++i) {
+    if (static_cast<float>(bdata[i]) > max_val) {
+      max_val = static_cast<float>(bdata[i]);
+      *max_ind = i;
+    }
+  }
+  return 0;
+}
+
+int NPyBfloat16_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,
+                           void* arr) {
+  const bfloat16* bdata = reinterpret_cast<const bfloat16*>(data);
+  float min_val = std::numeric_limits<float>::infinity();
+  for (npy_intp i = 0; i < n; ++i) {
+    if (static_cast<float>(bdata[i]) < min_val) {
+      min_val = static_cast<float>(bdata[i]);
+      *min_ind = i;
+    }
+  }
+  return 0;
+}
+
 // NumPy casts
 
+template <typename T, typename Enable = void>
+struct TypeDescriptor {
+  // typedef ... T;  // Representation type in memory for NumPy values of type
+  // static int Dtype() { return NPY_...; }  // Numpy type number for T.
+};
+
+template <>
+struct TypeDescriptor<bfloat16> {
+  typedef bfloat16 T;
+  static int Dtype() { return npy_bfloat16; }
+};
+
+template <>
+struct TypeDescriptor<uint8> {
+  typedef uint8 T;
+  static int Dtype() { return NPY_UINT8; }
+};
+
+template <>
+struct TypeDescriptor<uint16> {
+  typedef uint16 T;
+  static int Dtype() { return NPY_UINT16; }
+};
+
+// We register "int", "long", and "long long" types for portability across
+// Linux, where "int" and "long" are the same type, and Windows, where "long"
+// and "longlong" are the same type.
+template <>
+struct TypeDescriptor<unsigned int> {
+  typedef unsigned int T;
+  static int Dtype() { return NPY_UINT; }
+};
+
+template <>
+struct TypeDescriptor<unsigned long> {  // NOLINT
+  typedef unsigned long T;              // NOLINT
+  static int Dtype() { return NPY_ULONG; }
+};
+
+template <>
+struct TypeDescriptor<unsigned long long> {  // NOLINT
+  typedef unsigned long long T;              // NOLINT
+  static int Dtype() { return NPY_ULONGLONG; }
+};
+
+template <>
+struct TypeDescriptor<int8> {
+  typedef int8 T;
+  static int Dtype() { return NPY_INT8; }
+};
+
+template <>
+struct TypeDescriptor<int16> {
+  typedef int16 T;
+  static int Dtype() { return NPY_INT16; }
+};
+
+template <>
+struct TypeDescriptor<int> {
+  typedef int T;
+  static int Dtype() { return NPY_INT; }
+};
+
+template <>
+struct TypeDescriptor<long> {  // NOLINT
+  typedef long T;              // NOLINT
+  static int Dtype() { return NPY_LONG; }
+};
+
+template <>
+struct TypeDescriptor<long long> {  // NOLINT
+  typedef long long T;              // NOLINT
+  static int Dtype() { return NPY_LONGLONG; }
+};
+
+template <>
+struct TypeDescriptor<bool> {
+  typedef int8 T;
+  static int Dtype() { return NPY_BOOL; }
+};
+
+template <>
+struct TypeDescriptor<Eigen::half> {
+  typedef Eigen::half T;
+  static int Dtype() { return NPY_HALF; }
+};
+
+template <>
+struct TypeDescriptor<float> {
+  typedef float T;
+  static int Dtype() { return NPY_FLOAT; }
+};
+
+template <>
+struct TypeDescriptor<double> {
+  typedef double T;
+  static int Dtype() { return NPY_DOUBLE; }
+};
+
+template <>
+struct TypeDescriptor<std::complex<float>> {
+  typedef std::complex<float> T;
+  static int Dtype() { return NPY_COMPLEX64; }
+};
+
+template <>
+struct TypeDescriptor<std::complex<double>> {
+  typedef std::complex<double> T;
+  static int Dtype() { return NPY_COMPLEX128; }
+};
+
 // Performs a NumPy array cast from type 'From' to 'To'.
 template <typename From, typename To>
 void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
              void* toarr) {
-  const From* from = reinterpret_cast<From*>(from_void);
-  To* to = reinterpret_cast<To*>(to_void);
+  const auto* from =
+      reinterpret_cast<typename TypeDescriptor<From>::T*>(from_void);
+  auto* to = reinterpret_cast<typename TypeDescriptor<To>::T*>(to_void);
   for (npy_intp i = 0; i < n; ++i) {
-    to[i] = static_cast<To>(from[i]);
+    to[i] =
+        static_cast<typename TypeDescriptor<To>::T>(static_cast<To>(from[i]));
   }
 }
 
@@ -504,7 +701,7 @@
 // safely coerced to T.
 template <typename T>
 bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
-  if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16_,
+  if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16,
                                NPyCast<T, bfloat16>) < 0) {
     return false;
   }
@@ -520,60 +717,591 @@
 }
 
 template <typename InType, typename OutType, typename Functor>
-void BinaryUFunc(char** args, const npy_intp* dimensions, const npy_intp* steps,
-                 void* data) {
-  const char* i0 = args[0];
-  const char* i1 = args[1];
-  char* o = args[2];
-  for (npy_intp k = 0; k < *dimensions; k++) {
-    InType x = *reinterpret_cast<const InType*>(i0);
-    InType y = *reinterpret_cast<const InType*>(i1);
-    *reinterpret_cast<OutType*>(o) = Functor()(x, y);
-    i0 += steps[0];
-    i1 += steps[1];
-    o += steps[2];
+struct UnaryUFunc {
+  static std::vector<int> Types() {
+    return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype()};
   }
+  static void Call(char** args, const npy_intp* dimensions,
+                   const npy_intp* steps, void* data) {
+    const char* i0 = args[0];
+    char* o = args[1];
+    for (npy_intp k = 0; k < *dimensions; k++) {
+      auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
+      *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) = Functor()(x);
+      i0 += steps[0];
+      o += steps[1];
+    }
+  }
+};
+
+template <typename InType, typename OutType, typename OutType2,
+          typename Functor>
+struct UnaryUFunc2 {
+  static std::vector<int> Types() {
+    return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype(),
+            TypeDescriptor<OutType2>::Dtype()};
+  }
+  static void Call(char** args, const npy_intp* dimensions,
+                   const npy_intp* steps, void* data) {
+    const char* i0 = args[0];
+    char* o0 = args[1];
+    char* o1 = args[2];
+    for (npy_intp k = 0; k < *dimensions; k++) {
+      auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
+      std::tie(*reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o0),
+               *reinterpret_cast<typename TypeDescriptor<OutType2>::T*>(o1)) =
+          Functor()(x);
+      i0 += steps[0];
+      o0 += steps[1];
+      o1 += steps[2];
+    }
+  }
+};
+
+template <typename InType, typename OutType, typename Functor>
+struct BinaryUFunc {
+  static std::vector<int> Types() {
+    return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType>::Dtype(),
+            TypeDescriptor<OutType>::Dtype()};
+  }
+  static void Call(char** args, const npy_intp* dimensions,
+                   const npy_intp* steps, void* data) {
+    const char* i0 = args[0];
+    const char* i1 = args[1];
+    char* o = args[2];
+    for (npy_intp k = 0; k < *dimensions; k++) {
+      auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
+      auto y = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i1);
+      *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
+          Functor()(x, y);
+      i0 += steps[0];
+      i1 += steps[1];
+      o += steps[2];
+    }
+  }
+};
+
+template <typename InType, typename InType2, typename OutType, typename Functor>
+struct BinaryUFunc2 {
+  static std::vector<int> Types() {
+    return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType2>::Dtype(),
+            TypeDescriptor<OutType>::Dtype()};
+  }
+  static void Call(char** args, const npy_intp* dimensions,
+                   const npy_intp* steps, void* data) {
+    const char* i0 = args[0];
+    const char* i1 = args[1];
+    char* o = args[2];
+    for (npy_intp k = 0; k < *dimensions; k++) {
+      auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
+      auto y =
+          *reinterpret_cast<const typename TypeDescriptor<InType2>::T*>(i1);
+      *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
+          Functor()(x, y);
+      i0 += steps[0];
+      i1 += steps[1];
+      o += steps[2];
+    }
+  }
+};
+
+template <typename UFunc>
+bool RegisterUFunc(PyObject* numpy, const char* name) {
+  std::vector<int> types = UFunc::Types();
+  PyUFuncGenericFunction fn =
+      reinterpret_cast<PyUFuncGenericFunction>(UFunc::Call);
+  Safe_PyObjectPtr ufunc_obj = make_safe(PyObject_GetAttrString(numpy, name));
+  if (!ufunc_obj) {
+    return false;
+  }
+  PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
+  if (static_cast<int>(types.size()) != ufunc->nargs) {
+    PyErr_Format(PyExc_AssertionError,
+                 "ufunc %s takes %d arguments, loop takes %lu", name,
+                 ufunc->nargs, types.size());
+    return false;
+  }
+  if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16, fn,
+                                  const_cast<int*>(types.data()),
+                                  nullptr) < 0) {
+    return false;
+  }
+  return true;
 }
 
-// Numpy changed const-ness of PyUFuncGenericFunction, provide overload.
-template <typename Functor>
-void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
-                  void* data) {
-  BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
-}
-template <typename Functor>
-void CompareUFunc(char** args, const npy_intp* dimensions,
-                  const npy_intp* steps, void* data) {
-  BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
+namespace ufuncs {
+
+struct Add {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a + b; }
+};
+struct Subtract {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a - b; }
+};
+struct Multiply {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a * b; }
+};
+struct TrueDivide {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a / b; }
+};
+
+std::pair<float, float> divmod(float a, float b) {
+  if (b == 0.0f) {
+    float nan = std::numeric_limits<float>::quiet_NaN();
+    return {nan, nan};
+  }
+  float mod = std::fmod(a, b);
+  float div = (a - mod) / b;
+  if (mod != 0.0f) {
+    if ((b < 0.0f) != (mod < 0.0f)) {
+      mod += b;
+      div -= 1.0f;
+    }
+  } else {
+    mod = std::copysign(0.0f, b);
+  }
+
+  float floordiv;
+  if (div != 0.0f) {
+    floordiv = std::floor(div);
+    if (div - floordiv > 0.5f) {
+      floordiv += 1.0f;
+    }
+  } else {
+    floordiv = std::copysign(0.0f, a / b);
+  }
+  return {floordiv, mod};
 }
 
-struct Bfloat16EqFunctor {
+struct FloorDivide {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(divmod(static_cast<float>(a), static_cast<float>(b)).first);
+  }
+};
+struct Remainder {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(
+        divmod(static_cast<float>(a), static_cast<float>(b)).second);
+  }
+};
+struct DivmodUFunc {
+  static std::vector<int> Types() {
+    return {npy_bfloat16, npy_bfloat16, npy_bfloat16, npy_bfloat16};
+  }
+  static void Call(char** args, npy_intp* dimensions, npy_intp* steps,
+                   void* data) {
+    const char* i0 = args[0];
+    const char* i1 = args[1];
+    char* o0 = args[2];
+    char* o1 = args[3];
+    for (npy_intp k = 0; k < *dimensions; k++) {
+      bfloat16 x = *reinterpret_cast<const bfloat16*>(i0);
+      bfloat16 y = *reinterpret_cast<const bfloat16*>(i1);
+      float floordiv, mod;
+      std::tie(floordiv, mod) =
+          divmod(static_cast<float>(x), static_cast<float>(y));
+      *reinterpret_cast<bfloat16*>(o0) = bfloat16(floordiv);
+      *reinterpret_cast<bfloat16*>(o1) = bfloat16(mod);
+      i0 += steps[0];
+      i1 += steps[1];
+      o0 += steps[2];
+      o1 += steps[3];
+    }
+  }
+};
+struct Fmod {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(std::fmod(static_cast<float>(a), static_cast<float>(b)));
+  }
+};
+struct Negative {
+  bfloat16 operator()(bfloat16 a) { return -a; }
+};
+struct Positive {
+  bfloat16 operator()(bfloat16 a) { return a; }
+};
+struct Power {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(std::pow(static_cast<float>(a), static_cast<float>(b)));
+  }
+};
+struct Abs {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::abs(static_cast<float>(a)));
+  }
+};
+struct Cbrt {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::cbrt(static_cast<float>(a)));
+  }
+};
+struct Ceil {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::ceil(static_cast<float>(a)));
+  }
+};
+struct CopySign {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(
+        std::copysign(static_cast<float>(a), static_cast<float>(b)));
+  }
+};
+struct Exp {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::exp(static_cast<float>(a)));
+  }
+};
+struct Exp2 {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::exp2(static_cast<float>(a)));
+  }
+};
+struct Expm1 {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::expm1(static_cast<float>(a)));
+  }
+};
+struct Floor {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::floor(static_cast<float>(a)));
+  }
+};
+struct Frexp {
+  std::pair<bfloat16, int> operator()(bfloat16 a) {
+    int exp;
+    float f = std::frexp(static_cast<float>(a), &exp);
+    return {bfloat16(f), exp};
+  }
+};
+struct Heaviside {
+  bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
+    float x = static_cast<float>(bx);
+    if (Eigen::numext::isnan(x)) {
+      return bx;
+    }
+    if (x < 0) {
+      return bfloat16(0.0f);
+    }
+    if (x > 0) {
+      return bfloat16(1.0f);
+    }
+    return h0;  // x == 0
+  }
+};
+struct Conjugate {
+  bfloat16 operator()(bfloat16 a) { return a; }
+};
+struct IsFinite {
+  bool operator()(bfloat16 a) { return std::isfinite(static_cast<float>(a)); }
+};
+struct IsInf {
+  bool operator()(bfloat16 a) { return std::isinf(static_cast<float>(a)); }
+};
+struct IsNan {
+  bool operator()(bfloat16 a) {
+    return Eigen::numext::isnan(static_cast<float>(a));
+  }
+};
+struct Ldexp {
+  bfloat16 operator()(bfloat16 a, int exp) {
+    return bfloat16(std::ldexp(static_cast<float>(a), exp));
+  }
+};
+struct Log {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::log(static_cast<float>(a)));
+  }
+};
+struct Log2 {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::log2(static_cast<float>(a)));
+  }
+};
+struct Log10 {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::log10(static_cast<float>(a)));
+  }
+};
+struct Log1p {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::log1p(static_cast<float>(a)));
+  }
+};
+struct LogAddExp {
+  bfloat16 operator()(bfloat16 bx, bfloat16 by) {
+    float x = static_cast<float>(bx);
+    float y = static_cast<float>(by);
+    if (x == y) {
+      // Handles infinities of the same sign.
+      return bfloat16(x + std::log(2.0f));
+    }
+    float out = std::numeric_limits<float>::quiet_NaN();
+    if (x > y) {
+      out = x + std::log1p(std::exp(y - x));
+    } else if (x < y) {
+      out = y + std::log1p(std::exp(x - y));
+    }
+    return bfloat16(out);
+  }
+};
+struct LogAddExp2 {
+  bfloat16 operator()(bfloat16 bx, bfloat16 by) {
+    float x = static_cast<float>(bx);
+    float y = static_cast<float>(by);
+    if (x == y) {
+      // Handles infinities of the same sign.
+      return bfloat16(x + 1.0f);
+    }
+    float out = std::numeric_limits<float>::quiet_NaN();
+    if (x > y) {
+      out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
+    } else if (x < y) {
+      out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
+    }
+    return bfloat16(out);
+  }
+};
+struct Modf {
+  std::pair<bfloat16, bfloat16> operator()(bfloat16 a) {
+    float integral;
+    float f = std::modf(static_cast<float>(a), &integral);
+    return {bfloat16(f), bfloat16(integral)};
+  }
+};
+
+struct Reciprocal {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(1.f / static_cast<float>(a));
+  }
+};
+struct Rint {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::rint(static_cast<float>(a)));
+  }
+};
+struct Sign {
+  bfloat16 operator()(bfloat16 a) {
+    float f(a);
+    if (f < 0) {
+      return bfloat16(-1);
+    }
+    if (f > 0) {
+      return bfloat16(1);
+    }
+    return a;
+  }
+};
+struct SignBit {
+  bool operator()(bfloat16 a) { return std::signbit(static_cast<float>(a)); }
+};
+struct Sqrt {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::sqrt(static_cast<float>(a)));
+  }
+};
+struct Square {
+  bfloat16 operator()(bfloat16 a) {
+    float f(a);
+    return bfloat16(f * f);
+  }
+};
+struct Trunc {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::trunc(static_cast<float>(a)));
+  }
+};
+
+// Trigonometric functions
+struct Sin {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::sin(static_cast<float>(a)));
+  }
+};
+struct Cos {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::cos(static_cast<float>(a)));
+  }
+};
+struct Tan {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::tan(static_cast<float>(a)));
+  }
+};
+struct Arcsin {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::asin(static_cast<float>(a)));
+  }
+};
+struct Arccos {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::acos(static_cast<float>(a)));
+  }
+};
+struct Arctan {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::atan(static_cast<float>(a)));
+  }
+};
+struct Arctan2 {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(std::atan2(static_cast<float>(a), static_cast<float>(b)));
+  }
+};
+struct Hypot {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(std::hypot(static_cast<float>(a), static_cast<float>(b)));
+  }
+};
+struct Sinh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::sinh(static_cast<float>(a)));
+  }
+};
+struct Cosh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::cosh(static_cast<float>(a)));
+  }
+};
+struct Tanh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::tanh(static_cast<float>(a)));
+  }
+};
+struct Arcsinh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::asinh(static_cast<float>(a)));
+  }
+};
+struct Arccosh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::acosh(static_cast<float>(a)));
+  }
+};
+struct Arctanh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::atanh(static_cast<float>(a)));
+  }
+};
+struct Deg2rad {
+  bfloat16 operator()(bfloat16 a) {
+    static constexpr float radians_per_degree = M_PI / 180.0f;
+    return bfloat16(static_cast<float>(a) * radians_per_degree);
+  }
+};
+struct Rad2deg {
+  bfloat16 operator()(bfloat16 a) {
+    static constexpr float degrees_per_radian = 180.0f / M_PI;
+    return bfloat16(static_cast<float>(a) * degrees_per_radian);
+  }
+};
+
+struct Eq {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
 };
-struct Bfloat16NeFunctor {
+struct Ne {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
 };
-struct Bfloat16LtFunctor {
+struct Lt {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
 };
-struct Bfloat16GtFunctor {
+struct Gt {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
 };
-struct Bfloat16LeFunctor {
+struct Le {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
 };
-struct Bfloat16GeFunctor {
+struct Ge {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
 };
+struct Maximum {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    float fa(a), fb(b);
+    return Eigen::numext::isnan(fa) || fa > fb ? a : b;
+  }
+};
+struct Minimum {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    float fa(a), fb(b);
+    return Eigen::numext::isnan(fa) || fa < fb ? a : b;
+  }
+};
+struct Fmax {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    float fa(a), fb(b);
+    return Eigen::numext::isnan(fb) || fa > fb ? a : b;
+  }
+};
+struct Fmin {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    float fa(a), fb(b);
+    return Eigen::numext::isnan(fb) || fa < fb ? a : b;
+  }
+};
+
+struct LogicalNot {
+  npy_bool operator()(bfloat16 a) { return !a; }
+};
+struct LogicalAnd {
+  npy_bool operator()(bfloat16 a, bfloat16 b) { return a && b; }
+};
+struct LogicalOr {
+  npy_bool operator()(bfloat16 a, bfloat16 b) { return a || b; }
+};
+struct LogicalXor {
+  npy_bool operator()(bfloat16 a, bfloat16 b) {
+    return static_cast<bool>(a) ^ static_cast<bool>(b);
+  }
+};
+
+struct NextAfter {
+  bfloat16 operator()(bfloat16 from, bfloat16 to) {
+    uint16_t from_as_int, to_as_int;
+    const uint16_t sign_mask = 1 << 15;
+    float from_as_float(from), to_as_float(to);
+    memcpy(&from_as_int, &from, sizeof(bfloat16));
+    memcpy(&to_as_int, &to, sizeof(bfloat16));
+    if (Eigen::numext::isnan(from_as_float) ||
+        Eigen::numext::isnan(to_as_float)) {
+      return bfloat16(std::numeric_limits<float>::quiet_NaN());
+    }
+    if (from_as_int == to_as_int) {
+      return to;
+    }
+    if (from_as_float == 0) {
+      if (to_as_float == 0) {
+        return to;
+      } else {
+        // Smallest subnormal signed like `to`.
+        uint16_t out_int = (to_as_int & sign_mask) | 1;
+        bfloat16 out;
+        memcpy(&out, &out_int, sizeof(bfloat16));
+        return out;
+      }
+    }
+    uint16_t from_sign = from_as_int & sign_mask;
+    uint16_t to_sign = to_as_int & sign_mask;
+    uint16_t from_abs = from_as_int & ~sign_mask;
+    uint16_t to_abs = to_as_int & ~sign_mask;
+    uint16_t magnitude_adjustment =
+        (from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
+    uint16_t out_int = from_as_int + magnitude_adjustment;
+    bfloat16 out;
+    memcpy(&out, &out_int, sizeof(bfloat16));
+    return out;
+  }
+};
+
+// TODO(phawkins): implement spacing
+
+}  // namespace ufuncs
+
+}  // namespace
 
 // Initializes the module.
 bool Initialize() {
-  // It's critical to ImportNumpy and import umath
-  // to avoid crash in open source build.
   ImportNumpy();
   import_umath1(false);
 
-  Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
+  Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy"));
   if (!numpy_str) {
     return false;
   }
@@ -582,10 +1310,30 @@
     return false;
   }
 
-  // We hit a mysterious crash if we haven't initialized numpy before this:
-  PyBfloat16_Type.tp_base = &PyGenericArrType_Type;
+  // If another module (presumably either TF or JAX) has registered a bfloat16
+  // type, use it. We don't want two bfloat16 types if we can avoid it since it
+  // leads to confusion if we have two different types with the same name. This
+  // assumes that the other module has a sufficiently complete bfloat16
+  // implementation. The only known NumPy bfloat16 extension at the time of
+  // writing is this one (distributed in TF and JAX).
+  // TODO(phawkins): distribute the bfloat16 extension as its own pip package,
+  // so we can unambiguously refer to a single canonical definition of bfloat16.
+  int typenum = PyArray_TypeNumFromName(const_cast<char*>("bfloat16"));
+  if (typenum != NPY_NOTYPE) {
+    PyArray_Descr* descr = PyArray_DescrFromType(typenum);
+    // The test for an argmax function here is to verify that the
+    // bfloat16 implementation is sufficiently new, and, say, not from
+    // an older version of TF or JAX.
+    if (descr && descr->f && descr->f->argmax) {
+      npy_bfloat16 = typenum;
+      bfloat16_type_ptr = descr->typeobj;
+      return true;
+    }
+  }
 
-  if (PyType_Ready(&PyBfloat16_Type) < 0) {
+  bfloat16_type.tp_base = &PyGenericArrType_Type;
+
+  if (PyType_Ready(&bfloat16_type) < 0) {
     return false;
   }
 
@@ -598,127 +1346,263 @@
   NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
   NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
   NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
+  NPyBfloat16_ArrFuncs.dotfunc = NPyBfloat16_DotFunc;
+  NPyBfloat16_ArrFuncs.compare = NPyBfloat16_CompareFunc;
+  NPyBfloat16_ArrFuncs.argmax = NPyBfloat16_ArgMaxFunc;
+  NPyBfloat16_ArrFuncs.argmin = NPyBfloat16_ArgMinFunc;
 
   Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
-  npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr);
-  if (npy_bfloat16_ < 0) return false;
+  npy_bfloat16 = PyArray_RegisterDataType(&NPyBfloat16_Descr);
+  bfloat16_type_ptr = &bfloat16_type;
+  if (npy_bfloat16 < 0) {
+    return false;
+  }
 
   // Support dtype(bfloat16)
-  if (PyDict_SetItemString(PyBfloat16_Type.tp_dict, "dtype",
+  if (PyDict_SetItemString(bfloat16_type.tp_dict, "dtype",
                            reinterpret_cast<PyObject*>(&NPyBfloat16_Descr)) <
       0) {
     return false;
   }
 
   // Register casts
-
-  // We lie shamelessly and say that a cast from half to bfloat16 is safe.
-  // Numpy frequently uses the smallest legal representation type for small
-  // float constants (e.g., 1.0), which is often float16. Things break if these
-  // cannot be converted transparently to bfloat16.
-  if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF, /*cast_is_safe=*/true)) {
+  if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF, /*cast_is_safe=*/false)) {
     return false;
   }
-
   if (!RegisterBfloat16Cast<float>(NPY_FLOAT, /*cast_is_safe=*/true)) {
     return false;
   }
   if (!RegisterBfloat16Cast<double>(NPY_DOUBLE, /*cast_is_safe=*/true)) {
     return false;
   }
-  if (!RegisterBfloat16Cast<int32>(NPY_INT32, /*cast_is_safe=*/false)) {
+  if (!RegisterBfloat16Cast<bool>(NPY_BOOL, /*cast_is_safe=*/false)) {
     return false;
   }
-  if (!RegisterBfloat16Cast<int64>(NPY_INT64, /*cast_is_safe=*/false)) {
+  if (!RegisterBfloat16Cast<uint8>(NPY_UINT8, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<uint16>(NPY_UINT16, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG,  // NOLINT
+                                           /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<unsigned long long>(  // NOLINT
+          NPY_ULONGLONG, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<uint64>(NPY_UINT64, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<int8>(NPY_INT8, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<int16>(NPY_INT16, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<int>(NPY_INT, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<long>(NPY_LONG,  // NOLINT
+                                  /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<long long>(  // NOLINT
+          NPY_LONGLONG, /*cast_is_safe=*/false)) {
     return false;
   }
   // Following the numpy convention. imag part is dropped when converting to
   // float.
-  if (!RegisterBfloat16Cast<complex64>(NPY_COMPLEX64, /*cast_is_safe=*/true)) {
+  if (!RegisterBfloat16Cast<std::complex<float>>(NPY_COMPLEX64,
+                                                 /*cast_is_safe=*/true)) {
     return false;
   }
-  if (!RegisterBfloat16Cast<complex128>(NPY_COMPLEX128,
-                                        /*cast_is_safe=*/true)) {
+  if (!RegisterBfloat16Cast<std::complex<double>>(NPY_COMPLEX128,
+                                                  /*cast_is_safe=*/true)) {
     return false;
   }
 
-  // Register ufuncs
-  auto register_ufunc = [&](const char* name, PyUFuncGenericFunction fn,
-                            const std::array<int, 3>& types) {
-    Safe_PyObjectPtr ufunc_obj =
-        make_safe(PyObject_GetAttrString(numpy.get(), name));
-    if (!ufunc_obj) {
-      return false;
-    }
-    PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
-    if (types.size() != ufunc->nargs) {
-      PyErr_Format(PyExc_AssertionError,
-                   "ufunc %s takes %d arguments, loop takes %lu", name,
-                   ufunc->nargs, types.size());
-      return false;
-    }
-    if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16_, fn,
-                                    const_cast<int*>(types.data()),
-                                    nullptr) < 0) {
-      return false;
-    }
-    return true;
-  };
+  bool ok =
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Add>>(numpy.get(),
+                                                                  "add") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Subtract>>(
+          numpy.get(), "subtract") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Multiply>>(
+          numpy.get(), "multiply") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::TrueDivide>>(
+          numpy.get(), "divide") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::LogAddExp>>(
+          numpy.get(), "logaddexp") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::LogAddExp2>>(
+          numpy.get(), "logaddexp2") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Negative>>(
+          numpy.get(), "negative") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Positive>>(
+          numpy.get(), "positive") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::TrueDivide>>(
+          numpy.get(), "true_divide") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::FloorDivide>>(
+          numpy.get(), "floor_divide") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Power>>(numpy.get(),
+                                                                    "power") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Remainder>>(
+          numpy.get(), "remainder") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Remainder>>(
+          numpy.get(), "mod") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmod>>(numpy.get(),
+                                                                   "fmod") &&
+      RegisterUFunc<ufuncs::DivmodUFunc>(numpy.get(), "divmod") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Abs>>(numpy.get(),
+                                                                 "absolute") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Abs>>(numpy.get(),
+                                                                 "fabs") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Rint>>(numpy.get(),
+                                                                  "rint") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sign>>(numpy.get(),
+                                                                  "sign") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Heaviside>>(
+          numpy.get(), "heaviside") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Conjugate>>(
+          numpy.get(), "conjugate") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Exp>>(numpy.get(),
+                                                                 "exp") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Exp2>>(numpy.get(),
+                                                                  "exp2") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Expm1>>(numpy.get(),
+                                                                   "expm1") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log>>(numpy.get(),
+                                                                 "log") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log2>>(numpy.get(),
+                                                                  "log2") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log10>>(numpy.get(),
+                                                                   "log10") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log1p>>(numpy.get(),
+                                                                   "log1p") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sqrt>>(numpy.get(),
+                                                                  "sqrt") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Square>>(numpy.get(),
+                                                                    "square") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cbrt>>(numpy.get(),
+                                                                  "cbrt") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Reciprocal>>(
+          numpy.get(), "reciprocal") &&
 
-  // Comparisons
-  const std::array<int, 3> compare_types = {
-      {npy_bfloat16_, npy_bfloat16_, NPY_BOOL}};
+      // Trigonometric functions
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sin>>(numpy.get(),
+                                                                 "sin") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cos>>(numpy.get(),
+                                                                 "cos") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Tan>>(numpy.get(),
+                                                                 "tan") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arcsin>>(numpy.get(),
+                                                                    "arcsin") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arccos>>(numpy.get(),
+                                                                    "arccos") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arctan>>(numpy.get(),
+                                                                    "arctan") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Arctan2>>(
+          numpy.get(), "arctan2") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Hypot>>(numpy.get(),
+                                                                    "hypot") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sinh>>(numpy.get(),
+                                                                  "sinh") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cosh>>(numpy.get(),
+                                                                  "cosh") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Tanh>>(numpy.get(),
+                                                                  "tanh") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arcsinh>>(
+          numpy.get(), "arcsinh") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arccosh>>(
+          numpy.get(), "arccosh") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arctanh>>(
+          numpy.get(), "arctanh") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Deg2rad>>(
+          numpy.get(), "deg2rad") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Rad2deg>>(
+          numpy.get(), "rad2deg") &&
 
-  if (!register_ufunc("equal", CompareUFunc<Bfloat16EqFunctor>,
-                      compare_types)) {
-    return false;
-  }
-  if (!register_ufunc("not_equal", CompareUFunc<Bfloat16NeFunctor>,
-                      compare_types)) {
-    return false;
-  }
-  if (!register_ufunc("less", CompareUFunc<Bfloat16LtFunctor>, compare_types)) {
-    return false;
-  }
-  if (!register_ufunc("greater", CompareUFunc<Bfloat16GtFunctor>,
-                      compare_types)) {
-    return false;
-  }
-  if (!register_ufunc("less_equal", CompareUFunc<Bfloat16LeFunctor>,
-                      compare_types)) {
-    return false;
-  }
-  if (!register_ufunc("greater_equal", CompareUFunc<Bfloat16GeFunctor>,
-                      compare_types)) {
-    return false;
-  }
-  return true;
+      // Comparison functions
+      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Eq>>(numpy.get(),
+                                                             "equal") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Ne>>(numpy.get(),
+                                                             "not_equal") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Lt>>(numpy.get(),
+                                                             "less") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Gt>>(numpy.get(),
+                                                             "greater") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Le>>(numpy.get(),
+                                                             "less_equal") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Ge>>(numpy.get(),
+                                                             "greater_equal") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Maximum>>(
+          numpy.get(), "maximum") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Minimum>>(
+          numpy.get(), "minimum") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmax>>(numpy.get(),
+                                                                   "fmax") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmin>>(numpy.get(),
+                                                                   "fmin") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalAnd>>(
+          numpy.get(), "logical_and") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalOr>>(
+          numpy.get(), "logical_or") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalXor>>(
+          numpy.get(), "logical_xor") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::LogicalNot>>(
+          numpy.get(), "logical_not") &&
+
+      // Floating point functions
+      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsFinite>>(numpy.get(),
+                                                                  "isfinite") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsInf>>(numpy.get(),
+                                                               "isinf") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsNan>>(numpy.get(),
+                                                               "isnan") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::SignBit>>(numpy.get(),
+                                                                 "signbit") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::CopySign>>(
+          numpy.get(), "copysign") &&
+      RegisterUFunc<UnaryUFunc2<bfloat16, bfloat16, bfloat16, ufuncs::Modf>>(
+          numpy.get(), "modf") &&
+      RegisterUFunc<BinaryUFunc2<bfloat16, int, bfloat16, ufuncs::Ldexp>>(
+          numpy.get(), "ldexp") &&
+      RegisterUFunc<UnaryUFunc2<bfloat16, bfloat16, int, ufuncs::Frexp>>(
+          numpy.get(), "frexp") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Floor>>(numpy.get(),
+                                                                   "floor") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
+                                                                  "ceil") &&
+      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
+                                                                   "trunc") &&
+      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::NextAfter>>(
+          numpy.get(), "nextafter");
+
+  return ok;
 }
 
-}  // namespace
-
-void RegisterNumpyBfloat16() {
-  if (npy_bfloat16_ >= 0) {
+bool RegisterNumpyBfloat16() {
+  if (npy_bfloat16 != NPY_NOTYPE) {
     // Already initialized.
-    return;
+    return true;
   }
   if (!Initialize()) {
     if (!PyErr_Occurred()) {
       PyErr_SetString(PyExc_RuntimeError, "cannot load bfloat16 module.");
     }
     PyErr_Print();
+    return false;
   }
+  return true;
 }
 
-PyObject* Bfloat16PyType() {
-  CHECK(PyBfloat16_Type.tp_base != nullptr);
-  Py_INCREF(&PyBfloat16_Type);
-  return reinterpret_cast<PyObject*>(&PyBfloat16_Type);
+PyObject* Bfloat16Dtype() {
+  return reinterpret_cast<PyObject*>(bfloat16_type_ptr);
 }
 
-int Bfloat16NumpyType() {
-  CHECK_GE(npy_bfloat16_, 0);
-  return npy_bfloat16_;
-}
+int Bfloat16NumpyType() { return npy_bfloat16; }
 
 }  // namespace tensorflow
diff --git a/tensorflow/python/lib/core/bfloat16.h b/tensorflow/python/lib/core/bfloat16.h
index a609928..e40207b 100644
--- a/tensorflow/python/lib/core/bfloat16.h
+++ b/tensorflow/python/lib/core/bfloat16.h
@@ -20,11 +20,11 @@
 
 namespace tensorflow {
 
-// Register the bfloat16 numpy type.
-void RegisterNumpyBfloat16();
+// Register the bfloat16 numpy type. Returns true on success.
+bool RegisterNumpyBfloat16();
 
-// Returns the PyObject for the bfloat16 type.
-PyObject* Bfloat16PyType();
+// Returns a pointer to the bfloat16 dtype object.
+PyObject* Bfloat16Dtype();
 
 // Returns the id number of the bfloat16 numpy type.
 int Bfloat16NumpyType();
diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py
index f190299..0bd5f0c 100644
--- a/tensorflow/python/lib/core/bfloat16_test.py
+++ b/tensorflow/python/lib/core/bfloat16_test.py
@@ -12,54 +12,82 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """Test cases for the bfloat16 Python type."""
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
+import copy
+import itertools
 import math
 
+from absl.testing import absltest
+from absl.testing import parameterized
+
 import numpy as np
 
 # pylint: disable=unused-import,g-bad-import-order
-from tensorflow.python import _pywrap_bfloat16
 from tensorflow.python.framework import dtypes
+from tensorflow.python.lib.core import _pywrap_bfloat16
 from tensorflow.python.platform import test
 
-
 bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
 
 
-def float_values():
-  """Returns values that should round trip exactly to float and back."""
-  epsilon = float.fromhex("1.0p-7")
-  return [
-      0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-      -1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
-      float("inf"),
-      float("-inf"),
-      float("nan")
-  ]
+def numpy_assert_allclose(a, b, **kwargs):
+  a = a.astype(np.float32) if a.dtype == bfloat16 else a
+  b = b.astype(np.float32) if b.dtype == bfloat16 else b
+  return np.testing.assert_allclose(a, b, **kwargs)
 
 
-class Bfloat16Test(test.TestCase):
+epsilon = float.fromhex("1.0p-7")
 
-  def _assertFloatIdentical(self, v, w):
-    if math.isnan(v):
-      self.assertTrue(math.isnan(w))
-    else:
-      self.assertEqual(v, w)
+# Values that should round trip exactly to float and back.
+FLOAT_VALUES = [
+    0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
+    -1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
+    float("inf"),
+    float("-inf"),
+    float("nan")
+]
+
+
+class Bfloat16Test(parameterized.TestCase):
+  """Tests the non-numpy Python methods of the bfloat16 type."""
 
   def testRoundTripToFloat(self):
-    for v in float_values():
-      self._assertFloatIdentical(v, float(bfloat16(v)))
+    for v in FLOAT_VALUES:
+      np.testing.assert_equal(v, float(bfloat16(v)))
+
+  def testRoundTripNumpyTypes(self):
+    for dtype in [np.float16, np.float32, np.float64]:
+      np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
+      np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
+      np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
+      np.testing.assert_equal(
+          np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
 
   def testRoundTripToInt(self):
     for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
       self.assertEqual(v, int(bfloat16(v)))
 
+  # pylint: disable=g-complex-comprehension
+  @parameterized.named_parameters(({
+      "testcase_name": "_" + dtype.__name__,
+      "dtype": dtype
+  } for dtype in [bfloat16, np.float16, np.float32, np.float64]))
+  def testRoundTripToNumpy(self, dtype):
+    for v in FLOAT_VALUES:
+      np.testing.assert_equal(v, bfloat16(dtype(v)))
+      np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
+      np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
+    if dtype != bfloat16:
+      np.testing.assert_equal(
+          np.array(FLOAT_VALUES, dtype),
+          bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
+
   def testStr(self):
     self.assertEqual("0", str(bfloat16(0.0)))
     self.assertEqual("1", str(bfloat16(1.0)))
@@ -70,14 +98,13 @@
     self.assertEqual("nan", str(bfloat16(float("nan"))))
 
   def testRepr(self):
-    self.assertEqual("bfloat16(0)", repr(bfloat16(0)))
-    self.assertEqual("bfloat16(1)", repr(bfloat16(1)))
-    self.assertEqual("bfloat16(-3.5)", repr(bfloat16(-3.5)))
-    self.assertEqual("bfloat16(0.0078125)",
-                     repr(bfloat16(float.fromhex("1.0p-7"))))
-    self.assertEqual("bfloat16(inf)", repr(bfloat16(float("inf"))))
-    self.assertEqual("bfloat16(-inf)", repr(bfloat16(float("-inf"))))
-    self.assertEqual("bfloat16(nan)", repr(bfloat16(float("nan"))))
+    self.assertEqual("0", repr(bfloat16(0)))
+    self.assertEqual("1", repr(bfloat16(1)))
+    self.assertEqual("-3.5", repr(bfloat16(-3.5)))
+    self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
+    self.assertEqual("inf", repr(bfloat16(float("inf"))))
+    self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
+    self.assertEqual("nan", repr(bfloat16(float("nan"))))
 
   def testHash(self):
     self.assertEqual(0, hash(bfloat16(0.0)))
@@ -86,115 +113,166 @@
 
   # Tests for Python operations
   def testNegate(self):
-    for v in float_values():
-      self._assertFloatIdentical(-v, float(-bfloat16(v)))
+    for v in FLOAT_VALUES:
+      np.testing.assert_equal(-v, float(-bfloat16(v)))
 
   def testAdd(self):
-    self._assertFloatIdentical(0, float(bfloat16(0) + bfloat16(0)))
-    self._assertFloatIdentical(1, float(bfloat16(1) + bfloat16(0)))
-    self._assertFloatIdentical(0, float(bfloat16(1) + bfloat16(-1)))
-    self._assertFloatIdentical(5.5, float(bfloat16(2) + bfloat16(3.5)))
-    self._assertFloatIdentical(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
-    self._assertFloatIdentical(float("inf"),
-                               float(bfloat16(float("inf")) + bfloat16(-2.25)))
-    self._assertFloatIdentical(float("-inf"),
-                               float(bfloat16(float("-inf")) + bfloat16(-2.25)))
+    np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
+    np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
+    np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
+    np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
+    np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
     self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
 
+    # Test type promotion against Numpy scalar values.
+    self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
+    self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
+    self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
+    self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
+    self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
+    self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
+    self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
+    self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
+    self.assertEqual(np.float32,
+                     type(bfloat16(3.5) + np.array(2.25, np.float32)))
+    self.assertEqual(np.float32,
+                     type(np.array(3.5, np.float32) + bfloat16(2.25)))
+
   def testSub(self):
-    self._assertFloatIdentical(0, float(bfloat16(0) - bfloat16(0)))
-    self._assertFloatIdentical(1, float(bfloat16(1) - bfloat16(0)))
-    self._assertFloatIdentical(2, float(bfloat16(1) - bfloat16(-1)))
-    self._assertFloatIdentical(-1.5, float(bfloat16(2) - bfloat16(3.5)))
-    self._assertFloatIdentical(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
-    self._assertFloatIdentical(float("-inf"),
-                               float(bfloat16(-2.25) - bfloat16(float("inf"))))
-    self._assertFloatIdentical(float("inf"),
-                               float(bfloat16(-2.25) - bfloat16(float("-inf"))))
+    np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
+    np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
+    np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
+    np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
+    np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
+    np.testing.assert_equal(
+        float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
     self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
 
   def testMul(self):
-    self._assertFloatIdentical(0, float(bfloat16(0) * bfloat16(0)))
-    self._assertFloatIdentical(0, float(bfloat16(1) * bfloat16(0)))
-    self._assertFloatIdentical(-1, float(bfloat16(1) * bfloat16(-1)))
-    self._assertFloatIdentical(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
-    self._assertFloatIdentical(float("-inf"),
-                               float(bfloat16(float("inf")) * bfloat16(-2.25)))
-    self._assertFloatIdentical(float("inf"),
-                               float(bfloat16(float("-inf")) * bfloat16(-2.25)))
+    np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
+    np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
+    np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
+    np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
     self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
 
   def testDiv(self):
     self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
-    self._assertFloatIdentical(float("inf"), float(bfloat16(1) / bfloat16(0)))
-    self._assertFloatIdentical(-1, float(bfloat16(1) / bfloat16(-1)))
-    self._assertFloatIdentical(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
-    self._assertFloatIdentical(float("-inf"),
-                               float(bfloat16(float("inf")) / bfloat16(-2.25)))
-    self._assertFloatIdentical(float("inf"),
-                               float(bfloat16(float("-inf")) / bfloat16(-2.25)))
+    np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
+    np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
+    np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
+    np.testing.assert_equal(
+        float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
     self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
 
   def testLess(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
 
   def testLessEqual(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
 
   def testGreater(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
 
   def testGreaterEqual(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
 
   def testEqual(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
 
   def testNotEqual(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
 
   def testNan(self):
     a = np.isnan(bfloat16(float("nan")))
     self.assertTrue(a)
-    np.testing.assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
+    numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
 
-    a = np.array(
-        [bfloat16(1.34375),
-         bfloat16(1.4375),
-         bfloat16(float("nan"))],
-        dtype=dtypes.bfloat16.as_numpy_dtype)
+    a = np.array([bfloat16(1.34375),
+                  bfloat16(1.4375),
+                  bfloat16(float("nan"))],
+                 dtype=bfloat16)
     b = np.array(
         [bfloat16(1.3359375),
          bfloat16(1.4375),
          bfloat16(float("nan"))],
-        dtype=dtypes.bfloat16.as_numpy_dtype)
-    np.testing.assert_allclose(
+        dtype=bfloat16)
+    numpy_assert_allclose(
         a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
 
+  def testSort(self):
+    values_to_sort = np.float32(FLOAT_VALUES)
+    sorted_f32 = np.sort(values_to_sort)
+    sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
+    np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
 
-class Bfloat16NumPyTest(test.TestCase):
+
+BinaryOp = collections.namedtuple("BinaryOp", ["op"])
+
+UNARY_UFUNCS = [
+    np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
+    np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
+    np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
+    np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
+    np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
+]
+
+BINARY_UFUNCS = [
+    np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
+    np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
+    np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
+]
+
+BINARY_PREDICATE_UFUNCS = [
+    np.equal, np.not_equal, np.less, np.greater, np.less_equal,
+    np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
+]
+
+
+class Bfloat16NumPyTest(parameterized.TestCase):
+  """Tests the NumPy integration of the bfloat16 type."""
 
   def testDtype(self):
     self.assertEqual(bfloat16, np.dtype(bfloat16))
 
+  def testDeepCopyDoesNotAlterHash(self):
+    # For context, see https://github.com/google/jax/issues/4651. If the hash
+    # value of the type descriptor is not initialized correctly, a deep copy
+    # can change the type hash.
+    dtype = np.dtype(bfloat16)
+    h = hash(dtype)
+    _ = copy.deepcopy(dtype)
+    self.assertEqual(h, hash(dtype))
+
   def testArray(self):
     x = np.array([[1, 2, 3]], dtype=bfloat16)
     self.assertEqual(bfloat16, x.dtype)
-    self.assertEqual("[[bfloat16(1) bfloat16(2) bfloat16(3)]]", str(x))
-    self.assertAllEqual(x, x)
-    self.assertAllClose(x, x)
+    self.assertEqual("[[1 2 3]]", str(x))
+    np.testing.assert_equal(x, x)
+    numpy_assert_allclose(x, x)
     self.assertTrue((x == x).all())
 
   def testComparisons(self):
@@ -202,12 +280,12 @@
     bx = x.astype(bfloat16)
     y = np.array([82432, 7, 0], dtype=np.float32)
     by = y.astype(bfloat16)
-    self.assertAllEqual(x == y, bx == by)
-    self.assertAllEqual(x != y, bx != by)
-    self.assertAllEqual(x < y, bx < by)
-    self.assertAllEqual(x > y, bx > by)
-    self.assertAllEqual(x <= y, bx <= by)
-    self.assertAllEqual(x >= y, bx >= by)
+    np.testing.assert_equal(x == y, bx == by)
+    np.testing.assert_equal(x != y, bx != by)
+    np.testing.assert_equal(x < y, bx < by)
+    np.testing.assert_equal(x > y, bx > by)
+    np.testing.assert_equal(x <= y, bx <= by)
+    np.testing.assert_equal(x >= y, bx >= by)
 
   def testEqual2(self):
     a = np.array([401408], bfloat16)
@@ -216,8 +294,10 @@
 
   def testCasts(self):
     for dtype in [
-        np.float16, np.float32, np.float64, np.int32, np.int64,
-        np.complex64, np.complex128]:
+        np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
+        np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
+        np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
+    ]:
       x = np.array([[1, 2, 3]], dtype=dtype)
       y = x.astype(bfloat16)
       z = y.astype(dtype)
@@ -231,44 +311,133 @@
       x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
       y_np = x.astype(np.float32)
       y_tf = x.astype(bfloat16)
-      self.assertAllClose(y_np, y_tf, atol=2e-2)
+      numpy_assert_allclose(y_np, y_tf, atol=2e-2)
 
       z_np = y_np.astype(dtype)
       z_tf = y_tf.astype(dtype)
-      self.assertAllClose(z_np, z_tf, atol=2e-2)
-
-  def testAdd(self):
-    x = np.array([[1, 2, 3]], dtype=bfloat16)
-    y = np.array([[4, 5, 6]], dtype=bfloat16)
-    self.assertAllClose(np.array([[5, 7, 9]]), x + y)
-
-  def testLogSumExp(self):
-    x = np.array([[1, 2, 3]], dtype=np.float32)
-    y = np.array([[4, 5, 6]], dtype=np.float32)
-    self.assertAllClose(np.logaddexp(x, y),
-                        np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)),
-                        atol=2e-2)
+      numpy_assert_allclose(z_np, z_tf, atol=2e-2)
 
   def testArange(self):
-    self.assertAllEqual(
+    np.testing.assert_equal(
         np.arange(100, dtype=np.float32).astype(bfloat16),
         np.arange(100, dtype=bfloat16))
-    self.assertAllEqual(
+    np.testing.assert_equal(
         np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
         np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
-    self.assertAllEqual(
+    np.testing.assert_equal(
         np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
         np.arange(-0., -7., -0.25, dtype=bfloat16))
-    self.assertAllEqual(
+    np.testing.assert_equal(
         np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
         np.arange(-16384., 16384., 64., dtype=bfloat16))
 
-  def testSort(self):
-    values_to_sort = np.float32(float_values())
-    sorted_f32 = np.sort(values_to_sort)
-    sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
-    self.assertAllEqual(sorted_f32, np.float32(sorted_bf16))
+  # pylint: disable=g-complex-comprehension
+  @parameterized.named_parameters(({
+      "testcase_name": "_" + op.__name__,
+      "op": op
+  } for op in UNARY_UFUNCS))
+  def testUnaryUfunc(self, op):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7, 10).astype(bfloat16)
+    numpy_assert_allclose(
+        op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
+
+  @parameterized.named_parameters(({
+      "testcase_name": "_" + op.__name__,
+      "op": op
+  } for op in BINARY_UFUNCS))
+  def testBinaryUfunc(self, op):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7, 10).astype(bfloat16)
+    y = rng.randn(4, 1, 7, 10).astype(bfloat16)
+    numpy_assert_allclose(
+        op(x, y).astype(np.float32),
+        op(x.astype(np.float32), y.astype(np.float32)),
+        rtol=1e-2)
+
+  @parameterized.named_parameters(({
+      "testcase_name": "_" + op.__name__,
+      "op": op
+  } for op in BINARY_PREDICATE_UFUNCS))
+  def testBinaryPredicateUfunc(self, op):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7).astype(bfloat16)
+    y = rng.randn(4, 1, 7).astype(bfloat16)
+    np.testing.assert_equal(
+        op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
+
+  @parameterized.named_parameters(({
+      "testcase_name": "_" + op.__name__,
+      "op": op
+  } for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
+  def testPredicateUfunc(self, op):
+    rng = np.random.RandomState(seed=42)
+    shape = (3, 7, 10)
+    posinf_flips = rng.rand(*shape) < 0.1
+    neginf_flips = rng.rand(*shape) < 0.1
+    nan_flips = rng.rand(*shape) < 0.1
+    vals = rng.randn(*shape)
+    vals = np.where(posinf_flips, np.inf, vals)
+    vals = np.where(neginf_flips, -np.inf, vals)
+    vals = np.where(nan_flips, np.nan, vals)
+    vals = vals.astype(bfloat16)
+    np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
+
+  def testDivmod(self):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7).astype(bfloat16)
+    y = rng.randn(4, 1, 7).astype(bfloat16)
+    o1, o2 = np.divmod(x, y)
+    e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
+    numpy_assert_allclose(o1, e1, rtol=1e-2)
+    numpy_assert_allclose(o2, e2, rtol=1e-2)
+
+  def testModf(self):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7).astype(bfloat16)
+    o1, o2 = np.modf(x)
+    e1, e2 = np.modf(x.astype(np.float32))
+    numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
+    numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
+
+  def testLdexp(self):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7).astype(bfloat16)
+    y = rng.randint(-50, 50, (1, 7))
+    numpy_assert_allclose(
+        np.ldexp(x, y).astype(np.float32),
+        np.ldexp(x.astype(np.float32), y),
+        rtol=1e-2,
+        atol=1e-6)
+
+  def testFrexp(self):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7).astype(bfloat16)
+    mant1, exp1 = np.frexp(x)
+    mant2, exp2 = np.frexp(x.astype(np.float32))
+    np.testing.assert_equal(exp1, exp2)
+    numpy_assert_allclose(mant1, mant2, rtol=1e-2)
+
+  def testNextAfter(self):
+    one = np.array(1., dtype=bfloat16)
+    two = np.array(2., dtype=bfloat16)
+    zero = np.array(0., dtype=bfloat16)
+    nan = np.array(np.nan, dtype=bfloat16)
+    np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
+    np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
+    np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
+    np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
+    np.testing.assert_equal(np.nextafter(one, one), one)
+    smallest_denormal = float.fromhex("1.0p-133")
+    np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
+    np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
+    for a, b in itertools.permutations([0., -0., nan], 2):
+      np.testing.assert_equal(
+          np.nextafter(
+              np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
+          np.nextafter(
+              np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
 
 
 if __name__ == "__main__":
-  test.main()
+  absltest.main()
diff --git a/tensorflow/python/lib/core/bfloat16_wrapper.cc b/tensorflow/python/lib/core/bfloat16_wrapper.cc
index eb346af..741468b 100644
--- a/tensorflow/python/lib/core/bfloat16_wrapper.cc
+++ b/tensorflow/python/lib/core/bfloat16_wrapper.cc
@@ -20,5 +20,5 @@
   tensorflow::RegisterNumpyBfloat16();
 
   m.def("TF_bfloat16_type",
-        [] { return pybind11::handle(tensorflow::Bfloat16PyType()); });
+        [] { return pybind11::handle(tensorflow::Bfloat16Dtype()); });
 }
diff --git a/tensorflow/python/lib/io/BUILD b/tensorflow/python/lib/io/BUILD
new file mode 100644
index 0000000..71c35f2
--- /dev/null
+++ b/tensorflow/python/lib/io/BUILD
@@ -0,0 +1,103 @@
+# python/lib/io package
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+visibility = [
+    "//tensorflow:__subpackages__",
+]
+
+package(
+    default_visibility = visibility,
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "py_record_reader_lib",
+    srcs = ["py_record_reader.cc"],
+    hdrs = ["py_record_reader.h"],
+    deps = [
+        "//tensorflow/c:c_api",
+        "//tensorflow/c:tf_status_helper",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_file_io",
+    srcs = ["file_io_wrapper.cc"],
+    module_name = "_pywrap_file_io",
+    deps = [
+        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python/lib/core:pybind11_absl",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "@pybind11",
+    ],
+)
+
+py_library(
+    name = "lib",
+    srcs = [
+        "file_io.py",
+        "python_io.py",
+        "tf_record.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":_pywrap_file_io",
+        ":_pywrap_record_io",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:pywrap_tensorflow",
+        "//tensorflow/python:util",
+        "@six_archive//:six",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_record_io",
+    srcs = ["record_io_wrapper.cc"],
+    module_name = "_pywrap_record_io",
+    deps = [
+        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:lib_headers_for_pybind",
+        "//tensorflow/core/platform:types",
+        "//tensorflow/python/lib/core:pybind11_absl",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "@com_google_absl//absl/memory",
+        "@pybind11",
+    ],
+)
+
+tf_py_test(
+    name = "file_io_test",
+    size = "small",
+    srcs = ["file_io_test.py"],
+    python_version = "PY3",
+    tags = [
+        "no_rocm",
+        "no_windows",
+    ],
+    deps = [
+        ":lib",
+        "//tensorflow/python:errors",
+        "//tensorflow/python/platform:client_testlib",
+    ],
+)
+
+tf_py_test(
+    name = "tf_record_test",
+    size = "small",
+    srcs = ["tf_record_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":lib",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:util",
+        "//tensorflow/python/platform:client_testlib",
+    ],
+)
diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py
index a4c9b61..73365b3 100644
--- a/tensorflow/python/lib/io/file_io.py
+++ b/tensorflow/python/lib/io/file_io.py
@@ -23,8 +23,8 @@
 
 import six
 
-from tensorflow.python import _pywrap_file_io
 from tensorflow.python.framework import errors
+from tensorflow.python.lib.io import _pywrap_file_io
 from tensorflow.python.util import compat
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/lib/io/file_io_wrapper.cc b/tensorflow/python/lib/io/file_io_wrapper.cc
index 3ede938..c1ee09b 100644
--- a/tensorflow/python/lib/io/file_io_wrapper.cc
+++ b/tensorflow/python/lib/io/file_io_wrapper.cc
@@ -239,7 +239,7 @@
              py::gil_scoped_release release;
              auto* env = tensorflow::Env::Default();
              std::unique_ptr<WritableFile> self;
-             const auto status = mode.find("a") == std::string::npos
+             const auto status = mode.find('a') == std::string::npos
                                      ? env->NewWritableFile(filename, &self)
                                      : env->NewAppendableFile(filename, &self);
              py::gil_scoped_acquire acquire;
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 6c2be8d..f4315ee 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -19,7 +19,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python import _pywrap_record_io
+from tensorflow.python.lib.io import _pywrap_record_io
 from tensorflow.python.util import compat
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 02375c8..652b1ee 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1221,30 +1221,27 @@
 
   parent_name = name
 
-  if not (var is None and isinstance(op, ops.EagerTensor)):
-
+  if var is not None:
     def assign(val, name=None):
       """Closure that holds all the arguments to create an assignment."""
 
-      if var is None:
-        raise ValueError("Sliced assignment is only supported for variables")
-      else:
-        if name is None:
-          name = parent_name + "_assign"
+      if name is None:
+        name = parent_name + "_assign"
 
-        return var._strided_slice_assign(
-            begin=begin,
-            end=end,
-            strides=strides,
-            value=val,
-            name=name,
-            begin_mask=begin_mask,
-            end_mask=end_mask,
-            ellipsis_mask=ellipsis_mask,
-            new_axis_mask=new_axis_mask,
-            shrink_axis_mask=shrink_axis_mask)
+      return var._strided_slice_assign(
+          begin=begin,
+          end=end,
+          strides=strides,
+          value=val,
+          name=name,
+          begin_mask=begin_mask,
+          end_mask=end_mask,
+          ellipsis_mask=ellipsis_mask,
+          new_axis_mask=new_axis_mask,
+          shrink_axis_mask=shrink_axis_mask)
 
     op.assign = assign
+
   return op
 
 
@@ -4759,6 +4756,7 @@
 
 
 @tf_export("reverse_sequence", v1=[])
+@dispatch.add_dispatch_support
 def reverse_sequence_v2(input,
                         seq_lengths,
                         seq_axis=None,
diff --git a/tensorflow/python/ops/batch_ops_test.py b/tensorflow/python/ops/batch_ops_test.py
index 5749be9..15c670e 100644
--- a/tensorflow/python/ops/batch_ops_test.py
+++ b/tensorflow/python/ops/batch_ops_test.py
@@ -25,12 +25,17 @@
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.framework.errors import InvalidArgumentError
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import batch_ops
 from tensorflow.python.ops import gen_batch_ops
+from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import script_ops
+from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
 
@@ -50,7 +55,7 @@
     """Tests that a single batched tensor executes together and only once."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, index, _ = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=2,
@@ -92,7 +97,7 @@
     """Test that batching with padding up to an allowed batch size works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
       batched, index, _ = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=10,
@@ -124,7 +129,7 @@
     """Tests that multiple batched tensors execute together."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, _, _ = batch_ops.batch(
@@ -165,7 +170,7 @@
     """Tests illegally feeding tensors with different dim0 sizes."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
       batched, index, _ = batch_ops.batch(
@@ -181,7 +186,7 @@
     """Tests that batch and unbatch work together."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, index, id_t = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=10,
@@ -207,7 +212,7 @@
     """Tests that the batch_function decorator works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       # TODO(apassos): Removing this line causes test flakiness! Ideally should
       # be investigated.
       default_inp = array_ops.placeholder_with_default(2, shape=[])  # pylint: disable=unused-variable
@@ -235,33 +240,62 @@
     """Tests that the batch_function decorator works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
-      captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
-      captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
+    with self.cached_session(use_gpu=True) as sess:
+      captured_inp0 = array_ops.placeholder_with_default(2., shape=[])
+      captured_inp1 = resource_variable_ops.ResourceVariable(3.)
+      with ops.device("/cpu:0"):
+        captured_inp2 = resource_variable_ops.ResourceVariable(4.)
 
       @batch_ops.batch_function(1, 10, 100000)
       def computation(in_t):
-        return in_t + captured_inp0 - captured_inp1
+        return in_t + captured_inp0 + captured_inp1 + captured_inp2
 
-      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+      inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
       result = computation(inp)
       thread_results = []
 
       def worker():
         thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
 
+      sess.run(variables.global_variables_initializer())
       worker_thread = threading.Thread(target=worker)
       worker_thread.start()
       main_results = sess.run([result], feed_dict={inp: [2]})
       worker_thread.join()
-      self.assertEqual(thread_results[0], [2])
-      self.assertEqual(main_results[0], [3])
+      self.assertEqual(thread_results[0], [10])
+      self.assertEqual(main_results[0], [11])
+
+  @test_util.disable_xla("DeviceIndex returns sentinel value with XLA")
+  def testBatchDecoratedGpu(self):
+    if context.executing_eagerly():
+      return
+    with self.cached_session(use_gpu=True) as sess:
+
+      @batch_ops.batch_function(1, 10, 100000)
+      def computation(in_t):
+        # index is 0 on CPU and 1 on GPU
+        index = gen_functional_ops.DeviceIndex(device_names=["CPU", "GPU"])
+        return in_t + math_ops.cast(index, dtypes.float32)
+
+      inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
+      result = computation(inp)
+      thread_results = []
+
+      def worker():
+        thread_results.extend(sess.run([result], feed_dict={inp: [10.]}))
+
+      worker_thread = threading.Thread(target=worker)
+      worker_thread.start()
+      main_results = sess.run([result], feed_dict={inp: [20.]})
+      worker_thread.join()
+      self.assertEqual(thread_results[0], [10 + test_util.is_gpu_available()])
+      self.assertEqual(main_results[0], [20 + test_util.is_gpu_available()])
 
   def testBatchFunctionOp(self):
     """Tests that the batch_function op works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
 
       @function.Defun(dtypes.int32)
       def computation(in_t):
@@ -292,7 +326,7 @@
     """Tests that batch_function op works with captured input."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
       captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@@ -328,7 +362,7 @@
     """Tests that batch_function op works with error in the inputs."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
 
       @function.Defun(dtypes.int32, dtypes.int32)
@@ -345,8 +379,9 @@
           captured_tensors=computation.captured_inputs,
           Tout=[o.type for o in computation.definition.signature.output_arg])
 
-      with self.assertRaisesRegex(InvalidArgumentError,
-                                  ".*2 arguments.*but 1.*"):
+      with self.assertRaisesRegex(
+          InvalidArgumentError,
+          r"Function takes 2 argument\(s\) but 1 argument\(s\) were passed"):
         sess.run([result], feed_dict={inp: [2]})
 
   def testBatchFunctionOpWithLargeBatchSplitted(self):
@@ -354,7 +389,7 @@
     if context.executing_eagerly():
       return
 
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
 
       @function.Defun(dtypes.int32)
       def computation(in_t):
@@ -408,7 +443,7 @@
     """Tests that the batch_function decorator works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
 
       @batch_ops.batch_function(1, 10, 100000)
       def computation(in_t):
@@ -432,7 +467,7 @@
     """Tests that the unbatch timeout works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, index, id_t = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=2,
diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py
index 4f33c3a..fcbd0ca 100644
--- a/tensorflow/python/ops/collective_ops.py
+++ b/tensorflow/python/ops/collective_ops.py
@@ -261,6 +261,40 @@
       timeout_seconds=timeout)
 
 
+def broadcast_send_v2(t,
+                      group_size,
+                      group_key,
+                      instance_key,
+                      communication_hint='auto',
+                      timeout=0):
+  """Broadcasts one tensor to a group of others, across devices.
+
+  Args:
+    t: the tensor to be sent.
+    group_size: an int32 tensor.  One plus the number of receiving tensors, i.e.
+        the total number of devices participating.  Each tensor must reside on a
+        different device.
+    group_key: an int32 tensor identifying the group of devices.
+    instance_key: an int32 tensor identifying the participating group of Ops.
+    communication_hint: preferred collective communication.  The implementation
+      may fall back to another mechanism.  Options include `auto`, `ring`, and
+      `nccl`.
+    timeout: If set to a non zero, set a completion timeout to detect staleness.
+      If the timer goes off, a DeadlineExceededError is raised.
+      The timeout value in seconds. This feature is experimental.
+
+  Returns:
+    An Op implementing the distributed broadcast send.
+  """
+  return gen_collective_ops.collective_bcast_send_v2(
+      t,
+      group_size=group_size,
+      group_key=group_key,
+      instance_key=instance_key,
+      communication_hint=communication_hint.lower(),
+      timeout_seconds=timeout)
+
+
 def broadcast_recv(shape,
                    dtype,
                    group_size,
@@ -302,3 +336,41 @@
       instance_key=instance_key,
       communication_hint=communication_hint.lower(),
       timeout_seconds=timeout)
+
+
+def broadcast_recv_v2(shape,
+                      dtype,
+                      group_size,
+                      group_key,
+                      instance_key,
+                      communication_hint='auto',
+                      timeout=0):
+  """Receives a broadcasts tensor, across devices.
+
+  Args:
+    shape: an int tensor.  Shape of the tensor to be received.
+    dtype: Type of the tensor to be received.
+    group_size: an int32 tensor.  One plus the number of receiving tensors, i.e.
+        the total number of devices participating.  Each tensor must reside on a
+        different device.
+    group_key: an int32 tensor identifying the group of devices.
+    instance_key: an int32 tensor identifying the participating group of Ops.
+    communication_hint: preferred collective communication.  The implementation
+      may fall back to another mechanism.  Options include `auto`, `ring`, and
+      `nccl`.
+    timeout: If set to a non zero, set a completion timeout to detect staleness.
+      If the timer goes off, a DeadlineExceededError is raised.
+      The timeout value in seconds. This feature is experimental.
+
+  Returns:
+    An Op implementing the broadcast receive.
+  """
+  return gen_collective_ops.collective_bcast_recv_v2(
+      T=dtype,
+      group_size=group_size,
+      group_key=group_key,
+      instance_key=instance_key,
+      shape=shape,
+      communication_hint=communication_hint.lower(),
+      timeout_seconds=timeout)
+
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 2acd2d2..2e0b944 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -2918,8 +2918,9 @@
   output.
 
   Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
-  this method, as code executes in your expected order.* Only use tf.group when
-  working with v1-style code or in a graph context such as inside `Dataset.map`.
+  this method, as ops execute in the expected order thanks to automatic control
+  dependencies.* Only use `tf.group` when working with v1
+  `tf.Graph` code.
 
   When operating in a v1-style graph context, ops are not executed in the same
   order as specified in the code; TensorFlow will attempt to execute ops in
@@ -2991,22 +2992,16 @@
 @tf_export("tuple", v1=[])
 @dispatch.add_dispatch_support
 def tuple_v2(tensors, control_inputs=None, name=None):
-  """Group tensors together.
+  """Groups tensors together.
 
-  This creates a tuple of tensors with the same values as the `tensors`
-  argument, except that the value of each tensor is only returned after the
-  values of all tensors have been computed.
+  The returned tensors have the same value as the input tensors, but they
+  are computed only after all the input tensors have been computed.
 
-  `control_inputs` contains additional ops that have to finish before this op
-  finishes, but whose outputs are not returned.
+  Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
+  this method, as ops execute in the expected order thanks to automatic control
+  dependencies.* Only use `tf.tuple` when working with v1 `tf.Graph` code.
 
-  This can be used as a "join" mechanism for parallel computations: all the
-  argument tensors can be computed in parallel, but the values of any tensor
-  returned by `tuple` are only available after all the parallel computations
-  are done.
-
-  See also `tf.group` and
-  `tf.control_dependencies`.
+  See also `tf.group` and `tf.control_dependencies`.
 
   Args:
     tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 0b6bbbd..34a1413 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -57,11 +57,9 @@
 from tensorflow.python.ops import variables
 import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
 from tensorflow.python.platform import googletest
-from tensorflow.python.platform import test
 from tensorflow.python.training import momentum
 from tensorflow.python.util import nest
 
-
 TestTuple = collections.namedtuple("TestTuple", "a b")
 SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a")
 
@@ -85,7 +83,8 @@
       c = constant_op.constant(0, name="c")
       control_flow_ops.group(a.op, b.op, c.op, name="root")
     gd = g.as_graph_def()
-    self.assertProtoEquals("""
+    self.assertProtoEquals(
+        """
       node { name: "a" op: "Const"}
       node { name: "b" op: "Const"}
       node { name: "c" op: "Const"}
@@ -99,7 +98,8 @@
         b = constant_op.constant(0, name="b")
       control_flow_ops.group(a.op, b.op, name="root")
     gd = g.as_graph_def()
-    self.assertProtoEquals("""
+    self.assertProtoEquals(
+        """
       node { name: "a" op: "Const" device: "/task:0" }
       node { name: "b" op: "Const" device: "/task:0" }
       node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
@@ -116,7 +116,8 @@
       with g.device("/task:2"):
         control_flow_ops.group(a.op, b.op, c.op, d.op, name="root")
     gd = g.as_graph_def()
-    self.assertProtoEquals("""
+    self.assertProtoEquals(
+        """
       node { name: "a" op: "Const" device: "/task:0"}
       node { name: "b" op: "Const" device: "/task:0"}
       node { name: "c" op: "Const" device: "/task:1"}
@@ -135,7 +136,8 @@
       b = constant_op.constant(0, name="b")
       control_flow_ops.group([a.op, b.op], name="root")
     gd = g.as_graph_def()
-    self.assertProtoEquals("""
+    self.assertProtoEquals(
+        """
       node { name: "a" op: "Const"}
       node { name: "b" op: "Const"}
       node { name: "root" op: "NoOp" input: "^a" input: "^b" }
@@ -165,8 +167,7 @@
         "my_counter", shape=[], initializer=init_ops.zeros_initializer())
     increment_counter = state_ops.assign_add(counter, 1)
     const_with_dep = control_flow_ops.with_dependencies(
-        (increment_counter, constant_op.constant(42)),
-        constant_op.constant(7))
+        (increment_counter, constant_op.constant(42)), constant_op.constant(7))
 
     self.evaluate(variables.global_variables_initializer())
     self.assertEqual(0, self.evaluate(counter))
@@ -179,8 +180,7 @@
         "my_counter", shape=[], initializer=init_ops.zeros_initializer())
     increment_counter = state_ops.assign_add(counter, 1)
     const_with_dep = control_flow_ops.with_dependencies(
-        [increment_counter, constant_op.constant(42)],
-        constant_op.constant(7))
+        [increment_counter, constant_op.constant(42)], constant_op.constant(7))
 
     self.evaluate(variables.global_variables_initializer())
     self.assertEqual(0, self.evaluate(counter))
@@ -364,18 +364,16 @@
     x = constant_op.constant(2)
     y = constant_op.constant(5)
     z = control_flow_ops.cond(
-        math_ops.less(
-            x,
-            y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
+        math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
+        lambda: math_ops.add(y, 23))
     self.assertEqual(self.evaluate(z), 34)
 
   def testCondFalse(self):
     x = constant_op.constant(2)
     y = constant_op.constant(1)
     z = control_flow_ops.cond(
-        math_ops.less(
-            x,
-            y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
+        math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
+        lambda: math_ops.add(y, 23))
     self.assertEqual(self.evaluate(z), 24)
 
   def testCondTrueLegacy(self):
@@ -508,16 +506,18 @@
 
   @test_util.run_deprecated_v1
   def testControlContextImportScope(self):
+
     class NoABCControlFlowContext(control_flow_ops.ControlFlowContext):
       """A noop wrapper around `ControlFlowContext`.
 
       `ControlFlowContext` is an ABC and therefore cannot be instantiated.
       """
+
       # pylint: disable=useless-super-delegation
 
       def to_control_flow_context_def(self, context_def, export_scope=None):
-        super(NoABCControlFlowContext, self).to_control_flow_context_def(
-            context_def, export_scope)
+        super(NoABCControlFlowContext,
+              self).to_control_flow_context_def(context_def, export_scope)
 
     with self.cached_session():
       constant_op.constant(0, name="a")
@@ -557,8 +557,8 @@
 
 
 def _create_tensor_array(size, shape):
-  ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size,
-                                    clear_after_read=False)
+  ta = tensor_array_ops.TensorArray(
+      dtype=dtypes.float32, size=size, clear_after_read=False)
   for i in range(size):
     ta = ta.write(i, array_ops.zeros(shape))
   return ta
@@ -585,30 +585,37 @@
     else:
       self.assertAllEqual(a, b)
 
-  def _testShape(self, fn_true, fn_false, expected_shape,
-                 strict=False):
+  def _testShape(self, fn_true, fn_false, expected_shape, strict=False):
     condition = array_ops.placeholder(dtypes.bool)
-    output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
-                                        strict=strict)
+    output_cond = control_flow_ops.cond(
+        condition, fn_true, fn_false, strict=strict)
     self.assertEqual(
         _raw_nested_shape(_get_nested_shape(output_cond)),
         _raw_nested_shape(expected_shape))
 
-    output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
+    output_case = control_flow_ops.case([(condition, fn_true)],
+                                        fn_false,
                                         strict=strict)
     self.assertEqual(
         _raw_nested_shape(_get_nested_shape(output_case)),
         _raw_nested_shape(expected_shape))
 
-  def _testReturnValues(self, fn_true, fn_false, expected_value_true,
-                        expected_value_false, strict=False,
-                        check_cond=True, feed_dict=None):
-    if feed_dict is None: feed_dict = {}
+  def _testReturnValues(self,
+                        fn_true,
+                        fn_false,
+                        expected_value_true,
+                        expected_value_false,
+                        strict=False,
+                        check_cond=True,
+                        feed_dict=None):
+    if feed_dict is None:
+      feed_dict = {}
 
     condition = array_ops.placeholder(dtypes.bool)
-    output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
-                                        strict=strict)
-    output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
+    output_cond = control_flow_ops.cond(
+        condition, fn_true, fn_false, strict=strict)
+    output_case = control_flow_ops.case([(condition, fn_true)],
+                                        fn_false,
                                         strict=strict)
 
     with self.cached_session() as sess:
@@ -650,8 +657,12 @@
   def test_noop(self):
     shape = tensor_shape.TensorShape(None)
     self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape)
-    self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op,
-                           True, False, check_cond=False)
+    self._testReturnValues(
+        control_flow_ops.no_op,
+        control_flow_ops.no_op,
+        True,
+        False,
+        check_cond=False)
 
   @test_util.run_deprecated_v1
   def test_string(self):
@@ -686,22 +697,24 @@
     def _build_true_branch(dtype):
 
       def _build():
-        return (array_ops.zeros([2, 2], dtype=dtype),
-                array_ops.ones([3, 3], dtype=dtype))
+        return (array_ops.zeros([2, 2],
+                                dtype=dtype), array_ops.ones([3, 3],
+                                                             dtype=dtype))
 
       return _build
 
     def _build_false_branch(dtype):
 
       def _build():
-        return (array_ops.ones([2, 2], dtype=dtype),
-                array_ops.zeros([3, 3], dtype=dtype))
+        return (array_ops.ones([2, 2],
+                               dtype=dtype), array_ops.zeros([3, 3],
+                                                             dtype=dtype))
 
       return _build
 
     for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
-      shape = (tensor_shape.TensorShape([2, 2]),
-               tensor_shape.TensorShape([3, 3]))
+      shape = (tensor_shape.TensorShape([2,
+                                         2]), tensor_shape.TensorShape([3, 3]))
       fn_true = _build_true_branch(dtype)
       fn_false = _build_false_branch(dtype)
       self._testShape(fn_true, fn_false, shape)
@@ -733,27 +746,36 @@
       fn_true, true_tensor = _build_true_branch(dtype)
       fn_false, false_tensor = _build_false_branch(dtype)
       self._testShape(fn_true, fn_false, shape)
-      self._testReturnValues(fn_true, fn_false,
-                             np.zeros([2, 2]), np.ones([2, 2]),
-                             feed_dict={true_tensor: np.zeros([2, 2]),
-                                        false_tensor: np.ones([2, 2])})
+      self._testReturnValues(
+          fn_true,
+          fn_false,
+          np.zeros([2, 2]),
+          np.ones([2, 2]),
+          feed_dict={
+              true_tensor: np.zeros([2, 2]),
+              false_tensor: np.ones([2, 2])
+          })
 
   @test_util.run_deprecated_v1
   def test_sparse_tensors(self):
     shape = tensor_shape.TensorShape([None, None])
 
     def true_fn():
-      return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],
-                                         values=[1, 2], dense_shape=[3, 4])]
+      return [
+          sparse_tensor.SparseTensor(
+              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
+      ]
 
     def false_fn():
-      return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]],
-                                         values=[3, 4], dense_shape=[3, 4])]
+      return [
+          sparse_tensor.SparseTensor(
+              indices=[[0, 0], [2, 1]], values=[3, 4], dense_shape=[3, 4])
+      ]
 
-    value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]],
-                                             values=[1, 2], dense_shape=[3, 4])
-    value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]],
-                                             values=[3, 4], dense_shape=[3, 4])
+    value1 = sparse_tensor.SparseTensorValue(
+        indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
+    value2 = sparse_tensor.SparseTensorValue(
+        indices=[[0, 0], [2, 1]], values=[3, 4], dense_shape=[3, 4])
     # Non-strict cond is only available in v1
     if not tf2.enabled():
       self._testShape(true_fn, false_fn, shape)
@@ -775,21 +797,24 @@
       return _build, (a, b, c)
 
     for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
-      shape = (tensor_shape.TensorShape([None, 2]),
-               tensor_shape.TensorShape([None]),
+      shape = (tensor_shape.TensorShape([None,
+                                         2]), tensor_shape.TensorShape([None]),
                tensor_shape.TensorShape([3, None]))
       fn_true, true_tensors = _build_branch(dtype, shape)
       fn_false, false_tensors = _build_branch(dtype, shape)
       self._testShape(fn_true, fn_false, shape)
-      self._testReturnValues(fn_true, fn_false,
-                             (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
-                             (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
-                             feed_dict={true_tensors[0]: np.zeros([2, 2]),
-                                        false_tensors[0]: np.zeros([2, 2]),
-                                        true_tensors[1]: np.zeros([5]),
-                                        false_tensors[1]: np.zeros([5]),
-                                        true_tensors[2]: np.ones([3, 3]),
-                                        false_tensors[2]: np.ones([3, 3])})
+      self._testReturnValues(
+          fn_true,
+          fn_false, (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
+          (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
+          feed_dict={
+              true_tensors[0]: np.zeros([2, 2]),
+              false_tensors[0]: np.zeros([2, 2]),
+              true_tensors[1]: np.zeros([5]),
+              false_tensors[1]: np.zeros([5]),
+              true_tensors[2]: np.ones([3, 3]),
+              false_tensors[2]: np.ones([3, 3])
+          })
 
   @test_util.run_deprecated_v1
   def test_tensor_arrays(self):
@@ -811,8 +836,11 @@
 
   @test_util.run_v1_only("b/138741991")
   def test_list(self):
-    shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
-             tensor_shape.TensorShape([])]
+    shape = [
+        tensor_shape.TensorShape([]),
+        tensor_shape.TensorShape([]),
+        tensor_shape.TensorShape([])
+    ]
     fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)]
     fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)]
     self._testShape(fn_true, fn_false, shape)
@@ -838,19 +866,21 @@
     fn_tuple = lambda: (constant_op.constant(3),)
 
     with self.assertRaises(ValueError):
-      control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list,
-                            strict=True)
+      control_flow_ops.cond(
+          constant_op.constant(True), fn_tensor, fn_list, strict=True)
 
     with self.assertRaises(TypeError):
-      control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple,
-                            strict=True)
+      control_flow_ops.cond(
+          constant_op.constant(True), fn_list, fn_tuple, strict=True)
 
     with self.assertRaises(ValueError):
-      control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list,
+      control_flow_ops.case([(constant_op.constant(True), fn_tensor)],
+                            fn_list,
                             strict=True)
 
     with self.assertRaises(TypeError):
-      control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple,
+      control_flow_ops.case([(constant_op.constant(True), fn_list)],
+                            fn_tuple,
                             strict=True)
 
   @test_util.run_deprecated_v1
@@ -875,8 +905,7 @@
       self._testShape(fn_true, fn_false, shape)
       self._testReturnValues(fn_true, fn_false, 1, 3)
     self._testShape(fn_true, fn_false, (shape,), strict=True)
-    self._testReturnValues(fn_true, fn_false, (1,), (3,),
-                           strict=True)
+    self._testReturnValues(fn_true, fn_false, (1,), (3,), strict=True)
 
   @test_util.run_deprecated_v1
   def test_singleton_namedtuple(self):
@@ -887,10 +916,13 @@
     if not tf2.enabled():
       self._testShape(fn_true, fn_false, shape)
       self._testReturnValues(fn_true, fn_false, 1, 3)
-    self._testShape(fn_true, fn_false, SingletonTestTuple(shape),
-                    strict=True)
-    self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1),
-                           SingletonTestTuple(3), strict=True)
+    self._testShape(fn_true, fn_false, SingletonTestTuple(shape), strict=True)
+    self._testReturnValues(
+        fn_true,
+        fn_false,
+        SingletonTestTuple(1),
+        SingletonTestTuple(3),
+        strict=True)
 
   @test_util.run_deprecated_v1
   def test_tuple(self):
@@ -902,8 +934,8 @@
 
   @test_util.run_deprecated_v1
   def test_namedtuple(self):
-    shape = TestTuple(tensor_shape.TensorShape([]),
-                      tensor_shape.TensorShape([]))
+    shape = TestTuple(
+        tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
     fn_true = lambda: TestTuple(constant_op.constant(1), 2)
     fn_false = lambda: TestTuple(constant_op.constant(3), 4)
     self._testShape(fn_true, fn_false, shape)
@@ -911,22 +943,29 @@
 
   @test_util.run_deprecated_v1
   def test_nested(self):
-    shape = [tensor_shape.TensorShape([]),
-             TestTuple(tensor_shape.TensorShape([]),
-                       [tensor_shape.TensorShape([]),
-                        tensor_shape.TensorShape([])]),
-             tensor_shape.TensorShape([5, 5]),
-             tensor_shape.TensorShape([])]
+    shape = [
+        tensor_shape.TensorShape([]),
+        TestTuple(
+            tensor_shape.TensorShape([]),
+            [tensor_shape.TensorShape([]),
+             tensor_shape.TensorShape([])]),
+        tensor_shape.TensorShape([5, 5]),
+        tensor_shape.TensorShape([])
+    ]
 
     def true_fn():
-      return [constant_op.constant(1),
-              TestTuple(constant_op.constant(2), [3, 4]),
-              array_ops.zeros([5, 5]), 6]
+      return [
+          constant_op.constant(1),
+          TestTuple(constant_op.constant(2), [3, 4]),
+          array_ops.zeros([5, 5]), 6
+      ]
 
     def false_fn():
-      return [constant_op.constant(11),
-              TestTuple(constant_op.constant(12), [13, 14]),
-              array_ops.ones([5, 5]), 16]
+      return [
+          constant_op.constant(11),
+          TestTuple(constant_op.constant(12), [13, 14]),
+          array_ops.ones([5, 5]), 16
+      ]
 
     self._testShape(true_fn, false_fn, shape)
     self._testReturnValues(
@@ -940,10 +979,10 @@
 
     def body(i, matrix):
       result_tuple, unused_matrix = control_flow_ops.cond(
-          constant_op.constant(True),
-          lambda: (TestTuple(matrix * 2, matrix * 4), matrix),
-          lambda: (TestTuple(matrix * 4, matrix * 2), matrix))
-      return [i+1, result_tuple.a]
+          constant_op.constant(True), lambda:
+          (TestTuple(matrix * 2, matrix * 4), matrix), lambda:
+          (TestTuple(matrix * 4, matrix * 2), matrix))
+      return [i + 1, result_tuple.a]
 
     iteration, matrix = control_flow_ops.while_loop(
         lambda i, matrix: i < 10,
@@ -1113,9 +1152,6 @@
     """Verify disjoint branches across while iterations are run in parallel."""
     if control_flow_v2_toggles.control_flow_v2_enabled():
       self.skipTest("b/138870290")
-    if test.is_built_with_rocm():
-      self.skipTest(
-          "Disable subtest on ROCm due to missing Cholesky op support")
 
     with ops.Graph().as_default() as g:
       nbranches = 7
@@ -1124,16 +1160,20 @@
               random_ops.random_uniform([nbranches, 8, 512]) + 1e-3))
 
       def make_branch(i, mat, name):
+
         def branch_fn():
           next_i = i + 1
           with ops.device("gpu:0"):
             return next_i, math_ops.reduce_sum(
                 linalg_ops.cholesky(mat, name=name + "_Cholesky"))
+
         return branch_fn
 
       def make_branches(i):
-        return [make_branch(i, matrices[bi], "br{}".format(bi))
-                for bi in range(nbranches)]
+        return [
+            make_branch(i, matrices[bi], "br{}".format(bi))
+            for bi in range(nbranches)
+        ]
 
       def cond(i, _):
         return i < nbranches
@@ -1163,9 +1203,7 @@
     self.assertLen(chol_node_stats, nbranches)
 
     chol_node_stats = sorted(chol_node_stats, key=lambda stats: stats.node_name)
-    op_start_nanos = [
-        stats.all_start_nanos for stats in chol_node_stats
-    ]
+    op_start_nanos = [stats.all_start_nanos for stats in chol_node_stats]
     op_end_nanos = [
         stats.all_start_nanos + stats.op_end_rel_nanos
         for stats in chol_node_stats
@@ -1494,20 +1532,26 @@
   @test_util.enable_control_flow_v2
   @test_util.run_in_graph_and_eager_modes
   def testSkipsUnnecessaryCaptureGradients(self):
+
     @custom_gradient.custom_gradient
     def gradient_trap(t):
+
       def grad(w):
         # Computing this gradient should fail the test
         check_ops.assert_equal(0, 1)
         return w
+
       return t, grad
 
     x = array_ops.constant(0.0, name="x")
     y = array_ops.constant(1.0, name="y")
+
     def cond(s):
       return s < 10.0
+
     def body(s):
-      return s + 2*x + gradient_trap(y)
+      return s + 2 * x + gradient_trap(y)
+
     with backprop.GradientTape() as tape:
       tape.watch(x)
       out = control_flow_ops.while_loop(cond, body, (array_ops.constant(0.0),))
@@ -1548,5 +1592,6 @@
 
     self.assertAllEqual(whiny(True), 5)
 
+
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/python/ops/handle_data_util.py b/tensorflow/python/ops/handle_data_util.py
index d83bea3..4f17cf4 100644
--- a/tensorflow/python/ops/handle_data_util.py
+++ b/tensorflow/python/ops/handle_data_util.py
@@ -19,20 +19,11 @@
 from __future__ import print_function
 
 from tensorflow.python.client import pywrap_tf_session
-from tensorflow.python.framework import cpp_shape_inference_pb2
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.util import compat
 
 
-def get_resource_handle_data(graph_op):
-  assert type(graph_op) == ops.Tensor  # pylint: disable=unidiomatic-typecheck
-
-  handle_data = pywrap_tf_session.GetHandleShapeAndType(
-      graph_op.graph._c_graph, graph_op._as_tf_output())  # pylint: disable=protected-access
-
-  return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
-      compat.as_bytes(handle_data))
+get_resource_handle_data = ops.get_resource_handle_data
 
 
 def copy_handle_data(source_t, target_t):
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index bbd5831..70dcc80 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -1837,7 +1837,7 @@
       dimensions of each image.
 
   Returns:
-    A `Tensor` with the same shape and dtype as `image`.
+    A `Tensor` with the same shape as `image`.
 
   Raises:
     ValueError: if the shape of 'image' is incompatible with this function.
@@ -1846,22 +1846,18 @@
     image = ops.convert_to_tensor(image, name='image')
     image = _AssertAtLeast3DImage(image)
 
-    # Remember original dtype to so we can convert back if needed
-    orig_dtype = image.dtype
-    if orig_dtype not in [dtypes.float16, dtypes.float32]:
-      image = convert_image_dtype(image, dtypes.float32)
-
+    image = math_ops.cast(image, dtype=dtypes.float32)
     num_pixels = math_ops.reduce_prod(array_ops.shape(image)[-3:])
     image_mean = math_ops.reduce_mean(image, axis=[-1, -2, -3], keepdims=True)
 
     # Apply a minimum normalization that protects us against uniform images.
     stddev = math_ops.reduce_std(image, axis=[-1, -2, -3], keepdims=True)
-    min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, image.dtype))
+    min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, dtypes.float32))
     adjusted_stddev = math_ops.maximum(stddev, min_stddev)
 
     image -= image_mean
     image = math_ops.divide(image, adjusted_stddev, name=scope)
-    return convert_image_dtype(image, orig_dtype, saturate=True)
+    return image
 
 
 @tf_export('image.random_brightness')
@@ -4252,9 +4248,14 @@
   Example:
 
   ```python
-      # Read images from file.
-      im1 = tf.decode_png('path/to/im1.png')
-      im2 = tf.decode_png('path/to/im2.png')
+      # Read images (of size 255 x 255) from file.
+      im1 = tf.image.decode_image(tf.io.read_file('path/to/im1.png'))
+      im2 = tf.image.decode_image(tf.io.read_file('path/to/im2.png'))
+      tf.shape(im1)  # `img1.png` has 3 channels; shape is `(255, 255, 3)`
+      tf.shape(im2)  # `img2.png` has 3 channels; shape is `(255, 255, 3)`
+      # Add an outer batch for each image.
+      im1 = tf.expand_dims(im1, axis=0)
+      im2 = tf.expand_dims(im2, axis=0)
       # Compute SSIM over tf.uint8 Tensors.
       ssim1 = tf.image.ssim(im1, im2, max_val=255, filter_size=11,
                             filter_sigma=1.5, k1=0.01, k2=0.03)
@@ -4268,8 +4269,10 @@
   ```
 
   Args:
-    img1: First image batch.
-    img2: Second image batch.
+    img1: First image batch. 4-D Tensor of shape `[batch, height, width,
+      channels]`.
+    img2: Second image batch. 4-D Tensor of shape `[batch, height, width,
+      channels]`.
     max_val: The dynamic range of the images (i.e., the difference between the
       maximum the and minimum allowed values).
     filter_size: Default value 11 (size of gaussian filter).
@@ -4455,7 +4458,7 @@
     image = tf.reshape(tf.range(IMAGE_HEIGHT * IMAGE_WIDTH * CHANNELS,
       delta=1, dtype=tf.float32),
       shape=(BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS))
-    dx, dy = tf.image.image_gradients(image)
+    dy, dx = tf.image.image_gradients(image)
     print(image[0, :,:,0])
     tf.Tensor(
       [[ 0.  1.  2.  3.  4.]
@@ -4463,14 +4466,14 @@
       [10. 11. 12. 13. 14.]
       [15. 16. 17. 18. 19.]
       [20. 21. 22. 23. 24.]], shape=(5, 5), dtype=float32)
-    print(dx[0, :,:,0])
+    print(dy[0, :,:,0])
     tf.Tensor(
       [[5. 5. 5. 5. 5.]
       [5. 5. 5. 5. 5.]
       [5. 5. 5. 5. 5.]
       [5. 5. 5. 5. 5.]
       [0. 0. 0. 0. 0.]], shape=(5, 5), dtype=float32)
-    print(dy[0, :,:,0])
+    print(dx[0, :,:,0])
     tf.Tensor(
       [[1. 1. 1. 1. 0.]
       [1. 1. 1. 1. 0.]
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 980f892..0e87194 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -1643,7 +1643,8 @@
     self._testBrightness(x_np, y_np, delta=-10. / 255.)
 
 
-class PerImageWhiteningTest(test_util.TensorFlowTestCase):
+class PerImageWhiteningTest(test_util.TensorFlowTestCase,
+                            parameterized.TestCase):
 
   def _NumpyPerImageWhitening(self, x):
     num_pixels = np.prod(x.shape)
@@ -1656,13 +1657,19 @@
     y /= stddev
     return y
 
-  def testBasic(self):
+  @parameterized.named_parameters([("_int8", np.int8), ("_int16", np.int16),
+                                   ("_int32", np.int32), ("_int64", np.int64),
+                                   ("_uint8", np.uint8), ("_uint16", np.uint16),
+                                   ("_uint32", np.uint32),
+                                   ("_uint64", np.uint64),
+                                   ("_float32", np.float32)])
+  def testBasic(self, data_type):
     x_shape = [13, 9, 3]
-    x_np = np.arange(0, np.prod(x_shape), dtype=np.float32).reshape(x_shape)
+    x_np = np.arange(0, np.prod(x_shape), dtype=data_type).reshape(x_shape)
     y_np = self._NumpyPerImageWhitening(x_np)
 
     with self.cached_session(use_gpu=True):
-      x = constant_op.constant(x_np, shape=x_shape)
+      x = constant_op.constant(x_np, dtype=data_type, shape=x_shape)
       y = image_ops.per_image_standardization(x)
       y_tf = self.evaluate(y)
       self.assertAllClose(y_tf, y_np, atol=1e-4)
@@ -1685,17 +1692,6 @@
       for w_tf, w_np in zip(whiten_tf, whiten_np):
         self.assertAllClose(w_tf, w_np, atol=1e-4)
 
-  def testPreservesDtype(self):
-    imgs_npu8 = np.random.uniform(0., 255., [2, 5, 5, 3]).astype(np.uint8)
-    imgs_tfu8 = constant_op.constant(imgs_npu8)
-    whiten_tfu8 = image_ops.per_image_standardization(imgs_tfu8)
-    self.assertEqual(whiten_tfu8.dtype, dtypes.uint8)
-
-    imgs_npf16 = np.random.uniform(0., 255., [2, 5, 5, 3]).astype(np.float16)
-    imgs_tff16 = constant_op.constant(imgs_npf16)
-    whiten_tff16 = image_ops.per_image_standardization(imgs_tff16)
-    self.assertEqual(whiten_tff16.dtype, dtypes.float16)
-
 
 class CropToBoundingBoxTest(test_util.TensorFlowTestCase):
 
@@ -5014,7 +5010,7 @@
     max_output_size_np = 6
     iou_threshold_np = 0.5
     score_threshold_np = 0.0
-    soft_nms_sigma_np = 1.0
+    soft_nms_sigma_np = 0.5
     boxes = constant_op.constant(boxes_np)
     scores = constant_op.constant(scores_np)
     max_output_size = constant_op.constant(max_output_size_np)
diff --git a/tensorflow/python/ops/init_ops_v2_test.py b/tensorflow/python/ops/init_ops_v2_test.py
index d524f1e..2de636c 100644
--- a/tensorflow/python/ops/init_ops_v2_test.py
+++ b/tensorflow/python/ops/init_ops_v2_test.py
@@ -47,10 +47,7 @@
     self.assertEqual(tensor_shape.as_shape(shape), t2.shape)
     self.assertEqual(assertion, np.allclose(t1, t2, rtol=1e-15, atol=1e-15))
 
-  def _duplicated_test(self,
-                       init,
-                       shape=None,
-                       dtype=dtypes.float32):
+  def _duplicated_test(self, init, shape=None, dtype=dtypes.float32):
     if shape is None:
       shape = [100]
     t1 = self.evaluate(init(shape, dtype))
@@ -98,8 +95,8 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testZeros(self):
-    self._range_test(init_ops_v2.Zeros(), shape=(4, 5),
-                     target_mean=0., target_max=0.)
+    self._range_test(
+        init_ops_v2.Zeros(), shape=(4, 5), target_mean=0., target_max=0.)
 
   @test_util.run_in_graph_and_eager_modes
   def testZerosPartition(self):
@@ -115,8 +112,8 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testOnes(self):
-    self._range_test(init_ops_v2.Ones(), shape=(4, 5),
-                     target_mean=1., target_max=1.)
+    self._range_test(
+        init_ops_v2.Ones(), shape=(4, 5), target_mean=1., target_max=1.)
 
   @test_util.run_in_graph_and_eager_modes
   def testOnesPartition(self):
@@ -176,15 +173,13 @@
 
     self._testNDimConstantInitializer(value, shape, expected)
     self._testNDimConstantInitializer(np.asarray(value), shape, expected)
-    self._testNDimConstantInitializer(np.asarray(value).reshape(tuple(shape)),
-                                      shape, expected)
+    self._testNDimConstantInitializer(
+        np.asarray(value).reshape(tuple(shape)), shape, expected)
 
   def _testNDimConstantInitializerIncorrectNumberValues(self, value, shape):
     with test_util.use_gpu():
       init = init_ops_v2.constant_initializer(value)
-      self.assertRaises(TypeError,
-                        init,
-                        shape=shape)
+      self.assertRaises(TypeError, init, shape=shape)
 
   @test_util.run_in_graph_and_eager_modes
   def testNDimConstantInitializerIncorrectNumberValues(self):
@@ -192,8 +187,8 @@
 
     for shape in [[2, 4], [2, 2]]:
       self._testNDimConstantInitializerIncorrectNumberValues(value, shape)
-      self._testNDimConstantInitializerIncorrectNumberValues(np.asarray(value),
-                                                             shape)
+      self._testNDimConstantInitializerIncorrectNumberValues(
+          np.asarray(value), shape)
       self._testNDimConstantInitializerIncorrectNumberValues(
           np.asarray(value).reshape(tuple([2, 3])), shape)
 
@@ -351,8 +346,7 @@
     shape = [100, 100]
     expect_mean = 0.
     expect_var = 1. / shape[0]
-    init = init_ops_v2.VarianceScaling(
-        distribution="untruncated_normal")
+    init = init_ops_v2.VarianceScaling(distribution="untruncated_normal")
 
     with test_util.use_gpu(), test.mock.patch.object(
         random_ops, "random_normal",
@@ -399,8 +393,8 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testRangeInitializer(self):
-    self._range_test(init_ops_v2.Orthogonal(seed=123), shape=(20, 20),
-                     target_mean=0.)
+    self._range_test(
+        init_ops_v2.Orthogonal(seed=123), shape=(20, 20), target_mean=0.)
 
   @test_util.run_in_graph_and_eager_modes
   def testInitializerIdentical(self):
@@ -443,10 +437,6 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testShapesValues(self):
-
-    if test.is_built_with_rocm():
-      self.skipTest("Disable subtest on ROCm due to missing QR op support")
-
     for shape in [(10, 10), (10, 9, 8), (100, 5, 5), (50, 40), (40, 50)]:
       init = init_ops_v2.Orthogonal()
       tol = 1e-5
@@ -518,11 +508,12 @@
       init_default = init_ops_v2.Identity()
       init_custom = init_ops_v2.Identity(gain=0.9)
       with test_util.use_gpu():
-        self.assertAllClose(self.evaluate(init_default(shape, dtype=dtype)),
-                            np.eye(*shape))
+        self.assertAllClose(
+            self.evaluate(init_default(shape, dtype=dtype)), np.eye(*shape))
       with test_util.use_gpu():
-        self.assertAllClose(self.evaluate(init_custom(shape, dtype=dtype)),
-                            np.eye(*shape) * 0.9)
+        self.assertAllClose(
+            self.evaluate(init_custom(shape, dtype=dtype)),
+            np.eye(*shape) * 0.9)
 
   @test_util.run_in_graph_and_eager_modes
   def testPartition(self):
@@ -577,10 +568,7 @@
     fan_in, _ = init_ops_v2._compute_fans(shape)
     std = np.sqrt(2. / fan_in)
     self._range_test(
-        init_ops_v2.he_uniform(seed=123),
-        shape,
-        target_mean=0.,
-        target_std=std)
+        init_ops_v2.he_uniform(seed=123), shape, target_mean=0., target_std=std)
 
   @test_util.run_in_graph_and_eager_modes
   def testLecunNormal(self):
@@ -599,10 +587,7 @@
     fan_in, _ = init_ops_v2._compute_fans(shape)
     std = np.sqrt(2. / fan_in)
     self._range_test(
-        init_ops_v2.he_normal(seed=123),
-        shape,
-        target_mean=0.,
-        target_std=std)
+        init_ops_v2.he_normal(seed=123), shape, target_mean=0., target_std=std)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/ops/linalg/linear_operator_adjoint.py b/tensorflow/python/ops/linalg/linear_operator_adjoint.py
index 1af0ce9..0fe8e29 100644
--- a/tensorflow/python/ops/linalg/linear_operator_adjoint.py
+++ b/tensorflow/python/ops/linalg/linear_operator_adjoint.py
@@ -153,7 +153,6 @@
     with ops.name_scope(name, values=operator.graph_parents):
       super(LinearOperatorAdjoint, self).__init__(
           dtype=operator.dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_block_diag.py b/tensorflow/python/ops/linalg/linear_operator_block_diag.py
index 514b023..6367e00 100644
--- a/tensorflow/python/ops/linalg/linear_operator_block_diag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_block_diag.py
@@ -228,7 +228,6 @@
     with ops.name_scope(name, values=graph_parents):
       super(LinearOperatorBlockDiag, self).__init__(
           dtype=dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py
index 31dd5b2..d3d5c50 100644
--- a/tensorflow/python/ops/linalg/linear_operator_circulant.py
+++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py
@@ -119,7 +119,6 @@
 
       super(_BaseLinearOperatorCirculant, self).__init__(
           dtype=dtypes.as_dtype(input_output_dtype),
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_composition.py b/tensorflow/python/ops/linalg/linear_operator_composition.py
index ace7e85..bfe3479 100644
--- a/tensorflow/python/ops/linalg/linear_operator_composition.py
+++ b/tensorflow/python/ops/linalg/linear_operator_composition.py
@@ -185,7 +185,6 @@
     with ops.name_scope(name, values=graph_parents):
       super(LinearOperatorComposition, self).__init__(
           dtype=dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py
index 3f298bc..f5b26ba 100644
--- a/tensorflow/python/ops/linalg/linear_operator_diag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_diag.py
@@ -166,7 +166,6 @@
 
       super(LinearOperatorDiag, self).__init__(
           dtype=self._diag.dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py
index a616a8c..4319d01 100644
--- a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py
+++ b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py
@@ -149,7 +149,6 @@
 
       super(LinearOperatorFullMatrix, self).__init__(
           dtype=self._matrix.dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_householder.py b/tensorflow/python/ops/linalg/linear_operator_householder.py
index cbb7a88..e9a1af0 100644
--- a/tensorflow/python/ops/linalg/linear_operator_householder.py
+++ b/tensorflow/python/ops/linalg/linear_operator_householder.py
@@ -155,7 +155,6 @@
 
       super(LinearOperatorHouseholder, self).__init__(
           dtype=self._reflection_axis.dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_inversion.py b/tensorflow/python/ops/linalg/linear_operator_inversion.py
index b2784c4..7d7ae63 100644
--- a/tensorflow/python/ops/linalg/linear_operator_inversion.py
+++ b/tensorflow/python/ops/linalg/linear_operator_inversion.py
@@ -166,7 +166,6 @@
     with ops.name_scope(name, values=operator.graph_parents):
       super(LinearOperatorInversion, self).__init__(
           dtype=operator.dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_kronecker.py b/tensorflow/python/ops/linalg/linear_operator_kronecker.py
index b351bc5..d4a3482 100644
--- a/tensorflow/python/ops/linalg/linear_operator_kronecker.py
+++ b/tensorflow/python/ops/linalg/linear_operator_kronecker.py
@@ -230,7 +230,6 @@
     with ops.name_scope(name, values=graph_parents):
       super(LinearOperatorKronecker, self).__init__(
           dtype=dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
index 2f12c71..4157233 100644
--- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
+++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
@@ -260,7 +260,6 @@
 
       super(LinearOperatorLowRankUpdate, self).__init__(
           dtype=self._base_operator.dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
index fbc1f53..7a6ac9d 100644
--- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
+++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
@@ -158,7 +158,6 @@
 
       super(LinearOperatorLowerTriangular, self).__init__(
           dtype=self._tril.dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/linalg/linear_operator_toeplitz.py b/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
index 95546c2..a68a94e 100644
--- a/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
+++ b/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
@@ -159,7 +159,6 @@
 
       super(LinearOperatorToeplitz, self).__init__(
           dtype=self._row.dtype,
-          graph_parents=None,
           is_non_singular=is_non_singular,
           is_self_adjoint=is_self_adjoint,
           is_positive_definite=is_positive_definite,
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 5ec95b6..13fdf37 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -646,6 +646,22 @@
   For `x` with more dimensions, independently normalizes each 1-D slice along
   dimension `axis`.
 
+  * 1-D tensor example:
+  >>> x = tf.constant([3.0, 4.0])
+  >>> tf.math.l2_normalize(x).numpy()
+  array([0.6, 0.8], dtype=float32)
+
+  * 2-D tensor example:
+  >>> x = tf.constant([[3.0], [4.0]])
+  >>> tf.math.l2_normalize(x, 0).numpy()
+  array([[0.6],
+       [0.8]], dtype=float32)
+
+  >>> x = tf.constant([[3.0], [4.0]])
+  >>> tf.math.l2_normalize(x, 1).numpy()
+  array([[1.],
+       [1.]], dtype=float32)
+
   Args:
     x: A `Tensor`.
     axis: Dimension along which to normalize.  A scalar or a vector of
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 851bfcb..aaf2f77 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -27,6 +27,7 @@
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
@@ -1260,6 +1261,7 @@
       y_val = self.evaluate(y)
       self.assertAllEqual(y_val, y_val_expected)
 
+  @test_util.disable_xla("XLA catches the error and rethrows as different one")
   def testArbitraryASCII(self):
     x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
     y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]
@@ -1269,6 +1271,46 @@
       y_val = self.evaluate(y)
       self.assertAllEqual(y_val, y_val_expected)
 
+  @test_util.disable_xla("XLA catches the error and rethrows as different one")
+  def testInvalidLength(self):
+    x = [-4, -3, -2, -1, 0, 1, 2, 3]
+    with self.assertRaisesRegex(errors.InvalidArgumentError,
+                                "Source format must be of length 4 or 5"):
+      op = nn_ops.data_format_dim_map(
+          x, src_format="12345678", dst_format="87654321")
+      with test_util.use_gpu():
+        self.evaluate(op)
+
+  @test_util.disable_xla("XLA catches the error and rethrows as different one")
+  def testDuplicateSrc(self):
+    x = [-4, -3, -2, -1, 0, 1, 2, 3]
+    with self.assertRaisesRegex(
+        errors.InvalidArgumentError,
+        "Destination and source format must determine a permutation"):
+      op = nn_ops.data_format_dim_map(x, src_format="1233", dst_format="4321")
+      with test_util.use_gpu():
+        self.evaluate(op)
+
+  @test_util.disable_xla("XLA catches the error and rethrows as different one")
+  def testDuplicateDst(self):
+    x = [-4, -3, -2, -1, 0, 1, 2, 3]
+    with self.assertRaisesRegex(
+        errors.InvalidArgumentError,
+        "Destination and source format must determine a permutation"):
+      op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="3321")
+      with test_util.use_gpu():
+        self.evaluate(op)
+
+  @test_util.disable_xla("XLA catches the error and rethrows as different one")
+  def testExtraSpecifiers(self):
+    x = [-4, -3, -2, -1, 0, 1, 2, 3]
+    with self.assertRaisesRegex(
+        errors.InvalidArgumentError,
+        "Destination and source format must determine a permutation"):
+      op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="5321")
+      with test_util.use_gpu():
+        self.evaluate(op)
+
 
 class DataFormatVectorPermuteTest(test_lib.TestCase):
 
@@ -1370,6 +1412,60 @@
       y_val = self.evaluate(y)
       self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]])
 
+  @test_util.disable_xla("XLA catches the error and rethrows as different one")
+  def testInvalidLength(self):
+    x = [0, 1, 2, 3]
+    with self.assertRaisesRegex(errors.InvalidArgumentError,
+                                "Source format must be of length 4 or 5"):
+      op = nn_ops.data_format_vec_permute(
+          x, src_format="12345678", dst_format="87654321")
+      with test_util.use_gpu():
+        self.evaluate(op)
+
+  @test_util.disable_xla("XLA catches the error and rethrows as different one")
+  def testDuplicateSrc(self):
+    x = [0, 1, 2, 3]
+    with self.assertRaisesRegex(
+        errors.InvalidArgumentError,
+        "Destination and source format must determine a permutation"):
+      op = nn_ops.data_format_vec_permute(
+          x, src_format="1233", dst_format="4321")
+      with test_util.use_gpu():
+        self.evaluate(op)
+
+  @test_util.disable_xla("XLA catches the error and rethrows as different one")
+  def testDuplicateDst(self):
+    x = [0, 1, 2, 3]
+    with self.assertRaisesRegex(
+        errors.InvalidArgumentError,
+        "Destination and source format must determine a permutation"):
+      op = nn_ops.data_format_vec_permute(
+          x, src_format="1234", dst_format="3321")
+      with test_util.use_gpu():
+        self.evaluate(op)
+
+  @test_util.disable_xla("XLA catches the error and rethrows as different one")
+  def testExtraSpecifiers(self):
+    x = [0, 1, 2, 3]
+    with self.assertRaisesRegex(
+        errors.InvalidArgumentError,
+        "Destination and source format must determine a permutation"):
+      op = nn_ops.data_format_vec_permute(
+          x, src_format="1234", dst_format="5321")
+      with test_util.use_gpu():
+        self.evaluate(op)
+
+  @test_util.disable_xla("XLA catches the error and rethrows as different one")
+  def test2DNoWH(self):
+    x = [[0, 1], [2, 3]]
+    with self.assertRaisesRegex(
+        errors.InvalidArgumentError,
+        "Format specifier must contain H and W for 2D case"):
+      op = nn_ops.data_format_vec_permute(
+          x, src_format="1234", dst_format="4321")
+      with test_util.use_gpu():
+        self.evaluate(op)
+
 
 @test_util.run_all_in_graph_and_eager_modes
 class AvgPoolTest(test_lib.TestCase):
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
index 3ab9963..169eb17 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -218,7 +218,7 @@
 
 
 # pylint: disable=protected-access
-def _composite_to_tensors(value):
+def _composite_to_tensors(value, is_batched=False):
   """Converts a CompositeTensor into a list of stackable tensors."""
   if _should_expand_composite(value):
     spec = value._type_spec
@@ -227,6 +227,8 @@
                        "parallel_for or vectorized_map loop body must provide "
                        "a `BatchableTypeSpec` (saw: {}).".format(
                            value, spec))
+    if is_batched:
+      return spec._to_batched_tensor_list(value)
     return spec._to_tensor_list(value)
   return value
 # pylint: enable=protected-access
@@ -421,14 +423,26 @@
   return result
 
 
+# pylint: disable=protected-access
+def _gather_from_tensor_or_composite(x, i):
+  """Wrapper for gather that handles CompositeTensors."""
+  if _should_expand_composite(x):
+    spec = x._type_spec
+    gathered_tensors = [_broadcasting_gather(t, i)
+                        for t in spec._to_batched_tensor_list(x)]
+    return spec._unbatch()._from_compatible_tensor_list(gathered_tensors)
+  return _broadcasting_gather(x, i)
+# pylint: enable=protected-access
+
+
 @tf_export("vectorized_map")
 def vectorized_map(fn, elems, fallback_to_while_loop=True):
   """Parallel map on the list of tensors unpacked from `elems` on dimension 0.
 
   This method works similar to `tf.map_fn` but is optimized to run much faster,
   possibly with a much larger memory footprint. The speedups are obtained by
-  vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians, 
-  Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea 
+  vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
+  Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
   behind vectorization is to semantically launch all the invocations of `fn` in
   parallel and fuse corresponding operations across all these invocations. This
   fusion is done statically at graph generation time and the generated code is
@@ -518,19 +532,21 @@
   Raises:
     ValueError: If vectorization fails and fallback_to_while_loop is False.
   """
-  def _convert_to_tensor_or_ndarray(x):
-    if isinstance(x, np_arrays.ndarray):
-      return x
-    return ops.convert_to_tensor(x)
-  elems = nest.map_structure(_convert_to_tensor_or_ndarray, elems)
+  elems = nest.map_structure(ops.convert_to_tensor,
+                             elems,
+                             expand_composites=True)
 
   def loop_fn(i):
-    gathered_elems = nest.map_structure(lambda x: _broadcasting_gather(x, i),
-                                        elems)
+    gathered_elems = nest.map_structure(
+        lambda x: _gather_from_tensor_or_composite(x, i), elems)
     return fn(gathered_elems)
 
   # Extract batch size from the maximum first dimension of any element.
-  flat_elems = nest.flatten(elems)
+  flat_elems = nest.flatten(
+      nest.map_structure(
+          functools.partial(_composite_to_tensors,
+                            is_batched=True),
+          elems))
   def _get_shape(x):
     if isinstance(x, np_arrays.ndarray):
       x = x.data
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index f10d07f..f27f952 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -70,6 +70,7 @@
 from tensorflow.python.ops import variables
 from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
 from tensorflow.python.ops.parallel_for.test_util import PForTestCase
+from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.ops.signal import fft_ops
 from tensorflow.python.platform import test
 from tensorflow.python.util import nest
@@ -2157,6 +2158,27 @@
     self.assertTrue(particles.mass.shape, [4, 1, 3])
     self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
 
+  def test_vectorized_map_gathers_composite_tensors(self):
+    particles = Particle(mass=[1., 2., 3., 4., 5.],
+                         velocity=[1., 2., 3., 4., 5.])
+    self.assertAllEqual(
+        pfor_control_flow_ops.vectorized_map(
+            lambda x: x.mass * x.velocity, particles),
+        particles.mass * particles.velocity)
+
+  def test_vectorized_map_of_ragged_tensors(self):
+    # Vmap should be able to handle ragged Tensors as long as they're not
+    # *actually* ragged.
+    ragged = ragged_tensor.RaggedTensor.from_uniform_row_length(
+        ragged_tensor.RaggedTensor.from_row_lengths(
+            values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
+            row_lengths=[3, 3, 3, 3]),
+        uniform_row_length=2)  # Overall shape [2, 2, 3].
+    self.assertAllEqual(
+        pfor_control_flow_ops.vectorized_map(
+            lambda x: x.to_tensor(shape=[2, 3]), ragged),
+        ragged.to_tensor(shape=[2, 2, 3]))
+
 
 class ParsingTest(PForTestCase):
 
diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD
index 2934491..d726c3b 100644
--- a/tensorflow/python/ops/ragged/BUILD
+++ b/tensorflow/python/ops/ragged/BUILD
@@ -1104,7 +1104,6 @@
     srcs = ["ragged_map_fn_op_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
-    tags = ["no_rocm"],
     deps = [
         ":ragged",  # fixdeps: keep
         ":ragged_factory_ops",
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 843c747..2527252 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -168,9 +168,17 @@
                      name=None):
   """Outputs random values from a truncated normal distribution.
 
-  The generated values follow a normal distribution with specified mean and
-  standard deviation, except that values whose magnitude is more than 2 standard
-  deviations from the mean are dropped and re-picked.
+  The values are drawn from a normal distribution with specified mean and
+  standard deviation, discarding and re-drawing any samples that are more than
+  two standard deviations from the mean.
+
+  Examples:
+
+  >>> tf.random.truncated_normal(shape=[2])
+  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([..., ...], dtype=float32)>
+
+  >>> tf.random.truncated_normal(shape=[2], mean=3, stddev=1, dtype=tf.float32)
+  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([..., ...], dtype=float32)>
 
   Args:
     shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
@@ -178,11 +186,10 @@
       truncated normal distribution.
     stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
       of the normal distribution, before truncation.
-    dtype: The type of the output.
+    dtype: The type of the output. Restricted to floating-point types:
+      `tf.half`, `tf.float`, `tf.double`, etc.
     seed: A Python integer. Used to create a random seed for the distribution.
-      See
-      `tf.random.set_seed`
-      for behavior.
+      See `tf.random.set_seed` for more information.
     name: A name for the operation (optional).
 
   Returns:
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 6cda36d..f58dd5c 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -27,7 +27,6 @@
 
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import variable_pb2
-from tensorflow.python import _pywrap_utils
 from tensorflow.python.client import pywrap_tf_session
 from tensorflow.python.eager import context
 from tensorflow.python.eager import tape
@@ -53,6 +52,7 @@
 # pylint: enable=wildcard-import
 from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.types import core
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util import compat
 from tensorflow.python.util.deprecation import deprecated
 
diff --git a/tensorflow/python/ops/risc/risc_grad.py b/tensorflow/python/ops/risc/risc_grad.py
index 035bd9b..6ecd81d 100644
--- a/tensorflow/python/ops/risc/risc_grad.py
+++ b/tensorflow/python/ops/risc/risc_grad.py
@@ -28,6 +28,27 @@
   return None, None
 
 
+@ops.RegisterGradient("RiscBinaryArithmetic")
+def _RiscBinaryArithmeticGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscBinaryComparison")
+def _RiscBinaryComparisonGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscBitcast")
+def _RiscBitcastGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
 @ops.RegisterGradient("RiscBroadcast")
 def _RiscBroadcastGrad(_, grad):
   # pylint: disable=unused-argument
@@ -35,6 +56,20 @@
   return None, None
 
 
+@ops.RegisterGradient("RiscCast")
+def _RiscCastGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscCholesky")
+def _RiscCholeskyGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
 @ops.RegisterGradient("RiscConcat")
 def _RiscConcatGrad(_, grad):
   # pylint: disable=unused-argument
@@ -42,6 +77,13 @@
   return None, None
 
 
+@ops.RegisterGradient("RiscCondition")
+def _RiscConditionGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
 @ops.RegisterGradient("RiscConv")
 def _RiscConvGrad(_, grad):
   # pylint: disable=unused-argument
@@ -56,6 +98,48 @@
   return None, None
 
 
+@ops.RegisterGradient("RiscFft")
+def _RiscFftGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscGather")
+def _RiscGatherGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscIsFinite")
+def _RiscIsFiniteGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscLogicalAnd")
+def _RiscLogicalAndGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscLogicalNot")
+def _RiscLogicalNotGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscLogicalOr")
+def _RiscLogicalOrGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
 @ops.RegisterGradient("RiscMax")
 def _RiscMaxGrad(_, grad):
   # pylint: disable=unused-argument
@@ -77,6 +161,20 @@
   return None, None
 
 
+@ops.RegisterGradient("RiscRandomUniform")
+def _RiscRandomUniformGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscReduce")
+def _RiscReduceGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
 @ops.RegisterGradient("RiscReshape")
 def _RiscReshapeGrad(_, grad):
   # pylint: disable=unused-argument
@@ -84,6 +182,20 @@
   return None, None
 
 
+@ops.RegisterGradient("RiscReverse")
+def _RiscReverseGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscScatter")
+def _RiscScatterGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
 @ops.RegisterGradient("RiscShape")
 def _RiscShapeGrad(_, grad):
   # pylint: disable=unused-argument
@@ -96,3 +208,45 @@
   # pylint: disable=unused-argument
   # TODO(b/171294012): Implement gradient of RISC with RISC ops.
   return None, None
+
+
+@ops.RegisterGradient("RiscSort")
+def _RiscSortGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscSqueeze")
+def _RiscSqueezeGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscTranspose")
+def _RiscTransposeGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscTriangularSolve")
+def _RiscTriangularSolvesGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscUnary")
+def _RiscUnaryGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
+
+
+@ops.RegisterGradient("RiscWhile")
+def _RiscWhileGrad(_, grad):
+  # pylint: disable=unused-argument
+  # TODO(b/171294012): Implement gradient of RISC with RISC ops.
+  return None, None
diff --git a/tensorflow/python/ops/risc/risc_ops.py b/tensorflow/python/ops/risc/risc_ops.py
index 14cdb9e..506a481 100644
--- a/tensorflow/python/ops/risc/risc_ops.py
+++ b/tensorflow/python/ops/risc/risc_ops.py
@@ -21,11 +21,6 @@
 
 from tensorflow.python.ops import gen_risc_ops
 
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import
-from tensorflow.python.ops.risc_ops_gen import *
-# pylint: enable=wildcard-import
-
 
 def risc_add(
     input_lhs,
@@ -34,14 +29,49 @@
   return gen_risc_ops.risc_add(input_lhs, input_rhs, name=name)
 
 
+def risc_binary_arithmetic(x, y, op_type, name='RISC_BinaryArithmetic'):
+  return gen_risc_ops.risc_binary_arithmetic(x, y, op_type=op_type, name=name)
+
+
+def risc_binary_comparison(x, y, op_type, name='RISC_BinaryComparison'):
+  return gen_risc_ops.risc_binary_comparison(x, y, op_type=op_type, name=name)
+
+
+def risc_bitcast(x, dtype, name='RISC_BITCAST'):
+  return gen_risc_ops.risc_bitcast(x, dtype, name=name)
+
+
 def risc_broadcast(x, shape, name='RISC_BROADCAST'):
   return gen_risc_ops.risc_broadcast(x, shape, name=name)
 
 
+def risc_cast(x, dtype, name='RISC_CAST'):
+  return gen_risc_ops.risc_cast(x, dtype, name=name)
+
+
+def risc_cholesky(x, name='RISC_CHOLESKY'):
+  return gen_risc_ops.risc_cholesky(x, name=name)
+
+
 def risc_concat(x, axis, name='RISC_CONCAT'):
   return gen_risc_ops.risc_concat(x, axis, name=name)
 
 
+def risc_condition(pred,
+                   input_true,
+                   input_false,
+                   func_true,
+                   func_false,
+                   name='RISC_CONDITION'):
+  return gen_risc_ops.risc_condition(
+      pred,
+      input_true,
+      input_false,
+      func_true=func_true,
+      func_false=func_false,
+      name=name)
+
+
 def risc_conv(x,
               kernel,
               strides,
@@ -70,6 +100,41 @@
       name=name)
 
 
+def risc_fft(x, name='RISC_FFT'):
+  return gen_risc_ops.risc_fft(x, name=name)
+
+
+def risc_gather(params,
+                indices,
+                validate_indices=None,
+                axis=None,
+                batch_dims=0,
+                name='RISC_GATHER'):
+  return gen_risc_ops.risc_gather(
+      params,
+      indices,
+      validate_indices=validate_indices,
+      name=name,
+      axis=axis,
+      batch_dims=batch_dims)
+
+
+def risc_is_finite(x, name='RISC_IS_FINITE'):
+  return gen_risc_ops.risc_is_finite(x, name=name)
+
+
+def risc_logical_and(a, b, name='RISC_LOGICAL_AND'):
+  return gen_risc_ops.risc_logical_and(a, b, name=name)
+
+
+def risc_logical_not(a, b, name='RISC_LOGICAL_NOT'):
+  return gen_risc_ops.risc_logical_not(a, b, name=name)
+
+
+def risc_logical_or(a, b, name='RISC_LOGICAL_OR'):
+  return gen_risc_ops.risc_logical_or(a, b, name=name)
+
+
 def risc_max(input_lhs, input_rhs, name='RISC_MAX'):
   return gen_risc_ops.risc_max(input_lhs, input_rhs, name=name)
 
@@ -83,13 +148,76 @@
       x, ksize, strides, pooling_type=pooling_type, name=name)
 
 
+def risc_random_uniform(shape, seed, name='RISC_RANDOM_UNIFORM'):
+  return gen_risc_ops.risc_random_uniform(shape, seed, name=name)
+
+
+def risc_reduce(x, axis, reduce_type, name='RISC_REDUCE'):
+  return gen_risc_ops.risc_reduce(x, axis, reduce_type=reduce_type, name=name)
+
+
 def risc_reshape(x, shape, name='RISC_RESHAPE'):
   return gen_risc_ops.risc_reshape(x, shape, name=name)
 
 
+def risc_reverse(x, axis, name='RISC_REVERSE'):
+  return gen_risc_ops.risc_reverse(x, axis, name=name)
+
+
+def risc_scatter(indices, updates, shape, name='RISC_SCATTER'):
+  return gen_risc_ops.risc_scatter(indices, updates, shape, name=name)
+
+
 def risc_shape(x, name='RISC_SHAPE'):
   return gen_risc_ops.risc_shape(x, name=name)
 
 
 def risc_slice(x, begin, size, name='RISC_SLICE'):
   return gen_risc_ops.risc_slice(x, begin, size, name=name)
+
+
+def risc_sort(x, axis, direction='ASCENDING', name='RISC_SORT'):
+  return gen_risc_ops.risc_sort(x, axis, direction=direction, name=name)
+
+
+def risc_squeeze(x, axis=None, name='RISC_SQUEEZE'):
+  return gen_risc_ops.risc_squeeze(x, axis, name=name)
+
+
+def risc_transpose(x, perm=None, name='RISC_TRANSPOSE'):
+  return gen_risc_ops.risc_transpose(x, perm, name=name)
+
+
+def risc_triangular_solve(matrix,
+                          rhs,
+                          lower=True,
+                          adjoint=False,
+                          name='RISC_TRIANGULAR_SOLVE'):
+  return gen_risc_ops.risc_triangular_solve(
+      matrix, rhs, lower=lower, adjoint=adjoint, name=name)
+
+
+def risc_unary(x, op_type='ABL', name='RISC_UNARY'):
+  return gen_risc_ops.risc_unary(x, op_type=op_type, name=name)
+
+
+def risc_while(cond,
+               body,
+               loop_vars,
+               shape_invariants=None,
+               parallel_iterations=10,
+               back_prop=True,
+               swap_memory=False,
+               maximum_iterations=None,
+               name='RISC_WHILE'):
+  return gen_risc_ops.risc_while(
+      cond=cond,
+      body=body,
+      loop_vars=loop_vars,
+      shape_invariants=shape_invariants,
+      parallel_iterations=parallel_iterations,
+      back_prop=back_prop,
+      swap_memory=swap_memory,
+      name=name,
+      maximum_iterations=maximum_iterations,
+      return_same_structure=True)
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 8575cdf..363c8b8 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -28,13 +28,13 @@
 import numpy as np
 import six
 
-from tensorflow.python import _pywrap_py_func
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import func_graph
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
+from tensorflow.python.lib.core import _pywrap_py_func
 from tensorflow.python.ops import gen_script_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.util import compat
diff --git a/tensorflow/python/ops/signal/spectral_ops.py b/tensorflow/python/ops/signal/spectral_ops.py
index 7c4c554..9db5f05 100644
--- a/tensorflow/python/ops/signal/spectral_ops.py
+++ b/tensorflow/python/ops/signal/spectral_ops.py
@@ -120,10 +120,6 @@
       The returned window is suitable for reconstructing original waveform in
       inverse_stft.
   """
-  with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
-    frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
-    frame_step.shape.assert_has_rank(0)
-
   def inverse_stft_window_fn_inner(frame_length, dtype):
     """Computes a window that can be used in `inverse_stft`.
 
@@ -141,18 +137,20 @@
       `frame_step` is not scalar, or `frame_step` is not scalar.
     """
     with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
+      frame_step_ = ops.convert_to_tensor(frame_step, name='frame_step')
+      frame_step_.shape.assert_has_rank(0)
       frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
       frame_length.shape.assert_has_rank(0)
 
       # Use equation 7 from Griffin + Lim.
       forward_window = forward_window_fn(frame_length, dtype=dtype)
       denom = math_ops.square(forward_window)
-      overlaps = -(-frame_length // frame_step)  # Ceiling division.
-      denom = array_ops.pad(denom, [(0, overlaps * frame_step - frame_length)])
-      denom = array_ops.reshape(denom, [overlaps, frame_step])
+      overlaps = -(-frame_length // frame_step_)  # Ceiling division.
+      denom = array_ops.pad(denom, [(0, overlaps * frame_step_ - frame_length)])
+      denom = array_ops.reshape(denom, [overlaps, frame_step_])
       denom = math_ops.reduce_sum(denom, 0, keepdims=True)
       denom = array_ops.tile(denom, [overlaps, 1])
-      denom = array_ops.reshape(denom, [overlaps * frame_step])
+      denom = array_ops.reshape(denom, [overlaps * frame_step_])
 
       return forward_window / denom[:frame_length]
   return inverse_stft_window_fn_inner
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 8ba2edb..ba323cb 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -126,7 +126,13 @@
 
   The provided value can be a python boolean, a scalar boolean Tensor, or
   or a callable providing such a value; if a callable is passed it will be
-  invoked on-demand to determine whether summary writing will occur.
+  invoked on-demand to determine whether summary writing will occur.  Note that
+  when calling record_if() in an eager mode context, if you intend to provide a
+  varying condition like `step % 100 == 0`, you must wrap this in a
+  callable to avoid immediate eager evaluation of the condition.  In particular,
+  using a callable is the only way to have your condition evaluated as part of
+  the traced body of an @tf.function that is invoked from within the
+  `record_if()` context.
 
   Args:
     condition: can be True, False, a bool Tensor, or a callable providing such.
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 4e79ec9..387cde1 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -27,7 +27,6 @@
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import variable_pb2
 from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
-from tensorflow.python import _pywrap_utils
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -41,6 +40,7 @@
 from tensorflow.python.ops import state_ops
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training.tracking import base as trackable
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util import compat
 from tensorflow.python.util import object_identity
 from tensorflow.python.util import tf_should_use
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 8c6d969..8c6af52 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -23,6 +23,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import collections
+
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.python.client import pywrap_tf_session as c_api
 from tensorflow.python.eager import backprop_util
@@ -862,6 +864,19 @@
   return None
 
 
+OptimizedReductionOpsCacheKey = collections.namedtuple(
+    "OptimizedReductionOpsCacheKey", [
+        "op_type",
+        "inputs",
+        "dtypes",
+        "input_types",
+        "name",
+        "attrs",
+        "op_def",
+        "compute_device",
+    ])
+
+
 class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
   """FuncGraph for the gradient function of the body of a While op.
 
@@ -957,29 +972,25 @@
     # This optimization is currently also disabled when under a persistent tape,
     # since it leads to an unbounded number of side outputs. With caching it may
     # be possible to re-enable it.
-    if (op_type in {"Shape", "Size", "Rank"} and
+    optimized_reduction_ops = {
+        "Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength"
+    }
+    if (op_type in optimized_reduction_ops and
+        not util.output_all_intermediates() and
         all(input.graph is self._forward_graph for input in inputs) and
         all(_get_accumulator(input) is None for input in inputs) and
         not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and
         not util.graph_wrapped_for_higher_order_tape_gradients(
             self._forward_graph)):
-      with self._forward_graph.as_default():
-        # `name` was built using name_scope stack of gradient graph and may not
-        # be unique in the forward graph. `Graph.create_op` does not uniquify
-        # names which are name scopes i.e. end in `/`. To ensure that the op
-        # created gets a unique name in the forward graph we get rid of the
-        # trailing slash.
-        name = ops.name_from_scope_name(name)
-        result = self._forward_graph._create_op_internal(
-            op_type,
-            inputs,
-            dtypes=dtypes,
-            input_types=input_types,
-            name=name,
-            attrs=attrs,
-            op_def=op_def,
-            compute_device=compute_device)
-        return result
+      return self._move_op_to_forward_graph(
+          op_type,
+          inputs,
+          dtypes=dtypes,
+          input_types=input_types,
+          name=name,
+          attrs=attrs,
+          op_def=op_def,
+          compute_device=compute_device)
 
     return super(_WhileBodyGradFuncGraph, self)._create_op_internal(
         op_type,
@@ -991,6 +1002,83 @@
         op_def=op_def,
         compute_device=compute_device)
 
+  def _move_op_to_forward_graph(
+      self,
+      op_type,
+      inputs,
+      dtypes=None,  # pylint: disable=redefined-outer-name
+      input_types=None,
+      name=None,
+      attrs=None,
+      op_def=None,
+      compute_device=True):
+    # We have a cache of reduction ops that have already been moved to the
+    # forward graph, and we will check it first to avoid moving an op twice.
+    if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"):
+      self._forward_graph._optimized_reduction_ops_cache = {}
+    cache_key = self._get_optimized_reduction_ops_cache_key(
+        op_type, inputs, dtypes, input_types, name, attrs, op_def,
+        compute_device)
+    cached_op = self._forward_graph._optimized_reduction_ops_cache.get(
+        cache_key)
+    if cached_op is not None:
+      # This op has already been moved to the forward graph and we have it in
+      # the cache.
+      return cached_op
+
+    with self._forward_graph.as_default():
+      # `name` was built using name_scope stack of gradient graph and may not
+      # be unique in the forward graph. `Graph.create_op` does not uniquify
+      # names which are name scopes i.e. end in `/`. To ensure that the op
+      # created gets a unique name in the forward graph we get rid of the
+      # trailing slash.
+      name = ops.name_from_scope_name(name)
+      result = self._forward_graph._create_op_internal(
+          op_type,
+          inputs,
+          dtypes=dtypes,
+          input_types=input_types,
+          name=name,
+          attrs=attrs,
+          op_def=op_def,
+          compute_device=compute_device)
+
+      # Store the op we just moved to the forward graph so that it does
+      # not need to be added there again.
+      self._forward_graph._optimized_reduction_ops_cache[cache_key] = result
+      return result
+
+  def _get_optimized_reduction_ops_cache_key(
+      self,
+      op_type,
+      inputs,
+      dtypes=None,  # pylint: disable=redefined-outer-name
+      input_types=None,
+      name=None,
+      attrs=None,
+      op_def=None,
+      compute_device=True):
+    # We need all elements of CacheKey to be hashable.
+    inputs = tuple(map(lambda t: t.ref(), inputs))
+
+    if dtypes is not None:
+      dtypes = tuple(dtypes)
+
+    if input_types is not None:
+      input_types = tuple(input_types)
+
+    if attrs is not None:
+      hashable_attrs = []
+      for attr_name, attr_value in sorted(attrs.items()):
+        hashable_attrs.append((attr_name, attr_value.SerializeToString()))
+      attrs = tuple(hashable_attrs)
+
+    if op_def is not None:
+      op_def = op_def.SerializeToString()
+
+    return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types,
+                                         name, attrs, op_def, compute_device)
+
   def _capture_helper(self, tensor, name):
     """Implements the capturing described in the class docstring."""
     captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
diff --git a/tensorflow/python/platform/BUILD b/tensorflow/python/platform/BUILD
index 76b6197..024fcbc 100644
--- a/tensorflow/python/platform/BUILD
+++ b/tensorflow/python/platform/BUILD
@@ -46,10 +46,10 @@
     deps = [
         ":build_info",
         "//tensorflow/core:protos_all_py",
-        "//tensorflow/python:_pywrap_util_port",
-        "//tensorflow/python:lib",
         "//tensorflow/python:pywrap_tfe",
         "//tensorflow/python:util",
+        "//tensorflow/python/lib/io:lib",
+        "//tensorflow/python/util:_pywrap_util_port",
         "@absl_py//absl/flags",
         "@rules_python//python/runfiles",
         "@six_archive//:six",
diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD
index 9996f5c..42e542c 100644
--- a/tensorflow/python/profiler/BUILD
+++ b/tensorflow/python/profiler/BUILD
@@ -88,8 +88,8 @@
         ":option_builder",
         ":tfprof_logger",
         "//tensorflow/core/profiler:protos_all_py",
-        "//tensorflow/python:_pywrap_tfprof",
         "//tensorflow/python:errors",
+        "//tensorflow/python/util:_pywrap_tfprof",
         "@six_archive//:six",
     ],
 )
@@ -229,8 +229,8 @@
         "//tensorflow:internal",
     ],
     deps = [
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/profiler/internal:_pywrap_traceme",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/profiler/internal/model_analyzer_testlib.py b/tensorflow/python/profiler/internal/model_analyzer_testlib.py
index 459822c..a9b03b1 100644
--- a/tensorflow/python/profiler/internal/model_analyzer_testlib.py
+++ b/tensorflow/python/profiler/internal/model_analyzer_testlib.py
@@ -19,7 +19,6 @@
 
 import contextlib
 
-from tensorflow.python import _pywrap_tfprof as print_mdl
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
@@ -32,6 +31,7 @@
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.profiler import model_analyzer
 from tensorflow.python.training import gradient_descent
+from tensorflow.python.util import _pywrap_tfprof as print_mdl
 from tensorflow.python.util import compat
 
 
diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py
index 12ef107..4d8f097 100644
--- a/tensorflow/python/profiler/model_analyzer.py
+++ b/tensorflow/python/profiler/model_analyzer.py
@@ -27,12 +27,12 @@
 from google.protobuf import message
 from tensorflow.core.profiler import tfprof_options_pb2
 from tensorflow.core.profiler import tfprof_output_pb2
-from tensorflow.python import _pywrap_tfprof as print_mdl
 from tensorflow.python.eager import context
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.profiler import option_builder
 from tensorflow.python.profiler import tfprof_logger
+from tensorflow.python.util import _pywrap_tfprof as print_mdl
 from tensorflow.python.util.tf_export import tf_export
 
 _DEFAULT_PROFILE_OPTIONS = 0
diff --git a/tensorflow/python/profiler/profile_context.py b/tensorflow/python/profiler/profile_context.py
index e8e9ebd..6566550 100644
--- a/tensorflow/python/profiler/profile_context.py
+++ b/tensorflow/python/profiler/profile_context.py
@@ -25,12 +25,12 @@
 import threading
 
 from tensorflow.core.protobuf import config_pb2
-from tensorflow.python import _pywrap_tfprof as print_mdl
 from tensorflow.python.client import session
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.platform import gfile
 from tensorflow.python.profiler import model_analyzer
+from tensorflow.python.util import _pywrap_tfprof as print_mdl
 from tensorflow.python.util import compat
 
 WARMUP_STEPS = 10
diff --git a/tensorflow/python/profiler/profiler_v2.py b/tensorflow/python/profiler/profiler_v2.py
index 102a510..2bd210a 100644
--- a/tensorflow/python/profiler/profiler_v2.py
+++ b/tensorflow/python/profiler/profiler_v2.py
@@ -54,7 +54,7 @@
     ])):
   """Options for finer control over the profiler.
 
-  Use `tf.profiler.ProfilerOptions` to control `tf.profiler`
+  Use `tf.profiler.experimental.ProfilerOptions` to control `tf.profiler`
   behavior.
 
   Fields:
@@ -204,8 +204,8 @@
 
     Args:
       logdir: profile data will save to this directory.
-      options: An optional tf.profiler.ProfilerOptions can be provided to fine
-        tune the profiler's behavior.
+      options: An optional `tf.profiler.experimental.ProfilerOptions` can be
+        provided to fine tune the profiler's behavior.
     """
     self._logdir = logdir
     self._options = options
diff --git a/tensorflow/python/lib/core/py_exception_registry_wrapper.cc b/tensorflow/python/py_exception_registry_wrapper.cc
similarity index 100%
rename from tensorflow/python/lib/core/py_exception_registry_wrapper.cc
rename to tensorflow/python/py_exception_registry_wrapper.cc
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 1043d51..0de4337 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -42,21 +42,21 @@
     name = "constants",
     srcs = ["constants.py"],
     srcs_version = "PY2AND3",
-    deps = ["//tensorflow/python:tf_export"],
+    deps = ["//tensorflow/python/util:tf_export"],
 )
 
 py_strict_library(
     name = "signature_constants",
     srcs = ["signature_constants.py"],
     srcs_version = "PY2AND3",
-    deps = ["//tensorflow/python:tf_export"],
+    deps = ["//tensorflow/python/util:tf_export"],
 )
 
 py_strict_library(
     name = "tag_constants",
     srcs = ["tag_constants.py"],
     srcs_version = "PY2AND3",
-    deps = ["//tensorflow/python:tf_export"],
+    deps = ["//tensorflow/python/util:tf_export"],
 )
 
 py_strict_library(
@@ -75,9 +75,9 @@
         "//tensorflow/python:lib",
         "//tensorflow/python:platform",
         "//tensorflow/python:saver",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -97,9 +97,9 @@
         "//tensorflow/python:lib",
         "//tensorflow/python:platform",
         "//tensorflow/python:saver",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -137,8 +137,8 @@
         ":signature_def_utils",
         ":tag_constants",
         "//tensorflow/python:framework_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -153,9 +153,9 @@
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:lookup_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -205,9 +205,9 @@
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:lib",
         "//tensorflow/python:sparse_tensor",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -238,8 +238,8 @@
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -347,7 +347,6 @@
         "//tensorflow/python:platform",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:tensor_util",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:versions",
         "//tensorflow/python/eager:context",
@@ -360,6 +359,7 @@
         "//tensorflow/python/training/tracking:base",
         "//tensorflow/python/training/tracking:graph_view",
         "//tensorflow/python/training/tracking:util",
+        "//tensorflow/python/util:tf_export",
         "@absl_py//absl/logging",
     ],
 )
@@ -367,7 +367,6 @@
 tf_py_test(
     name = "save_test",
     srcs = ["save_test.py"],
-    tags = ["no_rocm"],
     deps = [
         ":loader",
         ":save",
@@ -407,7 +406,6 @@
         "//tensorflow/python:lookup_ops",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:tensor_util",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
         "//tensorflow/python/distribute:distribute_lib",
@@ -421,6 +419,7 @@
         "//tensorflow/python/training/tracking:base",
         "//tensorflow/python/training/tracking:graph_view",
         "//tensorflow/python/training/tracking:util",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -604,8 +603,8 @@
     name = "save_options",
     srcs = ["save_options.py"],
     deps = [
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
+        "//tensorflow/python/util:tf_export",
         "@enum34_archive//:enum",
         "@six_archive//:six",
     ],
@@ -615,7 +614,7 @@
     name = "load_options",
     srcs = ["load_options.py"],
     deps = [
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -628,8 +627,8 @@
         ":loader",
         "//tensorflow/python:lib",
         "//tensorflow/python:platform",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py
index 1d513b4..dda321b 100644
--- a/tensorflow/python/saved_model/load.py
+++ b/tensorflow/python/saved_model/load.py
@@ -573,7 +573,10 @@
     filename = os.path.join(
         saved_model_utils.get_assets_dir(self._export_dir),
         self._asset_file_def[proto.asset_file_def_index].filename)
-    return tracking.Asset(filename), setattr
+    asset = tracking.Asset(filename)
+    if not context.executing_eagerly():
+      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset.asset_path)
+    return asset, setattr
 
   def _recreate_function(self, proto):
     return function_deserialization.recreate_function(
diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py
index c2fad72..5c67dce 100644
--- a/tensorflow/python/saved_model/load_test.py
+++ b/tensorflow/python/saved_model/load_test.py
@@ -292,6 +292,7 @@
       imported_tensor = imported.f()
       with monitored_session.MonitoredSession() as sess:
         imported_output = sess.run(imported_tensor)
+        self.assertLen(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), 1)
         self.assertNotEqual(original_output, imported_output)
         with open(imported_output, "r") as f:
           self.assertEqual("contents", f.read())
@@ -798,6 +799,39 @@
     self.assertIsNotNone(imported_gradient)
     self.assertAllClose(imported_gradient, 2.)
 
+  def test_nested_fn_backprop(self, cycles):
+    weight = variables.Variable(2., trainable=True)
+
+    @def_function.function(input_signature=[
+        tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))])
+    def g(x):
+      weight.read_value()  # Just get the tape to watch the variable
+      handle = array_ops.identity(weight.handle)
+      @def_function.function
+      def launder_var_handle():
+        return array_ops.identity(handle)
+      return x + resource_variable_ops.read_variable_op(
+          launder_var_handle(), dtypes.float32)
+
+    root = tracking.AutoTrackable()
+    root.weight = weight
+    root.g = g
+    imported = cycle(root, cycles)
+    def get_gradient(obj, persistent):
+      with backprop.GradientTape(persistent=persistent) as t:
+        x = constant_op.constant([[1., 2., 3.], [1., -2, 3.]])
+        y = obj.g(x)
+        self.assertAllClose(y, obj.weight + x)
+        loss = math_ops.reduce_sum(y)
+        return t.gradient(loss, obj.weight)
+
+    imported_gradient = get_gradient(imported, persistent=False)
+    original_gradient = get_gradient(root, persistent=False)
+    self.assertIsNotNone(original_gradient)
+    self.assertAllClose(original_gradient, 6.)
+    self.assertIsNotNone(imported_gradient)
+    self.assertAllClose(imported_gradient, 6.)
+
   def test_restored_func_with_captured_var_backprop_float32(self, cycles):
     self._test_restored_func_with_captured_var_backprop(cycles, dtypes.float32)
 
@@ -2064,6 +2098,7 @@
   # allocations at a lower level.
   @test_util.assert_no_new_pyobjects_executing_eagerly
   def test_functions_cleaned(self):
+    self.skipTest("TODO(b/175152958): The test is leaking function definitions")
     if sys.version_info.major < 3:
       self.skipTest("Not working in Python 2")
     root = module.Module()
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 3725576..ce96ac5 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -781,8 +781,9 @@
   elif resource_variable_ops.is_resource_variable(obj):
     proto.variable.SetInParent()
     if not obj.name.endswith(":0"):
-      raise ValueError("Cowardly refusing to save variable %s because of"
-                       " unexpected suffix which won't be restored.")
+      raise ValueError("Cowardly refusing to save variable {} because of"
+                       " unexpected suffix which won't be restored.".format(
+                           obj.name))
     proto.variable.name = meta_graph._op_name(obj.name)  # pylint: disable=protected-access
     proto.variable.trainable = obj.trainable
     proto.variable.dtype = obj.dtype.as_datatype_enum
diff --git a/tensorflow/python/tf_program/mlir_gen.py b/tensorflow/python/tf_program/mlir_gen.py
index 8395848..3e41084 100644
--- a/tensorflow/python/tf_program/mlir_gen.py
+++ b/tensorflow/python/tf_program/mlir_gen.py
@@ -100,14 +100,14 @@
     attr = getattr(value, node.attr)
 
     if attr == core.Tensor:
-      return tfp.UnrankedTensorType.get(tfp.IntegerType.get(32, self.prog.ctx))
+      return tfp.UnrankedTensorType.get(tfp.IntegerType.get(self.prog.ctx, 32))
     return attr
 
   def visit_Name(self, node):
     if node.id == 'int':
-      return tfp.IntegerType.get(32, self.prog.ctx)
+      return tfp.IntegerType.get(self.prog.ctx, 32)
     if node.id == 'bool':
-      return tfp.IntegerType.get(1, self.prog.ctx)
+      return tfp.IntegerType.get(self.prog.ctx, 1)
     if node.id in self.ctx.info.namespace:
       return self.ctx.info.namespace[node.id]
 
@@ -203,7 +203,7 @@
       value = tfp.Tf_ConstOp.create(
           opb, opb.getUnknownLoc(),
           tfp.IntegerAttr.get(
-              tfp.IntegerType.get(32, self.prog.ctx), node.value)).getResult(0)
+              tfp.IntegerType.get(self.prog.ctx, 32), node.value)).getResult(0)
     return value
 
   def visit_FunctionDef(self, node):
diff --git a/tensorflow/python/tf_program/pywrap_tfd.py b/tensorflow/python/tf_program/pywrap_tfd.py
index a7a30b7..af198f6 100644
--- a/tensorflow/python/tf_program/pywrap_tfd.py
+++ b/tensorflow/python/tf_program/pywrap_tfd.py
@@ -85,7 +85,7 @@
   def create(cls, opb, loc, values):
     state = mlir.OperationState(loc, "tfp.Or")
     state.addTypes(
-        [UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
+        [UnrankedTensorType.get(IntegerType.get(opb.getContext(), 1))])
     state.addOperands(values)
     return opb.createOperation(state)
 
@@ -103,7 +103,7 @@
   def create(cls, opb, loc, values):
     state = mlir.OperationState(loc, "tfp.And")
     state.addTypes(
-        [UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
+        [UnrankedTensorType.get(IntegerType.get(opb.getContext(), 1))])
     state.addOperands(values)
     return opb.createOperation(state)
 
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 8e40d51..ba8d623 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -274,8 +274,8 @@
     visibility = ["//visibility:public"],
     deps = [
         "//tensorflow/python",  # TODO(b/34059704): remove when fixed
-        "//tensorflow/python:_pywrap_kernel_registry",
         "//tensorflow/python:platform",
+        "//tensorflow/python/util:_pywrap_kernel_registry",
     ],
 )
 
@@ -444,7 +444,6 @@
         "//tensorflow/cc/saved_model:saved_model_half_plus_two",
     ],
     force_without_xla_support_flag = False,
-    tags = ["no_rocm"],
 )
 
 saved_model_compile_aot(
@@ -455,7 +454,6 @@
         "//tensorflow/cc/saved_model:saved_model_half_plus_two",
     ],
     force_without_xla_support_flag = False,
-    tags = ["no_rocm"],
 )
 
 saved_model_compile_aot(
@@ -466,7 +464,6 @@
         "//tensorflow/cc/saved_model:saved_model_half_plus_two",
     ],
     force_without_xla_support_flag = False,
-    tags = ["no_rocm"],
     variables_to_feed = "variable_x",
 )
 
@@ -503,7 +500,6 @@
     srcs = if_xla_available([
         "aot_compiled_test.cc",
     ]),
-    tags = ["no_rocm"],
     deps = [
         "//tensorflow/core:test_main",
     ] + if_xla_available([
diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py
index ff2b3db..a6fbd2a 100644
--- a/tensorflow/python/tools/selective_registration_header_lib.py
+++ b/tensorflow/python/tools/selective_registration_header_lib.py
@@ -28,9 +28,9 @@
 
 from google.protobuf import text_format
 from tensorflow.core.framework import graph_pb2
-from tensorflow.python import _pywrap_kernel_registry
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import _pywrap_kernel_registry
 
 # Usually, we use each graph node to induce registration of an op and
 # corresponding kernel; nodes without a corresponding kernel (perhaps due to
diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD
index 234ca52..44838e6 100644
--- a/tensorflow/python/tpu/BUILD
+++ b/tensorflow/python/tpu/BUILD
@@ -34,7 +34,6 @@
         "no_oss_py2",
         "no_oss_py35",
         "no_pip",
-        "no_rocm",
     ],
     deps = [
         "//tensorflow/python:client_testlib",
@@ -97,7 +96,7 @@
     deps = [
         ":topology",
         "//tensorflow/python:platform",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -171,7 +170,7 @@
     srcs_version = "PY2AND3",
     deps = [
         "//tensorflow/core/protobuf/tpu:topology_proto_py",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
@@ -288,12 +287,12 @@
         "//tensorflow/python:function",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:tensor_shape",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python/compiler/xla",
         "//tensorflow/python/distribute:device_util",
         "//tensorflow/python/distribute:distribute_lib",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
         "@absl_py//absl/logging",
         "@enum34_archive//:enum",
@@ -478,7 +477,7 @@
     name = "tpu_name_util",
     srcs = ["tpu_name_util.py"],
     deps = [
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
index 32e5d54..75235c3 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
@@ -123,11 +123,6 @@
     self.feature_friends_row_lengths = [1, 3, 1, 3]
     self.resolver = None
 
-  def tearDown(self):
-    if self.resolver:
-      tpu_strategy_util.shutdown_tpu_system(self.resolver)
-    super(TPUEmbeddingCorrectness, self).tearDown()
-
   def _get_strategy(self):
     self.resolver = tpu_cluster_resolver.TPUClusterResolver(
         tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
diff --git a/tensorflow/python/tpu/tpu_embedding_v2_test.py b/tensorflow/python/tpu/tpu_embedding_v2_test.py
index d5f9e64..5843bb6 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2_test.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2_test.py
@@ -99,10 +99,6 @@
     self.cpu_mid_level = self.build_mid_level(
         self.second_mid_level_contents, self.cpu_mid_level_optimizer)
 
-  def tearDown(self):
-    tpu_strategy_util.shutdown_tpu_system(self.resolver)
-    super(TPUEmbeddingCheckpointTest, self).tearDown()
-
   def test_checkpoint_save_retrieves(self):
     # Ensure that the variables from the first model are loaded.
     self.first_mid_level._load_variables()
@@ -401,11 +397,6 @@
     self.feature_friends_row_lengths = [1, 3, 1, 3]
     self.resolver = None
 
-  def tearDown(self):
-    if self.resolver:
-      tpu_strategy_util.shutdown_tpu_system(self.resolver)
-    super(TPUEmbeddingTest, self).tearDown()
-
   def test_tables_with_same_name(self):
     with self.assertRaisesRegex(
         ValueError, 'Multiple tables with name table found.'):
diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD
index 0c7dbb1..543e3d2 100644
--- a/tensorflow/python/training/BUILD
+++ b/tensorflow/python/training/BUILD
@@ -113,9 +113,9 @@
         ":warm_starting_util",
         "//tensorflow/python:learning_rate_decay",
         "//tensorflow/python:sdca_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/training/experimental:loss_scale_optimizer",
         "//tensorflow/python/training/experimental:mixed_precision",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -139,7 +139,7 @@
         ":training_ops",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -154,7 +154,7 @@
         "//tensorflow/python:constant_op",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -170,7 +170,7 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:init_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -186,8 +186,8 @@
         "//tensorflow/python:math_ops",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:state_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -197,7 +197,7 @@
     srcs_version = "PY2AND3",
     deps = [
         "//tensorflow/python:errors",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -226,11 +226,11 @@
         "//tensorflow/python:io_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:resource_variable_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
         "//tensorflow/python/distribute:distribute_lib",
         "//tensorflow/python/training/saving:saveable_object_util",
+        "//tensorflow/python/util:tf_export",
         "@six_archive//:six",
     ],
 )
@@ -242,8 +242,8 @@
     deps = [
         "//tensorflow/python:errors",
         "//tensorflow/python:platform",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
+        "//tensorflow/python/util:tf_export",
         "@six_archive//:six",
     ],
 )
@@ -256,7 +256,7 @@
         ":server_lib",
         "//tensorflow/python:device",
         "//tensorflow/python:platform",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
         "@six_archive//:six",
     ],
 )
@@ -297,7 +297,7 @@
         "//tensorflow/python:constant_op",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -311,7 +311,7 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:resource_variable_ops",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -335,10 +335,10 @@
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:summary",
         "//tensorflow/python:tensor_shape",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_export",
         "@six_archive//:six",
     ],
 )
@@ -352,7 +352,7 @@
         ":training_ops",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -368,11 +368,11 @@
         "//tensorflow/python:init_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:state_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
         "//tensorflow/python/distribute:distribute_lib",
         "//tensorflow/python/distribute:reduce_util",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -390,7 +390,6 @@
         "//tensorflow/python:math_ops",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:state_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
@@ -399,6 +398,7 @@
         "//tensorflow/python/eager:backprop",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/training/tracking:base",
+        "//tensorflow/python/util:tf_export",
         "@six_archive//:six",
     ],
 )
@@ -413,7 +413,7 @@
         "//tensorflow/python:constant_op",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -426,7 +426,7 @@
         ":training_ops",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -436,8 +436,8 @@
     srcs_version = "PY2AND3",
     deps = [
         "//tensorflow/python:_pywrap_quantize_training",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -450,9 +450,9 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:session",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -474,7 +474,7 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:init_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -488,8 +488,8 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:session",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/distribute:distribute_lib",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
@@ -532,11 +532,11 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:state_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
         "//tensorflow/python/distribute:distribute_lib",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -573,10 +573,10 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:state_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
         "//tensorflow/python/training/saving:saveable_object_util",
+        "//tensorflow/python/util:tf_export",
         "@six_archive//:six",
     ],
 )
@@ -795,11 +795,11 @@
     name = "py_checkpoint_reader",
     srcs = ["py_checkpoint_reader.py"],
     deps = [
-        "//tensorflow/python:_pywrap_checkpoint_reader",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
+        "//tensorflow/python/util:_pywrap_checkpoint_reader",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -818,10 +818,10 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:lib",
         "//tensorflow/python:platform",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -867,13 +867,13 @@
         "//tensorflow/python:platform",
         "//tensorflow/python:session",
         "//tensorflow/python:string_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/training/saving:saveable_object",
         "//tensorflow/python/training/saving:saveable_object_util",
         "//tensorflow/python/training/tracking:base",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
     ],
 )
@@ -996,8 +996,8 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:init_ops",
         "//tensorflow/python:platform",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:variable_scope",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
         "@six_archive//:six",
     ],
@@ -1007,7 +1007,7 @@
     name = "session_run_hook",
     srcs = ["session_run_hook.py"],
     srcs_version = "PY2AND3",
-    deps = ["//tensorflow/python:tf_export"],
+    deps = ["//tensorflow/python/util:tf_export"],
 )
 
 py_library(
@@ -1025,10 +1025,10 @@
         "//tensorflow/python:lookup_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:summary",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -1065,8 +1065,8 @@
     deps = [
         "//tensorflow/python:errors",
         "//tensorflow/python:pywrap_tf_session",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -1082,10 +1082,10 @@
         "//tensorflow/python:platform",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:state_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -1351,10 +1351,10 @@
         "//tensorflow/python:platform",
         "//tensorflow/python:resources",
         "//tensorflow/python:summary",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
         "//tensorflow/python/distribute:distribute_coordinator_context",
+        "//tensorflow/python/util:tf_export",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/python/training/experimental/BUILD b/tensorflow/python/training/experimental/BUILD
index 0e43766..239ea2b 100644
--- a/tensorflow/python/training/experimental/BUILD
+++ b/tensorflow/python/training/experimental/BUILD
@@ -15,7 +15,6 @@
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
@@ -23,6 +22,7 @@
         "//tensorflow/python/distribute:reduce_util",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/training/tracking:base",
+        "//tensorflow/python/util:tf_export",
         "@six_archive//:six",
     ],
 )
@@ -37,9 +37,9 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:smart_cond",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python/distribute:distribute_lib",
         "//tensorflow/python/training:optimizer",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index b95e366..768188f 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -473,7 +473,8 @@
         self._averages[var.ref()] = avg
 
     with ops.name_scope(self.name) as scope:
-      decay = ops.convert_to_tensor(self._decay, name="decay")
+      decay = ops.convert_to_tensor(
+          self._decay, dtype=dtypes.float32, name="decay")
       if self._num_updates is not None:
         num_updates = math_ops.cast(
             self._num_updates, dtypes.float32, name="num_updates")
diff --git a/tensorflow/python/training/py_checkpoint_reader.py b/tensorflow/python/training/py_checkpoint_reader.py
index 83ab6e2..e3165e2 100644
--- a/tensorflow/python/training/py_checkpoint_reader.py
+++ b/tensorflow/python/training/py_checkpoint_reader.py
@@ -17,10 +17,10 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python._pywrap_checkpoint_reader import CheckpointReader
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.util import compat
+from tensorflow.python.util._pywrap_checkpoint_reader import CheckpointReader
 from tensorflow.python.util.tf_export import tf_export
 
 
diff --git a/tensorflow/python/training/saving/BUILD b/tensorflow/python/training/saving/BUILD
index 218ce9d..1edad74 100644
--- a/tensorflow/python/training/saving/BUILD
+++ b/tensorflow/python/training/saving/BUILD
@@ -16,7 +16,7 @@
     name = "checkpoint_options",
     srcs = ["checkpoint_options.py"],
     deps = [
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
diff --git a/tensorflow/python/types/BUILD b/tensorflow/python/types/BUILD
index d48f066..ed08419 100644
--- a/tensorflow/python/types/BUILD
+++ b/tensorflow/python/types/BUILD
@@ -33,7 +33,7 @@
     ],
     deps = [
         ":doc_typealias",
-        "//tensorflow/python:tf_export",
+        "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
         "@typing_extensions_archive//:typing_extensions",
     ],
diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD
new file mode 100644
index 0000000..292e4d1
--- /dev/null
+++ b/tensorflow/python/util/BUILD
@@ -0,0 +1,627 @@
+# Tensorflow util package
+
+load("//tensorflow:tensorflow.bzl", "py_strict_library")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library")  # @unused
+load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
+
+visibility = [
+    "//engedu/ml/tf_from_scratch:__pkg__",
+    "//third_party/cloud_tpu/convergence_tools:__subpackages__",
+    "//third_party/mlperf:__subpackages__",
+    "//tensorflow:internal",
+    "//tensorflow/lite/toco/python:__pkg__",
+    "//tensorflow_models:__subpackages__",
+    "//tensorflow_model_optimization:__subpackages__",
+    "//third_party/py/cleverhans:__subpackages__",
+    "//third_party/py/launchpad:__subpackages__",
+    "//third_party/py/reverb:__subpackages__",
+    "//third_party/py/neural_structured_learning:__subpackages__",
+    "//third_party/py/tensorflow_examples:__subpackages__",
+    "//third_party/py/tf_agents:__subpackages__",  # For benchmarks.
+    "//third_party/py/tf_slim:__subpackages__",
+    "//third_party/py/tensorflow_docs:__subpackages__",
+    "//third_party/py/keras:__subpackages__",
+]
+
+package(
+    default_visibility = visibility,
+    licenses = ["notice"],  # Apache 2.0
+)
+
+# TODO(mdan): Move this utility outside of TF.
+cc_library(
+    name = "kernel_registry",
+    srcs = ["kernel_registry.cc"],
+    hdrs = ["kernel_registry.h"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+    ],
+    alwayslink = 1,
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_tfprof",
+    srcs = ["tfprof_wrapper.cc"],
+    module_name = "_pywrap_tfprof",
+    deps = [
+        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:lib_headers_for_pybind",
+        "//tensorflow/core/profiler/internal:print_model_analysis_hdr",
+        "//third_party/eigen3",
+        "//third_party/python_runtime:headers",
+        "@com_google_absl//absl/strings",
+        "@pybind11",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_utils",
+    srcs = ["util_wrapper.cc"],
+    hdrs = ["util.h"],
+    module_name = "_pywrap_utils",
+    deps = [
+        "//tensorflow/python:pybind11_lib",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_kernel_registry",
+    srcs = ["kernel_registry_wrapper.cc"],
+    hdrs = ["kernel_registry.h"],
+    module_name = "_pywrap_kernel_registry",
+    deps = [
+        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:lib_headers_for_pybind",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python:pybind11_lib",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_stat_summarizer",
+    srcs = ["stat_summarizer_wrapper.cc"],
+    module_name = "_pywrap_stat_summarizer",
+    deps = [
+        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:protos_all_cc",
+        "//third_party/eigen3",
+        "//third_party/python_runtime:headers",
+        "@com_google_absl//absl/memory",
+        "@pybind11",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_tensor_float_32_execution",
+    srcs = ["tensor_float_32.cc"],
+    hdrs = ["//tensorflow/core/platform:tensor_float_32_hdr"],
+    compatible_with = get_compatible_with_portable(),
+    module_name = "_pywrap_tensor_float_32_execution",
+    deps = [
+        "@pybind11",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_util_port",
+    srcs = ["port_wrapper.cc"],
+    hdrs = ["//tensorflow/core/util:port_hdrs"],
+    module_name = "_pywrap_util_port",
+    deps = [
+        "//tensorflow/core/util:port",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_transform_graph",
+    srcs = ["transform_graph_wrapper.cc"],
+    hdrs = ["//tensorflow/tools/graph_transforms:transform_graph_hdrs"],
+    module_name = "_pywrap_transform_graph",
+    deps = [
+        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:lib_headers_for_pybind",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_checkpoint_reader",
+    srcs = ["py_checkpoint_reader_wrapper.cc"],
+    hdrs = [
+        "//tensorflow/c:checkpoint_reader_hdrs",
+        "//tensorflow/c:headers",
+        "//tensorflow/c/eager:headers",
+        "//tensorflow/python/lib/core:ndarray_tensor_hdr",
+        "//tensorflow/python/lib/core:py_exception_registry_hdr",
+        "//tensorflow/python/lib/core:safe_ptr_hdr",
+    ],
+    module_name = "_pywrap_checkpoint_reader",
+    deps = [
+        "//tensorflow/core:lib_headers_for_pybind",
+        "//tensorflow/core:op_gen_lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/util/tensor_bundle:tensor_bundle_headers_lib",
+        "//tensorflow/python:pybind11_lib",
+        "//tensorflow/python:pybind11_status",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
+        "//third_party/py/numpy:headers",
+        "//third_party/python_runtime:headers",
+        "@com_google_absl//absl/strings",
+        "@pybind11",
+    ],
+)
+
+cc_library(
+    name = "cpp_python_util",
+    srcs = ["util.cc"],
+    hdrs = ["util.h"],
+    deps = [
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/python/lib/core:safe_ptr",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
+        "//third_party/python_runtime:headers",
+        "@com_google_absl//absl/memory",
+    ],
+)
+
+tf_py_test(
+    name = "decorator_utils_test",
+    srcs = ["decorator_utils_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:platform",
+    ],
+)
+
+tf_py_test(
+    name = "deprecation_test",
+    srcs = ["deprecation_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:platform",
+    ],
+)
+
+tf_py_test(
+    name = "dispatch_test",
+    srcs = ["dispatch_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:platform",
+    ],
+)
+
+tf_py_test(
+    name = "keyword_args_test",
+    srcs = ["keyword_args_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_strict_library(
+    name = "tf_export",
+    srcs = ["tf_export.py"],
+    compatible_with = get_compatible_with_portable(),
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:__subpackages__"],
+    deps = [
+        ":tf_decorator",
+    ],
+)
+
+tf_py_test(
+    name = "tf_export_test",
+    srcs = ["tf_export_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:platform",
+    ],
+)
+
+# Leaf library: may not depend on anything else inside TensorFlow.
+# TODO(mdan): Move this utility outside of TF.
+py_strict_library(
+    name = "tf_decorator",
+    srcs = [
+        "tf_contextlib.py",
+        "tf_decorator.py",
+        "tf_inspect.py",
+    ],
+    compatible_with = get_compatible_with_portable(),
+    srcs_version = "PY2AND3",
+    visibility = [
+        "//tensorflow:__subpackages__",
+        # TODO(mdan): Remove these dependencies.
+        "//third_party/py/tf_slim:__subpackages__",
+        "//learning/deepmind/research/language/translation/lm:__subpackages__",
+    ],
+    deps = [
+        "@six_archive//:six",
+    ],
+)
+
+# Note: this is a heavyweight library specialized for TensorFlow graphs. Do not use for
+# other purposes.
+py_strict_library(
+    name = "tf_stack",
+    srcs = ["tf_stack.py"],
+    srcs_version = "PY2AND3",
+    # TODO(mdan): Remove public visibility.
+    visibility = ["//visibility:public"],
+    deps = [
+        ":_tf_stack",
+        "@six_archive//:six",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_tf_stack",
+    srcs = ["tf_stack.cc"],
+    hdrs = [
+        "//tensorflow/c:headers",
+        "//tensorflow/c/eager:headers",
+        # Using header directly is required to avoid ODR violations.
+        "stack_trace.h",
+    ],
+    # TODO(b/138203821): change to "util._tf_stack" once the bug is fixed.
+    module_name = "_tf_stack",
+    deps = [
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/types:span",
+        "@pybind11",
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+        "//tensorflow/c:pywrap_required_hdrs",
+        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
+        "//tensorflow/core/framework:pywrap_required_hdrs",
+        "//tensorflow/core/platform:path",
+    ] + if_static([
+        ":stack_trace",
+    ]),
+)
+
+tf_py_test(
+    name = "tf_stack_test",
+    srcs = ["tf_stack_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":tf_export",
+        ":tf_stack",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+cc_library(
+    name = "stack_trace",
+    srcs = ["stack_trace.cc"],
+    hdrs = ["stack_trace.h"],
+    deps = [
+        "//tensorflow/core/platform:str_util",
+        "//tensorflow/core/platform:stringpiece",
+        "//tensorflow/core/util:abstract_stack_trace",
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/types:optional",
+    ],
+)
+
+cc_library(
+    name = "function_parameter_canonicalizer",
+    srcs = ["function_parameter_canonicalizer.cc"],
+    hdrs = ["function_parameter_canonicalizer.h"],
+    deps = [
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/python/lib/core:py_util",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr",
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_function_parameter_canonicalizer_binding_for_test",
+    testonly = True,
+    srcs = ["function_parameter_canonicalizer_binding_for_test.cc"],
+    hdrs = [
+        "function_parameter_canonicalizer.h",
+    ],
+    module_name = "_function_parameter_canonicalizer_binding_for_test",
+    deps = [
+        "//tensorflow/core:lib",
+        "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs",
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+        "@com_google_absl//absl/types:span",
+        "@pybind11",
+    ],
+)
+
+tf_py_test(
+    name = "function_parameter_canonicalizer_test",
+    srcs = ["function_parameter_canonicalizer_test.py"],
+    python_version = "PY3",
+    tags = [
+        "no_pip",  # b/168621686
+        "no_windows",  # b/169275019
+    ],
+    deps = [
+        ":_function_parameter_canonicalizer_binding_for_test",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_library(
+    name = "util",
+    srcs = glob(
+        ["**/*.py"],
+        exclude = [
+            "example_parser*",
+            "tf_contextlib.py",
+            "tf_should_use.py",
+            "tf_export.py",
+            "tf_stack.py",
+            "tf_decorator.py",
+            "**/*_test.py",
+        ],
+    ),
+    compatible_with = get_compatible_with_portable(),
+    srcs_version = "PY2AND3",
+    visibility = visibility + [
+        "//tensorflow:__pkg__",
+        "//third_party/py/tensorflow_core:__subpackages__",
+        "//third_party/py/tf_agents:__subpackages__",
+        "//third_party/py/tfx:__subpackages__",
+    ],
+    deps = [
+        ":_pywrap_tensor_float_32_execution",
+        # global_test_configuration is added here because all major tests depend on this
+        # library. It isn't possible to add these test dependencies via tensorflow.bzl's
+        # py_test because not all tensorflow tests use tensorflow.bzl's py_test.
+        "//tensorflow/python:global_test_configuration",
+        ":tf_decorator",
+        ":tf_export",
+        "@org_python_pypi_backports_weakref",
+        "@com_google_protobuf//:protobuf_python",
+        "//third_party/py/numpy",
+        "@six_archive//:six",
+        "@wrapt",
+        "//tensorflow/tools/docs:doc_controls",
+        "//tensorflow/tools/compatibility:all_renames_v2",
+    ],
+)
+
+tf_py_test(
+    name = "object_identity_test",
+    size = "small",
+    srcs = ["object_identity_test.py"],
+    python_version = "PY3",
+)
+
+# Placeholder for intenal nest_test comments.
+tf_py_test(
+    name = "nest_test",
+    size = "small",
+    srcs = ["nest_test.py"],
+    main = "nest_test.py",
+    python_version = "PY3",
+    deps = [":nest_test_main_lib"],
+)
+
+py_library(
+    name = "nest_test_main_lib",
+    testonly = True,
+    srcs = ["nest_test.py"],
+    deps = [
+        ":util",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:math_ops",
+        "//third_party/py/numpy",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+tf_py_test(
+    name = "serialization_test",
+    size = "small",
+    srcs = ["serialization_test.py"],
+    main = "serialization_test.py",
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+tf_py_test(
+    name = "function_utils_test",
+    srcs = ["function_utils_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+tf_py_test(
+    name = "tf_contextlib_test",
+    size = "small",
+    srcs = ["tf_contextlib_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+tf_py_test(
+    name = "tf_decorator_test",
+    size = "small",
+    srcs = ["tf_decorator_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_library(
+    name = "tf_should_use",
+    srcs = ["tf_should_use.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":util",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python/eager:context",
+        "@six_archive//:six",
+    ],
+)
+
+tf_py_test(
+    name = "tf_should_use_test",
+    size = "small",
+    srcs = ["tf_should_use_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":tf_should_use",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+tf_py_test(
+    name = "tf_inspect_test",
+    size = "small",
+    srcs = ["tf_inspect_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_library(
+    name = "example_parser_configuration",
+    srcs = ["example_parser_configuration.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
+    ],
+)
+
+tf_py_test(
+    name = "lock_util_test",
+    size = "small",
+    srcs = ["lock_util_test.py"],
+    main = "lock_util_test.py",
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+tf_py_test(
+    name = "module_wrapper_test",
+    size = "small",
+    srcs = ["module_wrapper_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/tools/compatibility:all_renames_v2",
+        "@six_archive//:six",
+    ],
+)
+
+tf_proto_library(
+    name = "compare_test_proto",
+    testonly = 1,
+    srcs = ["protobuf/compare_test.proto"],
+    cc_api_version = 2,
+)
+
+tf_py_test(
+    name = "protobuf_compare_test",
+    size = "small",
+    srcs = ["protobuf/compare_test.py"],
+    main = "protobuf/compare_test.py",
+    python_version = "PY3",
+    tags = ["no_pip"],  # compare_test_pb2 proto is not available in pip.
+    deps = [
+        ":compare_test_proto_py",
+        ":util",
+        "//tensorflow/python:platform_test",
+        "@six_archive//:six",
+    ],
+)
+
+tf_py_test(
+    name = "example_parser_configuration_test",
+    size = "small",
+    srcs = ["example_parser_configuration_test.py"],
+    main = "example_parser_configuration_test.py",
+    python_version = "PY3",
+    deps = [
+        ":example_parser_configuration",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:parsing_ops",
+    ],
+)
+
+filegroup(
+    name = "util_hdr",
+    srcs = ["util.h"],
+)
+
+filegroup(
+    name = "compare_test_proto_src",
+    srcs = ["protobuf/compare_test.proto"],
+)
diff --git a/tensorflow/python/util/function_parameter_canonicalizer_test.py b/tensorflow/python/util/function_parameter_canonicalizer_test.py
index 968265f..5dc87b5 100644
--- a/tensorflow/python/util/function_parameter_canonicalizer_test.py
+++ b/tensorflow/python/util/function_parameter_canonicalizer_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python import _function_parameter_canonicalizer_binding_for_test
 from tensorflow.python.platform import test
+from tensorflow.python.util import _function_parameter_canonicalizer_binding_for_test
 
 
 class FunctionParameterCanonicalizerTest(test.TestCase):
diff --git a/tensorflow/python/util/module_wrapper.py b/tensorflow/python/util/module_wrapper.py
index c5856ee..c9b3511 100644
--- a/tensorflow/python/util/module_wrapper.py
+++ b/tensorflow/python/util/module_wrapper.py
@@ -19,12 +19,12 @@
 from __future__ import print_function
 
 import importlib
+import inspect
 import types
 
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util import tf_inspect
-from tensorflow.python.util import tf_stack
 from tensorflow.tools.compatibility import all_renames_v2
 
 
@@ -41,11 +41,12 @@
   # We want to get stack frame 3 frames up from current frame,
   # i.e. above __getattr__, _tfmw_add_deprecation_warning,
   # and _call_location calls.
-  stack = tf_stack.extract_stack(limit=4)
-  if not stack:  # should never happen as we're in a function
-    return 'UNKNOWN'
-  frame = stack[0]
-  return '{}:{}'.format(frame.filename, frame.lineno)
+  frame = inspect.currentframe()
+  for _ in range(4):
+    parent = frame.f_back
+    if parent is None:
+      break
+  return '{}:{}'.format(frame.f_code.co_filename, frame.f_lineno)
 
 
 def contains_deprecation_decorator(decorators):
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index db3ad27..9c83158 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -47,10 +47,10 @@
 import six as _six
 import wrapt as _wrapt
 
-from tensorflow.python import _pywrap_utils
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import _pywrap_utils
 from tensorflow.python.util.compat import collections_abc as _collections_abc
 from tensorflow.python.util.tf_export import tf_export
-from tensorflow.python.platform import tf_logging
 
 
 _SHALLOW_TREE_HAS_INVALID_KEYS = (
@@ -438,15 +438,62 @@
                           expand_composites=False):
   """Asserts that two structures are nested in the same way.
 
-  Note that namedtuples with identical name and fields are always considered
-  to have the same shallow structure (even with `check_types=True`).
-  For instance, this code will print `True`:
+  Note the method does not check the types of data inside the structures.
 
-  ```python
-  def nt(a, b):
-    return collections.namedtuple('foo', 'a b')(a, b)
-  print(assert_same_structure(nt(0, 1), nt(2, 3)))
-  ```
+  Examples:
+
+  * These scalar vs. scalar comparisons will pass:
+
+    >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32))
+    >>> tf.nest.assert_same_structure("abc", np.array([1, 2]))
+
+  * These sequence vs. sequence comparisons will pass:
+
+    >>> structure1 = (((1, 2), 3), 4, (5, 6))
+    >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
+    >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]]
+    >>> tf.nest.assert_same_structure(structure1, structure2)
+    >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False)
+
+    >>> import collections
+    >>> tf.nest.assert_same_structure(
+    ...     collections.namedtuple("bar", "a b")(1, 2),
+    ...     collections.namedtuple("foo", "a b")(2, 3),
+    ...     check_types=False)
+
+    >>> tf.nest.assert_same_structure(
+    ...     collections.namedtuple("bar", "a b")(1, 2),
+    ...     { "a": 1, "b": 2 },
+    ...     check_types=False)
+
+    >>> tf.nest.assert_same_structure(
+    ...     { "a": 1, "b": 2, "c": 3 },
+    ...     { "c": 6, "b": 5, "a": 4 })
+
+    >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits(
+    ...       values=[3, 1, 4, 1, 5, 9, 2, 6],
+    ...       row_splits=[0, 4, 4, 7, 8, 8])
+    >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits(
+    ...       values=[3, 1, 4],
+    ...       row_splits=[0, 3])
+    >>> tf.nest.assert_same_structure(
+    ...       ragged_tensor1,
+    ...       ragged_tensor2,
+    ...       expand_composites=True)
+
+  * These examples will raise exceptions:
+
+    >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1]))
+    Traceback (most recent call last):
+    ...
+    ValueError: The two structures don't have the same nested structure
+
+    >>> tf.nest.assert_same_structure(
+    ...       collections.namedtuple('bar', 'a b')(1, 2),
+    ...       collections.namedtuple('foo', 'a b')(2, 3))
+    Traceback (most recent call last):
+    ...
+    TypeError: The two structures don't have the same nested structure
 
   Args:
     nest1: an arbitrarily nested structure.
diff --git a/tensorflow/python/util/stack_trace.cc b/tensorflow/python/util/stack_trace.cc
index 40e05e6..8aed669 100644
--- a/tensorflow/python/util/stack_trace.cc
+++ b/tensorflow/python/util/stack_trace.cc
@@ -15,6 +15,8 @@
 
 #include "tensorflow/python/util/stack_trace.h"
 
+#include <limits>
+
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/stringpiece.h"
 
@@ -40,27 +42,37 @@
 namespace tensorflow {
 
 std::vector<StackFrame> StackTrace::ToStackFrames(
-    const StackTraceMap& mapper, const StackTraceFilter& filtered) const {
+    const StackTraceMap& mapper, const StackTraceFilter& filtered,
+    bool reverse_traversal, int limit) const {
   DCheckPyGilStateForStackTrace();
   std::vector<StackFrame> result;
   result.reserve(code_objs_.size());
 
-  for (int i = code_objs_.size() - 1; i >= 0; --i) {
-    const char* file_name = GetPythonString(code_objs_[i]->co_filename);
-    const int line_number =
-        PyCode_Addr2Line(code_objs_[i], last_instructions_[i]);
+  if (limit == -1) limit = std::numeric_limits<int>::max();
 
-    if (!result.empty() && filtered.count(file_name)) {
-      continue;  // Never filter the innermost frame.
+  for (int i = 0; i < code_objs_.size(); i++) {
+    int idx = reverse_traversal ? i : code_objs_.size() - 1 - i;
+
+    const std::pair<PyCodeObject*, int>& code_obj = code_objs_[idx];
+    const char* file_name = GetPythonString(code_obj.first->co_filename);
+    const int line_number = PyCode_Addr2Line(code_obj.first, code_obj.second);
+
+    if (filtered && filtered(file_name)) {
+      continue;
     }
 
-    auto it = mapper.find(std::make_pair(file_name, line_number));
+    absl::optional<StackFrame> mapped =
+        mapper ? mapper(std::make_pair(file_name, line_number)) : absl::nullopt;
 
-    if (it != mapper.end()) {
-      result.push_back(it->second);
+    if (mapped) {
+      result.push_back(*mapped);
     } else {
       result.emplace_back(StackFrame{file_name, line_number,
-                                     GetPythonString(code_objs_[i]->co_name)});
+                                     GetPythonString(code_obj.first->co_name)});
+    }
+
+    if (result.size() == limit) {
+      break;
     }
   }
 
diff --git a/tensorflow/python/util/stack_trace.h b/tensorflow/python/util/stack_trace.h
index b416e3b..eec6fee 100644
--- a/tensorflow/python/util/stack_trace.h
+++ b/tensorflow/python/util/stack_trace.h
@@ -44,10 +44,10 @@
 
 // Maps filename/line_no combination into a stack frame.
 using StackTraceMap =
-    absl::flat_hash_map<std::pair<std::string, int>, StackFrame>;
+    std::function<absl::optional<StackFrame>(std::pair<const char*, int>)>;
 
-// Contains filenames which should be skipped.
-using StackTraceFilter = absl::flat_hash_set<std::string>;
+// Returns "true" on filenames which should be skipped.
+using StackTraceFilter = std::function<bool(const char*)>;
 
 // A class for capturing Python stack trace.
 class StackTrace final {
@@ -74,8 +74,7 @@
       DCHECK(code_obj != nullptr);
 
       Py_INCREF(code_obj);
-      result.code_objs_.push_back(code_obj);
-      result.last_instructions_.push_back(frame->f_lasti);
+      result.code_objs_.push_back(std::make_pair(code_obj, frame->f_lasti));
     }
     return result;
   }
@@ -84,41 +83,38 @@
   ABSL_ATTRIBUTE_HOT
   ~StackTrace() { Clear(); }
 
-  StackTrace(StackTrace&& other) {
-    code_objs_ = std::move(other.code_objs_);
-    last_instructions_ = std::move(other.last_instructions_);
-    other.code_objs_ = {};
-  }
+  StackTrace(StackTrace&& other) { std::swap(code_objs_, other.code_objs_); }
 
   // Python GIL must be acquired beforehand.
   ABSL_ATTRIBUTE_HOT
   StackTrace& operator=(StackTrace&& other) {
     Clear();
     std::swap(code_objs_, other.code_objs_);
-    std::swap(last_instructions_, other.last_instructions_);
     return *this;
   }
 
   // Returns a structured representation of the captured stack trace.
   // `mapper` provides a custom mapping for translating stack frames, `filter`
-  // returns `true` for the stack frames which should be omitted, and if
-  // `drop_last` is set, the last stack frame is dropped.
-  std::vector<StackFrame> ToStackFrames(
-      const StackTraceMap& mapper = {},
-      const StackTraceFilter& filtered = {}) const;
+  // returns `true` for the stack frames which should be omitted.
+  //
+  // `reverse_traversal` changes the traversal order of the stack trace, and
+  // `limit` bounds the number of returned frames (after filtering).
+  std::vector<StackFrame> ToStackFrames(const StackTraceMap& mapper = {},
+                                        const StackTraceFilter& filtered = {},
+                                        bool reverse_traversal = false,
+                                        int limit = -1) const;
 
   // Python GIL must be acquired beforehand.
   ABSL_ATTRIBUTE_HOT
   void Clear() {
     if (!code_objs_.empty()) DCheckPyGilStateForStackTrace();
-    for (PyCodeObject* obj : code_objs_) Py_DECREF(obj);
+    for (const auto& p : code_objs_) Py_DECREF(p.first);
     code_objs_.clear();
-    last_instructions_.clear();
   }
 
  private:
-  absl::InlinedVector<PyCodeObject*, kStackTraceInitialSize> code_objs_;
-  absl::InlinedVector<int, kStackTraceInitialSize> last_instructions_;
+  absl::InlinedVector<std::pair<PyCodeObject*, int>, kStackTraceInitialSize>
+      code_objs_;
 
   StackTrace(const StackTrace&) = delete;
   StackTrace& operator=(const StackTrace&) = delete;
diff --git a/tensorflow/python/util/tf_stack.cc b/tensorflow/python/util/tf_stack.cc
index d549896..bc25159 100644
--- a/tensorflow/python/util/tf_stack.cc
+++ b/tensorflow/python/util/tf_stack.cc
@@ -19,7 +19,7 @@
 // We store the retrieved stack trace within the Node object directly. Then
 // whenever the graph is instantiated/copies, we copy the stack trace with it.
 // Since the graph instantiation goes through the protobuf roundtrip, we store
-// the original Graph with stack traces attached in FunctionLibraryDefinition.
+// the original stack traces mapping attached in FunctionLibraryDefinition.
 
 #include <Python.h>
 #include <frameobject.h>
@@ -54,6 +54,30 @@
 
 namespace py = pybind11;
 
+using SourceLoc = std::tuple<std::string, int>;
+
+using SourceMap = absl::flat_hash_map<SourceLoc, StackFrame>;
+
+using StringSet = absl::flat_hash_set<std::string>;
+
+// Python wrapper for a SourceMap.
+class PyBindSourceMap {
+ public:
+  PyBindSourceMap() : source_map_(std::make_shared<SourceMap>()) {}
+
+  // Shares ownership with whoever captures traces in the scope of this map.
+  std::shared_ptr<SourceMap> source_map_;
+};
+
+// Python wrapper for a FileSet.
+class PyBindFileSet {
+ public:
+  PyBindFileSet() : file_set_(std::make_shared<StringSet>()) {}
+
+  // Shares ownership with whoever captures traces in the scope of this set.
+  std::shared_ptr<StringSet> file_set_;
+};
+
 // Returns contents of the line corresponding to the given frame.
 //
 // Precondition: must be holding Python GIL.
@@ -98,47 +122,68 @@
 
 class StackTraceWrapper : public AbstractStackTrace {
  public:
-  StackTraceWrapper(StackTrace&& captured, const py::dict& source_map,
-                    const py::set& filtered_filenames)
+  StackTraceWrapper(StackTrace&& captured,
+                    const std::shared_ptr<SourceMap>& source_map,
+                    const std::shared_ptr<StringSet>& filter)
       : captured_(std::move(captured)),
         source_map_(source_map),
-        filtered_filenames_(filtered_filenames) {}
+        filter_(filter) {}
 
   explicit StackTraceWrapper(absl::Span<StackFrame const> stack_frames)
       : stack_frames_cache_(std::vector<StackFrame>(stack_frames.begin(),
                                                     stack_frames.end())) {}
 
-  static StackTraceWrapper ExtractStack(const py::object& limit,
-                                        const py::list& mappers,
-                                        const py::list& filters) {
-    // In Python 3.X ``traceback.extract_stack`` allows ``limit`` to
-    // either be None or -1.
-    int casted_limit = limit.is_none() ? -1 : py::cast<ssize_t>(limit);
-
-    // Raise limit by one since we are dropping the last frame.
-    if (casted_limit != -1) casted_limit++;
-
-    const py::dict& source_map =
-        mappers.empty()
-            ? py::dict()
-            : mappers[mappers.size() - 1].attr("get_effective_source_map")();
-    const py::set& filtered_filenames =
-        filters.empty()
-            ? py::set()
-            : filters[filters.size() - 1].attr("get_filtered_filenames")();
-    return StackTraceWrapper{StackTrace::Capture(casted_limit), source_map,
-                             filtered_filenames};
+  static StackTraceWrapper ExtractStack(
+      const std::shared_ptr<SourceMap>& source_map,
+      const std::shared_ptr<StringSet>& filter) {
+    return StackTraceWrapper{StackTrace::Capture(-1), source_map, filter};
   }
 
   absl::Span<StackFrame const> ToFrames() const override {
-    GenerateCache();
+    if (stack_frames_cache_) {
+      return *stack_frames_cache_;
+    }
+
+    // Grabbing the GIL solves two purposes: 1) makes the class thread-safe,
+    // and 2) ToStackFrames and LineContents actually need it.
+    PyGILState_STATE state = PyGILState_Ensure();
+
+    stack_frames_cache_ = captured_.ToStackFrames(
+        [&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
+        [&](const char* f) { return StackTraceFiltering(f); });
+    stack_frames_cache_->pop_back();  // Drop last stack frame.
+    PyGILState_Release(state);
     return *stack_frames_cache_;
   }
 
+  StackFrame LastUserFrame() const override {
+    if (last_stack_frame_cache_) {
+      return *last_stack_frame_cache_;
+    }
+
+    PyGILState_STATE state = PyGILState_Ensure();
+    std::vector<StackFrame> last_frame = captured_.ToStackFrames(
+        [&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
+        [&](const char* file_name) {
+          return StackTraceFiltering(file_name) ||
+                 IsInternalFrameForFilename(file_name);
+        },
+        /*reverse_traversal=*/true,
+        /*limit=*/1);
+
+    if (last_frame.empty()) {
+      last_stack_frame_cache_ = StackFrame{"", -1, ""};
+    } else {
+      DCHECK_EQ(last_frame.size(), 1);
+      last_stack_frame_cache_ = last_frame[0];
+    }
+    PyGILState_Release(state);
+    return *last_stack_frame_cache_;
+  }
+
   std::string ToString(const TracePrintingOptions& opts) const override {
-    GenerateCache();
     std::vector<std::string> files_to_find_prefix;
-    for (const StackFrame& frame : *stack_frames_cache_) {
+    for (const StackFrame& frame : ToFrames()) {
       if (!absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)) {
         files_to_find_prefix.push_back(frame.file_name);
       }
@@ -147,45 +192,18 @@
         opts.filter_common_prefix
             ? io::CommonPathPrefix(files_to_find_prefix).size()
             : 0;
-    return absl::StrJoin(
-        *stack_frames_cache_, "\n",
-        [&](std::string* out, const StackFrame& frame) {
-          absl::StrAppend(out,
-                          StackFrameToString(frame, opts, shared_prefix_size));
-        });
-  }
 
-  bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); }
-
-  void GenerateCache() const {
-    // Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
-    // 2) ToStackFrames and LineContents actually need it.
-    PyGILState_STATE state = PyGILState_Ensure();
-    if (stack_frames_cache_) {
-      return;
+    if (!opts.drop_internal_frames) {
+      return ToStringHelper(*stack_frames_cache_, opts, shared_prefix_size);
     }
 
-    absl::flat_hash_map<std::pair<std::string, int>, StackFrame> m;
-    absl::flat_hash_set<std::string> f;
-
-    for (const std::pair<py::handle, py::handle>& p : *source_map_) {
-      const py::tuple& key = py::cast<py::tuple>(p.first);
-      const py::tuple& value = py::cast<py::tuple>(p.second);
-
-      m.emplace(std::make_pair(std::string(py::cast<py::str>(key[0])),
-                               py::cast<ssize_t>(key[1])),
-                StackFrame{std::string(py::cast<py::str>(value[0])),
-                           py::cast<py::int_>(value[1]),
-                           std::string(py::cast<py::str>(value[2]))});
+    std::vector<StackFrame> filtered_frames;
+    for (const StackFrame& frame : *stack_frames_cache_) {
+      if (!IsInternalFrameForFilename(frame.file_name)) {
+        filtered_frames.push_back(frame);
+      }
     }
-
-    for (const py::handle& h : *filtered_filenames_) {
-      f.emplace(py::cast<py::str>(h));
-    }
-
-    stack_frames_cache_ = captured_.ToStackFrames(m, f);
-    stack_frames_cache_->pop_back();  // Drop last stack frame.
-    PyGILState_Release(state);
+    return ToStringHelper(filtered_frames, opts, shared_prefix_size);
   }
 
   StackTraceWrapper(StackTraceWrapper&&) = default;
@@ -193,21 +211,90 @@
     PyGILState_STATE state = PyGILState_Ensure();
     captured_.Clear();
     source_map_.reset();
-    filtered_filenames_.reset();
+    filter_.reset();
     PyGILState_Release(state);
   }
 
  private:
-  mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
+  static std::string ToStringHelper(absl::Span<StackFrame const> stack_frames,
+                                    const TracePrintingOptions& opts,
+                                    int shared_prefix_size) {
+    return absl::StrJoin(
+        stack_frames, "\n", [&](std::string* out, const StackFrame& frame) {
+          absl::StrAppend(out,
+                          StackFrameToString(frame, opts, shared_prefix_size));
+        });
+  }
+
+  static bool IsInternalFrameForFilename(absl::string_view file_name) {
+    // Use a simple heuristic for now.
+    // TODO(cheshire): Build a more sophisticated mechanism, rely on @tf.export.
+    return (absl::StrContains(file_name, "tensorflow/python") ||
+            absl::StrContains(file_name, "tensorflow\\python")) &&
+           !absl::StrContains(file_name, "keras") &&
+           !absl::StrContains(file_name, "test.py");
+  }
+
+  absl::optional<StackFrame> StackTraceMapping(SourceLoc loc) const {
+    if (source_map_->contains(loc)) {
+      return source_map_->at(loc);
+    }
+
+    return absl::nullopt;
+  }
+
+  bool StackTraceFiltering(const char* file_name) const {
+    return filter_->contains(file_name);
+  }
+
   StackTrace captured_;
+  std::shared_ptr<SourceMap> source_map_;
+  std::shared_ptr<StringSet> filter_;
+
   // Using optional to force destruction while we hold a GIL.
-  absl::optional<py::dict> source_map_;
-  absl::optional<py::set> filtered_filenames_;
+  mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
+  mutable absl::optional<StackFrame> last_stack_frame_cache_;
 };
 
 }  // namespace
 
 PYBIND11_MODULE(_tf_stack, m) {
+  py::class_<PyBindSourceMap>(m, "PyBindSourceMap")
+      .def(py::init())
+      .def("update_to",
+           [](const PyBindSourceMap& self, const py::tuple& source_map) {
+             self.source_map_->clear();
+             for (const auto& item : source_map) {
+               const auto& tuple_item = py::cast<py::tuple>(item);
+
+               const auto& key = py::cast<py::tuple>(tuple_item[0]);
+               std::string&& k_filename = py::cast<std::string>(key[0]);
+               int k_lineno = py::cast<int>(key[1]);
+
+               const auto& value = py::cast<py::tuple>(tuple_item[1]);
+               std::string&& v_filename = py::cast<std::string>(value[0]);
+               int v_lineno = py::cast<int>(value[1]);
+               const auto& function_name_val = value[2];
+               std::string&& v_function_name =
+                   function_name_val.is_none()
+                       ? ""
+                       : py::cast<std::string>(function_name_val);
+
+               self.source_map_->emplace(
+                   SourceLoc(k_filename, k_lineno),
+                   StackFrame({v_filename, v_lineno, v_function_name}));
+             }
+           });
+
+  py::class_<PyBindFileSet>(m, "PyBindFileSet")
+      .def(py::init())
+      .def("update_to", [](const PyBindFileSet& self, const py::set& file_set) {
+        self.file_set_->clear();
+        for (const auto& item : file_set) {
+          self.file_set_->insert(py::cast<std::string>(item));
+        }
+      });
+
   py::class_<StackFrame>(m, "StackFrame")
       .def_property_readonly(
           "filename",
@@ -293,32 +380,33 @@
            })
       .def("__hash__",
            [](const StackTraceWrapper& self) {
-             self.GenerateCache();
              return py::hash(py::str(self.ToString({})));
            })
-      .def("__repr__", [](const StackTraceWrapper& self) {
-        self.GenerateCache();
-        return py::str(self.ToString({}));
-      });
+      .def("__repr__",
+           [](const StackTraceWrapper& self) {
+             return py::str(self.ToString({}));
+           })
+      .def("last_user_frame",
+           [](const StackTraceWrapper& self) { return self.LastUserFrame(); });
 
   m.def(
       "extract_stack_for_node",
-      [](const py::object& limit, const py::list& mappers,
-         const py::list& filters,
+      [](const PyBindSourceMap& source_map, const PyBindFileSet& file_set,
          TF_Operation* op) -> const AbstractStackTrace& {
         Node* node = reinterpret_cast<Node*>(op);
         DCHECK(!node->GetStackTrace()) << "Should not reset the stack trace";
-        node->SetStackTrace(std::make_shared<StackTraceWrapper>(
-            StackTraceWrapper::ExtractStack(limit, mappers, filters)));
+        node->SetStackTrace(
+            std::make_shared<StackTraceWrapper>(StackTraceWrapper::ExtractStack(
+                source_map.source_map_, file_set.file_set_)));
         return *node->GetStackTrace();
       },
       py::return_value_policy::reference);
 
   m.def(
       "extract_stack",
-      [](const py::object& limit, const py::list& mappers,
-         const py::list& filters) {
-        return StackTraceWrapper::ExtractStack(limit, mappers, filters);
+      [](const PyBindSourceMap& source_map, const PyBindFileSet& file_set) {
+        return StackTraceWrapper::ExtractStack(source_map.source_map_,
+                                               file_set.file_set_);
       },
       py::return_value_policy::move);
 }
diff --git a/tensorflow/python/util/tf_stack.py b/tensorflow/python/util/tf_stack.py
index aad0a0f..a6d1bbe 100644
--- a/tensorflow/python/util/tf_stack.py
+++ b/tensorflow/python/util/tf_stack.py
@@ -25,7 +25,7 @@
 import six
 
 # TODO(b/138203821): change to from ...util import ... once the bug is fixed.
-from tensorflow.python import _tf_stack
+from tensorflow.python.util import _tf_stack
 
 # Generally such lookups should be done using `threading.local()`. See
 # https://blogs.gnome.org/jamesh/2008/06/11/tls-python/ for a detailed
@@ -40,8 +40,10 @@
   _get_thread_key = threading.get_ident
 
 
-_source_mapper_stacks = collections.defaultdict(list)
-_source_filter_stacks = collections.defaultdict(list)
+# TODO(mdan): Move these to C++ as well.
+# Moving to C++ can further avoid extra copies made by get_effective_map.
+_source_mapper_stacks = collections.defaultdict(lambda: [SentinelMapper()])
+_source_filter_stacks = collections.defaultdict(lambda: [SentinelFilter()])
 
 
 class StackTraceTransform(object):
@@ -51,8 +53,6 @@
   _thread_key = None
 
   def __enter__(self):
-    self.reset()
-
     # Any given instance is assumed to be used by a single thread, which reduces
     # expensive thread local lookups.
     if self._thread_key is None:
@@ -61,48 +61,71 @@
       assert self._thread_key == _get_thread_key(), 'Shared across threads?'
 
     stack = self._stack_dict[self._thread_key]
-    if stack:
-      self.parent = stack[-1]
-    else:
-      self.parent = None
+    self.parent = stack[-1]
     stack.append(self)
+    self.update()
     return self
 
   def __exit__(self, unused_type, unused_value, unused_traceback):
     top = self._stack_dict[self._thread_key].pop()
     assert top is self, 'Concurrent access?'
 
-  def reset(self):
-    pass
+  def update(self):
+    raise NotImplementedError('subclasses need to override this')
 
 
 class StackTraceMapper(StackTraceTransform):
   """Allows remapping traceback information to different source code."""
   _stack_dict = _source_mapper_stacks
 
-  def reset(self):
-    self._effective_source_map = None
+  def __init__(self):
+    self.internal_map = _tf_stack.PyBindSourceMap()
+
+  def update(self):
+    self.internal_map.update_to(tuple(self.get_effective_source_map().items()))
 
   def get_effective_source_map(self):
     """Returns a map (filename, lineno) -> (filename, lineno, function_name)."""
     raise NotImplementedError('subclasses need to override this')
 
 
+EMPTY_DICT = {}
+
+
+class SentinelMapper(StackTraceMapper):
+
+  def get_effective_source_map(self):
+    return EMPTY_DICT
+
+
 class StackTraceFilter(StackTraceTransform):
   """Allows filtering traceback information by removing superfluous frames."""
   _stack_dict = _source_filter_stacks
 
-  def reset(self):
-    self._filtered_filenames = None
+  def __init__(self):
+    self.internal_set = _tf_stack.PyBindFileSet()
+
+  def update(self):
+    self.internal_set.update_to(set(self.get_filtered_filenames()))
 
   def get_filtered_filenames(self):
     raise NotImplementedError('subclasses need to override this')
 
 
+EMPTY_SET = frozenset()
+
+
+class SentinelFilter(StackTraceFilter):
+
+  def get_filtered_filenames(self):
+    return EMPTY_SET
+
+
 class CurrentModuleFilter(StackTraceFilter):
   """Filters stack frames from the module where this is used (best effort)."""
 
   def __init__(self):
+    super().__init__()
     filter_filename = None
     outer_f = None
     f = inspect.currentframe()
@@ -114,6 +137,9 @@
         if outer_f is not None:
           filter_filename = inspect.getsourcefile(outer_f)
       self._filename = filter_filename
+      # This may be called repeatedly: once on entry by the superclass, then by
+      # each child context manager.
+      self._cached_set = None
     finally:
       # Avoid reference cycles, see:
       # https://docs.python.org/3.7/library/inspect.html#the-interpreter-stack
@@ -121,58 +147,52 @@
       del outer_f
 
   def get_filtered_filenames(self):
-    if self._filtered_filenames is None:
-      self._filtered_filenames = frozenset((self._filename,))
-      if self.parent is not None:
-        self._filtered_filenames |= self.parent.get_filtered_filenames()
-    return self._filtered_filenames
+    if self._cached_set is not None:
+      return self._cached_set
+
+    filtered_filenames = frozenset((self._filename,))
+    if self.parent is not None:
+      filtered_filenames |= self.parent.get_filtered_filenames()
+    self._cached_set = filtered_filenames
+    return filtered_filenames
 
 
-def extract_stack(limit=-1):
-  """A lightweight, extensible re-implementation of traceback.extract_stack.
-
-  NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
-      each stack frame using linecache, which results in an abundance of stat()
-      calls. This implementation does not retrieve the code, and any consumer
-      should apply _convert_stack to the result to obtain a traceback that can
-      be formatted etc. using traceback methods.
-
-  Args:
-    limit: A limit on the number of frames to return.
+def extract_stack():
+  """An eager-friendly alternative to traceback.extract_stack.
 
   Returns:
-    An object wrapping the sequence of StackFrame objects (filename, lineno,
-    name, line) corresponding to the call stack of the current thread. The
-    returned object can be indexed as a Python list.
+    A list-like FrameSummary containing StackFrame-like objects, which are
+    namedtuple-like objects with the following fields: filename, lineno, name,
+    line, meant to masquerade as traceback.FrameSummary objects.
   """
   # N.B ExtractStack in tf_stack.cc will drop this frame prior to
   # traversing the stack.
   # TODO(cheshire): Remove this function, use extract_stack_for_node or Python
   # traceback module.
   thread_key = _get_thread_key()
-  return _tf_stack.extract_stack(limit, _source_mapper_stacks[thread_key],
-                                 _source_filter_stacks[thread_key])
+  return _tf_stack.extract_stack(
+      _source_mapper_stacks[thread_key][-1].internal_map,
+      _source_filter_stacks[thread_key][-1].internal_set)
 
 
-def extract_stack_for_node(node, limit=-1):
-  """Same as extract_stack, but also saves the retrieved stack in `node`.
+# TODO(mdan): Revisit these - a single location is almost always sufficient.
+def extract_stack_for_node(node):
+  """Attaches the current stack trace to `node`.
 
   Args:
-    node: Pointer to the Node object.
-    limit: A limit on the number of frames to return.
+    node: a Node object.
 
   Returns:
-    An object wrapping the sequence of StackFrame objects (filename, lineno,
-    name, line) corresponding to the call stack of the current thread. The
-    returned object can be indexed as a Python list.
+    A list-like FrameSummary containing StackFrame-like objects, which are
+    namedtuple-like objects with the following fields: filename, lineno, name,
+    line, meant to masquerade as traceback.FrameSummary objects.
   """
   # N.B ExtractStack in tf_stack.cc will drop this frame prior to
   # traversing the stack.
   thread_key = _get_thread_key()
-  return _tf_stack.extract_stack_for_node(limit,
-                                          _source_mapper_stacks[thread_key],
-                                          _source_filter_stacks[thread_key],
-                                          node)
+  return _tf_stack.extract_stack_for_node(
+      _source_mapper_stacks[thread_key][-1].internal_map,
+      _source_filter_stacks[thread_key][-1].internal_set, node)
 
 
 StackSummary = _tf_stack.StackTraceWrapper
diff --git a/tensorflow/python/util/tf_stack_test.py b/tensorflow/python/util/tf_stack_test.py
index 07dc2d3..c704f7d 100644
--- a/tensorflow/python/util/tf_stack_test.py
+++ b/tensorflow/python/util/tf_stack_test.py
@@ -26,31 +26,19 @@
 
 class TFStackTest(test.TestCase):
 
-  def testLimit(self):
-    self.assertEmpty(tf_stack.extract_stack(limit=0))
-    self.assertLen(tf_stack.extract_stack(limit=1), 1)
+  def testFormatStackSelfConsistency(self):
+    # Both defined on the same line to produce identical stacks.
+    stacks = tf_stack.extract_stack(), traceback.extract_stack()
     self.assertEqual(
-        len(tf_stack.extract_stack(limit=-1)),
-        len(tf_stack.extract_stack()))
-
-  def testConsistencyWithTraceback(self):
-    stack, expected_stack = extract_stack()
-    for frame, expected in zip(stack, expected_stack):
-      self.assertEqual(convert_stack_frame(frame), expected)
-
-  def testFormatStack(self):
-    stack, expected_stack = extract_stack()
-    self.assertEqual(
-        traceback.format_list(stack),
-        traceback.format_list(expected_stack))
+        traceback.format_list(stacks[0]), traceback.format_list(stacks[1]))
 
   def testFrameSummaryEquality(self):
-    frame0, frame1 = tf_stack.extract_stack(limit=2)
-    self.assertNotEqual(frame0, frame1)
-    self.assertEqual(frame0, frame0)
+    frames1 = tf_stack.extract_stack()
+    frames2 = tf_stack.extract_stack()
 
-    another_frame0, _ = tf_stack.extract_stack(limit=2)
-    self.assertEqual(frame0, another_frame0)
+    self.assertNotEqual(frames1[0], frames1[1])
+    self.assertEqual(frames1[0], frames1[0])
+    self.assertEqual(frames1[0], frames2[0])
 
   def testFrameSummaryEqualityAndHash(self):
     # Both defined on the same line to produce identical stacks.
@@ -63,23 +51,16 @@
     self.assertEqual(frame1, frame2)
     self.assertEqual(hash(tuple(frame1)), hash(tuple(frame2)))
 
+  def testLastUserFrame(self):
+    trace = tf_stack.extract_stack()  # COMMENT
+    frame = trace.last_user_frame()
+    self.assertRegex(frame.line, "# COMMENT")
+
 
 def extract_stack(limit=None):
   # Both defined on the same line to produce identical stacks.
   return tf_stack.extract_stack(limit), traceback.extract_stack(limit)
 
 
-def convert_stack_frame(frame):
-  """Converts a TF stack frame into Python's."""
-  # TODO(mihaimaruseac): Remove except case when dropping suport for py2
-  try:
-    return traceback.FrameSummary(
-        frame.filename, frame.lineno, frame.name, line=frame.line)
-  except AttributeError:
-    # On Python < 3.5 (i.e., Python2), we don't have traceback.FrameSummary so
-    # we don't need to match with that class. Instead, just a tuple is enough.
-    return tuple(frame)
-
-
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/security/README.md b/tensorflow/security/README.md
index 27e24b6..32e06cb 100644
--- a/tensorflow/security/README.md
+++ b/tensorflow/security/README.md
@@ -10,10 +10,16 @@
 
 | Advisory Number | Type               | Versions affected | Reported by           | Additional Information      |
 |-----------------|--------------------|:-----------------:|-----------------------|-----------------------------|
-| [TFSA-2020-028](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-028.md)   | Float cast overflow undefined behavior               | <= 2.3 | (Reported on GitHub) | [issue report](https://github.com/tensorflow/tensorflow/issues/42129) |
-| [TFSA-2020-027](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-027.md)   | Segfault in `tf.quantization.quantize_and_dequantize`| <= 2.3 | (Reported on GitHub) | [issue report](https://github.com/tensorflow/tensorflow/issues/42105) |
-| [TFSA-2020-026](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-026.md)   | Segfault in `tf.raw_ops.Switch` in eager mode                                             | 2.2.0, 2.3.0        | Aivul Team from Qihoo 360                     |  |
-| [TFSA-2020-025](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-025.md)   | Undefined behavior in `dlpack.to_dlpack`                                                  | 2.2.0, 2.3.0        | Aivul Team from Qihoo 360                     |  |
+| [TFSA-2020-034](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-034.md)   | Heap out of bounds access in MakeEdge                                              | >= 1.15.0, <= 2.3.0 | (discovered internally)                       |  |
+| [TFSA-2020-033](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-033.md)   | CHECK-fail in LSTM with zero-length input                                          | >= 1.15.0, <= 2.3.0 | (discovered internally)                       |  |
+| [TFSA-2020-032](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-032.md)   | Heap out of bounds read in filesystem glob matching                                | 2.4.0-rc{0,1,2,3}   | Aivul Team from Qihoo 360                     |  |
+| [TFSA-2020-031](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-031.md)   | Write to immutable memory region                                                   | >= 1.15.0, <= 2.3.0 | Aivul Team from Qihoo 360                     |  |
+| [TFSA-2020-030](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-030.md)   | Lack of validation in data format attributes                                       | >= 1.15.0, <= 2.3.0 | Aivul Team from Qihoo 360                     |  |
+| [TFSA-2020-029](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-029.md)   | Uninitialized memory access in Eigen types                                         | >= 1.15.0, <= 2.3.0 | (discovered internally)                       |  |
+| [TFSA-2020-028](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-028.md)   | Float cast overflow undefined behavior                                             | <= 2.3 | (Reported on GitHub) | [issue report](https://github.com/tensorflow/tensorflow/issues/42129) |
+| [TFSA-2020-027](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-027.md)   | Segfault in `tf.quantization.quantize_and_dequantize                              `| <= 2.3 | (Reported on GitHub) | [issue report](https://github.com/tensorflow/tensorflow/issues/42105) |
+| [TFSA-2020-026](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-026.md)   | Segfault in `tf.raw_ops.Switch` in eager mode                                      | 2.2.0, 2.3.0        | Aivul Team from Qihoo 360                     |  |
+| [TFSA-2020-025](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-025.md)   | Undefined behavior in `dlpack.to_dlpack`                                           | 2.2.0, 2.3.0        | Aivul Team from Qihoo 360                     |  |
 | [TFSA-2020-024](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-024.md)   | Memory leak in `dlpack.to_dlpack`                                                  | 2.2.0, 2.3.0        | Aivul Team from Qihoo 360                     |  |
 | [TFSA-2020-023](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-023.md)   | Memory corruption in `dlpack.to_dlpack`                                            | 2.2.0, 2.3.0        | Aivul Team from Qihoo 360                     |  |
 | [TFSA-2020-022](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-022.md)   | Crash due to invalid shape of `grad_values` in SparseFillEmptyRowsGrad             | >= 1.15.0, <= 2.3.0 | (variant analysis, Aivul Team from Qihoo 360) |  |
diff --git a/tensorflow/security/advisory/tfsa-2020-029.md b/tensorflow/security/advisory/tfsa-2020-029.md
new file mode 100644
index 0000000..e145969
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2020-029.md
@@ -0,0 +1,53 @@
+## TFSA-2020-029: Uninitialized memory access in Eigen types
+
+### CVE Number
+CVE-2020-26266
+
+### Impact
+Under certain cases, a saved model can trigger use of uninitialized values
+during code execution. This is caused by having tensor buffers be filled with
+the default value of the type but forgetting to [default initialize the
+quantized floating point types in
+Eigen](https://github.com/tensorflow/tensorflow/blob/f70160322a579144950dff1537dcbe3c7c09d6f5/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h#L61-L104):
+
+```cc
+struct QUInt8 {
+  QUInt8() {}
+  // ...
+  uint8_t value;
+};
+
+struct QInt16 {
+  QInt16() {}
+  // ...
+  int16_t value;
+};
+
+struct QUInt16 {
+  QUInt16() {}
+  // ...
+  uint16_t value;
+};
+
+struct QInt32 {
+  QInt32() {}
+  // ...
+  int32_t value;
+};
+```
+
+### Patches
+
+We have patched the issue in GitHub commit
+[ace0c15a22f7f054abcc1f53eabbcb0a1239a9e2](https://github.com/tensorflow/tensorflow/commit/ace0c15a22f7f054abcc1f53eabbcb0a1239a9e2)
+and will release TensorFlow 2.4.0 containing the patch. TensorFlow nightly
+packages after this commit will also have the issue resolved.
+
+Since this issue also impacts TF versions before 2.4, we will patch all releases
+between 1.15 and 2.3 inclusive.
+
+### For more information
+Please consult [our security
+guide](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) for
+more information regarding the security model and how to contact us with issues
+and questions.
diff --git a/tensorflow/security/advisory/tfsa-2020-030.md b/tensorflow/security/advisory/tfsa-2020-030.md
new file mode 100644
index 0000000..5c8f1d7
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2020-030.md
@@ -0,0 +1,89 @@
+## TFSA-2020-030: Lack of validation in data format attributes
+
+### CVE Number
+CVE-2020-26267
+
+### Impact
+The `tf.raw_ops.DataFormatVecPermute` API does not validate the `src_format` and
+`dst_format` attributes. [The
+code](https://github.com/tensorflow/tensorflow/blob/304b96815324e6a73d046df10df6626d63ac12ad/tensorflow/core/kernels/data_format_ops.cc)
+assumes that these two arguments define a permutation of `NHWC`.
+
+However, these assumptions are not checked and this can result in uninitialized
+memory accesses, read outside of bounds and even crashes.
+
+```python
+>>> import tensorflow as tf
+>>> tf.raw_ops.DataFormatVecPermute(x=[1,4], src_format='1234', dst_format='1234')
+<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 757100143], dtype=int32)>
+...
+>>> tf.raw_ops.DataFormatVecPermute(x=[1,4], src_format='HHHH', dst_format='WWWW')
+<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 32701], dtype=int32)>
+...
+>>> tf.raw_ops.DataFormatVecPermute(x=[1,4], src_format='H', dst_format='W')
+<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 32701], dtype=int32)>
+>>> tf.raw_ops.DataFormatVecPermute(x=[1,2,3,4], 
+                                    src_format='1234', dst_format='1253')
+<tf.Tensor: shape=(4,), dtype=int32, numpy=array([4, 2, 939037184, 3], dtype=int32)>
+...
+>>> tf.raw_ops.DataFormatVecPermute(x=[1,2,3,4],
+                                    src_format='1234', dst_format='1223')
+<tf.Tensor: shape=(4,), dtype=int32, numpy=array([4, 32701, 2, 3], dtype=int32)>
+...
+>>> tf.raw_ops.DataFormatVecPermute(x=[1,2,3,4],
+                                    src_format='1224', dst_format='1423')
+<tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 4, 3, 32701], dtype=int32)>
+...
+>>> tf.raw_ops.DataFormatVecPermute(x=[1,2,3,4], src_format='1234', dst_format='432')
+<tf.Tensor: shape=(4,), dtype=int32, numpy=array([4, 3, 2, 32701], dtype=int32)>
+...
+>>> tf.raw_ops.DataFormatVecPermute(x=[1,2,3,4],
+                                    src_format='12345678', dst_format='87654321')
+munmap_chunk(): invalid pointer
+Aborted
+...
+>>> tf.raw_ops.DataFormatVecPermute(x=[[1,5],[2,6],[3,7],[4,8]],           
+                                    src_format='12345678', dst_format='87654321')
+<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
+array([[71364624,        0],
+       [71365824,        0],
+       [     560,        0],
+       [      48,        0]], dtype=int32)>
+...
+>>> tf.raw_ops.DataFormatVecPermute(x=[[1,5],[2,6],[3,7],[4,8]], 
+                                    src_format='12345678', dst_format='87654321')
+free(): invalid next size (fast)
+Aborted
+```
+
+A similar issue occurs in `tf.raw_ops.DataFormatDimMap`, for the same reasons:
+
+```python
+>>> tf.raw_ops.DataFormatDimMap(x=[[1,5],[2,6],[3,7],[4,8]], src_format='1234',
+>>> dst_format='8765')
+<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
+array([[1954047348, 1954047348],
+       [1852793646, 1852793646],
+       [1954047348, 1954047348],
+       [1852793632, 1852793632]], dtype=int32)>
+```
+
+### Patches
+
+We have patched the issue in GitHub commit
+[ebc70b7a592420d3d2f359e4b1694c236b82c7ae](https://github.com/tensorflow/tensorflow/commit/ebc70b7a592420d3d2f359e4b1694c236b82c7ae)
+and will release TensorFlow 2.4.0 containing the patch. TensorFlow nightly
+packages after this commit will also have the issue resolved.
+
+Since this issue also impacts TF versions before 2.4, we will patch all releases
+between 1.15 and 2.3 inclusive.
+
+### For more information
+Please consult [our security
+guide](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) for
+more information regarding the security model and how to contact us with issues
+and questions.
+
+### Attribution
+This vulnerability has been reported by members of the Aivul Team from Qihoo
+360.
diff --git a/tensorflow/security/advisory/tfsa-2020-031.md b/tensorflow/security/advisory/tfsa-2020-031.md
new file mode 100644
index 0000000..24bb891
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2020-031.md
@@ -0,0 +1,47 @@
+## TFSA-2020-031: Write to immutable memory region
+
+### CVE Number
+CVE-2020-26268
+
+### Impact
+The `tf.raw_ops.ImmutableConst` operation returns a constant tensor created from
+a memory mapped file which is assumed immutable. However, if the type of the
+tensor is not an integral type, the operation crashes the Python interpreter as
+it tries to write to the memory area:
+
+```python
+>>> import tensorflow as tf
+>>> with open('/tmp/test.txt','w') as f: f.write('a'*128)
+>>> tf.raw_ops.ImmutableConst(dtype=tf.string,shape=2,
+                              memory_region_name='/tmp/test.txt')
+```
+
+If the file is too small, TensorFlow properly returns an error as the memory
+area has fewer bytes than what is needed for the tensor it creates. However, as
+soon as there are enough bytes, the above snippet causes a segmentation fault.
+
+This is because the alocator used to return the buffer data is not marked as
+returning an opaque handle since the [needed virtual
+method](https://github.com/tensorflow/tensorflow/blob/c1e1fc899ad5f8c725dcbb6470069890b5060bc7/tensorflow/core/framework/typed_allocator.h#L78-L85)
+is [not
+overriden](https://github.com/tensorflow/tensorflow/blob/acdf3c04fcfa767ae8d109b9e1f727ef050dba4d/tensorflow/core/kernels/immutable_constant_op.cc).
+
+### Patches
+
+We have patched the issue in GitHub commit
+[c1e1fc899ad5f8c725dcbb6470069890b5060bc7](https://github.com/tensorflow/tensorflow/commit/c1e1fc899ad5f8c725dcbb6470069890b5060bc7)
+and will release TensorFlow 2.4.0 containing the patch. TensorFlow nightly
+packages after this commit will also have the issue resolved.
+
+Since this issue also impacts TF versions before 2.4, we will patch all releases
+between 1.15 and 2.3 inclusive.
+
+### For more information
+Please consult [our security
+guide](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) for
+more information regarding the security model and how to contact us with issues
+and questions.
+
+### Attribution
+This vulnerability has been reported by members of the Aivul Team from Qihoo
+360.
diff --git a/tensorflow/security/advisory/tfsa-2020-032.md b/tensorflow/security/advisory/tfsa-2020-032.md
new file mode 100644
index 0000000..93bbaeb
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2020-032.md
@@ -0,0 +1,51 @@
+## TFSA-2020-032: Heap out of bounds read in filesystem glob matching
+
+### CVE Number
+CVE-2020-26269
+
+### Impact
+The general implementation for matching filesystem paths to globbing pattern is
+vulnerable to an access out of bounds of [the array holding the
+directories](https://github.com/tensorflow/tensorflow/blob/458c6260265c46ebaf18052d6c61aea4b6b40926/tensorflow/core/platform/file_system_helper.cc#L127):
+
+```cc
+if (!fs->Match(child_path, dirs[dir_index])) { ... }
+```
+
+Since `dir_index` is [unconditionaly
+incremented](https://github.com/tensorflow/tensorflow/blob/458c6260265c46ebaf18052d6c61aea4b6b40926/tensorflow/core/platform/file_system_helper.cc#L106)
+outside of the lambda function where the vulnerable pattern occurs, this results
+in an access out of bounds issue under certain scenarios. For example, if
+`/tmp/x` is a directory that only contains a single file `y`, then the following
+snippet will cause a crash due to the out of bounds read:
+
+```python
+>>> tf.io.gfile.glob('/tmp/x/')
+Segmentation fault
+```
+
+There are multiple invariants and preconditions that are assumed by the parallel
+implementation of `GetMatchingPaths` but are not verified by the PRs introducing
+it ([#40861](https://github.com/tensorflow/tensorflow/pull/40861) and
+[#44310](https://github.com/tensorflow/tensorflow/pull/44310)). Thus, we are
+completely rewriting the implementation to fully specify and validate these.
+
+### Patches
+
+We have patched the issue in GitHub commit
+[8b5b9dc96666a3a5d27fad7179ff215e3b74b67c](https://github.com/tensorflow/tensorflow/commit/8b5b9dc96666a3a5d27fad7179ff215e3b74b67c)
+and will release TensorFlow 2.4.0 containing the patch. TensorFlow nightly
+packages after this commit will also have the issue resolved.
+
+This issue only impacts master branch and the release candidates for TF version
+2.4. The final release of the 2.4 release will be patched.
+
+### For more information
+Please consult [our security
+guide](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) for
+more information regarding the security model and how to contact us with issues
+and questions.
+
+### Attribution
+This vulnerability has been reported by members of the Aivul Team from Qihoo
+360.
diff --git a/tensorflow/security/advisory/tfsa-2020-033.md b/tensorflow/security/advisory/tfsa-2020-033.md
new file mode 100644
index 0000000..e5537c5
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2020-033.md
@@ -0,0 +1,27 @@
+## TFSA-2020-033: CHECK-fail in LSTM with zero-length input
+
+### CVE Number
+CVE-2020-26270
+
+### Impact
+Running an LSTM/GRU model where the LSTM/GRU layer receives an input with
+zero-length results in a `CHECK` failure when using the CUDA backend.
+
+This can result in a query-of-death vulnerability, via denial of service, if
+users can control the input to the layer.
+
+### Patches
+
+We have patched the issue in GitHub commit
+[14755416e364f17fb1870882fa778c7fec7f16e3](https://github.com/tensorflow/tensorflow/commit/14755416e364f17fb1870882fa778c7fec7f16e3)
+and will release TensorFlow 2.4.0 containing the patch. TensorFlow nightly
+packages after this commit will also have the issue resolved.
+
+Since this issue also impacts TF versions before 2.4, we will patch all releases
+between 1.15 and 2.3 inclusive.
+
+### For more information
+Please consult [our security
+guide](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) for
+more information regarding the security model and how to contact us with issues
+and questions.
diff --git a/tensorflow/security/advisory/tfsa-2020-034.md b/tensorflow/security/advisory/tfsa-2020-034.md
new file mode 100644
index 0000000..aa8f456
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2020-034.md
@@ -0,0 +1,44 @@
+## TFSA-2020-034: Heap out of bounds access in MakeEdge
+
+### CVE Number
+CVE-2020-26271
+
+### Impact
+Under certain cases, loading a saved model can result in accessing uninitialized
+memory while building the computation graph. The [`MakeEdge`
+function](https://github.com/tensorflow/tensorflow/blob/3616708cb866365301d8e67b43b32b46d94b08a0/tensorflow/core/common_runtime/graph_constructor.cc#L1426-L1438)
+creates an edge between one output tensor of the `src` node (given by
+`output_index`) and the input slot of the `dst` node (given by `input_index`).
+This is only possible if the types of the tensors on both sides coincide, so the
+function begins by obtaining the corresponding `DataType` values and comparing
+these for equality:
+
+```cc
+  DataType src_out = src->output_type(output_index);
+  DataType dst_in = dst->input_type(input_index);
+  //...
+```
+
+However, there is no check that the indices point to inside of the arrays they
+index into. Thus, this can result in accessing data out of bounds of the
+corresponding heap allocated arrays.
+
+In most scenarios, this can manifest as unitialized data access, but if the
+index points far away from the boundaries of the arrays this can be used to leak
+addresses from the library.
+
+### Patches
+
+We have patched the issue in GitHub commit
+[0cc38aaa4064fd9e79101994ce9872c6d91f816b](https://github.com/tensorflow/tensorflow/commit/0cc38aaa4064fd9e79101994ce9872c6d91f816b)
+and will release TensorFlow 2.4.0 containing the patch. TensorFlow nightly
+packages after this commit will also have the issue resolved.
+
+Since this issue also impacts TF versions before 2.4, we will patch all releases
+between 1.15 and 2.3 inclusive.
+
+### For more information
+Please consult [our security
+guide](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) for
+more information regarding the security model and how to contact us with issues
+and questions.
diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD
index 0ee227d..839950b 100644
--- a/tensorflow/stream_executor/cuda/BUILD
+++ b/tensorflow/stream_executor/cuda/BUILD
@@ -599,6 +599,16 @@
 )
 
 cc_library(
+    name = "cuda_asm_compiler",
+    srcs = if_cuda_is_configured(["cuda_asm_compiler.cc"]),
+    deps = if_cuda_is_configured([
+        "//tensorflow/core:lib_proto_parsing",
+        "//tensorflow/stream_executor/gpu:asm_compiler",
+        "//tensorflow/stream_executor/gpu:gpu_driver_header",
+    ]),
+)
+
+cc_library(
     name = "cuda_gpu_executor",
     srcs = if_cuda_is_configured(["cuda_gpu_executor.cc"]),
     hdrs = if_cuda_is_configured(["cuda_gpu_executor.h"]),
@@ -611,6 +621,7 @@
         ":cuda_platform_id",
         ":cuda_stream",
         ":cuda_timer",
+        ":cuda_asm_compiler",
         "@com_google_absl//absl/strings",
         "//tensorflow/stream_executor:event",
         "//tensorflow/stream_executor:plugin_registry",
diff --git a/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc b/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc
new file mode 100644
index 0000000..f92d3c4
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc
@@ -0,0 +1,55 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/stream_executor/gpu/asm_compiler.h"
+#include "tensorflow/stream_executor/gpu/gpu_driver.h"
+
+namespace stream_executor {
+
+#define RETURN_IF_CUDA_ERROR(expr)                                            \
+  do {                                                                        \
+    CUresult _status = expr;                                                  \
+    if (!SE_PREDICT_TRUE(_status == CUDA_SUCCESS)) {                          \
+      const char* error_string;                                               \
+      cuGetErrorString(_status, &error_string);                               \
+      std::ostringstream oss;                                                 \
+      oss << error_string << "\nin " << __FILE__ << "(" << __LINE__ << "): '" \
+          << #expr << "'";                                                    \
+      return port::Status(port::error::UNKNOWN, oss.str().c_str());           \
+    }                                                                         \
+  } while (false)
+
+port::StatusOr<std::vector<uint8>> LinkGpuAsm(
+    gpu::GpuContext* context, std::vector<CubinOrPTXImage> images) {
+  gpu::ScopedActivateContext activation(context);
+
+  CUlinkState link_state;
+  RETURN_IF_CUDA_ERROR(cuLinkCreate(0, nullptr, nullptr, &link_state));
+  for (auto& image : images) {
+    RETURN_IF_CUDA_ERROR(cuLinkAddData(
+        link_state, CU_JIT_INPUT_CUBIN, static_cast<void*>(image.bytes.data()),
+        image.bytes.size(), "", 0, nullptr, nullptr));
+  }
+  void* cubin_out;
+  size_t cubin_size;
+  RETURN_IF_CUDA_ERROR(cuLinkComplete(link_state, &cubin_out, &cubin_size));
+  std::vector<uint8> cubin(static_cast<uint8*>(cubin_out),
+                           static_cast<uint8*>(cubin_out) + cubin_size);
+  RETURN_IF_CUDA_ERROR(cuLinkDestroy(link_state));
+  return std::move(cubin);
+}
+
+}  // namespace stream_executor
diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
index e2923ba..c16f269 100644
--- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
+++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
@@ -295,7 +295,7 @@
 
   std::string version_and_rest = driver_version_file_contents.substr(
       offset + strlen(kDriverFilePrelude), std::string::npos);
-  size_t space_index = version_and_rest.find(" ");
+  size_t space_index = version_and_rest.find(' ');
   auto kernel_version = version_and_rest.substr(0, space_index);
   // TODO(b/22689637): Eliminate the explicit namespace if possible.
   auto stripped_kernel_version = absl::StripSuffix(kernel_version, ".ld64");
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index c03eb0a..e4e9914 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -1468,7 +1468,9 @@
   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
       cudnnDataType_t data_type) {
-    CHECK_GT(max_seq_length, 0);
+    if (max_seq_length <= 0) {
+      return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
+    }
     int dims[] = {batch_size, data_size, 1};
     int strides[] = {dims[1] * dims[2], dims[2], 1};
     TensorDescriptor tensor_desc = CreateTensorDescriptor();
@@ -1486,7 +1488,9 @@
       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
       const absl::Span<const int>& seq_lengths, bool time_major,
       cudnnDataType_t data_type) {
-    CHECK_GT(max_seq_length, 0);
+    if (max_seq_length <= 0) {
+      return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
+    }
     int dims[] = {batch_size, data_size, 1};
     int strides[] = {dims[1] * dims[2], dims[2], 1};
     TensorDescriptor tensor_desc = CreateTensorDescriptor();
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index 67fd72d..42db563 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -890,6 +890,137 @@
   return true;
 }
 
+#if CUDA_VERSION >= 10020
+/* static */ port::StatusOr<GpuDriver::VmemSpan>
+GpuDriver::ReserveVirtualMemory(GpuContext* context, uint64 bytes) {
+  ScopedActivateContext activation(context);
+  CUdeviceptr base;
+  CUresult res = cuMemAddressReserve(&base, bytes, /*alignment=*/0,
+                                     /*addr=*/0, /*flags=*/0);
+  if (res != CUDA_SUCCESS) {
+    return port::InternalError(
+        absl::StrFormat("error reserving %d bytes of virtual GPU memory: %s",
+                        bytes, ToString(res)));
+  }
+  return {{base, bytes}};
+}
+
+/* static */ void GpuDriver::FreeVirtualMemory(
+    GpuContext* context, GpuDriver::VmemSpan reservation) {
+  ScopedActivateContext activation(context);
+  CUresult res = cuMemAddressFree(reservation.base, reservation.size_bytes);
+  if (res != CUDA_SUCCESS) {
+    LOG(ERROR) << "error freeing vmem reservation of size "
+               << reservation.size_bytes << " at address " << reservation.base;
+  }
+}
+
+/* static */ port::StatusOr<uint64> GpuDriver::GetMinAllocationGranularity(
+    int device_ordinal) {
+  CUmemAllocationProp props = {};
+  props.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+  props.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+  props.location.id = device_ordinal;
+
+  size_t granularity;
+  CUresult res = cuMemGetAllocationGranularity(
+      &granularity, &props, CU_MEM_ALLOC_GRANULARITY_MINIMUM);
+  if (res != CUDA_SUCCESS) {
+    return port::InternalError(absl::StrCat(
+        "failed to get min allocation granularity: ", ToString(res)));
+  }
+  return granularity;
+}
+
+/* static */ port::StatusOr<GpuDriver::GenericMemoryHandle>
+GpuDriver::CreateMemoryHandle(GpuContext* context, uint64 bytes) {
+  ScopedActivateContext activation(context);
+  auto device = DeviceFromContext(context);
+  if (!device.ok()) {
+    LOG(ERROR) << "Failed to get device from context" << device.status();
+    return device.status();
+  }
+
+  CUmemAllocationProp props = {};
+  props.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+  props.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+  props.location.id = device.ValueOrDie();
+
+  CUmemGenericAllocationHandle mem_handle;
+  CUresult res = cuMemCreate(&mem_handle, bytes, &props, 0);
+  if (res != CUDA_SUCCESS) {
+    return port::InternalError(
+        absl::StrFormat("failed to create memory allocation of size %d: %s",
+                        bytes, ToString(res)));
+  }
+  return GpuDriver::GenericMemoryHandle{mem_handle, bytes};
+}
+
+/* static */ void GpuDriver::ReleaseMemoryHandle(
+    GpuContext* context, GpuDriver::GenericMemoryHandle handle) {
+  ScopedActivateContext activation(context);
+
+  CUresult res = cuMemRelease(handle.handle);
+  if (res != CUDA_SUCCESS) {
+    LOG(ERROR) << "Failed to release memory handle " << handle.handle
+               << " of size " << handle.bytes << ": " << ToString(res);
+  }
+}
+
+/* static */ port::Status GpuDriver::MapMemory(
+    GpuContext* context, CUdeviceptr va,
+    const GpuDriver::GenericMemoryHandle& handle,
+    const std::vector<int>& device_ordinals) {
+  ScopedActivateContext activation(context);
+
+  auto device = DeviceFromContext(context);
+  if (!device.ok()) {
+    return device.status();
+  }
+
+  // NB: Zero is the only valid value for both flags and offset.
+  CUresult res =
+      cuMemMap(va, handle.bytes, /*offset=*/0, handle.handle, /*flags=*/0);
+  if (res != CUDA_SUCCESS) {
+    return port::InternalError(absl::StrFormat(
+        "Failed to map %d bytes at %d: %s", handle.bytes, va, ToString(res)));
+  }
+
+  std::vector<CUmemAccessDesc> access_descriptors(device_ordinals.size());
+  for (int i = 0; i < access_descriptors.size(); ++i) {
+    access_descriptors[i].location.id = device_ordinals[i];
+    access_descriptors[i].location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+    access_descriptors[i].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
+  }
+
+  res = cuMemSetAccess(va, handle.bytes, access_descriptors.data(),
+                       access_descriptors.size());
+  if (res != CUDA_SUCCESS) {
+    // Unmap the memory that we failed to set access for.
+    if (cuMemUnmap(va, handle.bytes) != CUDA_SUCCESS) {
+      LOG(ERROR)
+          << "Failed to unmap memory in GpuDriver::MapMemory error path.";
+    }
+    return port::InternalError(absl::StrFormat(
+        "Failed to set read/write access on memory mapped at %d: %s", va,
+        ToString(res)));
+  }
+  return port::Status::OK();
+}
+
+/* static */ void GpuDriver::UnmapMemory(GpuContext* context, CUdeviceptr va,
+                                         uint64 bytes) {
+  ScopedActivateContext activation(context);
+
+  CUresult res = cuMemUnmap(va, bytes);
+  if (res != CUDA_SUCCESS) {
+    LOG(ERROR) << "Failed to unmap memory at " << va << " of size " << bytes
+               << ": " << ToString(res);
+  }
+}
+
+#endif
+
 /* static */ port::Status GpuDriver::DestroyEvent(GpuContext* context,
                                                   CUevent* event) {
   if (*event == nullptr) {
diff --git a/tensorflow/stream_executor/gpu/asm_compiler.cc b/tensorflow/stream_executor/gpu/asm_compiler.cc
index 21beeb0..d081557 100644
--- a/tensorflow/stream_executor/gpu/asm_compiler.cc
+++ b/tensorflow/stream_executor/gpu/asm_compiler.cc
@@ -108,23 +108,32 @@
 port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
     int device_ordinal, const char* ptx, GpuAsmOpts compilation_options) {
   using PtxCacheKey = std::tuple<int, std::string, GpuAsmOpts::PtxOptionsTuple>;
+  using PtxCompilerResult = port::StatusOr<std::vector<uint8>>;
   static tensorflow::mutex ptx_cache_mutex(tensorflow::LINKER_INITIALIZED);
   static auto& ptx_cache TF_GUARDED_BY(ptx_cache_mutex) =
-      *new absl::flat_hash_map<PtxCacheKey, std::vector<uint8>>();
+      *new absl::flat_hash_map<PtxCacheKey, PtxCompilerResult>();
 
   tensorflow::mutex_lock lock(ptx_cache_mutex);
   PtxCacheKey cache_key{device_ordinal, std::string(ptx),
                         compilation_options.ToTuple()};
   auto it = ptx_cache.find(cache_key);
   if (it == ptx_cache.end()) {
-    TF_ASSIGN_OR_RETURN(
-        std::vector<uint8> compiled,
-        CompileGpuAsm(device_ordinal, ptx, compilation_options));
+    PtxCompilerResult compiled =
+        CompileGpuAsm(device_ordinal, ptx, compilation_options);
     it = ptx_cache.emplace(cache_key, std::move(compiled)).first;
   }
 
   CHECK(it != ptx_cache.end());
-  const std::vector<uint8>& compiled = it->second;
+
+  // Failed compilation attempts are cached.
+  // Use separate status check and ValueOrDie invocation on ptx_cache
+  // entry to avoid value moving introduced by TF_ASSIGN_OR_RETURN.
+
+  if (TF_PREDICT_FALSE(!it->second.ok())) {
+    return it->second.status();
+  }
+
+  const std::vector<uint8>& compiled = it->second.ValueOrDie();
   return absl::MakeSpan(compiled);
 }
 
diff --git a/tensorflow/stream_executor/gpu/asm_compiler.h b/tensorflow/stream_executor/gpu/asm_compiler.h
index 1ac58aa..388f919 100644
--- a/tensorflow/stream_executor/gpu/asm_compiler.h
+++ b/tensorflow/stream_executor/gpu/asm_compiler.h
@@ -24,6 +24,9 @@
 #include "tensorflow/stream_executor/platform/port.h"
 
 namespace stream_executor {
+namespace gpu {
+class GpuContext;
+}
 
 // Compiles the given PTX string using ptxas and returns the resulting machine
 // code (i.e. a cubin) as a byte array. The generated cubin matches the compute
@@ -72,6 +75,11 @@
 port::StatusOr<std::vector<uint8>> BundleGpuAsm(
     std::vector<HsacoImage> images, const std::string rocm_root_dir);
 
+// Links multiple relocatable GPU images (e.g. results of ptxas -c) into a
+// single image.
+port::StatusOr<std::vector<uint8>> LinkGpuAsm(
+    gpu::GpuContext* context, std::vector<CubinOrPTXImage> images);
+
 }  // namespace stream_executor
 
 #endif  // TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_
diff --git a/tensorflow/stream_executor/gpu/gpu_driver.h b/tensorflow/stream_executor/gpu/gpu_driver.h
index 25b90be..3cd13dc 100644
--- a/tensorflow/stream_executor/gpu/gpu_driver.h
+++ b/tensorflow/stream_executor/gpu/gpu_driver.h
@@ -140,6 +140,63 @@
   // previously registered.
   static bool HostUnregister(GpuContext* context, void* location);
 
+  // Virtual memory support was added to CUDA in 10.2
+#if CUDA_VERSION >= 10020
+
+  // Reserves a range of virtual device memory addresses via
+  // cuMemAddressReserve. bytes must be a multiple of the host page size.
+  // Returns nullptr base address in VmemSpan if the reservation fails.
+  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1ge489256c107df2a07ddf96d80c86cd9b
+  struct VmemSpan {
+    GpuDevicePtr base;
+    // Size in bytes.
+    uint64 size_bytes;
+  };
+  static port::StatusOr<VmemSpan> ReserveVirtualMemory(GpuContext* context,
+                                                       uint64 bytes);
+
+  // Frees a range of virtual addresses that were previously reserved through
+  // ReserveVirtualMemory via cuMemAddressFree.
+  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g6993ecea2ea03e1b802b8255edc2da5b
+  static void FreeVirtualMemory(GpuContext* context, VmemSpan reservation);
+
+  // Calculates the minimum alignment for memory allocations done through
+  // cuMemCreate via cuMemGetAllocationGranularity.
+  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g30ee906c2cf66a0347b3dfec3d7eb31a
+  static port::StatusOr<uint64> GetMinAllocationGranularity(int device_ordinal);
+
+  // Allocates physical memory and returns a handle that can be mapped to
+  // virtual addresses via cuMemCreate. bytes must be a multiple of the
+  // granularity returned by GetMinAllocationGranularity.
+  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g899d69a862bba36449789c64b430dc7c
+  struct GenericMemoryHandle {
+    uint64 handle;
+    uint64 bytes;
+  };
+  static port::StatusOr<GenericMemoryHandle> CreateMemoryHandle(
+      GpuContext* context, uint64 bytes);
+
+  // Frees memory represented by the provided MemoryHandle via cuMemRelease.
+  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g3014f0759f43a8d82db951b8e4b91d68
+  static void ReleaseMemoryHandle(GpuContext* context,
+                                  GenericMemoryHandle handle);
+
+  // Maps a memory allocation handle to a reserved virtual address range via
+  // cuMemMap and sets the appropriate access settings via cuMemSetAccess.
+  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1gff1d395423af5c5c75375516959dae56
+  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g1b6b12b10e8324bf462ecab4e7ef30e1
+  static port::Status MapMemory(GpuContext* context, GpuDevicePtr va,
+                                const GenericMemoryHandle& handle,
+                                const std::vector<int>& device_ordinals);
+
+  // Unmaps the backing memory from the given virtual address range. This range
+  // must fully unmap a memory handle that was mapped using MapMemory; partial
+  // unmapping is not supported.
+  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1gfb50aac00c848fd7087e858f59bf7e2a
+  static void UnmapMemory(GpuContext* context, GpuDevicePtr va, uint64 bytes);
+
+#endif  // CUDA_VERSION >= 10200
+
   // Given a device ordinal, returns a device handle into the device outparam,
   // which must not be null.
   //
diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc
index 2223cb9..1604c2b 100644
--- a/tensorflow/stream_executor/rocm/rocm_blas.cc
+++ b/tensorflow/stream_executor/rocm/rocm_blas.cc
@@ -104,174 +104,174 @@
 #define ROCBLAS_BLAS_ROUTINE_EACH(__macro)  \
   __macro(rocblas_snrm2)                    \
   __macro(rocblas_dnrm2)                    \
-  /*__macro(rocblas_scnrm2)                   \
-    __macro(rocblas_dznrm2)                */ \
+  __macro(rocblas_scnrm2)		    \
+  __macro(rocblas_dznrm2)                   \
   __macro(rocblas_sdot)                     \
   __macro(rocblas_ddot)                     \
-  /*__macro(rocblas_cdotu)                    \
-    __macro(rocblas_cdotc)                    \
-    __macro(rocblas_zdotu)                    \
-    __macro(rocblas_zdotc)                 */ \
+  __macro(rocblas_cdotu)                    \
+  __macro(rocblas_cdotc)		    \
+  __macro(rocblas_zdotu)		    \
+  __macro(rocblas_zdotc)		    \
   __macro(rocblas_sscal)                    \
   __macro(rocblas_dscal)                    \
   __macro(rocblas_cscal)                    \
-    __macro(rocblas_csscal)                   \
-    __macro(rocblas_zscal)                    \
-    __macro(rocblas_zdscal)                 \
+  __macro(rocblas_csscal)		    \
+  __macro(rocblas_zscal)		    \
+  __macro(rocblas_zdscal)		    \
   __macro(rocblas_saxpy)                    \
   __macro(rocblas_daxpy)                    \
-  /*__macro(rocblas_caxpy)                    \
-    __macro(rocblas_zaxpy)                 */ \
+  __macro(rocblas_caxpy)                    \
+  __macro(rocblas_zaxpy)		    \
   __macro(rocblas_scopy)                    \
   __macro(rocblas_dcopy)                    \
-  /*__macro(rocblas_ccopy)                    \
-    __macro(rocblas_zcopy)                 */ \
+  __macro(rocblas_ccopy)                    \
+  __macro(rocblas_zcopy)		    \
   __macro(rocblas_sswap)                    \
   __macro(rocblas_dswap)                    \
-  /*__macro(rocblas_cswap)                    \
-    __macro(rocblas_zswap)                 */ \
+  __macro(rocblas_cswap)                    \
+  __macro(rocblas_zswap)		    \
   __macro(rocblas_isamax)                   \
   __macro(rocblas_idamax)                   \
-  /*__macro(rocblas_icamax)                   \
-    __macro(rocblas_izamax)                */ \
+  __macro(rocblas_icamax)                   \
+  __macro(rocblas_izamax)		    \
   __macro(rocblas_isamin)                   \
   __macro(rocblas_idamin)                   \
-  /*__macro(rocblas_icamin)                   \
-    __macro(rocblas_izamin)                */ \
+  __macro(rocblas_icamin)                   \
+  __macro(rocblas_izamin)		    \
   __macro(rocblas_sasum)                    \
   __macro(rocblas_dasum)                    \
-  /*__macro(rocblas_scasum)                   \
-    __macro(rocblas_dzasum)                   \
-    __macro(rocblas_srot)                     \
-    __macro(rocblas_drot)                     \
-    __macro(rocblas_crot)                     \
-    __macro(rocblas_csrot)                    \
-    __macro(rocblas_zrot)                     \
-    __macro(rocblas_zdrot)                    \
-    __macro(rocblas_srotg)                    \
-    __macro(rocblas_drotg)                    \
-    __macro(rocblas_Crotg)                    \
-    __macro(rocblas_crotg)                    \
-    __macro(rocblas_zrotm)                    \
-    __macro(rocblas_drotm)                    \
-    __macro(rocblas_srotmg)                   \
-    __macro(rocblas_drotmg)                */ \
+  __macro(rocblas_scasum)                   \
+  __macro(rocblas_dzasum)		    \
+  __macro(rocblas_srot)			    \
+  __macro(rocblas_drot)			    \
+  __macro(rocblas_crot)			    \
+  __macro(rocblas_csrot)		    \
+  __macro(rocblas_zrot)			    \
+  __macro(rocblas_zdrot)		    \
+  __macro(rocblas_srotg)		    \
+  __macro(rocblas_drotg)		    \
+  __macro(rocblas_crotg)		    \
+  __macro(rocblas_zrotg)		    \
+  __macro(rocblas_srotm)		    \
+  __macro(rocblas_drotm)		    \
+  __macro(rocblas_srotmg)		    \
+  __macro(rocblas_drotmg)		    \
   __macro(rocblas_sgemv)                    \
   __macro(rocblas_dgemv)                    \
   __macro(rocblas_cgemv)                    \
-    __macro(rocblas_zgemv)                    \
-  /*  __macro(rocblas_sgbmv)                    \
-    __macro(rocblas_dgbmv)                    \
-    __macro(rocblas_cgbmv)                    \
-    __macro(rocblas_zgbmv)                    \
-    __macro(rocblas_strmv)                    \
-    __macro(rocblas_dtrmv)                    \
-    __macro(rocblas_ctrmv)                    \
-    __macro(rocblas_ztrmv)                    \
-    __macro(rocblas_stbmv)                    \
-    __macro(rocblas_dtbmv)                    \
-    __macro(rocblas_ctbmv)                    \
-    __macro(rocblas_ztbmv)                    \
-    __macro(rocblas_stpmv)                    \
-    __macro(rocblas_dtpmv)                    \
-    __macro(rocblas_ctpmv)                    \
-    __macro(rocblas_ztpmv)                    \
-    __macro(rocblas_strsv)                    \
-    __macro(rocblas_dtrsv)                    \
-    __macro(rocblas_ctrsv)                    \
-    __macro(rocblas_ztrsv)                    \
-    __macro(rocblas_stpsv)                    \
-    __macro(rocblas_dtpsv)                    \
-    __macro(rocblas_ctpsv)                    \
-    __macro(rocblas_ztpsv)                    \
-    __macro(rocblas_stbsv)                    \
-    __macro(rocblas_dtbsv)                    \
-    __macro(rocblas_ctbsv)                    \
-    __macro(rocblas_ztbsv)                    \
-    __macro(rocblas_ssymv)                    \
-    __macro(rocblas_dsymv)                    \
-    __macro(rocblas_csymv)                    \
-    __macro(rocblas_zsymv)                    \
-    __macro(rocblas_chemv)                    \
-    __macro(rocblas_zhemv)                    \
-    __macro(rocblas_ssbmv)                    \
-    __macro(rocblas_dsbmv)                    \
-    __macro(rocblas_chbmv)                    \
-    __macro(rocblas_zhbmv)                    \
-    __macro(rocblas_sspmv)                    \
-    __macro(rocblas_dspmv)                    \
-    __macro(rocblas_chpmv)                    \
-    __macro(rocblas_zhpmv)                 */ \
+  __macro(rocblas_zgemv)		    \
+  __macro(rocblas_sgbmv)		    \
+  __macro(rocblas_dgbmv)		    \
+  __macro(rocblas_cgbmv)		    \
+  __macro(rocblas_zgbmv)		    \
+  __macro(rocblas_strmv)		    \
+  __macro(rocblas_dtrmv)		    \
+  __macro(rocblas_ctrmv)		    \
+  __macro(rocblas_ztrmv)		    \
+  __macro(rocblas_stbmv)		    \
+  __macro(rocblas_dtbmv)		    \
+  __macro(rocblas_ctbmv)		    \
+  __macro(rocblas_ztbmv)		    \
+  __macro(rocblas_stpmv)		    \
+  __macro(rocblas_dtpmv)		    \
+  __macro(rocblas_ctpmv)		    \
+  __macro(rocblas_ztpmv)		    \
+  __macro(rocblas_strsv)		    \
+  __macro(rocblas_dtrsv)		    \
+  __macro(rocblas_ctrsv)		    \
+  __macro(rocblas_ztrsv)		    \
+  __macro(rocblas_stpsv)		    \
+  __macro(rocblas_dtpsv)		    \
+  __macro(rocblas_ctpsv)		    \
+  __macro(rocblas_ztpsv)		    \
+  __macro(rocblas_stbsv)		    \
+  __macro(rocblas_dtbsv)		    \
+  __macro(rocblas_ctbsv)		    \
+  __macro(rocblas_ztbsv)		    \
+  __macro(rocblas_ssymv)		    \
+  __macro(rocblas_dsymv)		    \
+  /*    __macro(rocblas_csymv)		    \
+    __macro(rocblas_zsymv)              */  \
+  __macro(rocblas_chemv)		    \
+  __macro(rocblas_zhemv)		    \
+  __macro(rocblas_ssbmv)		    \
+  __macro(rocblas_dsbmv)		    \
+  __macro(rocblas_chbmv)		    \
+  __macro(rocblas_zhbmv)		    \
+  __macro(rocblas_sspmv)		    \
+  __macro(rocblas_dspmv)		    \
+  __macro(rocblas_chpmv)		    \
+  __macro(rocblas_zhpmv)		    \
   __macro(rocblas_sger)                     \
   __macro(rocblas_dger)                     \
-  /*__macro(rocblas_cgeru)                    \
-    __macro(rocblas_cgerc)                    \
-    __macro(rocblas_zgeru)                    \
-    __macro(rocblas_zgerc)                 */ \
+  __macro(rocblas_cgeru)		    \
+  __macro(rocblas_cgerc)		    \
+  __macro(rocblas_zgeru)		    \
+  __macro(rocblas_zgerc)		    \
   __macro(rocblas_ssyr)                     \
   __macro(rocblas_dsyr)                     \
-  /*__macro(rocblas_csyr)                     \
-    __macro(rocblas_zsyr)                     \
-    __macro(rocblas_cher)                     \
-    __macro(rocblas_zher)                     \
-    __macro(rocblas_sspr)                     \
-    __macro(rocblas_dspr)                     \
-    __macro(rocblas_chpr)                     \
-    __macro(rocblas_zhpr)                     \
-    __macro(rocblas_ssyr2)                    \
-    __macro(rocblas_dsyr2)                    \
-    __macro(rocblas_csyr2)                    \
-    __macro(rocblas_zsyr2)                    \
-    __macro(rocblas_cher2)                    \
-    __macro(rocblas_zher2)                    \
-    __macro(rocblas_sspr2)                    \
-    __macro(rocblas_dspr2)                    \
-    __macro(rocblas_chpr2)                    \
-    __macro(rocblas_zhpr2)                 */ \
+  /*__macro(rocblas_csyr)                   \
+    __macro(rocblas_zsyr)               */  \
+  __macro(rocblas_cher)			    \
+  __macro(rocblas_zher)			    \
+  __macro(rocblas_sspr)			    \
+  __macro(rocblas_dspr)			    \
+  __macro(rocblas_chpr)			    \
+  __macro(rocblas_zhpr)			    \
+  __macro(rocblas_ssyr2)		    \
+  __macro(rocblas_dsyr2)		    \
+  /*  __macro(rocblas_csyr2)		    \
+    __macro(rocblas_zsyr2)              */  \
+  __macro(rocblas_cher2)		    \
+  __macro(rocblas_zher2)		    \
+  __macro(rocblas_sspr2)		    \
+  __macro(rocblas_dspr2)		    \
+  __macro(rocblas_chpr2)                    \
+  __macro(rocblas_zhpr2)		    \
   __macro(rocblas_sgemm)                    \
   __macro(rocblas_dgemm)                    \
   __macro(rocblas_hgemm)                    \
   __macro(rocblas_cgemm)                    \
-    __macro(rocblas_zgemm)                    \
-  /*  __macro(rocblas_ssyrk)                    \
-    __macro(rocblas_dsyrk)                    \
-    __macro(rocblas_csyrk)                    \
-    __macro(rocblas_zsyrk)                    \
-    __macro(rocblas_cherk)                    \
-    __macro(rocblas_zherk)                    \
-    __macro(rocblas_ssyr2k)                   \
-    __macro(rocblas_dsyr2k)                   \
-    __macro(rocblas_csyr2k)                   \
-    __macro(rocblas_zsyr2k)                   \
-    __macro(rocblas_cher2k)                   \
-    __macro(rocblas_zher2k)                   \
-    __macro(rocblas_ssyrkx)                   \
-    __macro(rocblas_dsyrkx)                   \
-    __macro(rocblas_csyrkx)                   \
-    __macro(rocblas_zsyrkx)                   \
-    __macro(rocblas_cherkx)                   \
-    __macro(rocblas_zherkx)                   \
-    __macro(rocblas_ssymm)                    \
-    __macro(rocblas_dsymm)                    \
-    __macro(rocblas_csymm)                    \
-    __macro(rocblas_zsymm)                    \
-    __macro(rocblas_chemm)                    \
-    __macro(rocblas_zhemm)                 */ \
+  __macro(rocblas_zgemm)		    \
+  __macro(rocblas_ssyrk)		    \
+  __macro(rocblas_dsyrk)		    \
+  __macro(rocblas_csyrk)		    \
+  __macro(rocblas_zsyrk)		    \
+  __macro(rocblas_cherk)		    \
+  __macro(rocblas_zherk)		    \
+  __macro(rocblas_ssyr2k)		    \
+  __macro(rocblas_dsyr2k)		    \
+  __macro(rocblas_csyr2k)		    \
+  __macro(rocblas_zsyr2k)		    \
+  __macro(rocblas_cher2k)		    \
+  __macro(rocblas_zher2k)		    \
+  /*    __macro(rocblas_ssyrkx)		    \
+    __macro(rocblas_dsyrkx)                 \
+    __macro(rocblas_csyrkx)                 \
+    __macro(rocblas_zsyrkx)                 \
+    __macro(rocblas_cherkx)                 \
+    __macro(rocblas_zherkx)             */  \
+  __macro(rocblas_ssymm)		    \
+  __macro(rocblas_dsymm)		    \
+  __macro(rocblas_csymm)		    \
+  __macro(rocblas_zsymm)		    \
+  __macro(rocblas_chemm)		    \
+  __macro(rocblas_zhemm)		    \
   __macro(rocblas_strsm)                    \
   __macro(rocblas_dtrsm)                    \
-  /*__macro(rocblas_ctrsm)                    \
-    __macro(rocblas_ztrsm)                    \
-    __macro(rocblas_strmm)                    \
-    __macro(rocblas_dtrmm)                    \
-    __macro(rocblas_ctrmm)                    \
-    __macro(rocblas_ztrmm)                 */ \
+  __macro(rocblas_ctrsm)                    \
+  __macro(rocblas_ztrsm)		    \
+  __macro(rocblas_strmm)		    \
+  __macro(rocblas_dtrmm)		    \
+  __macro(rocblas_ctrmm)		    \
+  __macro(rocblas_ztrmm)		    \
   __macro(rocblas_sgeam)                    \
   __macro(rocblas_dgeam)                    \
-  /*__macro(rocblas_cgeam)                    \
-    __macro(rocblas_zgeam)                    \
-    __macro(rocblas_sdgmm)                    \
-    __macro(rocblas_ddgmm)                    \
-    __macro(rocblas_cdgmm)                    \
+  /*__macro(rocblas_cgeam)                  \
+    __macro(rocblas_zgeam)                  \
+    __macro(rocblas_sdgmm)                  \
+    __macro(rocblas_ddgmm)                  \
+    __macro(rocblas_cdgmm)                  \
     __macro(rocblas_zdgmm) */
 // clang-format on
 
@@ -445,7 +445,7 @@
                           const DeviceMemory<float> &x, int incx,
                           DeviceMemory<float> *result) {
   return DoBlasInternal(wrap::rocblas_sasum, stream,
-                        false /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ false, elem_count,
                         GpuMemory(x), incx, GpuMemoryMutable(result));
 }
 
@@ -453,24 +453,24 @@
                           const DeviceMemory<double> &x, int incx,
                           DeviceMemory<double> *result) {
   return DoBlasInternal(wrap::rocblas_dasum, stream,
-                        false /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ false, elem_count,
                         GpuMemory(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           DeviceMemory<float> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the ASUM operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_scasum, stream,
+                        /* pointer_mode_host = */ false, elem_count,
+                        complex_cast(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           DeviceMemory<double> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the ASUM operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_dzasum, stream,
+                        /* pointer_mode_host = */ false, elem_count,
+                        complex_cast(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
@@ -478,7 +478,7 @@
                           DeviceMemory<float> *y, int incy) {
   blas_log("DoBlasAxpy");
   return DoBlasInternal(wrap::rocblas_saxpy, stream,
-                        true /* = pointer_mode_host */, elem_count, &alpha,
+                        /* pointer_mode_host = */ true, elem_count, &alpha,
                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
 }
 
@@ -487,7 +487,7 @@
                           DeviceMemory<double> *y, int incy) {
   blas_log("DoBlasAxpy");
   return DoBlasInternal(wrap::rocblas_daxpy, stream,
-                        true /* = pointer_mode_host */, elem_count, &alpha,
+                        /* pointer_mode_host = */ true, elem_count, &alpha,
                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
 }
 
@@ -495,25 +495,25 @@
                           std::complex<float> alpha,
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           DeviceMemory<std::complex<float>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the AXPY operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_caxpy, stream, /* pointer_mode_host = */ true, elem_count,
+      complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
                           std::complex<double> alpha,
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           DeviceMemory<std::complex<double>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the AXPY operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zaxpy, stream, /* pointer_mode_host = */ true, elem_count,
+      complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
                           const DeviceMemory<float> &x, int incx,
                           DeviceMemory<float> *y, int incy) {
   return DoBlasInternal(wrap::rocblas_scopy, stream,
-                        true /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ true, elem_count,
                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
 }
 
@@ -521,24 +521,24 @@
                           const DeviceMemory<double> &x, int incx,
                           DeviceMemory<double> *y, int incy) {
   return DoBlasInternal(wrap::rocblas_dcopy, stream,
-                        true /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ true, elem_count,
                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
 }
 
 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           DeviceMemory<std::complex<float>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the COPY operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ccopy, stream,
+                        /* pointer_mode_host = */ true, elem_count,
+                        complex_cast(x), incx, complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           DeviceMemory<std::complex<double>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the COPY operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zcopy, stream,
+                        /* pointer_mode_host = */ true, elem_count,
+                        complex_cast(x), incx, complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
@@ -547,7 +547,7 @@
                          DeviceMemory<float> *result) {
   blas_log("DoBlasDot");
   return DoBlasInternal(
-      wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count,
+      wrap::rocblas_sdot, stream, /* pointer_mode_host = */ false, elem_count,
       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
 }
 
@@ -557,7 +557,7 @@
                          DeviceMemory<double> *result) {
   blas_log("DoBlasDot");
   return DoBlasInternal(
-      wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count,
+      wrap::rocblas_ddot, stream, /* pointer_mode_host = */ false, elem_count,
       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
 }
 
@@ -565,43 +565,43 @@
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           const DeviceMemory<std::complex<float>> &y, int incy,
                           DeviceMemory<std::complex<float>> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_cdotc, stream, /* pointer_mode_host = */ false, elem_count,
+      complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
 }
 
 bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           const DeviceMemory<std::complex<double>> &y, int incy,
                           DeviceMemory<std::complex<double>> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zdotc, stream, /* pointer_mode_host = */ false, elem_count,
+      complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
 }
 
 bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           const DeviceMemory<std::complex<float>> &y, int incy,
                           DeviceMemory<std::complex<float>> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_cdotu, stream, /* pointer_mode_host = */ false, elem_count,
+      complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
 }
 
 bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           const DeviceMemory<std::complex<double>> &y, int incy,
                           DeviceMemory<std::complex<double>> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zdotu, stream, /* pointer_mode_host = */ false, elem_count,
+      complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
 }
 
 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
                           const DeviceMemory<float> &x, int incx,
                           DeviceMemory<float> *result) {
   return DoBlasInternal(wrap::rocblas_snrm2, stream,
-                        false /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ false, elem_count,
                         GpuMemory(x), incx, GpuMemoryMutable(result));
 }
 
@@ -609,157 +609,161 @@
                           const DeviceMemory<double> &x, int incx,
                           DeviceMemory<double> *result) {
   return DoBlasInternal(wrap::rocblas_dnrm2, stream,
-                        false /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ false, elem_count,
                         GpuMemory(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           DeviceMemory<float> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_scnrm2, stream,
+                        /* pointer_mode_host = */ false, elem_count,
+                        complex_cast(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           DeviceMemory<double> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_dznrm2, stream,
+                        /* pointer_mode_host = */ false, elem_count,
+                        complex_cast(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
                          DeviceMemory<float> *x, int incx,
                          DeviceMemory<float> *y, int incy, float c, float s) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_srot, stream, /* pointer_mode_host = */ true, elem_count,
+      GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, &c, &s);
 }
 
 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
                          DeviceMemory<double> *x, int incx,
                          DeviceMemory<double> *y, int incy, double c,
                          double s) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_drot, stream, /* pointer_mode_host = */ true, elem_count,
+      GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, &c, &s);
 }
 
 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
                          DeviceMemory<std::complex<float>> *x, int incx,
                          DeviceMemory<std::complex<float>> *y, int incy,
                          float c, float s) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_csrot, stream,
+                        /* pointer_mode_host = */ true, elem_count,
+                        complex_cast(x), incx, complex_cast(y), incy, &c, &s);
 }
 
 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
                          DeviceMemory<std::complex<double>> *x, int incx,
                          DeviceMemory<std::complex<double>> *y, int incy,
                          double c, double s) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zdrot, stream,
+                        /* pointer_mode_host = */ true, elem_count,
+                        complex_cast(x), incx, complex_cast(y), incy, &c, &s);
 }
 
 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
                           DeviceMemory<float> *b, DeviceMemory<float> *c,
                           DeviceMemory<float> *s) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_srotg, stream,
+                        /* pointer_mode_host = */ false, GpuMemoryMutable(a),
+                        GpuMemoryMutable(b), GpuMemoryMutable(c),
+                        GpuMemoryMutable(s));
 }
 
 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
                           DeviceMemory<double> *b, DeviceMemory<double> *c,
                           DeviceMemory<double> *s) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_drotg, stream,
+                        /* pointer_mode_host = */ false, GpuMemoryMutable(a),
+                        GpuMemoryMutable(b), GpuMemoryMutable(c),
+                        GpuMemoryMutable(s));
 }
 
 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
                           DeviceMemory<std::complex<float>> *b,
                           DeviceMemory<float> *c,
                           DeviceMemory<std::complex<float>> *s) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_crotg, stream,
+                        /* pointer_mode_host = */ false, complex_cast(a),
+                        complex_cast(b), GpuMemoryMutable(c), complex_cast(s));
 }
 
 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
                           DeviceMemory<std::complex<double>> *b,
                           DeviceMemory<double> *c,
                           DeviceMemory<std::complex<double>> *s) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zrotg, stream,
+                        /* pointer_mode_host = */ false, complex_cast(a),
+                        complex_cast(b), GpuMemoryMutable(c), complex_cast(s));
 }
 
 bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
                           DeviceMemory<float> *x, int incx,
                           DeviceMemory<float> *y, int incy,
                           const DeviceMemory<float> &param) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROTM operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_srotm, stream, /* pointer_mode_host = */ false, elem_count,
+      GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, GpuMemory(param));
 }
 
 bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
                           DeviceMemory<double> *x, int incx,
                           DeviceMemory<double> *y, int incy,
                           const DeviceMemory<double> &param) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROTM operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_drotm, stream, /* pointer_mode_host = */ false, elem_count,
+      GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, GpuMemory(param));
 }
 
 bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
                            DeviceMemory<float> *d2, DeviceMemory<float> *x1,
                            const DeviceMemory<float> &y1,
                            DeviceMemory<float> *param) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_srotmg, stream,
+                        /* pointer_mode_host = */ false, GpuMemoryMutable(d1),
+                        GpuMemoryMutable(d2), GpuMemoryMutable(x1),
+                        GpuMemory(y1), GpuMemoryMutable(param));
 }
 
 bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
                            DeviceMemory<double> *d2, DeviceMemory<double> *x1,
                            const DeviceMemory<double> &y1,
                            DeviceMemory<double> *param) {
-  LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_drotmg, stream,
+                        /* pointer_mode_host = */ false, GpuMemoryMutable(d1),
+                        GpuMemoryMutable(d2), GpuMemoryMutable(x1),
+                        GpuMemory(y1), GpuMemoryMutable(param));
 }
 
 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
                           DeviceMemory<float> *x, int incx) {
   blas_log("DoBlasScal<float>");
   return DoBlasInternal(wrap::rocblas_sscal, stream,
-                        true /* = pointer_mode_host */, elem_count, &alpha,
+                        /* pointer_mode_host = */ true, elem_count, &alpha,
                         GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
                           DeviceMemory<double> *x, int incx) {
   return DoBlasInternal(wrap::rocblas_dscal, stream,
-                        true /* = pointer_mode_host */, elem_count, &alpha,
+                        /* pointer_mode_host = */ true, elem_count, &alpha,
                         GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
                           DeviceMemory<std::complex<float>> *x, int incx) {
   return DoBlasInternal(wrap::rocblas_csscal, stream,
-                        true /* = pointer_mode_host */, elem_count, &alpha,
+                        /* pointer_mode_host = */ true, elem_count, &alpha,
                         complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
                           DeviceMemory<std::complex<double>> *x, int incx) {
   return DoBlasInternal(wrap::rocblas_zdscal, stream,
-                        true /* = pointer_mode_host */, elem_count, &alpha,
+                        /* pointer_mode_host = */ true, elem_count, &alpha,
                         complex_cast(x), incx);
 }
 
@@ -767,7 +771,7 @@
                           std::complex<float> alpha,
                           DeviceMemory<std::complex<float>> *x, int incx) {
   return DoBlasInternal(wrap::rocblas_cscal, stream,
-                        true /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ true, elem_count,
                         complex_cast(alpha), complex_cast(x), incx);
 }
 
@@ -775,7 +779,7 @@
                           std::complex<double> alpha,
                           DeviceMemory<std::complex<double>> *x, int incx) {
   return DoBlasInternal(wrap::rocblas_zscal, stream,
-                        true /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ true, elem_count,
                         complex_cast(alpha), complex_cast(x), incx);
 }
 
@@ -783,7 +787,7 @@
                           DeviceMemory<float> *x, int incx,
                           DeviceMemory<float> *y, int incy) {
   return DoBlasInternal(wrap::rocblas_sswap, stream,
-                        true /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ true, elem_count,
                         GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
 }
 
@@ -791,31 +795,31 @@
                           DeviceMemory<double> *x, int incx,
                           DeviceMemory<double> *y, int incy) {
   return DoBlasInternal(wrap::rocblas_dswap, stream,
-                        true /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ true, elem_count,
                         GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
 }
 
 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
                           DeviceMemory<std::complex<float>> *x, int incx,
                           DeviceMemory<std::complex<float>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the SWAP operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_cswap, stream,
+                        /* pointer_mode_host = */ true, elem_count,
+                        complex_cast(x), incx, complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
                           DeviceMemory<std::complex<double>> *x, int incx,
                           DeviceMemory<std::complex<double>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the SWAP operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zswap, stream,
+                        /* pointer_mode_host = */ true, elem_count,
+                        complex_cast(x), incx, complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
                            const DeviceMemory<float> &x, int incx,
                            DeviceMemory<int> *result) {
   return DoBlasInternal(wrap::rocblas_isamax, stream,
-                        false /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ false, elem_count,
                         GpuMemory(x), incx, GpuMemoryMutable(result));
 }
 
@@ -823,56 +827,56 @@
                            const DeviceMemory<double> &x, int incx,
                            DeviceMemory<int> *result) {
   return DoBlasInternal(wrap::rocblas_idamax, stream,
-                        false /* = pointer_mode_host */, elem_count,
+                        /* pointer_mode_host = */ false, elem_count,
                         GpuMemory(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
                            const DeviceMemory<std::complex<float>> &x, int incx,
                            DeviceMemory<int> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the AMAX operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_icamax, stream,
+                        /* pointer_mode_host = */ false, elem_count,
+                        complex_cast(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
                            const DeviceMemory<std::complex<double>> &x,
                            int incx, DeviceMemory<int> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the AMAX operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_izamax, stream,
+                        /* pointer_mode_host = */ false, elem_count,
+                        complex_cast(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
                            const DeviceMemory<float> &x, int incx,
                            DeviceMemory<int> *result) {
-  return DoBlasInternal(
-      wrap::rocblas_isamin, stream, false /* = pointer_mode_host */, elem_count,
-      GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result));
+  return DoBlasInternal(wrap::rocblas_isamin, stream,
+                        /* pointer_mode_host = */ false, elem_count,
+                        GpuMemory(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
                            const DeviceMemory<double> &x, int incx,
                            DeviceMemory<int> *result) {
-  return DoBlasInternal(
-      wrap::rocblas_idamin, stream, false /* = pointer_mode_host */, elem_count,
-      GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result));
+  return DoBlasInternal(wrap::rocblas_idamin, stream,
+                        /* pointer_mode_host = */ false, elem_count,
+                        GpuMemory(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
                            const DeviceMemory<std::complex<float>> &x, int incx,
                            DeviceMemory<int> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the AMIN operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_icamin, stream,
+                        /* pointer_mode_host = */ false, elem_count,
+                        complex_cast(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
                            const DeviceMemory<std::complex<double>> &x,
                            int incx, DeviceMemory<int> *result) {
-  LOG(ERROR) << "rocBLAS does not currently support the AMIN operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_izamin, stream,
+                        /* pointer_mode_host = */ false, elem_count,
+                        complex_cast(x), incx, GpuMemoryMutable(result));
 }
 
 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
@@ -880,9 +884,10 @@
                           const DeviceMemory<float> &a, int lda,
                           const DeviceMemory<float> &x, int incx, float beta,
                           DeviceMemory<float> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_sgbmv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasTranspose(trans), m, n, kl, ku, &alpha, GpuMemory(a), lda,
+      GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
 }
 
 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
@@ -890,9 +895,10 @@
                           const DeviceMemory<double> &a, int lda,
                           const DeviceMemory<double> &x, int incx, double beta,
                           DeviceMemory<double> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_dgbmv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasTranspose(trans), m, n, kl, ku, &alpha, GpuMemory(a), lda,
+      GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
 }
 
 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
@@ -902,9 +908,11 @@
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           std::complex<float> beta,
                           DeviceMemory<std::complex<float>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_cgbmv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasTranspose(trans), m, n, kl, ku, complex_cast(alpha),
+      complex_cast(a), lda, complex_cast(x), incx, complex_cast(beta),
+      complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
@@ -914,9 +922,11 @@
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           std::complex<double> beta,
                           DeviceMemory<std::complex<double>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zgbmv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasTranspose(trans), m, n, kl, ku, complex_cast(alpha),
+      complex_cast(a), lda, complex_cast(x), incx, complex_cast(beta),
+      complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
@@ -925,7 +935,7 @@
                           float beta, DeviceMemory<float> *y, int incy) {
   blas_log("DoBlasGemv");
   return DoBlasInternal(
-      wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */,
+      wrap::rocblas_sgemv, stream, /* pointer_mode_host = */ true,
       ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
       incx, &beta, GpuMemoryMutable(y), incy);
 }
@@ -936,7 +946,7 @@
                           double beta, DeviceMemory<double> *y, int incy) {
   blas_log("DoBlasGemv");
   return DoBlasInternal(
-      wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */,
+      wrap::rocblas_dgemv, stream, /* pointer_mode_host = */ true,
       ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
       incx, &beta, GpuMemoryMutable(y), incy);
 }
@@ -949,7 +959,7 @@
                           DeviceMemory<std::complex<float>> *y, int incy) {
   blas_log("DoBlasGemv");
   return DoBlasInternal(
-      wrap::rocblas_cgemv, stream, true /* = pointer_mode_host */,
+      wrap::rocblas_cgemv, stream, /* pointer_mode_host = */ true,
       ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
 }
@@ -962,7 +972,7 @@
                           DeviceMemory<std::complex<double>> *y, int incy) {
   blas_log("DoBlasGemv\n");
   return DoBlasInternal(
-      wrap::rocblas_zgemv, stream, true /* = pointer_mode_host */,
+      wrap::rocblas_zgemv, stream, /* pointer_mode_host = */ true,
       ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
 }
@@ -972,7 +982,7 @@
                          const DeviceMemory<float> &y, int incy,
                          DeviceMemory<float> *a, int lda) {
   return DoBlasInternal(
-      wrap::rocblas_sger, stream, true /* = pointer_mode_host */, m, n, &alpha,
+      wrap::rocblas_sger, stream, /* pointer_mode_host = */ true, m, n, &alpha,
       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
 }
 
@@ -981,7 +991,7 @@
                          const DeviceMemory<double> &y, int incy,
                          DeviceMemory<double> *a, int lda) {
   return DoBlasInternal(
-      wrap::rocblas_dger, stream, true /* = pointer_mode_host */, m, n, &alpha,
+      wrap::rocblas_dger, stream, /* pointer_mode_host = */ true, m, n, &alpha,
       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
 }
 
@@ -990,9 +1000,10 @@
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           const DeviceMemory<std::complex<float>> &y, int incy,
                           DeviceMemory<std::complex<float>> *a, int lda) {
-  LOG(ERROR) << "rocBLAS does not currently support the GER operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_cgerc, stream,
+                        /* pointer_mode_host = */ true, m, n,
+                        complex_cast(alpha), complex_cast(x), incx,
+                        complex_cast(y), incy, complex_cast(a), lda);
 }
 
 bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
@@ -1000,9 +1011,10 @@
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           const DeviceMemory<std::complex<double>> &y, int incy,
                           DeviceMemory<std::complex<double>> *a, int lda) {
-  LOG(ERROR) << "rocBLAS does not currently support the GER operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zgerc, stream,
+                        /* pointer_mode_host = */ true, m, n,
+                        complex_cast(alpha), complex_cast(x), incx,
+                        complex_cast(y), incy, complex_cast(a), lda);
 }
 
 bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
@@ -1010,9 +1022,10 @@
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           const DeviceMemory<std::complex<float>> &y, int incy,
                           DeviceMemory<std::complex<float>> *a, int lda) {
-  LOG(ERROR) << "rocBLAS does not currently support the GERU operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_cgeru, stream,
+                        /* pointer_mode_host = */ true, m, n,
+                        complex_cast(alpha), complex_cast(x), incx,
+                        complex_cast(y), incy, complex_cast(a), lda);
 }
 
 bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
@@ -1020,9 +1033,10 @@
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           const DeviceMemory<std::complex<double>> &y, int incy,
                           DeviceMemory<std::complex<double>> *a, int lda) {
-  LOG(ERROR) << "rocBLAS does not currently support the GERU operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zgeru, stream,
+                        /* pointer_mode_host = */ true, m, n,
+                        complex_cast(alpha), complex_cast(x), incx,
+                        complex_cast(y), incy, complex_cast(a), lda);
 }
 
 bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
@@ -1031,9 +1045,10 @@
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           std::complex<float> beta,
                           DeviceMemory<std::complex<float>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the HBMV operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_chbmv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, k, complex_cast(alpha), complex_cast(a), lda,
+      complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
@@ -1042,9 +1057,10 @@
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           std::complex<double> beta,
                           DeviceMemory<std::complex<double>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the HBMV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zhbmv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, k, complex_cast(alpha), complex_cast(a), lda,
+      complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
@@ -1053,9 +1069,10 @@
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           std::complex<float> beta,
                           DeviceMemory<std::complex<float>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the HEMV operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_chemv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(a), lda,
+      complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
@@ -1064,27 +1081,30 @@
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           std::complex<double> beta,
                           DeviceMemory<std::complex<double>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the HEMV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zhemv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(a), lda,
+      complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
                          float alpha,
                          const DeviceMemory<std::complex<float>> &x, int incx,
                          DeviceMemory<std::complex<float>> *a, int lda) {
-  LOG(ERROR) << "rocBLAS does not currently support the HER operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_cher, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
+                        complex_cast(x), incx, complex_cast(a), lda);
 }
 
 bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
                          double alpha,
                          const DeviceMemory<std::complex<double>> &x, int incx,
                          DeviceMemory<std::complex<double>> *a, int lda) {
-  LOG(ERROR) << "rocBLAS does not currently support the HER operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zher, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
+                        complex_cast(x), incx, complex_cast(a), lda);
 }
 
 bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
@@ -1092,9 +1112,10 @@
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           const DeviceMemory<std::complex<float>> &y, int incy,
                           DeviceMemory<std::complex<float>> *a, int lda) {
-  LOG(ERROR) << "rocBLAS does not currently support the HER2 operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_cher2, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
+      complex_cast(y), incy, complex_cast(a), lda);
 }
 
 bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
@@ -1102,9 +1123,10 @@
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           const DeviceMemory<std::complex<double>> &y, int incy,
                           DeviceMemory<std::complex<double>> *a, int lda) {
-  LOG(ERROR) << "rocBLAS does not currently support the HER2 operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zher2, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
+      complex_cast(y), incy, complex_cast(a), lda);
 }
 
 bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
@@ -1113,9 +1135,10 @@
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           std::complex<float> beta,
                           DeviceMemory<std::complex<float>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the HPMV operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_chpmv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(ap),
+      complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
@@ -1124,27 +1147,30 @@
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           std::complex<double> beta,
                           DeviceMemory<std::complex<double>> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the HPMV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zhpmv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(ap),
+      complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
 }
 
 bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
                          float alpha,
                          const DeviceMemory<std::complex<float>> &x, int incx,
                          DeviceMemory<std::complex<float>> *ap) {
-  LOG(ERROR) << "rocBLAS does not currently support the HPR operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_chpr, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
+                        complex_cast(x), incx, complex_cast(ap));
 }
 
 bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
                          double alpha,
                          const DeviceMemory<std::complex<double>> &x, int incx,
                          DeviceMemory<std::complex<double>> *ap) {
-  LOG(ERROR) << "rocBLAS does not currently support the HPR operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zhpr, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
+                        complex_cast(x), incx, complex_cast(ap));
 }
 
 bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
@@ -1152,9 +1178,10 @@
                           const DeviceMemory<std::complex<float>> &x, int incx,
                           const DeviceMemory<std::complex<float>> &y, int incy,
                           DeviceMemory<std::complex<float>> *ap) {
-  LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_chpr2, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
+      complex_cast(y), incy, complex_cast(ap));
 }
 
 bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
@@ -1162,105 +1189,115 @@
                           const DeviceMemory<std::complex<double>> &x, int incx,
                           const DeviceMemory<std::complex<double>> &y, int incy,
                           DeviceMemory<std::complex<double>> *ap) {
-  LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zhpr2, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
+      complex_cast(y), incy, complex_cast(ap));
 }
 
 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
                           uint64 k, float alpha, const DeviceMemory<float> &a,
                           int lda, const DeviceMemory<float> &x, int incx,
                           float beta, DeviceMemory<float> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the SBMV operation "
-             << "for the \"complex<float>\" datatype";
-
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ssbmv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x),
+      incx, &beta, GpuMemoryMutable(y), incy);
 }
 
 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
                           uint64 k, double alpha, const DeviceMemory<double> &a,
                           int lda, const DeviceMemory<double> &x, int incx,
                           double beta, DeviceMemory<double> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the SBMV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_dsbmv, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x),
+      incx, &beta, GpuMemoryMutable(y), incy);
 }
 
 bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
                           float alpha, const DeviceMemory<float> &ap,
                           const DeviceMemory<float> &x, int incx, float beta,
                           DeviceMemory<float> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the SPMV operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_sspmv, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
+                        GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
 }
 
 bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
                           double alpha, const DeviceMemory<double> &ap,
                           const DeviceMemory<double> &x, int incx, double beta,
                           DeviceMemory<double> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the SPMV operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_dspmv, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
+                        GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
 }
 
 bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
                          float alpha, const DeviceMemory<float> &x, int incx,
                          DeviceMemory<float> *ap) {
-  LOG(ERROR) << "rocBLAS does not currently support the SPR operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_sspr, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
+                        GpuMemoryMutable(ap));
 }
 
 bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
                          double alpha, const DeviceMemory<double> &x, int incx,
                          DeviceMemory<double> *ap) {
-  LOG(ERROR) << "rocBLAS does not currently support the SPR operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_dspr, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
+                        GpuMemoryMutable(ap));
 }
 
 bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
                           float alpha, const DeviceMemory<float> &x, int incx,
                           const DeviceMemory<float> &y, int incy,
                           DeviceMemory<float> *ap) {
-  LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_sspr2, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
+                        GpuMemory(y), incy, GpuMemoryMutable(ap));
 }
 
 bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
                           double alpha, const DeviceMemory<double> &x, int incx,
                           const DeviceMemory<double> &y, int incy,
                           DeviceMemory<double> *ap) {
-  LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_dspr2, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
+                        GpuMemory(y), incy, GpuMemoryMutable(ap));
 }
 
 bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
                           float alpha, const DeviceMemory<float> &a, int lda,
                           const DeviceMemory<float> &x, int incx, float beta,
                           DeviceMemory<float> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYMV operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ssymv, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
+                        GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
 }
 
 bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
                           double alpha, const DeviceMemory<double> &a, int lda,
                           const DeviceMemory<double> &x, int incx, double beta,
                           DeviceMemory<double> *y, int incy) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYMV operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_dsymv, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
+                        GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
 }
 
 bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
                          float alpha, const DeviceMemory<float> &x, int incx,
                          DeviceMemory<float> *a, int lda) {
   return DoBlasInternal(wrap::rocblas_ssyr, stream,
-                        true /* = pointer_mode_host */,
+                        /* pointer_mode_host = */ true,
                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
                         GpuMemoryMutable(a), lda);
 }
@@ -1269,7 +1306,7 @@
                          double alpha, const DeviceMemory<double> &x, int incx,
                          DeviceMemory<double> *a, int lda) {
   return DoBlasInternal(wrap::rocblas_dsyr, stream,
-                        true /* = pointer_mode_host */,
+                        /* pointer_mode_host = */ true,
                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
                         GpuMemoryMutable(a), lda);
 }
@@ -1278,36 +1315,42 @@
                           float alpha, const DeviceMemory<float> &x, int incx,
                           const DeviceMemory<float> &y, int incy,
                           DeviceMemory<float> *a, int lda) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ssyr2, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
+                        GpuMemory(y), incy, GpuMemoryMutable(a), lda);
 }
 
 bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
                           double alpha, const DeviceMemory<double> &x, int incx,
                           const DeviceMemory<double> &y, int incy,
                           DeviceMemory<double> *a, int lda) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_dsyr2, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
+                        GpuMemory(y), incy, GpuMemoryMutable(a), lda);
 }
 
 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           uint64 k, const DeviceMemory<float> &a, int lda,
                           DeviceMemory<float> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_stbmv, stream,
+                        /* pointer_mode_host = */ false,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+                        ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
+                        GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           uint64 k, const DeviceMemory<double> &a, int lda,
                           DeviceMemory<double> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_dtbmv, stream,
+                        /* pointer_mode_host = */ false,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+                        ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
+                        GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
@@ -1315,9 +1358,11 @@
                           uint64 k, const DeviceMemory<std::complex<float>> &a,
                           int lda, DeviceMemory<std::complex<float>> *x,
                           int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ctbmv, stream,
+                        /* pointer_mode_host = */ false,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+                        ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
+                        complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
@@ -1325,27 +1370,33 @@
                           uint64 k, const DeviceMemory<std::complex<double>> &a,
                           int lda, DeviceMemory<std::complex<double>> *x,
                           int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ztbmv, stream,
+                        /* pointer_mode_host = */ false,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+                        ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
+                        complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           uint64 k, const DeviceMemory<float> &a, int lda,
                           DeviceMemory<float> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_stbsv, stream,
+                        /* pointer_mode_host = */ false,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+                        ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
+                        GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           uint64 k, const DeviceMemory<double> &a, int lda,
                           DeviceMemory<double> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_dtbsv, stream,
+                        /* pointer_mode_host = */ false,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+                        ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
+                        GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
@@ -1353,9 +1404,11 @@
                           uint64 k, const DeviceMemory<std::complex<float>> &a,
                           int lda, DeviceMemory<std::complex<float>> *x,
                           int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ctbsv, stream,
+                        /* pointer_mode_host = */ false,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+                        ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
+                        complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
@@ -1363,153 +1416,171 @@
                           uint64 k, const DeviceMemory<std::complex<double>> &a,
                           int lda, DeviceMemory<std::complex<double>> *x,
                           int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ztbsv, stream,
+                        /* pointer_mode_host = */ false,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+                        ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
+                        complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
                           int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_stpmv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<double> &ap,
                           DeviceMemory<double> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_dtpmv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<std::complex<float>> &ap,
                           DeviceMemory<std::complex<float>> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ctpmv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<std::complex<double>> &ap,
                           DeviceMemory<std::complex<double>> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ztpmv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
                           int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_stpsv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<double> &ap,
                           DeviceMemory<double> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_dtpsv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<std::complex<float>> &ap,
                           DeviceMemory<std::complex<float>> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ctpsv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<std::complex<double>> &ap,
                           DeviceMemory<std::complex<double>> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ztpsv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<float> &a, int lda,
                           DeviceMemory<float> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_strmv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<double> &a, int lda,
                           DeviceMemory<double> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_dtrmv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<std::complex<float>> &a, int lda,
                           DeviceMemory<std::complex<float>> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ctrmv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<std::complex<double>> &a, int lda,
                           DeviceMemory<std::complex<double>> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ztrmv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<float> &a, int lda,
                           DeviceMemory<float> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_strsv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<double> &a, int lda,
                           DeviceMemory<double> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_dtrsv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
 }
 
 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<std::complex<float>> &a, int lda,
                           DeviceMemory<std::complex<float>> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ctrsv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
                           const DeviceMemory<std::complex<double>> &a, int lda,
                           DeviceMemory<std::complex<double>> *x, int incx) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ztrsv, stream, /* pointer_mode_host = */ false,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
+      ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
 }
 
 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
@@ -1549,7 +1620,7 @@
   const Eigen::half alpha_half(alpha);
   const Eigen::half beta_half(beta);
   return DoBlasInternal(
-      wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */,
+      wrap::rocblas_hgemm, stream, /* pointer_mode_host = */ true,
       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
       reinterpret_cast<const rocblas_half *>(&alpha_half),
       reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda,
@@ -1593,7 +1664,7 @@
     }
   }
   return DoBlasInternal(
-      wrap::rocblas_sgemm, stream, true /* = pointer_mode_host */,
+      wrap::rocblas_sgemm, stream, /* pointer_mode_host = */ true,
       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
       GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
 }
@@ -1605,7 +1676,7 @@
                           DeviceMemory<double> *c, int ldc) {
   blas_log("DoBlasGemm");
   return DoBlasInternal(
-      wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */,
+      wrap::rocblas_dgemm, stream, /* pointer_mode_host = */ true,
       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
       GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
 }
@@ -1619,7 +1690,7 @@
                           DeviceMemory<std::complex<float>> *c, int ldc) {
   blas_log("DoBlasGemm");
   return DoBlasInternal(
-      wrap::rocblas_cgemm, stream, true /* = pointer_mode_host */,
+      wrap::rocblas_cgemm, stream, /* pointer_mode_host = */ true,
       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
       complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
       complex_cast(beta), complex_cast(c), ldc);
@@ -1634,7 +1705,7 @@
                           DeviceMemory<std::complex<double>> *c, int ldc) {
   blas_log("DoBlasGemm");
   return DoBlasInternal(
-      wrap::rocblas_zgemm, stream, true /* = pointer_mode_host */,
+      wrap::rocblas_zgemm, stream, /* pointer_mode_host = */ true,
       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
       complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
       complex_cast(beta), complex_cast(c), ldc);
@@ -2044,7 +2115,7 @@
   MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
 
   bool ok;
-  ok = DoBlasInternal(rocblas_func, stream, true /* = pointer_mode_host */,
+  ok = DoBlasInternal(rocblas_func, stream, /* pointer_mode_host = */ true,
                       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
                       n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda,
                       batch_stride_a, GpuMemory(b), ldb, batch_stride_b,
@@ -2164,9 +2235,11 @@
                           const DeviceMemory<std::complex<float>> &b, int ldb,
                           std::complex<float> beta,
                           DeviceMemory<std::complex<float>> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the HEMM operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_chemm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
+                        complex_cast(a), lda, complex_cast(b), ldb,
+                        complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
@@ -2176,9 +2249,11 @@
                           const DeviceMemory<std::complex<double>> &b, int ldb,
                           std::complex<double> beta,
                           DeviceMemory<std::complex<double>> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the HEMM operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zhemm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
+                        complex_cast(a), lda, complex_cast(b), ldb,
+                        complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
@@ -2187,9 +2262,11 @@
                           const DeviceMemory<std::complex<float>> &a, int lda,
                           float beta, DeviceMemory<std::complex<float>> *c,
                           int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the HERK operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_cherk, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
+                        k, complex_cast(alpha), complex_cast(a), lda,
+                        complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
@@ -2198,9 +2275,11 @@
                           const DeviceMemory<std::complex<double>> &a, int lda,
                           double beta, DeviceMemory<std::complex<double>> *c,
                           int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the HERK operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zherk, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
+                        k, complex_cast(alpha), complex_cast(a), lda,
+                        complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
@@ -2210,9 +2289,11 @@
                            const DeviceMemory<std::complex<float>> &b, int ldb,
                            float beta, DeviceMemory<std::complex<float>> *c,
                            int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the HER2K operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_cher2k, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
+      complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
+      complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
@@ -2222,9 +2303,11 @@
                            const DeviceMemory<std::complex<double>> &b, int ldb,
                            double beta, DeviceMemory<std::complex<double>> *c,
                            int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the HER2K operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zher2k, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
+      complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
+      complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
@@ -2232,9 +2315,10 @@
                           float alpha, const DeviceMemory<float> &a, int lda,
                           const DeviceMemory<float> &b, int ldb, float beta,
                           DeviceMemory<float> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ssymm, stream, /* pointer_mode_host = */ true,
+      ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, &alpha, GpuMemory(a),
+      lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
@@ -2242,9 +2326,10 @@
                           double alpha, const DeviceMemory<double> &a, int lda,
                           const DeviceMemory<double> &b, int ldb, double beta,
                           DeviceMemory<double> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_dsymm, stream, /* pointer_mode_host = */ true,
+      ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, &alpha, GpuMemory(a),
+      lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
@@ -2254,9 +2339,11 @@
                           const DeviceMemory<std::complex<float>> &b, int ldb,
                           std::complex<float> beta,
                           DeviceMemory<std::complex<float>> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_csymm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
+                        complex_cast(a), lda, complex_cast(b), ldb,
+                        complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
@@ -2266,27 +2353,31 @@
                           const DeviceMemory<std::complex<double>> &b, int ldb,
                           std::complex<double> beta,
                           DeviceMemory<std::complex<double>> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zsymm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
+                        complex_cast(a), lda, complex_cast(b), ldb,
+                        complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, uint64 n, uint64 k,
                           float alpha, const DeviceMemory<float> &a, int lda,
                           float beta, DeviceMemory<float> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ssyrk, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
+      GpuMemory(a), lda, &beta, GpuMemoryMutable(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
                           blas::Transpose trans, uint64 n, uint64 k,
                           double alpha, const DeviceMemory<double> &a, int lda,
                           double beta, DeviceMemory<double> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_dsyrk, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
+      GpuMemory(a), lda, &beta, GpuMemoryMutable(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
@@ -2295,9 +2386,11 @@
                           const DeviceMemory<std::complex<float>> &a, int lda,
                           std::complex<float> beta,
                           DeviceMemory<std::complex<float>> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_csyrk, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
+                        k, complex_cast(alpha), complex_cast(a), lda,
+                        complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
@@ -2306,9 +2399,11 @@
                           const DeviceMemory<std::complex<double>> &a, int lda,
                           std::complex<double> beta,
                           DeviceMemory<std::complex<double>> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_zsyrk, stream,
+                        /* pointer_mode_host = */ true,
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
+                        k, complex_cast(alpha), complex_cast(a), lda,
+                        complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
@@ -2316,9 +2411,10 @@
                            float alpha, const DeviceMemory<float> &a, int lda,
                            const DeviceMemory<float> &b, int ldb, float beta,
                            DeviceMemory<float> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_ssyr2k, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
+      GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
@@ -2326,9 +2422,10 @@
                            double alpha, const DeviceMemory<double> &a, int lda,
                            const DeviceMemory<double> &b, int ldb, double beta,
                            DeviceMemory<double> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_dsyr2k, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
+      GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
@@ -2338,9 +2435,11 @@
                            const DeviceMemory<std::complex<float>> &b, int ldb,
                            std::complex<float> beta,
                            DeviceMemory<std::complex<float>> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_csyr2k, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
+      complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
+      complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
@@ -2350,9 +2449,11 @@
                            const DeviceMemory<std::complex<double>> &b, int ldb,
                            std::complex<double> beta,
                            DeviceMemory<std::complex<double>> *c, int ldc) {
-  LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(
+      wrap::rocblas_zsyr2k, stream, /* pointer_mode_host = */ true,
+      ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
+      complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
+      complex_cast(beta), complex_cast(c), ldc);
 }
 
 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
@@ -2360,9 +2461,11 @@
                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
                           const DeviceMemory<float> &a, int lda,
                           DeviceMemory<float> *b, int ldb) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
-             << "for the \"float\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_strmm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
+                        ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
+                        GpuMemoryMutable(b), ldb);
 }
 
 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
@@ -2370,9 +2473,11 @@
                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
                           const DeviceMemory<double> &a, int lda,
                           DeviceMemory<double> *b, int ldb) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
-             << "for the \"double\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_dtrmm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
+                        ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
+                        GpuMemoryMutable(b), ldb);
 }
 
 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
@@ -2381,9 +2486,11 @@
                           std::complex<float> alpha,
                           const DeviceMemory<std::complex<float>> &a, int lda,
                           DeviceMemory<std::complex<float>> *b, int ldb) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ctrmm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
+                        ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
+                        complex_cast(a), lda, complex_cast(b), ldb);
 }
 
 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
@@ -2392,9 +2499,11 @@
                           std::complex<double> alpha,
                           const DeviceMemory<std::complex<double>> &a, int lda,
                           DeviceMemory<std::complex<double>> *b, int ldb) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ztrmm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
+                        ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
+                        complex_cast(a), lda, complex_cast(b), ldb);
 }
 
 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
@@ -2403,11 +2512,11 @@
                           const DeviceMemory<float> &a, int lda,
                           DeviceMemory<float> *b, int ldb) {
   blas_log("DoBlasTrsm");
-  return DoBlasInternal(
-      wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
-      ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
-      ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float *>(GpuMemory(a)),
-      lda, GpuMemoryMutable(b), ldb);
+  return DoBlasInternal(wrap::rocblas_strsm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
+                        ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
+                        GpuMemoryMutable(b), ldb);
 }
 
 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
@@ -2416,11 +2525,11 @@
                           const DeviceMemory<double> &a, int lda,
                           DeviceMemory<double> *b, int ldb) {
   blas_log("DoBlasTrsm");
-  return DoBlasInternal(
-      wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
-      ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
-      ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double *>(GpuMemory(a)),
-      lda, GpuMemoryMutable(b), ldb);
+  return DoBlasInternal(wrap::rocblas_dtrsm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
+                        ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
+                        GpuMemoryMutable(b), ldb);
 }
 
 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
@@ -2429,9 +2538,11 @@
                           std::complex<float> alpha,
                           const DeviceMemory<std::complex<float>> &a, int lda,
                           DeviceMemory<std::complex<float>> *b, int ldb) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRSM operation "
-             << "for the \"complex<float>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ctrsm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
+                        ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
+                        complex_cast(a), lda, complex_cast(b), ldb);
 }
 
 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
@@ -2440,9 +2551,11 @@
                           std::complex<double> alpha,
                           const DeviceMemory<std::complex<double>> &a, int lda,
                           DeviceMemory<std::complex<double>> *b, int ldb) {
-  LOG(ERROR) << "rocBLAS does not currently support the TRSM operation "
-             << "for the \"complex<double>\" datatype";
-  return false;
+  return DoBlasInternal(wrap::rocblas_ztrsm, stream,
+                        /* pointer_mode_host = */ true, ROCMBlasSide(side),
+                        ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
+                        ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
+                        complex_cast(a), lda, complex_cast(b), ldb);
 }
 
 bool ROCMBlas::DoBlasGemmStridedBatched(
diff --git a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc
index 2a85cb8..dbab030 100644
--- a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc
+++ b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc
@@ -856,6 +856,11 @@
 
     float clock_rate_ghz = static_cast<float>(prop.clockRate) / 1e6;
     builder.set_clock_rate_ghz(clock_rate_ghz);
+
+    // mem_bandwidth = 2 * mem_bus_width_in_bytes * mem_clock_rate_in_hz
+    int64 memory_bandwidth = 2 * (int64(prop.memoryBusWidth) / 8) *
+                             (int64(prop.memoryClockRate) * 1000);
+    builder.set_memory_bandwidth(memory_bandwidth);
   }
 
   {
diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD
index b8f7582..1a7e130 100644
--- a/tensorflow/stream_executor/tpu/BUILD
+++ b/tensorflow/stream_executor/tpu/BUILD
@@ -5,6 +5,7 @@
 package(
     default_visibility = [
         "//learning/brain/experimental/dtensor:__subpackages__",
+        "//tensorflow/core/profiler/internal/tpu:__subpackages__",
         "//tensorflow/core/tpu:__subpackages__",
     ],
     licenses = ["notice"],  # Apache 2.0
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 236475a..ffe9c8b 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -1015,7 +1015,7 @@
     native.py_library(
         name = generated_target_name,
         srcs = [out],
-        srcs_version = "PY2AND3",
+        srcs_version = "PY3",
         visibility = visibility,
         deps = [
             clean_dep("//tensorflow/python:framework_for_generated_wrappers_v2"),
@@ -1828,7 +1828,7 @@
         srcs = [],
         dso = [],
         kernels = [],
-        srcs_version = "PY2AND3",
+        srcs_version = "PY3",
         visibility = None,
         deps = [],
         **kwargs):
@@ -2019,7 +2019,7 @@
     native.py_library(
         name = name,
         srcs = [":" + name + ".py"],
-        srcs_version = "PY2AND3",
+        srcs_version = "PY3",
         data = select({
             clean_dep("//tensorflow:windows"): [":" + cc_library_pyd_name],
             "//conditions:default": [":" + cc_library_name],
@@ -2139,7 +2139,7 @@
             deps.append(clean_dep(to_add))
 
     # Python version placeholder
-    kwargs.setdefault("srcs_version", "PY2AND3")
+    kwargs.setdefault("srcs_version", "PY3")
     py_test(
         name = name,
         size = size,
@@ -2500,7 +2500,7 @@
         module_name,
         hdrs = [],
         features = [],
-        srcs_version = "PY2AND3",
+        srcs_version = "PY3",
         data = [],
         copts = [],
         linkopts = [],
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
index 88fd63d..c070882 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
@@ -100,6 +100,12 @@
       label: LABEL_OPTIONAL
       type: TYPE_INT64
     }
+    field {
+      name: "use_tfrt"
+      number: 18
+      label: LABEL_OPTIONAL
+      type: TYPE_BOOL
+    }
     enum_type {
       name: "MlirBridgeRollout"
       value: {
@@ -114,6 +120,10 @@
         name: "MLIR_BRIDGE_ROLLOUT_DISABLED"
         number: 2
       }
+      value: {
+        name: "MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED"
+        number: 3
+      }
     }
     reserved_range {
       start: 2
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
index a598071..34bf0d5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
@@ -229,6 +229,12 @@
         label: LABEL_OPTIONAL
         type: TYPE_INT64
       }
+      field {
+        name: "use_tfrt"
+        number: 18
+        label: LABEL_OPTIONAL
+        type: TYPE_BOOL
+      }
       enum_type {
         name: "MlirBridgeRollout"
         value: {
@@ -243,6 +249,10 @@
           name: "MLIR_BRIDGE_ROLLOUT_DISABLED"
           number: 2
         }
+        value: {
+          name: "MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED"
+          number: 3
+        }
       }
       reserved_range {
         start: 2
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-writer.pbtxt
index fa15dc8..bc49b72 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-writer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-writer.pbtxt
@@ -1,7 +1,7 @@
 path: "tensorflow.io.TFRecordWriter"
 tf_class {
   is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordWriter\'>"
-  is_instance: "<class \'tensorflow.python._pywrap_record_io.RecordWriter\'>"
+  is_instance: "<class \'tensorflow.python.lib.io._pywrap_record_io.RecordWriter\'>"
   is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
   member_method {
     name: "__init__"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-writer.pbtxt
index f2053da..bbd6491 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-writer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-writer.pbtxt
@@ -1,7 +1,7 @@
 path: "tensorflow.python_io.TFRecordWriter"
 tf_class {
   is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordWriter\'>"
-  is_instance: "<class \'tensorflow.python._pywrap_record_io.RecordWriter\'>"
+  is_instance: "<class \'tensorflow.python.lib.io._pywrap_record_io.RecordWriter\'>"
   is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
   member_method {
     name: "__init__"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 21dd896..7f6943d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -753,10 +753,18 @@
     argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
   member_method {
+    name: "CollectiveBcastRecvV2"
+    argspec: "args=[\'group_size\', \'group_key\', \'instance_key\', \'shape\', \'T\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
+  }
+  member_method {
     name: "CollectiveBcastSend"
     argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
   member_method {
+    name: "CollectiveBcastSendV2"
+    argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
+  }
+  member_method {
     name: "CollectiveGather"
     argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-normal.pbtxt
index 0df534b..529f715 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.GlorotNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-uniform.pbtxt
index 15db2f1..1e87df9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.GlorotUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-normal.pbtxt
index c23aa78..3075435 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.HeNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-uniform.pbtxt
index 70412ed..28c7048 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.HeUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-identity.pbtxt
index b6f3b9f..b0df723 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-identity.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-identity.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.initializers.Identity"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Identity\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Identity\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-normal.pbtxt
index a392394..8a5fd18 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.LecunNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-uniform.pbtxt
index d863752..3a0d9ee 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.LecunUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-orthogonal.pbtxt
index c918524..cac6ffc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-orthogonal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-orthogonal.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.initializers.Orthogonal"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Orthogonal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Orthogonal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-truncated-normal.pbtxt
index 53b9f20..02712ef 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-truncated-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-truncated-normal.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.initializers.TruncatedNormal"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.TruncatedNormal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.TruncatedNormal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-variance-scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-variance-scaling.pbtxt
index bb9a847..bb6c193 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-variance-scaling.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-variance-scaling.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.initializers.VarianceScaling"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt
index 30e92a3..78d7748 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.GlorotNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt
index fc43bef..9808f8b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.GlorotUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_normal.pbtxt
index 0cade59..4ae0d68 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.HeNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_uniform.pbtxt
index 3b43fd2..f5bbf39 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.HeUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt
index e857e75..8968272 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.initializers.identity"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Identity\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Identity\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_normal.pbtxt
index 8dfe4da..a9af697 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.LecunNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_uniform.pbtxt
index df8dfef..fd91a6a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.LecunUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt
index fa90188..81cee0a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.initializers.orthogonal"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Orthogonal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Orthogonal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt
index 9750914..4d0a400 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.initializers.truncated_normal"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.TruncatedNormal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.TruncatedNormal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.variance_scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.variance_scaling.pbtxt
index 5cff80e..97bc11c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.variance_scaling.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.variance_scaling.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.initializers.variance_scaling"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-writer.pbtxt
index fa15dc8..bc49b72 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-writer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-writer.pbtxt
@@ -1,7 +1,7 @@
 path: "tensorflow.io.TFRecordWriter"
 tf_class {
   is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordWriter\'>"
-  is_instance: "<class \'tensorflow.python._pywrap_record_io.RecordWriter\'>"
+  is_instance: "<class \'tensorflow.python.lib.io._pywrap_record_io.RecordWriter\'>"
   is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
   member_method {
     name: "__init__"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-normal.pbtxt
index a9f5593..d332e0b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.GlorotNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-uniform.pbtxt
index 255b1c1..10afe97 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.GlorotUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-normal.pbtxt
index 5b53b41..d59d94a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.HeNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-uniform.pbtxt
index 41fd8a2..ac1e5fc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.HeUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt
index 1a02232..0e763b5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.keras.initializers.Identity"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Identity\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Identity\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-normal.pbtxt
index 6ef45b2..0096045 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.LecunNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-uniform.pbtxt
index d2e590a..7ef39b5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.LecunUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt
index e1d23ed..4e50862 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.keras.initializers.Orthogonal"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Orthogonal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Orthogonal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt
index 14fe954..08f2164 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.keras.initializers.TruncatedNormal"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.TruncatedNormal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.TruncatedNormal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt
index c0e3d35..5e6d5f3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.keras.initializers.VarianceScaling"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt
index 5bca6a3..40f65fb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.GlorotNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt
index 3a6cbe1..c42584c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.GlorotUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_normal.pbtxt
index 5ece8ae..6393b9a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.HeNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_uniform.pbtxt
index 0d2dc7e..5fd1fc7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.HeUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt
index 647864a..36a4924 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.keras.initializers.identity"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Identity\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Identity\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_normal.pbtxt
index 4eb04c9..ccaafdf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_normal.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.LecunNormal\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_uniform.pbtxt
index d1f8e8a..11d760a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_uniform.pbtxt
@@ -2,8 +2,6 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.LecunUniform\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt
index 227f895..f89f581 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.keras.initializers.orthogonal"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Orthogonal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Orthogonal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt
index e5ebf90..b03849a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.keras.initializers.truncated_normal"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.TruncatedNormal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.TruncatedNormal\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.variance_scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.variance_scaling.pbtxt
index 4ec96ca..d5e4f79 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.variance_scaling.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.variance_scaling.pbtxt
@@ -1,8 +1,6 @@
 path: "tensorflow.keras.initializers.variance_scaling"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
-  is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
   is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
   is_instance: "<type \'object\'>"
   member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 21dd896..7f6943d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -753,10 +753,18 @@
     argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
   member_method {
+    name: "CollectiveBcastRecvV2"
+    argspec: "args=[\'group_size\', \'group_key\', \'instance_key\', \'shape\', \'T\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
+  }
+  member_method {
     name: "CollectiveBcastSend"
     argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
   member_method {
+    name: "CollectiveBcastSendV2"
+    argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
+  }
+  member_method {
     name: "CollectiveGather"
     argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
diff --git a/tensorflow/tools/ci_build/Dockerfile.micro b/tensorflow/tools/ci_build/Dockerfile.micro
index cf2d715..3122980 100644
--- a/tensorflow/tools/ci_build/Dockerfile.micro
+++ b/tensorflow/tools/ci_build/Dockerfile.micro
@@ -5,7 +5,14 @@
 
 LABEL maintainer="Pete Warden <petewarden@google.com>"
 
-RUN apt-get update && apt-get install -y zip xxd
+RUN echo deb http://apt.llvm.org/buster/ llvm-toolchain-buster main > /etc/apt/sources.list.d/llvm.list
+RUN wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add -
+
+RUN apt-get update
+
+RUN apt-get install -y zip xxd
+RUN apt-get install -y clang-format
+
 RUN pip install six
 # Install Renode test dependencies
 RUN pip install pyyaml requests psutil robotframework==3.1
diff --git a/tensorflow/tools/ci_build/builds/libtensorflow.sh b/tensorflow/tools/ci_build/builds/libtensorflow.sh
index bfd551f..aa7cbef 100755
--- a/tensorflow/tools/ci_build/builds/libtensorflow.sh
+++ b/tensorflow/tools/ci_build/builds/libtensorflow.sh
@@ -57,6 +57,8 @@
     BAZEL_OPTS="${BAZEL_OPTS} --config=cuda --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain"
     export TF_NEED_ROCM=0
     export TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80"
+  else
+    BAZEL_OPTS="${BAZEL_OPTS} --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
   fi
   bazel clean --expunge
   yes "" | ./configure
diff --git a/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh b/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh
index ce7789b..0f0f182 100755
--- a/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh
@@ -102,7 +102,7 @@
 pip2 install keras_preprocessing==1.0.5 --no-deps
 pip3 install keras_preprocessing==1.0.5 --no-deps
 pip2 install --upgrade h5py==2.8.0
-pip3 install --upgrade h5py==2.8.0
+pip3 install --upgrade h5py==3.1.0
 
 # Estimator
 pip2 install tf-estimator-nightly --no-deps
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index 578967a..f9893f0 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -134,7 +134,7 @@
 pip2 install keras_preprocessing==1.1.0 --no-deps
 pip3 install keras_preprocessing==1.1.0 --no-deps
 pip2 install --upgrade h5py==2.8.0
-pip3 install --upgrade h5py==2.8.0
+pip3 install --upgrade h5py==3.1.0
 
 # Estimator
 pip2 install tf-estimator-nightly --no-deps
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index bb53fc9..9530c9f 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -87,7 +87,7 @@
 
 # Keras
 pip3.5 install keras_preprocessing==1.0.5
-pip3.5 install --upgrade h5py==2.8.0
+pip3.5 install --upgrade h5py==3.1.0
 
 # Estimator
 pip3.5 install tf-estimator-nightly==1.12.0.dev20181203 --no-deps
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index bcf0d0b..f130ab8 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -101,7 +101,7 @@
 pip3 install --upgrade gast
 pip3 install --upgrade termcolor
 
-pip3 install --upgrade h5py==2.8.0
+pip3 install --upgrade h5py==3.1.0
 
 # Keras
 pip3 install keras_preprocessing==1.0.5
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
index 7b2ba29..4fd671c 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
@@ -31,7 +31,7 @@
 ROOT_DIR="$(realpath ${SCRIPT_DIR}/../../../../)"
 
 DOCKER_IMAGE="tf-libtensorflow-cpu"
-DOCKER_FILE="Dockerfile.cpu"
+DOCKER_FILE="Dockerfile.rbe.ubuntu16.04-manylinux2010"
 DOCKER_BINARY="docker"
 if [ "${TF_NEED_CUDA}" == "1" ]; then
   DOCKER_IMAGE="tf-tensorflow-gpu"
diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh
index dd7e2a5..3f36d2c 100644
--- a/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh
+++ b/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh
@@ -27,7 +27,7 @@
 source tf_build_env/bin/activate
 
 # Install macos pip dependencies
-install_macos_pip_deps sudo pip3.6
+install_macos_pip_deps virtualenv
 
 # Run configure.
 export TF_NEED_CUDA=0
diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh
index 2f73ad6..b3cddd0 100644
--- a/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh
+++ b/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh
@@ -22,11 +22,11 @@
 # Pick a more recent version of xcode
 export DEVELOPER_DIR=/Applications/Xcode_11.3.app/Contents/Developer
 sudo xcode-select -s "${DEVELOPER_DIR}"
-python -m virtualenv tf_build_env --system-site-packages
+python3.7 -m virtualenv tf_build_env --system-site-packages
 source tf_build_env/bin/activate
 
 # Install macos pip dependencies
-install_macos_pip_deps sudo pip3.7
+install_macos_pip_deps virtualenv
 
 # Run configure.
 export TF_NEED_CUDA=0
diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh
index 11b557a..70d742b 100644
--- a/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh
+++ b/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh
@@ -23,11 +23,11 @@
 export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer
 export MACOSX_DEPLOYMENT_TARGET=10.10
 sudo xcode-select -s "${DEVELOPER_DIR}"
-python -m virtualenv tf_build_env --system-site-packages
+python3.8 -m virtualenv tf_build_env --system-site-packages
 source tf_build_env/bin/activate
 
 # Install macos pip dependencies
-install_macos_pip_deps sudo pip3.8
+install_macos_pip_deps virtualenv
 
 # Run configure.
 export TF_NEED_CUDA=0
diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh
index 11b7c2a..d572ac1 100644
--- a/tensorflow/tools/ci_build/release/common.sh
+++ b/tensorflow/tools/ci_build/release/common.sh
@@ -126,7 +126,7 @@
   "${PIP_CMD}" install --user 'astunparse ~= 1.6.3'
   "${PIP_CMD}" install --user 'flatbuffers ~= 1.12.0'
   "${PIP_CMD}" install --user 'google_pasta ~= 0.2'
-  "${PIP_CMD}" install --user 'h5py ~= 2.10.0'
+  "${PIP_CMD}" install --user 'h5py ~= 3.1.0'
   "${PIP_CMD}" install --user 'keras_preprocessing ~= 1.1.2'
   "${PIP_CMD}" install --user 'numpy ~= 1.19.2'
   "${PIP_CMD}" install --user 'opt_einsum ~= 3.3.0'
@@ -188,7 +188,7 @@
   ${PIP_CMD} install $USER_FLAG 'astunparse ~= 1.6.3'
   ${PIP_CMD} install $USER_FLAG 'flatbuffers ~= 1.12.0'
   ${PIP_CMD} install $USER_FLAG 'google_pasta ~= 0.2'
-  ${PIP_CMD} install $USER_FLAG 'h5py ~= 2.10.0'
+  ${PIP_CMD} install $USER_FLAG 'h5py ~= 3.1.0'
   ${PIP_CMD} install $USER_FLAG 'keras_preprocessing ~= 1.1.2'
   ${PIP_CMD} install $USER_FLAG 'numpy ~= 1.19.2'
   ${PIP_CMD} install $USER_FLAG 'opt_einsum ~= 3.3.0'
diff --git a/tensorflow/tools/ci_build/release/common_win.bat b/tensorflow/tools/ci_build/release/common_win.bat
index f27ec31..dbe159a 100644
--- a/tensorflow/tools/ci_build/release/common_win.bat
+++ b/tensorflow/tools/ci_build/release/common_win.bat
@@ -18,7 +18,7 @@
 @REM Set Environment Variables
 @REM
 IF NOT DEFINED PYTHON_DIRECTORY (
-  SET PYTHON_DIRECTORY=Python36
+  SET PYTHON_DIRECTORY=Python37
 )
 SET PY_EXE=C:\%PYTHON_DIRECTORY%\python.exe
 SET PATH=%PATH%;C:\%PYTHON_DIRECTORY%
@@ -32,7 +32,7 @@
 %PY_EXE% -m pip install "astunparse ~= 1.6.3"
 %PY_EXE% -m pip install "flatbuffers ~= 1.12.0"
 %PY_EXE% -m pip install "google_pasta ~= 0.2"
-%PY_EXE% -m pip install "h5py ~= 2.10.0"
+%PY_EXE% -m pip install "h5py ~= 3.1.0"
 %PY_EXE% -m pip install "keras_preprocessing ~= 1.1.2"
 %PY_EXE% -m pip install "numpy ~= 1.19.2"
 %PY_EXE% -m pip install "opt_einsum ~= 3.3.0"
diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD
index 80b1e59..c5d53f6 100644
--- a/tensorflow/tools/common/BUILD
+++ b/tensorflow/tools/common/BUILD
@@ -48,7 +48,7 @@
         ":test_module1",
         ":test_module2",
         ":traverse",
-        "//tensorflow/python:platform_test",
+        "//tensorflow/python/platform:test",
     ],
 )
 
diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD
index 1aa76fb..238b690 100644
--- a/tensorflow/tools/compatibility/BUILD
+++ b/tensorflow/tools/compatibility/BUILD
@@ -270,7 +270,6 @@
     srcs = ["test_file_v2_0.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
-    tags = ["no_rocm"],
     deps = [
         "//tensorflow:tensorflow_py",
     ],
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index ca13f43..959bc57 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -1,4 +1,4 @@
-[//tensorflow/python:cpp_python_util] # util tfe
+[//tensorflow/python/util:cpp_python_util] # util tfe
 tensorflow::swig::IsSequence
 tensorflow::swig::IsSequenceOrComposite
 tensorflow::swig::IsCompositeTensor
@@ -48,7 +48,7 @@
 
 [//tensorflow/python:bfloat16_lib] # bfloat16
 tensorflow::RegisterNumpyBfloat16
-tensorflow::Bfloat16PyType
+tensorflow::Bfloat16Dtype
 
 [//tensorflow/python:py_func_lib] # py_func
 tensorflow::InitializePyTrampoline
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/arm64v8/devel-cpu-arm64v8-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/arm64v8/devel-cpu-arm64v8-jupyter.Dockerfile
index 74ffbb8..4f7fe84 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/arm64v8/devel-cpu-arm64v8-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/arm64v8/devel-cpu-arm64v8-jupyter.Dockerfile
@@ -62,7 +62,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
@@ -117,9 +117,7 @@
 
 RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
 RUN mkdir /.local && chmod a+rwx /.local
-RUN apt-get install -y --no-install-recommends wget
-# some examples require git to fetch dependencies
-RUN apt-get install -y --no-install-recommends git
+RUN apt-get update && apt-get install -y --no-install-recommends wget git
 WORKDIR /tf/tensorflow-tutorials
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/arm64v8/devel-cpu-arm64v8.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/arm64v8/devel-cpu-arm64v8.Dockerfile
index c2861e9..52727c9 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/arm64v8/devel-cpu-arm64v8.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/arm64v8/devel-cpu-arm64v8.Dockerfile
@@ -62,7 +62,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
index 107d1b4..deec0d2 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
@@ -33,7 +33,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
@@ -60,9 +60,7 @@
 
 RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
 RUN mkdir /.local && chmod a+rwx /.local
-RUN apt-get install -y --no-install-recommends wget
-# some examples require git to fetch dependencies
-RUN apt-get install -y --no-install-recommends git
+RUN apt-get update && apt-get install -y --no-install-recommends wget git
 WORKDIR /tf/tensorflow-tutorials
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile
index e83592c..e12571e 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile
@@ -33,7 +33,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
index 78ec441..a496ad7 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
@@ -62,7 +62,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
@@ -111,9 +111,7 @@
 
 RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
 RUN mkdir /.local && chmod a+rwx /.local
-RUN apt-get install -y --no-install-recommends wget
-# some examples require git to fetch dependencies
-RUN apt-get install -y --no-install-recommends git
+RUN apt-get update && apt-get install -y --no-install-recommends wget git
 WORKDIR /tf/tensorflow-tutorials
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
index 018b7bb..4973ddd 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
@@ -62,7 +62,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
index 23a4e33..c1d5732 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
@@ -101,7 +101,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
@@ -150,9 +150,7 @@
 
 RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
 RUN mkdir /.local && chmod a+rwx /.local
-RUN apt-get install -y --no-install-recommends wget
-# some examples require git to fetch dependencies
-RUN apt-get install -y --no-install-recommends git
+RUN apt-get update && apt-get install -y --no-install-recommends wget git
 WORKDIR /tf/tensorflow-tutorials
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
index a477081..010b756 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
@@ -101,7 +101,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
index 7dda72e..a5fa472 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
@@ -79,7 +79,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
@@ -106,9 +106,7 @@
 
 RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
 RUN mkdir /.local && chmod a+rwx /.local
-RUN apt-get install -y --no-install-recommends wget
-# some examples require git to fetch dependencies
-RUN apt-get install -y --no-install-recommends git
+RUN apt-get update && apt-get install -y --no-install-recommends wget git
 WORKDIR /tf/tensorflow-tutorials
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
index 279a790..5802593 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
@@ -79,7 +79,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-jupyter.Dockerfile
index c51b7bf..5961438 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpi-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpi-horovod-jupyter.Dockerfile
index 35494c2..41047e8 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpi-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpi-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpi-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpi-horovod.Dockerfile
index cd0f5f0..afddc75 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpi-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpi-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpich-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpich-horovod-jupyter.Dockerfile
index f4fd26e..22e4759 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpich-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpich-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpich-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpich-horovod.Dockerfile
index 751c093..4e0189e 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpich-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel-mpich-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel.Dockerfile
index a4d28dd..91a44dd 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-devel.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-jupyter.Dockerfile
index 5f6a898..8e64c7e 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpi-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpi-horovod-jupyter.Dockerfile
index e995a73..9b1a8ab 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpi-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpi-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpi-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpi-horovod.Dockerfile
index 7e853dd..ea62cea 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpi-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpi-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpich-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpich-horovod-jupyter.Dockerfile
index 2e91c6b..de65f0a 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpich-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpich-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpich-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpich-horovod.Dockerfile
index 50b19bf..2b82d58 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpich-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8-mpich-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8.Dockerfile
index 692c83e..00f6191 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/centos-8.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-jupyter.Dockerfile
index ffc951f..aa67879 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpi-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpi-horovod-jupyter.Dockerfile
index 34485a5..5dd1fb3 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpi-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpi-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpi-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpi-horovod.Dockerfile
index 85e271f..c6e1482 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpi-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpi-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpich-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpich-horovod-jupyter.Dockerfile
index ee6abd8..56a95f9 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpich-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpich-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpich-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpich-horovod.Dockerfile
index daf92ea..181a3a8 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpich-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel-mpich-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel.Dockerfile
index 10ae251..d84f707 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-devel.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-jupyter.Dockerfile
index 30729f9..61bce40 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpi-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpi-horovod-jupyter.Dockerfile
index 7a46ea0..268972d 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpi-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpi-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpi-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpi-horovod.Dockerfile
index 8fb1ee5..a9d6603 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpi-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpi-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpich-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpich-horovod-jupyter.Dockerfile
index 32f935e..bea23d8 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpich-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpich-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpich-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpich-horovod.Dockerfile
index 1187500..acef18e 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpich-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04-mpich-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04.Dockerfile
index 6a6cdf5..9033cae 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-16.04.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-jupyter.Dockerfile
index ffc951f..aa67879 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpi-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpi-horovod-jupyter.Dockerfile
index 34485a5..5dd1fb3 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpi-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpi-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpi-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpi-horovod.Dockerfile
index 85e271f..c6e1482 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpi-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpi-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpich-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpich-horovod-jupyter.Dockerfile
index 030fb86..13765bb 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpich-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpich-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpich-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpich-horovod.Dockerfile
index ad763a8..4937995 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpich-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel-mpich-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel.Dockerfile
index 10ae251..d84f707 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-devel.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-jupyter.Dockerfile
index 30729f9..61bce40 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpi-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpi-horovod-jupyter.Dockerfile
index 65043d1..f147acb 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpi-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpi-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpi-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpi-horovod.Dockerfile
index 69efc88..b6e0d02 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpi-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpi-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpich-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpich-horovod-jupyter.Dockerfile
index 0b42892..b2fc1c9 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpich-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpich-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpich-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpich-horovod.Dockerfile
index f570e92..8043109 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpich-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04-mpich-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04.Dockerfile
index 6a6cdf5..9033cae 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-18.04.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-jupyter.Dockerfile
index b1f1edf..09e9695 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpi-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpi-horovod-jupyter.Dockerfile
index 92b8101..3655569 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpi-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpi-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpi-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpi-horovod.Dockerfile
index 72275fc..66a3c8e 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpi-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpi-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpich-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpich-horovod-jupyter.Dockerfile
index f123955..b93d1a6 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpich-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpich-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpich-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpich-horovod.Dockerfile
index d4abafe..95a63fe 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpich-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel-mpich-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel.Dockerfile
index f8ae3df..4ad8192 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-devel.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-jupyter.Dockerfile
index 2b14525..8794668 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpi-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpi-horovod-jupyter.Dockerfile
index 09527a8..7830420 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpi-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpi-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpi-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpi-horovod.Dockerfile
index a703ed3..a5084ee 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpi-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpi-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpich-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpich-horovod-jupyter.Dockerfile
index 65473ac..a35aa12 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpich-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpich-horovod-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpich-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpich-horovod.Dockerfile
index 24bd164..2090352 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpich-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04-mpich-horovod.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04.Dockerfile
index 666e083..42a74df 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/onednn/ubuntu-20.04.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile
index 0a284f4..16163ae 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile
@@ -33,7 +33,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
@@ -78,9 +78,7 @@
 
 RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
 RUN mkdir /.local && chmod a+rwx /.local
-RUN apt-get install -y --no-install-recommends wget
-# some examples require git to fetch dependencies
-RUN apt-get install -y --no-install-recommends git
+RUN apt-get update && apt-get install -y --no-install-recommends wget git
 WORKDIR /tf/tensorflow-tutorials
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile
index 831e5ae..cbcd2e0 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile
@@ -33,7 +33,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile
index 7a5c905..bc07321 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile
@@ -62,7 +62,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
@@ -112,9 +112,7 @@
 
 RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
 RUN mkdir /.local && chmod a+rwx /.local
-RUN apt-get install -y --no-install-recommends wget
-# some examples require git to fetch dependencies
-RUN apt-get install -y --no-install-recommends git
+RUN apt-get update && apt-get install -y --no-install-recommends wget git
 WORKDIR /tf/tensorflow-tutorials
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile
index cd97e2e..d75993e 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile
@@ -62,7 +62,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
index 946136f..68f9f22 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
@@ -22,37 +22,34 @@
 ARG UBUNTU_VERSION=18.04
 
 ARG ARCH=
-ARG CUDA=10.1
+ARG CUDA=11.0
 FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
 # ARCH and CUDA are specified again because the FROM directive resets ARGs
 # (but their default value is retained if set previously)
 ARG ARCH
 ARG CUDA
-ARG CUDNN=7.6.4.38-1
-ARG CUDNN_MAJOR_VERSION=7
+ARG CUDNN=8.0.4.30-1
+ARG CUDNN_MAJOR_VERSION=8
 ARG LIB_DIR_PREFIX=x86_64
-ARG LIBNVINFER=6.0.1-1
-ARG LIBNVINFER_MAJOR_VERSION=6
+ARG LIBNVINFER=7.1.3-1
+ARG LIBNVINFER_MAJOR_VERSION=7
 
 # Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cuda-command-line-tools-${CUDA/./-} \
-        # There appears to be a regression in libcublas10=10.2.2.89-1 which
-        # prevents cublas from initializing in TF. See
-        # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
-        libcublas10=10.2.1.243-1 \ 
-        libcublas-dev=10.2.1.243-1 \
+        libcublas-${CUDA/./-} \
+        libcublas-dev-${CUDA/./-} \
         cuda-nvrtc-${CUDA/./-} \
         cuda-nvrtc-dev-${CUDA/./-} \
         cuda-cudart-dev-${CUDA/./-} \
-        cuda-cufft-dev-${CUDA/./-} \
-        cuda-curand-dev-${CUDA/./-} \
-        cuda-cusolver-dev-${CUDA/./-} \
-        cuda-cusparse-dev-${CUDA/./-} \
-        libcudnn7=${CUDNN}+cuda${CUDA} \
-        libcudnn7-dev=${CUDNN}+cuda${CUDA} \
+        libcufft-dev-${CUDA/./-} \
+        libcurand-dev-${CUDA/./-} \
+        libcusolver-dev-${CUDA/./-} \
+        libcusparse-dev-${CUDA/./-} \
+        libcudnn8=${CUDNN}+cuda${CUDA} \
+        libcudnn8-dev=${CUDNN}+cuda${CUDA} \
         libcurl3-dev \
         libfreetype6-dev \
         libhdf5-serial-dev \
@@ -67,7 +64,7 @@
         git \
         && \
     find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
-    rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+    rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v8.a
 
 # Install TensorRT if not building for PowerPC
 RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
@@ -104,7 +101,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
@@ -154,9 +151,7 @@
 
 RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
 RUN mkdir /.local && chmod a+rwx /.local
-RUN apt-get install -y --no-install-recommends wget
-# some examples require git to fetch dependencies
-RUN apt-get install -y --no-install-recommends git
+RUN apt-get update && apt-get install -y --no-install-recommends wget git
 WORKDIR /tf/tensorflow-tutorials
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
index cf84f4a..a4f379b 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
@@ -22,37 +22,34 @@
 ARG UBUNTU_VERSION=18.04
 
 ARG ARCH=
-ARG CUDA=10.1
+ARG CUDA=11.0
 FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
 # ARCH and CUDA are specified again because the FROM directive resets ARGs
 # (but their default value is retained if set previously)
 ARG ARCH
 ARG CUDA
-ARG CUDNN=7.6.4.38-1
-ARG CUDNN_MAJOR_VERSION=7
+ARG CUDNN=8.0.4.30-1
+ARG CUDNN_MAJOR_VERSION=8
 ARG LIB_DIR_PREFIX=x86_64
-ARG LIBNVINFER=6.0.1-1
-ARG LIBNVINFER_MAJOR_VERSION=6
+ARG LIBNVINFER=7.1.3-1
+ARG LIBNVINFER_MAJOR_VERSION=7
 
 # Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cuda-command-line-tools-${CUDA/./-} \
-        # There appears to be a regression in libcublas10=10.2.2.89-1 which
-        # prevents cublas from initializing in TF. See
-        # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
-        libcublas10=10.2.1.243-1 \ 
-        libcublas-dev=10.2.1.243-1 \
+        libcublas-${CUDA/./-} \
+        libcublas-dev-${CUDA/./-} \
         cuda-nvrtc-${CUDA/./-} \
         cuda-nvrtc-dev-${CUDA/./-} \
         cuda-cudart-dev-${CUDA/./-} \
-        cuda-cufft-dev-${CUDA/./-} \
-        cuda-curand-dev-${CUDA/./-} \
-        cuda-cusolver-dev-${CUDA/./-} \
-        cuda-cusparse-dev-${CUDA/./-} \
-        libcudnn7=${CUDNN}+cuda${CUDA} \
-        libcudnn7-dev=${CUDNN}+cuda${CUDA} \
+        libcufft-dev-${CUDA/./-} \
+        libcurand-dev-${CUDA/./-} \
+        libcusolver-dev-${CUDA/./-} \
+        libcusparse-dev-${CUDA/./-} \
+        libcudnn8=${CUDNN}+cuda${CUDA} \
+        libcudnn8-dev=${CUDNN}+cuda${CUDA} \
         libcurl3-dev \
         libfreetype6-dev \
         libhdf5-serial-dev \
@@ -67,7 +64,7 @@
         git \
         && \
     find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
-    rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+    rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v8.a
 
 # Install TensorRT if not building for PowerPC
 RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
@@ -104,7 +101,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
index 6ef0810..dbb9942 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
@@ -22,17 +22,17 @@
 ARG UBUNTU_VERSION=18.04
 
 ARG ARCH=
-ARG CUDA=10.1
+ARG CUDA=11.0
 FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
 # ARCH and CUDA are specified again because the FROM directive resets ARGs
 # (but their default value is retained if set previously)
 ARG ARCH
 ARG CUDA
-ARG CUDNN=7.6.4.38-1
-ARG CUDNN_MAJOR_VERSION=7
+ARG CUDNN=8.0.4.30-1
+ARG CUDNN_MAJOR_VERSION=8
 ARG LIB_DIR_PREFIX=x86_64
-ARG LIBNVINFER=6.0.1-1
-ARG LIBNVINFER_MAJOR_VERSION=6
+ARG LIBNVINFER=7.1.3-1
+ARG LIBNVINFER_MAJOR_VERSION=7
 
 # Needed for string substitution
 SHELL ["/bin/bash", "-c"]
@@ -40,17 +40,14 @@
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cuda-command-line-tools-${CUDA/./-} \
-        # There appears to be a regression in libcublas10=10.2.2.89-1 which
-        # prevents cublas from initializing in TF. See
-        # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
-        libcublas10=10.2.1.243-1 \ 
+        libcublas-${CUDA/./-} \
         cuda-nvrtc-${CUDA/./-} \
-        cuda-cufft-${CUDA/./-} \
-        cuda-curand-${CUDA/./-} \
-        cuda-cusolver-${CUDA/./-} \
-        cuda-cusparse-${CUDA/./-} \
+        libcufft-${CUDA/./-} \
+        libcurand-${CUDA/./-} \
+        libcusolver-${CUDA/./-} \
+        libcusparse-${CUDA/./-} \
         curl \
-        libcudnn7=${CUDNN}+cuda${CUDA} \
+        libcudnn8=${CUDNN}+cuda${CUDA} \
         libfreetype6-dev \
         libhdf5-serial-dev \
         libzmq3-dev \
@@ -82,7 +79,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
@@ -127,9 +124,7 @@
 
 RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
 RUN mkdir /.local && chmod a+rwx /.local
-RUN apt-get install -y --no-install-recommends wget
-# some examples require git to fetch dependencies
-RUN apt-get install -y --no-install-recommends git
+RUN apt-get update && apt-get install -y --no-install-recommends wget git
 WORKDIR /tf/tensorflow-tutorials
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
index f10e9f9..b22077c 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
@@ -22,17 +22,17 @@
 ARG UBUNTU_VERSION=18.04
 
 ARG ARCH=
-ARG CUDA=10.1
+ARG CUDA=11.0
 FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
 # ARCH and CUDA are specified again because the FROM directive resets ARGs
 # (but their default value is retained if set previously)
 ARG ARCH
 ARG CUDA
-ARG CUDNN=7.6.4.38-1
-ARG CUDNN_MAJOR_VERSION=7
+ARG CUDNN=8.0.4.30-1
+ARG CUDNN_MAJOR_VERSION=8
 ARG LIB_DIR_PREFIX=x86_64
-ARG LIBNVINFER=6.0.1-1
-ARG LIBNVINFER_MAJOR_VERSION=6
+ARG LIBNVINFER=7.1.3-1
+ARG LIBNVINFER_MAJOR_VERSION=7
 
 # Needed for string substitution
 SHELL ["/bin/bash", "-c"]
@@ -40,17 +40,14 @@
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cuda-command-line-tools-${CUDA/./-} \
-        # There appears to be a regression in libcublas10=10.2.2.89-1 which
-        # prevents cublas from initializing in TF. See
-        # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
-        libcublas10=10.2.1.243-1 \ 
+        libcublas-${CUDA/./-} \
         cuda-nvrtc-${CUDA/./-} \
-        cuda-cufft-${CUDA/./-} \
-        cuda-curand-${CUDA/./-} \
-        cuda-cusolver-${CUDA/./-} \
-        cuda-cusparse-${CUDA/./-} \
+        libcufft-${CUDA/./-} \
+        libcurand-${CUDA/./-} \
+        libcusolver-${CUDA/./-} \
+        libcusparse-${CUDA/./-} \
         curl \
-        libcudnn7=${CUDNN}+cuda${CUDA} \
+        libcudnn8=${CUDNN}+cuda${CUDA} \
         libfreetype6-dev \
         libhdf5-serial-dev \
         libzmq3-dev \
@@ -82,7 +79,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
diff --git a/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile
index cd84872..49905e7 100644
--- a/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile
@@ -5,9 +5,7 @@
 
 RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
 RUN mkdir /.local && chmod a+rwx /.local
-RUN apt-get install -y --no-install-recommends wget
-# some examples require git to fetch dependencies
-RUN apt-get install -y --no-install-recommends git
+RUN apt-get update && apt-get install -y --no-install-recommends wget git
 WORKDIR /tf/tensorflow-tutorials
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
 RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile
index a3c0738..6318a5f 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile
@@ -6,7 +6,7 @@
     python3-pip
 
 RUN python3 -m pip --no-cache-dir install --upgrade \
-    pip \
+    "pip<20.3" \
     setuptools
 
 # Some TF tools expect a "python" binary
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index 72374a8..6d74df3 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -183,8 +183,8 @@
         ":base_dir_oss",
         "//tensorflow:tensorflow_py",
         "//tensorflow/python:framework_ops",
-        "//tensorflow/python:tf_export",
         "//tensorflow/python:util",
+        "//tensorflow/python/util:tf_export",
         "@absl_py//absl:app",
         "@absl_py//absl/flags",
     ],
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 2f5eceb..3aeaf86 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -223,6 +223,7 @@
     visibility = [
         "//tensorflow/core:__pkg__",
         "//tensorflow/python:__pkg__",
+        "//tensorflow/python/util:__pkg__",
     ],
 )
 
@@ -336,9 +337,9 @@
     srcs_version = "PY2AND3",
     deps = [
         "//tensorflow/core:protos_all_py",
-        "//tensorflow/python:_pywrap_transform_graph",
         "//tensorflow/python:errors",
         "//tensorflow/python:util",
+        "//tensorflow/python/util:_pywrap_transform_graph",
     ],
 )
 
diff --git a/tensorflow/tools/graph_transforms/__init__.py b/tensorflow/tools/graph_transforms/__init__.py
index 8746567..84f7ea0 100644
--- a/tensorflow/tools/graph_transforms/__init__.py
+++ b/tensorflow/tools/graph_transforms/__init__.py
@@ -19,8 +19,8 @@
 
 # pylint: disable=unused-import,wildcard-import, line-too-long
 from tensorflow.core.framework import graph_pb2
-from tensorflow.python._pywrap_transform_graph import TransformGraphWithStringInputs
 from tensorflow.python.util import compat
+from tensorflow.python.util._pywrap_transform_graph import TransformGraphWithStringInputs
 
 
 def TransformGraph(input_graph_def, inputs, outputs, transforms):
diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc
index 5b9fa84..a004d7f 100644
--- a/tensorflow/tools/graph_transforms/transform_graph.cc
+++ b/tensorflow/tools/graph_transforms/transform_graph.cc
@@ -26,6 +26,7 @@
 #include "tensorflow/tools/graph_transforms/transform_utils.h"
 #if !defined(PLATFORM_WINDOWS)
 #include <pwd.h>
+#include <unistd.h>
 #endif
 
 namespace tensorflow {
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 21e031c..1252c01 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -46,17 +46,17 @@
         # Code that relies on these headers should dynamically link to
         # _pywrap_tensorflow_internal.so as well.
         "//tensorflow/python:model_analyzer_lib",
-        "//tensorflow/python:py_exception_registry",
-        "//tensorflow/python:pybind11_absl",
-        "//tensorflow/python:pybind11_lib",
-        "//tensorflow/python:pybind11_status",
-        "//tensorflow/python:pybind11_proto",
-        "//tensorflow/python:kernel_registry",
-        "//tensorflow/python:cpp_python_util",
-        "//tensorflow/python:py_func_lib",
-        "//tensorflow/python:py_seq_tensor",
-        "//tensorflow/python:py_util",
-        "//tensorflow/python:py_record_reader_lib",
+        "//tensorflow/python/lib/core:py_exception_registry",
+        "//tensorflow/python/lib/core:pybind11_proto",
+        "//tensorflow/python/lib/core:pybind11_absl",
+        "//tensorflow/python/lib/core:pybind11_lib",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "//tensorflow/python/lib/core:py_func_lib",
+        "//tensorflow/python/lib/core:py_seq_tensor",
+        "//tensorflow/python/lib/core:py_util",
+        "//tensorflow/python/lib/io:py_record_reader_lib",
+        "//tensorflow/python/util:cpp_python_util",
+        "//tensorflow/python/util:kernel_registry",
         "//tensorflow/python:python_op_gen",
         "//tensorflow/python:tf_session_helper",
         "//third_party/eigen3",
@@ -116,7 +116,6 @@
     "//tensorflow/python/training/experimental:loss_scale_optimizer",
     "//tensorflow/python:memory_checker",
     "//tensorflow/python:meta_graph_testdata",
-    "//tensorflow/python:util_example_parser_configuration",
     "//tensorflow/python/data/benchmarks:benchmark_base",
     "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base",
     "//tensorflow/python/data/experimental/kernel_tests:data_service_test_base",
@@ -147,6 +146,7 @@
     "//tensorflow/python/tools:tools_pip",
     "//tensorflow/python/tools/api/generator:create_python_api",
     "//tensorflow/python/tpu",
+    "//tensorflow/python/util:example_parser_configuration",
     "//tensorflow/python:image_grad_test_base",
     "//tensorflow/python:test_ops",
     "//tensorflow/python:while_v2",
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index 60e1ae5..a862d6d 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -87,7 +87,7 @@
     "//tensorflow/python/debug:grpc_tensorflow_server.par",
     "//tensorflow/python/feature_column:vocabulary_testdata",
     "//tensorflow/python:framework/test_file_system.so",
-    "//tensorflow/python:util_nest_test_main_lib",
+    "//tensorflow/python/util:nest_test_main_lib",
     # lite
     "//tensorflow/lite/experimental/examples/lstm:rnn_cell",
     "//tensorflow/lite/experimental/examples/lstm:rnn_cell.py",
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 613ce9f..d84b08d 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -79,7 +79,7 @@
     'astunparse ~= 1.6.3',
     'flatbuffers ~= 1.12.0',
     'google_pasta ~= 0.2',
-    'h5py ~= 2.10.0',
+    'h5py ~= 3.1.0',
     'keras_preprocessing ~= 1.1.2',
     'numpy ~= 1.19.2',
     'opt_einsum ~= 3.3.0',
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions.cc b/tensorflow/tools/proto_text/gen_proto_text_functions.cc
index 159976f..76e1456 100644
--- a/tensorflow/tools/proto_text/gen_proto_text_functions.cc
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions.cc
@@ -105,7 +105,7 @@
     const tensorflow::protobuf::FileDescriptor* fd =
         importer.Import(proto_path);
 
-    const int index = proto_path.find_last_of(".");
+    const int index = proto_path.find_last_of('.');
     string proto_path_no_suffix = proto_path.substr(0, index);
 
     proto_path_no_suffix =
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 2a9061b..a16b8bf 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -134,31 +134,31 @@
     # and update the sha256 with the result.
     tf_http_archive(
         name = "XNNPACK",
-        sha256 = "4982a2b2849fc3853bf8dda099e46306477c2abd139481adf37f6835e227a860",
-        strip_prefix = "XNNPACK-6eaa1521288d268dd4cceca4ae5c018cf009179b",
+        sha256 = "59ccf0c1c64899b511f8872a278e54c293970f57933b056492a364aa5ac709ec",
+        strip_prefix = "XNNPACK-094e692629d57ddb932fcc993193626f60daa61b",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/6eaa1521288d268dd4cceca4ae5c018cf009179b.zip",
-            "https://github.com/google/XNNPACK/archive/6eaa1521288d268dd4cceca4ae5c018cf009179b.zip",
+            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/094e692629d57ddb932fcc993193626f60daa61b.zip",
+            "https://github.com/google/XNNPACK/archive/094e692629d57ddb932fcc993193626f60daa61b.zip",
         ],
     )
 
     tf_http_archive(
         name = "FXdiv",
-        sha256 = "ab7dfb08829bee33dca38405d647868fb214ac685e379ec7ef2bebcd234cd44d",
-        strip_prefix = "FXdiv-b408327ac2a15ec3e43352421954f5b1967701d1",
+        sha256 = "3d7b0e9c4c658a84376a1086126be02f9b7f753caa95e009d9ac38d11da444db",
+        strip_prefix = "FXdiv-63058eff77e11aa15bf531df5dd34395ec3017c8",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/FXdiv/archive/b408327ac2a15ec3e43352421954f5b1967701d1.zip",
-            "https://github.com/Maratyszcza/FXdiv/archive/b408327ac2a15ec3e43352421954f5b1967701d1.zip",
+            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip",
+            "https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip",
         ],
     )
 
     tf_http_archive(
         name = "pthreadpool",
-        sha256 = "8461f6540ae9f777ce20d1c0d1d249e5e61c438744fb390c0c6f91940aa69ea3",
-        strip_prefix = "pthreadpool-545ebe9f225aec6dca49109516fac02e973a3de2",
+        sha256 = "e576de3e2504018462a3ee2282c99c2d0d708f01d17cd2f71f9f1fe6d3ba8b9b",
+        strip_prefix = "pthreadpool-77f9d3bcfabd1bdb910dd33b549d5290b968ef05",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/pthreadpool/archive/545ebe9f225aec6dca49109516fac02e973a3de2.zip",
-            "https://github.com/Maratyszcza/pthreadpool/archive/545ebe9f225aec6dca49109516fac02e973a3de2.zip",
+            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/pthreadpool/archive/77f9d3bcfabd1bdb910dd33b549d5290b968ef05.zip",
+            "https://github.com/Maratyszcza/pthreadpool/archive/77f9d3bcfabd1bdb910dd33b549d5290b968ef05.zip",
         ],
     )
 
@@ -202,11 +202,11 @@
         name = "eigen_archive",
         build_file = clean_dep("//third_party:eigen.BUILD"),
         patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
-        sha256 = "306f15c04fbd514b4adc3a327a2c6f63521ea6805cab75691fa30c30fea55193",  # SHARED_EIGEN_SHA
-        strip_prefix = "eigen-fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed",
+        sha256 = "df23a89e4cdfa7de2d81ee28190bd194413e47ff177c94076f845b32d7280344",  # SHARED_EIGEN_SHA
+        strip_prefix = "eigen-5dc2fbabeee17fe023c38756ebde0c1d56472913",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed/eigen-fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed.tar.gz",
-            "https://gitlab.com/libeigen/eigen/-/archive/fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed/eigen-fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed.tar.gz",
+            "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/5dc2fbabeee17fe023c38756ebde0c1d56472913/eigen-5dc2fbabeee17fe023c38756ebde0c1d56472913.tar.gz",
+            "https://gitlab.com/libeigen/eigen/-/archive/5dc2fbabeee17fe023c38756ebde0c1d56472913/eigen-5dc2fbabeee17fe023c38756ebde0c1d56472913.tar.gz",
         ],
     )
 
@@ -685,8 +685,8 @@
     )
 
     # Check out LLVM and MLIR from llvm-project.
-    LLVM_COMMIT = "ecaff13fc0bc1105ad910a72a5d0dcd164b35191"
-    LLVM_SHA256 = "d0178d6f6a23ce60752d11ee8b1d64784d8ce9625f03d76943b0e40a0043211a"
+    LLVM_COMMIT = "1b97cdf885d6455841280b8da858835e641ee941"
+    LLVM_SHA256 = "80d5036ba734fcb700a5699e2f99e5a0de5808dde01a1df3c4fae04510bc8e23"
     LLVM_URLS = [
         "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
         "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
diff --git a/third_party/gpus/find_rocm_config.py b/third_party/gpus/find_rocm_config.py
index c1eb119..69b2c19 100644
--- a/third_party/gpus/find_rocm_config.py
+++ b/third_party/gpus/find_rocm_config.py
@@ -251,6 +251,28 @@
   return hipsparse_config
 
 
+def _find_rocsolver_config(rocm_install_path):
+
+  def rocsolver_version_numbers(path):
+    version_file = os.path.join(path, "rocsolver/include/rocsolver-version.h")
+    if not os.path.exists(version_file):
+      raise ConfigError(
+          'rocsolver version file "{}" not found'.format(version_file))
+    major = _get_header_version(version_file, "ROCSOLVER_VERSION_MAJOR")
+    minor = _get_header_version(version_file, "ROCSOLVER_VERSION_MINOR")
+    patch = _get_header_version(version_file, "ROCSOLVER_VERSION_PATCH")
+    return major, minor, patch
+
+  major, minor, patch = rocsolver_version_numbers(rocm_install_path)
+
+  rocsolver_config = {
+      "rocsolver_version_number":
+          _get_composite_version_number(major, minor, patch)
+  }
+
+  return rocsolver_config
+
+
 def find_rocm_config():
   """Returns a dictionary of ROCm components config info."""
   rocm_install_path = _get_rocm_install_path()
@@ -269,6 +291,7 @@
   result.update(_find_rocfft_config(rocm_install_path))
   result.update(_find_roctracer_config(rocm_install_path))
   result.update(_find_hipsparse_config(rocm_install_path))
+  result.update(_find_rocsolver_config(rocm_install_path))
 
   return result
 
diff --git a/third_party/gpus/find_rocm_config.py.gz.base64 b/third_party/gpus/find_rocm_config.py.gz.base64
index f38b64f..60247f1 100644
--- a/third_party/gpus/find_rocm_config.py.gz.base64
+++ b/third_party/gpus/find_rocm_config.py.gz.base64
@@ -1 +1 @@
-eJy9WtFu27gSfddXEAqKyhtHSXsfdpFFHrxpFvXe1gns7C4WTWDQMm1zK4u6JJXUKPrvd4akZEqWEydOa6CoZQ0PhzNnDkdiDsi5yFeSzxeavD15e0KuF4xcs0wJ+Xsq7kmv0AshVUx6aUqGaKbIkCkm79g0Dg6CA/KBJ2DOpqTIpkwSDeN7OU3gP3enS/5iUnGRkbfxCYnQIHS3ws6vgLASBVnSFcmEJoViAMEVmfGUEfYlYbkmPCOJWOYpp1nCyD3XCzONAwE3yD8OQkw0BWsK9jlczXw7QrVxGD8LrfPT4+P7+/uYGmdjIefHqTVUxx/65xeD0cUROGyG/JmlTCki2f8KLmGpkxWhOfiT0Al4mdJ7IiShc8ngnhbo773kmmfzLlFipu+pZIAy5UpLPil0LVild7Bm3wDCRTMS9kakPwrJb71Rf9QFjL/71+8v/7wmf/eGw97gun8xIpdDcn45eNe/7l8O4Op30hv8Q/7bH7zrEgahgmnYl1yi/+AkxzCa1JERYzUHZsI6pHKW8BlPYF3ZvKBzRubijskMlkNyJpdcYTIVuDcFlJQvuaba/LKxKJzm7EU/QRiGV5JnSMPL8yVMP5FUrtAZsmAU559CihItJGfGR3Jn2QeUEuAgBtascqU0W8ZBgIRXieTAM8WoBC4oE4pt8EhMVUfpQsYxaloF8OMSKTBlGkOVmRBzWTphgHLrP45PRDbj80KaAOI4paei0LHxKqdIdFGCI0NcbpBmCymK+QJJwrI7LkW2ZJkmd1RyQ8oI/P84vupdv+/EQX8GxQX3Uj5tTMldWLp2OTYOpYPGHSalSbVkupAm7QR+ggAlYsrq8dP0M7PrKnOw8jyGosFblV+tfsc+XirEZ5sMG3ubzzInNhGm2hdUTo/QnynkUEPdB6qY+DyYSbEkE6pcUJ0wrH2r/I0JxGrtIoQHVCmoDE2YoCyPRa6PpUiWIZoUKH8UfNGQ9xktUlxPWrAA2RoEUHNCQvpE+U2o8hvogvsGTAqCIEkp1Om5SdEFRjm6MBIIqeqcBgS8V2gGs5DxnOmxm26MroxxaZExs7ny3fQHGWMglaZp6g0CX9+VrLWRLlPu0rYkbpClDo6McYEwXxORnG31D8z5jIRVjENMoVCxYwN60g64tvnkjb4F+4Nt9sZFyWhqpt4w6qxDtXHPDxhuPkJxzcaujMdZsZwwGS3pv0J2CUQM/4NhycKP/5sT+JCfiDEjh3iNV2gNV8bcn8aSvJwjQi+6JKNLVqZn6PIBap6DwyDpAATDTcIgiue+RrnUYPWkqENwn4tY5KxEDmUIO0QGZQPKfhYWenb0S9ix8V+ibxBDyWLzNZLhgZ2IvFLkMLqZHnZC8sp41zX4HTMOMmvsLQqxZQA45sd4DpKVR2867qaLEmhNZOw6AYaOcqhSvwJelzOHX7+ZcrO72k32OobFAXJkYkQOHWz9ExLTV+AGiDoDUTCdxddvsG/eZGEJYeiwHYKDHNq1WEVAFQS/2ZxhdKFoaBp2ylyCs1PLeKu1LcQ7xaWisblVZ5WKnIkJoLtlnF5z+l/BqzTGPJuJY2d4NGV3YZUL9LQcwr5Ad6EiH7BTpqkl6Kbey13LzG7Cj4hmw6iCXwPs1Jx26wG/P92aG6Z1MxSsjULtnFWUcXewFYKO44zMsIKnUSfGX/Ko07BbT1IfGSvodHQUxi4ctgjPLN3qYz+d3DobU5rtNm+cTe5Ko83m7W0561HYqVAdz1vEAlnQ8jMWXhsvWgSsFF/LNBj41YUnbEEIT5+hZ4D3LWgIpZ2txvcFz2WRab5kO7DeM2548DTqA84xz5K0mLJj/A7/SsB48RJl4OnB6/f9KzK0Xj+3MA7g8QoExFNpaF1sS/UFHyeggSjFDnUbZhz/dTEcQWPfJfcLAc7ZvsJiAVLkmYw/9v6AZ4GfzDZzSGp3+oPLYekC7CKS8Tvb/Jdbk2W+3fSPEpomBezy4B2ToO+Km+cuaOuEaYcszj0FPQU6YSUIAzYpFG5FCvurnMLqoX6hr4OaxycQVSvDth3Pjxhkd2NtYb1KnwqBQQhrJfmmtj0bXKToBpf9utrKXaiu7bxu7T38utqYs1ZdS46quUNlOcP9dhQLUlWWu/xuhfWxfwn4z62pJxAKJrq6GOzLqSaKR6tSvZ+OAr3s+fvw+fvFlry37xg1NvncbkUJT71k7b1/1OZudkwTePzZrWkylg+wHPxTHJ4nx37YTSdSLSZ0MBXP3fXRmuhdo5fYCf0n/rl8dgeN2AqBPaHMaPoI1i8G4batHgcis/qOG8AM+9X2pTTbJfyx+eTj1fSsbJugah+o2LEXwhbnNkwrwwn0aJ99WfBNd5KBMpZ1HagkACMBmlAKQaNRb4/R0xUC8vPbh95oX4nYgHmeRjRh9hWJbXWzta/06rHRWrbgvKxO1GdvCoWEWtxNKIzlXk2mA/Flwgd9+f2wtRCe+/S1M9OGvcG7kml1jtUBHTW8DDSo0RJxaI22pOLRvqg+VZMHs5nejQZouF9TZEF8FsDl0fckAeD/gKaoERwz8umK10TBkU8XvAaKGbmf3rXlfavcrdnUoPQmyouL3XruJse1pAmEbCeaO9u9mW5xfLLbX74Lzy30D6A6qNz1sHd+MXyB/b0J5O3wB+tMVGdSgtn3hgsKD97UMbRcsRNWv1hO9iH9FhZs5X2NYQ3qt2K9OPtrHjTfKyl8l7DjayVnu18BVDj+yyX7y/cT/GqKH1AI1VzuZP7j80R/A+Z5qt+EudpT9rfToL0CmhRrvO1pxXrZCmh64Cpg4yCheRBEyZQneBiH56FiZh8vjSMZwxNctx48IHjspK7lUDDYxu62EnyE4K9H1Vnr+nTVcLuSRgNeEXxzDqtWTOHxKmTo2/ryk33ZrYVIP3NtrMPb8iV67VCvHBEX+ZRqFu1yVtPZMmqXN97bxj72Pm/buEffkTww8OFnpgcGPthkPzDuscblgcA+Jvkdf/swEK5klhRk3NBRy9Vp9SLlM1t1ywO8jCghNZtGm+UVQ/kuVdSpJNv8JUQUvlKn5JXCA89ojWT8d3+h5HEeD7PcGw+1UrH9C4oY/x6IReFNdjEcXg5Pgb43mXf8qLSMALBTDYNi0HhSGgRQguMxnnCOx+TsjITjMa5xPDYKZJcb/B9qO76I
\ No newline at end of file
+eJy9Wm1v2zgS/q5fQSgoKm8cJe19uEUO+eBNs6j32iSws10smsCgbdrmRhZ1JJU0KPrfb4akZEqWEid2GqCoJQ0fDmeeeaGoPXIqsgfJ5wtN3h+9PyJXC0auWKqE/D0R96SX64WQKia9JCEDFFNkwBSTd2waB3vBHvnEJyDOpiRPp0wSDeN7GZ3Af+5Jl3xhUnGRkvfxEYlQIHSPws5/AOFB5GRJH0gqNMkVAwiuyIwnjLBvE5ZpwlMyEcss4TSdMHLP9cJM40BADfK3gxBjTUGagnwGVzNfjlBtFMa/hdbZ8eHh/f19TI2ysZDzw8QKqsNP/dOz8+HZAShshvyZJkwpItn/ci5hqeMHQjPQZ0LHoGVC74mQhM4lg2daoL73kmuezrtEiZm+p5IBypQrLfk41xVjFdrBmn0BMBdNSdgbkv4wJL/1hv1hFzD+6l99vPjzivzVGwx651f9syG5GJDTi/MP/av+xTlc/U5653+T//bPP3QJA1PBNOxbJlF/UJKjGY3ryJCxigIzYRVSGZvwGZ/AutJ5TueMzMUdkyksh2RMLrlCZypQbwooCV9yTbW5s7YonOZkp39BGIaXkqdIw4vTJUw/llQ+oDJkwSjOPwUXTbSQnBkdyZ1lH1BKgIJoWLPKB6XZMg4CJLyaSA48U4xK4IIypmiDR2KqKkoXPI5W0yqAm0ukwJRpNFVqTMxloYQByqz+OH4i0hmf59IYEMcpPRW5jo1WGUWiiwIcGeJ8gzRbSJHPF0gSlt5xKdIlSzW5o5IbUkag/+fRZe/qYycO+jMILniW8GltSu7M0rXLsXYoFDTqMCmNqyXTuTRuJ3ALDDQRU1a1n6a3zK6r8MGDpzEEDT4q9WrUO/bxEiFurTOs7a0/C59YR5hoX1A5PUB9puBDDXEfqHzs82AmxZKMqXJGdYlhpVupb0zAVisVwTyQlYJS0JgJwvJQZPpQiskyRJEc0x8FXTT4fUbzBNeT5CxAtgYBxJyQ4D5R/BKq+AV5wf0CJgVBMEkoxOmpcdEZWjk6MykQXNU5Dghor1AMZiGjOdMjN90IVRnh0iIjZn3lq+kPMsJAKk2TxBsEun4oWGstXbjcuW1J3CBLHRwZ4wJhvjoiOWnVD8T5jISljUN0oVCxYwNq0gy4kvnqjb4B+b02eaOiZDQxU68JdVamWnvmGwyLj1Bcs5EL41GaL8dMRkv6j5BdAhbD/2DYZOHb/90R/JFfiBEj+3iNVygNV0bcn8aSvJgjQi26JKVLVrhn4PwB2TwDhSGlAxAMNw4DK576Ocq5BqMnwTwEz7mIRcYK5FCGUCFSCBvI7CdhrmcHv4Yda/8l6gY2lCw2PyMZ7tmJyBtF9qPr6X4nJG+Mdl2D3zHjwLNG3qIQGwaAY27Gc0hZWfSu4x46K0GuiYxcJ0DTUQ5R6kfA22Lm8PsPE262ql2nb2NYHCBHxkZk38FW/0Ji+gosgJhnwAqms/j+A+rmdRoWEIYO7RAc0qFdi80ImAVBbzZnaF0IGpqEncKXoOzUMt7m2gbiHeNSUdg8qrJKRU7EGNA9MkqvOP2P4KUbY57OxKETPJiyu7D0BWpaDGHfoLtQkQ/YKdzUYHQT70XVMrMb8yOiKRil8SuAnYrSbj2g99cb88C0boaClVGYO2clZdwTbIWg4zghM4zgadSJ8U4WdWpyq0mqI2MFnY6OwtiZwwbhiaVbdezXoxsnY0KzWeadk8lcaDTJvL8pZj0IOyWq43lDskAWNNzGwGviRUMCK5KvZRoM/O7MEzYghMcvyGeA9yOoJUo7W4XvC57JPNV8yTZgvSdc0+B51AecQ55OknzKDvE3/CsA48UuwsDLB28/9i/JwGr90sDYg+0VJBAvS0PrYluqb7idgAaiSHaYt2HG0ZezwRAa+y65XwhQzvYVFguQIk9k9Ln3B+wFfjFlZp9UnvTPLwaFClBFJON3tvkvSpNlvi36BxOaTHKo8qAdk5DfFTf7LmjrhGmHLM49hXwKdMJIEAZsnCssRQr7q4zC6iF+oa+DmMcdiKqEYVPF8y0G3l1bW1iN0udCoBHCSki+q5Rng4sUXeOyH1et3IXoaud1Y+/hx9XanJXoWnLMmhtElhPcrqJYkDKy3OWrBdbn/gXgvzSmnkEomOjy7HxbTtVRPFoV2fv5KNDLnn4MX14vWvzeXDEqbPK53YgSHnvO2rp+VOaud0xj2P5s1jQZyUdYDvopDvvJkW9204mUiwkdTMlzd32wInrX5EvshP4V/7vYu0OOaIXAnlCmNHkC61eDcNMUj+citfkdC8AM+9XmpdTbJbxZ3/l4MT0r2iaI2kciduSZsEG5NdFScAw92q2fFnzRjdJAYctqHihTAFoCckKRCGqNerONnp8hwD+/feoNt00RazAvyxF1mG2TRFvctPaVXjzWWssGnN3miers9UQhIRY3SxRGcqsm04H4acIH3X09bAyEl+6+NmbaoHf+oWBalWNVQEcNzwM1ajRYHFqjFlc82RdVp6rzYDbTm9EABbdriiyIzwK4PHhNEgD+T2iKasYxI5+f8eooOPL5Ca+GYkZul++a/N6a7lZsqlF6HWXnyW41d53jWtIJmGwjmjvZrZlucXyy2zuvwnML/ROoDlnuatA7PRvsoL7XgbwKv7fyRHkmJZh9b7igsPGmjqHFil1i9YPlaBvSt7CglfcVhtWo34i1c/ZXNKi/V1L4LmHD10pOdrsAKHH8l0v2zusl/HKKnxAI5VzuZP7zy5L+GszLsn4d5nLLtN9Og+YIqFOs9ranEWu3EVDXoF4ClEjuNi0BTnbrEmBx/BJg77xqy2On+DmlYHjx6ctOSkEd6MWbvRrQDrZ7LWRorQQVotUqQSPWzitBRQMXB2sHavUDUUqmfIKH0vhdgJjZ1yxGkZThlwxuPXhQ9tSJdcPheNDG8aZAfILmb4flNwerrwwMw8sWwYCXNF+fw/qKKfzMADz0Y3X51R76aCGSW66NdHhTHCZVDreLEXGeTalm0SZnlp2WUZuc/LSNfeq9dtu4J98VPjLw8XcHjwx8dLP5yLinGvhHDPtU6/PIrE/VjI4fdwbCRduSQh0wTNby4bh8F3nLHrrFGXhKlJCaTaP1yIwh8pcq6pQ533xMFIVv1DF5o/CbgWiFZPR3H/l54YLnwe6loXpQsf0IKcZP6lgUXqdng8HF4BiYf516J/hKywgAO+UwiCONHxsEAUTvaIQfCYxG5OSEhKMRrnE0MsnLLjf4P/9j3RI=
\ No newline at end of file
diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl
index ecbb4b5..3161219 100644
--- a/third_party/gpus/rocm/BUILD.tpl
+++ b/third_party/gpus/rocm/BUILD.tpl
@@ -109,6 +109,7 @@
         ":hiprand",
         ":miopen",
         ":hipsparse",
+        ":rocsolver",
     ],
 )
 
@@ -143,6 +144,12 @@
     data = ["rocm/lib/%{hipsparse_lib}"],
 )
 
+cc_library(
+    name = "rocsolver",
+    srcs = ["rocm/lib/%{rocsolver_lib}"],
+    data = ["rocm/lib/%{rocsolver_lib}"],
+)
+
 filegroup(
     name = "rocm_root",
     srcs = [
diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl
index ce4c1b0..8b6a929 100644
--- a/third_party/gpus/rocm/build_defs.bzl.tpl
+++ b/third_party/gpus/rocm/build_defs.bzl.tpl
@@ -47,3 +47,7 @@
     if rocm_is_configured():
       return x
     return []
+
+def rocm_library(copts = [], **kwargs):
+    """Wrapper over cc_library which adds default ROCm options."""
+    native.cc_library(copts = rocm_default_copts() + copts, **kwargs)
diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl
index 10f03bf..6218a17 100644
--- a/third_party/gpus/rocm_configure.bzl
+++ b/third_party/gpus/rocm_configure.bzl
@@ -332,6 +332,7 @@
             ("MIOpen", rocm_config.rocm_toolkit_path + "/miopen"),
             ("rccl", rocm_config.rocm_toolkit_path + "/rccl"),
             ("hipsparse", rocm_config.rocm_toolkit_path + "/hipsparse"),
+            ("rocsolver", rocm_config.rocm_toolkit_path + "/rocsolver"),
         ]
     ]
 
@@ -457,6 +458,7 @@
             "%{rocfft_lib}": _lib_name("rocfft"),
             "%{hiprand_lib}": _lib_name("hiprand"),
             "%{hipsparse_lib}": _lib_name("hipsparse"),
+            "%{rocsolver_lib}": _lib_name("rocsolver"),
             "%{copy_rules}": "",
             "%{rocm_headers}": "",
         },
@@ -574,6 +576,12 @@
             src_dir = rocm_toolkit_path + "/hipsparse/include",
             out_dir = "rocm/include/hipsparse",
         ),
+        make_copy_dir_rule(
+            repository_ctx,
+            name = "rocsolver-include",
+            src_dir = rocm_toolkit_path + "/rocsolver/include",
+            out_dir = "rocm/include/rocsolver",
+        ),
     ]
 
     rocm_libs = _find_libs(repository_ctx, rocm_config, bash_bin)
@@ -627,13 +635,15 @@
             "%{miopen_lib}": rocm_libs["MIOpen"].file_name,
             "%{rccl_lib}": rocm_libs["rccl"].file_name,
             "%{hipsparse_lib}": rocm_libs["hipsparse"].file_name,
+            "%{rocsolver_lib}": rocm_libs["rocsolver"].file_name,
             "%{copy_rules}": "\n".join(copy_rules),
             "%{rocm_headers}": ('":rocm-include",\n' +
                                 '":rocfft-include",\n' +
                                 '":rocblas-include",\n' +
                                 '":miopen-include",\n' +
                                 '":rccl-include",\n' +
-                                '":hipsparse-include",'),
+                                '":hipsparse-include",' +
+                                '":rocsolver-include"'),
         },
     )
 
diff --git a/third_party/hexagon/workspace.bzl b/third_party/hexagon/workspace.bzl
index a22e2db..4331aba 100644
--- a/third_party/hexagon/workspace.bzl
+++ b/third_party/hexagon/workspace.bzl
@@ -7,9 +7,9 @@
 def repo():
     third_party_http_archive(
         name = "hexagon_nn",
-        sha256 = "2b0e29a061f389ad52054c12fcae38991b5f731d7a05770c7ac421433ed17cc2",
+        sha256 = "b94b653417a7eb871881438bb98cb2f4a652d4d92ff90f1faaa01a8ce82b2e3c",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/storage.cloud.google.com/download.tensorflow.org/tflite/hexagon_nn_headers_v1.20.0.0.tgz",
+            "https://storage.googleapis.com/mirror.tensorflow.org/storage.cloud.google.com/download.tensorflow.org/tflite/hexagon_nn_headers_v1.20.0.1.tgz",
         ],
         build_file = "//third_party/hexagon:BUILD",
     )
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 9d595d9..9413625 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -53,6 +53,22 @@
 ]
 
 gentbl(
+    name = "BuiltinDialectIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-dialect-decls",
+            "include/mlir/IR/BuiltinDialect.h.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/IR/BuiltinDialect.td",
+    td_srcs = [
+        ":OpBaseTdFiles",
+    ],
+)
+
+gentbl(
     name = "BuiltinOpsIncGen",
     strip_include_prefix = "include",
     tbl_outs = [
@@ -64,21 +80,40 @@
             "-gen-op-defs",
             "include/mlir/IR/BuiltinOps.cpp.inc",
         ),
-        (
-            "-gen-dialect-decls",
-            "include/mlir/IR/BuiltinDialect.h.inc",
-        ),
     ],
     tblgen = ":mlir-tblgen",
     td_file = "include/mlir/IR/BuiltinOps.td",
     td_srcs = [
         "include/mlir/IR/BuiltinOps.td",
+        "include/mlir/IR/BuiltinDialect.td",
         "include/mlir/Interfaces/CallInterfaces.td",
         "include/mlir/IR/SymbolInterfaces.td",
         ":OpBaseTdFiles",
     ],
 )
 
+gentbl(
+    name = "BuiltinTypesIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "--gen-typedef-decls",
+            "include/mlir/IR/BuiltinTypes.h.inc",
+        ),
+        (
+            "--gen-typedef-defs",
+            "include/mlir/IR/BuiltinTypes.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/IR/BuiltinTypes.td",
+    td_srcs = [
+        "include/mlir/IR/BuiltinTypes.td",
+        "include/mlir/IR/BuiltinDialect.td",
+        ":OpBaseTdFiles",
+    ],
+)
+
 cc_library(
     name = "IR",
     srcs = glob([
@@ -94,7 +129,9 @@
     ],
     includes = ["include"],
     deps = [
+        ":BuiltinDialectIncGen",
         ":BuiltinOpsIncGen",
+        ":BuiltinTypesIncGen",
         ":CallOpInterfacesIncGen",
         ":InferTypeOpInterfaceIncGen",
         ":OpAsmInterfaceIncGen",
@@ -376,6 +413,360 @@
 )
 
 ##---------------------------------------------------------------------------##
+# ArmNeon dialect.
+##---------------------------------------------------------------------------##
+
+filegroup(
+    name = "ArmNeonTdFiles",
+    srcs = [
+        "include/mlir/Dialect/ArmNeon/ArmNeon.td",
+        "include/mlir/Dialect/LLVMIR/LLVMOpBase.td",
+        "include/mlir/IR/OpBase.td",
+        ":SideEffectTdFiles",
+    ],
+)
+
+gentbl(
+    name = "ArmNeonIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-dialect-decls -dialect arm_neon",
+            "include/mlir/Dialect/ArmNeon/ArmNeonDialect.h.inc",
+        ),
+        (
+            "-gen-op-decls",
+            "include/mlir/Dialect/ArmNeon/ArmNeon.h.inc",
+        ),
+        (
+            "-gen-op-defs",
+            "include/mlir/Dialect/ArmNeon/ArmNeon.cpp.inc",
+        ),
+        (
+            "-gen-op-doc",
+            "g3doc/Dialects/ArmNeon/ArmNeon.md",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/ArmNeon/ArmNeon.td",
+    td_srcs = [
+        ":ArmNeonTdFiles",
+    ],
+)
+
+cc_library(
+    name = "ArmNeon",
+    srcs = [
+        "lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Dialect/ArmNeon/ArmNeonDialect.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":ArmNeonIncGen",
+        ":IR",
+        ":SideEffectInterfaces",
+        ":VectorOps",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_library(
+    name = "ArmNeonToLLVM",
+    srcs = glob([
+        "lib/Conversion/ArmNeonToLLVM/*.cpp",
+    ]) + ["lib/Conversion/PassDetail.h"],
+    hdrs = glob([
+        "include/mlir/Conversion/ArmNeonToLLVM/*.h",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":ArmNeon",
+        ":ConversionPassIncGen",
+        ":EDSC",
+        ":IR",
+        ":LLVMArmNeon",
+        ":LLVMDialect",
+        ":Pass",
+        ":StandardOps",
+        ":StandardToLLVM",
+        ":Support",
+        ":Transforms",
+        ":VectorOps",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+filegroup(
+    name = "LLVMArmNeonTdFiles",
+    srcs = [
+        "include/mlir/Dialect/LLVMIR/LLVMArmNeon.td",
+        ":LLVMOpsTdFiles",
+    ],
+)
+
+gentbl(
+    name = "LLVMArmNeonIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-dialect-decls -dialect=llvm_arm_neon",
+            "include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h.inc",
+        ),
+        (
+            "-gen-op-decls",
+            "include/mlir/Dialect/LLVMIR/LLVMArmNeon.h.inc",
+        ),
+        (
+            "-gen-op-defs",
+            "include/mlir/Dialect/LLVMIR/LLVMArmNeon.cpp.inc",
+        ),
+        (
+            "-gen-op-doc",
+            "g3doc/Dialects/LLVMIR/LLVMArmNeon.md",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/LLVMIR/LLVMArmNeon.td",
+    td_srcs = [
+        ":LLVMArmNeonTdFiles",
+    ],
+)
+
+cc_library(
+    name = "LLVMArmNeon",
+    srcs = [
+        "lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":LLVMArmNeonIncGen",
+        ":LLVMDialect",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+gentbl(
+    name = "LLVMArmNeonConversionIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-llvmir-conversions",
+            "include/mlir/Dialect/LLVMIR/LLVMArmNeonConversions.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/LLVMIR/LLVMArmNeon.td",
+    td_srcs = [
+        ":LLVMArmNeonTdFiles",
+    ],
+)
+
+cc_library(
+    name = "TargetLLVMArmNeonIntr",
+    srcs = [
+        "lib/Target/LLVMIR/LLVMArmNeonIntr.cpp",
+    ],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":LLVMArmNeon",
+        ":LLVMArmNeonConversionIncGen",
+        ":LLVMIRModuleTranslation",
+        ":Translation",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+##---------------------------------------------------------------------------##
+# ArmSVE dialect.
+##---------------------------------------------------------------------------##
+
+filegroup(
+    name = "ArmSVETdFiles",
+    srcs = [
+        "include/mlir/Dialect/ArmSVE/ArmSVE.td",
+        "include/mlir/Dialect/LLVMIR/LLVMOpBase.td",
+        "include/mlir/IR/OpBase.td",
+        ":SideEffectTdFiles",
+    ],
+)
+
+gentbl(
+    name = "ArmSVEIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-op-decls",
+            "include/mlir/Dialect/ArmSVE/ArmSVE.h.inc",
+        ),
+        (
+            "-gen-op-defs",
+            "include/mlir/Dialect/ArmSVE/ArmSVE.cpp.inc",
+        ),
+        (
+            "-gen-typedef-decls",
+            "include/mlir/Dialect/ArmSVE/ArmSVETypes.h.inc",
+        ),
+        (
+            "-gen-typedef-defs",
+            "include/mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc",
+        ),
+        (
+            "-gen-dialect-decls -dialect arm_sve",
+            "include/mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/ArmSVE/ArmSVE.td",
+    td_srcs = [
+        ":ArmSVETdFiles",
+    ],
+)
+
+cc_library(
+    name = "ArmSVE",
+    srcs = [
+        "lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Dialect/ArmSVE/ArmSVEDialect.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":ArmSVEIncGen",
+        ":IR",
+        ":SideEffectInterfaces",
+        ":VectorOps",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_library(
+    name = "ArmSVEToLLVM",
+    srcs = glob([
+        "lib/Conversion/ArmSVEToLLVM/*.cpp",
+    ]) + ["lib/Conversion/PassDetail.h"],
+    hdrs = glob([
+        "include/mlir/Conversion/ArmSVEToLLVM/*.h",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":ArmSVE",
+        ":ConversionPassIncGen",
+        ":EDSC",
+        ":IR",
+        ":LLVMArmSVE",
+        ":LLVMDialect",
+        ":Pass",
+        ":StandardOps",
+        ":StandardToLLVM",
+        ":Support",
+        ":Transforms",
+        ":VectorOps",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+filegroup(
+    name = "LLVMArmSVETdFiles",
+    srcs = [
+        "include/mlir/Dialect/LLVMIR/LLVMArmSVE.td",
+        ":LLVMOpsTdFiles",
+    ],
+)
+
+gentbl(
+    name = "LLVMArmSVEIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-dialect-decls -dialect=llvm_arm_sve",
+            "include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h.inc",
+        ),
+        (
+            "-gen-op-decls",
+            "include/mlir/Dialect/LLVMIR/LLVMArmSVE.h.inc",
+        ),
+        (
+            "-gen-op-defs",
+            "include/mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc",
+        ),
+        (
+            "-gen-op-doc",
+            "g3doc/Dialects/LLVMIR/LLVMArmSVE.md",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/LLVMIR/LLVMArmSVE.td",
+    td_srcs = [
+        ":LLVMArmSVETdFiles",
+    ],
+)
+
+cc_library(
+    name = "LLVMArmSVE",
+    srcs = [
+        "lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":LLVMArmSVEIncGen",
+        ":LLVMDialect",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+gentbl(
+    name = "LLVMArmSVEConversionIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-llvmir-conversions",
+            "include/mlir/Dialect/LLVMIR/LLVMArmSVEConversions.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/LLVMIR/LLVMArmSVE.td",
+    td_srcs = [
+        ":LLVMArmSVETdFiles",
+    ],
+)
+
+cc_library(
+    name = "TargetLLVMArmSVEIntr",
+    srcs = [
+        "lib/Target/LLVMIR/LLVMArmSVEIntr.cpp",
+    ],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":LLVMArmSVE",
+        ":LLVMArmSVEConversionIncGen",
+        ":LLVMIRModuleTranslation",
+        ":Translation",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+##---------------------------------------------------------------------------##
 # AVX512 dialect.
 ##---------------------------------------------------------------------------##
 
@@ -463,6 +854,97 @@
 )
 
 filegroup(
+    name = "LLVMAVX512TdFiles",
+    srcs = [
+        "include/mlir/Dialect/LLVMIR/LLVMAVX512.td",
+        ":LLVMOpsTdFiles",
+    ],
+)
+
+gentbl(
+    name = "LLVMAVX512IncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-dialect-decls -dialect=llvm_avx512",
+            "include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h.inc",
+        ),
+        (
+            "-gen-op-decls",
+            "include/mlir/Dialect/LLVMIR/LLVMAVX512.h.inc",
+        ),
+        (
+            "-gen-op-defs",
+            "include/mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc",
+        ),
+        (
+            "-gen-op-doc",
+            "g3doc/Dialects/LLVMIR/LLVMAVX512.md",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/LLVMIR/LLVMAVX512.td",
+    td_srcs = [
+        ":LLVMAVX512TdFiles",
+    ],
+)
+
+cc_library(
+    name = "LLVMAVX512",
+    srcs = [
+        "lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":LLVMAVX512IncGen",
+        ":LLVMDialect",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+gentbl(
+    name = "LLVMAVX512ConversionIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-llvmir-conversions",
+            "include/mlir/Dialect/LLVMIR/LLVMAVX512Conversions.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/LLVMIR/LLVMAVX512.td",
+    td_srcs = [
+        ":LLVMAVX512TdFiles",
+    ],
+)
+
+cc_library(
+    name = "TargetLLVMAVX512Intr",
+    srcs = [
+        "lib/Target/LLVMIR/LLVMAVX512Intr.cpp",
+    ],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":LLVMAVX512",
+        ":LLVMAVX512ConversionIncGen",
+        ":LLVMIRModuleTranslation",
+        ":Translation",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+##---------------------------------------------------------------------------##
+# SCF dialect.
+##---------------------------------------------------------------------------##
+
+filegroup(
     name = "SCFTdFiles",
     srcs = [
         "include/mlir/Dialect/SCF/SCFOps.td",
@@ -772,6 +1254,7 @@
     deps = [
         ":AVX512ToLLVM",
         ":AffineToStandard",
+        ":ArmNeonToLLVM",
         ":AsyncToLLVM",
         ":ConversionPassIncGen",
         ":GPUToGPURuntimeTransforms",
@@ -1047,6 +1530,7 @@
         ":ShapeToStandardGen",
         ":StandardOps",
         ":Support",
+        ":TensorDialect",
         ":Transforms",
     ],
 )
@@ -1104,6 +1588,7 @@
         ":SideEffectInterfaces",
         ":StandardOpsIncGen",
         ":Support",
+        ":TensorDialect",
         ":VectorInterfaces",
         ":ViewLikeInterface",
         "@llvm-project//llvm:Support",
@@ -1139,6 +1624,7 @@
         ":StandardOps",
         ":StandardOpsTransformsPassIncGen",
         ":Support",
+        ":TensorDialect",
         ":Transforms",
         "@llvm-project//llvm:Support",
     ],
@@ -1229,93 +1715,6 @@
     ],
 )
 
-filegroup(
-    name = "LLVMAVX512TdFiles",
-    srcs = [
-        "include/mlir/Dialect/LLVMIR/LLVMAVX512.td",
-        ":LLVMOpsTdFiles",
-    ],
-)
-
-gentbl(
-    name = "LLVMAVX512IncGen",
-    strip_include_prefix = "include",
-    tbl_outs = [
-        (
-            "-gen-dialect-decls -dialect=llvm_avx512",
-            "include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h.inc",
-        ),
-        (
-            "-gen-op-decls",
-            "include/mlir/Dialect/LLVMIR/LLVMAVX512.h.inc",
-        ),
-        (
-            "-gen-op-defs",
-            "include/mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc",
-        ),
-        (
-            "-gen-op-doc",
-            "g3doc/Dialects/LLVMIR/LLVMAVX512.md",
-        ),
-    ],
-    tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/LLVMIR/LLVMAVX512.td",
-    td_srcs = [
-        ":LLVMAVX512TdFiles",
-    ],
-)
-
-cc_library(
-    name = "LLVMAVX512",
-    srcs = [
-        "lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp",
-    ],
-    hdrs = [
-        "include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h",
-    ],
-    includes = ["include"],
-    deps = [
-        ":IR",
-        ":LLVMAVX512IncGen",
-        ":LLVMDialect",
-        "@llvm-project//llvm:Core",
-        "@llvm-project//llvm:Support",
-    ],
-)
-
-gentbl(
-    name = "LLVMAVX512ConversionIncGen",
-    strip_include_prefix = "include",
-    tbl_outs = [
-        (
-            "-gen-llvmir-conversions",
-            "include/mlir/Dialect/LLVMIR/LLVMAVX512Conversions.inc",
-        ),
-    ],
-    tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/LLVMIR/LLVMAVX512.td",
-    td_srcs = [
-        ":LLVMAVX512TdFiles",
-    ],
-)
-
-cc_library(
-    name = "TargetLLVMAVX512Intr",
-    srcs = [
-        "lib/Target/LLVMIR/LLVMAVX512Intr.cpp",
-    ],
-    includes = ["include"],
-    deps = [
-        ":IR",
-        ":LLVMAVX512",
-        ":LLVMAVX512ConversionIncGen",
-        ":LLVMIRModuleTranslation",
-        ":Translation",
-        "@llvm-project//llvm:Core",
-        "@llvm-project//llvm:Support",
-    ],
-)
-
 cc_library(
     name = "LLVMDialect",
     srcs = glob(
@@ -1326,6 +1725,10 @@
         exclude = [
             "lib/Dialect/LLVMIR/IR/*AVX512*.cpp",
             "lib/Dialect/LLVMIR/IR/*AVX512*.h",
+            "lib/Dialect/LLVMIR/IR/*ArmNeon*.cpp",
+            "lib/Dialect/LLVMIR/IR/*ArmNeon*.h",
+            "lib/Dialect/LLVMIR/IR/*ArmSVE*.cpp",
+            "lib/Dialect/LLVMIR/IR/*ArmSVE*.h",
             "lib/Dialect/LLVMIR/IR/NVVM*.cpp",
             "lib/Dialect/LLVMIR/IR/NVVM*.h",
             "lib/Dialect/LLVMIR/IR/ROCDL*.cpp",
@@ -1338,6 +1741,8 @@
         ],
         exclude = [
             "include/mlir/Dialect/LLVMIR/*AVX512*.h",
+            "include/mlir/Dialect/LLVMIR/*ArmNeon*.h",
+            "include/mlir/Dialect/LLVMIR/*ArmSVE*.h",
             "include/mlir/Dialect/LLVMIR/NVVM*.h",
             "include/mlir/Dialect/LLVMIR/ROCDL*.h",
         ],
@@ -1656,8 +2061,8 @@
     deps = [
         ":ConversionPassIncGen",
         ":Pass",
+        ":SPIRVConversion",
         ":SPIRVDialect",
-        ":SPIRVLowering",
         ":Transforms",
         ":VectorOps",
     ],
@@ -1796,8 +2201,8 @@
         ":Pass",
         ":SCFDialect",
         ":SCFToSPIRV",
+        ":SPIRVConversion",
         ":SPIRVDialect",
-        ":SPIRVLowering",
         ":StandardToSPIRVTransforms",
         ":Support",
         ":Transforms",
@@ -1847,6 +2252,7 @@
         ":LLVMDialect",
         ":Pass",
         ":SPIRVDialect",
+        ":SPIRVUtils",
         ":StandardOps",
         ":StandardToLLVM",
         ":Support",
@@ -2174,7 +2580,7 @@
         "include/mlir/Interfaces/ControlFlowInterfaces.td",
         ":SideEffectTdFiles",
         ":OpBaseTdFiles",
-    ] + glob(["include/mlir/Dialect/SPIRV/*.td"]),
+    ] + glob(["include/mlir/Dialect/SPIRV/IR/*.td"]),
 )
 
 gentbl(
@@ -2183,15 +2589,15 @@
     tbl_outs = [
         (
             "-gen-op-decls",
-            "include/mlir/Dialect/SPIRV/SPIRVOps.h.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVOps.h.inc",
         ),
         (
             "-gen-op-defs",
-            "include/mlir/Dialect/SPIRV/SPIRVOps.cpp.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc",
         ),
         (
             "-gen-dialect-decls",
-            "include/mlir/Dialect/SPIRV/SPIRVOpsDialect.h.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.h.inc",
         ),
         (
             "-gen-op-doc",
@@ -2199,27 +2605,27 @@
         ),
         (
             "-gen-enum-decls",
-            "include/mlir/Dialect/SPIRV/SPIRVEnums.h.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h.inc",
         ),
         (
             "-gen-enum-defs",
-            "include/mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVEnums.cpp.inc",
         ),
         (
             "-gen-spirv-enum-avail-decls",
-            "include/mlir/Dialect/SPIRV/SPIRVEnumAvailability.h.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVEnumAvailability.h.inc",
         ),
         (
             "-gen-spirv-enum-avail-defs",
-            "include/mlir/Dialect/SPIRV/SPIRVEnumAvailability.cpp.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVEnumAvailability.cpp.inc",
         ),
         (
             "-gen-spirv-capability-implication",
-            "include/mlir/Dialect/SPIRV/SPIRVCapabilityImplication.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVCapabilityImplication.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/SPIRV/SPIRVOps.td",
+    td_file = "include/mlir/Dialect/SPIRV/IR/SPIRVOps.td",
     td_srcs = [
         ":SPIRVOpsTdFiles",
     ],
@@ -2227,18 +2633,18 @@
 
 gentbl(
     name = "SPIRVCanonicalizationIncGen",
-    strip_include_prefix = "lib/Dialect/SPIRV",
+    strip_include_prefix = "lib/Dialect/SPIRV/IR",
     tbl_outs = [
         (
             "-gen-rewriters",
-            "lib/Dialect/SPIRV/SPIRVCanonicalization.inc",
+            "lib/Dialect/SPIRV/IR/SPIRVCanonicalization.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "lib/Dialect/SPIRV/SPIRVCanonicalization.td",
+    td_file = "lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td",
     td_srcs = [
         ":SPIRVOpsTdFiles",
-        "lib/Dialect/SPIRV/SPIRVCanonicalization.td",
+        "lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td",
     ],
 )
 
@@ -2248,19 +2654,19 @@
     tbl_outs = [
         (
             "-gen-avail-interface-decls",
-            "include/mlir/Dialect/SPIRV/SPIRVAvailability.h.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.h.inc",
         ),
         (
             "-gen-avail-interface-defs",
-            "include/mlir/Dialect/SPIRV/SPIRVAvailability.cpp.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc",
         ),
         (
             "-gen-spirv-avail-impls",
-            "include/mlir/Dialect/SPIRV/SPIRVOpAvailabilityImpl.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/SPIRV/SPIRVOps.td",
+    td_file = "include/mlir/Dialect/SPIRV/IR/SPIRVOps.td",
     td_srcs = [
         ":SPIRVOpsTdFiles",
     ],
@@ -2271,15 +2677,15 @@
     tbl_outs = [
         (
             "-gen-struct-attr-decls",
-            "include/mlir/Dialect/SPIRV/TargetAndABI.h.inc",
+            "include/mlir/Dialect/SPIRV/IR/TargetAndABI.h.inc",
         ),
         (
             "-gen-struct-attr-defs",
-            "include/mlir/Dialect/SPIRV/TargetAndABI.cpp.inc",
+            "include/mlir/Dialect/SPIRV/IR/TargetAndABI.cpp.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/SPIRV/TargetAndABI.td",
+    td_file = "include/mlir/Dialect/SPIRV/IR/TargetAndABI.td",
     td_srcs = [
         ":SPIRVOpsTdFiles",
         ":StdOpsTdFiles",
@@ -2287,16 +2693,16 @@
 )
 
 gentbl(
-    name = "SPIRVOpUtilsIncGen",
+    name = "SPIRVAttrUtilsGen",
     strip_include_prefix = "include",
     tbl_outs = [
         (
-            "-gen-spirv-op-utils",
-            "include/mlir/Dialect/SPIRV/SPIRVOpUtils.inc",
+            "-gen-spirv-attr-utils",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVAttrUtils.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/SPIRV/SPIRVBase.td",
+    td_file = "include/mlir/Dialect/SPIRV/IR/SPIRVBase.td",
     td_srcs = [
         ":SPIRVOpsTdFiles",
         ":SPIRVAvailabilityIncGen",
@@ -2309,11 +2715,11 @@
     tbl_outs = [
         (
             "-gen-spirv-serialization",
-            "include/mlir/Dialect/SPIRV/SPIRVSerialization.inc",
+            "include/mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/SPIRV/SPIRVOps.td",
+    td_file = "include/mlir/Dialect/SPIRV/IR/SPIRVOps.td",
     td_srcs = [
         ":SPIRVOpsTdFiles",
     ],
@@ -2321,26 +2727,15 @@
 
 cc_library(
     name = "SPIRVDialect",
-    srcs = glob(
-        [
-            "lib/Dialect/SPIRV/*.cpp",
-            "lib/Dialect/SPIRV/*.h",
-        ],
-        exclude = [
-            "lib/Dialect/SPIRV/SPIRVLowering.cpp",
-        ],
-    ) + [
+    srcs = glob([
+        "lib/Dialect/SPIRV/IR/*.cpp",
+        "lib/Dialect/SPIRV/IR/*.h",
+    ]) + [
         "include/mlir/Transforms/InliningUtils.h",
     ],
-    hdrs = glob(
-        [
-            "include/mlir/Dialect/SPIRV/*.h",
-        ],
-        exclude = [
-            "include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h",
-            "include/mlir/Dialect/SPIRV/SPIRVLowering.h",
-        ],
-    ),
+    hdrs = glob([
+        "include/mlir/Dialect/SPIRV/IR/*.h",
+    ]),
     includes = ["include"],
     deps = [
         ":CommonFolders",
@@ -2348,11 +2743,10 @@
         ":IR",
         ":Parser",
         ":Pass",
+        ":SPIRVAttrUtilsGen",
         ":SPIRVAvailabilityIncGen",
         ":SPIRVCanonicalizationIncGen",
-        ":SPIRVOpUtilsIncGen",
         ":SPIRVOpsIncGen",
-        ":SPIRVPassIncGen",
         ":SPIRVSerializationGen",
         ":SPIRVTargetAndABIStructGen",
         ":SideEffectInterfaces",
@@ -2368,39 +2762,82 @@
     tbl_outs = [
         (
             "-gen-pass-decls -name SPIRV",
-            "include/mlir/Dialect/SPIRV/Passes.h.inc",
+            "include/mlir/Dialect/SPIRV/Transforms/Passes.h.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/SPIRV/Passes.td",
+    td_file = "include/mlir/Dialect/SPIRV/Transforms/Passes.td",
     td_srcs = [
         ":PassBaseTdFiles",
     ],
 )
 
 cc_library(
-    name = "SPIRVLowering",
+    name = "SPIRVUtils",
     srcs = glob([
-        "lib/Dialect/SPIRV/Transforms/*.cpp",
-        "lib/Dialect/SPIRV/Transforms/*.h",
-    ]) + [
-        "lib/Dialect/SPIRV/SPIRVLowering.cpp",
+        "lib/Dialect/SPIRV/Utils/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/mlir/Dialect/SPIRV/Utils/*.h",
+    ]),
+    includes = [
+        "include",
+    ],
+    deps = [
+        ":SPIRVDialect",
+        ":Support",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_library(
+    name = "SPIRVConversion",
+    srcs = [
+        "lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp",
     ],
     hdrs = [
-        "include/mlir/Dialect/SPIRV/Passes.h",
-        "include/mlir/Dialect/SPIRV/SPIRVLowering.h",
-        "include/mlir/Dialect/SPIRV/TargetAndABI.h",
+        "include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h",
     ],
     includes = [
         "include",
     ],
     deps = [
+        ":SPIRVDialect",
+        ":Support",
+        ":TransformUtils",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_library(
+    name = "SPIRVTransforms",
+    srcs = glob(
+        [
+            "lib/Dialect/SPIRV/Transforms/*.cpp",
+            "lib/Dialect/SPIRV/Transforms/*.h",
+        ],
+        exclude = [
+            "lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp",
+        ],
+    ),
+    hdrs = glob(
+        [
+            "include/mlir/Dialect/SPIRV/Transforms/*.h",
+        ],
+        exclude = [
+            "include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h",
+        ],
+    ),
+    includes = [
+        "include",
+    ],
+    deps = [
         ":IR",
         ":Pass",
+        ":SPIRVConversion",
         ":SPIRVDialect",
         ":SPIRVPassIncGen",
-        ":SPIRVTargetAndABIStructGen",
-        ":StandardOps",
+        ":SPIRVUtils",
         ":Support",
         ":Transforms",
         "@llvm-project//llvm:Support",
@@ -2424,8 +2861,9 @@
         ":ConversionPassIncGen",
         ":IR",
         ":Pass",
+        ":SPIRVConversion",
         ":SPIRVDialect",
-        ":SPIRVLowering",
+        ":SPIRVUtils",
         ":StandardOps",
         ":Support",
         ":Transforms",
@@ -2440,24 +2878,38 @@
 )
 
 cc_library(
-    name = "SPIRVSerialization",
-    srcs = glob(
-        [
-            "lib/Dialect/SPIRV/Serialization/*.cpp",
-        ],
-        exclude = [
-            "lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp",
-        ],
-    ),
+    name = "SPIRVBinaryUtils",
+    srcs = [
+        "lib/Target/SPIRV/SPIRVBinaryUtils.cpp",
+    ],
     hdrs = [
-        "include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h",
-        "include/mlir/Dialect/SPIRV/Serialization.h",
+        "include/mlir/Target/SPIRV/SPIRVBinaryUtils.h",
     ],
     includes = ["include"],
     deps = [
         ":IR",
+        ":SPIRVAttrUtilsGen",
         ":SPIRVDialect",
-        ":SPIRVOpUtilsIncGen",
+        ":SPIRVOpsIncGen",
+        ":Support",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_library(
+    name = "SPIRVSerialization",
+    srcs = [
+        "lib/Target/SPIRV/Serialization.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Target/SPIRV/Serialization.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":SPIRVAttrUtilsGen",
+        ":SPIRVBinaryUtils",
+        ":SPIRVDialect",
         ":SPIRVOpsIncGen",
         ":SPIRVSerializationGen",
         ":Support",
@@ -2467,14 +2919,36 @@
 )
 
 cc_library(
-    name = "SPIRVLinking",
+    name = "SPIRVDeserialization",
+    srcs = [
+        "lib/Target/SPIRV/Deserialization.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Target/SPIRV/Deserialization.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":SPIRVAttrUtilsGen",
+        ":SPIRVBinaryUtils",
+        ":SPIRVDialect",
+        ":SPIRVOpsIncGen",
+        ":SPIRVSerializationGen",
+        ":Support",
+        ":Transforms",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_library(
+    name = "SPIRVModuleCombiner",
     srcs = glob(
         [
             "lib/Dialect/SPIRV/Linking/ModuleCombiner/*.cpp",
         ],
     ),
     hdrs = [
-        "include/mlir/Dialect/SPIRV/ModuleCombiner.h",
+        "include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h",
     ],
     includes = ["include"],
     deps = [
@@ -2488,12 +2962,13 @@
 cc_library(
     name = "SPIRVTranslateRegistration",
     srcs = [
-        "lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp",
+        "lib/Target/SPIRV/TranslateRegistration.cpp",
     ],
     includes = ["include"],
     deps = [
         ":IR",
         ":Parser",
+        ":SPIRVDeserialization",
         ":SPIRVDialect",
         ":SPIRVSerialization",
         ":Support",
@@ -2502,6 +2977,120 @@
     ],
 )
 
+filegroup(
+    name = "TensorOpsTdFiles",
+    srcs = [
+        "include/mlir/Dialect/Tensor/IR/TensorBase.td",
+        "include/mlir/Dialect/Tensor/IR/TensorOps.td",
+        ":OpBaseTdFiles",
+        ":SideEffectTdFiles",
+    ],
+)
+
+gentbl(
+    name = "TensorBaseIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-dialect-decls -dialect=tensor",
+            "include/mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Tensor/IR/TensorBase.td",
+    td_srcs = [
+        ":TensorOpsTdFiles",
+    ],
+)
+
+gentbl(
+    name = "TensorOpsIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-op-decls",
+            "include/mlir/Dialect/Tensor/IR/TensorOps.h.inc",
+        ),
+        (
+            "-gen-op-defs",
+            "include/mlir/Dialect/Tensor/IR/TensorOps.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Tensor/IR/TensorOps.td",
+    td_srcs = [
+        ":TensorOpsTdFiles",
+    ],
+)
+
+cc_library(
+    name = "TensorDialect",
+    srcs = glob(
+        [
+            "lib/Dialect/Tensor/IR/*.cpp",
+            "lib/Dialect/Tensor/IR/*.h",
+        ],
+    ) + [
+        "include/mlir/Transforms/InliningUtils.h",
+    ],
+    hdrs = [
+        "include/mlir/Dialect/Tensor/IR/Tensor.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":SideEffectInterfaces",
+        ":Support",
+        ":TensorBaseIncGen",
+        ":TensorOpsIncGen",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+gentbl(
+    name = "TensorPassIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-pass-decls -name Tensor",
+            "include/mlir/Dialect/Tensor/Transforms/Passes.h.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Tensor/Transforms/Passes.td",
+    td_srcs = [
+        ":PassBaseTdFiles",
+    ],
+)
+
+cc_library(
+    name = "TensorTransforms",
+    srcs = glob(
+        [
+            "lib/Dialect/Tensor/Transforms/*.cpp",
+            "lib/Dialect/Tensor/Transforms/*.h",
+        ],
+    ),
+    hdrs = [
+        "include/mlir/Dialect/Tensor/Transforms/Passes.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":Async",
+        ":EDSC",
+        ":IR",
+        ":ParallelLoopMapperAttrGen",
+        ":Pass",
+        ":SCFDialect",
+        ":StandardOps",
+        ":Support",
+        ":TensorDialect",
+        ":TensorPassIncGen",
+        ":Transforms",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "Rewrite",
     srcs = glob([
@@ -2789,8 +3378,8 @@
         ":IR",
         ":Pass",
         ":SCFDialect",
+        ":SPIRVConversion",
         ":SPIRVDialect",
-        ":SPIRVLowering",
         ":StandardOps",
         ":Support",
         ":TransformUtils",
@@ -3136,6 +3725,8 @@
         ":OpenMPDialect",
         ":Support",
         ":TargetLLVMAVX512Intr",
+        ":TargetLLVMArmNeonIntr",
+        ":TargetLLVMArmSVEIntr",
         ":Translation",
         "@llvm-project//llvm:Core",
         "@llvm-project//llvm:IRReader",
@@ -3334,6 +3925,10 @@
         ":AffinePassIncGen",
         ":AffineToStandard",
         ":AffineTransforms",
+        ":ArmNeon",
+        ":ArmNeonToLLVM",
+        ":ArmSVE",
+        ":ArmSVEToLLVM",
         ":Async",
         ":AsyncPassIncGen",
         ":AsyncToLLVM",
@@ -3349,6 +3944,8 @@
         ":GPUTransforms",
         ":IR",
         ":LLVMAVX512",
+        ":LLVMArmNeon",
+        ":LLVMArmSVE",
         ":LLVMDialect",
         ":LLVMIRTransforms",
         ":LLVMPassIncGen",
@@ -3375,9 +3972,9 @@
         ":SCFTransforms",
         ":SDBM",
         ":SPIRVDialect",
-        ":SPIRVLowering",
         ":SPIRVPassIncGen",
         ":SPIRVToLLVM",
+        ":SPIRVTransforms",
         ":Shape",
         ":ShapeToStandard",
         ":ShapeTransforms",
@@ -3387,6 +3984,8 @@
         ":StandardOpsTransformsPassIncGen",
         ":StandardToLLVM",
         ":StandardToSPIRVTransforms",
+        ":TensorDialect",
+        ":TensorTransforms",
         ":TosaDialect",
         ":Transforms",
         ":TransformsPassIncGen",
@@ -3470,15 +4069,18 @@
 )
 
 cc_library(
-    name = "mlir_async_runtime",
-    srcs = [
-        "lib/ExecutionEngine/AsyncRuntime.cpp",
-    ],
-    hdrs = [
-        "include/mlir/ExecutionEngine/AsyncRuntime.h",
-    ],
+    name = "mlir_async_runtime_api",
+    hdrs = ["include/mlir/ExecutionEngine/AsyncRuntime.h"],
     includes = ["include"],
-    deps = ["@llvm-project//llvm:Support"],
+)
+
+cc_library(
+    name = "mlir_async_runtime",
+    srcs = ["lib/ExecutionEngine/AsyncRuntime.cpp"],
+    deps = [
+        ":mlir_async_runtime_api",
+        "@llvm-project//llvm:Support",
+    ],
 )
 
 cc_library(
@@ -3603,6 +4205,7 @@
         ":MlirJitRunner",
         ":Pass",
         ":SPIRVDialect",
+        ":SPIRVTransforms",
         ":StandardToLLVM",
         ":StandardToSPIRVTransforms",
         "@llvm-project//llvm:Support",
@@ -4138,8 +4741,8 @@
         ":LinalgOps",
         ":LinalgTransforms",
         ":Pass",
+        ":SPIRVConversion",
         ":SPIRVDialect",
-        ":SPIRVLowering",
         ":StandardOps",
     ],
 )
@@ -4232,6 +4835,7 @@
         ":StandardOpsTransforms",
         ":StandardToLLVM",
         ":Support",
+        ":TensorDialect",
         ":TransformUtils",
         ":Transforms",
         ":TransformsPassIncGen",
@@ -4294,11 +4898,17 @@
     deps = [
         ":AVX512",
         ":AVX512ToLLVM",
+        ":ArmNeon",
+        ":ArmNeonToLLVM",
+        ":ArmSVE",
+        ":ArmSVEToLLVM",
         ":ConversionPassIncGen",
         ":DialectUtils",
         ":EDSC",
         ":IR",
         ":LLVMAVX512",
+        ":LLVMArmNeon",
+        ":LLVMArmSVE",
         ":LLVMDialect",
         ":LLVMIRModuleTranslation",
         ":Pass",
diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD
index ab0da68..6ba0fbe 100644
--- a/third_party/mlir/test.BUILD
+++ b/third_party/mlir/test.BUILD
@@ -308,9 +308,9 @@
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SPIRVConversion",
         "@llvm-project//mlir:SPIRVDialect",
-        "@llvm-project//mlir:SPIRVLinking",
-        "@llvm-project//mlir:SPIRVLowering",
+        "@llvm-project//mlir:SPIRVModuleCombiner",
     ],
 )