diff --git a/.bazelrc b/.bazelrc
index d06e0836184..a1f323c142d 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -365,13 +365,14 @@ build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe
 build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
 
 build:rbe_linux_cuda_nvcc --config=rbe_linux
-build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
-build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
-build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
-build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
-build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
-build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
-build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0"
+build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
+build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
+build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
+build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
+build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
 build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
 build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
 build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
diff --git a/README.md b/README.md
index 56baa0740c3..e95fea22c56 100644
--- a/README.md
+++ b/README.md
@@ -70,7 +70,7 @@ $ python
 3
 >>> hello = tf.constant('Hello, TensorFlow!')
 >>> hello.numpy()
-'Hello, TensorFlow!'
+b'Hello, TensorFlow!'
 ```
 
 For more examples, see the
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 67a2dde6c27..29dba253fee 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -418,21 +418,11 @@ void TensorHandleSilentCopy(bool async,
 
     auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
         matmul->operation.get());
-    if (!async) {
-      // The input handles should never change since they have been mirrored.
-      ASSERT_EQ(op->GetInput(0), arg0);
-      ASSERT_EQ(op->GetInput(1), arg1);
-    } else {
-      if (cpu_op) {
-        ASSERT_EQ(op->GetInput(0), arg0);
-        // The GPU handle should be replaced with a CPU copy
-        ASSERT_NE(op->GetInput(1), arg1);
-      } else {
-        // The CPU handle should be replaced with a GPU copy
-        ASSERT_NE(op->GetInput(0), arg0);
-        ASSERT_EQ(op->GetInput(1), arg1);
-      }
-    }
+
+    // The input handles should never change since they have been mirrored.
+    EXPECT_EQ(op->GetInput(0), arg0);
+    EXPECT_EQ(op->GetInput(1), arg1);
+
     TFE_DeleteOp(matmul);
     TFE_DeleteTensorHandle(retvals[0]);
     TFE_DeleteTensorHandle(hgpu);
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index c283328403b..acbd2d27a45 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -14,6 +14,10 @@ package_group(
     includes = [
         "//tensorflow/compiler/tf2xla:internal",
     ],
+    packages = [
+        "//tensorflow/compiler/tests/...",
+        "//tensorflow/python/...",
+    ],
 )
 
 package_group(
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index 6023507cac7..6753ab9e728 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -46,7 +46,7 @@ limitations under the License.
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/QuantOps/QuantOps.h"  // TF:llvm-project
 #include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Diagnostics.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
index 084bd26d28e..ac20ab68eaa 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
@@ -42,7 +42,7 @@ limitations under the License.
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/ToolOutputFile.h"
 #include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
@@ -165,7 +165,7 @@ constexpr size_t kInitialBufferSize = 10240;
 // `isSigned` is set to false for other types.
 static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
                                                   bool is_signed = true) {
-  if (!is_signed && type.isInteger(8)) {
+  if (!is_signed && type.isSignlessInteger(8)) {
     return tflite::TensorType_UINT8;
   }
   if (!is_signed) {
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index e73f6b732eb..83e372e5732 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -26,7 +26,7 @@ limitations under the License.
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Matchers.h"  // TF:llvm-project
@@ -275,7 +275,7 @@ Attribute ConstFoldBinaryOp(
     return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
                                         float_calculate, is_commutative);
 
-  if (elemType.isa<IntegerType>())
+  if (elemType.isSignlessInteger())
     return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
                                           int_calculate, is_commutative);
 
@@ -1560,7 +1560,7 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
            limit_tensor.getType().getRank() == 0 &&
            delta_tensor.getType().getRank() == 0);
     Type elem_type = getType().cast<ShapedType>().getElementType();
-    if (elem_type.isa<IntegerType>()) {
+    if (elem_type.isSignlessInteger()) {
       auto start_attr = start_tensor.getValue<IntegerAttr>({});
       auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
       auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
@@ -1662,7 +1662,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
 
   // Do not try to fold elements attr of a quant type because
   // DenseElementsAttr does not support it.
-  if (!getType().cast<ShapedType>().getElementType().isIntOrFloat())
+  if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
     return nullptr;
 
   assert(perm_tensor.getType().getRank() == 1);
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 36a1e93dc26..c4d5c6ce98f 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -3100,7 +3100,9 @@ def LstmMandatoryInputsConstraint : PredOpTrait<
   "mandatory operands element types should match",
   // TODO(ashwinm): Replace the indices with input tensor names when that
   // support is available.
-  TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 18, 19]>>;
+  Or<[
+    TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 18, 19]>,
+    Neg<TypeIsPred<"input", F32>>]>>;
 
 def LstmOptionalPeepholeWeightConstraint : PredOpTrait<
   "the optional peephole weights should all be specified or none",
@@ -3126,7 +3128,7 @@ def LstmProjectionWeightBiasConstraint : PredOpTrait<
 
 def LstmResultConstraint : PredOpTrait<
   "the input and result tensor elemental types must be same",
-  TCresVTEtIsSameAsOp<0, 0>>;
+  TFL_TCresVTEtIsSameAsOp<0, 0>>;
 
 // This is the basic kernel type LSTM op.
 // TODO(b/142417845): Refactor this part to return its tflite node name as
@@ -3195,47 +3197,47 @@ Ba et al. “Layer Normalization”
   }];
 
   let arguments = (
-    ins TFL_TensorOf<[F32]>:$input,
+    ins TFL_TensorOf<[F32, QI8]>:$input,
 
     // Weights
-    TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights,
-    TFL_TensorOf<[F32, I8]>:$input_to_forget_weights,
-    TFL_TensorOf<[F32, I8]>:$input_to_cell_weights,
-    TFL_TensorOf<[F32, I8]>:$input_to_output_weights,
+    TFL_TensorOfOrNone<[F32, I8, QI8]>:$input_to_input_weights,
+    TFL_TensorOf<[F32, I8, QI8]>:$input_to_forget_weights,
+    TFL_TensorOf<[F32, I8, QI8]>:$input_to_cell_weights,
+    TFL_TensorOf<[F32, I8, QI8]>:$input_to_output_weights,
 
     // Recurrent weights
-    TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights,
-    TFL_TensorOf<[F32, I8]>:$recurrent_to_forget_weights,
-    TFL_TensorOf<[F32, I8]>:$recurrent_to_cell_weights,
-    TFL_TensorOf<[F32, I8]>:$recurrent_to_output_weights,
+    TFL_TensorOfOrNone<[F32, I8, QI8]>:$recurrent_to_input_weights,
+    TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_forget_weights,
+    TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_cell_weights,
+    TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_output_weights,
 
     // Cell weights
-    TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights,
+    TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_input_weights,
     // Optional input
-    TFL_TensorOfOrNone<[F32, I8]>:$cell_to_forget_weights,
+    TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_forget_weights,
     // Optional input
-    TFL_TensorOfOrNone<[F32, I8]>:$cell_to_output_weights,
+    TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_output_weights,
 
     // Bias
-    TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
-    TFL_TensorOf<[F32]>:$forget_gate_bias,
-    TFL_TensorOf<[F32]>:$cell_bias,
-    TFL_TensorOf<[F32]>:$output_gate_bias,
+    TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias,
+    TFL_TensorOf<[F32, QI32]>:$forget_gate_bias,
+    TFL_TensorOf<[F32, QI32]>:$cell_bias,
+    TFL_TensorOf<[F32, QI32]>:$output_gate_bias,
 
     // Projection weight and bias
-    TFL_TensorOfOrNone<[F32, I8]>:$projection_weights,
+    TFL_TensorOfOrNone<[F32, I8, QI8]>:$projection_weights,
     // Optional input
-    TFL_TensorOfOrNone<[F32]>:$projection_bias,
+    TFL_TensorOfOrNone<[F32, QI32]>:$projection_bias,
 
     // Stateful activation and cell states.
     TFL_StatefulTensor:$input_activation_state,
     TFL_StatefulTensor:$input_cell_state,
 
     // Layer norm coefficients
-    TFL_TensorOfOrNone<[F32]>:$input_layer_norm_coefficients,
-    TFL_TensorOfOrNone<[F32]>:$forget_layer_norm_coefficients,
-    TFL_TensorOfOrNone<[F32]>:$cell_layer_norm_coefficients,
-    TFL_TensorOfOrNone<[F32]>:$output_layer_norm_coefficients,
+    TFL_TensorOfOrNone<[F32, QI16]>:$input_layer_norm_coefficients,
+    TFL_TensorOfOrNone<[F32, QI16]>:$forget_layer_norm_coefficients,
+    TFL_TensorOfOrNone<[F32, QI16]>:$cell_layer_norm_coefficients,
+    TFL_TensorOfOrNone<[F32, QI16]>:$output_layer_norm_coefficients,
 
     // Attributes
     TFL_AFAttr:$fused_activation_function,
diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
index 45e87e63475..617f968b958 100644
--- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
@@ -25,7 +25,7 @@ limitations under the License.
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/QuantOps/FakeQuantSupport.h"  // TF:llvm-project
 #include "mlir/Dialect/QuantOps/QuantOps.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/AffineExpr.h"  // TF:llvm-project
 #include "mlir/IR/AffineMap.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
index b2355b2ae6e..5f52c892421 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
@@ -26,7 +26,7 @@ limitations under the License.
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/QuantOps/QuantOps.h"  // TF:llvm-project
 #include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
index ed998510328..9bb1d677df2 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
@@ -26,7 +26,7 @@ limitations under the License.
 #include "mlir/Dialect/QuantOps/FakeQuantSupport.h"  // TF:llvm-project
 #include "mlir/Dialect/QuantOps/QuantOps.h"  // TF:llvm-project
 #include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
@@ -191,7 +191,7 @@ struct QuantizationPattern : public RewritePattern {
         auto ele_type = operand.getType().cast<TensorType>().getElementType();
         if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
           inputs.push_back(op_inst.input());
-        } else if (ele_type.isa<IntegerType>()) {
+        } else if (ele_type.isSignlessInteger()) {
           // If the operand is an integer tensor, then it doesn't require the
           // DQ op in the pattern.
           inputs.push_back(operand);
@@ -225,7 +225,7 @@ struct QuantizationPattern : public RewritePattern {
           auto user = llvm::cast<Q>(*result.user_begin());
           outputs_replaced.insert({user.output(), enumerated_result.index()});
           output_types.push_back(user.getType());
-        } else if (result_ele_type.template isa<IntegerType>()) {
+        } else if (result_ele_type.isSignlessInteger()) {
           // If the result is an integer tensor, then it doesn't require the
           // D op in the pattern.
           outputs_replaced.insert({result, enumerated_result.index()});
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc
index 7c2846231c9..0c746d0c943 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc
@@ -26,7 +26,7 @@ limitations under the License.
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "mlir/Dialect/QuantOps/QuantOps.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
 #include "mlir/IR/PatternMatch.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
index a6d6ec52234..5fe5fbfb3ee 100644
--- a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
@@ -27,6 +27,20 @@ func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
 }
 
+func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
+  %cst = constant dense<[2, 2]> : tensor<2xi32>
+  %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
+  %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
+  %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
+  return %2 : tensor<1x128x128x8xf32>
+
+  // CHECK-LABEL: testDilatedConvWithNonTrivialDilations
+  // CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
+  // CHECK-NEXT: return [[RESULT]]
+}
+
 func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
   %cst = constant dense<[2, 2]> : tensor<2xi32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
@@ -104,7 +118,7 @@ func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1:
 
 func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
   %cst = constant dense<[2, 2]> : tensor<2xi32>
-  %cst_0 = constant dense<3> : tensor<i32>
+  %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
   %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
   %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@@ -115,7 +129,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
 
   // CHECK-LABEL: testDilatedConvWithExpandSqueeze1
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@@ -125,7 +139,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
 
 func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
   %cst = constant dense<[2, 2]> : tensor<2xi32>
-  %cst_0 = constant dense<3> : tensor<i32>
+  %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
   %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
   %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@@ -136,7 +150,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
 
   // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@@ -146,7 +160,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
 
 func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
   %cst = constant dense<[2, 2]> : tensor<2xi32>
-  %cst_0 = constant dense<3> : tensor<i32>
+  %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
   %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
   %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
@@ -157,7 +171,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
 
   // CHECK-LABEL: testDilatedConvWithExpandSqueeze2
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@@ -167,7 +181,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
 
 func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
   %cst = constant dense<[2, 2]> : tensor<2xi32>
-  %cst_0 = constant dense<3> : tensor<i32>
+  %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
   %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
   %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
@@ -178,7 +192,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
 
   // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@@ -188,7 +202,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
 
 func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
   %cst = constant dense<[2, 2]> : tensor<2xi32>
-  %cst_0 = constant dense<3> : tensor<i32>
+  %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
   %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
   %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@@ -200,7 +214,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
 
   // CHECK-LABEL: testDilatedConvWithExpandSqueeze3
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@@ -210,7 +224,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
 
 func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
   %cst = constant dense<[2, 2]> : tensor<2xi32>
-  %cst_0 = constant dense<3> : tensor<i32>
+  %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
   %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
   %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
   %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@@ -222,10 +236,29 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
 
   // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
   // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
 }
+
+func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> {
+  %cst = constant dense<[2, 2]> : tensor<2xi32>
+  %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
+  %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
+  %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
+  %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
+  %3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32>
+  %4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32>
+  return %4 : tensor<1x128x128x1xf32>
+
+  // CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis
+  // CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
+  // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
+  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
+  // CHECK-NEXT: return [[RESULT]]
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index da58b3704d0..57e2340dd37 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -593,6 +593,21 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>,
   return %0 : tensor<?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: testLstmQuantizedType
+func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
+    %cst = constant unit
+    %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( {
+    }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
+    return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
+  // CHECK: %[[RES0:.*]] = constant unit
+  // CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( {
+  // CHECK-NEXT: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
+  // CHECK: return %[[RES1]]
+}
+
+
 // -----
 
 // CHECK-LABEL: testLstm
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
index 448a4f9eb5f..0dcce6bf4e8 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
@@ -154,7 +154,7 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
 // -----
 
 module {
-func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
+func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
   %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
   %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
   %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
@@ -165,7 +165,7 @@ func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor
   return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
 }
 
-// CHECK:       func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
+// CHECK:       func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
 // CHECK:           [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
 // CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
@@ -189,7 +189,7 @@ func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor
 // -----
 
 module {
-func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
+func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
   %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
   %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
   %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
@@ -200,7 +200,7 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te
   return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
 }
 
-// CHECK:       func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
+// CHECK:       func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
 // CHECK:           [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
 // CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
 // CHECK:           [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
@@ -224,3 +224,84 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te
 // CHECK:           return [[VAL_24]] : tensor<8x8x10xf32>
 // CHECK:         }
 }
+
+// -----
+
+module {
+func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
+  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
+  %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
+  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
+  %3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
+  %4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x?x10xf32>
+  %5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
+  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
+}
+
+// CHECK:       func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
+// CHECK:           [[VAL_6:%.*]] = constant dense<0> : tensor<1xi32>
+// CHECK:           [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor<?x8x8xf32>, tensor<1xi32>) -> tensor<?x8x8xf32>
+// CHECK:           [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
+// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
+// CHECK:           [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
+// CHECK:           [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
+// CHECK:           [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
+// CHECK:           [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
+// CHECK:           [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
+// CHECK:           [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
+// CHECK:           [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
+// CHECK:           [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
+// CHECK:           [[VAL_21:%.*]] = constant unit
+// CHECK:           [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
+// CHECK:           }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32>
+// CHECK:           return [[VAL_23:%.*]] : tensor<8x?x10xf32>
+// CHECK:         }
+
+}
+
+// -----
+
+module {
+func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
+  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
+  %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
+  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
+  %3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
+  %4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
+  %5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
+  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
+}
+
+// CHECK:       func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
+// CHECK:           [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
+// CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
+// CHECK:           [[VAL_8:%.*]] = constant dense<0> : tensor<1xi32>
+// CHECK:           [[VAL_9:%.*]] = "tf.ReverseV2"([[VAL_7]], [[VAL_8]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32>
+// CHECK:           [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
+// CHECK:           [[VAL_11:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_10]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
+// CHECK:           [[VAL_12:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
+// CHECK:           [[VAL_13:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_12]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
+// CHECK:           [[VAL_14:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
+// CHECK:           [[VAL_15:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           [[VAL_16:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_14]], [[VAL_15]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
+// CHECK:           [[VAL_17:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
+// CHECK:           [[VAL_18:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           [[VAL_19:%.*]]:4 = "tf.SplitV"([[VAL_13]], [[VAL_17]], [[VAL_18]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
+// CHECK:           [[VAL_20:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
+// CHECK:           [[VAL_21:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           [[VAL_22:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_20]], [[VAL_21]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
+// CHECK:           [[VAL_23:%.*]] = constant unit
+// CHECK:           [[VAL_24:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) ( {
+// CHECK:           }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
+// CHECK:           [[VAL_25:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
+// CHECK:           [[VAL_26:%.*]] = "tf.Transpose"([[VAL_27:%.*]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
+// CHECK:           return [[VAL_26]] : tensor<8x8x10xf32>
+// CHECK:         }
+
+}
+
diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
index 0472bd6abcf..30fe391762f 100644
--- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
index c3d3df14e0b..65bed845bae 100644
--- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
+++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
@@ -27,6 +27,7 @@ limitations under the License.
 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
 #include "mlir/IR/TypeUtilities.h"  // TF:llvm-project
 #include "mlir/Pass/Pass.h"  // TF:llvm-project
+#include "tensorflow/compiler/mlir/lite/utils/validators.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 
 namespace mlir {
@@ -80,6 +81,17 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
 template <typename Conv2dOpTy>
 PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
     Conv2dOpTy op, PatternRewriter& rewriter) const {
+  // Make sure Conv2D has 'VALID' padding.
+  if (op.template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
+    return Pattern::matchFailure();
+  }
+  // Make sure dilations are all ones if set.
+  const ArrayAttr& dilations =
+      op.template getAttrOfType<ArrayAttr>("dilations");
+  if (dilations && !TFIntListIsAllOnes(dilations)) {
+    return Pattern::matchFailure();
+  }
+
   // Check if the ConvOp is preceded by a `Expand` op and succeeded by a
   // `Squeeze` op.
   Operation* prev_op = op.getOperation()->getPrevNode();
@@ -90,6 +102,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
 
   TF::ExpandDimsOp expand_op;
   TF::SqueezeOp squeeze_op;
+  int64_t expand_axis;
   // Expand + Squeeze op.
   if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
     if (!llvm::isa<TF::SqueezeOp>(next_op)) {
@@ -99,6 +112,22 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
     expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
     squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
 
+    // Make sure that the axis in `expand_op` is constant.
+    if (auto const_op =
+            llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
+      expand_axis =
+          (*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
+              .getSExtValue();
+    } else {
+      return Pattern::matchFailure();
+    }
+    // Make sure that the `squeeze_dims` is equal to `expand_axis`.
+    auto squeeze_dims = squeeze_op.squeeze_dims();
+    if (squeeze_dims.size() != 1 ||
+        squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) {
+      return Pattern::matchFailure();
+    }
+
     // Update previous/next op pointer.
     prev_op = prev_op->getPrevNode();
     if (!prev_op) return Pattern::matchFailure();
@@ -108,10 +137,14 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
 
   // SpaceToBatchND op.
   if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure();
+  // TODO(b/149936532): Check `padding` input, currently ignored.
   TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
 
   // Pad op.
   TF::PadOp pad_op;
+  // TODO(b/149936532): Currently we just ignore the PadOp. However note that
+  // in real scenarios this may not always be correct: user can put a PadOp here
+  // with non-trivial consequences.
   if (llvm::isa<TF::PadOp>(next_op)) {
     pad_op = llvm::cast<TF::PadOp>(next_op);
     next_op = next_op->getNextNode();
@@ -119,6 +152,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
   }
 
   // BatchToSpaceND + BiasAdd.
+  // TODO(b/149936532): Check the `crops` input, currently ignored.
   TF::BatchToSpaceNDOp bts_op;
   TF::BiasAddOp biasadd_op;
   bool final_op_is_bts = true;
@@ -146,14 +180,10 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
   if (!dilations_attr.hasValue()) return Pattern::matchFailure();
   op.setAttr("dilations", dilations_attr.getValue());
 
-  // Here we need to set the correct padding for Conv op. In TF, the conv op
-  // inserted after 'SpaceToBatch' always has 'VALID' padding. This might
-  // become a problem here if the original Conv op has 'SAME' padding. When
-  // the original conv has 'SAME' padding, TF will set a non-zero padding for
-  // the 'SpaceToBatch' op, so we rely on this information to check if we need
-  // to change the padding from 'VALID' to 'SAME' (a.k.a when we see non-zero
-  // values in `stb_op.paddings`, we change the current Conv's padding to
-  // 'SAME').
+  // Padding is set to 'SAME' when `stb_op` has non-zero paddings.
+  // TODO(b/149936532): This assumption only holds when the input width & height
+  // is multiple of dilation width & height. We should fix it in order to
+  // support other use cases.
   auto stb_paddings = stb_op.paddings();
   ElementsAttr stb_paddings_attr;
   if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) {
@@ -175,7 +205,8 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
     auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
     SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
                                          input_shape.end());
-    expand_shape.push_back(1);
+    expand_shape.insert(expand_shape.begin() + expand_axis, 1);
+
     auto expand_result_type = RankedTensorType::get(
         expand_shape, getElementTypeOrSelf(stb_op.input()));
     expand_op.getResult().setType(expand_result_type);
@@ -208,7 +239,7 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
   ElementsAttr stb_bs_attr, bts_bs_attr;
   if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
       !matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
-    // Returns failure status if block shape is not a constant.
+    // Returns failure status if block_shape is not a constant.
     return {};
   }
   // Check that the block_shape of `stb_op` and `bts_op` are equal.
@@ -217,9 +248,8 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
     if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
   }
 
-  // TODO(haoliang): support 1-D dilated conv.
+  // Set dilation factor.
   if (stb_bs_attr.getNumElements() < 2) return {};
-
   int dilation_h_factor =
       stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
   int dilation_w_factor =
diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc
index e07cea8535e..3582046f13f 100644
--- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc
@@ -22,7 +22,7 @@ limitations under the License.
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/Casting.h"
 #include "mlir/Analysis/LoopAnalysis.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Block.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc
index e31b143ab43..f3a15b7ebd3 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc
@@ -15,7 +15,7 @@ limitations under the License.
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringMap.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Block.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
index fd09d2e5c24..683905d06c7 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
@@ -16,7 +16,7 @@ limitations under the License.
 // TFLite legalization patterns
 
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 
@@ -341,7 +341,7 @@ def : Pat<(TF_MatrixDiagOp $diagonal), (TFL_MatrixDiagOp $diagonal)>;
 class I32VectorElementsAttr<int len> : ElementsAttrBase<
   CPred<"$_self.isa<DenseIntElementsAttr>() &&"
       "$_self.cast<DenseIntElementsAttr>().getType()."
-      "getElementType().isInteger(32)">,
+      "getElementType().isSignlessInteger(32)">,
   "32-bit int elements attribute of shape [" # len # "]"> {
 
   let storageType = [{ DenseIntElementsAttr }];
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index 77ea344b3a4..cf24ed7e0f4 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -255,7 +255,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
 
   ShapedType shape_type = shape.getType().cast<ShapedType>();
   // The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
-  if (!shape_type.getElementType().isInteger(32)) {
+  if (!shape_type.getElementType().isSignlessInteger(32)) {
     auto new_shape = shape_type.getShape();
     IntegerType new_ele_type = rewriter.getIntegerType(32);
     ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc
index 7d1dbbb9fcc..ea44a34eb2b 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc
@@ -15,7 +15,7 @@ limitations under the License.
 
 // Converts TF While to TFL While with single call in body and cond.
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc
index 3349261af02..4fde08bc1cf 100644
--- a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc
@@ -20,7 +20,7 @@ limitations under the License.
 #include "llvm/ADT/None.h"
 #include "llvm/ADT/Optional.h"
 #include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
 #include "mlir/Pass/Pass.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index 1b240e2e674..00159644185 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -32,7 +32,7 @@ limitations under the License.
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/LoopAnalysis.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Block.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
@@ -335,8 +335,9 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
       ConversionPatternRewriter &rewriter) const override {
     Type dtype = op.element_dtype();
     if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
-          dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
-          dtype.isInteger(32) || dtype.isInteger(64))) {
+          dtype.isInteger(1) || dtype.isSignlessInteger(8) ||
+          dtype.isSignlessInteger(16) || dtype.isSignlessInteger(32) ||
+          dtype.isSignlessInteger(64))) {
       op.emitError(
           "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
           "integer or 16-bit/32-bit/64-bit float type during TF Lite "
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index 112e7f788ce..dbc12a85b67 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -31,7 +31,7 @@ limitations under the License.
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/Casting.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Matchers.h"  // TF:llvm-project
 #include "mlir/IR/PatternMatch.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index 71017fe2801..0ad5be055dc 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -16,7 +16,7 @@ limitations under the License.
 // This is the optimization pattern definition file for TensorFlow Lite.
 
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/post_quantize_patterns.td
index 283b29ea005..ecceba5316e 100644
--- a/tensorflow/compiler/mlir/lite/transforms/post_quantize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize_patterns.td
@@ -16,7 +16,7 @@ limitations under the License.
 // This is the quantization pattern definition file for TensorFlow Lite.
 
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
 
 // Both Quantize and Dequantize ops have side effects, so we have to define
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
index 7181877085d..98f9c73f791 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
@@ -23,7 +23,7 @@ limitations under the License.
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/raw_ostream.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td
index 07dd8ab4455..22bcc563f7b 100644
--- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td
@@ -16,7 +16,7 @@ limitations under the License.
 // This is the quantization pattern definition file for TensorFlow Lite.
 
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
 
 // Quantize attribute $0 by using quantization parameter from %1.
diff --git a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc
index 17125bffd85..c8aa67084ce 100644
--- a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc
@@ -18,7 +18,7 @@ limitations under the License.
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/Support/Casting.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Block.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td b/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td
index ff024ad0463..b0435b7cf4c 100644
--- a/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td
@@ -14,7 +14,7 @@ limitations under the License.
 ==============================================================================*/
 
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc
index 5a7397ed9c9..13afa1bf9b8 100644
--- a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc
@@ -20,7 +20,7 @@ limitations under the License.
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/CommandLine.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Identifier.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
index 5589fc04c2d..8ed5b0e0341 100644
--- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
@@ -17,7 +17,7 @@ limitations under the License.
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/CommandLine.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Identifier.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc
index a9cc483df76..3d4bbdfa13c 100644
--- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc
@@ -38,7 +38,7 @@ FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) {
 
 IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) {
   if (attr.getType().getNumElements() != 1 ||
-      !attr.getType().getElementType().isa<IntegerType>()) {
+      !attr.getType().getElementType().isSignlessInteger()) {
     return {};
   }
   SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0);
diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.h b/tensorflow/compiler/mlir/lite/utils/attribute_utils.h
index 5a11690d15f..7c0ff910db1 100644
--- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.h
+++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.h
@@ -19,7 +19,7 @@ limitations under the License.
 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 
 namespace mlir {
 namespace TFL {
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
index 6d8bfab0e6c..400d504b8f1 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
@@ -21,7 +21,7 @@ limitations under the License.
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
@@ -95,6 +95,14 @@ Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
   return Transpose(builder, value_to_transpose, perm, type, location);
 }
 
+Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis,
+              RankedTensorType type, mlir::Location location) {
+  auto axis_op = CreateI32SplatConst(builder, {1}, axis, location);
+  // The result type will be the same as the input.
+  return builder->create<TF::ReverseV2Op>(location, type, value_to_reverse,
+                                          axis_op);
+}
+
 ArrayRef<int64_t> GetRankedTensorShape(Value value) {
   return value.getType().cast<RankedTensorType>().getShape();
 }
@@ -615,6 +623,16 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
     final_input_type = final_inputs.getType().dyn_cast<RankedTensorType>();
   }
 
+  // Handle go_backwards:
+  // LSTM in Keras semantic will reverse the input sequence if it's go_backwards
+  auto go_backwards_attr = func_op.getAttrOfType<BoolAttr>("tf.go_backwards");
+
+  if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) {
+    // We assume input is already in {time, batch, size} layout.
+    final_inputs =
+        Reverse(builder, final_inputs, 0, final_input_type, func_op.getLoc());
+  }
+
   int batch = final_input_type.getDimSize(1);
   int time = final_input_type.getDimSize(0);
 
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
index b229206a4e4..0593bd150c7 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Casting.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc
index a12cad15256..4067cfb04b9 100644
--- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc
@@ -17,7 +17,7 @@ limitations under the License.
 
 #include <vector>
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 
 namespace mlir {
diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h
index 917ae93f6a8..635922d5cbb 100644
--- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h
+++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h
@@ -16,7 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 
 namespace mlir {
 namespace TFL {
diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h
index e1ae4392881..fa1304c68e0 100644
--- a/tensorflow/compiler/mlir/lite/utils/validators.h
+++ b/tensorflow/compiler/mlir/lite/utils/validators.h
@@ -19,7 +19,7 @@ limitations under the License.
 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
 
 namespace mlir {
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 58e6cb005c1..466f882133f 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -90,7 +90,7 @@ gentbl(
     td_file = "ir/tf_saved_model_ops.td",
     td_srcs = [
         "@llvm-project//mlir:include/mlir/IR/OpBase.td",
-        "@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
+        "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
     ],
 )
 
@@ -114,7 +114,7 @@ gentbl(
     td_file = "ir/tf_executor_ops.td",
     td_srcs = [
         "@llvm-project//mlir:include/mlir/IR/OpBase.td",
-        "@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
+        "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
     ],
 )
 
@@ -138,7 +138,7 @@ gentbl(
     td_file = "ir/tf_device_ops.td",
     td_srcs = [
         "@llvm-project//mlir:include/mlir/IR/OpBase.td",
-        "@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
+        "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
     ],
 )
 
@@ -281,6 +281,7 @@ cc_library(
         "transforms/generated_canonicalize.inc",
         "transforms/generated_optimize.inc",
         "transforms/graph_pruning.cc",
+        "transforms/launch_to_device_attribute.cc",
         "transforms/layout_optimization.cc",
         "transforms/mark_function_visibility.cc",
         "transforms/materialize_mlir_passthrough_op.cc",
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
index c6144ec21e3..85d87a56f01 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
@@ -26,7 +26,7 @@ limitations under the License.
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/Dialect/Traits.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 25ba5cc2cb1..77e098c37e5 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -3111,6 +3111,70 @@ cublas.
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, AllTypesMatch<["input", "band"]>]> {
+  let summary = [{
+Copy a tensor setting everything outside a central band in each innermost matrix
+to zero.
+  }];
+
+  let description = [{
+The `band` part is computed as follows:
+Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
+tensor with the same shape where
+
+`band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
+
+The indicator function
+
+`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
+                 (num_upper < 0 || (n-m) <= num_upper)`.
+
+For example:
+
+```
+# if 'input' is [[ 0,  1,  2, 3]
+                 [-1,  0,  1, 2]
+                 [-2, -1,  0, 1]
+                 [-3, -2, -1, 0]],
+
+tf.matrix_band_part(input, 1, -1) ==> [[ 0,  1,  2, 3]
+                                       [-1,  0,  1, 2]
+                                       [ 0, -1,  0, 1]
+                                       [ 0,  0, -1, 0]],
+
+tf.matrix_band_part(input, 2, 1) ==> [[ 0,  1,  0, 0]
+                                      [-1,  0,  1, 0]
+                                      [-2, -1,  0, 1]
+                                      [ 0, -2, -1, 0]]
+```
+
+Useful special cases:
+
+```
+ tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
+ tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
+ tf.matrix_band_part(input, 0, 0) ==> Diagonal.
+```
+  }];
+
+  let arguments = (ins
+    TF_Tensor:$input,
+    TF_I32OrI64Tensor:$num_lower,
+    TF_I32OrI64Tensor:$num_upper
+  );
+
+  let results = (outs
+    TF_Tensor:$band
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr Tindex = TF_DerivedOperandTypeAttr<1>;
+
+  let verifier = [{
+    return Verify(*this);
+  }];
+}
+
 def TF_MatrixDiagOp : TF_Op<"MatrixDiag", [NoSideEffect]> {
   let summary = [{
 Returns a batched diagonal tensor with a given batched diagonal values.
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
index f3fdab674e4..92e6d522125 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
@@ -85,7 +85,7 @@ class TF_TensorFlowType <string name, string description> :
 
 // Any tensor element type allowed in TensorFlow ops
 def TF_ElementType : Type<Or<[AnyFloat.predicate,
-                              AnyInteger.predicate,
+                              AnySignlessInteger.predicate,
                               AnyComplex.predicate,
                               TF_TFDialectType.predicate]>,
                           "tf.dtype">;
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index 86de5578553..8d4c284bcf8 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -35,7 +35,7 @@ limitations under the License.
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/Dialect/Traits.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
@@ -1506,6 +1506,29 @@ void LogicalNotOp::getCanonicalizationPatterns(
                  LogicalNotOfLess, LogicalNotOfLessEqual>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// MatrixBandPartOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(MatrixBandPartOp op) {
+  if (!HasRankAtLeast(op.input(), 2)) {
+    return op.emitOpError()
+           << "requires `input` to have rank of at least 2, but found "
+           << op.input().getType();
+  }
+  if (!IsOfRankOrUnranked(op.num_lower(), 0)) {
+    return op.emitOpError()
+           << "requires `num_lower` to have 0 dimensions, but found "
+           << op.num_lower().getType();
+  }
+  if (!IsOfRankOrUnranked(op.num_upper(), 0)) {
+    return op.emitOpError()
+           << "requires `num_upper` to have 0 dimensions, but found "
+           << op.num_upper().getType();
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MaxOp
 //===----------------------------------------------------------------------===//
@@ -2104,7 +2127,8 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
   }
 
   Type element_type = result_ranked_type.getElementType();
-  if (!element_type.isInteger(32) && !element_type.isInteger(64))
+  if (!element_type.isSignlessInteger(32) &&
+      !element_type.isSignlessInteger(64))
     return op->emitOpError("requires int32 or int64 return type for result")
            << variadic_idx_str;
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h
index 2898338f8eb..4059aba209f 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h
@@ -91,7 +91,7 @@ class TensorFlowType : public Type {
 // Returns true if the specified type is a valid TensorFlow element type.
 static inline bool IsValidTFElementType(Type type) {
   return type.isa<ComplexType>() || type.isa<FloatType>() ||
-         type.isa<IntegerType>() || type.isa<TensorFlowType>();
+         type.isSignlessInteger() || type.isa<TensorFlowType>();
 }
 
 // Returns true if this is a valid TensorFlow tensor type.
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute.mlir b/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute.mlir
new file mode 100644
index 00000000000..6e57c3aa6ca
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute.mlir
@@ -0,0 +1,123 @@
+// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-launch-to-device-attribute | FileCheck %s --dump-input=fail
+
+
+// Tests single TensorFlow op is hoisted out and has the correct device assigned
+// by parent `tf_device.launch`.
+// CHECK-LABEL: func @single_op_launch
+func @single_op_launch() {
+  tf_executor.graph {
+    %0:5 = tf_executor.island {
+      %a = "tf.opA"() : () -> tensor<i1>
+      %launch:2 = "tf_device.launch"() ( {
+        %b:2 = "tf.opB"(%a) : (tensor<i1>) -> (tensor<i32>, tensor<f32>)
+        tf_device.return %b#1, %b#0 : tensor<f32>, tensor<i32>
+      }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
+      %c = "tf.opC"() : () -> tensor<i1>
+      tf_executor.yield %a, %launch#0, %launch#1, %c : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
+// CHECK:      %[[A:.*]] = "tf.opA"
+// CHECK:      %[[B:.*]]:2 = "tf.opB"(%[[A]])
+// CHECK-SAME: device = "CPU:0"
+// CHECK:      %[[C:.*]] = "tf.opC"
+// CHECK-NOT:  "tf_device.launch"
+// CHECK:      tf_executor.yield %[[A]], %[[B]]#1, %[[B]]#0, %[[C]]
+
+
+// Tests multiple TensorFlow ops are hoisted out and all have the correct device
+// assigned by parent `tf_device.launch`.
+// CHECK-LABEL: func @multi_op_launch
+func @multi_op_launch() {
+  tf_executor.graph {
+    %0:5 = tf_executor.island {
+      %a = "tf.opA"() : () -> tensor<i1>
+      %launch:2 = "tf_device.launch"() ( {
+        %b = "tf.opB"(%a) : (tensor<i1>) -> tensor<i32>
+        %c = "tf.opC"(%b) : (tensor<i32>) -> tensor<f32>
+        tf_device.return %c, %b : tensor<f32>, tensor<i32>
+      }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
+      %d = "tf.opD"() : () -> tensor<i1>
+      tf_executor.yield %a, %launch#0, %launch#1, %d : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
+// CHECK:      %[[A:.*]] = "tf.opA"
+// CHECK:      %[[B:.*]] = "tf.opB"(%[[A]])
+// CHECK-SAME: device = "CPU:0"
+// CHECK:      %[[C:.*]] = "tf.opC"(%[[B]])
+// CHECK-SAME: device = "CPU:0"
+// CHECK:      %[[D:.*]] = "tf.opD"
+// CHECK-NOT:  "tf_device.launch"
+// CHECK:      tf_executor.yield %[[A]], %[[C]], %[[B]], %[[D]]
+
+
+// -----
+
+
+// Tests ops are hoisted out and devices are set only if the `tf_device.launch`
+// contains TensorFlow ops.
+func @non_tf_dialect_op_launch() {
+  tf_executor.graph {
+    %0:5 = tf_executor.island {
+      %a = "tf.opA"() : () -> tensor<i1>
+      // expected-error@+1 {{'tf_device.launch' op must contain only 'tf' dialect ops}}
+      %launch:2 = "tf_device.launch"() ( {
+        %b = "tf.opB"(%a) : (tensor<i1>) -> tensor<i32>
+        %c = "unknown.opC"(%b) : (tensor<i32>) -> tensor<f32>
+        tf_device.return %c, %b : tensor<f32>, tensor<i32>
+      }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
+      %d = "tf.opD"() : () -> tensor<i1>
+      tf_executor.yield %a, %launch#0, %launch#1, %d : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
+
+// -----
+
+
+// Tests TensorFlow op with conflicting `device` attribute compared to parent
+// `tf_device.launch`.
+func @conflicting_device() {
+  tf_executor.graph {
+    %0 = tf_executor.island {
+      // expected-error@+1 {{'tf_device.launch' op inner 'tf' dialect op has conflicting 'device' attribute, got 'GPU:0' but expected 'CPU:0'}}
+      "tf_device.launch"() ( {
+        "tf.opA"() {device = "GPU:0"} : () -> ()
+        tf_device.return
+      }) {device = "CPU:0"} : () -> ()
+      tf_executor.yield
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
+
+// -----
+
+
+// Tests TensorFlow op with bad `device` attribute already set.
+func @bad_tf_device_attr() {
+  tf_executor.graph {
+    %0 = tf_executor.island {
+      // expected-error@+1 {{'tf_device.launch' op inner 'tf' dialect op has bad 'device' attribute}}
+      "tf_device.launch"() ( {
+        "tf.opA"() {device = 0 : i32} : () -> ()
+        tf_device.return
+      }) {device = "CPU:0"} : () -> ()
+      tf_executor.yield
+    }
+    tf_executor.fetch
+  }
+  return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index 9b29c5c1d92..319660ae4bb 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -854,6 +854,78 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
 
 // -----
 
+// Test valid tf.MatrixBandPart
+// CHECK-LABEL: func @testValidMatrixBandPartOp
+func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
+  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
+  return %0 : tensor<64x64xbf16>
+}
+
+// -----
+
+// Test valid tf.MatrixBandPart
+// CHECK-LABEL: func @testValidMatrixBandPartOp3D
+func @testValidMatrixBandPartOp3D(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64x64xbf16> {
+  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64x64xbf16>
+  return %0 : tensor<64x64x64xbf16>
+}
+
+// -----
+
+// Test valid tf.MatrixBandPart
+// CHECK-LABEL: func @testValidMatrixBandPartOpUnranked
+func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
+  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
+  return %0 : tensor<*xbf16>
+}
+
+// -----
+
+// Test invalid tf.MatrixBandPart
+func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
+  // expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
+  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
+  return %0 : tensor<64x64xbf16>
+}
+
+// -----
+
+// Test invalid tf.MatrixBandPart
+func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
+  // expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
+  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
+  return %0 : tensor<*xbf16>
+}
+
+// -----
+
+// Test invalid tf.MatrixBandPart
+func @testInvalidMatrixBandPartOp(%arg0: tensor<i64>, %arg1: tensor<64x64xi64>, %arg2: tensor<i64>) -> tensor<i64> {
+  // expected-error @+1 {{op requires `input` to have rank of at least 2, but found 'tensor<i64>'}}
+  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<i64>, tensor<64x64xi64>, tensor<i64>) -> tensor<i64>
+  return %0 : tensor<i64>
+}
+
+// -----
+
+// Test invalid tf.MatrixBandPart
+func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<32xi64>, %arg2: tensor<i64>) -> tensor<64x64xi64> {
+  // expected-error @+1 {{op requires `num_lower` to have 0 dimensions, but found 'tensor<32xi64>'}}
+  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<32xi64>, tensor<i64>) -> tensor<64x64xi64>
+  return %0 : tensor<64x64xi64>
+}
+
+// -----
+
+// Test invalid tf.MatrixBandPart
+func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<i64>, %arg2: tensor<32xi64>) -> tensor<64x64xi64> {
+  // expected-error @+1 {{op requires `num_upper` to have 0 dimensions, but found 'tensor<32xi64>'}}
+  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<i64>, tensor<32xi64>) -> tensor<64x64xi64>
+  return %0 : tensor<64x64xi64>
+}
+
+// -----
+
 //===--------------------------------------------------------------------===//
 //  tf.{|Stateful}PartitionedCall
 //===--------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir
new file mode 100644
index 00000000000..afc95865236
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir
@@ -0,0 +1,220 @@
+// RUN: tf-opt -tf-saved-model-optimize-global-tensors -split-input-file %s | FileCheck %s --dump-input=fail
+//===----------------------------------------------------------------------===//
+// Immutability.
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: module attributes {tf_saved_model.semantics}
+module attributes {tf_saved_model.semantics} {
+
+  // Test case: This test exercises marking a global tensor as immutable after it propagates
+  // via set of chained calls -> f -> f_callee -> f_callee_callee
+
+  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK-NOT: is_mutable
+  // CHECK-SAME: } : () -> ()
+  "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
+
+  func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
+  attributes {tf_saved_model.exported_names = ["f"]} {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+  func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+  func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
+    return %val : tensor<f32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module attributes {tf_saved_model.semantics}
+module attributes {tf_saved_model.semantics} {
+
+  // Test case:
+  // This test exercises trying to mark immutable when same func is called by multiple callers
+  // with different global tensors.
+  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK-NOT: is_mutable
+  "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
+
+  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK-NOT: is_mutable
+  "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v2", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
+
+  func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
+  attributes {tf_saved_model.exported_names = ["f"]} {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_common} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+  func @f2(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v2}) -> (tensor<f32> {tf_saved_model.index_path = []})
+  attributes {tf_saved_model.exported_names = ["f2"]} {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_common} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+  func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
+    return %val : tensor<f32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module attributes {tf_saved_model.semantics}
+module attributes {tf_saved_model.semantics} {
+
+  // Test case: This test exercises immutability without explicit use
+  // via ReadVariableOp
+
+  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK-NOT: is_mutable
+  // CHECK-SAME: } : () -> ()
+  "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
+
+  func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
+  attributes {tf_saved_model.exported_names = ["f"]} {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
+    %val_2 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
+    return %val_2 : tensor<f32>
+  }
+
+  func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %cst_1 = constant dense<2.0> : tensor<f32>
+    return %cst_1 : tensor<f32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test case: Test mutation detection propagates across function calls
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: module attributes {tf_saved_model.semantics}
+module attributes {tf_saved_model.semantics} {
+  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK-SAME: is_mutable
+  // CHECK-SAME: } : () -> ()
+  "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
+
+  // CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
+  func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
+  attributes {tf_saved_model.exported_names = ["f"]} {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+  // CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
+  func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+  // CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
+  func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
+    "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
+    return %c0 : tensor<f32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module attributes {tf_saved_model.semantics}
+module attributes {tf_saved_model.semantics} {
+
+  // Test case: The inter-procedural analysis with different types of
+  // TF call ops
+
+  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK-SAME: is_mutable
+  // CHECK-SAME: } : () -> ()
+  "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
+
+  // CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
+  func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
+  attributes {tf_saved_model.exported_names = ["f"]} {
+    %val = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+  // CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
+  func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+  // CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
+  func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
+    "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
+    return %c0 : tensor<f32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module attributes {tf_saved_model.semantics}
+module attributes {tf_saved_model.semantics} {
+
+  // Test case: The inter-procedural analysis does not recurse infinitely
+
+  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK-NOT: is_mutable
+  // CHECK-SAME: } : () -> ()
+  "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
+
+  func @exported_f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
+  attributes {tf_saved_model.exported_names = ["exported_f"]} {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+
+  // CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
+  func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+  // CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
+  func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module attributes {tf_saved_model.semantics}
+module attributes {tf_saved_model.semantics} {
+
+  // Test case: Inter-procedural analysis with resource usage in an
+  // unknown op, we assume mutating behavior and propagate that.
+
+  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK-SAME: is_mutable
+  // CHECK-SAME: } : () -> ()
+  "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
+
+  func @exported_f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
+  attributes {tf_saved_model.exported_names = ["exported_f"]} {
+    %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
+    return %val : tensor<f32>
+  }
+
+
+  // CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
+  func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
+    %c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
+    "tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
+    return %c0 : tensor<f32>
+  }
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
index f181924d0a6..0fef58ebb8a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
@@ -17,7 +17,7 @@ limitations under the License.
 // `tf_device.launch` with equivalent `tf_device.launch_func` operations.
 
 #include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Block.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td
index a95a319d0a4..bac7b9ba01c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 
 // Here, the element type can be any integer or float type. But, note that only
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc
index 9660367cb68..ad844883453 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc
@@ -17,7 +17,7 @@ limitations under the License.
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Visitors.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc
index 57ea1822b5b..01901d8b5a4 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc
@@ -16,7 +16,7 @@ limitations under the License.
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/SymbolTable.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
index 44309a5e019..4d5ad5ad423 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
@@ -31,7 +31,7 @@ limitations under the License.
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/LoopAnalysis.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Block.h"  // TF:llvm-project
 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc
index 31898ec1048..8cfa69c396e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc
@@ -16,7 +16,7 @@ limitations under the License.
 // This transformation pass transforms functional control flow operations in the
 // standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc
new file mode 100644
index 00000000000..9a196aef54b
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc
@@ -0,0 +1,135 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This pass hoists a `tf_device.launch` body and assigns a `device` attribute
+// to each TensorFlow dialect op in the body based on the `device` attribute on
+// the `tf_device.launch`. If a TensorFlow dialect op already has a device
+// attribute, that attribute will be overwritten with the `tf_device.launch`
+// device.
+//
+// For example:
+//   %island:5 = tf_executor.island {
+//     %a = "tf.opA"() : () -> tensor<i1>
+//     %launch:2 = "tf_device.launch"() ( {
+//       %b = "tf.opB"() : () -> tensor<i32>
+//       %c = "tf.opC"() : () -> tensor<f32>
+//       tf_device.return %c, %b : tensor<f32>, tensor<i32>
+//     }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
+//     %d = "tf.opD"() : () -> tensor<i1>
+//     tf_executor.yield %a, %launch#0, %launch#1, %d :
+//                       tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
+//   }
+//
+// Will be transformed into:
+//   %island:5 = tf_executor.island {
+//     %a = "tf.opA"() : () -> tensor<i1>
+//     %b = "tf.opB"() {device = "CPU:0"} : () -> tensor<i32>
+//     %c = "tf.opC"() {device = "CPU:0"} : () -> tensor<f32>
+//     %d = "tf.opD"() : () -> tensor<i1>
+//     tf_executor.yield %a, %c, %b, %d :
+//                       tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
+//   }
+
+#include "mlir/IR/Attributes.h"  // TF:llvm-project
+#include "mlir/IR/Block.h"  // TF:llvm-project
+#include "mlir/IR/Dialect.h"  // TF:llvm-project
+#include "mlir/IR/Operation.h"  // TF:llvm-project
+#include "mlir/IR/Visitors.h"  // TF:llvm-project
+#include "mlir/Pass/Pass.h"  // TF:llvm-project
+#include "mlir/Support/LogicalResult.h"  // TF:llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+
+namespace mlir {
+namespace TFDevice {
+namespace {
+constexpr char kDeviceAttr[] = "device";
+
+struct LaunchToDeviceAttributePass
+    : public FunctionPass<LaunchToDeviceAttributePass> {
+  void runOnFunction() override;
+};
+
+LogicalResult HoistOpsAndAnnotateWithDevice(const Dialect* tf_dialect,
+                                            tf_device::LaunchOp launch) {
+  // Forward launch inner op results to launch op results.
+  launch.replaceAllUsesWith(launch.GetBody().getTerminator()->getOperands());
+
+  // For all inner ops of the TensorFlow dialect, assign the launch device as a
+  // `device` attribute.
+  auto body = launch.GetBody().without_terminator();
+  for (Operation& op : body) {
+    if (op.getDialect() != tf_dialect)
+      return launch.emitOpError() << "must contain only 'tf' dialect ops";
+
+    auto device_attr = op.getAttr(kDeviceAttr);
+    if (!device_attr) {
+      op.setAttr(kDeviceAttr, launch.deviceAttr());
+      continue;
+    }
+
+    if (auto device_str_attr = device_attr.dyn_cast<StringAttr>()) {
+      if (launch.device() != device_str_attr.getValue())
+        return launch.emitOpError()
+               << "inner 'tf' dialect op has conflicting 'device' attribute, "
+                  "got '"
+               << device_str_attr.getValue() << "' but expected '"
+               << launch.device() << "'";
+    } else {
+      return launch.emitOpError()
+             << "inner 'tf' dialect op has bad 'device' attribute";
+    }
+  }
+
+  // Move all inner ops of the launch to the block containing the launch.
+  Operation* launch_op = launch.getOperation();
+  launch_op->getBlock()->getOperations().splice(
+      launch_op->getIterator(), launch.GetBody().getOperations(), body.begin(),
+      body.end());
+
+  launch.erase();
+
+  return success();
+}
+
+void LaunchToDeviceAttributePass::runOnFunction() {
+  const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
+  if (!tf_dialect) {
+    signalPassFailure();
+    getFunction().emitError() << "'tf' dialect is not registered";
+  }
+
+  auto result = getFunction().walk([&](tf_device::LaunchOp launch) {
+    if (failed(HoistOpsAndAnnotateWithDevice(tf_dialect, launch)))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (result.wasInterrupted()) return signalPassFailure();
+}
+
+}  // anonymous namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass() {
+  return std::make_unique<LaunchToDeviceAttributePass>();
+}
+
+static PassRegistration<LaunchToDeviceAttributePass> pass(
+    "tf-launch-to-device-attribute",
+    "Hoists and annotates device launch inner ops with associated device "
+    "attribute");
+
+}  // namespace TFDevice
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
index ec0ac5e3c1e..1074f9e1926 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 
 // Here, the element type can be any integer or float type. But, note that only
@@ -122,7 +122,7 @@ def LowerSparseSoftmaxCrossEntropyWithLogitsOp : Pattern<
 //===----------------------------------------------------------------------===//
 
 def ComplexTensor   : TensorOf<[AnyComplex]>;
-def RealTensor   : TensorOf<[AnyInteger, AnyFloat]>;
+def RealTensor   : TensorOf<[AnySignlessInteger, AnyFloat]>;
 
 def : Pat<(TF_SquareOp $val), (TF_MulOp $val, $val)>;
 
@@ -179,7 +179,7 @@ def LowerL2LossOp :
 // Pad op patterns.
 //===----------------------------------------------------------------------===//
 
-def : Pat<(TF_PadOp TensorOf<[AnyInteger, AnyFloat]>:$input, $paddings),
+def : Pat<(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings),
           (TF_PadV2Op $input, $paddings,
              (TF_ConstOp (GetScalarOfType<0> $input)))>;
 
@@ -224,6 +224,6 @@ def CreateTFShapeOp : NativeCodeCall<
 
 // TODO(hinsu): Support inputs of TensorList types.
 def LowerZerosLikeOp :
-  Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnyInteger, AnyFloat]>:$input),
+  Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input),
       (TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)),
                         (CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>;
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
index 36d7712eb2c..acaf7974280 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
@@ -14,7 +14,7 @@ limitations under the License.
 ==============================================================================*/
 #include <iostream>
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td
index 87467238e57..0fb62cb064d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 
 def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc
index 2a61b0cf0d1..1aaceb8ecc7 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc
@@ -15,16 +15,27 @@ limitations under the License.
 
 // This pass optimizes tf_saved_model.global_tensor ops.
 
+#include <cstddef>
 #include <map>
 #include <set>
 
 #include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/STLExtras.h"
+#include "mlir/Analysis/CallInterfaces.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
+#include "mlir/IR/Function.h"  // TF:llvm-project
 #include "mlir/IR/Module.h"  // TF:llvm-project
+#include "mlir/IR/Operation.h"  // TF:llvm-project
+#include "mlir/IR/StandardTypes.h"  // TF:llvm-project
+#include "mlir/IR/SymbolTable.h"  // TF:llvm-project
+#include "mlir/IR/Types.h"  // TF:llvm-project
 #include "mlir/Pass/Pass.h"  // TF:llvm-project
+#include "mlir/Support/LLVM.h"  // TF:llvm-project
+#include "mlir/Support/LogicalResult.h"  // TF:llvm-project
+#include "mlir/Transforms/RegionUtils.h"  // TF:llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 
 namespace mlir {
 namespace tf_saved_model {
@@ -45,8 +56,142 @@ struct GlobalTensorUse {
 using GlobalTensorUsesMap =
     std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
 
+static bool IsResourceType(Type type) {
+  if (auto tensor_type = type.dyn_cast<TensorType>()) {
+    return tensor_type.getElementType().isa<TF::ResourceType>();
+  }
+  return false;
+}
+
+static bool IsResource(Value value) { return IsResourceType(value.getType()); }
+
+class ResourceAnalyzer {
+ public:
+  explicit ResourceAnalyzer(ModuleOp module) {
+    SymbolTable symbol_table(module);
+    for (auto func : module.getOps<FuncOp>()) {
+      AnalyzeFunc(func, symbol_table);
+    }
+  }
+
+  bool IsPotentiallyWritten(Value resource) const {
+    assert(IsResource(resource));
+    auto it = resource_infos_.find(resource);
+    if (it == resource_infos_.end()) {
+      return false;
+    }
+    return it->second.potentially_written;
+  }
+
+ private:
+  // Analyze the specified func for resource mutating operations, namely
+  // TF::AssignVariableOp, if so, set the resource associated as "potentially
+  // written". Do this recursively across the chain of funcs via call or control
+  // flow ops.
+  // TODO(ashwinm): Move to iterative traversal.
+  LogicalResult AnalyzeFunc(FuncOp func, const SymbolTable& symbol_table) {
+    // Avoid infinite recursion.
+    if (!discovered_.insert(func).second) {
+      return success();
+    }
+
+    func.walk([&](Operation* op) {
+      if (isa<TF::ReadVariableOp>(op) || isa<ReturnOp>(op)) {
+        return;
+      }
+      if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {
+        SetPotentiallyWritten(assign_variable.resource());
+        return;
+      }
+      if (auto call = dyn_cast<CallOpInterface>(op)) {
+        if (auto sym = op->getAttrOfType<SymbolRefAttr>("f")) {
+          PropagatePotentiallyWrittenUpFromCallee(
+              sym.cast<FlatSymbolRefAttr>().getValue(), call.getArgOperands(),
+              symbol_table);
+        }
+        return;
+      }
+      if (auto if_op = dyn_cast<TF::IfOp>(op)) {
+        for (auto callee : {if_op.then_branch(), if_op.else_branch()}) {
+          PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input(),
+                                                  symbol_table);
+        }
+        return;
+      }
+      if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
+        for (auto callee : {while_op.cond(), while_op.body()}) {
+          PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input(),
+                                                  symbol_table);
+        }
+        return;
+      }
+      // For all other ops, we assume it mutates all resources it uses, so
+      // this errs on the side of being conservative. We should improve
+      // this by using either a property or a trait that clearly
+      // identifies ops with resource mutating behavior.
+      if (PropagatePotentiallyWrittenWithinUnhandledOp(op)) {
+        return;
+      }
+    });
+    return success();
+  }
+
+  // If an op is not one of the handled ones, we assume all resource usages
+  // within its purview are mutating in nature.
+  bool PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) {
+    for (auto operand : op->getOperands()) {
+      if (IsResource(operand)) {
+        SetPotentiallyWritten(operand);
+        return true;
+      }
+    }
+    bool uses_resources = false;
+    visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
+      if (IsResource(operand->get())) {
+        SetPotentiallyWritten(operand->get());
+        uses_resources = true;
+      }
+    });
+    return uses_resources;
+  }
+
+  // Given a funcOp associated with the callee and operands from the
+  // corresponding callOp, propagate the potentially written decision to the
+  // callOp's operands, if the corresponding func's arguments are potentially
+  // written resources.
+  void PropagatePotentiallyWrittenUpFromCallee(
+      StringRef callee, Operation::operand_range propagate_to,
+      const SymbolTable& symbol_table) {
+    auto func = symbol_table.lookup<FuncOp>(callee);
+    AnalyzeFunc(func, symbol_table);
+    for (auto t : llvm::zip(func.getArguments(), propagate_to)) {
+      if (!IsResource(std::get<0>(t))) {
+        continue;
+      }
+      if (IsPotentiallyWritten(std::get<0>(t))) {
+        SetPotentiallyWritten(std::get<1>(t));
+      }
+    }
+  }
+
+  void SetPotentiallyWritten(Value resource) {
+    assert(IsResource(resource));
+    resource_infos_[resource].potentially_written = true;
+  }
+  struct ResourceInfo {
+    bool potentially_written = false;
+  };
+  // Key: Resource Value's
+  // Value: Information we know about that Value.
+  // Note that these Value's are in general in different functions.
+  DenseMap<Value, ResourceInfo> resource_infos_;
+  // The set of func's we already discovered.
+  DenseSet<FuncOp> discovered_;
+};
+
 bool IsImmutable(GlobalTensorOp global_tensor,
-                 ArrayRef<GlobalTensorUse> global_tensor_uses) {
+                 ArrayRef<GlobalTensorUse> global_tensor_uses,
+                 const ResourceAnalyzer& resource_analyzer) {
   // Global tensor is already known to be immutable.
   if (!global_tensor.is_mutable()) {
     return false;
@@ -57,17 +202,11 @@ bool IsImmutable(GlobalTensorOp global_tensor,
     return false;
   }
 
-  // Check the uses to see if this global tensor is only used in a way that
-  // is compatible with being immutable.
-  // Right now, this uses a very simple algorithm that only checks the top-level
-  // func for tf.ReadVariableOp. If the resource is passed into other functions
-  // or control flow, we fail to prove it is freezable even though we could.
+  // A global tensor is immutable if the resource analyzer deems it so.
   for (auto& global_tensor_use : global_tensor_uses) {
     auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
-    for (auto user : arg.getUsers()) {
-      if (!isa<TF::ReadVariableOp>(user)) {
-        return false;
-      }
+    if (resource_analyzer.IsPotentiallyWritten(arg)) {
+      return false;
     }
   }
   return true;
@@ -96,11 +235,12 @@ static GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
 // Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we
 // can prove it is safe to do so.
 void MarkGlobalTensorsImmutable(
-    ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map) {
+    ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map,
+    const ResourceAnalyzer& resource_analyzer) {
   for (const auto& kv : global_tensor_uses_map) {
     auto global_tensor = kv.first;
     const auto& global_tensor_uses = kv.second;
-    if (IsImmutable(global_tensor, global_tensor_uses)) {
+    if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
       global_tensor.removeAttr("is_mutable");
     }
   }
@@ -137,17 +277,15 @@ void EraseUnusedBoundInputs(ModuleOp module) {
 }
 
 void OptimizeGlobalTensorsPass::runOnModule() {
-  // This analysis could be much more elaborate, including tracking global
-  // tensors interprocedurally and uses in a wide variety of ops. But I don't
-  // know if we need that complexity.
   auto module = getModule();
-
   EraseUnusedBoundInputs(module);
 
-  // Figure out which func's use each tf_saved_model.global_tensor.
+  ResourceAnalyzer resource_analyzer(module);
+
   GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module);
 
-  MarkGlobalTensorsImmutable(module, global_tensor_uses);
+  MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer);
+
   EraseUnusedGlobalTensors(module, global_tensor_uses);
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index 16d8ddfb900..17af8c3cfbb 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -186,6 +186,10 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateParallelExecuteToIslandsPass();
 // same data across replicas.
 std::unique_ptr<OpPassBase<ModuleOp>> CreateAnnotateParameterReplicationPass();
 
+// Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
+// attribute to each TensorFlow dialect op in the body based on the `device`
+// attribute on the `tf_device.launch`.
+std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass();
 }  // namespace TFDevice
 
 namespace TFTPU {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
index 13dda3ed0d0..d3cc508a490 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
@@ -27,7 +27,7 @@ limitations under the License.
 
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
 #include "mlir/Pass/Pass.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
index 8dc21feca90..dee5e8b079f 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
@@ -22,7 +22,7 @@ limitations under the License.
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Block.h"  // TF:llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index c44f0f97fd6..b3474e2faf1 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -25,7 +25,7 @@ limitations under the License.
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Block.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Diagnostics.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
index 384b66bc737..eb57e8ff742 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
@@ -26,7 +26,7 @@ limitations under the License.
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Casting.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
index 8136db7d164..32cb2e02930 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
@@ -19,7 +19,7 @@ limitations under the License.
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc
index 672ba418489..b89b3d8e6b2 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc
@@ -22,7 +22,7 @@ limitations under the License.
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/Support/Debug.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
 #include "mlir/IR/Value.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc
index 96a7fcbb5ba..7755f5f2259 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc
@@ -21,7 +21,7 @@ limitations under the License.
 #include "llvm/ADT/SmallString.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
 #include "mlir/IR/Value.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
index 529c2517508..0ae02ed63b6 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
@@ -28,7 +28,7 @@ limitations under the License.
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Casting.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index 3a624490da1..1f0f8e2b9de 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -42,7 +42,7 @@ limitations under the License.
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Analysis/Verifier.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h
index 79a302b066b..4a67b7fae76 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h
+++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h
@@ -16,7 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/lib/core/status.h"
 
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc
index a97bca9fc3d..2ee3893eac9 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc
@@ -14,7 +14,7 @@ limitations under the License.
 ==============================================================================*/
 
 #include "llvm/Support/Debug.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
 #include "mlir/Pass/Pass.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index cce5fde2883..ef1e52ee5c0 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -17,7 +17,7 @@ limitations under the License.
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/StringRef.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
 #include "mlir/IR/OpDefinition.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
index a64b7ecfdb3..f8c118ac9d9 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
@@ -23,7 +23,7 @@ limitations under the License.
 #include "absl/strings/string_view.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Casting.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
 #include "mlir/IR/Identifier.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index bc9bdf49a39..f00e880f36b 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -19,7 +19,7 @@ limitations under the License.
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/raw_ostream.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc
index f8eabeb046d..82304f95e33 100644
--- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc
@@ -15,7 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/mlir/xla/hlo_module_importer.h"
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
 #include "mlir/IR/OperationSupport.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index 41ef8690735..b011b6069c7 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -1130,7 +1130,7 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value operand,
   // Illegal attributes.
   ShapedType attr_ty = start_indices.getType();
   if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
-      !attr_ty.getElementType().isInteger(64) ||
+      !attr_ty.getElementType().isSignlessInteger(64) ||
       limit_indices.getType() != attr_ty || strides.getType() != attr_ty)
     return ty;
 
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 269e1cc8897..42b42d99380 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -52,7 +52,7 @@ def HLO_FpTensor : TensorOf<[AnyFloat]>;
 
 def HLO_PredTensor : TensorOf<[HLO_Pred]>;
 
-def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger, AnyComplex]>;
+def HLO_Tensor : TensorOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
 
 def HLO_ComplexTensor : TensorOf<[AnyComplex]>;
 
@@ -64,13 +64,13 @@ def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
 // an index type (as it stores indices) but that is currently disallowed in
 // MLIR.
 def HLO_DimensionTensor : ShapedContainerType<
-    [AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
+    [AnySignlessInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
     "a 1D tensor of dimensions">;
 
 // In general, static shaped tensor constraints should be avoided unless
 // it is for a legacy op which is only correct with static shapes.
 def HLO_StaticShapeTensor : StaticShapeTensorOf<[
-      AnyFloat, AnyInteger, AnyComplex]>;
+      AnyFloat, AnySignlessInteger, AnyComplex]>;
 
 //===----------------------------------------------------------------------===//
 // XLA combined type definitions.
@@ -784,7 +784,7 @@ def HLO_ScalarsToDimensionTensorOp : HLO_Op<"scalars_to_dimension_tensor",
     compute shape arguments to dynamic operations.
   }];
 
-  let arguments = (ins Variadic<AnyInteger>);
+  let arguments = (ins Variadic<AnySignlessInteger>);
   let results = (outs HLO_DimensionTensor);
 
   // Cannot be exported to legacy formats.
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h
index 120b035e5d0..3e3570f5b54 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h
@@ -40,7 +40,7 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
 
   // Handle integer elements.
   Attribute elementAttr;
-  if (valElementType.isa<IntegerType>())
+  if (valElementType.isSignlessInteger())
     elementAttr = b->getIntegerAttr(valElementType, constant);
   else if (valElementType.isa<FloatType>())
     elementAttr = b->getFloatAttr(valElementType, constant);
diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
index 794fee181a6..3a675f20d92 100644
--- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
@@ -42,7 +42,7 @@ def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
 // Any integer or floating-point tensor types
 def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
 
-def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger]>;
+def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger]>;
 
 def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>;
 
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 8fa7d809024..92614755ec3 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -27,7 +27,7 @@ limitations under the License.
 #include "llvm/Support/SMLoc.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
@@ -1221,7 +1221,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
                  << "requires arg " << padding_arg_index
                  << " to be a scalar for use as a dynamic parameter";
 
-      if (!mlir::getElementTypeOrSelf(padding_arg_type).isa<IntegerType>())
+      if (!mlir::getElementTypeOrSelf(padding_arg_type).isSignlessInteger())
         return entry_func.emitError()
                << "requires arg " << padding_arg_index
                << " to be of an int type for use as a dynamic parameter";
diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir
index 33d5884a882..d43ca3b6bb2 100644
--- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir
@@ -210,7 +210,6 @@ func @broadcast(%operand: memref<5x7x1xf32>, %result: memref<7x10x6x4x5xf32>) {
 
 // -----
 
-// CHECK-DAG: #[[OPERAND_MPA:.*]] = affine_map<(d0, d1, d2) -> (0)>
 // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func @broadcast_scalar
 func @broadcast_scalar(%operand: memref<f32>, %result: memref<7x10x6xf32>) {
@@ -219,9 +218,10 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<7x10x6xf32>) {
     : (memref<f32>, memref<7x10x6xf32>) -> ()
   return
 }
-// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
-// CHECK-NEXT: ^bb0(%[[OPERAND:[a-zA-Z0-9_]*]]: f32, %[[RESULT:.*]]: f32):
-// CHECK-NEXT:   linalg.yield %[[OPERAND]] : f32
+// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
+// CHECK-NEXT: ^bb0(%[[RESULT:.*]]: f32):
+// CHECK-NEXT: %[[CONST:.*]] = load %{{.*}} : memref<f32>
+// CHECK-NEXT:   linalg.yield %[[CONST]] : f32
 
 // -----
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
index 1384abed91c..29d399c68fa 100644
--- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
@@ -16,7 +16,7 @@ limitations under the License.
 // This file implements logic for lowering HLO dialect to LHLO dialect.
 
 #include "absl/memory/memory.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h b/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h
index 7c6d162632f..d2a1f47e540 100644
--- a/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h
+++ b/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h
@@ -16,7 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
 #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc
index 8351f94d172..72ea2e18ec0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc
@@ -18,7 +18,7 @@ limitations under the License.
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/Casting.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Block.h"  // TF:llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index da135ea1860..8f955d6944a 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -24,7 +24,7 @@ limitations under the License.
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/Dialect/Traits.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Diagnostics.h"  // TF:llvm-project
@@ -119,7 +119,7 @@ static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
 Type GetSumAccumulationType(Type input_type) {
   MLIRContext *ctx = input_type.getContext();
   if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
-  if (input_type.isInteger(8) || input_type.isInteger(16))
+  if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
     return IntegerType::get(32, ctx);
   return input_type;
 }
@@ -1274,7 +1274,7 @@ class ConvertMaxPoolOp : public OpRewritePattern<TF::MaxPoolOp> {
                                      PatternRewriter &rewriter) const override {
     Type element_type =
         op.input().getType().cast<TensorType>().getElementType();
-    if (!element_type.isIntOrFloat()) return matchFailure();
+    if (!element_type.isSignlessIntOrFloat()) return matchFailure();
     Location loc = op.getLoc();
     ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
 
@@ -2248,7 +2248,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
     Type input_element_type = input_type.getElementType();
     // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
     // tf.ArgMax doesn't support complex data types, this check can be removed.
-    if (!input_element_type.isIntOrFloat()) return this->matchFailure();
+    if (!input_element_type.isSignlessIntOrFloat()) return this->matchFailure();
 
     Location loc = op.getLoc();
     Value init_value =
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
index 58e98a881e9..265466ef3a4 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
@@ -28,7 +28,7 @@ limitations under the License.
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/iterator_range.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index 872a288c259..519ba9235f1 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -16,7 +16,7 @@ limitations under the License.
 // This is the legalization pattern definition file for TF to XLA.
 
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
 
@@ -504,8 +504,8 @@ def : Pat<(TF_SignOp $x),
           )>;
 
 def BothElementTypesSameWidthIntOrFloat : Constraint<CPred<
-  "getElementTypeOrSelf($0.getType()).isIntOrFloat() && "
-  "getElementTypeOrSelf($1.getType()).isIntOrFloat() && "
+  "getElementTypeOrSelf($0.getType()).isSignlessIntOrFloat() && "
+  "getElementTypeOrSelf($1.getType()).isSignlessIntOrFloat() && "
   "getElementTypeOrSelf($0.getType()).getIntOrFloatBitWidth() == "
   "getElementTypeOrSelf($1.getType()).getIntOrFloatBitWidth()">,
   "element types must be integers or floats of same width">;
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
index 5ee6010c3a8..3c15f0be7e8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
@@ -16,7 +16,7 @@ limitations under the License.
 // This file implements logic for lowering XLA dialect to Standard dialect.
 
 #include "llvm/ADT/StringSwitch.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
 #include "mlir/IR/PatternMatch.h"  // TF:llvm-project
 #include "mlir/Pass/Pass.h"  // TF:llvm-project
@@ -45,8 +45,8 @@ class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
     // Broadcasting not supported by this rewrite.
     if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure();
 
-    if (!lhs_type.getElementType().isa<IntegerType>() ||
-        !rhs_type.getElementType().isa<IntegerType>())
+    if (!lhs_type.getElementType().isSignlessInteger() ||
+        !rhs_type.getElementType().isSignlessInteger())
       return matchFailure();
 
     auto comparison_direction = op.comparison_direction();
@@ -113,7 +113,8 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
                                      PatternRewriter &rewriter) const override {
     auto output_type = op.getType().cast<ShapedType>();
     // TODO(prakalps): Handle FP and ComplexType iota ops.
-    if (!output_type.getElementType().isa<IntegerType>()) return matchFailure();
+    if (!output_type.getElementType().isSignlessInteger())
+      return matchFailure();
     auto output_size = output_type.getNumElements();
     auto dimension = op.iota_dimension().getSExtValue();
     auto max_dim_size = output_type.getDimSize(dimension);
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td
index a15b28193cd..c0f6c2c3541 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td
@@ -16,7 +16,7 @@ limitations under the License.
 // This is the legalization pattern definition file for XLA to StandardOps.
 
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc
index b0f6b83038a..2c550465302 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc
@@ -17,7 +17,7 @@ limitations under the License.
 
 #include "absl/memory/memory.h"
 #include "mlir/Dialect/AffineOps/AffineOps.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
index e991b186d72..c9245d93e56 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
@@ -22,7 +22,7 @@ limitations under the License.
 #include "mlir/Dialect/GPU/GPUDialect.h"  // TF:llvm-project
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"  // TF:llvm-project
 #include "mlir/Dialect/LoopOps/LoopOps.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td
index d8a5ae6c6de..dcb0ab20e9e 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td
@@ -17,7 +17,7 @@ limitations under the License.
 // equivalent real value operations.
 
 include "mlir/IR/OpBase.td"
-include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc
index c956cd6b277..f18607dfffb 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc
@@ -17,7 +17,7 @@ limitations under the License.
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringSwitch.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h
index 0e33a8646ad..6554942954e 100644
--- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h
@@ -18,7 +18,7 @@ limitations under the License.
 
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
 
@@ -206,7 +206,7 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
   const auto& lhs = args[0];
   const auto& rhs = args[1];
   Type element_type = lhs.getType();
-  if (element_type.isa<IntegerType>()) {
+  if (element_type.isSignlessInteger()) {
     Optional<CmpIPredicate> predicate =
         getCmpPredicate<CmpIPredicate>(xla_op.comparison_direction());
     assert(predicate.hasValue() && "expected valid comparison direction");
@@ -288,8 +288,8 @@ template <>
 inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
     xla_lhlo::ConvertOp xla_op, ArrayRef<Type> result_types,
     ArrayRef<Value> args, OpBuilder* b) {
-  const Type& sourceType = args.front().getType();
-  const Type& targetType = result_types.front();
+  Type sourceType = args.front().getType();
+  Type targetType = result_types.front();
 
   if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
     return b->create<mlir::SIToFPOp>(xla_op.getLoc(), result_types, args,
@@ -307,7 +307,7 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
     // No conversion is needed for the same width floats
     return args.front();
   }
-  if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) {
+  if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
     IntegerType src = sourceType.cast<IntegerType>();
     IntegerType res = targetType.cast<IntegerType>();
     if (src.getWidth() > res.getWidth()) {
diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc
index 13467be41d9..4c20a589ce0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc
@@ -15,7 +15,7 @@ limitations under the License.
 
 #include <numeric>
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
 #include "mlir/IR/PatternMatch.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc
index 596b67f0eed..644fffcc7ea 100644
--- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
 #include "mlir/IR/PatternMatch.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc
index 46df664dc42..7f7060fef64 100644
--- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc
@@ -19,7 +19,7 @@ limitations under the License.
 #include "llvm/ADT/APInt.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"  // TF:llvm-project
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/AffineExpr.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
@@ -83,7 +83,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
       emitError(loc, "lhlo to linalg conversion expects ranked args");
       return ConversionPattern::matchFailure();
     }
-    if (!argType.getElementType().isIntOrFloat()) {
+    if (!argType.getElementType().isSignlessIntOrFloat()) {
       return ConversionPattern::matchFailure();
     }
 
@@ -171,7 +171,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
     auto loc = lhlo_op.getLoc();
     auto argType =
         lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
-    if (!argType || !argType.getElementType().isIntOrFloat() ||
+    if (!argType || !argType.getElementType().isSignlessIntOrFloat() ||
         (argType.getRank() != 0)) {
       return ConversionPattern::matchFailure();
     }
@@ -208,6 +208,9 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
     auto resultType = getXLAOpResultType<isLHLO>(op);
     if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op))
       return ConversionPattern::matchFailure();
+    // TODO(b/150203558) Enable once tiling/fusion works in this case.
+    if (isLHLO && (operandType.getRank() == 0))
+      return ConversionPattern::matchFailure();
     ArrayAttr indexingMapsAttr =
         static_cast<const Derived&>(*this).getIndexingMapsAttr(op, &rewriter);
     if (!indexingMapsAttr) return ConversionPattern::matchFailure();
@@ -278,6 +281,52 @@ class BroadcastInDimConverter
   }
 };
 
+// Special case for scalar broadcast in lhlo.
+// TODO(b/150203558) Remove once the bug is fixed.
+class ScalarBroadcastInDimConverter
+    : public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
+ public:
+  using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
+
+  PatternMatchResult matchAndRewrite(
+      xla_lhlo::BroadcastInDimOp broadcastOp, ArrayRef<Value> args,
+      ConversionPatternRewriter& rewriter) const final {
+    auto operandMemrefType =
+        broadcastOp.operand().getType().dyn_cast<MemRefType>();
+    // Only support scalar operands.
+    if (operandMemrefType.getRank() != 0) return matchFailure();
+    auto resultMemrefType =
+        broadcastOp.output().getType().dyn_cast<MemRefType>();
+    if (!operandMemrefType || !resultMemrefType) return matchFailure();
+    auto broadcastDims = broadcastOp.broadcast_dimensions();
+    if (!broadcastDims.hasValue()) return matchFailure();
+
+    unsigned nloops = resultMemrefType.getRank();
+    SmallVector<Attribute, 1> indexingMaps{
+        AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops))};
+    auto loc = broadcastOp.getLoc();
+    auto linalgOp = rewriter.create<linalg::GenericOp>(
+        loc, ArrayRef<Type>{}, broadcastOp.output(),
+        rewriter.getI64IntegerAttr(0),  // args_in
+        rewriter.getI64IntegerAttr(1),  // args_out
+        rewriter.getArrayAttr(indexingMaps),
+        GetNParallelLoopsAttrs(nloops, &rewriter),
+        /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr);
+
+    // Add a block to the region.
+    auto* region = &linalgOp.region();
+    auto* block = rewriter.createBlock(region, region->end());
+    block->addArguments(resultMemrefType.getElementType());
+
+    rewriter.setInsertionPointToEnd(block);
+    auto scalar =
+        rewriter.create<LoadOp>(loc, broadcastOp.operand(), llvm::None);
+    rewriter.create<linalg::YieldOp>(loc, scalar.getResult());
+    rewriter.eraseOp(broadcastOp);
+    return matchSuccess();
+  }
+};
+
 template <typename OpTy, bool isLHLO = true>
 class TransposeConverter
     : public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
@@ -385,7 +434,7 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
     if (!resultMemrefType) return matchFailure();
 
     auto resultElementType = resultMemrefType.getElementType();
-    if (!resultElementType.isIntOrFloat()) return matchFailure();
+    if (!resultElementType.isSignlessIntOrFloat()) return matchFailure();
 
     // Construct the indexing maps needed for linalg.generic ops.
     unsigned nloops = resultMemrefType.getRank();
@@ -502,6 +551,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
                    PointwiseToLinalgConverter<xla_lhlo::SubOp>,
                    PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
                    ReshapeAddRemoveDimConverter<xla_lhlo::ReshapeOp>,
+                   ScalarBroadcastInDimConverter,
                    ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
                    SliceConverter
                   >(context);
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index d517b5d0bdd..f3ee4e38f31 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1514,6 +1514,7 @@ cuda_py_test(
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
     ],
     xla_enable_strict_auto_jit = False,
+    xla_enabled = True,
     deps = [
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client",
@@ -1534,6 +1535,7 @@ cuda_py_test(
         "no_rocm",
     ],
     xla_enable_strict_auto_jit = False,
+    xla_enabled = True,
     deps = [
         ":test_utils",
         "//tensorflow/core:protos_all_py",
@@ -1558,6 +1560,7 @@ cuda_py_test(
         "no_rocm",
     ],
     xla_enable_strict_auto_jit = False,
+    xla_enabled = True,
     deps = [
         ":test_utils",
         "//tensorflow/core:protos_all_py",
@@ -1650,6 +1653,7 @@ cuda_py_test(
         "no_rocm",
     ],
     xla_enable_strict_auto_jit = False,
+    xla_enabled = True,
     deps = [
         ":lstm",
         ":xla_test",
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index a03980f20ba..0ed81b7e9e5 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -695,7 +695,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
       wholly_compiled_f = def_function.function(f)
       op_by_op_f = def_function.function(f, experimental_compile=False)
 
-      x = constant_op.constant([0.0, 2.0], name='data')
+      x = array_ops.identity([0.0, 2.0], name='data')
 
       # When function is wholly compiled, all outputs will be on the
       # device on which it is run.
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 86da500b1dd..ba5d7e9d788 100755
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -620,6 +620,7 @@ cc_library(
         "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
+        "//tensorflow/stream_executor:device_description",
         "@com_google_absl//absl/algorithm:container",
         "@llvm-project//llvm:core",
     ],
@@ -1124,6 +1125,7 @@ cc_library(
         ":cudnn_batchnorm_rewriter",
         ":cudnn_pad_for_convolutions",
         ":fusion_merger",
+        ":gemm_rewriter",
         ":gpu_constants",
         ":gpu_conv_algorithm_picker",
         ":gpu_conv_padding_legalization",
@@ -1140,9 +1142,13 @@ cc_library(
         ":ir_emitter",
         ":multi_output_fusion",
         ":partition_assignment",
+        ":reduction_degenerate_dim_remover",
+        ":reduction_dimension_grouper",
+        ":reduction_layout_normalizer",
         ":stream_assignment",
         ":stream_executor_util",
         ":target_constants",
+        ":tree_reduction_rewriter",
         ":variadic_op_splitter",
         "//tensorflow/compiler/xla:protobuf_util",
         "//tensorflow/compiler/xla:status_macros",
@@ -1232,19 +1238,13 @@ cc_library(
         ":cudnn_pad_for_convolutions",
         ":cusolver_rewriter",
         ":gemm_algorithm_picker",
-        ":gemm_rewriter",
         ":gpu_compiler",
-        ":gpu_conv_algorithm_picker",
         ":gpu_conv_padding_legalization",
         ":gpu_conv_rewriter",
         ":gpu_layout_assignment",
         ":ir_emission_utils",
-        ":reduction_degenerate_dim_remover",
-        ":reduction_dimension_grouper",
-        ":reduction_layout_normalizer",
         ":stream_executor_util",
         ":target_constants",
-        ":tree_reduction_rewriter",
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/container:node_hash_map",
         "@com_google_absl//absl/types:optional",
diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
index 97013804271..b78748edb7e 100644
--- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
@@ -87,39 +87,6 @@ Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
   return Status::OK();
 }
 
-Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment(
-    HloModule* hlo_module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* device_allocator) {
-  HloPassPipeline pipeline("post-layout_assignment");
-  pipeline.AddInvariantChecker<HloVerifier>(
-      /*layout_sensitive=*/true,
-      /*allow_mixed_precision=*/false,
-      LayoutAssignment::InstructionCanChangeLayout);
-
-  pipeline.AddPass<ReductionDegenerateDimRemover>();
-  pipeline.AddPass<ReductionLayoutNormalizer>();
-  pipeline.AddPass<ReductionDimensionGrouper>();
-
-  // The LayoutAssignment pass may leave behind kCopy instructions which are
-  // duplicate or NOPs, so remove them with algebraic simplification and CSE.
-  AlgebraicSimplifierOptions options;
-  options.set_is_layout_sensitive(true);
-  pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
-
-  // Rewrite GEMMs into custom calls.
-  pipeline.AddPass<GemmRewriter>();
-
-  pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
-
-  // Clean up new_tuple described above.
-  pipeline.AddPass<TupleSimplifier>();
-
-  pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
-  TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
-
-  return Status::OK();
-}
-
 AMDGPUCompiler::AMDGPUCompiler()
     : GpuCompiler(stream_executor::rocm::kROCmPlatformId, amdgpu::kTargetTriple,
                   amdgpu::kDataLayout) {}
diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
index d1a74a7822e..acc5e021e3d 100644
--- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
@@ -37,10 +37,6 @@ class AMDGPUCompiler : public GpuCompiler {
       HloModule* hlo_module, se::StreamExecutor* stream_exec,
       se::DeviceMemoryAllocator* device_allocator) override;
 
-  Status OptimizeHloPostLayoutAssignment(
-      HloModule* hlo_module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) override;
-
   GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
 
   StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index e4c57203543..51b30e238e9 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -46,7 +46,9 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
+#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
@@ -61,10 +63,14 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
+#include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
+#include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
+#include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h"
 #include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
@@ -106,6 +112,7 @@ limitations under the License.
 #include "tensorflow/core/platform/subprocess.h"
 #include "tensorflow/core/platform/tracing.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/util/env_var.h"
 
 namespace xla {
 namespace gpu {
@@ -333,6 +340,81 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
   return pipeline.Run(hlo_module).status();
 }
 
+// TODO(cheshire): Duplication with gpu_conv_algorithm picker, figure out a
+// right way to share this.
+static bool RequireDeterminism() {
+  bool deterministic_ops = false;
+  TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
+                                             /*default_val=*/false,
+                                             &deterministic_ops));
+  return deterministic_ops;
+}
+
+Status GpuCompiler::OptimizeHloPostLayoutAssignment(
+    HloModule* hlo_module, se::StreamExecutor* stream_exec,
+    se::DeviceMemoryAllocator* device_allocator) {
+  HloPassPipeline pipeline("post-layout_assignment");
+  /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
+   * fixing the ticket. */
+  pipeline.AddInvariantChecker<HloVerifier>(
+      /*layout_sensitive=*/true,
+      /*allow_mixed_precision=*/false,
+      LayoutAssignment::InstructionCanChangeLayout);
+
+  pipeline.AddPass<ReductionDegenerateDimRemover>();
+  pipeline.AddPass<ReductionLayoutNormalizer>();
+  pipeline.AddPass<ReductionDimensionGrouper>();
+
+  // The LayoutAssignment pass may leave behind kCopy instructions which are
+  // duplicate or NOPs, so remove them with algebraic simplification and CSE.
+  AlgebraicSimplifierOptions options;
+  options.set_is_layout_sensitive(true);
+  pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
+
+  if (RequireDeterminism() ||
+      hlo_module->config().debug_options().xla_gpu_deterministic_reductions()) {
+    pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>();
+  }
+
+  // Rewrite GEMMs into custom calls.
+  pipeline.AddPass<GemmRewriter>();
+
+  // Choose the fastest algorithm for each conv.
+  //
+  // We pick the algorithm before fusion so we can generate better HLO. After
+  // GpuConvRewriter, our convolutions are CustomCalls which return a
+  // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
+  // scratch:
+  //
+  //   customcall = (f32[...], f32[0])
+  //   return gte(customcall, 0)
+  //
+  // The algorithm picker then chooses the best algorithm, and potentially
+  // increases the scratch space.  It replaces customcall with new_tuple,
+  // giving us the following:
+  //
+  //   new_customcall = (f32[...], f32[N])
+  //   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
+  //   return gte(new_tuple, 0)
+  //
+  // The new tuple and gte instructions then be simplified away, because
+  // nobody is expected to use the scratch value.
+  //
+  // However, if we were to run GpuConvAlgorithmPicker after fusion
+  // the gte(customcall, 0) would probably already be into a fusion node.  We
+  // can't simplify across HloComputation boundaries, so in this case we
+  // wouldn't be able to simplify away the new_tuple bits.
+  pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
+
+  // Clean up new_tuple described above.
+  pipeline.AddPass<TupleSimplifier>();
+
+  pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+  TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+
+  return Status::OK();
+}
+
 StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
     se::DeviceMemoryAllocator* device_allocator) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
index 901d994d4ad..b52af5392d1 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
@@ -64,7 +64,7 @@ class GpuCompiler : public LLVMCompiler {
 
   virtual Status OptimizeHloPostLayoutAssignment(
       HloModule* hlo_module, se::StreamExecutor* stream_exec,
-      se::DeviceMemoryAllocator* device_allocator) = 0;
+      se::DeviceMemoryAllocator* device_allocator);
 
   virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() {
     return
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 8646b4b5016..566d4f0e463 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -34,6 +34,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/stream_executor/device_description.h"
 
 namespace xla {
 namespace gpu {
@@ -126,7 +127,9 @@ bool IsCublasGemm(const HloInstruction& hlo) {
 }
 
 std::array<int64, 3> GetReductionTiling(
-    const ReductionDimensions& reduction_dimensions) {
+    const ReductionDimensions& reduction_dimensions,
+    int smallest_input_dtype_bits,
+    const stream_executor::DeviceDescription* device_description) {
   if (reduction_dimensions.is_row_reduction) {
     int64 tile_z = std::min(reduction_dimensions.dimensions[0], int64{8});
     if (reduction_dimensions.dimensions[1] == 1) {
@@ -137,7 +140,17 @@ std::array<int64, 3> GetReductionTiling(
         0) {
       return {tile_z, 1, 64};
     }
-    return {tile_z, 1, 8};
+    int cc_major = 0, cc_minor = 0;
+    if (device_description != nullptr) {
+      device_description->cuda_compute_capability(&cc_major, &cc_minor);
+    }
+    int unroll_x = 8;
+    if (cc_major >= 6 && smallest_input_dtype_bits == 16) {
+      unroll_x = 16;
+    } else if (cc_major >= 6 && smallest_input_dtype_bits == 8) {
+      unroll_x = 64;
+    }
+    return {tile_z, 1, unroll_x};
   }
 
   // Column reduction.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 82b10a50c39..8a2385a242b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/stream_executor/device_description.h"
 
 // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
 // don't belong in "ir_emission_utils".
@@ -193,8 +194,12 @@ ReductionDimensions GetReductionKindAndContiguousComponents(
 
 // Get tiling per thread for the given reduction in dimensions [D, H, W] per
 // thread.
+// If the device isn't known pass null for device_description and you will get
+// non-optimized value.
 std::array<int64, 3> GetReductionTiling(
-    const ReductionDimensions& reduction_dimensions);
+    const ReductionDimensions& reduction_dimensions,
+    int smallest_input_dtype_bits,
+    const stream_executor::DeviceDescription* device_description);
 
 // Emits call to "vprintf" with given format and arguments.
 llvm::Value* EmitPrintf(absl::string_view fmt,
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f99e43cc06d..99bfc9185a7 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -3097,9 +3097,20 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
            << " " << reduction_dimensions.dimensions[0] << " "
            << reduction_dimensions.dimensions[1] << " "
            << reduction_dimensions.dimensions[2];
+  auto get_dtype_bits = [](const HloInstruction* i) {
+    return primitive_util::BitWidth(i->shape().element_type());
+  };
 
+  // For fusion with multiple inputs, use the smallest input dtype to
+  // select the reduction_tiling.
+  int smallest_input_dtype_bits = get_dtype_bits(first_reduce->operand(0));
+  for (xla::HloInstruction* input : unnested_hlo->operands()) {
+    smallest_input_dtype_bits =
+        std::min(get_dtype_bits(input), smallest_input_dtype_bits);
+  }
   std::array<int64, 3> reduction_tiling =
-      GetReductionTiling(reduction_dimensions);
+      GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits,
+                         &ir_emitter_context_->device_description());
   bool dilated_x =
       reduction_dimensions.is_row_reduction ||
       !IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape,
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index a1a901f0b94..6d036094a69 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -27,19 +27,13 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h"
 #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
 #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
-#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
-#include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
-#include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
-#include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
-#include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h"
 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
 #include "tensorflow/compiler/xla/service/hlo_cse.h"
 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
@@ -55,7 +49,6 @@ limitations under the License.
 #include "tensorflow/core/platform/cuda_libdevice_path.h"
 #include "tensorflow/core/platform/tracing.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
-#include "tensorflow/core/util/env_var.h"
 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
 
@@ -152,83 +145,20 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
   return Status::OK();
 }
 
-// TODO(cheshire): Duplication with gpu_conv_algorithm picker, figure out a
-// right way to share this.
-static bool RequireDeterminism() {
-  bool deterministic_ops = false;
-  TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
-                                             /*default_val=*/false,
-                                             &deterministic_ops));
-  return deterministic_ops;
-}
-
 Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
     HloModule* hlo_module, se::StreamExecutor* stream_exec,
     se::DeviceMemoryAllocator* device_allocator) {
-  HloPassPipeline pipeline("post-layout_assignment");
-  /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
-   * fixing the ticket. */
-  pipeline.AddInvariantChecker<HloVerifier>(
-      /*layout_sensitive=*/true,
-      /*allow_mixed_precision=*/false,
-      LayoutAssignment::InstructionCanChangeLayout);
-
-  pipeline.AddPass<ReductionDegenerateDimRemover>();
-  pipeline.AddPass<ReductionLayoutNormalizer>();
-  pipeline.AddPass<ReductionDimensionGrouper>();
-
-  // The LayoutAssignment pass may leave behind kCopy instructions which are
-  // duplicate or NOPs, so remove them with algebraic simplification and CSE.
-  AlgebraicSimplifierOptions options;
-  options.set_is_layout_sensitive(true);
-  pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
-
-  if (RequireDeterminism() ||
-      hlo_module->config().debug_options().xla_gpu_deterministic_reductions()) {
-    pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>();
-  }
+  TF_RETURN_IF_ERROR(GpuCompiler::OptimizeHloPostLayoutAssignment(
+      hlo_module, stream_exec, device_allocator));
 
+  HloPassPipeline pipeline("nvptx post-layout_assignment");
   // Pad the dimensions of matrices in dot operations to multiples of 8.
   if (IsVoltaOrLater(*stream_exec)) {
     pipeline.AddPass<CublasGemmPadForTensorCores>();
   }
-  // Rewrite GEMMs into custom calls.
-  pipeline.AddPass<GemmRewriter>();
-
-  // Choose the fastest algorithm for each conv.
-  //
-  // We pick the algorithm before fusion so we can generate better HLO. After
-  // GpuConvRewriter, our convolutions are CustomCalls which return a
-  // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
-  // scratch:
-  //
-  //   customcall = (f32[...], f32[0])
-  //   return gte(customcall, 0)
-  //
-  // The algorithm picker then chooses the best algorithm, and potentially
-  // increases the scratch space.  It replaces customcall with new_tuple,
-  // giving us the following:
-  //
-  //   new_customcall = (f32[...], f32[N])
-  //   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
-  //   return gte(new_tuple, 0)
-  //
-  // The new tuple and gte instructions then be simplified away, because
-  // nobody is expected to use the scratch value.
-  //
-  // However, if we were to run GpuConvAlgorithmPicker after fusion
-  // the gte(customcall, 0) would probably already be into a fusion node.  We
-  // can't simplify across HloComputation boundaries, so in this case we
-  // wouldn't be able to simplify away the new_tuple bits.
-  pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
 
   // Find the fastest algorithm for GEMMs.
   pipeline.AddPass<GemmAlgorithmPicker>(stream_exec, device_allocator);
-
-  // Clean up new_tuple described above.
-  pipeline.AddPass<TupleSimplifier>();
-
-  pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
   TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
 
   return Status::OK();
diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc
index e1ab7bb9646..79f5c0fd901 100644
--- a/tensorflow/compiler/xla/service/interpreter/platform.cc
+++ b/tensorflow/compiler/xla/service/interpreter/platform.cc
@@ -30,7 +30,7 @@ limitations under the License.
 namespace stream_executor {
 namespace interpreter {
 
-XlaInterpreterPlatform::XlaInterpreterPlatform(const string& name,
+XlaInterpreterPlatform::XlaInterpreterPlatform(const std::string& name,
                                                const Platform::Id& id)
     : name_(name), id_(id) {}
 
@@ -40,7 +40,7 @@ Platform::Id XlaInterpreterPlatform::id() const { return id_; }
 
 int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; }
 
-const string& XlaInterpreterPlatform::Name() const { return name_; }
+const std::string& XlaInterpreterPlatform::Name() const { return name_; }
 
 port::StatusOr<std::unique_ptr<DeviceDescription>>
 XlaInterpreterPlatform::DescriptionForDevice(int ordinal) const {
diff --git a/tensorflow/compiler/xla/service/interpreter/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h
index ff9c5d07f8d..da037bf17bc 100644
--- a/tensorflow/compiler/xla/service/interpreter/platform.h
+++ b/tensorflow/compiler/xla/service/interpreter/platform.h
@@ -31,14 +31,14 @@ class XlaInterpreterPlatform : public Platform {
  public:
   XlaInterpreterPlatform()
       : XlaInterpreterPlatform("Interpreter", kXlaInterpreterPlatformId) {}
-  XlaInterpreterPlatform(const string& name, const Platform::Id& id);
+  XlaInterpreterPlatform(const std::string& name, const Platform::Id& id);
   ~XlaInterpreterPlatform() override;
 
   Platform::Id id() const override;
 
   int VisibleDeviceCount() const override;
 
-  const string& Name() const override;
+  const std::string& Name() const override;
 
   port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
       int ordinal) const override;
@@ -60,7 +60,7 @@ class XlaInterpreterPlatform : public Platform {
 
  private:
   // This platform's name.
-  string name_;
+  std::string name_;
   // This platform's id.
   Platform::Id id_;
 
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc
index aa28a36c945..b3e4002a898 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc
@@ -31,7 +31,7 @@ limitations under the License.
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/AffineOps/AffineOps.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/AffineExpr.h"  // TF:llvm-project
 #include "mlir/IR/AffineMap.h"  // TF:llvm-project
 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
index bb67305c344..184d8d202c3 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
@@ -16,7 +16,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h"
 
 #include "llvm/ADT/STLExtras.h"
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
 #include "mlir/IR/Types.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc
index 5d67d7dcf7f..bd64c18680c 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc
@@ -15,7 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h"
 
-#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 
 namespace mlir {
 namespace {
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index 176e7af5d8c..ca26ae4e756 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -31,7 +31,7 @@ limitations under the License.
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"  // TF:llvm-project
 #include "mlir/Dialect/Linalg/Passes.h"  // TF:llvm-project
 #include "mlir/Dialect/LoopOps/LoopOps.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
index 13009992ab5..75c7c284881 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
@@ -16,7 +16,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h"
 
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
index e383169399c..e471ba192e1 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
@@ -21,7 +21,7 @@ limitations under the License.
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"  // TF:llvm-project
 #include "mlir/Dialect/GPU/GPUDialect.h"  // TF:llvm-project
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // TF:llvm-project
-#include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Function.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 8f89c4655a3..c35f05ebf45 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -60,8 +60,8 @@ int main(int argc, char** argv) {
     LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU";
   }
 
-  string triple_string;
-  string target_cpu = argv[1];
+  std::string triple_string;
+  std::string target_cpu = argv[1];
   if (target_cpu == "k8") {
     triple_string = "x86_64-none-linux-gnu";
   } else if (target_cpu == "darwin") {
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 3913d604881..5c71f539efa 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -112,10 +112,7 @@ class ReduceTest : public ClientLibraryTestBase {
     std::unique_ptr<GlobalData> input_global_data =
         client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
-    float expected = 0.0;
-    for (float item : input_data) {
-      expected += item;
-    }
+    float expected = absl::c_accumulate(input_data, 0.0f);
     ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
                                ErrorSpec(0.001));
   }
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 4bae89cda01..b02eb89ebfc 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -106,6 +106,9 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
 # buildifier: disable=same-origin-load
 # Placeholder: load("//tensorflow:tensorflow.bzl", "tf_portable_proto_lib")
 
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_monitoring_deps")
+
 # For platform specific build config
 load(
     "//tensorflow/core/platform:build_config.bzl",
@@ -1988,7 +1991,6 @@ cc_library(
         "//tensorflow/core/platform:hash",
         "//tensorflow/core/platform:load_library",
         "//tensorflow/core/platform:logger",
-        "//tensorflow/core/platform:monitoring",
         "//tensorflow/core/platform:mutex",
         "//tensorflow/core/platform:notification",
         "//tensorflow/core/platform:net",
@@ -2026,7 +2028,7 @@ cc_library(
         "@zlib",
         "@double_conversion//:double-conversion",
         "@com_google_protobuf//:protobuf",
-    ] + tf_protos_all_impl() + tf_protos_grappler_impl() + tf_protos_profiler_impl(),
+    ] + tf_protos_all_impl() + tf_protos_grappler_impl() + tf_protos_profiler_impl() + tf_monitoring_deps(),
     # Alwayslink causes a cc_binary to "always link" in the
     # srcs for a given cc_library, even if they are unreferenced, see:
     # https://docs.bazel.build/versions/master/be/c-cpp.html#cc_library.alwayslink
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc
index b17278fb365..c7583c374f2 100644
--- a/tensorflow/core/common_runtime/device_mgr.cc
+++ b/tensorflow/core/common_runtime/device_mgr.cc
@@ -45,7 +45,7 @@ StaticDeviceMgr::StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices)
     }
     const auto& t = d->device_type();
     device_type_counts_[t]++;
-    if (cpu_device_ == nullptr && t == "CPU") {
+    if (cpu_device_ == nullptr && t == "CPU" && d->parsed_name().id == 0) {
       cpu_device_ = d.get();
     }
   }
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 66b6bdde1c8..968713f8acd 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -2813,4 +2813,66 @@ TEST_F(DirectSessionCollectiveTest,
   ASSERT_EQ(key1, key2);
 }
 
+// Accesses the cancellation manager for the step after the step has been
+// cancelled.
+class StatefulOutputRequiredOp : public OpKernel {
+ public:
+  explicit StatefulOutputRequiredOp(OpKernelConstruction* ctx)
+      : OpKernel(ctx) {}
+  void Compute(OpKernelContext* ctx) override {
+    // The op counts the number of outputs required in the current subgraph,
+    // and emits that number on each of its required outputs.
+    Tensor count_outputs_required_t(0LL);
+    int64& count_outputs_required = count_outputs_required_t.scalar<int64>()();
+    for (int i = 0; i < num_outputs(); ++i) {
+      if (ctx->output_required(i)) ++count_outputs_required;
+    }
+    for (int i = 0; i < num_outputs(); ++i) {
+      if (ctx->output_required(i)) ctx->set_output(i, count_outputs_required_t);
+    }
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("StatefulOutputRequired").Device(DEVICE_CPU),
+                        StatefulOutputRequiredOp);
+REGISTER_OP("StatefulOutputRequired")
+    .Output("results : num_outs * int64")
+    .Attr("num_outs : int = 5")
+    .SetIsStateful();
+
+TEST(DirectSessionTest, TestStatefulOutputRequiredOp) {
+  GraphDef graph;
+  // Creates a graph with a StatefulOutputRequired op with 5 outputs.
+  protobuf::TextFormat::ParseFromString(
+      R"proto(
+        node { name: 'n' op: 'StatefulOutputRequired' device: '/device:CPU:0' }
+        versions { producer: 9 }
+      )proto",
+      &graph);
+
+  std::unique_ptr<Session> session(NewSession(SessionOptions()));
+  ASSERT_TRUE(session != nullptr);
+  TF_ASSERT_OK(session->Create(std::move(graph)));
+
+  // As a stateful op, a single StatefulOutputRequired kernel will be created
+  // and shared across multiple subgraphs. We create 5 different subgraphs,
+  // fetching different prefixes of the output of the op.
+  for (int num_outputs_required = 1; num_outputs_required <= 5;
+       ++num_outputs_required) {
+    std::vector<string> fetch_tensor_names;
+    fetch_tensor_names.reserve(num_outputs_required);
+    for (int output_idx = 0; output_idx < num_outputs_required; ++output_idx) {
+      fetch_tensor_names.push_back(strings::StrCat("n:", output_idx));
+    }
+    std::vector<Tensor> fetch_tensors;
+    TF_ASSERT_OK(session->Run({}, fetch_tensor_names, {}, &fetch_tensors));
+    ASSERT_EQ(num_outputs_required, fetch_tensors.size());
+    for (const Tensor& t : fetch_tensors) {
+      ASSERT_EQ(num_outputs_required, t.scalar<int64>()());
+    }
+  }
+
+  TF_ASSERT_OK(session->Close());
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/dynamic_device_mgr.cc b/tensorflow/core/common_runtime/dynamic_device_mgr.cc
index f7e2e27e4ab..4bea08bb021 100644
--- a/tensorflow/core/common_runtime/dynamic_device_mgr.cc
+++ b/tensorflow/core/common_runtime/dynamic_device_mgr.cc
@@ -194,7 +194,8 @@ Device* DynamicDeviceMgr::HostCPU() const {
   }
   cpu_device_ = nullptr;
   for (const auto& pair : dynamic_devices_) {
-    if (pair.first->device_type() == DEVICE_CPU) {
+    if (pair.first->device_type() == DEVICE_CPU &&
+        pair.first->parsed_name().id == 0) {
       cpu_device_ = pair.first;
       break;
     }
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index 76e34173459..9211a022b24 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -135,6 +135,7 @@ tf_cuda_library(
         "//conditions:default": [
             "//tensorflow/core:framework",
             "//tensorflow/core:lib",
+            "//tensorflow/core/profiler/lib:traceme",
         ],
     }),
 )
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index fb051a9a583..5a053c2b51a 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -49,7 +49,6 @@ limitations under the License.
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/lib/core/blocking_counter.h"
 #include "tensorflow/core/lib/monitoring/gauge.h"
-#include "tensorflow/core/platform/monitoring.h"
 #include "tensorflow/core/util/env_var.h"
 
 namespace tensorflow {
@@ -102,7 +101,6 @@ EagerContext::EagerContext(
   // provided). For builds using "tensorflow/core/platform/default", this is
   // currently a no-op.
   eager_context_created->GetCell()->Set(true);
-  monitoring::StartExporter();
   InitPrioritizedDeviceTypeList();
   runner_ = [this](std::function<void()> closure) {
     this->thread_pool_->Schedule(std::move(closure));
@@ -167,6 +165,16 @@ std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) {
   return v;
 }
 
+std::vector<string> DeviceTypesToString(
+    const PrioritizedDeviceTypeVector& types) {
+  std::vector<string> v;
+  v.reserve(types.size());
+  for (const auto& p : types) {
+    v.push_back(p.first.type_string());
+  }
+  return v;
+}
+
 // Selects the "best" device that both exists and is supported.
 //
 // The `existing` argument specifies the available devices in the system, in
@@ -232,13 +240,17 @@ Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
     return errors::InvalidArgument(
         "Could not satisfy device specification '", preferred,
         "'. enable_soft_placement=", AllowSoftPlacement(),
-        ". All available devices [",
+        ". Supported device types [",
+        absl::StrJoin(DeviceTypesToString(supported), ", "),
+        "]. All available devices [",
         absl::StrJoin(DevicesToString(existing), ", "), "].");
   }
   return errors::InvalidArgument(
       "No supported device found in available devices [",
       absl::StrJoin(DevicesToString(existing), ", "),
-      "]. enable_soft_placement=", AllowSoftPlacement(), ".");
+      "]. enable_soft_placement=", AllowSoftPlacement(),
+      ". Supported devices types [",
+      absl::StrJoin(DeviceTypesToString(supported), ", "), "].");
 }
 
 void EagerContext::ResetClusterFLR(
diff --git a/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/tensorflow/core/common_runtime/eager/copy_to_device_node.h
index 4c4a1e13266..1a1459a9f1c 100644
--- a/tensorflow/core/common_runtime/eager/copy_to_device_node.h
+++ b/tensorflow/core/common_runtime/eager/copy_to_device_node.h
@@ -27,15 +27,25 @@ namespace tensorflow {
 class CopyToDeviceNode : public EagerNode {
  public:
   CopyToDeviceNode(TensorHandle* src, TensorHandle* dst, Device* dstd,
-                   const EagerContext& ctx)
-      : EagerNode(), src_(src), dst_(dst), dstd_(dstd), ctx_(ctx) {
-    src_->Ref();
-    dst_->Ref();
+                   const EagerContext& ctx, bool async, bool mirror)
+      : EagerNode(),
+        src_(src),
+        dst_(dst),
+        dstd_(dstd),
+        ctx_(ctx),
+        async_(async),
+        mirror_(mirror) {
+    if (async_) {
+      src_->Ref();
+      dst_->Ref();
+    }
   }
 
   ~CopyToDeviceNode() override {
-    src_->Unref();
-    dst_->Unref();
+    if (async_) {
+      src_->Unref();
+      dst_->Unref();
+    }
   }
 
   Status Run() override {
@@ -43,16 +53,20 @@ class CopyToDeviceNode : public EagerNode {
     MEMDEBUG_CACHE_OP(MEMDEBUG_CACHE_VAL ? MEMDEBUG_CACHE_VAL
                                          : "eager::CopyToDeviceNode");
     TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &tensor));
-    return dst_->SetTensor(std::move(tensor), ctx_.CanonicalDevice(dstd_));
+    if (!async_ && mirror_) {
+      return dst_->AddLocalMirror(std::move(tensor), dstd_);
+    } else {
+      return dst_->SetTensor(std::move(tensor), dstd_);
+    }
   }
 
-  void Abort(Status status) override { dst_->Poison(status); }
+  void Abort(Status status) override { dst_->Poison(status, dstd_); }
 
   string DebugString() const override {
     string out = "[CopyToDeviceNode]";
     strings::StrAppend(&out, " src_tensor: ", src_->DebugString());
     strings::StrAppend(&out, ", dst_tensor: ", dst_->DebugString());
-    strings::StrAppend(&out, ", dst_device: ", dstd_->name());
+    strings::StrAppend(&out, ", dst_device: ", dstd_ ? dstd_->name() : "[]");
     return out;
   }
 
@@ -63,6 +77,8 @@ class CopyToDeviceNode : public EagerNode {
   TensorHandle* dst_;
   Device* dstd_;
   const EagerContext& ctx_;
+  bool async_;
+  bool mirror_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc
index 2e50bd9de49..d49c9a5064b 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.cc
+++ b/tensorflow/core/common_runtime/eager/eager_executor.cc
@@ -19,8 +19,17 @@ limitations under the License.
 
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/util/env_var.h"
 
 namespace tensorflow {
+namespace {
+bool IsAsyncWaitForRemoteFunctionEnabled() {
+  bool enabled = true;
+  TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_ASYNC_WAIT_FOR_REMOTE_FUNCTION",
+                                 true, &enabled));
+  return enabled;
+}
+}  // namespace
 
 EagerExecutor::EagerExecutor(bool async)
     : next_node_id_(0),
@@ -28,7 +37,10 @@ EagerExecutor::EagerExecutor(bool async)
       thread_(async ? tensorflow::Env::Default()->StartThread(
                           tensorflow::ThreadOptions(), "eager_async_executor",
                           std::bind(&EagerExecutor::Run, this))
-                    : nullptr) {}
+                    : nullptr),
+      last_eager_client_(nullptr),
+      enable_async_wait_for_remote_function_(
+          IsAsyncWaitForRemoteFunctionEnabled()) {}
 
 EagerExecutor::~EagerExecutor() {
   tensorflow::mutex_lock l(node_queue_mutex_);
@@ -194,6 +206,7 @@ void EagerExecutor::ClearError() {
   DCHECK(node_queue_.empty());
   status_ = tensorflow::Status::OK();
   ok_ = true;
+  last_eager_client_ = nullptr;
   nodes_pending_.notify_all();
 }
 
@@ -327,6 +340,33 @@ Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item,
                               bool from_queue) {
   DVLOG(3) << "Running Node: [id " << item->id << "] "
            << item->node->DebugString();
+  AsyncRemoteExecuteNode* async_remote_node =
+      item->node->AsAsyncRemoteExecuteNode();
+  if (enable_async_wait_for_remote_function_) {
+    if (async_remote_node != nullptr) {
+      if (last_eager_client_ != nullptr &&
+          async_remote_node->eager_client() != nullptr &&
+          last_eager_client_ != async_remote_node->eager_client()) {
+        // Running a remote function, need to sync if the function is going to
+        // different device than last time we run remote distributed function.
+        DVLOG(3) << "Executing Sync Executor for node" << item->id;
+        tensorflow::Status status = async_remote_node->SyncExecutors();
+        if (!status.ok()) {
+          NodeDone(item, status, from_queue);
+          return status;
+        }
+        last_eager_client_ = nullptr;
+      }
+      if (async_remote_node->eager_client() != nullptr &&
+          async_remote_node->needs_remote_inputs() &&
+          async_remote_node->allow_multiple_pending_requests()) {
+        // We are running remote distributed function, update
+        // last_remote_device_name_.
+        last_eager_client_ = async_remote_node->eager_client();
+      }
+    }
+  }
+
   AsyncEagerNode* async_node = item->node->AsAsync();
   if (async_node == nullptr) {
     tensorflow::Status status = item->node->Run();
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.h b/tensorflow/core/common_runtime/eager/eager_executor.h
index e414ab4b6c5..375e9a6e6a7 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.h
+++ b/tensorflow/core/common_runtime/eager/eager_executor.h
@@ -39,6 +39,10 @@ limitations under the License.
 namespace tensorflow {
 
 class AsyncEagerNode;
+class AsyncRemoteExecuteNode;
+namespace eager {
+class EagerClient;
+}
 
 // A unit of execution for the EagerExecutor class below. Example subclasses
 // encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one
@@ -65,6 +69,7 @@ class EagerNode {
 
   // Returns nullptr iff this Eager node is synchronous.
   virtual AsyncEagerNode* AsAsync() { return nullptr; }
+  virtual AsyncRemoteExecuteNode* AsAsyncRemoteExecuteNode() { return nullptr; }
 
   virtual string DebugString() const = 0;
 
@@ -86,6 +91,16 @@ class AsyncEagerNode : public EagerNode {
   }
 };
 
+class AsyncRemoteExecuteNode : public AsyncEagerNode {
+ public:
+  AsyncRemoteExecuteNode* AsAsyncRemoteExecuteNode() final { return this; }
+
+  virtual const eager::EagerClient* eager_client() const = 0;
+  virtual bool needs_remote_inputs() const = 0;
+  virtual bool allow_multiple_pending_requests() const = 0;
+  virtual Status SyncExecutors() = 0;
+};
+
 // A class for handling async execution (see TFE_ContextSetAsync).
 // Note that this class is thread-safe.
 // TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the
@@ -228,6 +243,11 @@ class EagerExecutor {
   // Thread object that calls the `Run` method in async mode.This thread runs
   // until state_ is set to kShuttingDown. It is `nullptr` in sync mode.
   const std::unique_ptr<Thread> thread_;
+
+  // Last device where remote function with remote inputs was executed.
+  const eager::EagerClient* last_eager_client_;
+
+  const bool enable_async_wait_for_remote_function_;
 };
 
 inline bool EagerExecutor::Async() const { return thread_ != nullptr; }
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 3496a39714d..4f1a9cfc2cc 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -769,9 +769,11 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
         id, i, remote_task, output_dtypes[i], op_device, &ctx, &retvals[i]);
     if (!status.ok()) {
       for (int j = 0; j < i; ++j) {
-        retvals[j]->Poison(errors::Internal(
-            "Failed to construct unshaped remote tensor handle at index ", i,
-            " for op ", op->Name()));
+        retvals[j]->PoisonRemote(
+            errors::Internal(
+                "Failed to construct unshaped remote tensor handle at index ",
+                i, " for op ", op->Name()),
+            op_device, ctx.GetContextViewId());
       }
       return status;
     }
@@ -796,7 +798,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
            << " (is async?: " << executor.Async() << ").";
 
   std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
-      std::move(request), op_device, ctx.GetContextViewId(), eager_client.get(),
+      &op->EagerContext(), std::move(request), op_device,
+      ctx.GetContextViewId(), eager_client.get(),
       op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
       op->Inputs(), {retvals, num_outputs}));
   Status s = executor.AddOrExecute(std::move(node));
@@ -1057,11 +1060,15 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
     return Status::OK();
   }
 
-  // TODO(gjn): Need to add support for async execution. Note if receiver
-  // is local, we need to first add support in TensorHandle to wait on local
-  // mirrors.
-  if (mirror && !executor->Async()) {
-    TF_RETURN_IF_ERROR(h->AddEmptyLocalMirror(d));
+  bool async = executor->Async();
+  if (mirror) {
+    // We don't bother adding an empty local mirror in sync mode since we'll be
+    // executing the operation directly and be calling AddLocalMirror. A
+    // reference count is still needed which will be removed if the operation
+    // fails.
+    if (async) {
+      TF_RETURN_IF_ERROR(h->AddEmptyLocalMirror(d));
+    }
     h->Ref();
     *result = h;
   } else {
@@ -1069,10 +1076,18 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
         true, d, dstd, h->resource_device(), h->dtype, ctx, result));
   }
 
-  // Note that `h` may not be currently ready. However execution order will
-  // make sure that `h` is ready before the copy is actually done.
-  std::unique_ptr<EagerNode> node(new CopyToDeviceNode(h, *result, dstd, *ctx));
-  Status s = executor->AddOrExecute(std::move(node));
+  Status s;
+  if (async) {
+    // Note that `h` may not be currently ready. However execution order will
+    // make sure that `h` is ready before the copy is actually done.
+    std::unique_ptr<EagerNode> node(
+        new CopyToDeviceNode(h, *result, d, *ctx, async, mirror));
+    s = executor->AddOrExecute(std::move(node));
+  } else {
+    CopyToDeviceNode node(h, *result, d, *ctx, async, mirror);
+    s = executor->SyncExecute(&node);
+  }
+
   // Since the operation failed, we need to Unref any outputs that were
   // allocated.
   if (!s.ok()) {
@@ -1121,7 +1136,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
       // TODO(gjn): Need to add support for async execution. Note if receiver
       // is local, we need to first add support in TensorHandle to wait on local
       // mirrors.
-      if (mirror && !executor->Async()) {
+      if (mirror) {
         TF_RETURN_IF_ERROR(h->AddEmptyLocalMirror(d));
         h->Ref();
         *result = h;
@@ -1159,7 +1174,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
       }
     }
 
-    auto node = absl::make_unique<eager::RemoteCopyNode>(
+    auto node = std::make_unique<eager::RemoteCopyNode>(
         ctx, executor, h, result[0], device, recv_op_id);
     Status s = executor->AddOrExecute(std::move(node));
     if (!s.ok()) {
diff --git a/tensorflow/core/common_runtime/eager/execute_node.h b/tensorflow/core/common_runtime/eager/execute_node.h
index 7e5340575c9..ed1bd956179 100644
--- a/tensorflow/core/common_runtime/eager/execute_node.h
+++ b/tensorflow/core/common_runtime/eager/execute_node.h
@@ -189,8 +189,10 @@ class AsyncExecuteNode : public EagerNode {
   }
 
   void Abort(Status status) override {
+    int i = 0;
     for (auto handle : retvals_) {
-      handle->Poison(status);
+      handle->Poison(status, ctx_->CanonicalDevice(kernel_->OutputDevice(i)));
+      ++i;
     }
   }
 
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index 4dbf5d6313d..47a7125ced8 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -319,21 +319,23 @@ Status TensorHandle::WaitReady(const char* caller) const {
   if (!IsReady()) {
     profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
                                profiler::TraceMeLevel::kInfo);
-    DVLOG(3) << "Waiting on TensorHandle " << this;
     tf_shared_lock l(mu_);
     mu_.Await(Condition(&is_ready_));
-    DVLOG(3) << "TensorHandle ready: " << this;
   }
   return is_poisoned_;
 }
 
 Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
+  DVLOG(3) << "Tensor on TensorHandle: " << this;
+
   TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Tensor"));
   return tensor_handle_data_->Tensor(t);
 }
 
 Status TensorHandle::TensorFromDevice(const Device* d,
                                       const tensorflow::Tensor** t) const {
+  DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
+
   if (d == absl::get<Device*>(device_)) {
     if (is_remote_) {
       return errors::Internal("Invalid Tensor call on remote handle: ", this);
@@ -344,19 +346,20 @@ Status TensorHandle::TensorFromDevice(const Device* d,
   }
 
   tf_shared_lock l(mu_);
-  auto mirror = local_mirrors_.find(d);
-  if (mirror != local_mirrors_.end()) {
-    return mirror->second->Tensor(t);
+  auto elem = local_mirrors_.find(d);
+  if (elem == local_mirrors_.end()) {
+    return errors::Internal("Invalid device: ", d,
+                            " in Tensor call to handle: ", this);
   }
 
-  auto empty_mirror = empty_local_mirrors_.find(d);
-  if (empty_mirror != empty_local_mirrors_.end()) {
-    // TODO(gjn): Add support for waiting on local mirrors
-    return errors::Internal("Attempted to get Tensor for empty mirror");
+  // Check if the handle is non-empty, else wait.
+  auto& mirror = elem->second;
+  if (mirror.second == nullptr) {
+    TF_RETURN_IF_ERROR(
+        mirror.first->WaitReady("TensorHandle::TensorFromDevice"));
   }
 
-  return errors::Internal("Invalid device: ", d,
-                          " in Tensor call to handle: ", this);
+  return mirror.second->Tensor(t);
 }
 
 Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) {
@@ -371,19 +374,19 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) {
   }
 
   tf_shared_lock l(mu_);
-  auto mirror = local_mirrors_.find(d);
-  if (mirror != local_mirrors_.end()) {
-    return mirror->second->TensorValue(t);
+  auto elem = local_mirrors_.find(d);
+  if (elem == local_mirrors_.end()) {
+    return errors::Internal("Invalid device: ", d,
+                            " in TensorValue call to handle: ", this);
   }
 
-  auto empty_mirror = empty_local_mirrors_.find(d);
-  if (empty_mirror != empty_local_mirrors_.end()) {
-    // TODO(gjn): Add support for waiting on local mirrors
-    return errors::Internal("Attempted to get TensorValue for empty mirror");
+  // Check if the handle is non-empty, else wait.
+  auto& mirror = elem->second;
+  if (mirror.second == nullptr) {
+    TF_RETURN_IF_ERROR(mirror.first->WaitReady("TensorHandle::TensorValue"));
   }
 
-  return errors::Internal("Invalid device: ", d,
-                          " in TensorValue call to handle: ", this);
+  return mirror.second->TensorValue(t);
 }
 
 TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
@@ -506,53 +509,50 @@ Status TensorHandle::NumElements(int64* num_elements) const {
 }
 
 Status TensorHandle::Unprotect(const Device* d) {
+  DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
+
   if (d == absl::get<Device*>(device_)) {
     return tensor_handle_data_->Unprotect();
   }
 
   tf_shared_lock l(mu_);
-  auto mirror = local_mirrors_.find(d);
-  if (mirror != local_mirrors_.end()) {
-    return mirror->second->Unprotect();
+  auto elem = local_mirrors_.find(d);
+  if (elem == local_mirrors_.end()) {
+    return errors::Internal("Invalid device: ", d,
+                            " in Unprotect call to handle: ", this);
   }
 
-  auto empty_mirror = empty_local_mirrors_.find(d);
-  if (empty_mirror != empty_local_mirrors_.end()) {
+  // Check if the handle is non-empty
+  auto& mirror = elem->second;
+  if (mirror.second == nullptr) {
     return errors::Internal("Attempted to unprotect an empty mirror");
   }
 
-  return errors::Internal("Invalid device: ", d,
-                          " in Unprotect call to handle: ", this);
+  return mirror.second->Unprotect();
 }
 
 bool TensorHandle::HasLocalMirror(const Device* d) const {
-  mutex_lock l(mu_);
-  auto mirror = local_mirrors_.find(d);
-  if (mirror != local_mirrors_.end()) {
-    return true;
-  }
+  DVLOG(3) << "HasLocalMirror on TensorHandle: " << this << " device: " << d;
 
-  auto empty_mirror = empty_local_mirrors_.find(d);
-  if (empty_mirror != empty_local_mirrors_.end()) {
-    return true;
-  }
-
-  return false;
+  tf_shared_lock l(mu_);
+  return local_mirrors_.find(d) != local_mirrors_.end();
 }
 
 Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
   DVLOG(3) << "AddEmptyLocalMirror on TensorHandle: " << this
            << " device: " << d;
 
+  if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
+    return errors::Internal("Cannot add mirror for primary device.");
+  }
+
   mutex_lock l(mu_);
   if (local_mirrors_.find(d) != local_mirrors_.end()) {
     return errors::Internal("Attempted to duplicate a local mirror.");
   }
 
-  auto ret = empty_local_mirrors_.insert(d);
-  if (!ret.second) {
-    return errors::Internal("Attempted to duplicate an empty local mirror.");
-  }
+  local_mirrors_[d] =
+      std::make_pair(std::make_unique<EmptyLocalTensorHandleData>(), nullptr);
 
   return Status::OK();
 }
@@ -560,6 +560,9 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
 #if !defined(IS_MOBILE_PLATFORM)
 Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
                                    int32* output_num) const {
+  DVLOG(3) << "RemoteAddress on TensorHandle: " << this << " device: " << d
+           << " " << d->name();
+
   if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
     tf_shared_lock l(mu_);
     auto mirror = remote_mirrors_.find(d->name());
@@ -593,7 +596,8 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
 
 bool TensorHandle::HasRemoteMirror(const Device* d,
                                    uint64 context_view_id) const {
-  DVLOG(3) << "HasRemoteMirror on TensorHandle: " << this;
+  DVLOG(3) << "HasRemoteMirror on TensorHandle: " << this << " device: " << d
+           << " " << d->name();
 
   tf_shared_lock l(mu_);
   auto mirror = remote_mirrors_.find(d->name());
@@ -619,7 +623,8 @@ bool TensorHandle::HasRemoteMirror(const Device* d,
 
 bool TensorHandle::HasResourceShapeMirror(const Device* d,
                                           uint64 context_view_id) const {
-  DVLOG(3) << "HasResourceShapeMirror on TensorHandle: " << this;
+  DVLOG(3) << "HasResourceShapeMirror on TensorHandle: " << this
+           << " device: " << d << " " << d->name();
 
   tf_shared_lock l(mu_);
   auto mirror = resource_shape_mirrors_.find(d->name());
@@ -635,7 +640,8 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d,
 
 Status TensorHandle::AddUnshapedRemoteMirror(
     std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) {
-  DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this;
+  DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this
+           << " device: " << d << " " << d->name();
 
   mutex_lock l(mu_);
   auto remote_mirror = remote_mirrors_.find(d->name());
@@ -704,7 +710,8 @@ Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
 
 Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
                                     uint64 context_view_id) {
-  DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d;
+  DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d
+           << " " << d->name();
 
   if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
     mutex_lock l(mu_);
@@ -758,16 +765,57 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
 
   return Status::OK();
 }
+
+void TensorHandle::PoisonRemote(Status status, const Device* d,
+                                uint64 context_view_id) {
+  DVLOG(3) << "PoisonRemote on TensorHandle: " << this << " device: " << d
+           << " " << d->name();
+
+  if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
+    DCHECK(!is_async_ || !IsReady())
+        << "PoisonRemote can only be called on non-ready handle: " << this;
+
+    is_poisoned_ = status;
+    mutex_lock l(mu_);
+    is_ready_ = true;
+  } else {
+    tf_shared_lock l(mu_);
+    auto mirror = unshaped_remote_mirrors_.find(d->name());
+    if (mirror != unshaped_remote_mirrors_.end()) {
+      if (mirror->second->context_view_id() == context_view_id) {
+        mirror->second->Poison(status);
+      }
+    }
+  }
+}
 #endif
 
+Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
+                                    const Device* d) {
+  if (d == absl::get<Device*>(device_)) {
+    return errors::Internal(
+        "Local mirror assign conflicts with primary device.");
+  }
+
+  mutex_lock l(mu_);
+  auto elem = local_mirrors_.insert(std::make_pair(
+      d, std::make_pair(nullptr,
+                        std::make_unique<LocalTensorHandleData>(tensor))));
+  if (!elem.second) {
+    return errors::Internal("Attempted to set tensor for existing mirror.");
+  }
+
+  return Status::OK();
+}
+
 Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) {
+  DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
+
   if (d == absl::get<Device*>(device_)) {
     DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
     DCHECK(!is_async_ || !IsReady())
         << "SetTensor is only called on non-ready handles.";
 
-    DVLOG(3) << "SetTensor on TensorHandle: " << this;
-
     if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) {
       auto& resource_handle = tensor.flat<class ResourceHandle>()(0);
       handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
@@ -779,37 +827,53 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) {
       is_ready_ = true;
     }
   } else {
-    mutex_lock l(mu_);
-    if (local_mirrors_.find(d) != local_mirrors_.end()) {
-      return errors::Internal("Attempted to set tensor for existing mirror.");
-    }
-
-    auto elem = empty_local_mirrors_.find(d);
-    if (elem == empty_local_mirrors_.end()) {
+    tf_shared_lock l(mu_);
+    auto elem = local_mirrors_.find(d);
+    if (elem == local_mirrors_.end()) {
       return errors::Internal(
           "Attempted to set tensor for non-existent local mirror.");
     }
-    local_mirrors_[d] = absl::make_unique<LocalTensorHandleData>(tensor);
-    empty_local_mirrors_.erase(elem);
+
+    auto& mirror = elem->second;
+    if (mirror.second != nullptr) {
+      return errors::Internal("Attempted to set tensor for existing mirror.");
+    }
+
+    mirror.second = absl::make_unique<LocalTensorHandleData>(tensor);
+    mirror.first->SetReady();
   }
 
   return Status::OK();
 }
 
-void TensorHandle::Poison(Status status) {
-  DCHECK(!is_async_ || !IsReady())
-      << "Poison(status) can only be called on non-ready handle: " << this;
+void TensorHandle::Poison(Status status, const Device* d) {
+  DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
 
-  DVLOG(3) << "Poison on TensorHandle: " << this;
+  if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
+    DCHECK(!is_async_ || !IsReady())
+        << "Poison can only be called on non-ready handle: " << this;
 
-  is_poisoned_ = status;
-  mutex_lock l(mu_);
-  is_ready_ = true;
+    is_poisoned_ = status;
+    mutex_lock l(mu_);
+    is_ready_ = true;
+  } else {
+    tf_shared_lock l(mu_);
+    auto elem = local_mirrors_.find(d);
+    DCHECK(elem != local_mirrors_.end())
+        << "Attempted to poison non-existent local mirror, handle: " << this
+        << " device: " << d;
+
+    auto& mirror = elem->second;
+    DCHECK(mirror.second == nullptr) << "Attempted to poison existing mirror.";
+
+    mirror.first->Poison(status);
+  }
 }
 
 Status TensorHandle::CopyToDevice(const EagerContext& ctx,
-                                  tensorflow::Device* dstd,
+                                  tensorflow::Device* d,
                                   tensorflow::Tensor* output) {
+  tensorflow::Device* dstd = (d == nullptr) ? ctx.HostCPU() : d;
   tensorflow::Device* srcd = absl::get<Device*>(DeviceOrHostCPU(ctx));
   const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
   const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 817347adc92..1e5b741cbb7 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -153,6 +153,9 @@ class TensorHandle : public core::RefCounted {
   // Add an empty mirror placeholder for the specified device. The expectation
   // is this will be populated by a call to SetTensor.
   Status AddEmptyLocalMirror(const Device* d);
+  // Add a local mirror. This will fail if an empty local mirror was previously
+  // added. For that case, SetTensor should be used instead.
+  Status AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d);
 
 #if !defined(IS_MOBILE_PLATFORM)
   bool HasRemoteMirror(const Device* d, uint64 context_view_id) const;
@@ -176,6 +179,13 @@ class TensorHandle : public core::RefCounted {
   // were created without a known shape.
   Status SetRemoteShape(const TensorShape& shape, const Device* d,
                         uint64 context_view_id);
+
+  // Poisons either this handle or a remote mirror with error `status`.
+  // Poisoning means that the handle will become ready and methods trying
+  // to access the remote shape will return this error `status`.
+  // Exactly one of SetRemoteShape or PoisonRemote methods must be called on a
+  // unshaped handle on a remote device.
+  void PoisonRemote(Status status, const Device* d, uint64 context_view_id);
 #endif
 
   // Sets the `tensor` for this async non-ready handle making it ready.
@@ -183,14 +193,14 @@ class TensorHandle : public core::RefCounted {
   // handles to make them ready.
   Status SetTensor(tensorflow::Tensor&& tensor, const Device* d);
 
-  // Poisons this non-ready handle with an error `status`.
+  // Poisons either this handle or a local mirror with error `status`.
   // Poisoning means that the handle will become ready and methods trying
   // to access the actual tensor or shape will return this error `status`.
-  // Exactly one of SetTensor, SetRemoteShape, or Poison methods must be called
-  // on a non-ready tensor.
-  void Poison(Status status);
+  // Exactly one of SetTensor or Poison methods must be called on a non-ready
+  // tensor for a specific device.
+  void Poison(Status status, const Device* d);
 
-  Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* dstd,
+  Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d,
                       tensorflow::Tensor* output);
 
   Status InferenceShape(
@@ -265,9 +275,14 @@ class TensorHandle : public core::RefCounted {
 
   mutable mutex mu_;
 
-  std::map<const tensorflow::Device*, std::unique_ptr<LocalTensorHandleData>>
+  // Map of local mirrors. In sync mode the EmptyLocalTensorHandleData is
+  // nullptr. In async mode, we use the EmptyLocalTensorHandleData to manage
+  // waiting clients. Once the EmptyLocalTensorHandleData is "ready" only the
+  // LocalTensorHandleData should be used.
+  std::map<const tensorflow::Device*,
+           std::pair<std::unique_ptr<EmptyLocalTensorHandleData>,
+                     std::unique_ptr<LocalTensorHandleData>>>
       local_mirrors_ GUARDED_BY(mu_);
-  std::set<const tensorflow::Device*> empty_local_mirrors_ GUARDED_BY(mu_);
 #if !defined(IS_MOBILE_PLATFORM)
   // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
   // variable is ready, since we could get the shape locally without remote copy
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc
index b6d17e1ee1a..690c14e2ffd 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc
@@ -16,6 +16,7 @@ limitations under the License.
 
 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
 #include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
 
 namespace tensorflow {
 
@@ -104,6 +105,32 @@ Status EmptyLocalTensorHandleData::Unprotect() {
   return errors::Unavailable("Unable to unprotect an empty handle.");
 }
 
+bool EmptyLocalTensorHandleData::IsReady() const {
+  tf_shared_lock l(mu_);
+  return is_ready_;
+}
+
+void EmptyLocalTensorHandleData::SetReady() {
+  mutex_lock l(mu_);
+  is_ready_ = true;
+}
+
+Status EmptyLocalTensorHandleData::WaitReady(const char* caller) const {
+  if (!IsReady()) {
+    profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
+                               profiler::TraceMeLevel::kInfo);
+    tf_shared_lock l(mu_);
+    mu_.Await(Condition(&is_ready_));
+  }
+  return is_poisoned_;
+}
+
+void EmptyLocalTensorHandleData::Poison(Status status) {
+  is_poisoned_ = status;
+  mutex_lock l(mu_);
+  is_ready_ = true;
+}
+
 string EmptyLocalTensorHandleData::DebugString() const {
   return "EmptyLocalTensorHandleData";
 }
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.h b/tensorflow/core/common_runtime/eager/tensor_handle_data.h
index 5e600cc8818..3a791d94315 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle_data.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.h
@@ -57,7 +57,9 @@ class LocalTensorHandleData : public TensorHandleData {
   Status NumElements(int64* num_elements) const override;
   Status Unprotect() override;
 
-  string DebugString() const override { return tensor_.DebugString(); }
+  string DebugString() const override {
+    return tensor_.DeviceSafeDebugString();
+  }
 
  private:
   tensorflow::Tensor tensor_;
@@ -87,7 +89,18 @@ class EmptyLocalTensorHandleData : public TensorHandleData {
   Status NumElements(int64* num_elements) const override;
   Status Unprotect() override;
 
+  bool IsReady() const;
+  void SetReady();
+  Status WaitReady(const char* caller) const;
+  void Poison(Status status);
+  Status IsPoisoned() const { return is_poisoned_; }
+
   string DebugString() const override;
+
+ private:
+  mutable mutex mu_;
+  bool is_ready_ GUARDED_BY(mu_);
+  Status is_poisoned_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 0be1d5df616..4a8d38c8e53 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -197,6 +197,10 @@ struct NodeItem {
   // Number of output edges.
   size_t num_output_edges;
 
+  // If non-null, contains an array of num_outputs bools, where the ith bool
+  // is true if and only if the ith output is consumed by another node.
+  std::unique_ptr<bool[]> outputs_required;
+
   const EdgeInfo* output_edge_list() const { return output_edge_base(); }
 
   // ith output edge.
@@ -708,16 +712,24 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
     }
 
     // Record information about whether each output of the op is used.
-    std::vector<bool> used_outputs(n->num_outputs(), false);
+    std::unique_ptr<bool[]> outputs_required(new bool[n->num_outputs()]);
+    std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false);
+    size_t unused_outputs = n->num_outputs();
     for (const Edge* e : n->out_edges()) {
       if (e->src_output() >= 0) {
-        used_outputs[e->src_output()] = true;
+        if (!outputs_required[e->src_output()]) {
+          --unused_outputs;
+          outputs_required[e->src_output()] = true;
+        }
       }
     }
-    for (bool used_output : used_outputs) {
-      if (!used_output) {
-        metrics::RecordUnusedOutput(n->type_string());
+    if (unused_outputs > 0) {
+      for (int i = 0; i < n->num_outputs(); ++i) {
+        if (!outputs_required[i]) {
+          metrics::RecordUnusedOutput(n->type_string());
+        }
       }
+      item->outputs_required = std::move(outputs_required);
     }
   }
 
@@ -1823,6 +1835,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
       params.is_input_dead = is_input_dead;
       params.output_attr_array = item.output_attrs();
       params.forward_from_array = item.forward_from();
+      params.outputs_required_array = item.outputs_required.get();
 
       if (item.kernel_is_async) {
         // Asynchronous computes.
@@ -2093,9 +2106,10 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
   for (int i = 0; i < item.num_outputs; ++i) {
     const TensorValue val = ctx->release_output(i);
     if (val.tensor == nullptr) {
-      // Unless it's a Switch or a Recv, the node must produce a
-      // tensor value at i-th output.
-      if (!item.is_recv_or_switch) {
+      // Unless it's a Switch or a Recv, or the executor has marked the output
+      // as not required, the node must produce a tensor value at i-th output.
+      if (!(item.is_recv_or_switch ||
+            (item.outputs_required && !item.outputs_required[i]))) {
         s.Update(errors::Internal("Missing ", i, "-th output from ",
                                   FormatNodeDefForError(item.kernel->def())));
       }
diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc
index 575fafdbcde..3817736020e 100644
--- a/tensorflow/core/common_runtime/session.cc
+++ b/tensorflow/core/common_runtime/session.cc
@@ -21,7 +21,6 @@ limitations under the License.
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/monitoring/gauge.h"
 #include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/monitoring.h"
 
 namespace tensorflow {
 namespace {
@@ -71,7 +70,6 @@ Session* NewSession(const SessionOptions& options) {
   // provided). For builds using "tensorflow/core/platform/default", this is
   // currently a no-op.
   session_created->GetCell()->Set(true);
-  monitoring::StartExporter();
   Session* out_session;
   s = NewSession(options, &out_session);
   if (!s.ok()) {
@@ -93,7 +91,6 @@ Status NewSession(const SessionOptions& options, Session** out_session) {
   // provided). For builds using "tensorflow/core/platform/default", this is
   // currently a no-op.
   session_created->GetCell()->Set(true);
-  monitoring::StartExporter();
   s = factory->NewSession(options, out_session);
   if (!s.ok()) {
     *out_session = nullptr;
diff --git a/tensorflow/core/distributed_runtime/eager/eager_client.h b/tensorflow/core/distributed_runtime/eager/eager_client.h
index 3b083f3cae6..5f260e477d6 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_client.h
+++ b/tensorflow/core/distributed_runtime/eager/eager_client.h
@@ -57,6 +57,8 @@ class EagerClient : public core::RefCounted {
   virtual void StreamingEnqueueAsync(const EnqueueRequest* request,
                                      EnqueueResponse* response,
                                      StatusCallback done) = 0;
+
+  virtual bool allow_multiple_pending_requests() const = 0;
 };
 
 // Simple wrapper class that can be used to retrieve EagerClients.
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index 686f471ca5e..41db645507b 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -96,6 +96,8 @@ class FakeEagerClient : public EagerClient {
     done(impl_->Enqueue(request, response));
   }
 
+  bool allow_multiple_pending_requests() const override { return false; }
+
  private:
   TestEagerServiceImpl* impl_;
 };
diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
index f5fc68a8e38..7bb001ef853 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
@@ -200,11 +200,12 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) {
   auto* remote_op = request.add_queue()->mutable_operation();
   PrepareRemoteOp(remote_op, op);
   remote_op->set_id(recv_op_id_);
+  uint64 context_view_id = ctx_->GetContextViewId();
 
   core::RefCountPtr<eager::EagerClient> eager_client;
   Status status = ctx_->GetClient(recv_device_, &eager_client);
   if (!status.ok()) {
-    captured_state_->dst()->Poison(status);
+    captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
     done(status);
     return;
   }
@@ -216,7 +217,7 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) {
   // Blocks until send has completed.
   Status send_status = captured_state_->GetSendStatus();
   if (!send_status.ok()) {
-    captured_state_->dst()->Poison(send_status);
+    captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
     done(send_status);
     return;
   }
@@ -224,7 +225,6 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) {
   EnqueueResponse* response = new EnqueueResponse;
   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
   Device* recv_device = recv_device_;
-  uint64 context_view_id = ctx_->GetContextViewId();
   eager_client->StreamingEnqueueAsync(
       &request, response,
       [captured_state, response, recv_device, context_view_id,
@@ -241,7 +241,7 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) {
                           "Please file an issue with the TensorFlow Team.";
           }
         } else {
-          captured_state->dst()->Poison(s);
+          captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
         }
         done(s);
         delete response;
@@ -254,8 +254,9 @@ void RemoteCopyNode::StartRecv(StatusCallback done) {
   EagerOperation op(ctx_);
   Status status = op.Reset("_Recv", /*raw_device_name=*/nullptr,
                            /*remote=*/false, /*executor=*/nullptr);
+  Device* recv_device = ctx_->CanonicalDevice(recv_device_);
   if (!status.ok()) {
-    captured_state_->dst()->Poison(status);
+    captured_state_->dst()->Poison(status, recv_device);
     done(status);
     return;
   }
@@ -276,12 +277,12 @@ void RemoteCopyNode::StartRecv(StatusCallback done) {
     std::vector<Tensor> outputs(1);
     status = RunLocalRecv(&op, &outputs);
     if (!status.ok()) {
-      captured_state_->dst()->Poison(status);
+      captured_state_->dst()->Poison(status, recv_device);
       done(status);
       return;
     }
-    status = captured_state_->dst()->SetTensor(
-        std::move(outputs[0]), ctx_->CanonicalDevice(recv_device_));
+    status =
+        captured_state_->dst()->SetTensor(std::move(outputs[0]), recv_device);
     done(status);
   } else {
     // Handles captured_state_->dst_ internally.
@@ -297,6 +298,7 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
   auto* send_tensor = request.add_queue()->mutable_send_tensor();
   send_tensor->set_op_id(recv_op_id_);
   send_tensor->set_device_name(recv_device_->name());
+  uint64 context_view_id = ctx_->GetContextViewId();
 
   // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
   // copy it to the CPU before copying it out.
@@ -314,7 +316,7 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
   core::RefCountPtr<eager::EagerClient> eager_client;
   s = ctx_->GetClient(recv_device_, &eager_client);
   if (!s.ok()) {
-    captured_state_->dst()->Poison(s);
+    captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
     done(s);
     return;
   }
@@ -322,7 +324,6 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
   captured_state->SetSrcShape(tensor.shape());
   Device* recv_device = recv_device_;
-  uint64 context_view_id = ctx_->GetContextViewId();
   eager_client->StreamingEnqueueAsync(
       &request, response,
       [captured_state, response, recv_device, context_view_id,
@@ -336,7 +337,7 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
                        << status.ToString();
           }
         } else {
-          captured_state->dst()->Poison(s);
+          captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
         }
         done(s);
         delete response;
@@ -378,7 +379,8 @@ void RemoteCopyNode::RunAsync(StatusCallback done) {
 
 void RemoteCopyNode::Abort(Status status) {
   if (!started_) {
-    captured_state_->dst()->Poison(status);
+    uint64 context_view_id = ctx_->GetContextViewId();
+    captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
   }
 }
 
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
index f84e0ebb5ee..81547615706 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
@@ -81,7 +81,7 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) {
                             "Please file an issue with the TensorFlow Team.";
             }
           } else {
-            retvals[i]->Poison(status);
+            retvals[i]->PoisonRemote(status, device, context_view_id);
           }
           retvals[i]->Unref();
         }
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index de5d66a1efb..ed9f9c0ee0f 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -34,14 +34,16 @@ namespace eager {
 
 // RemoteExecuteNode is an implementation of EagerNode which enqueues
 // an operation via RPC in a remote EagerService.
-class RemoteExecuteNode : public AsyncEagerNode {
+class RemoteExecuteNode : public AsyncRemoteExecuteNode {
  public:
-  RemoteExecuteNode(std::unique_ptr<EnqueueRequest> request, Device* device,
+  RemoteExecuteNode(EagerContext* eager_context,
+                    std::unique_ptr<EnqueueRequest> request, Device* device,
                     uint64 context_view_id, EagerClient* eager_client,
                     const NodeDef& ndef, FunctionLibraryDefinition* lib_def,
                     const gtl::InlinedVector<TensorHandle*, 4>& inputs,
                     absl::Span<TensorHandle*> retvals)
-      : AsyncEagerNode(),
+      : AsyncRemoteExecuteNode(),
+        eager_context_(eager_context),
         request_(std::move(request)),
         device_(device),
         context_view_id_(context_view_id),
@@ -62,6 +64,16 @@ class RemoteExecuteNode : public AsyncEagerNode {
       handle->Ref();
     }
     eager_client_->Ref();
+
+    needs_remote_inputs_ = false;
+    for (const TensorHandle* input : inputs_) {
+      // TODO(bramandia): Should this be op_device() instead?
+      if (input->resource_device() != nullptr &&
+          input->resource_device() != device_) {
+        needs_remote_inputs_ = true;
+        break;
+      }
+    }
   }
 
   ~RemoteExecuteNode() override {
@@ -81,12 +93,24 @@ class RemoteExecuteNode : public AsyncEagerNode {
 
   void RunAsync(StatusCallback done) override;
 
+  Status SyncExecutors() override { return eager_context_->SyncExecutors(); }
+
   void Abort(Status status) override {
+    int i = 0;
     for (auto handle : retvals_) {
-      handle->Poison(status);
+      handle->PoisonRemote(status, device_, context_view_id_);
+      ++i;
     }
   }
 
+  const EagerClient* eager_client() const override { return eager_client_; }
+
+  bool needs_remote_inputs() const override { return needs_remote_inputs_; }
+
+  bool allow_multiple_pending_requests() const override {
+    return eager_client_->allow_multiple_pending_requests();
+  }
+
   string DebugString() const override {
     string out = "[RemoteExecuteNode]";
     strings::StrAppend(&out, " request: ", request_->DebugString());
@@ -95,9 +119,11 @@ class RemoteExecuteNode : public AsyncEagerNode {
   }
 
  private:
+  EagerContext* eager_context_;  // Not owned, and must outlive this node.
   std::unique_ptr<EnqueueRequest> request_;
   Device* device_;             // Not owned
   uint64 context_view_id_;
+  bool needs_remote_inputs_;
   EagerClient* eager_client_;  // Not owned, and must outlive this node.
   const NodeDef ndef_;
   const FunctionLibraryDefinition* lib_def_;
diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h
index 9f7db52b447..34e3ec5f83d 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h
@@ -75,6 +75,9 @@ class UnshapedRemoteTensorHandleData : public TensorHandleData {
   Status NumElements(int64* num_elements) const override;
   Status Unprotect() override;
 
+  void Poison(Status status) { is_poisoned_ = status; }
+  Status IsPoisoned() const { return is_poisoned_; }
+
   string DebugString() const override;
 
   int64 op_id() const { return op_id_; }
@@ -94,6 +97,7 @@ class UnshapedRemoteTensorHandleData : public TensorHandleData {
   void ReleaseRemoteTensorHandle() { delete_remote_tensor_ = false; }
 
  private:
+  Status is_poisoned_;
   // IDs required when this class is representing a remote tensor handle.
   const int64 op_id_;
   const int32 output_num_;
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
index 5ad48118ae9..c1811303bc9 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
@@ -115,6 +115,10 @@ class GrpcEagerClient : public EagerClient {
   }
   ~GrpcEagerClient() override { thread_->Unref(); }
 
+  bool allow_multiple_pending_requests() const override {
+    return EnableStreaming();
+  }
+
 #define CLIENT_METHOD(method)                                             \
   void method##Async(const method##Request* request,                      \
                      method##Response* response, StatusCallback done)     \
diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc
index 69559f27c29..fac28514115 100644
--- a/tensorflow/core/distributed_runtime/worker_session.cc
+++ b/tensorflow/core/distributed_runtime/worker_session.cc
@@ -14,8 +14,8 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/distributed_runtime/worker_session.h"
 
+#include "tensorflow/core/lib/monitoring/collection_registry.h"
 #include "tensorflow/core/lib/monitoring/gauge.h"
-#include "tensorflow/core/platform/monitoring.h"
 
 namespace tensorflow {
 
@@ -123,7 +123,6 @@ WorkerSession::WorkerSession(
   // provided). For builds using "tensorflow/core/platform/default", this is
   // currently a no-op.
   worker_session_created->GetCell()->Set(true);
-  monitoring::StartExporter();
 }
 
 Status WorkerSession::UpdateWorkerCacheAndDevices(
@@ -170,7 +169,6 @@ WorkerSession::WorkerSession(
   // provided). For builds using "tensorflow/core/platform/default", this is
   // currently a no-op.
   worker_session_created->GetCell()->Set(true);
-  monitoring::StartExporter();
 }
 
 WorkerSession::~WorkerSession() {
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 9e22321b42c..876c7551263 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -156,6 +156,18 @@ class OpKernel {
   // Returns a pointer to the tensor stored inside constant ops.
   virtual const Tensor* const_tensor() const { return nullptr; }
 
+  // Returns true if this kernel must produce its ith output.
+  // REQUIRES: 0 <= i < num_inputs().
+  bool output_required(int i) const { return outputs_required_[i]; }
+
+  // Hints whether or not the ith output must be produced when running the
+  // kernel. By default, all outputs are required. The kernel implementation
+  // may ignore the hint.
+  // REQUIRES: 0 <= i < num_inputs().
+  void set_output_required(int i, bool is_required) {
+    outputs_required_[i] = is_required;
+  }
+
   // Updates the dynamic cost estimate, which is used to determine whether this
   // op is expensive. The new cost estimate is a weighted average of the old
   // cost estimate and the latest cost.
@@ -223,6 +235,7 @@ class OpKernel {
   const bool is_deferred_;
   bool expensive_;
   std::atomic_uint_fast64_t cost_estimate_;
+  std::vector<bool> outputs_required_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
 };
@@ -732,6 +745,10 @@ class OpKernelContext {
     // For tracking actively running deferred ops.
     std::function<void()> inc_num_deferred_ops_function;
     std::function<void()> dec_num_deferred_ops_function;
+
+    // For implementing `OpKernelContext::output_required()`. If null, all
+    // outputs are required.
+    bool* outputs_required_array = nullptr;
   };
 
   // params must outlive the OpKernelContext.
@@ -941,10 +958,9 @@ class OpKernelContext {
   // should call allocate_output(index, ...), set_output(index, ...),
   // set_output_ref(index, ...), or set the status to a non-ok value.
   // If it returns false, it may output, but is not required to do so.
-  // TODO(mrry): Convert this to return Status, and implement a string
-  // name version.
   bool output_required(int index) const {
-    return true;  // TODO(josh11b): implement
+    return !params_->outputs_required_array ||
+           params_->outputs_required_array[index];
   }
 
   // Allocation of tensors during kernel execution inside the Compute
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index cdb841e06d7..79e0cd2bcf1 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -17,6 +17,8 @@ limitations under the License.
 // all over the place, we should log an error and execute the original graph.
 #ifdef INTEL_MKL
 
+#include "tensorflow/core/graph/mkl_layout_pass.h"
+
 #include <algorithm>
 #include <functional>
 #include <memory>
@@ -34,6 +36,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
 #include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
@@ -43,9 +46,6 @@ limitations under the License.
 #include "tensorflow/core/util/tensor_format.h"
 #include "tensorflow/core/util/util.h"
 
-#include "tensorflow/core/graph/mkl_graph_util.h"
-#include "tensorflow/core/graph/mkl_layout_pass.h"
-
 namespace tensorflow {
 
 // This pass implements rewriting of graph to support following scenarios:
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index 9d7f37d2be6..19ed267b441 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -48,53 +48,60 @@ class ConcatBaseOp : public OpKernel {
   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
       ConstMatrixVector;
 
-  explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
+  explicit ConcatBaseOp(OpKernelConstruction* c)
+      : OpKernel(c),
+        axis_attribute_name_(AxisArgName == NAME_IS_AXIS
+                                 ? "axis"
+                                 : AxisArgName == NAME_IS_CONCAT_DIM
+                                       ? "concat_dim"
+                                       : "<invalid>") {
+    int unused;
+    OP_REQUIRES_OK(
+        c, InputRange(axis_attribute_name_, &axis_input_index_, &unused));
+    OP_REQUIRES_OK(c, InputRange("values", &values_input_start_index_,
+                                 &values_input_end_index_));
+  }
 
   void Compute(OpKernelContext* c) override {
-    const Tensor* concat_dim_tensor;
-    const char* axis_attribute_name =
-        AxisArgName == NAME_IS_AXIS
-            ? "axis"
-            : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
-    OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
+    const Tensor& concat_dim_tensor = c->input(axis_input_index_);
+
     // TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars.
     OP_REQUIRES(c,
-                (TensorShapeUtils::IsScalar(concat_dim_tensor->shape()) ||
-                 (TensorShapeUtils::IsVector(concat_dim_tensor->shape()) &&
-                  concat_dim_tensor->shape().dim_size(0) == 1)),
+                (TensorShapeUtils::IsScalar(concat_dim_tensor.shape()) ||
+                 (TensorShapeUtils::IsVector(concat_dim_tensor.shape()) &&
+                  concat_dim_tensor.shape().dim_size(0) == 1)),
                 errors::InvalidArgument(
-                    axis_attribute_name,
+                    axis_attribute_name_,
                     " tensor should be a scalar integer, but got shape ",
-                    concat_dim_tensor->shape().DebugString()));
+                    concat_dim_tensor.shape().DebugString()));
     int64 concat_dim;
     // In case of ConcatV2, "axis" could be int32 or int64
     if (AxisArgName == NAME_IS_AXIS) {
       OP_REQUIRES(
           c,
-          (concat_dim_tensor->dtype() == DT_INT32 ||
-           concat_dim_tensor->dtype() == DT_INT64),
-          errors::InvalidArgument(axis_attribute_name,
+          (concat_dim_tensor.dtype() == DT_INT32 ||
+           concat_dim_tensor.dtype() == DT_INT64),
+          errors::InvalidArgument(axis_attribute_name_,
                                   " tensor should be int32 or int64, but got ",
-                                  DataTypeString(concat_dim_tensor->dtype())));
+                                  DataTypeString(concat_dim_tensor.dtype())));
     } else {
-      OP_REQUIRES(c, (concat_dim_tensor->dtype() == DT_INT32),
+      OP_REQUIRES(c, (concat_dim_tensor.dtype() == DT_INT32),
                   errors::InvalidArgument(
-                      axis_attribute_name, " tensor should be int32, but got ",
-                      DataTypeString(concat_dim_tensor->dtype())));
+                      axis_attribute_name_, " tensor should be int32, but got ",
+                      DataTypeString(concat_dim_tensor.dtype())));
     }
-    if (concat_dim_tensor->dtype() == DT_INT32) {
+    if (concat_dim_tensor.dtype() == DT_INT32) {
       concat_dim =
-          internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
+          internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
     } else {
       concat_dim =
-          internal::SubtleMustCopy(concat_dim_tensor->scalar<int64>()());
+          internal::SubtleMustCopy(concat_dim_tensor.scalar<int64>()());
     }
 
-    OpInputList values;
-    OP_REQUIRES_OK(c, c->input_list("values", &values));
-    const int N = values.size();
-    const int input_dims = values[0].dims();
-    const TensorShape& input_shape = values[0].shape();
+    const int N = values_input_end_index_ - values_input_start_index_;
+    const Tensor& first_input = c->input(values_input_start_index_);
+    const int input_dims = first_input.dims();
+    const TensorShape& input_shape = first_input.shape();
 
     int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
     // concat_dim==0 allows concatenating a list of scalars into a vector.
@@ -116,7 +123,7 @@ class ConcatBaseOp : public OpKernel {
     }
     int64 output_concat_dim = 0;
     for (int i = 0; i < N; ++i) {
-      const auto& in = values[i];
+      const auto& in = c->input(values_input_start_index_ + i);
       OP_REQUIRES(
           c, in.dims() == input_dims,
           errors::InvalidArgument(
@@ -137,7 +144,7 @@ class ConcatBaseOp : public OpKernel {
       if (in.NumElements() > 0) {
         int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
-            in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
+            in.template shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
       }
       // TODO(rmlarsen): Remove check once !allow_legacy_scalars()?
       output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
@@ -170,6 +177,12 @@ class ConcatBaseOp : public OpKernel {
       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
     }
   }
+
+ private:
+  const char* const axis_attribute_name_;
+  int axis_input_index_;
+  int values_input_start_index_;
+  int values_input_end_index_;
 };
 
 template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index ccdafdf91c9..4f26aed641e 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -48,7 +48,7 @@ namespace tensorflow {
 namespace {
 
 NodeDef StripTensorDataFromNodeDef(OpKernelConstruction* ctx) {
-#ifndef __ANDROID__
+#ifndef TENSORFLOW_LITE_PROTOS
   DCHECK_EQ(NodeDef::descriptor()->field_count(), 6)
       << "The NodeDef format has changed, and the attr-stripping code may need "
       << "to be updated.";
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index 911462c8eff..54bde28ad62 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -44,12 +44,9 @@ class SelectOp : public OpKernel {
   explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {}
 
   void Compute(OpKernelContext* ctx) override {
-    const Tensor* cond;
-    const Tensor* then;
-    const Tensor* else_;
-    OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
-    OP_REQUIRES_OK(ctx, ctx->input("t", &then));
-    OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
+    const Tensor* cond = &ctx->input(0);
+    const Tensor* then = &ctx->input(1);
+    const Tensor* else_ = &ctx->input(2);
 
     if (TensorShapeUtils::IsScalar(cond->shape())) {
       ComputeScalar(ctx, cond, then, else_);
@@ -149,12 +146,9 @@ class SelectV2Op : public OpKernel {
   explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
 
   void Compute(OpKernelContext* ctx) override {
-    const Tensor* cond;
-    const Tensor* then;
-    const Tensor* else_;
-    OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
-    OP_REQUIRES_OK(ctx, ctx->input("t", &then));
-    OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
+    const Tensor* cond = &ctx->input(0);
+    const Tensor* then = &ctx->input(1);
+    const Tensor* else_ = &ctx->input(2);
 
     // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
     // This matches the behavior of numpy.
@@ -260,7 +254,6 @@ class SelectV2Op : public OpKernel {
             ctx->input(1).shape().DebugString(), " is not supported yet."));
         break;
     }
-    return;
   }
 
  private:
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index d041ab5ac6a..a68b3faeb37 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -434,6 +434,7 @@ tf_kernel_library(
     name = "snapshot_dataset_op",
     srcs = ["snapshot_dataset_op.cc"],
     deps = [
+        "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:dataset_ops_op_lib",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
@@ -441,6 +442,7 @@ tf_kernel_library(
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:graph_view",
         "//tensorflow/core/kernels/data:dataset_utils",
+        "//tensorflow/core/platform:platform_port",
         "//tensorflow/core/profiler/lib:traceme",
         "@com_google_absl//absl/time",
     ],
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
index ae3015bc833..68ee3c4c134 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
 #include <random>
 
 #include "absl/time/clock.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
 #include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -32,7 +33,9 @@ limitations under the License.
 #include "tensorflow/core/lib/io/compression.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/io/random_inputstream.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/snappy.h"
 #if !defined(IS_SLIM_BUILD)
 #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
 #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
@@ -63,9 +66,6 @@ enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
 // Defaults to 10 GiB per shard.
 const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024;
 
-const int64 kSnappyWriterInputBufferSizeBytes = 16 << 20;   // 16 MiB
-const int64 kSnappyWriterOutputBufferSizeBytes = 16 << 20;  // 16 MiB
-
 // The reader input buffer size is deliberately large because the input reader
 // will throw an error if the compressed block length cannot fit in the input
 // buffer.
@@ -75,6 +75,8 @@ const int64 kSnappyReaderOutputBufferSizeBytes = 32 << 20;  // 32 MiB
 
 const size_t kHeaderSize = sizeof(uint64);
 
+const int64 kCurrentVersion = 1;
+
 constexpr char kModeAuto[] = "auto";
 constexpr char kModeWrite[] = "write";
 constexpr char kModeRead[] = "read";
@@ -95,6 +97,7 @@ constexpr char kState[] = "state";
 constexpr char kHashDir[] = "hash_dir";
 constexpr char kRunId[] = "run_id";
 constexpr char kRunDir[] = "run_dir";
+constexpr char kVersionStr[] = "version";
 constexpr char kFilenames[] = "filenames";
 constexpr char kCurrentFilenames[] = "current_filenames";
 constexpr char kElementsProduced[] = "elements_produced";
@@ -115,9 +118,9 @@ class SnapshotWriter {
   static constexpr const char* const kWriteStringPiece = "WriteStringPiece";
   static constexpr const char* const kWriteCord = "WriteCord";
 
-  explicit SnapshotWriter(WritableFile* dest, const string& compression_type =
-                                                  io::compression::kNone)
-      : dest_(dest), compression_type_(compression_type) {
+  explicit SnapshotWriter(WritableFile* dest, const string& compression_type,
+                          int version, const DataTypeVector& dtypes)
+      : dest_(dest), compression_type_(compression_type), version_(version) {
 #if defined(IS_SLIM_BUILD)
     if (compression_type != io::compression::kNone) {
       LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
@@ -134,41 +137,100 @@ class SnapshotWriter {
       TF_CHECK_OK(zlib_output_buffer->Init());
       dest_ = zlib_output_buffer;
       dest_is_owned_ = true;
-    } else if (compression_type == io::compression::kSnappy) {
-      io::SnappyOutputBuffer* snappy_output_buffer = new io::SnappyOutputBuffer(
-          dest, /*input_buffer_bytes=*/kSnappyWriterInputBufferSizeBytes,
-          /*output_buffer_bytes=*/kSnappyWriterOutputBufferSizeBytes);
-      dest_ = snappy_output_buffer;
-      dest_is_owned_ = true;
     }
 #endif  // IS_SLIM_BUILD
+    simple_tensor_mask_.reserve(dtypes.size());
+    for (const auto& dtype : dtypes) {
+      if (DataTypeCanUseMemcpy(dtype)) {
+        simple_tensor_mask_.push_back(true);
+        num_simple_++;
+      } else {
+        simple_tensor_mask_.push_back(false);
+        num_complex_++;
+      }
+    }
   }
 
-  Status WriteRecord(const StringPiece& data) {
-    profiler::TraceMe activity(
-        [&]() {
-          return absl::StrCat(kClassName, kSeparator, kWriteStringPiece);
-        },
-        profiler::TraceMeLevel::kInfo);
-    char header[kHeaderSize];
-    core::EncodeFixed64(header, data.size());
-    TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
-    return dest_->Append(data);
-  }
-
+  Status WriteTensors(const std::vector<Tensor>& tensors) {
+    if (compression_type_ != io::compression::kSnappy) {
+      experimental::SnapshotRecord record;
+      for (const auto& tensor : tensors) {
+        TensorProto* t = record.add_tensor();
+        tensor.AsProtoTensorContent(t);
+      }
 #if defined(PLATFORM_GOOGLE)
-  Status WriteRecord(const absl::Cord& data) {
-    profiler::TraceMe activity(
-        [&]() { return absl::StrCat(kClassName, kSeparator, kWriteCord); },
-        profiler::TraceMeLevel::kInfo);
-    char header[kHeaderSize];
-    core::EncodeFixed64(header, data.size());
-
-    TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
-
-    return dest_->Append(data);
-  }
+      return WriteRecord(record.SerializeAsCord());
+#else   // PLATFORM_GOOGLE
+      return WriteRecord(record.SerializeAsString());
 #endif  // PLATFORM_GOOGLE
+    }
+
+    if (version_ != 1) {
+      return errors::InvalidArgument("Version: ", version_,
+                                     " is not supported.");
+    }
+    if (compression_type_ != io::compression::kSnappy) {
+      return errors::InvalidArgument(
+          "Version 1 is only compatible with snappy compression");
+    }
+
+    std::vector<const TensorBuffer*> tensor_buffers;
+    tensor_buffers.reserve(num_simple_);
+    std::vector<TensorProto> tensor_protos;
+    tensor_protos.reserve(num_complex_);
+    SnapshotTensorMetadata metadata;
+    int64 total_size = 0;
+    for (int i = 0; i < tensors.size(); ++i) {
+      const Tensor& tensor = tensors[i];
+      TensorMetadata* tensor_metadata = metadata.add_tensor_metadata();
+      tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
+      int64 size = 0;
+      if (simple_tensor_mask_[i]) {
+        auto tensor_buffer = DMAHelper::buffer(&tensor);
+        tensor_buffers.push_back(tensor_buffer);
+        size = tensor_buffer->size();
+      } else {
+        TensorProto proto;
+        tensor.AsProtoTensorContent(&proto);
+        size = proto.ByteSizeLong();
+        tensor_protos.push_back(std::move(proto));
+      }
+      tensor_metadata->set_tensor_size_bytes(size);
+      total_size += size;
+    }
+
+    std::vector<char> uncompressed(total_size);
+    char* position = uncompressed.data();
+    int buffer_index = 0;
+    int proto_index = 0;
+    for (int i = 0; i < tensors.size(); ++i) {
+      const auto& tensor_metadata = metadata.tensor_metadata(i);
+      if (simple_tensor_mask_[i]) {
+        memcpy(position, tensor_buffers[buffer_index]->data(),
+               tensor_metadata.tensor_size_bytes());
+        buffer_index++;
+      } else {
+        tensor_protos[proto_index].SerializeToArray(
+            position, tensor_metadata.tensor_size_bytes());
+        proto_index++;
+      }
+      position += tensor_metadata.tensor_size_bytes();
+    }
+    DCHECK_EQ(position, uncompressed.data() + total_size);
+
+    string output;
+    if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
+      return errors::Internal("Failed to compress using snappy.");
+    }
+#if defined(PLATFORM_GOOGLE)
+    absl::Cord metadata_serialized = metadata.SerializeAsCord();
+#else   // PLATFORM_GOOGLE
+    std::string metadata_serialized = metadata.SerializeAsString();
+#endif  // PLATFORM_GOOGLE
+    TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
+    TF_RETURN_IF_ERROR(WriteRecord(output));
+    return Status::OK();
+  }
 
   Status Sync() { return dest_->Sync(); }
 
@@ -192,9 +254,29 @@ class SnapshotWriter {
   }
 
  private:
+  Status WriteRecord(const StringPiece& data) {
+    char header[kHeaderSize];
+    core::EncodeFixed64(header, data.size());
+    TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
+    return dest_->Append(data);
+  }
+
+#if defined(PLATFORM_GOOGLE)
+  Status WriteRecord(const absl::Cord& data) {
+    char header[kHeaderSize];
+    core::EncodeFixed64(header, data.size());
+    TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
+    return dest_->Append(data);
+  }
+#endif  // PLATFORM_GOOGLE
+
   WritableFile* dest_;
   bool dest_is_owned_ = false;
   const string compression_type_;
+  const int version_;
+  std::vector<bool> simple_tensor_mask_;  // true for simple, false for complex.
+  int num_simple_ = 0;
+  int num_complex_ = 0;
 };
 
 class SnapshotReader {
@@ -203,12 +285,14 @@ class SnapshotReader {
   static constexpr const char* const kReadString = "ReadString";
   static constexpr const char* const kReadCord = "ReadCord";
 
-  explicit SnapshotReader(
-      RandomAccessFile* file,
-      const string& compression_type = io::compression::kNone)
+  explicit SnapshotReader(RandomAccessFile* file,
+                          const string& compression_type, int version,
+                          const DataTypeVector& dtypes)
       : file_(file),
         input_stream_(new io::RandomAccessInputStream(file)),
-        compression_type_(compression_type) {
+        compression_type_(compression_type),
+        version_(version),
+        dtypes_(dtypes) {
 #if defined(IS_SLIM_BUILD)
     if (compression_type_ != io::compression::kNone) {
       LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
@@ -223,17 +307,167 @@ class SnapshotReader {
           input_stream_.release(), zlib_options.input_buffer_size,
           zlib_options.output_buffer_size, zlib_options, true);
     } else if (compression_type_ == io::compression::kSnappy) {
-      input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
-          file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
-          /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
+      if (version_ == 0) {
+        input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
+            file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
+            /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
+      } else {
+        input_stream_ =
+            absl::make_unique<io::BufferedInputStream>(file_, 64 << 20);
+      }
     }
 #endif  // IS_SLIM_BUILD
+    simple_tensor_mask_.reserve(dtypes.size());
+    for (const auto& dtype : dtypes) {
+      if (DataTypeCanUseMemcpy(dtype)) {
+        simple_tensor_mask_.push_back(true);
+        num_simple_++;
+      } else {
+        simple_tensor_mask_.push_back(false);
+        num_complex_++;
+      }
+    }
+  }
+
+  Status ReadTensors(std::vector<Tensor>* read_tensors) {
+    profiler::TraceMe activity(
+        [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
+        profiler::TraceMeLevel::kInfo);
+    if (version_ == 0 || compression_type_ != io::compression::kSnappy) {
+      return ReadTensorsV0(read_tensors);
+    }
+    if (version_ != 1) {
+      return errors::InvalidArgument("Version: ", version_,
+                                     " is not supported.");
+    }
+    if (compression_type_ != io::compression::kSnappy) {
+      return errors::InvalidArgument("Version 1 only supports snappy.");
+    }
+
+    SnapshotTensorMetadata metadata;
+    tstring metadata_str;
+    TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
+    if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
+      return errors::DataLoss("Could not parse SnapshotTensorMetadata");
+    }
+    read_tensors->reserve(metadata.tensor_metadata_size());
+
+    std::vector<Tensor> simple_tensors;
+    simple_tensors.reserve(num_simple_);
+    std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs;
+    tensor_proto_strs.reserve(num_complex_);
+    TF_RETURN_IF_ERROR(
+        SnappyUncompress(metadata, &simple_tensors, &tensor_proto_strs));
+
+    int simple_index = 0;
+    int complex_index = 0;
+    for (int i = 0; i < simple_tensor_mask_.size(); ++i) {
+      if (simple_tensor_mask_[i]) {
+        read_tensors->push_back(std::move(simple_tensors[simple_index]));
+        simple_index++;
+      } else {
+        auto tensor_proto_str =
+            std::move(tensor_proto_strs[complex_index].first);
+        size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
+        TensorProto tp;
+#if defined(PLATFORM_GOOGLE)
+        auto tensor_proto_ptr = tensor_proto_str.release();
+        absl::Cord c;
+        c.AppendExternalMemory(
+            absl::string_view(tensor_proto_ptr, tensor_proto_size),
+            tensor_proto_ptr,
+            [](void* arg) { delete[] static_cast<char*>(arg); });
+        if (!tp.ParseFromCord(c)) {
+          return errors::Internal("Could not parse TensorProto");
+        }
+#else   // PLATFORM_GOOGLE
+        if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
+          return errors::Internal("Could not parse TensorProto");
+        }
+#endif  // PLATFORM_GOOGLE
+        Tensor t;
+        if (!t.FromProto(tp)) {
+          return errors::Internal("Could not parse Tensor");
+        }
+        read_tensors->push_back(std::move(t));
+        complex_index++;
+      }
+    }
+    return Status::OK();
+  }
+
+ private:
+  Status ReadTensorsV0(std::vector<Tensor>* read_tensors) {
+    experimental::SnapshotRecord record;
+#if defined(PLATFORM_GOOGLE)
+    absl::Cord c;
+    TF_RETURN_IF_ERROR(ReadRecord(&c));
+    record.ParseFromCord(c);
+#else   // PLATFORM_GOOGLE
+    tstring record_bytes;
+    TF_RETURN_IF_ERROR(ReadRecord(&record_bytes));
+    record.ParseFromArray(record_bytes.data(), record_bytes.size());
+#endif  // PLATFORM_GOOGLE
+    read_tensors->reserve(record.tensor_size());
+    for (int i = 0; i < record.tensor_size(); ++i) {
+      read_tensors->emplace_back();
+      if (!read_tensors->back().FromProto(record.tensor(i))) {
+        return errors::DataLoss("Unable to parse tensor from proto.");
+      }
+    }
+    return Status::OK();
+  }
+
+  Status SnappyUncompress(
+      const SnapshotTensorMetadata& metadata,
+      std::vector<Tensor>* simple_tensors,
+      std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
+          tensor_proto_strs) {
+    tstring compressed;
+    TF_RETURN_IF_ERROR(ReadRecord(&compressed));
+    size_t size;
+    if (!port::Snappy_GetUncompressedLength(compressed.data(),
+                                            compressed.size(), &size)) {
+      return errors::Internal("Could not get snappy uncompressed length");
+    }
+
+    int num_tensors = metadata.tensor_metadata_size();
+    std::vector<struct iovec> iov(num_tensors);
+    int index = 0;
+    int64 total_size = 0;
+    for (int i = 0; i < simple_tensor_mask_.size(); ++i) {
+      const auto& tensor_metadata = metadata.tensor_metadata(i);
+      if (simple_tensor_mask_[i]) {
+        TensorShape shape(tensor_metadata.tensor_shape());
+        Tensor simple_tensor(dtypes_[i], shape);
+        TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor);
+        iov[index].iov_base = buffer->data();
+        iov[index].iov_len = buffer->size();
+        simple_tensors->push_back(std::move(simple_tensor));
+      } else {
+        auto tensor_proto_str =
+            absl::make_unique<char[]>(tensor_metadata.tensor_size_bytes());
+        iov[index].iov_base = tensor_proto_str.get();
+        iov[index].iov_len = tensor_metadata.tensor_size_bytes();
+        tensor_proto_strs->push_back(std::make_pair(
+            std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes()));
+      }
+      total_size += iov[index].iov_len;
+      index++;
+    }
+    if (size != total_size) {
+      return errors::Internal("Uncompressed size mismatch. Snappy expects ",
+                              size, " whereas the tensor metadata suggests ",
+                              total_size);
+    }
+    if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(),
+                                        iov.data(), num_tensors)) {
+      return errors::Internal("Failed to perform snappy decompression.");
+    }
+    return Status::OK();
   }
 
   Status ReadRecord(tstring* record) {
-    profiler::TraceMe activity(
-        [&]() { return absl::StrCat(kClassName, kSeparator, kReadString); },
-        profiler::TraceMeLevel::kInfo);
     tstring header;
     TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
     uint64 length = core::DecodeFixed64(header.data());
@@ -245,13 +479,6 @@ class SnapshotReader {
     tstring header;
     TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
     uint64 length = core::DecodeFixed64(header.data());
-    profiler::TraceMe activity(
-        [&]() {
-          return absl::StrCat(kClassName, kSeparator, kReadCord,
-                              "#length=", length, "#");
-        },
-        profiler::TraceMeLevel::kInfo);
-
     if (compression_type_ == io::compression::kNone) {
       return input_stream_->ReadNBytes(length, record);
     } else {
@@ -268,50 +495,31 @@ class SnapshotReader {
   }
 #endif
 
- private:
   RandomAccessFile* file_;
   std::unique_ptr<io::InputStreamInterface> input_stream_;
   const string compression_type_;
+  const int version_;
+  const DataTypeVector dtypes_;
+  int num_simple_ = 0;
+  int num_complex_ = 0;
+  std::vector<bool> simple_tensor_mask_;  // true for simple, false for complex.
 };
 
 Status WriteMetadataFile(const string& hash_dir,
                          const experimental::SnapshotMetadataRecord& metadata) {
   string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
   TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
-
   std::string tmp_filename =
       absl::StrCat(metadata_filename, "-tmp-", random::New64());
-
-  std::unique_ptr<WritableFile> file;
-  TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(tmp_filename, &file));
-
-  auto writer = absl::make_unique<SnapshotWriter>(file.get());
-  TF_RETURN_IF_ERROR(writer->WriteRecord(metadata.SerializeAsString()));
-  TF_RETURN_IF_ERROR(writer->Close());
-  TF_RETURN_IF_ERROR(file->Sync());
-  TF_RETURN_IF_ERROR(file->Close());
-
-  TF_RETURN_IF_ERROR(
-      Env::Default()->RenameFile(tmp_filename, metadata_filename));
-
-  return Status::OK();
+  TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), tmp_filename, metadata));
+  return Env::Default()->RenameFile(tmp_filename, metadata_filename);
 }
 
 Status ReadMetadataFile(const string& hash_dir,
                         experimental::SnapshotMetadataRecord* metadata) {
   string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
   TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename));
-
-  std::unique_ptr<RandomAccessFile> file;
-  TF_RETURN_IF_ERROR(
-      Env::Default()->NewRandomAccessFile(metadata_filename, &file));
-
-  tstring record_bytes;
-  SnapshotReader reader(file.get());
-  TF_RETURN_IF_ERROR(reader.ReadRecord(&record_bytes));
-
-  metadata->ParseFromArray(record_bytes.data(), record_bytes.size());
-  return Status::OK();
+  return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
 }
 
 Status DumpDatasetGraph(const std::string& path, uint64 hash,
@@ -332,6 +540,10 @@ Status DetermineOpState(const std::string& mode_string,
                         const uint64 pending_snapshot_expiry_seconds,
                         SnapshotMode* mode) {
   if (mode_string == kModeRead) {
+    // In read mode, we should expect a metadata file is written.
+    if (errors::IsNotFound(file_status)) {
+      return file_status;
+    }
     LOG(INFO) << "Overriding mode to reader.";
     *mode = READER;
     return Status::OK();
@@ -727,10 +939,25 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
             if (run_id.empty()) {
               run_id = metadata.run_id();
             }
+            // dtypes in metadata should be the same as dataset()->output_dtypes
+            if (metadata.dtype_size() != dataset()->output_dtypes().size()) {
+              return errors::Internal(
+                  "Expected number of dtypes: ",
+                  dataset()->output_dtypes().size(),
+                  " but number in snapshot: ", metadata.dtype_size());
+            }
+            for (int i = 0; i < metadata.dtype_size(); ++i) {
+              if (metadata.dtype(i) != dataset()->output_dtypes()[i]) {
+                return errors::Internal(
+                    "Type: ", i,
+                    " doesn't match. Snapshot: ", metadata.dtype(i),
+                    "; dataset: ", dataset()->output_dtypes()[i]);
+              }
+            }
             iterator_ = absl::make_unique<SnapshotReaderIterator>(
                 SnapshotReaderIterator::Params{
                     dataset(), absl::StrCat(prefix(), "ReaderImpl")},
-                hash_dir_, run_id);
+                hash_dir_, run_id, metadata.version());
             break;
           case PASSTHROUGH:
             iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
@@ -748,10 +975,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
 
         explicit SnapshotReaderIterator(const Params& params,
                                         const string& hash_dir,
-                                        const string& run_id)
+                                        const string& run_id, int64 version)
             : DatasetIterator<Dataset>(params),
               hash_dir_(hash_dir),
-              run_id_(run_id) {}
+              run_id_(run_id),
+              version_(version) {}
 
         ~SnapshotReaderIterator() override {
           mutex_lock l(mu_);
@@ -889,6 +1117,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
               writer->WriteScalar(full_name(kHashDir), hash_dir_));
           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_));
           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_));
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name(kVersionStr), version_));
           TF_RETURN_IF_ERROR(writer->WriteScalar(
               full_name(strings::StrCat(kFilenames, kSizeSuffix)),
               filenames_.size()));
@@ -932,6 +1162,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
           }
           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_));
           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_));
+          TF_RETURN_IF_ERROR(
+              reader->ReadScalar(full_name(kVersionStr), &version_));
           curr_filenames_.clear();
           curr_filenames_.reserve(dataset()->num_reader_threads_);
           for (auto i = 0; i < dataset()->num_reader_threads_; ++i) {
@@ -986,7 +1218,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
           std::unique_ptr<RandomAccessFile> file;
           TF_RETURN_IF_ERROR(
               Env::Default()->NewRandomAccessFile(filename, &file));
-          SnapshotReader reader(file.get(), dataset()->compression_);
+          SnapshotReader reader(file.get(), dataset()->compression_, version_,
+                                dataset()->output_dtypes());
 
           while (true) {
             // Wait for a slot in the buffer.
@@ -1003,30 +1236,14 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
                     "ReadFile");
               }
             }
-#if !defined(PLATFORM_GOOGLE)
-            tstring record_bytes;
-            Status s = reader.ReadRecord(&record_bytes);
-#else
-            absl::Cord record_cord;
-            Status s = reader.ReadRecord(&record_cord);
-#endif
+            std::vector<Tensor> read_tensors;
+            Status s = reader.ReadTensors(&read_tensors);
             if (s.ok()) {
               profiler::TraceMe activity(
                   [&]() { return absl::StrCat(prefix(), kSeparator, kParse); },
                   profiler::TraceMeLevel::kInfo);
-              experimental::SnapshotRecord record;
-#if !defined(PLATFORM_GOOGLE)
-              record.ParseFromArray(record_bytes.data(), record_bytes.size());
-#else
-              record.ParseFromCord(record_cord);
-#endif
               BufferElement elem;
-              for (int i = 0; i < record.tensor_size(); ++i) {
-                elem.value.emplace_back();
-                if (!elem.value.back().FromProto(record.tensor(i))) {
-                  return errors::DataLoss("Unable to parse tensor from proto.");
-                }
-              }
+              elem.value = std::move(read_tensors);
               elem.status = Status::OK();
               mutex_lock l(mu_);
               buffer_.push_back(std::move(elem));
@@ -1142,9 +1359,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
         condition_variable cond_var_;
 
         const string hash_dir_;
-        const experimental::SnapshotMetadataRecord metadata_;
         tstring run_id_ GUARDED_BY(mu_);
         tstring run_dir_ GUARDED_BY(mu_);
+        int64 version_;
         std::vector<tstring> filenames_;
 
         uint64 elements_produced_ GUARDED_BY(mu_) = 0;
@@ -1220,6 +1437,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
                 metadata.set_creation_timestamp(EnvTime::NowMicros());
                 metadata.set_graph_hash(dataset()->graph_hash_);
                 metadata.set_run_id(run_id_.data(), run_id_.size());
+                metadata.set_version(kCurrentVersion);
+                for (const auto& output_dtype : dataset()->output_dtypes()) {
+                  metadata.add_dtype(output_dtype);
+                }
                 metadata.set_finalized(false);
                 TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata));
               }
@@ -1564,11 +1785,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
           }
 
           if (produced_elem) {
-            experimental::SnapshotRecord record;
             for (const auto& out_tensor : elem.value) {
               *bytes_written += out_tensor.TotalBytes();
-              TensorProto* t = record.add_tensor();
-              out_tensor.AsProtoTensorContent(t);
             }
 
             bool should_close;
@@ -1584,16 +1802,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
               TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
                   *snapshot_data_filename, file));
               *writer = absl::make_unique<SnapshotWriter>(
-                  file->get(), dataset()->compression_);
+                  file->get(), dataset()->compression_, kCurrentVersion,
+                  dataset()->output_dtypes());
               *bytes_written = 0;
             }
-#if defined(PLATFORM_GOOGLE)
-            TF_RETURN_IF_ERROR(
-                (*writer)->WriteRecord(record.SerializeAsCord()));
-#else   // PLATFORM_GOOGLE
-            TF_RETURN_IF_ERROR(
-                (*writer)->WriteRecord(record.SerializeAsString()));
-#endif  // PLATFORM_GOOGLE
+            TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value));
             return Status::OK();
           }
 
@@ -1641,7 +1854,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
             return;
           }
           std::unique_ptr<SnapshotWriter> writer(
-              new SnapshotWriter(file.get(), dataset()->compression_));
+              new SnapshotWriter(file.get(), dataset()->compression_,
+                                 kCurrentVersion, dataset()->output_dtypes()));
 
           bool end_of_processing = false;
           while (!end_of_processing) {
diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.cc b/tensorflow/core/kernels/mfcc_mel_filterbank.cc
index 2c22fec2b11..8eb2d9d8309 100644
--- a/tensorflow/core/kernels/mfcc_mel_filterbank.cc
+++ b/tensorflow/core/kernels/mfcc_mel_filterbank.cc
@@ -100,8 +100,8 @@ bool MfccMelFilterbank::Initialize(int input_length, double input_sample_rate,
     if ((i < start_index_) || (i > end_index_)) {
       band_mapper_[i] = -2;  // Indicate an unused Fourier coefficient.
     } else {
-      while ((center_frequencies_[channel] < melf) &&
-             (channel < num_channels_)) {
+      while ((channel < num_channels_) &&
+             (center_frequencies_[channel] < melf)) {
         ++channel;
       }
       band_mapper_[i] = channel - 1;  // Can be == -1
diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc
index e36e481ebbf..0751fff4c26 100644
--- a/tensorflow/core/kernels/mkl_avgpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc
@@ -111,16 +111,18 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
       else
         pooling_prop_kind = prop_kind::forward_training;
 #ifdef ENABLE_MKLDNN_V1
+      // TODO(DNNL): Find out what should we use input_md.data.format.
       MklPoolingParams fwdParams(
           src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
           padding_right, ALGORITHM::pooling_avg_exclude_padding,
           pooling_prop_kind,
-          static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
+          static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md);
 #else
       MklPoolingParams fwdParams(
           src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
           padding_right, ALGORITHM::pooling_avg_exclude_padding,
-          pooling_prop_kind, static_cast<MEMORY_FORMAT>(input_md.data.format));
+          pooling_prop_kind, static_cast<MEMORY_FORMAT>(input_md.data.format),
+          input_md);
 #endif
       pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
 
@@ -234,17 +236,18 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
       // Pass prop_kind::forward_training to create a forward primitive
       // that is used in the backward pass.
 #ifdef ENABLE_MKLDNN_V1
+      // TODO(DNNL): Find out what should we use src_md.data.format.
       MklPoolingParams bwdParams(
           orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
           strides, padding_left, padding_right,
           ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training,
-          static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
+          static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md);
 #else
       MklPoolingParams bwdParams(
           orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
           strides, padding_left, padding_right,
           ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training,
-          static_cast<MEMORY_FORMAT>(src_md.data.format));
+          static_cast<MEMORY_FORMAT>(src_md.data.format), src_md);
 #endif
       MklPoolingBwdPrimitive<T>* pooling_bwd =
           MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index 30023f360cc..2638d835f2c 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -26,6 +26,9 @@ limitations under the License.
 #define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
 #define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
 
+#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
+#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
+
 using mkldnn::batch_normalization_backward;
 using mkldnn::batch_normalization_forward;
 using mkldnn::prop_kind;
@@ -51,18 +54,17 @@ struct MklBatchNormFwdParams {
   MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
 #ifndef ENABLE_MKLDNN_V1
                         bool training, MEMORY_FORMAT src_format)
+#else
+                        bool training, memory::desc src_md)
+#endif  // !ENABLE_MKLDNN_V1
       : src_dims(src_dims),
         depth(depth),
         eps(eps),
         training(training),
+#ifndef ENABLE_MKLDNN_V1
         src_format(src_format) {
   }
 #else
-                        bool training, memory::desc src_md)
-      : src_dims(src_dims),
-        depth(depth),
-        eps(eps),
-        training(training),
         src_md(src_md) {
   }
 #endif  // !ENABLE_MKLDNN_V1
@@ -231,6 +233,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
     }
 
     // BatchNorm forward primitive.
+    // TODO(intel-tf): Merge all the #ifdefs and simplify code
     if (!fwdParams.training && !(IS_SET(use_global_stats))) {
 #ifdef ENABLE_MKLDNN_V1
       if ((IS_SET(use_scale_shift)) && mkldnn_use_scaleshift) {
@@ -383,6 +386,7 @@ struct MklBatchNormBwdParams {
   int depth;
   float eps;
   bool training;
+
 #ifndef ENABLE_MKLDNN_V1
   MEMORY_FORMAT src_format;
 #else
@@ -466,10 +470,8 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
 #ifdef ENABLE_MKLDNN_V1
     // Execute backward batch-normalization primitives.
     DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
-    for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
-      context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
-                                            context_.net_args.at(i));
-    }
+    execute_primitives(context_.bwd_primitives, context_.bwd_stream,
+                       context_.net_args);
 #else
     context_.bwd_stream->submit(context_.bwd_primitives);
 #endif  // ENABLE_MKLDNN_V1
@@ -841,7 +843,17 @@ class MklFusedBatchNormOp : public OpKernel {
       MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd =
           MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams);
 
-      const T* src_data = src_tensor.flat<T>().data();
+      // Check if reorder is needed for src.
+      const T* src_data = nullptr;
+      std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd();
+      if (IS_SRC_REORDER_NEEDED(src_md, bn_fwd_pd, bn_fwd)) {
+        src.SetUsrMem(src_md, &src_tensor);
+        src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
+            GET_SRC_DESC_FROM_OP_PD(bn_fwd_pd), cpu_engine_));
+        src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
+      } else {
+        src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
+      }
 
       // Allocate output (dst) tensor; always set it as MKL-DNN layout
       MklDnnShape dnn_shape_dst;
diff --git a/tensorflow/core/kernels/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl_matmul_op_fused.cc
index 18f6667fd1e..2ff24f5d2fc 100644
--- a/tensorflow/core/kernels/mkl_matmul_op_fused.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op_fused.cc
@@ -86,11 +86,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
     const int k = src_tf_shape.dim_size(dim_pair[0]);
     const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]);
 
-    OP_REQUIRES(
-        ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
-        errors::InvalidArgument(
-            "Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(),
-            ", In[1]: ", weight_tf_shape.DebugString()));
+    OP_REQUIRES(ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
+                errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
+                                        src_tf_shape.DebugString(), ", In[1]: ",
+                                        weight_tf_shape.DebugString()));
     OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel,
                 errors::InvalidArgument(
                     "Must provide as many biases as the channel size: ",
@@ -201,9 +200,9 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
       // Execute fused matmul op.
       matmul_prim->Execute(src_data, weight_data, bias_data, dst_data);
     } catch (mkldnn::error& e) {
-      string error_msg = "Status: " + std::to_string(e.status) +
-                         ", message: " + string(e.message) + ", in file " +
-                         string(__FILE__) + ":" + std::to_string(__LINE__);
+      string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+                         string(e.message) + ", in file " + string(__FILE__) +
+                         ":" + std::to_string(__LINE__);
       OP_REQUIRES_OK(
           ctx, errors::Aborted("Operation received an exception:", error_msg));
     }
diff --git a/tensorflow/core/kernels/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl_matmul_ops_common.h
index ca90a24a1cd..c5d27b92e00 100644
--- a/tensorflow/core/kernels/mkl_matmul_ops_common.h
+++ b/tensorflow/core/kernels/mkl_matmul_ops_common.h
@@ -443,16 +443,16 @@ class MklDnnMatMulOpBase : public OpKernel {
 
     Tensor* weight_tensor_ptr = nullptr;
 
-    size_t size = matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS.get_size();
+    size_t weight_size = matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS.get_size();
     TensorShape weight_tf_shape;
-    weight_tf_shape.AddDim(size / sizeof(Tweight));
+    weight_tf_shape.AddDim(weight_size / sizeof(Tweight));
 
     OP_REQUIRES_OK(context, context->allocate_persistent(
                                 DataTypeToEnum<Tweight>::value, weight_tf_shape,
                                 &weight_oi_, &weight_tensor_ptr));
 
     void* weight_oi_t_data = weight.GetTensorBuffer(weight_tensor_ptr);
-    memcpy(weight_oi_t_data, weight_data, size);
+    memcpy(weight_oi_t_data, weight_data, weight_size);
 
 // cache the memory descriptor
 #ifdef ENABLE_MKLDNN_V1
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index de4974f659e..098ea049246 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -139,15 +139,16 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
       else
         pooling_prop_kind = prop_kind::forward_training;
 #ifdef ENABLE_MKLDNN_V1
+      // TODO(DNNL): Figure out what should be used for input_md.data.format
       MklPoolingParams fwdParams(
           src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
           padding_right, ALGORITHM::pooling_max, pooling_prop_kind,
-          static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
+          static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md);
 #else
       MklPoolingParams fwdParams(
           src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
           padding_right, ALGORITHM::pooling_max, pooling_prop_kind,
-          static_cast<MEMORY_FORMAT>(input_md.data.format));
+          static_cast<MEMORY_FORMAT>(input_md.data.format), input_md);
 #endif
       pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
       // Allocate output tensor.
@@ -297,17 +298,18 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
                              this->data_format_mkldnn_);
 
 #ifdef ENABLE_MKLDNN_V1
+      // TODO(DNNL): Find out what should be used for src_md.data.format.
       MklPoolingParams bwdParams(
           orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
           strides, padding_left, padding_right, ALGORITHM::pooling_max,
           prop_kind::forward_training,
-          static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
+          static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md);
 #else
       MklPoolingParams bwdParams(
           orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
           strides, padding_left, padding_right, ALGORITHM::pooling_max,
           prop_kind::forward_training,
-          static_cast<MEMORY_FORMAT>(src_md.data.format));
+          static_cast<MEMORY_FORMAT>(src_md.data.format), src_md);
 #endif
       MklPoolingBwdPrimitive<T>* pooling_bwd =
           MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
index 29bcaf5e67c..904866f8223 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -43,6 +43,7 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
   //        so src format is currently hard-coded.
   //        A utility function is used to do this,
   //        which may be broken with future CPU architectures
+#ifndef ENABLE_MKLDNN_V1
   bool is_2d = (fwdParams.src_dims.size() == 4);
   if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
     context_.src_fmt = is_2d ? MEMORY_FORMAT::nhwc : MEMORY_FORMAT::ndhwc;
@@ -51,6 +52,9 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
 
   context_.src_md.reset(new memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
                                          context_.src_fmt));
+#else
+  context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
+#endif  //  !ENABLE_MKLDNN_V1
   context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
                                          MEMORY_FORMAT::any));
 
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index 8d18c95a542..ff51282ecc6 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -48,12 +48,13 @@ struct MklPoolingParams {
   mkldnn::algorithm alg_kind;
   mkldnn::prop_kind prop_kind;
   MEMORY_FORMAT src_format;
+  memory::desc src_md;
 
   MklPoolingParams(memory::dims src_dims, memory::dims dst_dims,
                    memory::dims filter_dims, memory::dims strides,
                    memory::dims padding_left, memory::dims padding_right,
                    mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind,
-                   MEMORY_FORMAT src_format)
+                   MEMORY_FORMAT src_format, memory::desc src_md)
       : src_dims(src_dims),
         dst_dims(dst_dims),
         filter_dims(filter_dims),
@@ -62,7 +63,8 @@ struct MklPoolingParams {
         padding_right(padding_right),
         alg_kind(alg_kind),
         prop_kind(prop_kind),
-        src_format(src_format) {}
+        src_format(src_format),
+        src_md(src_md) {}
 };
 
 template <typename T>
diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc
index 4b4705150a6..cf2b6bb1100 100644
--- a/tensorflow/core/kernels/pack_op.cc
+++ b/tensorflow/core/kernels/pack_op.cc
@@ -50,20 +50,10 @@ class PackOp : public OpKernel {
   }
 
   void Compute(OpKernelContext* c) override {
-    OpInputList values;
-    OP_REQUIRES_OK(c, c->input_list("values", &values));
-    const int num = values.size();
+    const int num = num_inputs();
+    const Tensor& first_input = c->input(0);
 
-    // Verify that all input shapes match
-    for (int i = 1; i < num; i++) {
-      OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()),
-                  errors::InvalidArgument(
-                      "Shapes of all inputs must match: values[0].shape = ",
-                      values[0].shape().DebugString(), " != values[", i,
-                      "].shape = ", values[i].shape().DebugString()));
-    }
-
-    int expanded_num_dims = values[0].dims() + 1;
+    int expanded_num_dims = first_input.dims() + 1;
     int axis = axis_;
     if (axis < 0) axis += expanded_num_dims;
 
@@ -72,13 +62,13 @@ class PackOp : public OpKernel {
                                         -expanded_num_dims, ", ",
                                         expanded_num_dims, ")"));
 
-    TensorShape output_shape(values[0].shape());
+    TensorShape output_shape(first_input.shape());
     output_shape.InsertDim(axis, num);
 
     // In the num = 1 case, just reshape the input
     if (num == 1) {
       Tensor output;
-      CHECK(output.CopyFrom(values[0], output_shape));
+      CHECK(output.CopyFrom(first_input, output_shape));
       c->set_output(0, output);
       return;
     }
@@ -109,8 +99,15 @@ class PackOp : public OpKernel {
       ConstMatrixVector inputs_flat;
       inputs_flat.reserve(num);
       for (int i = 0; i < num; ++i) {
+        const Tensor& input = c->input(i);
+        OP_REQUIRES(c, first_input.shape().IsSameSize(input.shape()),
+                    errors::InvalidArgument(
+                        "Shapes of all inputs must match: values[0].shape = ",
+                        first_input.shape().DebugString(), " != values[", i,
+                        "].shape = ", input.shape().DebugString()));
+
         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
-            values[i].shaped<T, 2>({before_dim, after_dim})));
+            input.shaped<T, 2>({before_dim, after_dim})));
       }
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
       if (std::is_same<Device, GPUDevice>::value) {
diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
index a97c5cb47a2..8de93cf9b30 100644
--- a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
+++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
@@ -78,16 +78,23 @@ class SparseFillEmptyRowsOp : public OpKernel {
     const int64 N = indices_t.shape().dim_size(0);
     const int64 dense_rows = dense_shape(0);
 
-    Tensor* empty_row_indicator_t;
-    OP_REQUIRES_OK(context, context->allocate_output(kEmptyRowIndicatorOutput,
-                                                     TensorShape({dense_rows}),
-                                                     &empty_row_indicator_t));
-    auto empty_row_indicator = empty_row_indicator_t->vec<bool>();
-    Tensor* reverse_index_map_t;
-    OP_REQUIRES_OK(context, context->allocate_output(kReverseIndexMapOutput,
-                                                     TensorShape({N}),
-                                                     &reverse_index_map_t));
-    auto reverse_index_map = reverse_index_map_t->vec<int64>();
+    bool* empty_row_indicator = nullptr;
+    if (context->output_required(kEmptyRowIndicatorOutput)) {
+      Tensor* empty_row_indicator_t = nullptr;
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(kEmptyRowIndicatorOutput,
+                                              TensorShape({dense_rows}),
+                                              &empty_row_indicator_t));
+      empty_row_indicator = empty_row_indicator_t->vec<bool>().data();
+    }
+    int64* reverse_index_map = nullptr;
+    if (context->output_required(kReverseIndexMapOutput)) {
+      Tensor* reverse_index_map_t = nullptr;
+      OP_REQUIRES_OK(context, context->allocate_output(kReverseIndexMapOutput,
+                                                       TensorShape({N}),
+                                                       &reverse_index_map_t));
+      reverse_index_map = reverse_index_map_t->vec<int64>().data();
+    }
 
     int rank = indices_t.shape().dim_size(1);
 
@@ -122,8 +129,11 @@ class SparseFillEmptyRowsOp : public OpKernel {
     bool all_rows_full = true;
     for (int row = 0; row < dense_rows; ++row) {
       // csr_offset here describes the number of elements in this dense row
-      empty_row_indicator(row) = (csr_offset[row] == 0);
-      all_rows_full = all_rows_full & !empty_row_indicator(row);
+      bool row_empty = (csr_offset[row] == 0);
+      if (empty_row_indicator) {
+        empty_row_indicator[row] = row_empty;
+      }
+      all_rows_full = all_rows_full & !row_empty;
       // In filled version, each row has at least one element.
       csr_offset[row] = std::max(csr_offset[row], int64{1});
       // Update csr_offset to represent the number of elements up to and
@@ -140,8 +150,10 @@ class SparseFillEmptyRowsOp : public OpKernel {
     if (all_rows_full) {
       context->set_output(kOutputIndicesOutput, indices_t);
       context->set_output(kOutputValuesOutput, values_t);
-      for (int64 i = 0; i < N; ++i) {
-        reverse_index_map(i) = i;
+      if (reverse_index_map) {
+        for (int64 i = 0; i < N; ++i) {
+          reverse_index_map[i] = i;
+        }
       }
     } else {
       Tensor* output_indices_t;
@@ -169,7 +181,9 @@ class SparseFillEmptyRowsOp : public OpKernel {
         std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0));
         output_values(output_i) = values(i);
         // We'll need this reverse index map to backprop correctly.
-        reverse_index_map(i) = output_i;
+        if (reverse_index_map) {
+          reverse_index_map[i] = output_i;
+        }
       }
 
       // Fill in values for rows that are missing
diff --git a/tensorflow/core/lib/io/BUILD b/tensorflow/core/lib/io/BUILD
index 87b5090a59f..d03a895b429 100644
--- a/tensorflow/core/lib/io/BUILD
+++ b/tensorflow/core/lib/io/BUILD
@@ -208,6 +208,21 @@ cc_library(
     alwayslink = True,
 )
 
+cc_library(
+    name = "cache",
+    srcs = [
+        "cache.cc",
+    ],
+    hdrs = [
+        "cache.h",
+    ],
+    deps = [
+        "//tensorflow/core/platform:coding",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:stringpiece",
+    ],
+)
+
 cc_library(
     name = "table",
     srcs = [
@@ -220,6 +235,7 @@ cc_library(
     ],
     deps = [
         ":block",
+        ":cache",
         ":iterator",
         ":table_options",
         "//tensorflow/core/lib/core:coding",
@@ -290,6 +306,8 @@ filegroup(
         "block_builder.h",
         "buffered_inputstream.cc",
         "buffered_inputstream.h",
+        "cache.cc",
+        "cache.h",
         "compression.cc",
         "compression.h",
         "format.cc",
@@ -352,6 +370,7 @@ filegroup(
     name = "legacy_lib_io_all_tests",
     srcs = [
         "buffered_inputstream_test.cc",
+        "cache_test.cc",
         "inputbuffer_test.cc",
         "inputstream_interface_test.cc",
         "path_test.cc",
@@ -369,6 +388,7 @@ filegroup(
     name = "legacy_lib_io_headers",
     srcs = [
         "buffered_inputstream.h",
+        "cache.h",
         "compression.h",
         "inputstream_interface.h",
         "path.h",
diff --git a/tensorflow/core/lib/io/cache.cc b/tensorflow/core/lib/io/cache.cc
new file mode 100644
index 00000000000..b5521b1752b
--- /dev/null
+++ b/tensorflow/core/lib/io/cache.cc
@@ -0,0 +1,450 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/io/cache.h"
+
+#include <assert.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "tensorflow/core/platform/coding.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+namespace table {
+
+Cache::~Cache() {}
+
+namespace {
+
+// LRU cache implementation
+//
+// Cache entries have an "in_cache" boolean indicating whether the cache has a
+// reference on the entry.  The only ways that this can become false without the
+// entry being passed to its "deleter" are via Erase(), via Insert() when
+// an element with a duplicate key is inserted, or on destruction of the cache.
+//
+// The cache keeps two linked lists of items in the cache.  All items in the
+// cache are in one list or the other, and never both.  Items still referenced
+// by clients but erased from the cache are in neither list.  The lists are:
+// - in-use:  contains the items currently referenced by clients, in no
+//   particular order.  (This list is used for invariant checking.  If we
+//   removed the check, elements that would otherwise be on this list could be
+//   left as disconnected singleton lists.)
+// - LRU:  contains the items not currently referenced by clients, in LRU order
+// Elements are moved between these lists by the Ref() and Unref() methods,
+// when they detect an element in the cache acquiring or losing its only
+// external reference.
+
+// An entry is a variable length heap-allocated structure.  Entries
+// are kept in a circular doubly linked list ordered by access time.
+struct LRUHandle {
+  void* value;
+  void (*deleter)(const Slice&, void* value);
+  LRUHandle* next_hash;
+  LRUHandle* next;
+  LRUHandle* prev;
+  size_t charge;  // TODO(opt): Only allow uint32_t?
+  size_t key_length;
+  bool in_cache;     // Whether entry is in the cache.
+  uint32_t refs;     // References, including cache reference, if present.
+  uint32_t hash;     // Hash of key(); used for fast sharding and comparisons
+  char key_data[1];  // Beginning of key
+
+  Slice key() const {
+    // next_ is only equal to this if the LRU handle is the list head of an
+    // empty list. List heads never have meaningful keys.
+    assert(next != this);
+
+    return Slice(key_data, key_length);
+  }
+};
+
+// We provide our own simple hash table since it removes a whole bunch
+// of porting hacks and is also faster than some of the built-in hash
+// table implementations in some of the compiler/runtime combinations
+// we have tested.  E.g., readrandom speeds up by ~5% over the g++
+// 4.4.3's builtin hashtable.
+class HandleTable {
+ public:
+  HandleTable() : length_(0), elems_(0), list_(nullptr) { Resize(); }
+  ~HandleTable() { delete[] list_; }
+
+  LRUHandle* Lookup(const Slice& key, uint32_t hash) {
+    return *FindPointer(key, hash);
+  }
+
+  LRUHandle* Insert(LRUHandle* h) {
+    LRUHandle** ptr = FindPointer(h->key(), h->hash);
+    LRUHandle* old = *ptr;
+    h->next_hash = (old == nullptr ? nullptr : old->next_hash);
+    *ptr = h;
+    if (old == nullptr) {
+      ++elems_;
+      if (elems_ > length_) {
+        // Since each cache entry is fairly large, we aim for a small
+        // average linked list length (<= 1).
+        Resize();
+      }
+    }
+    return old;
+  }
+
+  LRUHandle* Remove(const Slice& key, uint32_t hash) {
+    LRUHandle** ptr = FindPointer(key, hash);
+    LRUHandle* result = *ptr;
+    if (result != nullptr) {
+      *ptr = result->next_hash;
+      --elems_;
+    }
+    return result;
+  }
+
+ private:
+  // The table consists of an array of buckets where each bucket is
+  // a linked list of cache entries that hash into the bucket.
+  uint32_t length_;
+  uint32_t elems_;
+  LRUHandle** list_;
+
+  // Return a pointer to slot that points to a cache entry that
+  // matches key/hash.  If there is no such cache entry, return a
+  // pointer to the trailing slot in the corresponding linked list.
+  LRUHandle** FindPointer(const Slice& key, uint32_t hash) {
+    LRUHandle** ptr = &list_[hash & (length_ - 1)];
+    while (*ptr != nullptr && ((*ptr)->hash != hash || key != (*ptr)->key())) {
+      ptr = &(*ptr)->next_hash;
+    }
+    return ptr;
+  }
+
+  void Resize() {
+    uint32_t new_length = 4;
+    while (new_length < elems_) {
+      new_length *= 2;
+    }
+    LRUHandle** new_list = new LRUHandle*[new_length];
+    memset(new_list, 0, sizeof(new_list[0]) * new_length);
+    uint32_t count = 0;
+    for (uint32_t i = 0; i < length_; i++) {
+      LRUHandle* h = list_[i];
+      while (h != nullptr) {
+        LRUHandle* next = h->next_hash;
+        uint32_t hash = h->hash;
+        LRUHandle** ptr = &new_list[hash & (new_length - 1)];
+        h->next_hash = *ptr;
+        *ptr = h;
+        h = next;
+        count++;
+      }
+    }
+    assert(elems_ == count);
+    delete[] list_;
+    list_ = new_list;
+    length_ = new_length;
+  }
+};
+
+// A single shard of sharded cache.
+class LRUCache {
+ public:
+  LRUCache();
+  ~LRUCache();
+
+  // Separate from constructor so caller can easily make an array of LRUCache
+  void SetCapacity(size_t capacity) { capacity_ = capacity; }
+
+  // Like Cache methods, but with an extra "hash" parameter.
+  Cache::Handle* Insert(const Slice& key, uint32_t hash, void* value,
+                        size_t charge,
+                        void (*deleter)(const Slice& key, void* value));
+  Cache::Handle* Lookup(const Slice& key, uint32_t hash);
+  void Release(Cache::Handle* handle);
+  void Erase(const Slice& key, uint32_t hash);
+  void Prune();
+  size_t TotalCharge() const {
+    mutex_lock l(mutex_);
+    return usage_;
+  }
+
+ private:
+  void LRU_Remove(LRUHandle* e);
+  void LRU_Append(LRUHandle* list, LRUHandle* e);
+  void Ref(LRUHandle* e);
+  void Unref(LRUHandle* e);
+  bool FinishErase(LRUHandle* e) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+  // Initialized before use.
+  size_t capacity_;
+
+  // mutex_ protects the following state.
+  mutable mutex mutex_;
+  size_t usage_ GUARDED_BY(mutex_);
+
+  // Dummy head of LRU list.
+  // lru.prev is newest entry, lru.next is oldest entry.
+  // Entries have refs==1 and in_cache==true.
+  LRUHandle lru_ GUARDED_BY(mutex_);
+
+  // Dummy head of in-use list.
+  // Entries are in use by clients, and have refs >= 2 and in_cache==true.
+  LRUHandle in_use_ GUARDED_BY(mutex_);
+
+  HandleTable table_ GUARDED_BY(mutex_);
+};
+
+LRUCache::LRUCache() : capacity_(0), usage_(0) {
+  // Make empty circular linked lists.
+  lru_.next = &lru_;
+  lru_.prev = &lru_;
+  in_use_.next = &in_use_;
+  in_use_.prev = &in_use_;
+}
+
+LRUCache::~LRUCache() {
+  assert(in_use_.next == &in_use_);  // Error if caller has an unreleased handle
+  for (LRUHandle* e = lru_.next; e != &lru_;) {
+    LRUHandle* next = e->next;
+    assert(e->in_cache);
+    e->in_cache = false;
+    assert(e->refs == 1);  // Invariant of lru_ list.
+    Unref(e);
+    e = next;
+  }
+}
+
+void LRUCache::Ref(LRUHandle* e) {
+  if (e->refs == 1 && e->in_cache) {  // If on lru_ list, move to in_use_ list.
+    LRU_Remove(e);
+    LRU_Append(&in_use_, e);
+  }
+  e->refs++;
+}
+
+void LRUCache::Unref(LRUHandle* e) {
+  assert(e->refs > 0);
+  e->refs--;
+  if (e->refs == 0) {  // Deallocate.
+    assert(!e->in_cache);
+    (*e->deleter)(e->key(), e->value);
+    free(e);
+  } else if (e->in_cache && e->refs == 1) {
+    // No longer in use; move to lru_ list.
+    LRU_Remove(e);
+    LRU_Append(&lru_, e);
+  }
+}
+
+void LRUCache::LRU_Remove(LRUHandle* e) {
+  e->next->prev = e->prev;
+  e->prev->next = e->next;
+}
+
+void LRUCache::LRU_Append(LRUHandle* list, LRUHandle* e) {
+  // Make "e" newest entry by inserting just before *list
+  e->next = list;
+  e->prev = list->prev;
+  e->prev->next = e;
+  e->next->prev = e;
+}
+
+Cache::Handle* LRUCache::Lookup(const Slice& key, uint32_t hash) {
+  mutex_lock l(mutex_);
+  LRUHandle* e = table_.Lookup(key, hash);
+  if (e != nullptr) {
+    Ref(e);
+  }
+  return reinterpret_cast<Cache::Handle*>(e);
+}
+
+void LRUCache::Release(Cache::Handle* handle) {
+  mutex_lock l(mutex_);
+  Unref(reinterpret_cast<LRUHandle*>(handle));
+}
+
+Cache::Handle* LRUCache::Insert(const Slice& key, uint32_t hash, void* value,
+                                size_t charge,
+                                void (*deleter)(const Slice& key,
+                                                void* value)) {
+  mutex_lock l(mutex_);
+
+  LRUHandle* e =
+      reinterpret_cast<LRUHandle*>(malloc(sizeof(LRUHandle) - 1 + key.size()));
+  e->value = value;
+  e->deleter = deleter;
+  e->charge = charge;
+  e->key_length = key.size();
+  e->hash = hash;
+  e->in_cache = false;
+  e->refs = 1;  // for the returned handle.
+  memcpy(e->key_data, key.data(), key.size());
+
+  if (capacity_ > 0) {
+    e->refs++;  // for the cache's reference.
+    e->in_cache = true;
+    LRU_Append(&in_use_, e);
+    usage_ += charge;
+    FinishErase(table_.Insert(e));
+  } else {  // don't cache. (capacity_==0 is supported and turns off caching.)
+    // next is read by key() in an assert, so it must be initialized
+    e->next = nullptr;
+  }
+  while (usage_ > capacity_ && lru_.next != &lru_) {
+    LRUHandle* old = lru_.next;
+    assert(old->refs == 1);
+    bool erased = FinishErase(table_.Remove(old->key(), old->hash));
+    if (!erased) {  // to avoid unused variable when compiled NDEBUG
+      assert(erased);
+    }
+  }
+
+  return reinterpret_cast<Cache::Handle*>(e);
+}
+
+// If e != nullptr, finish removing *e from the cache; it has already been
+// removed from the hash table.  Return whether e != nullptr.
+bool LRUCache::FinishErase(LRUHandle* e) {
+  if (e != nullptr) {
+    assert(e->in_cache);
+    LRU_Remove(e);
+    e->in_cache = false;
+    usage_ -= e->charge;
+    Unref(e);
+  }
+  return e != nullptr;
+}
+
+void LRUCache::Erase(const Slice& key, uint32_t hash) {
+  mutex_lock l(mutex_);
+  FinishErase(table_.Remove(key, hash));
+}
+
+void LRUCache::Prune() {
+  mutex_lock l(mutex_);
+  while (lru_.next != &lru_) {
+    LRUHandle* e = lru_.next;
+    assert(e->refs == 1);
+    bool erased = FinishErase(table_.Remove(e->key(), e->hash));
+    if (!erased) {  // to avoid unused variable when compiled NDEBUG
+      assert(erased);
+    }
+  }
+}
+
+static const int kNumShardBits = 4;
+static const int kNumShards = 1 << kNumShardBits;
+
+class ShardedLRUCache : public Cache {
+ private:
+  LRUCache shard_[kNumShards];
+  mutex id_mutex_;
+  uint64_t last_id_;
+
+  static inline uint32_t HashSlice(const Slice& s) {
+    return Hash(s.data(), s.size(), 0);
+  }
+
+  static uint32_t Shard(uint32_t hash) { return hash >> (32 - kNumShardBits); }
+
+ public:
+  explicit ShardedLRUCache(size_t capacity) : last_id_(0) {
+    const size_t per_shard = (capacity + (kNumShards - 1)) / kNumShards;
+    for (int s = 0; s < kNumShards; s++) {
+      shard_[s].SetCapacity(per_shard);
+    }
+  }
+  ~ShardedLRUCache() override {}
+  Handle* Insert(const Slice& key, void* value, size_t charge,
+                 void (*deleter)(const Slice& key, void* value)) override {
+    const uint32_t hash = HashSlice(key);
+    return shard_[Shard(hash)].Insert(key, hash, value, charge, deleter);
+  }
+  Handle* Lookup(const Slice& key) override {
+    const uint32_t hash = HashSlice(key);
+    return shard_[Shard(hash)].Lookup(key, hash);
+  }
+  void Release(Handle* handle) override {
+    LRUHandle* h = reinterpret_cast<LRUHandle*>(handle);
+    shard_[Shard(h->hash)].Release(handle);
+  }
+  void Erase(const Slice& key) override {
+    const uint32_t hash = HashSlice(key);
+    shard_[Shard(hash)].Erase(key, hash);
+  }
+  void* Value(Handle* handle) override {
+    return reinterpret_cast<LRUHandle*>(handle)->value;
+  }
+  uint64_t NewId() override {
+    mutex_lock l(id_mutex_);
+    return ++(last_id_);
+  }
+  void Prune() override {
+    for (int s = 0; s < kNumShards; s++) {
+      shard_[s].Prune();
+    }
+  }
+  size_t TotalCharge() const override {
+    size_t total = 0;
+    for (int s = 0; s < kNumShards; s++) {
+      total += shard_[s].TotalCharge();
+    }
+    return total;
+  }
+
+ private:
+  // TODO(byronyi): Figure out why Hash32 fails EvictionPolicy test.
+  static uint32_t Hash(const char* data, size_t n, uint32_t seed) {
+    // Similar to murmur hash
+    const uint32_t m = 0xc6a4a793;
+    const uint32_t r = 24;
+    const char* limit = data + n;
+    uint32_t h = seed ^ (n * m);
+
+    // Pick up four bytes at a time
+    while (data + 4 <= limit) {
+      uint32_t w = core::DecodeFixed32(data);
+      data += 4;
+      h += w;
+      h *= m;
+      h ^= (h >> 16);
+    }
+
+    // Pick up remaining bytes
+    switch (limit - data) {
+      case 3:
+        h += static_cast<uint8_t>(data[2]) << 16;
+        ABSL_FALLTHROUGH_INTENDED;
+      case 2:
+        h += static_cast<uint8_t>(data[1]) << 8;
+        ABSL_FALLTHROUGH_INTENDED;
+      case 1:
+        h += static_cast<uint8_t>(data[0]);
+        h *= m;
+        h ^= (h >> r);
+        break;
+    }
+    return h;
+  }
+};
+
+}  // end anonymous namespace
+
+Cache* NewLRUCache(size_t capacity) { return new ShardedLRUCache(capacity); }
+
+}  // namespace table
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/lib/io/cache.h b/tensorflow/core/lib/io/cache.h
new file mode 100644
index 00000000000..788a637077a
--- /dev/null
+++ b/tensorflow/core/lib/io/cache.h
@@ -0,0 +1,125 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_LIB_IO_CACHE_H_
+#define TENSORFLOW_CORE_LIB_IO_CACHE_H_
+
+#include "tensorflow/core/platform/stringpiece.h"
+
+// A Cache is an interface that maps keys to values.  It has internal
+// synchronization and may be safely accessed concurrently from
+// multiple threads.  It may automatically evict entries to make room
+// for new entries.  Values have a specified charge against the cache
+// capacity.  For example, a cache where the values are variable
+// length strings, may use the length of the string as the charge for
+// the string.
+//
+// A builtin cache implementation with a least-recently-used eviction
+// policy is provided.  Clients may use their own implementations if
+// they want something more sophisticated (like scan-resistance, a
+// custom eviction policy, variable cache sizing, etc.)
+
+namespace tensorflow {
+
+using Slice = StringPiece;
+
+namespace table {
+
+class Cache;
+
+// Create a new cache with a fixed size capacity.  This implementation
+// of Cache uses a least-recently-used eviction policy.
+Cache* NewLRUCache(size_t capacity);
+
+class Cache {
+ public:
+  Cache() = default;
+
+  Cache(const Cache&) = delete;
+  Cache& operator=(const Cache&) = delete;
+
+  // Destroys all existing entries by calling the "deleter"
+  // function that was passed to the constructor.
+  virtual ~Cache();
+
+  // Opaque handle to an entry stored in the cache.
+  struct Handle {};
+
+  // Insert a mapping from key->value into the cache and assign it
+  // the specified charge against the total cache capacity.
+  //
+  // Returns a handle that corresponds to the mapping.  The caller
+  // must call this->Release(handle) when the returned mapping is no
+  // longer needed.
+  //
+  // When the inserted entry is no longer needed, the key and
+  // value will be passed to "deleter".
+  virtual Handle* Insert(const Slice& key, void* value, size_t charge,
+                         void (*deleter)(const Slice& key, void* value)) = 0;
+
+  // If the cache has no mapping for "key", returns nullptr.
+  //
+  // Else return a handle that corresponds to the mapping.  The caller
+  // must call this->Release(handle) when the returned mapping is no
+  // longer needed.
+  virtual Handle* Lookup(const Slice& key) = 0;
+
+  // Release a mapping returned by a previous Lookup().
+  // REQUIRES: handle must not have been released yet.
+  // REQUIRES: handle must have been returned by a method on *this.
+  virtual void Release(Handle* handle) = 0;
+
+  // Return the value encapsulated in a handle returned by a
+  // successful Lookup().
+  // REQUIRES: handle must not have been released yet.
+  // REQUIRES: handle must have been returned by a method on *this.
+  virtual void* Value(Handle* handle) = 0;
+
+  // If the cache contains entry for key, erase it.  Note that the
+  // underlying entry will be kept around until all existing handles
+  // to it have been released.
+  virtual void Erase(const Slice& key) = 0;
+
+  // Return a new numeric id.  May be used by multiple clients who are
+  // sharing the same cache to partition the key space.  Typically the
+  // client will allocate a new id at startup and prepend the id to
+  // its cache keys.
+  virtual uint64_t NewId() = 0;
+
+  // Remove all cache entries that are not actively in use.  Memory-constrained
+  // applications may wish to call this method to reduce memory usage.
+  // Default implementation of Prune() does nothing.  Subclasses are strongly
+  // encouraged to override the default implementation.  A future release of
+  // leveldb may change Prune() to a pure abstract method.
+  virtual void Prune() {}
+
+  // Return an estimate of the combined charges of all elements stored in the
+  // cache.
+  virtual size_t TotalCharge() const = 0;
+
+ private:
+  void LRU_Remove(Handle* e);
+  void LRU_Append(Handle* e);
+  void Unref(Handle* e);
+
+  struct Rep;
+  Rep* rep_;
+};
+
+}  // namespace table
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_LIB_IO_CACHE_H_
diff --git a/tensorflow/core/lib/io/cache_test.cc b/tensorflow/core/lib/io/cache_test.cc
new file mode 100644
index 00000000000..38552d43b34
--- /dev/null
+++ b/tensorflow/core/lib/io/cache_test.cc
@@ -0,0 +1,238 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/io/cache.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace table {
+// Conversions between numeric keys/values and the types expected by Cache.
+static std::string EncodeKey(int k) {
+  std::string result;
+  core::PutFixed32(&result, k);
+  return result;
+}
+static int DecodeKey(const Slice& k) {
+  assert(k.size() == 4);
+  return core::DecodeFixed32(k.data());
+}
+static void* EncodeValue(uintptr_t v) { return reinterpret_cast<void*>(v); }
+static int DecodeValue(void* v) { return reinterpret_cast<uintptr_t>(v); }
+
+class CacheTest : public ::testing::Test {
+ public:
+  static void Deleter(const Slice& key, void* v) {
+    current_->deleted_keys_.push_back(DecodeKey(key));
+    current_->deleted_values_.push_back(DecodeValue(v));
+  }
+
+  static const int kCacheSize = 1000;
+  std::vector<int> deleted_keys_;
+  std::vector<int> deleted_values_;
+  Cache* cache_;
+
+  CacheTest() : cache_(NewLRUCache(kCacheSize)) { current_ = this; }
+
+  ~CacheTest() { delete cache_; }
+
+  int Lookup(int key) {
+    Cache::Handle* handle = cache_->Lookup(EncodeKey(key));
+    const int r = (handle == nullptr) ? -1 : DecodeValue(cache_->Value(handle));
+    if (handle != nullptr) {
+      cache_->Release(handle);
+    }
+    return r;
+  }
+
+  void Insert(int key, int value, int charge = 1) {
+    cache_->Release(cache_->Insert(EncodeKey(key), EncodeValue(value), charge,
+                                   &CacheTest::Deleter));
+  }
+
+  Cache::Handle* InsertAndReturnHandle(int key, int value, int charge = 1) {
+    return cache_->Insert(EncodeKey(key), EncodeValue(value), charge,
+                          &CacheTest::Deleter);
+  }
+
+  void Erase(int key) { cache_->Erase(EncodeKey(key)); }
+  static CacheTest* current_;
+};
+CacheTest* CacheTest::current_;
+
+TEST_F(CacheTest, HitAndMiss) {
+  ASSERT_EQ(-1, Lookup(100));
+
+  Insert(100, 101);
+  ASSERT_EQ(101, Lookup(100));
+  ASSERT_EQ(-1, Lookup(200));
+  ASSERT_EQ(-1, Lookup(300));
+
+  Insert(200, 201);
+  ASSERT_EQ(101, Lookup(100));
+  ASSERT_EQ(201, Lookup(200));
+  ASSERT_EQ(-1, Lookup(300));
+
+  Insert(100, 102);
+  ASSERT_EQ(102, Lookup(100));
+  ASSERT_EQ(201, Lookup(200));
+  ASSERT_EQ(-1, Lookup(300));
+
+  ASSERT_EQ(1, deleted_keys_.size());
+  ASSERT_EQ(100, deleted_keys_[0]);
+  ASSERT_EQ(101, deleted_values_[0]);
+}
+
+TEST_F(CacheTest, Erase) {
+  Erase(200);
+  ASSERT_EQ(0, deleted_keys_.size());
+
+  Insert(100, 101);
+  Insert(200, 201);
+  Erase(100);
+  ASSERT_EQ(-1, Lookup(100));
+  ASSERT_EQ(201, Lookup(200));
+  ASSERT_EQ(1, deleted_keys_.size());
+  ASSERT_EQ(100, deleted_keys_[0]);
+  ASSERT_EQ(101, deleted_values_[0]);
+
+  Erase(100);
+  ASSERT_EQ(-1, Lookup(100));
+  ASSERT_EQ(201, Lookup(200));
+  ASSERT_EQ(1, deleted_keys_.size());
+}
+
+TEST_F(CacheTest, EntriesArePinned) {
+  Insert(100, 101);
+  Cache::Handle* h1 = cache_->Lookup(EncodeKey(100));
+  ASSERT_EQ(101, DecodeValue(cache_->Value(h1)));
+
+  Insert(100, 102);
+  Cache::Handle* h2 = cache_->Lookup(EncodeKey(100));
+  ASSERT_EQ(102, DecodeValue(cache_->Value(h2)));
+  ASSERT_EQ(0, deleted_keys_.size());
+
+  cache_->Release(h1);
+  ASSERT_EQ(1, deleted_keys_.size());
+  ASSERT_EQ(100, deleted_keys_[0]);
+  ASSERT_EQ(101, deleted_values_[0]);
+
+  Erase(100);
+  ASSERT_EQ(-1, Lookup(100));
+  ASSERT_EQ(1, deleted_keys_.size());
+
+  cache_->Release(h2);
+  ASSERT_EQ(2, deleted_keys_.size());
+  ASSERT_EQ(100, deleted_keys_[1]);
+  ASSERT_EQ(102, deleted_values_[1]);
+}
+
+TEST_F(CacheTest, EvictionPolicy) {
+  Insert(100, 101);
+  Insert(200, 201);
+  Insert(300, 301);
+  Cache::Handle* h = cache_->Lookup(EncodeKey(300));
+
+  // Frequently used entry must be kept around,
+  // as must things that are still in use.
+  for (int i = 0; i < kCacheSize + 100; i++) {
+    Insert(1000 + i, 2000 + i);
+    ASSERT_EQ(2000 + i, Lookup(1000 + i));
+    ASSERT_EQ(101, Lookup(100));
+  }
+  ASSERT_EQ(101, Lookup(100));
+  ASSERT_EQ(-1, Lookup(200));
+  ASSERT_EQ(301, Lookup(300));
+  cache_->Release(h);
+}
+
+TEST_F(CacheTest, UseExceedsCacheSize) {
+  // Overfill the cache, keeping handles on all inserted entries.
+  std::vector<Cache::Handle*> h;
+  for (int i = 0; i < kCacheSize + 100; i++) {
+    h.push_back(InsertAndReturnHandle(1000 + i, 2000 + i));
+  }
+
+  // Check that all the entries can be found in the cache.
+  for (int i = 0; i < h.size(); i++) {
+    ASSERT_EQ(2000 + i, Lookup(1000 + i));
+  }
+
+  for (int i = 0; i < h.size(); i++) {
+    cache_->Release(h[i]);
+  }
+}
+
+TEST_F(CacheTest, HeavyEntries) {
+  // Add a bunch of light and heavy entries and then count the combined
+  // size of items still in the cache, which must be approximately the
+  // same as the total capacity.
+  const int kLight = 1;
+  const int kHeavy = 10;
+  int added = 0;
+  int index = 0;
+  while (added < 2 * kCacheSize) {
+    const int weight = (index & 1) ? kLight : kHeavy;
+    Insert(index, 1000 + index, weight);
+    added += weight;
+    index++;
+  }
+
+  int cached_weight = 0;
+  for (int i = 0; i < index; i++) {
+    const int weight = (i & 1 ? kLight : kHeavy);
+    int r = Lookup(i);
+    if (r >= 0) {
+      cached_weight += weight;
+      ASSERT_EQ(1000 + i, r);
+    }
+  }
+  ASSERT_LE(cached_weight, kCacheSize + kCacheSize / 10);
+}
+
+TEST_F(CacheTest, NewId) {
+  uint64_t a = cache_->NewId();
+  uint64_t b = cache_->NewId();
+  ASSERT_NE(a, b);
+}
+
+TEST_F(CacheTest, Prune) {
+  Insert(1, 100);
+  Insert(2, 200);
+
+  Cache::Handle* handle = cache_->Lookup(EncodeKey(1));
+  ASSERT_TRUE(handle);
+  cache_->Prune();
+  cache_->Release(handle);
+
+  ASSERT_EQ(100, Lookup(1));
+  ASSERT_EQ(-1, Lookup(2));
+}
+
+TEST_F(CacheTest, ZeroSizeCache) {
+  delete cache_;
+  cache_ = NewLRUCache(0);
+
+  Insert(1, 100);
+  ASSERT_EQ(-1, Lookup(1));
+}
+
+}  // namespace table
+}  // namespace tensorflow
diff --git a/tensorflow/core/lib/io/table.cc b/tensorflow/core/lib/io/table.cc
index 1e68493bfe9..6cd1b21c14d 100644
--- a/tensorflow/core/lib/io/table.cc
+++ b/tensorflow/core/lib/io/table.cc
@@ -18,6 +18,7 @@ limitations under the License.
 #include "tensorflow/core/lib/core/coding.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/io/block.h"
+#include "tensorflow/core/lib/io/cache.h"
 #include "tensorflow/core/lib/io/format.h"
 #include "tensorflow/core/lib/io/table_options.h"
 #include "tensorflow/core/lib/io/two_level_iterator.h"
@@ -32,7 +33,7 @@ struct Table::Rep {
   Options options;
   Status status;
   RandomAccessFile* file;
-  // XXX  uint64 cache_id;
+  uint64 cache_id;
 
   BlockHandle metaindex_handle;  // Handle to metaindex_block: saved from footer
   Block* index_block;
@@ -60,21 +61,18 @@ Status Table::Open(const Options& options, RandomAccessFile* file, uint64 size,
   Block* index_block = nullptr;
   if (s.ok()) {
     s = ReadBlock(file, footer.index_handle(), &contents);
-    if (s.ok()) {
-      index_block = new Block(contents);
-    }
   }
 
   if (s.ok()) {
     // We've successfully read the footer and the index block: we're
     // ready to serve requests.
+    index_block = new Block(contents);
     Rep* rep = new Table::Rep;
     rep->options = options;
     rep->file = file;
     rep->metaindex_handle = footer.metaindex_handle();
     rep->index_block = index_block;
-    // XXX    rep->cache_id = (options.block_cache ?
-    // options.block_cache->NewId() : 0);
+    rep->cache_id = (options.block_cache ? options.block_cache->NewId() : 0);
     *table = new Table(rep);
   } else {
     if (index_block) delete index_block;
@@ -89,13 +87,24 @@ static void DeleteBlock(void* arg, void* ignored) {
   delete reinterpret_cast<Block*>(arg);
 }
 
+static void DeleteCachedBlock(const absl::string_view&, void* value) {
+  Block* block = reinterpret_cast<Block*>(value);
+  delete block;
+}
+
+static void ReleaseBlock(void* arg, void* h) {
+  Cache* cache = reinterpret_cast<Cache*>(arg);
+  Cache::Handle* handle = reinterpret_cast<Cache::Handle*>(h);
+  cache->Release(handle);
+}
+
 // Convert an index iterator value (i.e., an encoded BlockHandle)
 // into an iterator over the contents of the corresponding block.
 Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) {
   Table* table = reinterpret_cast<Table*>(arg);
-  //  Cache* block_cache = table->rep_->options.block_cache;
+  Cache* block_cache = table->rep_->options.block_cache;
   Block* block = nullptr;
-  //  Cache::Handle* cache_handle = NULL;
+  Cache::Handle* cache_handle = NULL;
 
   BlockHandle handle;
   StringPiece input = index_value;
@@ -105,16 +114,38 @@ Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) {
 
   if (s.ok()) {
     BlockContents contents;
-    s = ReadBlock(table->rep_->file, handle, &contents);
-    if (s.ok()) {
-      block = new Block(contents);
+    if (block_cache != nullptr) {
+      char cache_key_buffer[16];
+      core::EncodeFixed64(cache_key_buffer, table->rep_->cache_id);
+      core::EncodeFixed64(cache_key_buffer + 8, handle.offset());
+      absl::string_view key(cache_key_buffer, sizeof(cache_key_buffer));
+      cache_handle = block_cache->Lookup(key);
+      if (cache_handle != nullptr) {
+        block = reinterpret_cast<Block*>(block_cache->Value(cache_handle));
+      } else {
+        s = ReadBlock(table->rep_->file, handle, &contents);
+        if (s.ok()) {
+          block = new Block(contents);
+          cache_handle = block_cache->Insert(key, block, block->size(),
+                                             &DeleteCachedBlock);
+        }
+      }
+    } else {
+      s = ReadBlock(table->rep_->file, handle, &contents);
+      if (s.ok()) {
+        block = new Block(contents);
+      }
     }
   }
 
   Iterator* iter;
   if (block != nullptr) {
     iter = block->NewIterator();
-    iter->RegisterCleanup(&DeleteBlock, block, nullptr);
+    if (cache_handle == nullptr) {
+      iter->RegisterCleanup(&DeleteBlock, block, nullptr);
+    } else {
+      iter->RegisterCleanup(&ReleaseBlock, block_cache, cache_handle);
+    }
   } else {
     iter = NewErrorIterator(s);
   }
diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h
index 9a36bf16315..d1b43ae7d43 100644
--- a/tensorflow/core/lib/io/table_options.h
+++ b/tensorflow/core/lib/io/table_options.h
@@ -21,6 +21,8 @@ limitations under the License.
 namespace tensorflow {
 namespace table {
 
+class Cache;
+
 // DB contents are stored in a set of blocks, each of which holds a
 // sequence of key,value pairs.  Each block may be compressed before
 // being stored in a file.  The following enum describes which
@@ -60,6 +62,12 @@ struct Options {
   // incompressible, the kSnappyCompression implementation will
   // efficiently detect that and will switch to uncompressed mode.
   CompressionType compression = kSnappyCompression;
+
+  // Control over blocks (user data is stored in a set of blocks, and
+  // a block is the unit of reading from disk).
+
+  // If non-null, use the specified cache for blocks.
+  Cache* block_cache = nullptr;
 };
 
 }  // namespace table
diff --git a/tensorflow/core/lib/monitoring/BUILD b/tensorflow/core/lib/monitoring/BUILD
index fd74298eae0..bbed9a58452 100644
--- a/tensorflow/core/lib/monitoring/BUILD
+++ b/tensorflow/core/lib/monitoring/BUILD
@@ -70,11 +70,6 @@ cc_library(
 cc_library(
     name = "counter",
     hdrs = ["counter.h"],
-    visibility = [
-        "//tensorflow/c/eager:__pkg__",
-        "//tensorflow/core:__pkg__",
-        "//tensorflow/core/platform:__subpackages__",
-    ],
     deps = [
         ":collection_registry",
         ":metric_def",
@@ -91,11 +86,6 @@ cc_library(
 cc_library(
     name = "gauge",
     hdrs = ["gauge.h"],
-    visibility = [
-        "//tensorflow/c/eager:__pkg__",
-        "//tensorflow/core:__pkg__",
-        "//tensorflow/core/platform:__subpackages__",
-    ],
     deps = [
         ":collection_registry",
         ":metric_def",
@@ -156,11 +146,6 @@ cc_library(
     name = "sampler",
     srcs = ["sampler.cc"],
     hdrs = ["sampler.h"],
-    visibility = [
-        "//tensorflow/c/eager:__pkg__",
-        "//tensorflow/core:__pkg__",
-        "//tensorflow/core/platform:__subpackages__",
-    ],
     deps = [
         ":collection_registry",
         ":metric_def",
@@ -192,11 +177,6 @@ cc_library(
     name = "percentile_sampler",
     srcs = ["percentile_sampler.cc"],
     hdrs = ["percentile_sampler.h"],
-    visibility = [
-        "//tensorflow/c/eager:__pkg__",
-        "//tensorflow/core:__pkg__",
-        "//tensorflow/core/platform:__subpackages__",
-    ],
     deps = [
         ":collection_registry",
         ":metric_def",
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index 6b637c21d24..26f3ec02694 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -360,6 +360,37 @@ MetricCollector<metric_kind, Value, NumLabels> MetricCollectorGetter::Get(
                                         collector_);
 }
 
+class Exporter {
+ public:
+  virtual ~Exporter() {}
+  virtual void PeriodicallyExportMetrics() = 0;
+  virtual void ExportMetrics() = 0;
+};
+
+namespace exporter_registration {
+
+class ExporterRegistration {
+ public:
+  explicit ExporterRegistration(Exporter* exporter) : exporter_(exporter) {
+    exporter_->PeriodicallyExportMetrics();
+  }
+
+ private:
+  Exporter* exporter_;
+};
+
+}  // namespace exporter_registration
+
+#define REGISTER_TF_METRICS_EXPORTER(exporter) \
+  REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(__COUNTER__, exporter)
+
+#define REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(ctr, exporter) \
+  REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter)
+
+#define REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter)                       \
+  static ::tensorflow::monitoring::exporter_registration::ExporterRegistration \
+      exporter_registration_##ctr(new exporter())
+
 }  // namespace monitoring
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index fb40e56829d..eef097b0c75 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -16,7 +16,6 @@ load(
     "//tensorflow/core/platform:build_config.bzl",
     "tf_additional_env_hdrs",
     "tf_additional_lib_hdrs",
-    "tf_additional_monitoring_hdrs",
     "tf_additional_tensor_coding_deps",
     "tf_additional_test_srcs",
     "tf_fingerprint_deps",
@@ -24,7 +23,6 @@ load(
     "tf_google_mobile_srcs_only_runtime",
     "tf_kernel_tests_linkstatic",
     "tf_logging_deps",
-    "tf_monitoring_deps",
     "tf_platform_alias",
     "tf_platform_deps",
     "tf_protobuf_compiler_deps",
@@ -83,7 +81,6 @@ exports_files(
         "load_library.h",
         "logging.h",
         "mem.h",
-        "monitoring.h",
         "mutex.h",
         "net.h",
         "numa.h",
@@ -333,12 +330,6 @@ cc_library(
     deps = tf_logging_deps(),
 )
 
-cc_library(
-    name = "monitoring",
-    textual_hdrs = ["monitoring.h"],
-    deps = tf_monitoring_deps(),
-)
-
 cc_library(
     name = "macros",
     hdrs = ["macros.h"],
@@ -760,6 +751,62 @@ cc_library(
     ] + tf_platform_deps("unbounded_work_queue"),
 )
 
+cc_library(
+    name = "retrying_utils",
+    srcs = [
+        "retrying_utils.cc",
+    ],
+    hdrs = [
+        "retrying_utils.h",
+    ],
+    copts = tf_copts(),
+    deps = [
+        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+cc_library(
+    name = "retrying_file_system",
+    hdrs = [
+        "retrying_file_system.h",
+    ],
+    copts = tf_copts(),
+    deps = [
+        ":retrying_utils",
+        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_cc_test(
+    name = "retrying_file_system_test",
+    size = "small",
+    srcs = ["retrying_file_system_test.cc"],
+    deps = [
+        ":retrying_file_system",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:str_util",
+    ],
+)
+
+tf_cc_test(
+    name = "retrying_utils_test",
+    size = "small",
+    srcs = ["retrying_utils_test.cc"],
+    deps = [
+        ":retrying_utils",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core/platform:str_util",
+    ],
+)
+
 # This is a hacky, do-nothing, binary that makes it easy to verify ability to
 # build, link, and in some cases run all of the libraries under platform.
 # Realistically, most of this would be covered by tests but at this point
@@ -791,6 +838,8 @@ cc_binary(
         ":png",
         ":prefetch",
         ":protobuf",
+        ":retrying_utils",
+        ":retrying_file_system",
         ":scanner",
         ":setround",
         ":stacktrace",
@@ -1178,7 +1227,6 @@ filegroup(
         "init_main.h",
         "logger.h",
         "mem.h",
-        "monitoring.h",
         "mutex.h",
         "net.h",
         "notification.h",
@@ -1203,7 +1251,7 @@ filegroup(
         "subprocess.h",
         "thread_annotations.h",
         ":base_hdrs",
-    ] + tf_additional_monitoring_hdrs() + tf_additional_env_hdrs(),
+    ] + tf_additional_env_hdrs(),
     visibility = ["//tensorflow/core:__pkg__"],
 )
 
@@ -1280,7 +1328,6 @@ filegroup(
         "demangle.h",
         "denormal.h",
         "host_info.h",
-        "monitoring.h",
         "platform.h",
         "protobuf_internal.h",
         "refcount.h",
@@ -1452,7 +1499,6 @@ filegroup(
         "cpu_feature_guard.cc",
         "cpu_feature_guard.h",
         "fingerprint.h",
-        "monitoring.h",
         "notification.h",
         "platform_strings.cc",
         "platform_strings.h",
diff --git a/tensorflow/core/platform/build_config.bzl b/tensorflow/core/platform/build_config.bzl
index 4a1ba38fbd8..d1fb5f829ea 100644
--- a/tensorflow/core/platform/build_config.bzl
+++ b/tensorflow/core/platform/build_config.bzl
@@ -12,7 +12,6 @@ load(
     _tf_additional_env_hdrs = "tf_additional_env_hdrs",
     _tf_additional_lib_deps = "tf_additional_lib_deps",
     _tf_additional_lib_hdrs = "tf_additional_lib_hdrs",
-    _tf_additional_monitoring_hdrs = "tf_additional_monitoring_hdrs",
     _tf_additional_rpc_deps = "tf_additional_rpc_deps",
     _tf_additional_tensor_coding_deps = "tf_additional_tensor_coding_deps",
     _tf_additional_test_deps = "tf_additional_test_deps",
@@ -24,7 +23,6 @@ load(
     _tf_kernel_tests_linkstatic = "tf_kernel_tests_linkstatic",
     _tf_lib_proto_parsing_deps = "tf_lib_proto_parsing_deps",
     _tf_logging_deps = "tf_logging_deps",
-    _tf_monitoring_deps = "tf_monitoring_deps",
     _tf_platform_alias = "tf_platform_alias",
     _tf_platform_deps = "tf_platform_deps",
     _tf_portable_deps_no_runtime = "tf_portable_deps_no_runtime",
@@ -56,7 +54,6 @@ tf_additional_device_tracer_srcs = _tf_additional_device_tracer_srcs
 tf_additional_env_hdrs = _tf_additional_env_hdrs
 tf_additional_lib_deps = _tf_additional_lib_deps
 tf_additional_lib_hdrs = _tf_additional_lib_hdrs
-tf_additional_monitoring_hdrs = _tf_additional_monitoring_hdrs
 tf_additional_rpc_deps = _tf_additional_rpc_deps
 tf_additional_tensor_coding_deps = _tf_additional_tensor_coding_deps
 tf_additional_test_deps = _tf_additional_test_deps
@@ -68,7 +65,6 @@ tf_jspb_proto_library = _tf_jspb_proto_library
 tf_kernel_tests_linkstatic = _tf_kernel_tests_linkstatic
 tf_lib_proto_parsing_deps = _tf_lib_proto_parsing_deps
 tf_logging_deps = _tf_logging_deps
-tf_monitoring_deps = _tf_monitoring_deps
 tf_platform_alias = _tf_platform_alias
 tf_platform_deps = _tf_platform_deps
 tf_portable_deps_no_runtime = _tf_portable_deps_no_runtime
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 53c4f6cda1f..fe08edceae9 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -92,14 +92,14 @@ cc_library(
         ":google_auth_provider",
         ":http_request",
         ":ram_file_block_cache",
-        ":retrying_file_system",
-        ":retrying_utils",
         ":time_util",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/platform:numbers",
         "//tensorflow/core/platform:path",
+        "//tensorflow/core/platform:retrying_file_system",
+        "//tensorflow/core/platform:retrying_utils",
         "//tensorflow/core/platform:str_util",
         "//tensorflow/core/platform:stringprintf",
         "@jsoncpp_git//:jsoncpp",
@@ -128,14 +128,14 @@ cc_library(
         ":google_auth_provider",
         ":http_request",
         ":ram_file_block_cache",
-        ":retrying_file_system",
-        ":retrying_utils",
         ":time_util",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/platform:numbers",
         "//tensorflow/core/platform:path",
+        "//tensorflow/core/platform:retrying_file_system",
+        "//tensorflow/core/platform:retrying_utils",
         "//tensorflow/core/platform:str_util",
         "//tensorflow/core/platform:stringprintf",
         "@jsoncpp_git//:jsoncpp",
@@ -200,12 +200,12 @@ cc_library(
     deps = [
         ":compute_engine_metadata_client",
         ":oauth_client",
-        ":retrying_utils",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/platform:base64",
         "//tensorflow/core/platform:errors",
         "//tensorflow/core/platform:path",
+        "//tensorflow/core/platform:retrying_utils",
         "//tensorflow/core/platform:status",
         "@com_google_absl//absl/strings",
         "@jsoncpp_git//:jsoncpp",
@@ -224,9 +224,9 @@ cc_library(
     deps = [
         ":curl_http_request",
         ":http_request",
-        ":retrying_utils",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/platform:retrying_utils",
     ],
 )
 
@@ -283,34 +283,6 @@ cc_library(
     ],
 )
 
-cc_library(
-    name = "retrying_utils",
-    srcs = [
-        "retrying_utils.cc",
-    ],
-    hdrs = [
-        "retrying_utils.h",
-    ],
-    copts = tf_copts(),
-    deps = [
-        "//tensorflow/core:framework_headers_lib",
-        "//tensorflow/core:lib_internal",
-    ],
-)
-
-cc_library(
-    name = "retrying_file_system",
-    hdrs = [
-        "retrying_file_system.h",
-    ],
-    copts = tf_copts(),
-    deps = [
-        ":retrying_utils",
-        "//tensorflow/core:framework_headers_lib",
-        "//tensorflow/core:lib_internal",
-    ],
-)
-
 cc_library(
     name = "time_util",
     srcs = [
@@ -482,20 +454,6 @@ tf_cc_test(
     ],
 )
 
-tf_cc_test(
-    name = "retrying_file_system_test",
-    size = "small",
-    srcs = ["retrying_file_system_test.cc"],
-    deps = [
-        ":retrying_file_system",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core/platform:str_util",
-    ],
-)
-
 tf_cc_test(
     name = "time_util_test",
     size = "small",
@@ -506,17 +464,3 @@ tf_cc_test(
         "//tensorflow/core:test_main",
     ],
 )
-
-tf_cc_test(
-    name = "retrying_utils_test",
-    size = "small",
-    srcs = ["retrying_utils_test.cc"],
-    deps = [
-        ":retrying_utils",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core/platform:str_util",
-    ],
-)
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
index d7611615606..164380b4141 100644
--- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
@@ -17,7 +17,7 @@ limitations under the License.
 #define TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_
 
 #include "tensorflow/core/platform/cloud/http_request.h"
-#include "tensorflow/core/platform/cloud/retrying_utils.h"
+#include "tensorflow/core/platform/retrying_utils.h"
 #include "tensorflow/core/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 054ad242692..57847d2ea38 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -32,7 +32,6 @@ limitations under the License.
 #include "tensorflow/core/platform/cloud/file_block_cache.h"
 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
 #include "tensorflow/core/platform/cloud/ram_file_block_cache.h"
-#include "tensorflow/core/platform/cloud/retrying_utils.h"
 #include "tensorflow/core/platform/cloud/time_util.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
@@ -40,6 +39,7 @@ limitations under the License.
 #include "tensorflow/core/platform/numbers.h"
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/retrying_utils.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/platform/stringprintf.h"
 #include "tensorflow/core/platform/thread_annotations.h"
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index b075cbe9828..98933532b17 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -29,8 +29,8 @@ limitations under the License.
 #include "tensorflow/core/platform/cloud/gcs_dns_cache.h"
 #include "tensorflow/core/platform/cloud/gcs_throttle.h"
 #include "tensorflow/core/platform/cloud/http_request.h"
-#include "tensorflow/core/platform/cloud/retrying_file_system.h"
 #include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/retrying_file_system.h"
 #include "tensorflow/core/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc
index 2c74a3e2330..e8546ca022f 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider.cc
@@ -26,10 +26,10 @@ limitations under the License.
 #include "absl/strings/match.h"
 #include "include/json/json.h"
 #include "tensorflow/core/platform/base64.h"
-#include "tensorflow/core/platform/cloud/retrying_utils.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/path.h"
+#include "tensorflow/core/platform/retrying_utils.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD
index 07a057718cb..7a96f134a9d 100644
--- a/tensorflow/core/platform/default/BUILD
+++ b/tensorflow/core/platform/default/BUILD
@@ -193,17 +193,6 @@ cc_library(
     ],
 )
 
-cc_library(
-    name = "monitoring",
-    srcs = ["monitoring.cc"],
-    hdrs = ["//tensorflow/core/platform:monitoring.h"],
-    tags = [
-        "manual",
-        "no_oss",
-        "nobuilder",
-    ],
-)
-
 cc_library(
     name = "mutex",
     srcs = [
@@ -529,7 +518,6 @@ filegroup(
     srcs = [
         "casts.h",
         "cord.h",
-        "monitoring.cc",
         "mutex.h",
         "mutex_data.h",
         "notification.h",
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 4b9b5a896d4..dbb6ffdbc6e 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -569,9 +569,6 @@ def tf_additional_lib_hdrs():
         ],
     })
 
-def tf_additional_monitoring_hdrs():
-    return []
-
 def tf_additional_env_hdrs():
     return []
 
@@ -756,9 +753,6 @@ def tf_platform_alias(name):
 def tf_logging_deps():
     return ["//tensorflow/core/platform/default:logging"]
 
-def tf_monitoring_deps():
-    return ["//tensorflow/core/platform/default:monitoring"]
-
 def tf_resource_deps():
     return ["//tensorflow/core/platform/default:resource"]
 
diff --git a/tensorflow/core/platform/default/monitoring.cc b/tensorflow/core/platform/default/monitoring.cc
deleted file mode 100644
index 71ece3e3c14..00000000000
--- a/tensorflow/core/platform/default/monitoring.cc
+++ /dev/null
@@ -1,26 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/platform/monitoring.h"
-
-namespace tensorflow {
-namespace monitoring {
-
-void StartExporter() {}
-
-void ExportMetrics() {}
-
-}  // namespace monitoring
-}  // namespace tensorflow
diff --git a/tensorflow/core/platform/default/port.cc b/tensorflow/core/platform/default/port.cc
index 47f4abae3bb..756e7e8a93a 100644
--- a/tensorflow/core/platform/default/port.cc
+++ b/tensorflow/core/platform/default/port.cc
@@ -332,6 +332,16 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) {
 #endif
 }
 
+bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length,
+                              const struct iovec* iov, size_t iov_cnt) {
+#ifdef TF_USE_SNAPPY
+  return snappy::RawUncompressToIOVec(compressed, compressed_length, iov,
+                                      iov_cnt);
+#else
+  return false;
+#endif
+}
+
 string Demangle(const char* mangled) { return mangled; }
 
 double NominalCPUFrequency() {
diff --git a/tensorflow/core/platform/monitoring.h b/tensorflow/core/platform/monitoring.h
deleted file mode 100644
index f01233933c3..00000000000
--- a/tensorflow/core/platform/monitoring.h
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_PLATFORM_MONITORING_H_
-#define TENSORFLOW_CORE_PLATFORM_MONITORING_H_
-
-namespace tensorflow {
-namespace monitoring {
-
-// Starts exporting metrics through a platform-specific monitoring API (if
-// provided). For builds using "tensorflow/core/platform/default", this is
-// currently a no-op. This function is idempotent.
-//
-// The TensorFlow runtime will call this the first time a new session is created
-// using the NewSession() method or an Eager Context is created.
-void StartExporter();
-
-// Manually invokes a one time metrics export through a platform-specific
-// monitoring API (if provided). For builds using
-// "tensorflow/core/platform/default", this is currently a no-op.
-void ExportMetrics();
-
-}  // namespace monitoring
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_PLATFORM_MONITORING_H_
diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/retrying_file_system.h
similarity index 99%
rename from tensorflow/core/platform/cloud/retrying_file_system.h
rename to tensorflow/core/platform/retrying_file_system.h
index 12bbc7d6abb..df8850ace93 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system.h
+++ b/tensorflow/core/platform/retrying_file_system.h
@@ -21,10 +21,10 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/platform/cloud/retrying_utils.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/retrying_utils.h"
 #include "tensorflow/core/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/retrying_file_system_test.cc
similarity index 99%
rename from tensorflow/core/platform/cloud/retrying_file_system_test.cc
rename to tensorflow/core/platform/retrying_file_system_test.cc
index b48831ab238..b43c3375265 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/retrying_file_system_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/platform/cloud/retrying_file_system.h"
+#include "tensorflow/core/platform/retrying_file_system.h"
 
 #include <fstream>
 
diff --git a/tensorflow/core/platform/cloud/retrying_utils.cc b/tensorflow/core/platform/retrying_utils.cc
similarity index 98%
rename from tensorflow/core/platform/cloud/retrying_utils.cc
rename to tensorflow/core/platform/retrying_utils.cc
index 1f0c41824bf..1b6fa68c31c 100644
--- a/tensorflow/core/platform/cloud/retrying_utils.cc
+++ b/tensorflow/core/platform/retrying_utils.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/platform/cloud/retrying_utils.h"
+#include "tensorflow/core/platform/retrying_utils.h"
 
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/platform/env.h"
diff --git a/tensorflow/core/platform/cloud/retrying_utils.h b/tensorflow/core/platform/retrying_utils.h
similarity index 100%
rename from tensorflow/core/platform/cloud/retrying_utils.h
rename to tensorflow/core/platform/retrying_utils.h
diff --git a/tensorflow/core/platform/cloud/retrying_utils_test.cc b/tensorflow/core/platform/retrying_utils_test.cc
similarity index 98%
rename from tensorflow/core/platform/cloud/retrying_utils_test.cc
rename to tensorflow/core/platform/retrying_utils_test.cc
index 7a2dbacacc8..5b162571067 100644
--- a/tensorflow/core/platform/cloud/retrying_utils_test.cc
+++ b/tensorflow/core/platform/retrying_utils_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/platform/cloud/retrying_utils.h"
+#include "tensorflow/core/platform/retrying_utils.h"
 
 #include <fstream>
 
diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD
index a5494d5c318..d174b108279 100644
--- a/tensorflow/core/platform/s3/BUILD
+++ b/tensorflow/core/platform/s3/BUILD
@@ -33,6 +33,8 @@ tf_cc_binary(
     linkshared = 1,
     deps = [
         "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core/platform:retrying_file_system",
+        "//tensorflow/core/platform:retrying_utils",
         "@aws",
         "@com_google_protobuf//:protobuf_headers",
         "@curl",
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index acdacc306d5..3be563d2c94 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -196,6 +196,23 @@ Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket,
   return Status::OK();
 }
 
+static Status CheckForbiddenError(
+    const Aws::Client::AWSError<Aws::S3::S3Errors>& error) {
+  if (error.GetResponseCode() == Aws::Http::HttpResponseCode::FORBIDDEN) {
+    return errors::FailedPrecondition(
+        "AWS Credentials have not been set properly. "
+        "Unable to access the specified S3 location");
+  } else {
+    return Status::OK();
+  }
+}
+
+static Status CreateStatusFromAwsError(
+    const Aws::Client::AWSError<Aws::S3::S3Errors>& error) {
+  TF_RETURN_IF_ERROR(CheckForbiddenError(error));
+  return errors::Unknown(error.GetExceptionName(), ": ", error.GetMessage());
+}
+
 class S3RandomAccessFile : public RandomAccessFile {
  public:
   S3RandomAccessFile(const string& bucket, const string& object,
@@ -217,9 +234,14 @@ class S3RandomAccessFile : public RandomAccessFile {
     });
     auto getObjectOutcome = this->s3_client_->GetObject(getObjectRequest);
     if (!getObjectOutcome.IsSuccess()) {
-      n = 0;
-      *result = StringPiece(scratch, n);
-      return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
+      auto error = getObjectOutcome.GetError();
+      if (error.GetResponseCode() ==
+          Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE) {
+        n = 0;
+        *result = StringPiece(scratch, n);
+        return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
+      }
+      return CreateStatusFromAwsError(error);
     }
     n = getObjectOutcome.GetResult().GetContentLength();
     getObjectOutcome.GetResult().GetBody().read(scratch, n);
@@ -305,15 +327,10 @@ class S3WritableFile : public WritableFile {
 
     if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED) {
       auto error = handle->GetLastError();
-      if (error.GetResponseCode() == Aws::Http::HttpResponseCode::FORBIDDEN) {
-        return errors::FailedPrecondition(
-            "AWS Credentials have not been set properly. "
-            "Unable to access the specified S3 location");
-      } else {
-        return errors::Unknown(
-            error.GetExceptionName(), ": ", handle->GetFailedParts().size(),
-            " failed parts. ", handle->GetLastError().GetMessage());
-      }
+      TF_RETURN_IF_ERROR(CheckForbiddenError(error));
+      return errors::Unknown(error.GetExceptionName(), ": ",
+                             handle->GetFailedParts().size(), " failed parts. ",
+                             handle->GetLastError().GetMessage());
     }
     outfile_->clear();
     outfile_->seekp(offset);
@@ -514,8 +531,7 @@ Status S3FileSystem::GetChildren(const string& dir,
     auto listObjectsOutcome =
         this->GetS3Client()->ListObjects(listObjectsRequest);
     if (!listObjectsOutcome.IsSuccess()) {
-      return errors::Unknown(listObjectsOutcome.GetError().GetExceptionName(),
-                             ": ", listObjectsOutcome.GetError().GetMessage());
+      return CreateStatusFromAwsError(listObjectsOutcome.GetError());
     }
 
     listObjectsResult = listObjectsOutcome.GetResult();
@@ -549,8 +565,7 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) {
     headBucketRequest.WithBucket(bucket.c_str());
     auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest);
     if (!headBucketOutcome.IsSuccess()) {
-      return errors::Unknown(headBucketOutcome.GetError().GetExceptionName(),
-                             ": ", headBucketOutcome.GetError().GetMessage());
+      return CreateStatusFromAwsError(headBucketOutcome.GetError());
     }
     stats->length = 0;
     stats->is_directory = 1;
@@ -570,6 +585,8 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) {
     stats->mtime_nsec =
         headObjectOutcome.GetResult().GetLastModified().Millis() * 1e6;
     found = true;
+  } else {
+    TF_RETURN_IF_ERROR(CheckForbiddenError(headObjectOutcome.GetError()));
   }
   string prefix = object;
   if (prefix.back() != '/') {
@@ -584,11 +601,15 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) {
   auto listObjectsOutcome =
       this->GetS3Client()->ListObjects(listObjectsRequest);
   if (listObjectsOutcome.IsSuccess()) {
-    if (listObjectsOutcome.GetResult().GetContents().size() > 0) {
+    auto listObjects = listObjectsOutcome.GetResult().GetContents();
+    if (listObjects.size() > 0) {
       stats->length = 0;
       stats->is_directory = 1;
+      stats->mtime_nsec = listObjects[0].GetLastModified().Millis() * 1e6;
       found = true;
     }
+  } else {
+    TF_RETURN_IF_ERROR(CheckForbiddenError(listObjectsOutcome.GetError()));
   }
   if (!found) {
     return errors::NotFound("Object ", fname, " does not exist");
@@ -611,8 +632,7 @@ Status S3FileSystem::DeleteFile(const string& fname) {
   auto deleteObjectOutcome =
       this->GetS3Client()->DeleteObject(deleteObjectRequest);
   if (!deleteObjectOutcome.IsSuccess()) {
-    return errors::Unknown(deleteObjectOutcome.GetError().GetExceptionName(),
-                           ": ", deleteObjectOutcome.GetError().GetMessage());
+    return CreateStatusFromAwsError(deleteObjectOutcome.GetError());
   }
   return Status::OK();
 }
@@ -626,6 +646,7 @@ Status S3FileSystem::CreateDir(const string& dirname) {
     headBucketRequest.WithBucket(bucket.c_str());
     auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest);
     if (!headBucketOutcome.IsSuccess()) {
+      TF_RETURN_IF_ERROR(CheckForbiddenError(headBucketOutcome.GetError()));
       return errors::NotFound("The bucket ", bucket, " was not found.");
     }
     return Status::OK();
@@ -634,9 +655,11 @@ Status S3FileSystem::CreateDir(const string& dirname) {
   if (filename.back() != '/') {
     filename.push_back('/');
   }
-  std::unique_ptr<WritableFile> file;
-  TF_RETURN_IF_ERROR(NewWritableFile(filename, &file));
-  TF_RETURN_IF_ERROR(file->Close());
+  if (!this->FileExists(filename).ok()) {
+    std::unique_ptr<WritableFile> file;
+    TF_RETURN_IF_ERROR(NewWritableFile(filename, &file));
+    TF_RETURN_IF_ERROR(file->Close());
+  }
   return Status::OK();
 }
 
@@ -660,7 +683,10 @@ Status S3FileSystem::DeleteDir(const string& dirname) {
     auto contents = listObjectsOutcome.GetResult().GetContents();
     if (contents.size() > 1 ||
         (contents.size() == 1 && contents[0].GetKey() != prefix.c_str())) {
-      return errors::FailedPrecondition("Cannot delete a non-empty directory.");
+      return errors::Unknown(
+          "Cannot delete a non-empty directory. "
+          "This operation will be retried in case this "
+          "is due to S3's eventual consistency.");
     }
     if (contents.size() == 1 && contents[0].GetKey() == prefix.c_str()) {
       string filename = dirname;
@@ -669,6 +695,8 @@ Status S3FileSystem::DeleteDir(const string& dirname) {
       }
       return DeleteFile(filename);
     }
+  } else {
+    TF_RETURN_IF_ERROR(CheckForbiddenError(listObjectsOutcome.GetError()));
   }
   return Status::OK();
 }
@@ -762,8 +790,7 @@ Status S3FileSystem::SimpleCopy(const Aws::String& source,
   copyObjectRequest.SetCopySource(source);
   auto copyObjectOutcome = this->GetS3Client()->CopyObject(copyObjectRequest);
   if (!copyObjectOutcome.IsSuccess()) {
-    return errors::Unknown(copyObjectOutcome.GetError().GetExceptionName(),
-                           ": ", copyObjectOutcome.GetError().GetMessage());
+    return CreateStatusFromAwsError(copyObjectOutcome.GetError());
   }
   return Status::OK();
 }
@@ -782,9 +809,7 @@ Status S3FileSystem::MultiPartCopy(const Aws::String& source,
   auto multipartUploadOutcome =
       this->GetS3Client()->CreateMultipartUpload(multipartUploadRequest);
   if (!multipartUploadOutcome.IsSuccess()) {
-    return errors::Unknown(multipartUploadOutcome.GetError().GetExceptionName(),
-                           ": ",
-                           multipartUploadOutcome.GetError().GetMessage());
+    return CreateStatusFromAwsError(multipartUploadOutcome.GetError());
   }
 
   Aws::String uploadID = multipartUploadOutcome.GetResult().GetUploadId();
@@ -916,8 +941,7 @@ Status S3FileSystem::AbortMultiPartCopy(Aws::String target_bucket,
       .WithUploadId(uploadID);
   auto abortOutcome = this->GetS3Client()->AbortMultipartUpload(abortRequest);
   if (!abortOutcome.IsSuccess()) {
-    return errors::Unknown(abortOutcome.GetError().GetExceptionName(), ": ",
-                           abortOutcome.GetError().GetMessage());
+    return CreateStatusFromAwsError(abortOutcome.GetError());
   }
   return Status::OK();
 }
@@ -933,8 +957,7 @@ Status S3FileSystem::CompleteMultiPartCopy(
   auto completeOutcome =
       this->GetS3Client()->CompleteMultipartUpload(completeRequest);
   if (!completeOutcome.IsSuccess()) {
-    return errors::Unknown(completeOutcome.GetError().GetExceptionName(), ": ",
-                           completeOutcome.GetError().GetMessage());
+    return CreateStatusFromAwsError(completeOutcome.GetError());
   }
   return Status::OK();
 }
@@ -969,8 +992,7 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) {
     auto listObjectsOutcome =
         this->GetS3Client()->ListObjects(listObjectsRequest);
     if (!listObjectsOutcome.IsSuccess()) {
-      return errors::Unknown(listObjectsOutcome.GetError().GetExceptionName(),
-                             ": ", listObjectsOutcome.GetError().GetMessage());
+      return CreateStatusFromAwsError(listObjectsOutcome.GetError());
     }
 
     listObjectsResult = listObjectsOutcome.GetResult();
@@ -989,9 +1011,7 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) {
       auto deleteObjectOutcome =
           this->GetS3Client()->DeleteObject(deleteObjectRequest);
       if (!deleteObjectOutcome.IsSuccess()) {
-        return errors::Unknown(
-            deleteObjectOutcome.GetError().GetExceptionName(), ": ",
-            deleteObjectOutcome.GetError().GetMessage());
+        return CreateStatusFromAwsError(deleteObjectOutcome.GetError());
       }
     }
     listObjectsRequest.SetMarker(listObjectsResult.GetNextMarker());
@@ -1000,6 +1020,6 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) {
   return Status::OK();
 }
 
-REGISTER_FILE_SYSTEM("s3", S3FileSystem);
+REGISTER_FILE_SYSTEM("s3", RetryingS3FileSystem);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/platform/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h
index 8d4dfd6b6bd..7ea01b24cf9 100644
--- a/tensorflow/core/platform/s3/s3_file_system.h
+++ b/tensorflow/core/platform/s3/s3_file_system.h
@@ -24,6 +24,7 @@ limitations under the License.
 
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/retrying_file_system.h"
 
 namespace tensorflow {
 
@@ -133,6 +134,17 @@ class S3FileSystem : public FileSystem {
   uint64 multi_part_copy_part_size_;
 };
 
+/// S3 implementation of a file system with retry on failures.
+class RetryingS3FileSystem : public RetryingFileSystem<S3FileSystem> {
+ public:
+  RetryingS3FileSystem()
+      : RetryingFileSystem(
+            std::unique_ptr<S3FileSystem>(new S3FileSystem),
+            RetryConfig(100000 /* init_delay_time_us */,
+                        32000000 /* max_delay_time_us */, 10 /* max_retries */
+                        )) {}
+};
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CONTRIB_S3_S3_FILE_SYSTEM_H_
diff --git a/tensorflow/core/platform/snappy.h b/tensorflow/core/platform/snappy.h
index 5477b097ef0..df06f3dcc1e 100644
--- a/tensorflow/core/platform/snappy.h
+++ b/tensorflow/core/platform/snappy.h
@@ -18,6 +18,17 @@ limitations under the License.
 
 #include "tensorflow/core/platform/types.h"
 
+#if !defined(PLATFORM_WINDOWS)
+#include <sys/uio.h>
+#else
+namespace tensorflow {
+struct iovec {
+  void* iov_base;
+  size_t iov_len;
+};
+}  // namespace tensorflow
+#endif
+
 namespace tensorflow {
 namespace port {
 
@@ -28,6 +39,9 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length,
                                   size_t* result);
 bool Snappy_Uncompress(const char* input, size_t length, char* output);
 
+bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length,
+                              const struct iovec* iov, size_t iov_cnt);
+
 }  // namespace port
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc
index 2303b587ce6..547af76bdf6 100644
--- a/tensorflow/core/platform/windows/port.cc
+++ b/tensorflow/core/platform/windows/port.cc
@@ -157,6 +157,17 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) {
 #endif
 }
 
+bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length,
+                              const struct iovec* iov, size_t iov_cnt) {
+#ifdef TF_USE_SNAPPY
+  const snappy::iovec* snappy_iov = reinterpret_cast<const snappy::iovec*>(iov);
+  return snappy::RawUncompressToIOVec(compressed, compressed_length, snappy_iov,
+                                      iov_cnt);
+#else
+  return false;
+#endif
+}
+
 string Demangle(const char* mangled) { return mangled; }
 
 double NominalCPUFrequency() {
diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.cc b/tensorflow/core/profiler/rpc/client/capture_profile.cc
index 87a71bb9ff2..e344bf6e1ba 100644
--- a/tensorflow/core/profiler/rpc/client/capture_profile.cc
+++ b/tensorflow/core/profiler/rpc/client/capture_profile.cc
@@ -60,6 +60,7 @@ ProfileRequest PopulateProfileRequest(int duration_ms,
   }
   request.add_tools("op_profile");
   request.add_tools("input_pipeline");
+  request.add_tools("kernel_stats");
   request.add_tools("memory_viewer");
   request.add_tools("overview_page");
   request.add_tools("pod_viewer");
diff --git a/tensorflow/core/protobuf/data/experimental/snapshot.proto b/tensorflow/core/protobuf/data/experimental/snapshot.proto
index 422602d3760..e013deb2ee1 100644
--- a/tensorflow/core/protobuf/data/experimental/snapshot.proto
+++ b/tensorflow/core/protobuf/data/experimental/snapshot.proto
@@ -3,6 +3,8 @@ syntax = "proto3";
 package tensorflow.data.experimental;
 
 import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
 
 // Each SnapshotRecord represents one batch of pre-processed input data. A batch
 // consists of a list of tensors that we encode as TensorProtos. This message
@@ -13,9 +15,29 @@ message SnapshotRecord {
 
 // This stores the metadata information present in each snapshot record.
 message SnapshotMetadataRecord {
+  // Stores the fingerprint of the graph that describes the dataset that is
+  // snapshotted.
   string graph_hash = 1;
+  // Run ID that this snapshot corresponds to.
   string run_id = 2;
+  // Time when we started creating this snapshot.
   int64 creation_timestamp = 3;
+  // Version of the snapshot data file format.
+  int64 version = 4;
+  // A list of tensor dtype corresponding to each element of the snapshot.
+  repeated .tensorflow.DataType dtype = 5;
 
   bool finalized = 1000;
 }
+
+// Metadata for a single tensor in the Snapshot Record.
+message TensorMetadata {
+  .tensorflow.TensorShapeProto tensor_shape = 2;
+  // Number of uncompressed bytes used to store the tensor representation.
+  int64 tensor_size_bytes = 3;
+}
+
+// Metadata for all the tensors in a Snapshot Record.
+message SnapshotTensorMetadata {
+  repeated TensorMetadata tensor_metadata = 1;
+}
diff --git a/tensorflow/core/util/mkl_types.h b/tensorflow/core/util/mkl_types.h
index cdf313d585e..8edbddac8a1 100644
--- a/tensorflow/core/util/mkl_types.h
+++ b/tensorflow/core/util/mkl_types.h
@@ -48,6 +48,7 @@ namespace tensorflow {
 #define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) op_pd->workspace_desc()
 #define GET_TENSOR_FORMAT(fmt) MklTensorFormatToMklDnnDataFormat(fmt)
 #define GET_TF_DATA_FORMAT(shape, mem_desc) shape.GetTfDataFormat()
+#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc()
 #define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd->weights_desc()
 #define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op) \
   GET_WEIGHTS_DESC_FROM_OP_PD(op_pd)
@@ -114,14 +115,13 @@ namespace tensorflow {
 #define TENSOR_FORMAT MKL_TENSOR_FORMAT
 #define TENSOR_FORMAT_NHWC MKL_TENSOR_FORMAT_NHWC
 #define TENSOR_MAX_DIMS MKLDNN_MAX_NDIMS
-#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc()
-#define BN_FLAGS mkldnn::normalization_flags
 
 #else
 
 #define ADD_MD add_pd
 #define ALGORITHM mkldnn
 #define ALGORITHM_UNDEF ALGORITHM::algorithm_undef
+#define BN_FLAGS mkldnn
 #define CPU_STREAM(engine) stream(stream::kind::eager_nostore)
 #define DATA_WITH_ENGINE(data, engine) data
 #define DST_MD dst_pd
@@ -148,6 +148,7 @@ namespace tensorflow {
   op_pd.get()->workspace_primitive_desc()
 #define GET_TENSOR_FORMAT(fmt) fmt
 #define GET_TF_DATA_FORMAT(shape, mem_desc) mem_desc.data.format
+#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemPrimDesc()
 #define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd.get()->weights_primitive_desc()
 #define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op) op->GetFilterMemoryFormat()
 #define IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, op_pd, op) \
@@ -215,8 +216,6 @@ namespace tensorflow {
 #define SUMMAND_MD summand_pd
 #define TENSOR_FORMAT TensorFormat
 #define TENSOR_FORMAT_NHWC FORMAT_NHWC
-#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemPrimDesc()
-#define BN_FLAGS mkldnn
 #endif  // ENABLE_MKLDNN_V1
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index 3ebb2ea5dc8..339c28dfb66 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -40,6 +40,7 @@ limitations under the License.
 #include "tensorflow/core/lib/io/table_builder.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/util/env_var.h"
 #include "tensorflow/core/util/saved_tensor_slice_util.h"
 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
 #include "tensorflow/core/util/tensor_slice_util.h"
@@ -729,6 +730,7 @@ BundleReader::BundleReader(Env* env, StringPiece prefix)
       prefix_(prefix),
       metadata_(nullptr),
       table_(nullptr),
+      index_cache_(nullptr),
       iter_(nullptr),
       need_to_swap_bytes_(false) {
   const string filename = MetaFilename(prefix_);
@@ -741,7 +743,17 @@ BundleReader::BundleReader(Env* env, StringPiece prefix)
   status_ = env_->NewRandomAccessFile(filename, &wrapper);
   if (!status_.ok()) return;
   metadata_ = wrapper.release();
-  status_ = table::Table::Open(table::Options(), metadata_, file_size, &table_);
+
+  table::Options o;
+  int64 cache_size;
+  Status s =
+      ReadInt64FromEnvVar("TF_TABLE_INDEX_CACHE_SIZE_IN_MB", 0, &cache_size);
+  if (s.ok() && cache_size > 0) {
+    index_cache_ = table::NewLRUCache(cache_size << 20);
+    o.block_cache = index_cache_;
+  }
+
+  status_ = table::Table::Open(o, metadata_, file_size, &table_);
   if (!status_.ok()) return;
   iter_ = table_->NewIterator();
 
@@ -772,6 +784,9 @@ BundleReader::~BundleReader() {
   delete metadata_;
   delete iter_;
   delete table_;
+  if (index_cache_) {
+    delete index_cache_;
+  }
   // InputBuffer does not own the underlying RandomAccessFile.
   for (auto pair : data_) {
     if (pair.second != nullptr && pair.second->file() != nullptr) {
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
index e1f39eccd17..c362dd41151 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
@@ -61,8 +61,6 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
 #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
 
-#include "tensorflow/core/protobuf/tensor_bundle.pb.h"
-
 #include <map>
 #include <string>
 #include <unordered_map>
@@ -72,12 +70,14 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_slice.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/io/cache.h"
 #include "tensorflow/core/lib/io/inputbuffer.h"
 #include "tensorflow/core/lib/io/table.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/file_system.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/tensor_bundle.pb.h"
 #include "tensorflow/core/util/tensor_bundle/naming.h"
 #include "tensorflow/core/util/tensor_slice_set.h"
 
@@ -288,6 +288,7 @@ class BundleReader {
   Status status_;
   RandomAccessFile* metadata_;  // Owned.
   table::Table* table_;
+  table::Cache* index_cache_;
   table::Iterator* iter_;
   // Owned the InputBuffer objects and their underlying RandomAccessFile's.
   std::unordered_map<int32, io::InputBuffer*> data_;
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 449a95765a5..ecdce1e627b 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -11611,7 +11611,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d
 // element on that dimension. The dimension order is determined by the value of
 // `data_format`, see above for details. Dilations in the batch and depth
 // dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -11868,7 +11868,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2
 //
 // value: The cropped area of the image must have an aspect ratio =
 // width / height within this range.
-// If not specified, defaults to {f:0.75 f:1.33}
+// If not specified, defaults to {f:0.75  f:1.33}
 func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr {
 	return func(m optionalAttr) {
 		m["aspect_ratio_range"] = value
@@ -11879,7 +11879,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort
 //
 // value: The cropped area of the image must contain a fraction of the
 // supplied image within this range.
-// If not specified, defaults to {f:0.05 f:1}
+// If not specified, defaults to {f:0.05  f:1}
 func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
 	return func(m optionalAttr) {
 		m["area_range"] = value
@@ -12085,7 +12085,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo
 //
 // value: The cropped area of the image must have an aspect ratio =
 // width / height within this range.
-// If not specified, defaults to {f:0.75 f:1.33}
+// If not specified, defaults to {f:0.75  f:1.33}
 func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
 	return func(m optionalAttr) {
 		m["aspect_ratio_range"] = value
@@ -12096,7 +12096,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted
 //
 // value: The cropped area of the image must contain a fraction of the
 // supplied image within this range.
-// If not specified, defaults to {f:0.05 f:1}
+// If not specified, defaults to {f:0.05  f:1}
 func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
 	return func(m optionalAttr) {
 		m["area_range"] = value
@@ -18937,7 +18937,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr {
 // ImageSummaryBadColor sets the optional bad_color attribute to value.
 //
 // value: Color to use for pixels with non-finite values.
-// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
+// If not specified, defaults to {dtype:DT_UINT8  tensor_shape:{dim:{size:4}}  int_val:255  int_val:0  int_val:0  int_val:255}
 func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
 	return func(m optionalAttr) {
 		m["bad_color"] = value
@@ -20077,7 +20077,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr {
 // filter element on that dimension. The dimension order is determined by the
 // value of `data_format`, see above for details. Dilations in the batch and
 // depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1  i:1}
 func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -21345,7 +21345,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr {
 // element on that dimension. The dimension order is determined by the value of
 // `data_format`, see above for details. Dilations in the batch and depth
 // dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22053,7 +22053,7 @@ func Conv2DDataFormat(value string) Conv2DAttr {
 // filter element on that dimension. The dimension order is determined by the
 // value of `data_format`, see above for details. Dilations in the batch and
 // depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func Conv2DDilations(value []int64) Conv2DAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22249,7 +22249,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy
 // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value.
 //
 // value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22318,7 +22318,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized
 // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value.
 //
 // value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22433,7 +22433,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi
 // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value.
 //
 // value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22492,7 +22492,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D
 // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value.
 //
 // value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22666,7 +22666,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann
 // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value.
 //
 // value: list of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22857,7 +22857,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr {
 // filter element on that dimension. The dimension order is determined by the
 // value of `data_format`, see above for details. Dilations in the batch and
 // depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1  i:1}
 func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -25297,7 +25297,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi
 type Conv3DBackpropFilterAttr func(optionalAttr)
 
 // Conv3DBackpropFilterDilations sets the optional dilations attribute to value.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1  i:1}
 func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -25629,7 +25629,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN
 // element on that dimension. The dimension order is determined by the value of
 // `data_format`, see above for details. Dilations in the batch and depth
 // dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -25679,7 +25679,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil
 type Conv3DBackpropInputAttr func(optionalAttr)
 
 // Conv3DBackpropInputDilations sets the optional dilations attribute to value.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1  i:1}
 func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -25929,7 +25929,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr {
 // element on that dimension. The dimension order is determined by the value of
 // `data_format`, see above for details. Dilations in the batch and depth
 // dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -26559,7 +26559,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr {
 // filter element on that dimension. The dimension order is determined by the
 // value of `data_format`, see above for details. Dilations in the batch and
 // depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -27624,7 +27624,7 @@ func Conv3DDataFormat(value string) Conv3DAttr {
 // filter element on that dimension. The dimension order is determined by the
 // value of `data_format`, see above for details. Dilations in the batch and
 // depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1  i:1}
 func Conv3DDilations(value []int64) Conv3DAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -45536,7 +45536,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr {
 // element on that dimension. The dimension order is determined by the value of
 // `data_format`, see above for details. Dilations in the batch and depth
 // dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1  i:1  i:1  i:1}
 func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD
index d09e75e052d..2aeff13b3be 100644
--- a/tensorflow/lite/delegates/gpu/cl/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/BUILD
@@ -301,6 +301,7 @@ cc_library(
         ":model_hints",
         ":opencl_wrapper",
         ":precision",
+        ":storage_type_util",
         ":tensor_type",
         "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector",
@@ -387,6 +388,19 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "storage_type_util",
+    srcs = ["storage_type_util.cc"],
+    hdrs = ["storage_type_util.h"],
+    deps = [
+        ":cl_context",
+        ":cl_device",
+        ":tensor_type",
+        "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:shape",
+    ],
+)
+
 cc_library(
     name = "tensor",
     srcs = ["tensor.cc"],
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
index a2a66cae0c9..6b0511fb267 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
@@ -30,6 +30,7 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/cl/model_hints.h"
 #include "tensorflow/lite/delegates/gpu/cl/precision.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h"
+#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/memory_management.h"
@@ -109,64 +110,6 @@ void AddUsage(ValueId id, int task_index,
   }
 }
 
-TensorStorageType SelectBestStorageType(const CLContext& context,
-                                        const CLDevice& device,
-                                        const BHWC& shape,
-                                        const TensorStorageType& desired,
-                                        const DataType& data_type,
-                                        const Layout& layout) {
-  if (CanCreateTensorWithShape(context, device, shape,
-                               TensorDescriptor{data_type, desired, layout})) {
-    return desired;
-  }
-  auto GetBestTypeAfterTextureArray = [&]() {
-    if (device.SupportsImageBuffer() &&
-        CanCreateTensorWithShape(
-            context, device, shape,
-            TensorDescriptor{data_type, TensorStorageType::IMAGE_BUFFER,
-                             layout})) {
-      return TensorStorageType::IMAGE_BUFFER;
-    } else {
-      return TensorStorageType::BUFFER;
-    }
-  };
-  auto GetBestTypeAfterTexture2D = [&]() {
-    if (device.SupportsTextureArray() &&
-        CanCreateTensorWithShape(
-            context, device, shape,
-            TensorDescriptor{data_type, TensorStorageType::TEXTURE_ARRAY,
-                             layout})) {
-      return TensorStorageType::TEXTURE_ARRAY;
-    } else {
-      return GetBestTypeAfterTextureArray();
-    }
-  };
-  auto GetBestTypeAfterTexture3D = [&]() {
-    if (CanCreateTensorWithShape(
-            context, device, shape,
-            TensorDescriptor{data_type, TensorStorageType::TEXTURE_2D,
-                             layout})) {
-      return TensorStorageType::TEXTURE_2D;
-    } else {
-      return GetBestTypeAfterTexture2D();
-    }
-  };
-  switch (desired) {
-    case TensorStorageType::TEXTURE_2D:
-    case TensorStorageType::SINGLE_TEXTURE_2D:
-      return GetBestTypeAfterTexture2D();
-    case TensorStorageType::TEXTURE_ARRAY:
-      return GetBestTypeAfterTextureArray();
-    case TensorStorageType::TEXTURE_3D:
-      return GetBestTypeAfterTexture3D();
-    case TensorStorageType::IMAGE_BUFFER:
-    case TensorStorageType::BUFFER:
-      return TensorStorageType::BUFFER;
-    default:
-      return TensorStorageType::BUFFER;
-  }
-}
-
 // returns true if actual memory for this storage type will be allocated with
 // clCreateBuffer.
 bool IsBufferBased(const TensorStorageType& type) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc
index aa071910658..78e9795bc63 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc
@@ -69,6 +69,63 @@ std::string GenerateAsyncUpload(const std::string& local_ptr_name,
        offset + ", " + std::to_string(elements_to_upload) + ", 0);\n";
   return c;
 }
+
+std::string GenerateBlockCoords(const int3& block_size,
+                                const int3& work_group_launch_order,
+                                bool linear_hw) {
+  std::string c;
+  int3 launch_remap;
+  launch_remap[work_group_launch_order.x] = 0;
+  launch_remap[work_group_launch_order.y] = 1;
+  launch_remap[work_group_launch_order.z] = 2;
+  if (linear_hw) {
+    if (work_group_launch_order[0] == 0) {
+      c += "  int linear_hw = get_global_id(0);\n";
+    } else {
+      c += "  int linear_hw = get_group_id(" + std::to_string(launch_remap[0]) +
+           ") * get_local_size(0) + get_local_id(0);\n";
+    }
+    c += "  int Y = (linear_hw / task_size_x) * " +
+         std::to_string(block_size.y) + ";\n";
+    c += "  int X = (linear_hw % task_size_x) * " +
+         std::to_string(block_size.x) + ";\n";
+    if (work_group_launch_order[1] == 1) {
+      c += "  int Z = get_global_id(1) * " + std::to_string(block_size.z) +
+           ";\n";
+    } else {
+      c += "  int Z = (get_group_id(" + std::to_string(launch_remap[1]) +
+           ") * get_local_size(1) + get_local_id(1)) * " +
+           std::to_string(block_size.z) + ";\n";
+    }
+  } else {
+    if (work_group_launch_order[0] == 0) {
+      c += "  int X = get_global_id(0) * " + std::to_string(block_size.x) +
+           ";\n";
+    } else {
+      c += "  int X = (get_group_id(" + std::to_string(launch_remap[0]) +
+           ") * get_local_size(0) + get_local_id(0)) * " +
+           std::to_string(block_size.x) + ";\n";
+    }
+    if (work_group_launch_order[1] == 1) {
+      c += "  int Y = get_global_id(1) * " + std::to_string(block_size.y) +
+           ";\n";
+    } else {
+      c += "  int Y = (get_group_id(" + std::to_string(launch_remap[1]) +
+           ") * get_local_size(1) + get_local_id(1)) * " +
+           std::to_string(block_size.y) + ";\n";
+    }
+    if (work_group_launch_order[2] == 2) {
+      c += "  int Z = get_global_id(2) * " + std::to_string(block_size.z) +
+           ";\n";
+    } else {
+      c += "  int Z = (get_group_id(" + std::to_string(launch_remap[2]) +
+           ") * get_local_size(2) + get_local_id(2)) * " +
+           std::to_string(block_size.z) + ";\n";
+    }
+  }
+
+  return c;
+}
 }  // namespace
 
 ConvPowerVR::ConvPowerVR(const OperationDef& definition,
@@ -146,6 +203,11 @@ Status ConvPowerVR::BindArguments() {
         int4(kernel_dilation_.x, kernel_dilation_.y,
              kernel_dilation_.z * src_[0]->Batch(), kernel_dilation_.w)));
   }
+  if (conv_params_.linear_hw) {
+    const int grid_x = IntegralDivideRoundUp(
+        dst_[0]->Width() * dst_[0]->Batch(), conv_params_.block_size.x);
+    RETURN_IF_ERROR(kernel_.SetBytesAuto(grid_x));
+  }
   RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
   RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
   return OkStatus();
@@ -159,15 +221,27 @@ int3 ConvPowerVR::GetGridSize() const {
   const int grid_z =
       IntegralDivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.z);
   int3 wg;
-  wg.x = IntegralDivideRoundUp(grid_x, conv_params_.work_group_size.x);
-  wg.y = IntegralDivideRoundUp(grid_y, conv_params_.work_group_size.y);
-  wg.z = IntegralDivideRoundUp(grid_z, conv_params_.work_group_size.z);
-  return int3(wg[conv_params_.work_group_launch_order[0]] *
-                  conv_params_.work_group_size.x,
-              wg[conv_params_.work_group_launch_order[1]] *
-                  conv_params_.work_group_size.y,
-              wg[conv_params_.work_group_launch_order[2]] *
-                  conv_params_.work_group_size.z);
+
+  if (conv_params_.linear_hw) {
+    wg.x =
+        IntegralDivideRoundUp(grid_x * grid_y, conv_params_.work_group_size.x);
+    wg.y = IntegralDivideRoundUp(grid_z, conv_params_.work_group_size.y);
+    return int3(wg[conv_params_.work_group_launch_order[0]] *
+                    conv_params_.work_group_size.x,
+                wg[conv_params_.work_group_launch_order[1]] *
+                    conv_params_.work_group_size.y,
+                1);
+  } else {
+    wg.x = IntegralDivideRoundUp(grid_x, conv_params_.work_group_size.x);
+    wg.y = IntegralDivideRoundUp(grid_y, conv_params_.work_group_size.y);
+    wg.z = IntegralDivideRoundUp(grid_z, conv_params_.work_group_size.z);
+    return int3(wg[conv_params_.work_group_launch_order[0]] *
+                    conv_params_.work_group_size.x,
+                wg[conv_params_.work_group_launch_order[1]] *
+                    conv_params_.work_group_size.y,
+                wg[conv_params_.work_group_launch_order[2]] *
+                    conv_params_.work_group_size.z);
+  }
 }
 
 Status ConvPowerVR::Tune(const TuningParameters& params) {
@@ -248,33 +322,22 @@ std::string GenerateConvPowerVR1x1(
     c += "    int4 stride_padding,           \n";
     c += "    int4 kernel_dilation,          \n";
   }
+  if (conv_params.linear_hw) {
+    c += "    int task_size_x,               \n";
+  }
   c += "    int4 src_size,                   \n";
   c += "    int4 dst_size                    \n";
   c += ") {\n";
-  int3 launch_remap;
-  launch_remap[conv_params.work_group_launch_order.x] = 0;
-  launch_remap[conv_params.work_group_launch_order.y] = 1;
-  launch_remap[conv_params.work_group_launch_order.z] = 2;
-  if (conv_params.work_group_launch_order[0] == 0) {
-    c += "  int X = get_global_id(0) * " + std::to_string(block_size.x) + ";\n";
-  } else {
-    c += "  int X = (get_group_id(" + std::to_string(launch_remap[0]) +
-         ") * get_local_size(0) + get_local_id(0)) * " +
-         std::to_string(block_size.x) + ";\n";
+  c += GenerateBlockCoords(conv_params.block_size,
+                           conv_params.work_group_launch_order,
+                           conv_params.linear_hw);
+  std::vector<std::string> dst_x(conv_params.block_size.x);
+  for (int x = 0; x < conv_params.block_size.x; ++x) {
+    dst_x[x] = "(X + " + std::to_string(x) + ")";
   }
-  if (conv_params.work_group_launch_order[1] == 1) {
-    c += "  int Y = get_global_id(1) * " + std::to_string(block_size.y) + ";\n";
-  } else {
-    c += "  int Y = (get_group_id(" + std::to_string(launch_remap[1]) +
-         ") * get_local_size(1) + get_local_id(1)) * " +
-         std::to_string(block_size.y) + ";\n";
-  }
-  if (conv_params.work_group_launch_order[2] == 2) {
-    c += "  int Z = get_global_id(2) * " + std::to_string(block_size.z) + ";\n";
-  } else {
-    c += "  int Z = (get_group_id(" + std::to_string(launch_remap[2]) +
-         ") * get_local_size(2) + get_local_id(2)) * " +
-         std::to_string(block_size.z) + ";\n";
+  std::vector<std::string> dst_y(conv_params.block_size.y);
+  for (int y = 0; y < conv_params.block_size.y; ++y) {
+    dst_y[y] = "(Y + " + std::to_string(y) + ")";
   }
   if (!need_local_mem) {
     c += "  if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) {\n";
@@ -283,8 +346,12 @@ std::string GenerateConvPowerVR1x1(
   }
   if (conv_params.weights_upload_type ==
       ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
-    c += "  int lid = get_local_id(1) * " + std::to_string(work_group_size.x) +
-         " + get_local_id(0);\n";
+    if (conv_params.linear_hw) {
+      c += "  int lid = get_local_id(0);\n";
+    } else {
+      c += "  int lid = get_local_id(1) * " +
+           std::to_string(work_group_size.x) + " + get_local_id(0);\n";
+    }
   }
   for (int z = 0; z < block_size.z; ++z) {
     for (int y = 0; y < block_size.y; ++y) {
@@ -296,20 +363,18 @@ std::string GenerateConvPowerVR1x1(
   }
   if (!is1x1) {
     for (int x = 0; x < block_size.x; ++x) {
-      const std::string xc = "(X + " + std::to_string(x) + ")";
       if (stride_correction) {
         c += "  int xc" + std::to_string(x) + " = " +
-             GetXStrideCorrected(xc, "src_size.w", "stride_padding.x",
+             GetXStrideCorrected(dst_x[x], "src_size.w", "stride_padding.x",
                                  "stride_padding.z") +
              ";\n";
       } else {
-        c += "  int xc" + std::to_string(x) + " = " + xc +
+        c += "  int xc" + std::to_string(x) + " = " + dst_x[x] +
              " * stride_padding.x + stride_padding.z;\n";
       }
     }
     for (int y = 0; y < block_size.y; ++y) {
-      const std::string yc = "(Y + " + std::to_string(y) + ")";
-      c += "  int yc" + std::to_string(y) + " = " + yc +
+      c += "  int yc" + std::to_string(y) + " = " + dst_y[y] +
            " * stride_padding.y + stride_padding.w;\n";
     }
   }
@@ -373,10 +438,8 @@ std::string GenerateConvPowerVR1x1(
       const std::string yck = "yck" + std::to_string(y);
       for (int x = 0; x < block_size.x; ++x) {
         const std::string xck = "xck" + std::to_string(x);
-        std::string xc =
-            is1x1 ? "min(X + " + std::to_string(x) + ", src_size.x - 1)" : xck;
-        std::string yc =
-            is1x1 ? "min(Y + " + std::to_string(y) + ", src_size.y - 1)" : yck;
+        std::string xc = is1x1 ? "min(" + dst_x[x] + ", src_size.x - 1)" : xck;
+        std::string yc = is1x1 ? "min(" + dst_y[y] + ", src_size.y - 1)" : yck;
         std::string id = std::to_string(y) + std::to_string(x);
         c += "  int src_a_" + id + " = " + yc + " * src_size.x + " + xc + ";\n";
       }
@@ -408,10 +471,8 @@ std::string GenerateConvPowerVR1x1(
           c += "    src_a_" + id + " += src_layer_offset;\n";
         } else {
           std::string id = std::to_string(y) + std::to_string(x);
-          const std::string xc =
-              is1x1 ? "X + " + std::to_string(x) : "xck" + std::to_string(x);
-          const std::string yc =
-              is1x1 ? "Y + " + std::to_string(y) : "yck" + std::to_string(y);
+          const std::string xc = is1x1 ? dst_x[x] : "xck" + std::to_string(x);
+          const std::string yc = is1x1 ? dst_y[y] : "yck" + std::to_string(y);
           c += "    src" + id + " = " +
                src_tensor.ReadAsTypeWHS(conv_params.weights_data_type, xc, yc,
                                         "s", mode) +
@@ -522,8 +583,8 @@ std::string GenerateConvPowerVR1x1(
     c += "    FLT4 bias_val = TO_FLT4(weights_cache[" + sz + "]);\n";
     for (int y = 0; y < block_size.y; ++y) {
       for (int x = 0; x < block_size.x; ++x) {
-        const std::string xs = "X + " + std::to_string(x);
-        const std::string ys = "Y + " + std::to_string(y);
+        const std::string xs = dst_x[x];
+        const std::string ys = dst_y[y];
         const std::string zs = "Z + " + sz;
         const std::string r_id = sz + std::to_string(y) + std::to_string(x);
         bool need_x_check = x != 0;
@@ -552,18 +613,27 @@ std::string GenerateConvPowerVR1x1(
 
 ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
     const CLDevice& device, const OperationDef& definition, int src_depth,
-    int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1) const {
+    int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1,
+    bool different_weights_for_height) const {
   ConvParams conv_params;
+  conv_params.linear_hw = false;
   conv_params.weights_data_type =
       DeduceDataTypeFromPrecision(definition.precision);
   conv_params.x_kernel_is_1 = x_kernel_is_1;
   conv_params.y_kernel_is_1 = y_kernel_is_1;
-  conv_params.different_weights_for_height = false;
+  conv_params.different_weights_for_height = different_weights_for_height;
   if (device.IsNvidia()) {
+    if (different_weights_for_height) {
+      conv_params.work_group_size = int3(32, 1, 1);
+      conv_params.work_group_launch_order = int3(2, 0, 1);
+      conv_params.fixed_work_group_size = true;
+    } else {
+      conv_params.linear_hw = true;
+      conv_params.work_group_size = int3(32, 1, 1);
+      conv_params.work_group_launch_order = int3(1, 0, 2);
+      conv_params.fixed_work_group_size = true;
+    }
     conv_params.block_size = int3(1, 1, 4);
-    conv_params.work_group_size = int3(8, 4, 1);
-    conv_params.work_group_launch_order = int3(2, 0, 1);
-    conv_params.fixed_work_group_size = true;
     conv_params.src_depth_loop_size = 1;
     conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
     if (dst_depth % 4 == 0 || dst_depth >= 8) {
@@ -580,13 +650,20 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
       conv_params.src_depth_loop_size = 4;
     }
   } else if (device.IsPowerVR()) {
+    if (different_weights_for_height) {
+      conv_params.work_group_size = int3(32, 1, 1);
+      conv_params.work_group_launch_order = int3(2, 0, 1);
+      conv_params.fixed_work_group_size = true;
+    } else {
+      conv_params.linear_hw = true;
+      conv_params.work_group_size = int3(32, 1, 1);
+      conv_params.work_group_launch_order = int3(1, 0, 2);
+      conv_params.fixed_work_group_size = true;
+    }
     conv_params.weights_data_type =
         definition.precision == CalculationsPrecision::F16 ? DataType::FLOAT16
                                                            : DataType::FLOAT32;
     conv_params.block_size = int3(1, 1, 4);
-    conv_params.work_group_size = int3(8, 4, 1);
-    conv_params.work_group_launch_order = int3(2, 0, 1);
-    conv_params.fixed_work_group_size = true;
     conv_params.src_depth_loop_size = 1;
     conv_params.weights_upload_type =
         WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
@@ -619,16 +696,22 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
         }
       }
       conv_params.block_size.x = 2;
-      conv_params.work_group_size = int3(4, 8, 1);
     }
   } else if (device.IsAMD()) {
+    if (different_weights_for_height) {
+      conv_params.work_group_size = int3(32, 1, 1);
+      conv_params.work_group_launch_order = int3(2, 0, 1);
+      conv_params.fixed_work_group_size = true;
+    } else {
+      conv_params.work_group_size = int3(8, 4, 1);
+      conv_params.work_group_launch_order = int3(2, 0, 1);
+      conv_params.fixed_work_group_size = true;
+    }
+
     conv_params.block_size = int3(2, 1, 1);
     if (x_kernel_is_1 && y_kernel_is_1) {
       conv_params.block_size.y = 2;
     }
-    conv_params.work_group_size = int3(8, 4, 1);
-    conv_params.work_group_launch_order = int3(2, 0, 1);
-    conv_params.fixed_work_group_size = true;
     conv_params.src_depth_loop_size = 1;
     conv_params.weights_upload_type = WeightsUploadType::CONSTANT_MEM;
     if (dst_depth % 8 == 0 || dst_depth >= 32) {
@@ -694,7 +777,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
                              attr.padding.prepended.h == 0 &&
                              attr.padding.appended.h == 0;
   return GuessBestParams(device, definition, src_depth, dst_depth,
-                         x_kernel_is_1, y_kernel_is_1);
+                         x_kernel_is_1, y_kernel_is_1, false);
 }
 
 ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
@@ -702,8 +785,8 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
     const FullyConnectedAttributes& attr) const {
   const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
   const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
-  ConvPowerVR::ConvParams params =
-      GuessBestParams(device, definition, src_depth, dst_depth, true, true);
+  ConvPowerVR::ConvParams params = GuessBestParams(
+      device, definition, src_depth, dst_depth, true, true, false);
   params.work_group_size.x *= params.work_group_size.y;
   params.work_group_size.y = 1;
   params.block_size.x *= params.block_size.y;
@@ -716,13 +799,10 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd(
     const Convolution2DAttributes& attr) const {
   const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
   const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
-  ConvPowerVR::ConvParams params =
-      GuessBestParams(device, definition, src_depth, dst_depth, true, true);
-  params.work_group_size.x *= params.work_group_size.y;
-  params.work_group_size.y = 1;
+  ConvPowerVR::ConvParams params = GuessBestParams(
+      device, definition, src_depth, dst_depth, true, true, true);
   params.block_size.x *= params.block_size.y;
   params.block_size.y = 1;
-  params.different_weights_for_height = true;
   return params;
 }
 
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h
index 0832fcf91f0..110b983940a 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h
@@ -70,6 +70,7 @@ class ConvPowerVR : public GPUOperation {
     int3 work_group_size;
     int3 work_group_launch_order;
     bool fixed_work_group_size;
+    bool linear_hw;
     bool different_weights_for_height;
     int src_depth_loop_size;
     WeightsUploadType weights_upload_type;
@@ -127,7 +128,8 @@ class ConvPowerVR : public GPUOperation {
   ConvParams GuessBestParams(const CLDevice& device,
                              const OperationDef& definition, int src_depth,
                              int dst_depth, bool x_kernel_is_1,
-                             bool y_kernel_is_1) const;
+                             bool y_kernel_is_1,
+                             bool different_weights_for_height) const;
 
   Status BindArguments();
   int3 GetGridSize() const;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc
index a09236f77fc..7a2e54840b9 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc
@@ -248,13 +248,8 @@ Status GetBestWorkGroup(const TuningParameters& params, const CLKernel& kernel,
                         const int3& grid, int3* best_work_group) {
   switch (params.tuning_type) {
     case TuningType::FAST:
-      if (params.info->vendor != Vendor::QUALCOMM) {
-        *best_work_group = int3(8, 4, 1);
-        return OkStatus();
-      } else {
-        *best_work_group = GetWorkGroup(grid, kernel.GetMaxWorkGroupSize());
-        return OkStatus();
-      }
+      *best_work_group = GetWorkGroup(grid, kernel.GetMaxWorkGroupSize());
+      return OkStatus();
     case TuningType::EXHAUSTIVE:
       return GetBestWorkGroupAlignedToGrid(params, kernel, grid,
                                            best_work_group);
@@ -268,16 +263,16 @@ Status GetBestWorkGroupConv(const TuningParameters& params,
                             const CLKernel& kernel, const int3& grid,
                             int3* best_work_group) {
   switch (params.tuning_type) {
-    case TuningType::FAST:
-      if (params.info->vendor != Vendor::QUALCOMM) {
-        *best_work_group = int3(8, 4, 1);
-        return OkStatus();
-      } else {
-        int max_z_size = params.info->adreno_info.gpu_version < 400 ? 16 : 64;
-        *best_work_group =
-            GetWorkGroupConv(grid, kernel.GetMaxWorkGroupSize(), max_z_size);
-        return OkStatus();
+    case TuningType::FAST: {
+      int max_z_size = 16;
+      if (params.info->vendor == Vendor::QUALCOMM) {
+        max_z_size = params.info->adreno_info.gpu_version < 400 ? 16 : 64;
       }
+      max_z_size = std::min(max_z_size, params.info->max_work_group_sizes.z);
+      *best_work_group =
+          GetWorkGroupConv(grid, kernel.GetMaxWorkGroupSize(), max_z_size);
+      return OkStatus();
+    }
     case TuningType::EXHAUSTIVE:
       return GetBestWorkGroupAlignedToGrid(params, kernel, grid,
                                            best_work_group);
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD
index 695077ba8a8..908c1b91583 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD
@@ -83,7 +83,9 @@ cc_library(
         ":dw_convolution_selector",
         ":fully_connected_selector",
         ":simple_selectors",
+        "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/cl:model_hints",
+        "//tensorflow/lite/delegates/gpu/cl:storage_type_util",
         "//tensorflow/lite/delegates/gpu/cl:tensor_type",
         "//tensorflow/lite/delegates/gpu/cl/kernels:elementwise",
         "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
@@ -119,8 +121,10 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/cl/kernels:resize",
         "//tensorflow/lite/delegates/gpu/cl/kernels:softmax",
         "//tensorflow/lite/delegates/gpu/cl/kernels:softmax1x1",
+        "//tensorflow/lite/delegates/gpu/cl/kernels:space_to_depth",
         "//tensorflow/lite/delegates/gpu/cl/kernels:strided_slice",
         "//tensorflow/lite/delegates/gpu/cl/kernels:transpose",
+        "//tensorflow/lite/delegates/gpu/cl/kernels:winograd",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc
index 9e7876cee80..0103ca08b90 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc
@@ -47,6 +47,20 @@ Status SelectConvolutionAdreno(const Convolution2DAttributes& attr,
   return OkStatus();
 }
 
+Status SelectConvolutionWinogradAdreno(const Convolution2DAttributes& attr,
+                                       const BHWC& dst_shape,
+                                       const CreationContext& creation_context,
+                                       const OperationDef& op_def,
+                                       ModelHints hints,
+                                       std::unique_ptr<GPUOperation>* ptr) {
+  ConvTexture conv;
+  RETURN_IF_ERROR(
+      CreateConvTextureWino4x4To6x6(creation_context, op_def, attr, &conv));
+  *ptr = absl::make_unique<ConvTexture>(std::move(conv));
+
+  return OkStatus();
+}
+
 Status SelectConvolutionNVidia(const Convolution2DAttributes& attr,
                                const CreationContext& creation_context,
                                const OperationDef& op_def,
@@ -89,6 +103,25 @@ Status SelectConvolutionMali(const Convolution2DAttributes& attr,
   }
   return OkStatus();
 }
+
+Status SelectConvolutionWinogradMali(const Convolution2DAttributes& attr,
+                                     const CreationContext& creation_context,
+                                     const OperationDef& op_def,
+                                     std::unique_ptr<GPUOperation>* ptr) {
+  if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) {
+    ConvBuffer1x1 conv;
+    RETURN_IF_ERROR(
+        CreateConvBuffer1x1Wino4x4To6x6(creation_context, op_def, attr, &conv));
+    *ptr = absl::make_unique<ConvBuffer1x1>(std::move(conv));
+  } else {
+    ConvPowerVR conv;
+    RETURN_IF_ERROR(
+        CreateConvPowerVRWino4x4To6x6(creation_context, op_def, attr, &conv));
+    *ptr = absl::make_unique<ConvPowerVR>(std::move(conv));
+  }
+
+  return OkStatus();
+}
 }  // namespace
 
 Status SelectConvolution(const Convolution2DAttributes& attr,
@@ -113,6 +146,33 @@ Status SelectConvolution(const Convolution2DAttributes& attr,
   }
 }
 
+Status SelectConvolutionForWinograd(const Convolution2DAttributes& attr,
+                                    const BHWC& dst_shape,
+                                    const CreationContext& creation_context,
+                                    const OperationDef& op_def,
+                                    ModelHints hints,
+                                    std::unique_ptr<GPUOperation>* ptr) {
+  switch (creation_context.device->vendor()) {
+    case Vendor::QUALCOMM:
+      return SelectConvolutionWinogradAdreno(attr, dst_shape, creation_context,
+                                             op_def, hints, ptr);
+    case Vendor::POWERVR:
+    case Vendor::AMD:
+    case Vendor::NVIDIA: {
+      ConvPowerVR conv;
+      RETURN_IF_ERROR(
+          CreateConvPowerVRWino4x4To6x6(creation_context, op_def, attr, &conv));
+      *ptr = absl::make_unique<ConvPowerVR>(std::move(conv));
+      return OkStatus();
+    }
+    case Vendor::MALI:
+      return SelectConvolutionWinogradMali(attr, creation_context, op_def, ptr);
+    default:
+      return SelectConvolutionWinogradAdreno(attr, dst_shape, creation_context,
+                                             op_def, hints, ptr);
+  }
+}
+
 }  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h
index 7dd6c79eea0..dc0657ec47c 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h
@@ -34,6 +34,13 @@ Status SelectConvolution(const Convolution2DAttributes& attr,
                          const OperationDef& op_def, ModelHints hints,
                          std::unique_ptr<GPUOperation>* ptr);
 
+Status SelectConvolutionForWinograd(const Convolution2DAttributes& attr,
+                                    const BHWC& dst_shape,
+                                    const CreationContext& creation_context,
+                                    const OperationDef& op_def,
+                                    ModelHints hints,
+                                    std::unique_ptr<GPUOperation>* ptr);
+
 }  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
index 527c94b1fe2..29c246a2744 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
@@ -17,12 +17,14 @@ limitations under the License.
 
 #include "absl/strings/str_cat.h"
 #include "absl/types/any.h"
+#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
 #include "tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h"
+#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
@@ -51,6 +53,111 @@ bool IsChannelsBroadcastedForSecondInput(
          inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
          inputs[1]->tensor.shape.c == 1;
 }
+
+bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
+                                   const CLDevice& device,
+                                   const BHWC& dst_shape) {
+  const int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4);
+  const int tiles_y = IntegralDivideRoundUp(dst_shape.h, 4);
+  const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
+  const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
+  const bool suitable_attributes =
+      attr.weights.shape.w == 3 && attr.weights.shape.h == 3 &&
+      attr.dilations == HW(1, 1) && attr.strides == HW(1, 1);
+  const int min_depth = 32;
+  const bool recommended_channels =
+      dst_depth % 4 == 0 && src_depth >= min_depth && dst_depth >= min_depth;
+  const bool recommended_hw = tiles_x * tiles_y >= 128;
+  return suitable_attributes && recommended_channels && recommended_hw;
+}
+
+Status WinogradFromNode(const CreationContext& creation_context,
+                        const OperationDef& op_def, ModelHints hints,
+                        const BHWC& input_shape, const BHWC& output_shape,
+                        const Convolution2DAttributes& attr,
+                        GPUOperationsSubgraph* gpu_subgraph) {
+  if (!IsSuitableForWinograd4x4To6x6(attr, *creation_context.device,
+                                     output_shape)) {
+    return UnimplementedError("No implementation for this case.");
+  }
+
+  const int tiles_x = IntegralDivideRoundUp(output_shape.w, 4);
+  const int tiles_y = IntegralDivideRoundUp(output_shape.h, 4);
+  const BHWC shape_0{input_shape.b, 36, tiles_x * tiles_y, input_shape.c};
+  const BHWC shape_1{input_shape.b, 36, tiles_x * tiles_y, output_shape.c};
+  TensorDescriptor td_0;
+  td_0.storage_type = SelectBestStorageType(
+      *creation_context.context, *creation_context.device, shape_0,
+      op_def.src_tensors[0].storage_type, op_def.src_tensors[0].data_type,
+      op_def.src_tensors[0].layout);
+  td_0.data_type = op_def.src_tensors[0].data_type;
+  td_0.layout = op_def.src_tensors[0].layout;
+  TensorDescriptor td_1;
+  td_1.storage_type = SelectBestStorageType(
+      *creation_context.context, *creation_context.device, shape_1,
+      op_def.src_tensors[0].storage_type, op_def.src_tensors[0].data_type,
+      op_def.src_tensors[0].layout);
+  td_1.data_type = op_def.src_tensors[0].data_type;
+  td_1.layout = op_def.src_tensors[0].layout;
+  gpu_subgraph->new_tensors = {{shape_0, td_0}, {shape_1, td_1}};
+  gpu_subgraph->operations.clear();
+  gpu_subgraph->operations.resize(3);
+
+  OperationDef winograd_up_def;
+  winograd_up_def.precision = op_def.precision;
+  winograd_up_def.src_tensors.push_back(op_def.src_tensors[0]);
+  winograd_up_def.dst_tensors.push_back(td_0);
+  auto& winograd_up = gpu_subgraph->operations[0];
+  RETURN_IF_ERROR(SelectWinograd4x4To36(
+      creation_context, attr.padding, winograd_up_def, &winograd_up.operation));
+  winograd_up.input_ids = {0};
+  winograd_up.output_ids = {-1};
+
+  OperationDef conv_def;
+  conv_def.precision = op_def.precision;
+  conv_def.src_tensors.push_back(td_0);
+  conv_def.dst_tensors.push_back(td_1);
+  auto& conv = gpu_subgraph->operations[1];
+  conv.input_ids = {-1};
+  conv.output_ids = {-2};
+  RETURN_IF_ERROR(SelectConvolutionForWinograd(
+      attr, input_shape, creation_context, conv_def, hints, &conv.operation));
+
+  OperationDef winograd_down_def;
+  winograd_down_def.precision = op_def.precision;
+  winograd_down_def.src_tensors.push_back(td_1);
+  winograd_down_def.dst_tensors.push_back(op_def.dst_tensors[0]);
+  auto& winograd_down = gpu_subgraph->operations[2];
+  winograd_down.input_ids = {-2};
+  winograd_down.output_ids = {0};
+  auto bias_copy = attr.bias;
+  if (bias_copy.shape.v < attr.weights.shape.o) {
+    bias_copy.shape = Linear(attr.weights.shape.o);
+    bias_copy.data.resize(attr.weights.shape.o);
+  }
+  RETURN_IF_ERROR(SelectWinograd36To4x4(creation_context, winograd_down_def,
+                                        bias_copy, &winograd_down.operation));
+
+  return OkStatus();
+}
+
+std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
+    const std::vector<Value<TensorRef<BHWC>>*>& inputs,
+    const std::vector<Value<TensorRef<BHWC>>*>& outputs,
+    GPUOperationsSubgraph* gpu_subgraph) {
+  gpu_subgraph->operations.clear();
+  gpu_subgraph->new_tensors.clear();
+  gpu_subgraph->operations.push_back({});
+  for (int i = 0; i < inputs.size(); ++i) {
+    gpu_subgraph->operations[0].input_ids.push_back(i);
+  }
+  for (int i = 0; i < outputs.size(); ++i) {
+    gpu_subgraph->operations[0].output_ids.push_back(i);
+  }
+
+  return &gpu_subgraph->operations[0].operation;
+}
+
 }  // namespace
 
 Status GPUOperationFromNode(const CreationContext& creation_context,
@@ -59,15 +166,8 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
                             const std::vector<Value<TensorRef<BHWC>>*>& outputs,
                             const Node& node,
                             GPUOperationsSubgraph* gpu_subgraph) {
-  gpu_subgraph->operations.push_back({});
   std::unique_ptr<GPUOperation>* gpu_op =
-      &gpu_subgraph->operations[0].operation;
-  for (int i = 0; i < inputs.size(); ++i) {
-    gpu_subgraph->operations[0].input_ids.push_back(i);
-  }
-  for (int i = 0; i < outputs.size(); ++i) {
-    gpu_subgraph->operations[0].output_ids.push_back(i);
-  }
+      InitSingleOpSubgraph(inputs, outputs, gpu_subgraph);
   auto op_type = OperationTypeFromString(node.operation.type);
   switch (op_type) {
     case OperationType::ADD: {
@@ -111,9 +211,17 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
     case OperationType::CONVOLUTION_2D: {
       auto attr =
           absl::any_cast<Convolution2DAttributes>(node.operation.attributes);
-      auto input = inputs[0];
-      return SelectConvolution(attr, input->tensor.shape, creation_context,
-                               op_def, hints, gpu_op);
+      auto input_shape = inputs[0]->tensor.shape;
+      auto output_shape = outputs[0]->tensor.shape;
+      if (WinogradFromNode(creation_context, op_def, hints, input_shape,
+                           output_shape, attr, gpu_subgraph)
+              .ok()) {
+        return OkStatus();
+      } else {
+        gpu_op = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph);
+        return SelectConvolution(attr, input_shape, creation_context, op_def,
+                                 hints, gpu_op);
+      }
     }
     case OperationType::CONVOLUTION_TRANSPOSED: {
       auto attr = absl::any_cast<ConvolutionTransposedAttributes>(
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
index 92ef85b1779..22244351bd7 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
@@ -38,6 +38,7 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h"
 #include "tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h"
 #include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h"
+#include "tensorflow/lite/delegates/gpu/cl/kernels/winograd.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 
 namespace tflite {
@@ -195,6 +196,28 @@ void SelectTranspose(const TransposeAttributes& attr,
   *ptr = absl::make_unique<Transpose>(std::move(operation));
 }
 
+Status SelectWinograd4x4To36(const CreationContext& creation_context,
+                             const Padding2D& padding,
+                             const OperationDef& op_def,
+                             std::unique_ptr<GPUOperation>* ptr) {
+  Winograd4x4To36 operation;
+  RETURN_IF_ERROR(
+      CreateWinograd4x4To36(creation_context, op_def, padding, &operation));
+  *ptr = absl::make_unique<Winograd4x4To36>(std::move(operation));
+  return OkStatus();
+}
+
+Status SelectWinograd36To4x4(
+    const CreationContext& creation_context, const OperationDef& op_def,
+    const ::tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
+    std::unique_ptr<GPUOperation>* ptr) {
+  Winograd36To4x4 operation;
+  RETURN_IF_ERROR(
+      CreateWinograd36To4x4(creation_context, op_def, biases, &operation));
+  *ptr = absl::make_unique<Winograd36To4x4>(std::move(operation));
+  return OkStatus();
+}
+
 }  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
index 5a830994f75..fd29ebc0e91 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
@@ -90,6 +90,16 @@ void SelectTranspose(const TransposeAttributes& attr,
                      const OperationDef& op_def,
                      std::unique_ptr<GPUOperation>* ptr);
 
+Status SelectWinograd4x4To36(const CreationContext& creation_context,
+                             const Padding2D& padding,
+                             const OperationDef& op_def,
+                             std::unique_ptr<GPUOperation>* ptr);
+
+Status SelectWinograd36To4x4(
+    const CreationContext& creation_context, const OperationDef& op_def,
+    const ::tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
+    std::unique_ptr<GPUOperation>* ptr);
+
 }  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc
new file mode 100644
index 00000000000..26eb3ad3538
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc
@@ -0,0 +1,141 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
+
+#include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
+#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
+#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+
+namespace tflite {
+namespace gpu {
+namespace cl {
+bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device,
+                              const BHWDC& shape,
+                              const TensorDescriptor& descriptor) {
+  const int slices = IntegralDivideRoundUp(shape.c, 4);
+  switch (descriptor.storage_type) {
+    case TensorStorageType::BUFFER: {
+      const int flt4_size =
+          4 * (descriptor.data_type == DataType::FLOAT32 ? 4 : 2);
+      const int buffer_size =
+          shape.b * shape.w * shape.h * shape.d * slices * flt4_size;
+      return buffer_size <= device.GetInfo().buffer_max_size;
+    }
+    case TensorStorageType::IMAGE_BUFFER:
+      return shape.b * shape.w * shape.h * shape.d * slices <=
+             device.GetInfo().image_buffer_max_size;
+    case TensorStorageType::TEXTURE_3D:
+      if (device.cl_version() < OpenCLVersion::CL_1_2 && slices == 1) {
+        // clCreateImage3D (that used in CL 1.0/1.1) can not create image with
+        // depth = 1 by specification;
+        return false;
+      }
+      return shape.w * shape.b <= device.GetInfo().image3d_max_width &&
+             shape.h <= device.GetInfo().image3d_max_height &&
+             slices * shape.d <= device.GetInfo().image3d_max_depth;
+    case TensorStorageType::TEXTURE_ARRAY:
+      // Bug on some Adreno. b/131099086
+      if (slices == 1 && !device.SupportsOneLayerTextureArray()) {
+        return false;
+      }
+      return shape.w * shape.b <= device.GetInfo().image2d_max_width &&
+             shape.h <= device.GetInfo().image2d_max_height &&
+             slices * shape.d <= device.GetInfo().image_array_max_layers;
+    case TensorStorageType::TEXTURE_2D:
+      return shape.w * shape.b * shape.d <=
+                 device.GetInfo().image2d_max_width &&
+             shape.h * slices <= device.GetInfo().image2d_max_height;
+    case TensorStorageType::SINGLE_TEXTURE_2D:
+      return shape.c <= 4 &&
+             context.IsFloatTexture2DSupported(shape.c, descriptor.data_type) &&
+             shape.w * shape.b * shape.d <=
+                 device.GetInfo().image2d_max_width &&
+             shape.h <= device.GetInfo().image2d_max_height;
+    default:
+      return false;
+  }
+}
+
+bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device,
+                              const BHWC& shape,
+                              const TensorDescriptor& descriptor) {
+  const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c);
+  return CanCreateTensorWithShape(context, device, shape5D, descriptor);
+}
+
+TensorStorageType SelectBestStorageType(const CLContext& context,
+                                        const CLDevice& device,
+                                        const BHWC& shape,
+                                        const TensorStorageType& desired,
+                                        const DataType& data_type,
+                                        const Layout& layout) {
+  if (CanCreateTensorWithShape(context, device, shape,
+                               TensorDescriptor{data_type, desired, layout})) {
+    return desired;
+  }
+  auto GetBestTypeAfterTextureArray = [&]() {
+    if (device.SupportsImageBuffer() &&
+        CanCreateTensorWithShape(
+            context, device, shape,
+            TensorDescriptor{data_type, TensorStorageType::IMAGE_BUFFER,
+                             layout})) {
+      return TensorStorageType::IMAGE_BUFFER;
+    } else {
+      return TensorStorageType::BUFFER;
+    }
+  };
+  auto GetBestTypeAfterTexture2D = [&]() {
+    if (device.SupportsTextureArray() &&
+        CanCreateTensorWithShape(
+            context, device, shape,
+            TensorDescriptor{data_type, TensorStorageType::TEXTURE_ARRAY,
+                             layout})) {
+      return TensorStorageType::TEXTURE_ARRAY;
+    } else {
+      return GetBestTypeAfterTextureArray();
+    }
+  };
+  auto GetBestTypeAfterTexture3D = [&]() {
+    if (CanCreateTensorWithShape(
+            context, device, shape,
+            TensorDescriptor{data_type, TensorStorageType::TEXTURE_2D,
+                             layout})) {
+      return TensorStorageType::TEXTURE_2D;
+    } else {
+      return GetBestTypeAfterTexture2D();
+    }
+  };
+  switch (desired) {
+    case TensorStorageType::TEXTURE_2D:
+    case TensorStorageType::SINGLE_TEXTURE_2D:
+      return GetBestTypeAfterTexture2D();
+    case TensorStorageType::TEXTURE_ARRAY:
+      return GetBestTypeAfterTextureArray();
+    case TensorStorageType::TEXTURE_3D:
+      return GetBestTypeAfterTexture3D();
+    case TensorStorageType::IMAGE_BUFFER:
+    case TensorStorageType::BUFFER:
+      return TensorStorageType::BUFFER;
+    default:
+      return TensorStorageType::BUFFER;
+  }
+}
+
+}  // namespace cl
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/storage_type_util.h b/tensorflow/lite/delegates/gpu/cl/storage_type_util.h
new file mode 100644
index 00000000000..87fc2206e81
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/cl/storage_type_util.h
@@ -0,0 +1,48 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_STORAGE_TYPE_UTIL_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_CL_STORAGE_TYPE_UTIL_H_
+
+#include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
+#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
+#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+
+namespace tflite {
+namespace gpu {
+namespace cl {
+
+bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device,
+                              const BHWDC& shape,
+                              const TensorDescriptor& descriptor);
+
+bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device,
+                              const BHWC& shape,
+                              const TensorDescriptor& descriptor);
+
+TensorStorageType SelectBestStorageType(const CLContext& context,
+                                        const CLDevice& device,
+                                        const BHWC& shape,
+                                        const TensorStorageType& desired,
+                                        const DataType& data_type,
+                                        const Layout& layout);
+
+}  // namespace cl
+}  // namespace gpu
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_STORAGE_TYPE_UTIL_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc
index 8423613440e..610ba407eb9 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.cc
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc
@@ -331,60 +331,6 @@ Status Tensor::ReadData(CLCommandQueue* queue, Tensor5DFloat32* dst) const {
   return ReadDataBHWDC(absl::MakeSpan(dst->data), queue);
 }
 
-bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device,
-                              const BHWC& shape,
-                              const TensorDescriptor& descriptor) {
-  const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c);
-  return CanCreateTensorWithShape(context, device, shape5D, descriptor);
-}
-
-bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device,
-                              const BHWDC& shape,
-                              const TensorDescriptor& descriptor) {
-  const int slices = IntegralDivideRoundUp(shape.c, 4);
-  switch (descriptor.storage_type) {
-    case TensorStorageType::BUFFER: {
-      const int flt4_size =
-          4 * (descriptor.data_type == DataType::FLOAT32 ? 4 : 2);
-      const int buffer_size =
-          shape.b * shape.w * shape.h * shape.d * slices * flt4_size;
-      return buffer_size <= device.GetInfo().buffer_max_size;
-    }
-    case TensorStorageType::IMAGE_BUFFER:
-      return shape.b * shape.w * shape.h * shape.d * slices <=
-             device.GetInfo().image_buffer_max_size;
-    case TensorStorageType::TEXTURE_3D:
-      if (device.cl_version() < OpenCLVersion::CL_1_2 && slices == 1) {
-        // clCreateImage3D (that used in CL 1.0/1.1) can not create image with
-        // depth = 1 by specification;
-        return false;
-      }
-      return shape.w * shape.b <= device.GetInfo().image3d_max_width &&
-             shape.h <= device.GetInfo().image3d_max_height &&
-             slices * shape.d <= device.GetInfo().image3d_max_depth;
-    case TensorStorageType::TEXTURE_ARRAY:
-      // Bug on some Adreno. b/131099086
-      if (slices == 1 && !device.SupportsOneLayerTextureArray()) {
-        return false;
-      }
-      return shape.w * shape.b <= device.GetInfo().image2d_max_width &&
-             shape.h <= device.GetInfo().image2d_max_height &&
-             slices * shape.d <= device.GetInfo().image_array_max_layers;
-    case TensorStorageType::TEXTURE_2D:
-      return shape.w * shape.b * shape.d <=
-                 device.GetInfo().image2d_max_width &&
-             shape.h * slices <= device.GetInfo().image2d_max_height;
-    case TensorStorageType::SINGLE_TEXTURE_2D:
-      return shape.c <= 4 &&
-             context.IsFloatTexture2DSupported(shape.c, descriptor.data_type) &&
-             shape.w * shape.b * shape.d <=
-                 device.GetInfo().image2d_max_width &&
-             shape.h <= device.GetInfo().image2d_max_height;
-    default:
-      return false;
-  }
-}
-
 Status CreateTensor(const CLContext& context, const CLDevice& device,
                     const BHWC& shape, const TensorDescriptor& descriptor,
                     Tensor* result) {
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h
index efc09480a39..34a45436386 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.h
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.h
@@ -145,14 +145,6 @@ class Tensor {
 
 using TensorPtr = std::shared_ptr<Tensor>;
 
-bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device,
-                              const BHWC& shape,
-                              const TensorDescriptor& descriptor);
-
-bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device,
-                              const BHWDC& shape,
-                              const TensorDescriptor& descriptor);
-
 Status AllocateTensorMemory(const CLContext& context, const CLDevice& device,
                             const BHWC& shape,
                             const TensorDescriptor& descriptor,
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/concat_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/concat_builder.cc
index 05f30329215..430aae2443a 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/builders/concat_builder.cc
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/concat_builder.cc
@@ -41,6 +41,10 @@ TfLiteStatus ConcatOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
   int tensor_id;
 
   // Input data tensors.
+  // input_bound_minimum & input_bound_maximum track the minimum & maximum
+  // min/max bounds across all inputs.
+  float input_bound_minimum = std::numeric_limits<float>::max();
+  float input_bound_maximum = std::numeric_limits<float>::min();
   input_minima_.reserve(inputs->size);
   input_maxima_.reserve(inputs->size);
   for (int i = 0; i < inputs->size; ++i) {
@@ -53,6 +57,8 @@ TfLiteStatus ConcatOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
         std::numeric_limits<uint8_t>::max()));
     input_minima_.push_back(data_min);
     input_maxima_.push_back(data_max);
+    if (data_min < input_bound_minimum) input_bound_minimum = data_min;
+    if (data_max > input_bound_maximum) input_bound_maximum = data_max;
   }
 
   // Minima tensors.
@@ -96,19 +102,27 @@ TfLiteStatus ConcatOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
   auto* output_max_const = graph_builder_->AddConstNodeWithData(
       quant_bound_shape, (char*)&output_max_, sizeof(output_max_));
 
-  auto* requantize_op = graph_builder_->AddNode(GetTFLiteNodeID());
-  requantize_op->SetOpType(OP_Requantize_8to8);
-  requantize_op->AddInput(concat_out);
-  requantize_op->AddInput(concat_out_min);
-  requantize_op->AddInput(concat_out_max);
-  requantize_op->AddInput(TensorID(output_min_const->GetID(), 0));
-  requantize_op->AddInput(TensorID(output_max_const->GetID(), 0));
-  node_output_ =
-      requantize_op->AddOutput(sizeof(uint8_t), 4,
-                               {output_batch_size, output_height_size,
-                                output_width_size, output_depth_size});
-  requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1});
-  requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1});
+  if (output_min_ == input_bound_minimum &&
+      output_max_ == input_bound_maximum) {
+    // If the input min/max (across all tensors) is same as the output min/max,
+    // Hexagon's Requantize causes errors in InceptionV3.
+    // TODO(b/150137234): Figure out why this is.
+    node_output_ = concat_out;
+  } else {
+    auto* requantize_op = graph_builder_->AddNode(GetTFLiteNodeID());
+    requantize_op->SetOpType(OP_Requantize_8to8);
+    requantize_op->AddInput(concat_out);
+    requantize_op->AddInput(concat_out_min);
+    requantize_op->AddInput(concat_out_max);
+    requantize_op->AddInput(TensorID(output_min_const->GetID(), 0));
+    requantize_op->AddInput(TensorID(output_max_const->GetID(), 0));
+    node_output_ =
+        requantize_op->AddOutput(sizeof(uint8_t), 4,
+                                 {output_batch_size, output_height_size,
+                                  output_width_size, output_depth_size});
+    requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1});
+    requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1});
+  }
 
   return kTfLiteOk;
 }
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/concat_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/concat_test.cc
index bc4026d795b..c66ae17a71a 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/concat_test.cc
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/concat_test.cc
@@ -12,12 +12,36 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include <random>
+
 #include <gtest/gtest.h>
 #include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h"
 
 namespace tflite {
 using testing::ElementsAreArray;
 
+void GenerateUniformRandomVector(int size, float min, float max,
+                                 std::minstd_rand* random_engine,
+                                 std::vector<float>* result) {
+  // Never use std::uniform_*_distribution in tests, it's
+  // implementation-defined. Likewise, don't use std::default_random_engine,
+  // implementation-defined. Implementation-defined is bad because it means that
+  // any toolchain update or new platform may run into test failures.
+  // std::minstd_rand is a standard instantiation of
+  // std::linear_congruential_engine, the cheapest generator in c++11 stdlib,
+  // it's good enough here.
+  result->resize(size);
+  for (int i = 0; i < size; i++) {
+    // We don't care whether the `max` value may ever be produced exactly.
+    // It may actually be thanks to rounding, as std::minstd_rand::modulus
+    // is 2^31 - 1 is greater than the inverse float epsilon.
+    float random_value_scaled_0_1 =
+        (*random_engine)() *
+        (1.0f / static_cast<float>(std::minstd_rand::modulus));
+    (*result)[i] = min + (max - min) * random_value_scaled_0_1;
+  }
+}
+
 class QuantizedConcatenationOpModel : public SingleOpModelWithHexagon {
  public:
   QuantizedConcatenationOpModel(const std::vector<TensorData>& input_template,
@@ -37,7 +61,7 @@ class QuantizedConcatenationOpModel : public SingleOpModelWithHexagon {
   }
 
   template <typename T>
-  void SetInput(int index, std::initializer_list<float> data) {
+  void SetInput(int index, std::vector<float> data) {
     QuantizeAndPopulate<T>(index, data);
   }
 
@@ -95,4 +119,50 @@ TEST(QuantizedConcatenationOpModel, FourInputsQuantizedMixedRange) {
                   /*max_abs_error=*/0.2)));
 }
 
+// If the input min/max (across all tensors) is same as the output min/max,
+// Hexagon's Requantize causes errors in InceptionV3.
+// So, we diable it for that case in the builder.
+// This unit test ensures that the math still works.
+TEST(QuantizedConcatenationOpModel, FourInputsQuantizedMixedRange_LargeData) {
+  // Problem specification.
+  // Adapted from CONCAT node at #15 in Inceptionv3 quantized.
+  std::vector<float> params1 = {0, 11.30514f};
+  std::vector<float> params2 = {0, 10.38416f};
+  std::vector<float> params3 = {0, 13.52495f};
+  std::vector<float> params4 = {0, 5.883808f};
+  std::vector<float> params_output = {0, 13.52495f};
+  QuantizedConcatenationOpModel m0(
+      {{TensorType_UINT8, {1, 35, 35, 64}, params1[0], params1[1]},
+       {TensorType_UINT8, {1, 35, 35, 64}, params2[0], params2[1]},
+       {TensorType_UINT8, {1, 35, 35, 96}, params3[0], params3[1]},
+       {TensorType_UINT8, {1, 35, 35, 32}, params4[0], params4[1]}},
+      /*axis=*/3, {TensorType_UINT8, {}, params_output[0], params_output[1]});
+
+  // Generate random data.
+  std::minstd_rand random_engine;
+  std::vector<float> data1, data2, data3, data4;
+  int num_elements_multiplier = 1 * 35 * 35;
+  GenerateUniformRandomVector(num_elements_multiplier * 64, params1[0],
+                              params1[1], &random_engine, &data1);
+  GenerateUniformRandomVector(num_elements_multiplier * 64, params2[0],
+                              params2[1], &random_engine, &data2);
+  GenerateUniformRandomVector(num_elements_multiplier * 96, params3[0],
+                              params3[1], &random_engine, &data3);
+  GenerateUniformRandomVector(num_elements_multiplier * 32, params4[0],
+                              params4[1], &random_engine, &data4);
+  m0.SetInput<uint8_t>(0, data1);
+  m0.SetInput<uint8_t>(1, data2);
+  m0.SetInput<uint8_t>(2, data3);
+  m0.SetInput<uint8_t>(3, data4);
+
+  // Reference output.
+  m0.Invoke();
+  std::vector<float> reference_output = m0.GetDequantizedOutput<uint8_t>();
+
+  m0.ApplyDelegateAndInvoke();
+  EXPECT_THAT(m0.GetDequantizedOutput<uint8_t>(),
+              ElementsAreArray(ArrayFloatNear(reference_output,
+                                              /*max_abs_error=*/0.1)));
+}
+
 }  // namespace tflite
diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate.cc b/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate.cc
index cba74b3df4f..b0507693e35 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate.cc
+++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate.cc
@@ -132,6 +132,30 @@ class HexagonDelegate : public TfLiteDelegate {
     if (hexagon_nn == nullptr) {
       return false;
     }
+    if (hexagon_nn->hexagon_nn_version != nullptr &&
+        hexagon_nn->hexagon_nn_hexagon_interface_version) {
+      int hexagon_nn_version = -1;
+      int hexagon_interface_version =
+          hexagon_nn->hexagon_nn_hexagon_interface_version();
+      if (hexagon_nn->hexagon_nn_version(&hexagon_nn_version) != 0) {
+        TFLITE_LOG_PROD(tflite::TFLITE_LOG_WARNING,
+                        "Failed to fetch Hexagon NN version. This might be "
+                        "because you're using incompatible versions of "
+                        "libhexagon_interface and libhexagon_nn_skel. "
+                        "You must use compatible versions. "
+                        "Refer to Tensorflow Lite Hexagon Delegate Guide.");
+        return false;
+      }
+      if (hexagon_nn_version != hexagon_interface_version) {
+        TFLITE_LOG_PROD(
+            tflite::TFLITE_LOG_WARNING,
+            "Incompatible versions between interface library and "
+            "libhexagon_skel %d vs %d. You must use compatible versions. "
+            "Refer to Tensorflow Lite Hexagon Delegate Guide.",
+            hexagon_interface_version, hexagon_nn_version);
+        return false;
+      }
+    }
     return hexagon_nn->hexagon_nn_is_device_supported &&
            hexagon_nn->hexagon_nn_is_device_supported();
   }
diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.cc b/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.cc
index cd86a95c86f..3cf64496e00 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.cc
+++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.cc
@@ -76,6 +76,8 @@ HexagonNN CreateNewHexagonInterface() {
   LOAD_FUNCTION(libhexagon_interface, hexagon_nn_is_device_supported,
                 hexagon_nn);
   LOAD_FUNCTION(libhexagon_interface, hexagon_nn_version, hexagon_nn);
+  LOAD_FUNCTION(libhexagon_interface, hexagon_nn_hexagon_interface_version,
+                hexagon_nn);
   hexagon_nn.interface_loaded = successfully_loaded;
   return hexagon_nn;
 }
diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.h b/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.h
index b41f55b0a42..fb81a650205 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.h
+++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.h
@@ -126,6 +126,9 @@ struct HexagonNN {
   // Otherwise.
   hexagon_nn_is_device_supported_fn* hexagon_nn_is_device_supported;
 
+  // Returns the version number of the interface library.
+  hexagon_nn_hexagon_interface_version_fn* hexagon_nn_hexagon_interface_version;
+
   hexagon_nn_version_fn* hexagon_nn_version = nullptr;
 
   bool interface_loaded = false;
diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.h b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.h
index 812eb792a5c..d142edd4d03 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.h
+++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.h
@@ -21,6 +21,7 @@ extern "C" {
 void hexagon_nn_global_teardown(void);
 void hexagon_nn_global_init(void);
 bool hexagon_nn_is_device_supported();
+int hexagon_nn_hexagon_interface_version(void);
 #ifdef __cplusplus
 }
 #endif
diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/version_scripts.lds b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/version_scripts.lds
index 1e254e7eb0e..7b003afc770 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/version_scripts.lds
+++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/version_scripts.lds
@@ -19,6 +19,7 @@ VERS_1.0 {
     hexagon_nn_global_init;
     hexagon_nn_is_device_supported;
     hexagon_nn_version;
+    hexagon_nn_hexagon_interface_version;
 
   # Hide everything else.
   local:
diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn_interface.h b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn_interface.h
index 74f6ee11de0..cfb61a59182 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn_interface.h
+++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn_interface.h
@@ -56,4 +56,7 @@ using hexagon_nn_is_device_supported_fn =
 
 using hexagon_nn_version_fn = decltype(hexagon_nn_version);
 
+using hexagon_nn_hexagon_interface_version_fn =
+    decltype(hexagon_nn_hexagon_interface_version);
+
 #endif  // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_HEXAGON_NN_INTERFACE_H_
diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
index c7662d149e9..4a1e7d4a65e 100644
--- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
+++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
@@ -24,6 +24,7 @@ import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.nio.MappedByteBuffer;
 import java.nio.channels.FileChannel;
+import java.nio.charset.Charset;
 import java.util.ArrayList;
 import java.util.List;
 import org.checkerframework.checker.nullness.qual.NonNull;
@@ -46,10 +47,28 @@ public class FileUtil {
   @NonNull
   public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath)
       throws IOException {
+    return loadLabels(context, filePath, Charset.defaultCharset());
+  }
+
+  /**
+   * Loads labels from the label file into a list of strings.
+   *
+   * <p>A legal label file is the plain text file whose contents are split into lines, and each line
+   * is an individual value. The file should be in assets of the context.
+   *
+   * @param context The context holds assets.
+   * @param filePath The path of the label file, relative with assets directory.
+   * @param cs {@code Charset} to use when decoding content of label file.
+   * @return a list of labels.
+   * @throws IOException if error occurs to open or read the file.
+   */
+  @NonNull
+  public static List<String> loadLabels(
+      @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
     SupportPreconditions.checkNotNull(context, "Context cannot be null.");
     SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
     InputStream inputStream = context.getAssets().open(filePath);
-    return loadLabels(inputStream);
+    return loadLabels(inputStream, cs);
   }
 
   /**
@@ -62,8 +81,23 @@ public class FileUtil {
    */
   @NonNull
   public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException {
+    return loadLabels(inputStream, Charset.defaultCharset());
+  }
+
+  /**
+   * Loads labels from an input stream of an opened label file. See details for label files in
+   * {@link FileUtil#loadLabels(Context, String)}.
+   *
+   * @param inputStream the input stream of an opened label file.
+   * @param cs {@code Charset} to use when decoding content of label file.
+   * @return a list of labels.
+   * @throws IOException if error occurs to open or read the file.
+   */
+  @NonNull
+  public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs)
+      throws IOException {
     List<String> labels = new ArrayList<>();
-    BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
+    BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs));
     String line;
     while ((line = reader.readLine()) != null) {
       labels.add(line);
@@ -72,6 +106,38 @@ public class FileUtil {
     return labels;
   }
 
+  /**
+   * Loads a vocabulary file (a single-column text file) into a list of strings.
+   *
+   * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
+   * and each line is an individual value. The file should be in assets of the context.
+   *
+   * @param context The context holds assets.
+   * @param filePath The path of the vocabulary file, relative with assets directory.
+   * @return a list of vocabulary words.
+   * @throws IOException if error occurs to open or read the file.
+   */
+  @NonNull
+  public static List<String> loadSingleColumnTextFile(
+      @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
+    return loadLabels(context, filePath, cs);
+  }
+
+  /**
+   * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column
+   * text file). See details for vocabulary files in {@link FileUtil#loadVocabularyFile(Context,
+   * String)}.
+   *
+   * @param inputStream the input stream of an opened vocabulary file.
+   * @return a list of vocabulary words.
+   * @throws IOException if error occurs to open or read the file.
+   */
+  @NonNull
+  public static List<String> loadSingleColumnTextFile(@NonNull InputStream inputStream, Charset cs)
+      throws IOException {
+    return loadLabels(inputStream, cs);
+  }
+
   /**
    * Loads a file from the asset folder through memory mapping.
    *
diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
index 7840fbfac6a..1881747870b 100644
--- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
+++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
@@ -24,6 +24,12 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  * <p>Note: The data type of output tensor is always {@code FLOAT32} except when the DequantizeOp is
  * created effectively as an identity Op such as setting {@code zeroPoint} to 0 and {@code scale} to
  * 1 (in this case, the output tensor is the same instance as input).
+ *
+ * <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link DequantizeOp} will be bypassed,
+ * which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful
+ * when passing in the quantization parameters that are extracted directly from the TFLite model
+ * flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read
+ * as 0.
  */
 public class DequantizeOp extends NormalizeOp implements TensorOperator {
 
diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
index 25db461ede1..8ac57eed286 100644
--- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
+++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
@@ -41,6 +41,11 @@ public class NormalizeOp implements TensorOperator {
    *   output = (input - mean) / stddev
    * </pre>
    *
+   * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the
+   * normalization. <br>
+   * 1. Both {@code mean} and {code stddev} are 0. <br>
+   * 2. {@code mean} is 0 and {stddev} is Infinity.
+   *
    * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will
    * happen, and original input will be directly returned in execution.
    *
@@ -53,7 +58,28 @@ public class NormalizeOp implements TensorOperator {
    * @throws IllegalArgumentException if {@code stddev} is zero.
    */
   public NormalizeOp(float mean, float stddev) {
-    this(new float[] {mean}, new float[] {stddev});
+    // Make exceptions to the cases that
+    // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization parameters
+    // from a tensor which does not have the values populated in the metadata. The same situation
+    // may also happen to the quantization parameters.
+    // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization
+    // parameters from a tensor which does not have the values populated in the metadata, and then
+    // passing the parameters into the DequantizeOp.
+    // Bypass both of the two cases, by reseting stddev to 1.0f.
+    if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) {
+      stddev = 1.0f;
+    }
+
+    SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero.");
+    boolean meansIsZeroAndDevsIs1 = false;
+    if (mean == 0.0f && stddev == 1.0f) {
+      meansIsZeroAndDevsIs1 = true;
+    }
+
+    this.isIdentityOp = meansIsZeroAndDevsIs1;
+    this.mean = new float[] {mean};
+    this.stddev = new float[] {stddev};
+    this.numChannels = 1;
   }
 
   /**
diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
index 77a6559cb00..8b3e82aee13 100644
--- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
+++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
@@ -25,6 +25,12 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  * math on top of input. The data type of output tensor is always {@code FLOAT32} except that the Op
  * is effectively an identity Op (in this case, the output tensor is the same instance as the
  * input). To connect with quantized model, a {@link CastOp} is probably needed.
+ *
+ * <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link QuantizeOp} will be bypassed,
+ * which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful
+ * when passing in the quantization parameters that are extracted directly from the TFLite model
+ * flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read
+ * as 0.
  */
 public class QuantizeOp extends NormalizeOp implements TensorOperator {
 
diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD
index cf8e6d40f9f..bfd610b759a 100644
--- a/tensorflow/lite/java/BUILD
+++ b/tensorflow/lite/java/BUILD
@@ -201,7 +201,6 @@ java_test(
         "src/testdata/int32.bin",
         "src/testdata/int64.bin",
         "src/testdata/invalid_model.bin",
-        "src/testdata/quantized.bin",
         "src/testdata/string.bin",
         "src/testdata/uint8.bin",
         "src/testdata/with_custom_op.lite",
@@ -304,6 +303,7 @@ java_test(
         "src/testdata/add.bin",
         "src/testdata/int32.bin",
         "src/testdata/int64.bin",
+        "src/testdata/quantized.bin",
     ],
     javacopts = JAVACOPTS,
     tags = [
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index a6285895e4f..ca21ec5c7ea 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -254,24 +254,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
     return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds;
   }
 
-  /**
-   * Gets the quantization zero point of an output.
-   *
-   * @throws IllegalArgumentException if the output index is invalid.
-   */
-  int getOutputQuantizationZeroPoint(int index) {
-    return getOutputQuantizationZeroPoint(interpreterHandle, index);
-  }
-
-  /**
-   * Gets the quantization scale of an output.
-   *
-   * @throws IllegalArgumentException if the output index is invalid.
-   */
-  float getOutputQuantizationScale(int index) {
-    return getOutputQuantizationScale(interpreterHandle, index);
-  }
-
   /** Gets the number of input tensors. */
   int getInputTensorCount() {
     return inputTensors.length;
@@ -374,10 +356,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
 
   private static native int getOutputDataType(long interpreterHandle, int outputIdx);
 
-  private static native int getOutputQuantizationZeroPoint(long interpreterHandle, int outputIdx);
-
-  private static native float getOutputQuantizationScale(long interpreterHandle, int outputIdx);
-
   private static final int ERROR_BUFFER_SIZE = 512;
 
   private long errorHandle;
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index 5d15b2c9a7e..95b9a41bf32 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -36,13 +36,55 @@ public final class Tensor {
   /**
    * Creates a Tensor wrapper from the provided interpreter instance and tensor index.
    *
-   * <p>The caller is responsible for closing the created wrapper, and ensuring the provided
-   * native interpreter is valid until the tensor is closed.
+   * <p>The caller is responsible for closing the created wrapper, and ensuring the provided native
+   * interpreter is valid until the tensor is closed.
    */
   static Tensor fromIndex(long nativeInterpreterHandle, int tensorIndex) {
     return new Tensor(create(nativeInterpreterHandle, tensorIndex));
   }
 
+  /**
+   * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the
+   * <a
+   * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
+   * Model schema file.</a>
+   *
+   * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} and
+   * {@code zero_point} are both single values instead of arrays.
+   *
+   * <p>For tensor that are not quantized, the values of scale and zero_point are both 0.
+   *
+   * <p>Given a quantized value q, the corresponding float value f should be: <br>
+   * f = scale * (q - zero_point) <br>
+   */
+  public static class QuantizationParams {
+    /** The scale value used in quantization. */
+    private final float scale;
+    /** The zero point value used in quantization. */
+    private final int zeroPoint;
+
+    /**
+     * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}.
+     *
+     * @param scale The scale value used in quantization.
+     * @param zeroPoint The zero point value used in quantization.
+     */
+    public QuantizationParams(final float scale, final int zeroPoint) {
+      this.scale = scale;
+      this.zeroPoint = zeroPoint;
+    }
+
+    /** Returns the scale value. */
+    public float getScale() {
+      return scale;
+    }
+
+    /** Returns the zero point value. */
+    public int getZeroPoint() {
+      return zeroPoint;
+    }
+  }
+
   /** Disposes of any resources used by the Tensor wrapper. */
   void close() {
     delete(nativeHandle);
@@ -114,6 +156,16 @@ public final class Tensor {
     return name(nativeHandle);
   }
 
+  /**
+   * Returns the quantization parameters of the tensor within the owning {@link Interpreter}.
+   *
+   * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
+   * quantized, the values of scale and zero_point are both 0.
+   */
+  public QuantizationParams quantizationParams() {
+    return quantizationParamsCopy;
+  }
+
   /**
    * Copies the contents of the provided {@code src} object to the Tensor.
    *
@@ -376,12 +428,14 @@ public final class Tensor {
   private final DataType dtype;
   private int[] shapeCopy;
   private final int[] shapeSignatureCopy;
+  private final QuantizationParams quantizationParamsCopy;
 
   private Tensor(long nativeHandle) {
     this.nativeHandle = nativeHandle;
     this.dtype = DataType.fromC(dtype(nativeHandle));
     this.shapeCopy = shape(nativeHandle);
     this.shapeSignatureCopy = shapeSignature(nativeHandle);
+    this.quantizationParamsCopy = quantizationParameters(nativeHandle);
   }
 
   private ByteBuffer buffer() {
@@ -413,4 +467,6 @@ public final class Tensor {
   private static native int index(long handle);
 
   private static native String name(long handle);
+
+  private static native QuantizationParams quantizationParameters(long handle);
 }
diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 3701e07bd82..28d14b6da87 100644
--- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -458,40 +458,6 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
   return static_cast<jint>(type);
 }
 
-JNIEXPORT jint JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint(
-    JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
-  if (interpreter == nullptr) return 0;
-  const int idx = static_cast<int>(output_idx);
-  if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
-    ThrowException(env, kIllegalArgumentException,
-                   "Failed to get %d-th output out of %d outputs", output_idx,
-                   interpreter->outputs().size());
-    return 0;
-  }
-  TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]);
-  return static_cast<jint>(target->params.zero_point);
-}
-
-JNIEXPORT jfloat JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale(
-    JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
-  if (interpreter == nullptr) return 1.0f;
-  const int idx = static_cast<int>(output_idx);
-  if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
-    ThrowException(env, kIllegalArgumentException,
-                   "Failed to get %d-th output out of %d outputs", output_idx,
-                   interpreter->outputs().size());
-    return 1.0f;
-  }
-  TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]);
-  return static_cast<jfloat>(target->params.scale);
-}
-
 JNIEXPORT jboolean JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc
index 9a38e85acd1..1a50b3233ca 100644
--- a/tensorflow/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc
@@ -482,6 +482,26 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_index(JNIEnv* env,
   return GetTensorIndexFromHandle(env, handle);
 }
 
+JNIEXPORT jobject JNICALL
+Java_org_tensorflow_lite_Tensor_quantizationParameters(JNIEnv* env,
+                                                       jclass clazz,
+                                                       jlong handle) {
+  const TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
+
+  // For tensor that are not quantized, the values of scale and zero_point are
+  // both 0.
+  jfloat scale = static_cast<jfloat>(tensor->params.scale);
+  jlong zero_point = static_cast<jint>(tensor->params.zero_point);
+
+  jclass quantization_params_class =
+      env->FindClass("org/tensorflow/lite/Tensor$QuantizationParams");
+  jmethodID quantization_params_constructor =
+      env->GetMethodID(quantization_params_class, "<init>", "(FI)V");
+
+  return env->NewObject(quantization_params_class,
+                        quantization_params_constructor, scale, zero_point);
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 6d522197b2d..bab39793130 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -46,9 +46,6 @@ public final class NativeInterpreterWrapperTest {
   private static final String STRING_MODEL_PATH =
       "tensorflow/lite/java/src/testdata/string.bin";
 
-  private static final String QUANTIZED_MODEL_PATH =
-      "tensorflow/lite/java/src/testdata/quantized.bin";
-
   private static final String INVALID_MODEL_PATH =
       "tensorflow/lite/java/src/testdata/invalid_model.bin";
 
@@ -561,16 +558,4 @@ public final class NativeInterpreterWrapperTest {
       assertThat(wrapper.getInputTensor(0).shape()).isEqualTo(expectedDims);
     }
   }
-
-  @Test
-  public void testGetOutputQuantizationParams() {
-    try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) {
-      assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(0);
-      assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.0f);
-    }
-    try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(QUANTIZED_MODEL_PATH)) {
-      assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(127);
-      assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.25f);
-    }
-  }
 }
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index 09e9b1cbc8f..f828f26f4c5 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -31,6 +31,7 @@ import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.tensorflow.lite.Tensor.QuantizationParams;
 
 /** Unit tests for {@link org.tensorflow.lite.Tensor}. */
 @RunWith(JUnit4.class)
@@ -45,6 +46,9 @@ public final class TensorTest {
   private static final String LONG_MODEL_PATH =
       "tensorflow/lite/java/src/testdata/int64.bin";
 
+  private static final String QUANTIZED_MODEL_PATH =
+      "tensorflow/lite/java/src/testdata/quantized.bin";
+
   private NativeInterpreterWrapper wrapper;
   private Tensor tensor;
 
@@ -451,4 +455,27 @@ public final class TensorTest {
       // Expected failure.
     }
   }
+
+  @Test
+  public void testQuantizationParameters_floatModel() {
+    QuantizationParams quantizationParams = tensor.quantizationParams();
+    float scale = quantizationParams.getScale();
+    long zeroPoint = quantizationParams.getZeroPoint();
+
+    assertThat(scale).isWithin(1e-6f).of(0.0f);
+    assertThat(zeroPoint).isEqualTo(0);
+  }
+
+  @Test
+  public void testQuantizationParameters_quantizedModel() {
+    wrapper = new NativeInterpreterWrapper(QUANTIZED_MODEL_PATH);
+    tensor = wrapper.getOutputTensor(0);
+
+    QuantizationParams quantizationParams = tensor.quantizationParams();
+    float scale = quantizationParams.getScale();
+    long zeroPoint = quantizationParams.getZeroPoint();
+
+    assertThat(scale).isWithin(1e-6f).of(0.25f);
+    assertThat(zeroPoint).isEqualTo(127);
+  }
 }
diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h
index 6308131409f..9f967070413 100644
--- a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h
+++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h
@@ -58,8 +58,6 @@ inline void ConvPerChannel(
   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
                            filter_width != 1 || filter_height != 1;
   const int8 input_zero_point = -input_offset;
-  TFLITE_DCHECK_GE(input_zero_point, output_activation_min);
-  TFLITE_DCHECK_LE(input_zero_point, output_activation_max);
   const uint8 zero_point_byte =
       *reinterpret_cast<const uint8*>(&input_zero_point);
   if (need_dilated_im2col) {
diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h
index 26d7548e5c3..642d7577a1b 100644
--- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h
+++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h
@@ -1814,8 +1814,9 @@ inline void DepthwiseConvWithRounding(
   const auto ruy_paths = ruy_context != nullptr
                              ? ruy_context->GetRuntimeEnabledPaths()
                              : ruy::Path::kNone;
+  // TODO(b/150208140): Re-enable once erroneous activation in test is resolved.
   const bool has_dot_product_instructions =
-      (ruy_paths & ruy::Path::kNeonDotprod) != ruy::Path::kNone;
+      false && (ruy_paths & ruy::Path::kNeonDotprod) != ruy::Path::kNone;
 
   // Dispatch to dot-product 3x3 kernels when supported.
   if (has_dot_product_instructions) {
diff --git a/tensorflow/lite/kernels/internal/reference/strided_slice.h b/tensorflow/lite/kernels/internal/reference/strided_slice.h
index 921c49ea77b..ba6d4c22554 100644
--- a/tensorflow/lite/kernels/internal/reference/strided_slice.h
+++ b/tensorflow/lite/kernels/internal/reference/strided_slice.h
@@ -18,7 +18,6 @@ limitations under the License.
 #include "tensorflow/lite/kernels/internal/common.h"
 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
 #include "tensorflow/lite/kernels/internal/types.h"
-
 namespace tflite {
 
 namespace reference_ops {
@@ -28,47 +27,60 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
                          const T* input_data,
                          const RuntimeShape& unextended_output_shape,
                          T* output_data) {
+  using strided_slice::LoopCondition;
+  using strided_slice::StartForAxis;
+  using strided_slice::StopForAxis;
   // Note that the output_shape is not used herein.
   tflite::StridedSliceParams params_copy = op_params;
 
-  TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
   const RuntimeShape input_shape =
-      RuntimeShape::ExtendedShape(4, unextended_input_shape);
+      RuntimeShape::ExtendedShape(5, unextended_input_shape);
   const RuntimeShape output_shape =
-      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+      RuntimeShape::ExtendedShape(5, unextended_output_shape);
 
-  // Reverse and pad to 4 dimensions because that is what the runtime code
-  // requires (ie. all shapes must be 4D and are given backwards).
-  strided_slice::StridedSlicePadIndices(&params_copy, 4);
+  // Reverse and pad to 5 dimensions because that is what the runtime code
+  // requires (ie. all shapes must be 5D and are given backwards).
+  strided_slice::StridedSlicePadIndices(&params_copy, 5);
 
-  const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
-  const int stop_b =
-      strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
-  const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
-  const int stop_h =
-      strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
-  const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
-  const int stop_w =
-      strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
-  const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
-  const int stop_d =
-      strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
+  const int start_0 = StartForAxis(params_copy, input_shape, 0);
+  const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
+  const int start_1 = StartForAxis(params_copy, input_shape, 1);
+  const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
+  const int start_2 = StartForAxis(params_copy, input_shape, 2);
+  const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
+  const int start_3 = StartForAxis(params_copy, input_shape, 3);
+  const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
+  const int start_4 = StartForAxis(params_copy, input_shape, 4);
+  const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
 
   T* out_ptr = output_data;
-  for (int in_b = start_b;
-       !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
-       in_b += params_copy.strides[0]) {
-    for (int in_h = start_h;
-         !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
-         in_h += params_copy.strides[1]) {
-      for (int in_w = start_w;
-           !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
-           in_w += params_copy.strides[2]) {
-        for (int in_d = start_d; !strided_slice::LoopCondition(
-                 in_d, stop_d, params_copy.strides[3]);
-             in_d += params_copy.strides[3]) {
-          *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
+  for (int offset_0 = start_0 * input_shape.Dims(1),
+           end_0 = stop_0 * input_shape.Dims(1),
+           step_0 = params_copy.strides[0] * input_shape.Dims(1);
+       !LoopCondition(offset_0, end_0, params_copy.strides[0]);
+       offset_0 += step_0) {
+    for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
+             end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
+             step_1 = params_copy.strides[1] * input_shape.Dims(2);
+         !LoopCondition(offset_1, end_1, params_copy.strides[1]);
+         offset_1 += step_1) {
+      for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
+               end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
+               step_2 = params_copy.strides[2] * input_shape.Dims(3);
+           !LoopCondition(offset_2, end_2, params_copy.strides[2]);
+           offset_2 += step_2) {
+        for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
+                 end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
+                 step_3 = params_copy.strides[3] * input_shape.Dims(4);
+             !LoopCondition(offset_3, end_3, params_copy.strides[3]);
+             offset_3 += step_3) {
+          for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
+               !LoopCondition(offset_4, end_4, params_copy.strides[4]);
+               offset_4 += params_copy.strides[4]) {
+            *out_ptr++ = input_data[offset_4];
+          }
         }
       }
     }
diff --git a/tensorflow/lite/kernels/internal/strided_slice_logic.h b/tensorflow/lite/kernels/internal/strided_slice_logic.h
index 3022ac7b8e9..12dd33d3296 100644
--- a/tensorflow/lite/kernels/internal/strided_slice_logic.h
+++ b/tensorflow/lite/kernels/internal/strided_slice_logic.h
@@ -35,7 +35,7 @@ inline int Clamp(const int v, const int lo, const int hi) {
 inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
                                    int dim_count) {
   // Add indices and mask bits to fully include extra dimensions
-  TFLITE_CHECK_LE(dim_count, 4);
+  TFLITE_CHECK_LE(dim_count, 5);
   TFLITE_CHECK_GE(dim_count, p->start_indices_count);
   TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
   TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h
index e96e50209bb..f4d2b5c0dad 100644
--- a/tensorflow/lite/kernels/internal/types.h
+++ b/tensorflow/lite/kernels/internal/types.h
@@ -128,9 +128,9 @@ struct Dims {
 
 class RuntimeShape {
  public:
-  // Shapes with dimensions up to 4 are stored directly in the structure, while
+  // Shapes with dimensions up to 5 are stored directly in the structure, while
   // larger shapes are separately allocated.
-  static constexpr int kMaxSmallSize = 4;
+  static constexpr int kMaxSmallSize = 5;
 
   RuntimeShape& operator=(RuntimeShape const&) = delete;
 
@@ -207,8 +207,8 @@ class RuntimeShape {
   inline const int32* DimsData() const {
     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
   }
-  // The caller must ensure that the shape is no bigger than 4-D.
-  inline const int32* DimsDataUpTo4D() const { return dims_; }
+  // The caller must ensure that the shape is no bigger than 5-D.
+  inline const int32* DimsDataUpTo5D() const { return dims_; }
 
   inline void Resize(int dimensions_count) {
     if (size_ > kMaxSmallSize) {
@@ -378,7 +378,7 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
 
 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
   TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
-  const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo4D());
+  const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
   TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
   TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
   TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
@@ -1049,11 +1049,11 @@ struct SqueezeParams {
 
 struct StridedSliceParams {
   int8 start_indices_count;
-  int32 start_indices[4];
+  int32 start_indices[5];
   int8 stop_indices_count;
-  int32 stop_indices[4];
+  int32 stop_indices[5];
   int8 strides_count;
-  int32 strides[4];
+  int32 strides[5];
 
   int16 begin_mask;
   int16 ellipsis_mask;
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index e8eebd81025..51534375e5f 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -156,7 +156,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
   AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
              /* min_version */ 1,
-             /* max_version */ 3);
+             /* max_version */ 4);
   AddBuiltin(BuiltinOperator_EXP, Register_EXP());
   AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
              /* min_version */ 1,
diff --git a/tensorflow/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc
index ba39b016624..e2ca812d193 100644
--- a/tensorflow/lite/kernels/strided_slice.cc
+++ b/tensorflow/lite/kernels/strided_slice.cc
@@ -142,8 +142,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
   TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
   TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
-  TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
-                     "StridedSlice op only supports 1D-4D input arrays.");
+  TF_LITE_ENSURE_MSG(context, op_context.dims <= 5,
+                     "StridedSlice op only supports 1D-5D input arrays.");
 
   // TODO(b/138098220): Remove when bug is resolved.
   // Currently, working on using the compiler to cannonize strided_slice,
diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc
index 83093a09eed..8db98dba0e9 100644
--- a/tensorflow/lite/kernels/strided_slice_test.cc
+++ b/tensorflow/lite/kernels/strided_slice_test.cc
@@ -82,9 +82,9 @@ TYPED_TEST_SUITE(StridedSliceOpTest, DataTypes);
 
 #ifdef GTEST_HAS_DEATH_TEST
 TYPED_TEST(StridedSliceOpTest, UnsupportedInputSize) {
-  EXPECT_DEATH(StridedSliceOpModel<TypeParam>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0,
-                                              0, 0, 0, 0),
-               "StridedSlice op only supports 1D-4D input arrays.");
+  EXPECT_DEATH(StridedSliceOpModel<TypeParam>({2, 2, 2, 2, 2, 2}, {5}, {5}, {5},
+                                              0, 0, 0, 0, 0),
+               "StridedSlice op only supports 1D-5D input arrays.");
 }
 
 TYPED_TEST(StridedSliceOpTest, UnssupportedArgs) {
@@ -612,5 +612,29 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) {
   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
 }
+
+TYPED_TEST(StridedSliceOpTest, In5D_Identity) {
+  StridedSliceOpModel<TypeParam> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0,
+                                   0);
+  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+  m.SetBegin({0, 0, 0, 0, 0});
+  m.SetEnd({2, 1, 2, 1, 2});
+  m.SetStrides({1, 1, 1, 1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 1, 2}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 9, 10, 11, 12}));
+}
+
+TYPED_TEST(StridedSliceOpTest, In5D_IdentityShrinkAxis1) {
+  StridedSliceOpModel<TypeParam> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0,
+                                   1);
+  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+  m.SetBegin({0, 0, 0, 0, 0});
+  m.SetEnd({2, 1, 2, 1, 2});
+  m.SetStrides({1, 1, 1, 1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4}));
+}
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/micro/examples/micro_speech/esp/Makefile.inc b/tensorflow/lite/micro/examples/micro_speech/esp/Makefile.inc
index f03f199d215..a658c2fad94 100644
--- a/tensorflow/lite/micro/examples/micro_speech/esp/Makefile.inc
+++ b/tensorflow/lite/micro/examples/micro_speech/esp/Makefile.inc
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-ifeq ($(TARGET), "esp")
+ifeq ($(TARGET), esp)
 
 # Adding some esp specific files in the main CMakeLists.txt
 ESP_MICRO_SPEECH_SRCS := \
@@ -26,34 +26,6 @@ CXXFLAGS += -Wno-return-type -Wno-strict-aliasing -Wno-ignored-qualifiers
 MICRO_SPEECH_SRCS += $(ESP_MICRO_SPEECH_SRCS)
 MICRO_SPEECH_HDRS += $(ESP_MICRO_SPEECH_HDRS)
 MAIN_SRCS += $(ESP_MICRO_SPEECH_SRCS)
-# Adding the microfrontend lib in the CMakeLists.txt of tfmicro
-PROJECT_INCLUDES += \
-tensorflow/lite/experimental/microfrontend/lib
-
-MICRO_FRONTEND_LIB_SRCS := \
-tensorflow/lite/experimental/microfrontend/lib/fft.cc \
-tensorflow/lite/experimental/microfrontend/lib/fft_util.cc \
-tensorflow/lite/experimental/microfrontend/lib/filterbank.c \
-tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c \
-tensorflow/lite/experimental/microfrontend/lib/frontend.c \
-tensorflow/lite/experimental/microfrontend/lib/frontend_util.c \
-tensorflow/lite/experimental/microfrontend/lib/log_lut.c \
-tensorflow/lite/experimental/microfrontend/lib/log_scale.c \
-tensorflow/lite/experimental/microfrontend/lib/log_scale_util.c \
-tensorflow/lite/experimental/microfrontend/lib/noise_reduction.c \
-tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.c \
-tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.c \
-tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.c \
-tensorflow/lite/experimental/microfrontend/lib/window.c \
-tensorflow/lite/experimental/microfrontend/lib/window_util.c
-
-# Adding the micro frontend lib srcs into the CMakeLists.txt of tfmicro
-MICROLITE_CC_SRCS += $(MICRO_FRONTEND_LIB_SRCS)
-# Adding the kissfft srcs in the CMakeLists.txt of tfmicro
-THIRD_PARTY_CC_SRCS += $(KISSFFT_LIB_SRCS)
-# stopping microfrontend srcs from being included in the main srcs
-MICRO_SPEECH_SRCS := $(filter-out $(MICRO_FRONTEND_LIB_SRCS), $(MICRO_SPEECH_SRCS))
-MICRO_SPEECH_SRCS := $(filter-out $(KISSFFT_LIB_SRCS), $(MICRO_SPEECH_SRCS))
 
 ESP_PROJECT_FILES += \
 sdkconfig.defaults
diff --git a/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc b/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc
index 5c23381c3cc..3596246d1e3 100644
--- a/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc
+++ b/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc
@@ -18,12 +18,17 @@ limitations under the License.
 #include <cstdlib>
 #include <cstring>
 
+// FreeRTOS.h must be included before some of the following dependencies.
+// Solves b/150260343.
+// clang-format off
+#include "freertos/FreeRTOS.h"
+// clang-format on
+
 #include "driver/i2s.h"
 #include "esp_log.h"
 #include "esp_spi_flash.h"
 #include "esp_system.h"
 #include "esp_timer.h"
-#include "freertos/FreeRTOS.h"
 #include "freertos/task.h"
 #include "ringbuf.h"
 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/micro_model_settings.h"
diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile
index 1dc45f88cb9..5546dfc95ab 100644
--- a/tensorflow/lite/micro/tools/make/Makefile
+++ b/tensorflow/lite/micro/tools/make/Makefile
@@ -1,4 +1,3 @@
-
 ifneq (3.82,$(firstword $(sort $(MAKE_VERSION) 3.82)))
   $(error "Requires make version 3.82 or later (current is $(MAKE_VERSION))")
 endif
@@ -65,18 +64,24 @@ TEST_SCRIPT := tensorflow/lite/micro/testing/test_linux_binary.sh
 
 MICROLITE_LIBS := -lm
 
-# There are no rules for compiling objects for the host system (since we don't
-# generate things like the protobuf compiler that require that), so all of
-# these settings are for the target compiler.
-CXXFLAGS := -O3
-CXXFLAGS += -std=c++11 -g -DTF_LITE_STATIC_MEMORY
-CXXFLAGS += -fno-rtti
-CCFLAGS := -g -DTF_LITE_STATIC_MEMORY
-LDOPTS := -L/usr/local/lib
+# TODO(b/150240249): Add in -fno-rtti once that works for the Xtensa toolchain.
+CXXFLAGS := -std=c++11 -DTF_LITE_STATIC_MEMORY
+CCFLAGS  := -std=c11   -DTF_LITE_STATIC_MEMORY
 ARFLAGS := -r
 TARGET_TOOLCHAIN_PREFIX :=
 CC_PREFIX :=
 
+ifeq ($(BUILD_TYPE), debug)
+	CXXFLAGS += -DDEBUG -g
+	CCFLAGS  += -DDEBUG -g
+else ifeq ($(BUILD_TYPE), release)
+	CXXFLAGS += -DNDEBUG -O3 -DTF_LITE_STRIP_ERROR_STRINGS
+	CCFLAGS  += -DNDEBUG -O3 -DTF_LITE_STRIP_ERROR_STRINGS
+else
+	CXXFLAGS += -O3
+	CCFLAGS  += -O3
+endif
+
 # This library is the main target for this makefile. It will contain a minimal
 # runtime that can be linked in to other programs.
 MICROLITE_LIB_NAME := libtensorflow-microlite.a
diff --git a/tensorflow/lite/micro/tools/make/helper_functions.inc b/tensorflow/lite/micro/tools/make/helper_functions.inc
index be6d943f4e9..40c8c1c0aeb 100644
--- a/tensorflow/lite/micro/tools/make/helper_functions.inc
+++ b/tensorflow/lite/micro/tools/make/helper_functions.inc
@@ -326,10 +326,15 @@ $(PRJDIR)$(2)/esp-idf/sdkconfig.defaults: tensorflow/lite/micro/examples/$(2)/es
 	@cp $$< $$@
 
 $(PRJDIR)$(2)/esp-idf/%: tensorflow/lite/micro/tools/make/templates/esp/%.tpl
-	$(eval MAIN_SRCS_RELATIVE := $(patsubst tensorflow/lite/micro/examples/$(2)/%,%,$(5)))
+  # Split the sources into 2 components:
+  # - Main component contains only the example's sources, relative from its dir.
+  $(eval MAIN_SRCS := $(filter tensorflow/lite/micro/examples/%,$(5)))
+  $(eval MAIN_SRCS_RELATIVE := $(patsubst tensorflow/lite/micro/examples/$(2)/%,%,$(MAIN_SRCS)))
+  # - TFL Micro component contains everything but the example sources.
+  $(eval TFLM_SRCS := $(filter-out tensorflow/lite/micro/examples/%,$(5)) $(3))
 
 	@mkdir -p $$(dir $$@)
-	@sed -E 's#\%\{COMPONENT_SRCS\}\%#$(3)#g' $$< | \
+	@sed -E 's#\%\{COMPONENT_SRCS\}\%#$(TFLM_SRCS)#g' $$< | \
 	sed -E 's#\%\{MAIN_SRCS\}\%#$(MAIN_SRCS_RELATIVE)#g' | \
 	sed -E 's#\%\{EXECUTABLE\}\%#$(2)#g' | \
 	sed -E 's#\%\{COMPONENT_INCLUDES\}\%#$(10)#g' | \
diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc
index 0ccad72692d..5836aea417d 100644
--- a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc
+++ b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc
@@ -3,17 +3,14 @@
 #   - RI2019.2 Toolkit (for xt-clang/xt-clang++).
 #   - XTENSA_CORE: The name of the core to use, will cause a compiler exception
 #                  without providing a core.
+
 ifeq ($(TARGET), xtensa-xpg)
   TARGET_ARCH := xtensa-xpg
 
   PLATFORM_ARGS = \
-    -DTF_LITE_STATIC_MEMORY \
-    -DTF_LITE_STRIP_ERROR_STRINGS \
-    -DNDEBUG \
     -DTF_LITE_MCU_DEBUG_LOG \
     --xtensa-core=$(XTENSA_CORE) \
     -mcoproc \
-    -O3 \
     -DXTENSA -DMAX_RFFT_PWR=9 -DMIN_RFFT_PWR=MAX_RFFT_PWR \
     -fdata-sections \
     -ffunction-sections \
@@ -25,18 +22,10 @@ ifeq ($(TARGET), xtensa-xpg)
   CXX_TOOL := clang++
   CC_TOOL := clang
 
-  CXXFLAGS = $(PLATFORM_ARGS) -std=c++11
-  CCFLAGS = $(PLATFORM_ARGS) -std=c11
+  CXXFLAGS += $(PLATFORM_ARGS)
+  CCFLAGS += $(PLATFORM_ARGS)
 
   LDFLAGS += -Wl,-gc-sections
 
   TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_xpg_binary.sh
-
-  # These are microcontroller-specific rules for converting the ELF output
-  # of the linker into a binary image that can be loaded directly.
-  OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy
-
-  $(BINDIR)/%.bin: $(BINDIR)/%
-	  @mkdir -p $(dir $@)
-	  $(OBJCOPY) $< $@ -O binary
 endif
diff --git a/tensorflow/lite/micro/xtensa-xpg/debug_log.cc b/tensorflow/lite/micro/xtensa-xpg/debug_log.cc
index a95a084978b..ca8b251c069 100644
--- a/tensorflow/lite/micro/xtensa-xpg/debug_log.cc
+++ b/tensorflow/lite/micro/xtensa-xpg/debug_log.cc
@@ -39,7 +39,10 @@ limitations under the License.
 #include <cstdio>
 
 extern "C" void DebugLog(const char* s) {
-#ifndef NDEBUG
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+  // Reusing TF_LITE_STRIP_ERROR_STRINGS to disable DebugLog completely to get
+  // maximum reduction in binary size. This is because we have DebugLog calls
+  // via TF_LITE_CHECK that are not stubbed out by TF_LITE_REPORT_ERROR.
   fprintf(stderr, "%s", s);
 #endif
 }
diff --git a/tensorflow/lite/testing/op_tests/strided_slice_np_style.py b/tensorflow/lite/testing/op_tests/strided_slice_np_style.py
index 95f7acabdf7..45f2e4b867a 100644
--- a/tensorflow/lite/testing/op_tests/strided_slice_np_style.py
+++ b/tensorflow/lite/testing/op_tests/strided_slice_np_style.py
@@ -68,6 +68,18 @@ def make_strided_slice_np_style_tests(options):
                    [slice(1, 11, 3), Ellipsis,
                     slice(3, 7, 2)]],
       },
+      # Ellipsis 5d.
+      {
+          "dtype": [tf.float32],
+          "shape": [[11, 21, 15, 7, 9]],
+          "spec": [[
+              slice(3, 7, 2),
+              slice(None),
+              slice(None),
+              slice(None),
+              slice(None)
+          ], [Ellipsis, slice(3, 7, 2)]],
+      },
       # All combinations.
       {
           "dtype": [tf.float32],
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index 3238d8ef032..0c310d15020 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -1165,6 +1165,15 @@ class StridedSlice
     op->new_axis_mask = options.new_axis_mask();
     op->shrink_axis_mask = options.shrink_axis_mask();
   }
+
+  int GetVersion(const OperatorSignature& op_signature) const override {
+    const auto& ss_op =
+        static_cast<const StridedSliceOperator&>(*op_signature.op);
+    ::tflite::OpSignature op_sig =
+        GetVersioningOpSig(builtin_op(), op_signature);
+    op_sig.options.strided_slice.num_dims = ss_op.start_indices.size();
+    return ::tflite::GetBuiltinOperatorVersion(op_sig);
+  }
 };
 
 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md
index 89a64ab2751..286ddf69cab 100644
--- a/tensorflow/lite/tools/benchmark/README.md
+++ b/tensorflow/lite/tools/benchmark/README.md
@@ -65,6 +65,14 @@ and the following optional parameters:
     This is available on recent Android devices. Note that some Android P
     devices will fail to use NNAPI for models in `/data/local/tmp/` and this
     benchmark tool will not correctly use NNAPI.
+*   `max_delegated_partitions`: `int` (default=0, i.e. no limit) \
+    The maximum number of partitions that will be delegated. \
+    Currently supported only by the NNAPI Delegate and it won't work \
+    if `use_legacy_nnapi` has been selected.
+*   `disable_nnapi_cpu`: `bool` (default=false) \
+    Excludes the [NNAPI CPU reference implementation](https://developer.android.com/ndk/guides/neuralnetworks#device-assignment)
+    from the possible devices to be used by NNAPI to execute the model.
+    This option is ignored if `nnapi_accelerator_name` is specified.
 *   `use_gpu`: `bool` (default=false) \
     Whether to use the [GPU accelerator delegate](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/gpu).
     This option is currently only available on Android and iOS devices.
diff --git a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc
index eccd389af6f..46620ae3372 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc
@@ -240,6 +240,8 @@ void BenchmarkPerformanceOptions::ResetPerformanceOptions() {
   single_option_run_params_->Set<bool>("gpu_precision_loss_allowed", true);
   single_option_run_params_->Set<bool>("use_nnapi", false);
   single_option_run_params_->Set<std::string>("nnapi_accelerator_name", "");
+  single_option_run_params_->Set<bool>("disable_nnapi_cpu", false);
+  single_option_run_params_->Set<int>("max_delegated_partitions", 0);
 #endif
 #if defined(TFLITE_ENABLE_HEXAGON)
   single_option_run_params_->Set<bool>("use_hexagon", false);
@@ -305,6 +307,10 @@ void BenchmarkPerformanceOptions::CreatePerformanceOptions() {
         params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(true));
         params.AddParam("nnapi_accelerator_name",
                         BenchmarkParam::Create<std::string>(name));
+        params.AddParam("disable_nnapi_cpu",
+                        BenchmarkParam::Create<bool>(false));
+        params.AddParam("max_delegated_partitions",
+                        BenchmarkParam::Create<int>(0));
         all_run_params_.emplace_back(std::move(params));
       }
     }
diff --git a/tensorflow/lite/tools/benchmark/benchmark_test.cc b/tensorflow/lite/tools/benchmark/benchmark_test.cc
index 81b27c092ec..0d49fc8baec 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_test.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_test.cc
@@ -77,6 +77,8 @@ BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs,
                   BenchmarkParam::Create<std::string>(""));
   params.AddParam("nnapi_execution_preference",
                   BenchmarkParam::Create<std::string>(""));
+  params.AddParam("disable_nnapi_cpu", BenchmarkParam::Create<bool>(false));
+  params.AddParam("max_delegated_partitions", BenchmarkParam::Create<int>(0));
   params.AddParam("profiling_output_csv_file",
                   BenchmarkParam::Create<std::string>(""));
   return params;
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index 6b1e9819312..d563724efe4 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -215,6 +215,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
                           BenchmarkParam::Create<int32_t>(1024));
   default_params.AddParam("profiling_output_csv_file",
                           BenchmarkParam::Create<std::string>(""));
+  default_params.AddParam("max_delegated_partitions",
+                          BenchmarkParam::Create<int32_t>(0));
 
   for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
     delegate_util->AddParams(&default_params);
@@ -257,7 +259,9 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
       CreateFlag<std::string>(
           "profiling_output_csv_file", &params_,
           "File path to export profile data as CSV, if not set "
-          "prints to stdout.")};
+          "prints to stdout."),
+      CreateFlag<int>("max_delegated_partitions", &params_,
+                      "Max partitions to be delegated.")};
 
   flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
 
@@ -295,6 +299,8 @@ void BenchmarkTfLiteModel::LogParams() {
   TFLITE_LOG(INFO) << "CSV File to export profiling data to: ["
                    << params_.Get<std::string>("profiling_output_csv_file")
                    << "]";
+  TFLITE_LOG(INFO) << "Max number of delegated partitions : ["
+                   << params_.Get<int32_t>("max_delegated_partitions") << "]";
 
   for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
     delegate_util->LogParams(params_);
diff --git a/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc b/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc
index 4ac50b9771f..3f87de863e7 100644
--- a/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc
+++ b/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc
@@ -50,7 +50,10 @@ std::vector<Flag> NnapiDelegateProvider::CreateFlags(
                               "sustained_speed, low_power, undefined"),
       CreateFlag<std::string>(
           "nnapi_accelerator_name", params,
-          "the name of the nnapi accelerator to use (requires Android Q+)")};
+          "the name of the nnapi accelerator to use (requires Android Q+)"),
+      CreateFlag<bool>("disable_nnapi_cpu", params,
+                       "Disable the NNAPI CPU device")};
+
   return flags;
 }
 
@@ -60,17 +63,18 @@ void NnapiDelegateProvider::AddParams(BenchmarkParams* params) const {
                    BenchmarkParam::Create<std::string>(""));
   params->AddParam("nnapi_accelerator_name",
                    BenchmarkParam::Create<std::string>(""));
+  params->AddParam("disable_nnapi_cpu", BenchmarkParam::Create<bool>(false));
 }
 
 void NnapiDelegateProvider::LogParams(const BenchmarkParams& params) const {
 #if defined(__ANDROID__)
   TFLITE_LOG(INFO) << "Use nnapi : [" << params.Get<bool>("use_nnapi") << "]";
-  if (!params.Get<std::string>("nnapi_execution_preference").empty()) {
-    TFLITE_LOG(INFO) << "nnapi execution preference: ["
-                     << params.Get<std::string>("nnapi_execution_preference")
-                     << "]";
-  }
   if (params.Get<bool>("use_nnapi")) {
+    if (!params.Get<std::string>("nnapi_execution_preference").empty()) {
+      TFLITE_LOG(INFO) << "nnapi execution preference: ["
+                       << params.Get<std::string>("nnapi_execution_preference")
+                       << "]";
+    }
     std::string log_string = "nnapi accelerator name: [" +
                              params.Get<std::string>("nnapi_accelerator_name") +
                              "]";
@@ -80,6 +84,10 @@ void NnapiDelegateProvider::LogParams(const BenchmarkParams& params) const {
       log_string += " (Available: " + string_device_names_list + ")";
     }
     TFLITE_LOG(INFO) << log_string;
+    if (params.Get<bool>("disable_nnapi_cpu")) {
+      TFLITE_LOG(INFO) << "disable_nnapi_cpu: ["
+                       << params.Get<bool>("disable_nnapi_cpu") << "]";
+    }
   }
 #endif
 }
@@ -93,6 +101,8 @@ TfLiteDelegatePtr NnapiDelegateProvider::CreateTfLiteDelegate(
         params.Get<std::string>("nnapi_accelerator_name");
     if (!accelerator_name.empty()) {
       options.accelerator_name = accelerator_name.c_str();
+    } else if (params.Get<bool>("disable_nnapi_cpu")) {
+      options.disallow_nnapi_cpu = true;
     }
     std::string string_execution_preference =
         params.Get<std::string>("nnapi_execution_preference");
@@ -121,6 +131,10 @@ TfLiteDelegatePtr NnapiDelegateProvider::CreateTfLiteDelegate(
       }
       options.execution_preference = execution_preference;
     }
+    int max_delegated_partitions = params.Get<int>("max_delegated_partitions");
+    if (max_delegated_partitions > 0) {
+      options.max_number_delegated_partitions = max_delegated_partitions;
+    }
     delegate = evaluation::CreateNNAPIDelegate(options);
     if (!delegate.get()) {
       TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform.";
diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile
index 1597b8e7f2e..b78fb14b785 100644
--- a/tensorflow/lite/tools/make/Makefile
+++ b/tensorflow/lite/tools/make/Makefile
@@ -56,9 +56,8 @@ LIBS := \
 # There are no rules for compiling objects for the host system (since we don't
 # generate things like the protobuf compiler that require that), so all of
 # these settings are for the target compiler.
-CXXFLAGS := -O3 -DNDEBUG -fPIC
-CXXFLAGS += $(EXTRA_CXXFLAGS)
-CFLAGS := ${CXXFLAGS}
+CFLAGS := -O3 -DNDEBUG -fPIC
+CXXFLAGS := $(CFLAGS) --std=c++11 $(EXTRA_CXXFLAGS)
 LDOPTS := -L/usr/local/lib
 ARFLAGS := -r
 TARGET_TOOLCHAIN_PREFIX :=
@@ -68,10 +67,6 @@ ifeq ($(HOST_OS),windows)
 CXXFLAGS += -fext-numeric-literals -D__LITTLE_ENDIAN__
 endif
 
-ifeq ($(TARGET),ios)
-CXXFLAGS += --std=c++11
-endif
-
 # Auto-detect optimization opportunity if building natively.
 ifeq ($(HOST_OS),$(TARGET))
 ifeq ($(HOST_ARCH),$(TARGET_ARCH))
diff --git a/tensorflow/lite/tools/pip_package/Makefile b/tensorflow/lite/tools/pip_package/Makefile
index eaca6e131b3..ff18303284c 100644
--- a/tensorflow/lite/tools/pip_package/Makefile
+++ b/tensorflow/lite/tools/pip_package/Makefile
@@ -46,9 +46,10 @@ docker-build: docker-image
 		--volume $(TENSORFLOW_DIR):/tensorflow \
 		--volume $(OUT_DIR):/out \
 		$(TAG_IMAGE) \
-		/bin/bash -c "tensorflow/tensorflow/lite/tools/pip_package/build_pip_package.sh && \
-		              (cp ${MAKEFILE_DIR}/gen/tflite_pip/*.deb ${MAKEFILE_DIR}/gen/tflite_pip/python3/dist/{*.whl,*.tar.gz} /out 2>/dev/null || true)"
+		/bin/bash -c "/tensorflow/tensorflow/lite/tools/pip_package/build_pip_package.sh && \
+		              (cp /tensorflow/tensorflow/lite/tools/pip_package/gen/tflite_pip/*.deb \
+		                  /tensorflow/tensorflow/lite/tools/pip_package/gen/tflite_pip/${PYTHON}/dist/{*.whl,*.tar.gz} \
+		                  /out 2>/dev/null || true)"
 
 clean:
 	rm -rf $(CURDIR)/out
-
diff --git a/tensorflow/lite/tools/pip_package/build_pip_package.sh b/tensorflow/lite/tools/pip_package/build_pip_package.sh
index 925c6142be0..5ba2cf954f6 100755
--- a/tensorflow/lite/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/lite/tools/pip_package/build_pip_package.sh
@@ -13,8 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-set -e
-set -x
+set -ex
 
 SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
 PYTHON="${PYTHON:-python3}"
@@ -23,7 +22,7 @@ export TENSORFLOW_DIR="${SCRIPT_DIR}/../../../.."
 TENSORFLOW_LITE_DIR="${TENSORFLOW_DIR}/tensorflow/lite"
 TENSORFLOW_VERSION=$(grep "_VERSION = " "${TENSORFLOW_DIR}/tensorflow/tools/pip_package/setup.py" | cut -d= -f2 | sed "s/[ '-]//g")
 export PACKAGE_VERSION="${TENSORFLOW_VERSION}${VERSION_SUFFIX}"
-BUILD_DIR="${SCRIPT_DIR}/gen/tflite_pip/python3"
+BUILD_DIR="${SCRIPT_DIR}/gen/tflite_pip/${PYTHON}"
 
 # Build source tree.
 rm -rf "${BUILD_DIR}" && mkdir -p "${BUILD_DIR}/tflite_runtime"
diff --git a/tensorflow/lite/tools/pip_package/setup.py b/tensorflow/lite/tools/pip_package/setup.py
index 19c9993e5fa..f99a5b043dc 100644
--- a/tensorflow/lite/tools/pip_package/setup.py
+++ b/tensorflow/lite/tools/pip_package/setup.py
@@ -39,47 +39,30 @@ PACKAGE_NAME = 'tflite_runtime'
 PACKAGE_VERSION = os.environ['PACKAGE_VERSION']
 DOCLINES = __doc__.split('\n')
 TENSORFLOW_DIR = os.environ['TENSORFLOW_DIR']
-
-# Setup cross compiling
-TARGET = os.environ.get('TENSORFLOW_TARGET', None)
-if TARGET == 'rpi':
-  os.environ['CXX'] = 'arm-linux-gnueabihf-g++'
-  os.environ['CC'] = 'arm-linux-gnueabihf-gcc'
-elif TARGET == 'aarch64':
-  os.environ['CXX'] = 'aarch64-linux-gnu-g++'
-  os.environ['CC'] = 'aarch64-linux-gnu-gcc'
-MAKE_CROSS_OPTIONS = ['TARGET=%s' % TARGET]  if TARGET else []
-
-TARGET_ARCH = (
-    os.environ['TENSORFLOW_TARGET_ARCH'] \
-    if 'TENSORFLOW_TARGET_ARCH' in os.environ
-    else None)
-MAKE_CROSS_OPTIONS += ['TARGET_ARCH=%s' % TARGET_ARCH] \
-        if TARGET_ARCH else []
-
-CC_PREFIX = (
-    os.environ['TENSORFLOW_CC_PREFIX'] \
-    if 'TENSORFLOW_CC_PREFIX' in os.environ
-    else None)
-MAKE_CROSS_OPTIONS += ['CC_PREFIX=%s' % CC_PREFIX] \
-        if CC_PREFIX else []
-
-EXTRA_CXXFLAGS = (
-    os.environ['TENSORFLOW_EXTRA_CXXFLAGS'] \
-    if 'TENSORFLOW_EXTRA_CXXFLAGS' in os.environ
-    else None)
-MAKE_CROSS_OPTIONS += ['EXTRA_CXXFLAGS=%s' % EXTRA_CXXFLAGS] \
-        if EXTRA_CXXFLAGS else []
-
 RELATIVE_MAKE_DIR = os.path.join('tensorflow', 'lite', 'tools', 'make')
 MAKE_DIR = os.path.join(TENSORFLOW_DIR, RELATIVE_MAKE_DIR)
 DOWNLOADS_DIR = os.path.join(MAKE_DIR, 'downloads')
 RELATIVE_MAKEFILE_PATH = os.path.join(RELATIVE_MAKE_DIR, 'Makefile')
 DOWNLOAD_SCRIPT_PATH = os.path.join(MAKE_DIR, 'download_dependencies.sh')
 
+# Setup cross compiling
+TARGET = os.environ.get('TENSORFLOW_TARGET')
+if TARGET == 'rpi':
+  os.environ['CXX'] = 'arm-linux-gnueabihf-g++'
+  os.environ['CC'] = 'arm-linux-gnueabihf-gcc'
+elif TARGET == 'aarch64':
+  os.environ['CXX'] = 'aarch64-linux-gnu-g++'
+  os.environ['CC'] = 'aarch64-linux-gnu-gcc'
+
+MAKE_CROSS_OPTIONS = []
+for name in ['TARGET', 'TARGET_ARCH', 'CC_PREFIX', 'EXTRA_CXXFLAGS']:
+  value = os.environ.get('TENSORFLOW_%s' % name)
+  if value:
+    MAKE_CROSS_OPTIONS.append('%s=%s' % (name, value))
+
 
 # Check physical memory and if we are on a reasonable non small SOC machine
-# with more than 4GB, use all the CPUs, otherwisxe only 1.
+# with more than 4GB, use all the CPUs, otherwise only 1.
 def get_build_cpus():
   physical_bytes = os.sysconf('SC_PAGESIZE') * os.sysconf('SC_PHYS_PAGES')
   if physical_bytes < (1<<30) * 4:
@@ -156,6 +139,7 @@ ext = Extension(
              'interpreter_wrapper/numpy.cc',
              'interpreter_wrapper/python_error_reporter.cc',
              'interpreter_wrapper/python_utils.cc'],
+    extra_compile_args=['--std=c++11'],
     swig_opts=['-c++',
                '-I%s' % TENSORFLOW_DIR,
                '-module', 'interpreter_wrapper',
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index b699f0dbc9b..b3ef46503b3 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -266,6 +266,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
       }
       return 1;
     case BuiltinOperator_STRIDED_SLICE:
+      if (op_sig.options.strided_slice.num_dims > 4) {
+        return 4;
+      }
       // If the op takes bool input, it is version 3.
       if (op_sig.input_types.at(0) == TensorType_BOOL) {
         return 3;
@@ -431,6 +434,11 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
             resize_bilinear_option->half_pixel_centers();
       }
     } break;
+    // TODO(b/150176627): Add tests for GetOpSignature.
+    case BuiltinOperator_STRIDED_SLICE: {
+      op_sig.options.strided_slice.num_dims =
+          subgraph->tensors()->Get(op->inputs()->Get(0))->shape()->size();
+    } break;
 
     default:
       break;
diff --git a/tensorflow/lite/tools/versioning/op_version.h b/tensorflow/lite/tools/versioning/op_version.h
index 7fbc5a056e5..364d1a299cc 100644
--- a/tensorflow/lite/tools/versioning/op_version.h
+++ b/tensorflow/lite/tools/versioning/op_version.h
@@ -49,6 +49,9 @@ typedef struct {
     struct {
       bool half_pixel_centers;
     } resize_bilinear;
+    struct {
+      int32_t num_dims;
+    } strided_slice;
   } options;
 } OpSignature;
 
diff --git a/tensorflow/lite/tools/visualize.py b/tensorflow/lite/tools/visualize.py
index b78695be5a5..1f89f9c5448 100644
--- a/tensorflow/lite/tools/visualize.py
+++ b/tensorflow/lite/tools/visualize.py
@@ -265,7 +265,9 @@ class TensorMapper(object):
       html += str(i) + " "
       html += NameListToString(tensor["name"]) + " "
       html += TensorTypeToName(tensor["type"]) + " "
-      html += (repr(tensor["shape"]) if "shape" in tensor else "[]") + "<br>"
+      html += (repr(tensor["shape"]) if "shape" in tensor else "[]")
+      html += (repr(tensor["shape_signature"])
+               if "shape_signature" in tensor else "[]") + "<br>"
     html += "</span>"
     html += repr(x)
     html += "</span>"
@@ -447,9 +449,9 @@ def CreateHtmlFile(tflite_input, html_output):
                           ("builtin_options", None),
                           ("opcode_index", opcode_mapper)]
     tensor_keys_to_display = [("name", NameListToString),
-                              ("type", TensorTypeToName),
-                              ("shape", None),
-                              ("buffer", None), ("quantization", None)]
+                              ("type", TensorTypeToName), ("shape", None),
+                              ("shape_signature", None), ("buffer", None),
+                              ("quantization", None)]
 
     html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
 
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 3b99884c603..6f7ec6389a3 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -623,9 +623,10 @@ tf_python_pybind_extension(
         "@pybind11",
         "//third_party/python_runtime:headers",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:framework",
         "//tensorflow/core:core_cpu_headers_lib",
-        "//tensorflow/core:lib_headers_for_pybind",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
         "@com_google_absl//absl/types:optional",
     ] + if_static(
         extra_deps = [
@@ -1642,6 +1643,7 @@ py_library(
     srcs = ["framework/auto_control_deps.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":auto_control_deps_utils",
         ":control_flow_ops",
         ":framework_ops",
         ":sparse_tensor",
@@ -1650,6 +1652,15 @@ py_library(
     ],
 )
 
+py_library(
+    name = "auto_control_deps_utils",
+    srcs = ["framework/auto_control_deps_utils.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":dtypes",
+    ],
+)
+
 tf_py_test(
     name = "auto_control_deps_test",
     size = "small",
@@ -2155,6 +2166,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":array_ops",
+        ":auto_control_deps_utils",
         ":constant_op",
         ":control_flow_ops",
         ":framework_ops",
@@ -2882,6 +2894,7 @@ tf_gen_op_wrapper_private_py(
     name = "resource_variable_ops_gen",
     visibility = [
         "//tensorflow/compiler/tf2xla:internal",
+        "//tensorflow/python/distribute:__pkg__",
     ],
 )
 
@@ -3352,6 +3365,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":array_ops",
+        ":auto_control_deps_utils",
         ":c_api_util",
         ":control_flow_util_v2",
         ":framework_ops",
@@ -3378,6 +3392,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":array_ops",
+        ":auto_control_deps_utils",
         ":constant_op",
         ":control_flow_ops",
         ":control_flow_util",
@@ -3983,6 +3998,7 @@ py_library(
     deps = [
         ":array_ops",
         ":array_ops_gen",
+        ":auto_control_deps_utils",
         ":dtypes",
         ":framework_ops",
         ":pywrap_tf_session",
@@ -4179,6 +4195,7 @@ cuda_py_test(
     srcs = ["ops/stateful_random_ops_test.py"],
     python_version = "PY3",
     xla_enable_strict_auto_jit = False,
+    xla_enabled = True,
     deps = [
         ":client_testlib",
         ":config",
@@ -7495,7 +7512,7 @@ tf_python_pybind_extension(
         ":pybind11_status",
         "@pybind11",
         "//tensorflow/core:core_cpu_headers_lib",
-        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:framework",
         "//tensorflow/core:gpu_id",
         "//tensorflow/core:protos_all_cc",
     ] + if_not_windows(["//tensorflow/core/grappler/costs:graph_properties"]),  # b/148556093,
@@ -7631,9 +7648,9 @@ tf_python_pybind_extension(
     deps = [
         ":pybind11_status",
         "//tensorflow/core:core_cpu_headers_lib",
-        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:framework",
         "//tensorflow/core:gpu_id",
-        "//tensorflow/core:lib_headers_for_pybind",
+        "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "@pybind11",
     ],
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index e376480d188..5e0355d644f 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 2, 24)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 2, 26)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
index 96b3b764864..535cf884dc6 100644
--- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
@@ -161,17 +161,49 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
 
     self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testWriteSnapshotRepeatAfterwards(self):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(compression=[
+              snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
+              snapshot.COMPRESSION_SNAPPY
+          ])))
+  def testWriteSnapshotRepeatAfterwards(self, compression):
     tmpdir = self.snapshot_dir
 
     dataset = dataset_ops.Dataset.range(10)
-    dataset = dataset.apply(snapshot.snapshot(tmpdir))
+    dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
     dataset = dataset.repeat(10)
     self.assertDatasetProduces(dataset, list(range(10)) * 10)
 
     self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
 
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(compression=[
+              snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
+              snapshot.COMPRESSION_SNAPPY
+          ])))
+  def testWriteSnapshotMixTypes(self, compression):
+    tmpdir = self.snapshot_dir
+
+    dataset = dataset_ops.Dataset.range(10)
+
+    def map_fn(x):
+      return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x)
+
+    dataset = dataset.map(map_fn)
+    dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
+    dataset = dataset.repeat(10)
+
+    expected = []
+    for i in range(10):
+      expected.append((i, str(i), str(2 * i), 2 * i))
+    self.assertDatasetProduces(dataset, expected * 10)
+
+    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
+
   @combinations.generate(test_base.default_test_combinations())
   def testSpecifySnapshotNameWriteAndRead(self):
     tmpdir = self.snapshot_dir
@@ -365,8 +397,14 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
       res3 = self.evaluate(next3())
       self.assertEqual(res2, res3)
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testReadSnapshotParallelAfterWrite(self):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(compression=[
+              snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
+              snapshot.COMPRESSION_SNAPPY
+          ])))
+  def testReadSnapshotParallelAfterWrite(self, compression):
     self.setUpTFRecord(10, 4000)
     filenames = self.test_filenames
 
@@ -383,7 +421,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
             tmpdir,
             shard_size_bytes=1024 * 1024,
             num_reader_threads=2,
-            reader_buffer_size=10))
+            reader_buffer_size=10,
+            compression=compression))
     self.assertDatasetProduces(dataset, expected, assert_items_equal=True)
 
     # remove the original files and try to read the data back only from
@@ -396,7 +435,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
             tmpdir,
             shard_size_bytes=1024 * 1024,
             num_reader_threads=2,
-            reader_buffer_size=10))
+            reader_buffer_size=10,
+            compression=compression))
     self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
 
   # Not testing Snappy here because Snappy reads currently require a lot of
@@ -514,21 +554,31 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
       self.evaluate(next2())
     self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testSpecifyShardSize(self):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(compression=[
+              snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
+              snapshot.COMPRESSION_SNAPPY
+          ])))
+  def testSpecifyShardSize(self, compression):
     tmpdir = self.snapshot_dir
 
     dataset = dataset_ops.Dataset.from_tensor_slices([1.0])
     dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024]))
     dataset = dataset.repeat(10)
     dataset = dataset.apply(
-        snapshot.snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024))
+        snapshot.snapshot(
+            tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression))
     next_fn = self.getNext(dataset)
 
     for _ in range(10):
       self.evaluate(next_fn())
 
-    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 3)
+    num_files = 1
+    if compression == snapshot.COMPRESSION_NONE:
+      num_files = 3
+    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files)
 
   @combinations.generate(test_base.default_test_combinations())
   def testAdditionalOperationsAfterReadBack(self):
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 2bd34b195e4..6596bfa8607 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -44,6 +44,7 @@ from tensorflow.python.data.util import traverse
 from tensorflow.python.eager import context
 from tensorflow.python.eager import function as eager_function
 from tensorflow.python.framework import auto_control_deps
+from tensorflow.python.framework import auto_control_deps_utils as acd_utils
 from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -4469,43 +4470,67 @@ def _collect_resource_inputs(op):
   """Collects resource inputs for the given ops (and its variant inputs)."""
 
   def _process(op_queue, seen_ops):
-    """Processes the next element of the op queue."""
+    """Processes the next element of the op queue.
 
-    result = []
+    Args:
+      op_queue: Queue of Dataset operations to process.
+      seen_ops: Already processed set of Operations.
+
+    Returns:
+      A 2-tuple containing sets of resource handles. The first tuple entry
+      contains read-only handles and the second entry contains read-write
+      handles.
+    """
+
+    reads = []
+    writes = []
     op = op_queue.pop()
     if op in seen_ops:
-      return result
+      return reads, writes
     seen_ops.add(op)
     for t in op.inputs:
       if t.dtype == dtypes.variant:
         # Conservatively assume that any variant inputs are datasets.
         op_queue.append(t.op)
       elif t.dtype == dtypes.resource:
-        result.append(t)
-    return result
+        # TODO(b/150139257): This always returns True right now since we have
+        # not updated the functional ops to set the special attribute that ACD
+        # uses to figure out which of the op's inputs are read-only.
+        if acd_utils.op_writes_to_resource(t, op):
+          writes.append(t)
+        else:
+          reads.append(t)
+    return reads, writes
 
   op_queue = [op]
   seen_ops = set()
-  resource_inputs = []
+  all_reads = []
+  all_writes = []
   while op_queue:
-    resource_inputs.extend(_process(op_queue, seen_ops))
+    reads, writes = _process(op_queue, seen_ops)
+    all_reads.extend(reads)
+    all_writes.extend(writes)
 
-  return resource_inputs
+  return all_reads, all_writes
 
 
 @auto_control_deps.register_acd_resource_resolver
-def _resource_resolver(op, resource_inputs):
+def _resource_resolver(op, resource_reads, resource_writes):
   """Updates resource inputs for tf.data ops with indirect dependencies."""
 
   updated = False
   if op.type in [
       "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset"
   ]:
-    indirect_resource_inputs = _collect_resource_inputs(op)
-    for inp in indirect_resource_inputs:
-      if inp not in resource_inputs:
+    reads, writes = _collect_resource_inputs(op)
+    for inp in reads:
+      if inp not in resource_reads:
         updated = True
-        resource_inputs.add(inp)
+        resource_reads.add(inp)
+    for inp in writes:
+      if inp not in resource_writes:
+        updated = True
+        resource_writes.add(inp)
 
   if op.type in [
       "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional"
@@ -4516,10 +4541,14 @@ def _resource_resolver(op, resource_inputs):
     ]
 
     if len(make_iterator_ops) == 1:
-      indirect_resource_inputs = _collect_resource_inputs(make_iterator_ops[0])
-      for inp in indirect_resource_inputs:
-        if inp not in resource_inputs:
+      reads, writes = _collect_resource_inputs(make_iterator_ops[0])
+      for inp in reads:
+        if inp not in resource_reads:
           updated = True
-          resource_inputs.add(inp)
+          resource_reads.add(inp)
+      for inp in writes:
+        if inp not in resource_writes:
+          updated = True
+          resource_writes.add(inp)
 
   return updated
diff --git a/tensorflow/python/debug/lib/dumping_callback_test.py b/tensorflow/python/debug/lib/dumping_callback_test.py
index 5382965ebc4..b76077e8def 100644
--- a/tensorflow/python/debug/lib/dumping_callback_test.py
+++ b/tensorflow/python/debug/lib/dumping_callback_test.py
@@ -1128,7 +1128,7 @@ class TracingCallbackTest(
         # 1st element: tensor ID; 2nd element: 0 indicating no inf or nan.
         self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0])
     elif tensor_debug_mode == "CONCISE_HEALTH":
-      for tensor_value in tensor_values:
+      for trace in graph_exec_traces:
         tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
         # 1st element: tensor ID.
         # 2nd element: element count. Remaining elements: all zero because there
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 7f4bc9641ae..efa10d6ee45 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -63,9 +63,11 @@ py_library(
     srcs = ["cross_device_ops.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":collective_util",
         ":cross_device_utils",
         ":device_util",
         ":reduce_util",
+        ":tpu_values",
         ":values",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:device_lib",
@@ -96,6 +98,7 @@ py_library(
         "//tensorflow/python:gradients",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:nccl_ops",
+        "//tensorflow/python:platform",
     ],
 )
 
@@ -144,6 +147,7 @@ py_library(
     ],
     srcs_version = "PY2AND3",
     deps = [
+        ":collective_util",
         ":device_util",
         ":numpy_dataset",
         ":reduce_util",
@@ -531,6 +535,7 @@ py_library(
         ":input_lib",
         ":numpy_dataset",
         ":reduce_util",
+        ":tpu_values",
         ":values",
         "//tensorflow/compiler/xla/experimental/xla_sharding",
         "//tensorflow/python:array_ops",
@@ -578,6 +583,15 @@ py_library(
     ],
 )
 
+py_library(
+    name = "collective_util",
+    srcs = ["collective_util.py"],
+    deps = [
+        "//tensorflow/python:util",
+        "//tensorflow/python:variable_scope",
+    ],
+)
+
 py_library(
     name = "shared_variable_creator",
     srcs = ["shared_variable_creator.py"],
@@ -612,18 +626,36 @@ py_library(
     deps = [
         ":device_util",
         ":distribute_lib",
+        ":reduce_util",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python:composite_tensor",
         "//tensorflow/python:control_flow_ops",
-        "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
-        "//tensorflow/python:resource_variable_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:tensor_util",
+        "//tensorflow/python:type_spec",
         "//tensorflow/python:util",
+        "//tensorflow/python:variable_scope",
+        "//tensorflow/python:variables",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/tpu:tpu_lib",
+        "//tensorflow/python/eager:tape",
         "//tensorflow/python/training/saving:saveable_object",
         "//tensorflow/python/training/saving:saveable_object_util",
         "//tensorflow/python/training/tracking:base",
-        "@six_archive//:six",
+    ],
+)
+
+py_library(
+    name = "tpu_values",
+    srcs = ["tpu_values.py"],
+    deps = [
+        ":values",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:resource_variable_ops_gen",
+        "//tensorflow/python/eager:context",
+        "//tensorflow/python/eager:tape",
+        "//tensorflow/python/tpu:tpu_lib",
     ],
 )
 
@@ -775,7 +807,9 @@ cuda_py_test(
     name = "cross_device_utils_test",
     srcs = ["cross_device_utils_test.py"],
     deps = [
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
+        "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python/distribute:combinations",
@@ -795,6 +829,7 @@ cuda_py_test(
     ],
     deps = [
         ":collective_all_reduce_strategy",
+        ":collective_util",
         ":mirrored_strategy",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
@@ -883,7 +918,7 @@ distribute_py_test(
     srcs = ["values_test.py"],
     main = "values_test.py",
     tags = [
-        "no_oss",  # http://b/119349471
+        "multi_and_single_gpu",
     ],
     deps = [
         ":mirrored_strategy",
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
index e77fce0f2a5..ab9721e1bfb 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -95,6 +95,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
         TFConfigClusterResolver which is instantiated from the TF_CONFIG env
         var.
     """
+    # TODO(b/150151677): consider move communication to CollectiveHints.
     super(CollectiveAllReduceStrategy, self).__init__(
         CollectiveAllReduceExtended(
             self,
@@ -505,7 +506,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
 
     return updated_config
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     if (isinstance(value, values.Mirrored) and
         reduce_op == reduce_util.ReduceOp.MEAN):
       return value
@@ -526,7 +527,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
       return cross_device_ops_lib.reduce_non_distributed_value(
           reduce_op, value, destinations, len(self.worker_devices))
     return self._get_cross_device_ops().reduce(
-        reduce_op, value, destinations=destinations)
+        reduce_op,
+        value,
+        destinations=destinations,
+        experimental_hints=experimental_hints)
 
   def _warn_nccl_no_gpu(self):
     if ((self._communication ==
diff --git a/tensorflow/python/distribute/collective_util.py b/tensorflow/python/distribute/collective_util.py
new file mode 100644
index 00000000000..fb7008d1636
--- /dev/null
+++ b/tensorflow/python/distribute/collective_util.py
@@ -0,0 +1,63 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for collectives."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("distribute.experimental.CollectiveHints")
+class Hints(object):
+  """Hints for collective operations like AllReduce.
+
+  This can be passed to methods like
+  `tf.distribute.get_replica_context().all_reduce()` to optimize collective
+  operation performance. Note that these are only hints, which may or may not
+  change the actual behavior. Some options only apply to certain strategy and
+  are ignored by others.
+
+  One common optimization is to break gradients all-reduce into multiple packs
+  so that weight updates can overlap with gradient all-reduce.
+
+  Example:
+
+  ```python
+  hints = tf.distribute.experimental.CollectiveHints(
+      bytes_per_pack=50 * 1024 * 1024)
+  grads = tf.distribute.get_replica_context().all_reduce(
+      'sum', grads, experimental_hints=hints)
+  optimizer.apply_gradients(zip(grads, vars), all_reduce_sum_gradients=False)
+  ```
+
+  """
+
+  def __init__(self, bytes_per_pack=0):
+    """Creates a CollectiveHints.
+
+    Args:
+      bytes_per_pack: A non-negative integer. Breaks collective operations into
+        packs of certain size. If it's zero, the value is determined
+        automatically. This only applies to all-reduce with
+        `MultiWorkerMirroredStrategy` currently.
+
+    Raises:
+      ValueError: When arguments have invalid value.
+    """
+    if bytes_per_pack < 0:
+      raise ValueError("bytes_per_pack must be non-negative")
+    self.bytes_per_pack = bytes_per_pack
diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index eae15cf292c..eeb45428bab 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -19,14 +19,16 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
-import enum
 
+import enum
 import six
 
 from tensorflow.python.client import device_lib
+from tensorflow.python.distribute import collective_util
 from tensorflow.python.distribute import cross_device_utils
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import reduce_util
+from tensorflow.python.distribute import tpu_values
 from tensorflow.python.distribute import values as value_lib
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
@@ -62,8 +64,8 @@ def validate_destinations(destinations):
   if not isinstance(
       destinations,
       (value_lib.DistributedValues, ops.Tensor, value_lib.AggregatingVariable,
-       six.string_types, value_lib.TPUMirroredVariable)
-  ) and not resource_variable_ops.is_resource_variable(destinations):
+       six.string_types, tpu_values.TPUMirroredVariable
+      )) and not resource_variable_ops.is_resource_variable(destinations):
     raise ValueError("destinations must be one of a `DistributedValues` object,"
                      " a tf.Variable object, or a device string.")
 
@@ -221,7 +223,11 @@ class CrossDeviceOps(object):
     # Returns 1 by default, the value may be overridden by sub classes.
     return 1
 
-  def reduce(self, reduce_op, per_replica_value, destinations):
+  def reduce(self,
+             reduce_op,
+             per_replica_value,
+             destinations,
+             experimental_hints=None):
     """Reduce `per_replica_value` to `destinations`.
 
     It runs the reduction operation defined by `reduce_op` and put the
@@ -230,8 +236,10 @@ class CrossDeviceOps(object):
     Args:
       reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
         per_replica_value will be reduced.
-      per_replica_value: a PerReplica object or a tensor with device set.
+      per_replica_value: A PerReplica object or a tensor with device set.
       destinations: the reduction destinations.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       a Mirrored object.
@@ -253,10 +261,15 @@ class CrossDeviceOps(object):
           per_replica_value.values,
           wrap_class=value_lib.Mirrored)
 
+    if experimental_hints is None:
+      experimental_hints = collective_util.Hints()
     return self.reduce_implementation(reduce_op, per_replica_value,
-                                      destinations)
+                                      destinations, experimental_hints)
 
-  def batch_reduce(self, reduce_op, value_destination_pairs):
+  def batch_reduce(self,
+                   reduce_op,
+                   value_destination_pairs,
+                   experimental_hints=None):
     """Reduce PerReplica objects in a batch.
 
     Reduce each first element in `value_destination_pairs` to each second
@@ -266,10 +279,12 @@ class CrossDeviceOps(object):
     fuse several tensors into one or multiple packs before reduction.
 
     Args:
-      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
-        the `per_replica_value` will be reduced.
-      value_destination_pairs: a list or a tuple of PerReplica objects
-        (or tensors with device set if there is one device) and destinations.
+      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the
+        `per_replica_value` will be reduced.
+      value_destination_pairs: A list or a tuple of PerReplica objects (or
+        tensors with device set if there is one device) and destinations.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       a list of Mirrored objects.
@@ -298,7 +313,10 @@ class CrossDeviceOps(object):
           for v, _ in value_destination_pairs
       ]
 
-    return self.batch_reduce_implementation(reduce_op, value_destination_pairs)
+    if experimental_hints is None:
+      experimental_hints = collective_util.Hints()
+    return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
+                                            experimental_hints)
 
   def broadcast(self, tensor, destinations):
     """Broadcast the `tensor` to destinations.
@@ -314,7 +332,8 @@ class CrossDeviceOps(object):
     return self.broadcast_implementation(tensor, destinations)
 
   @doc_controls.for_subclass_implementers
-  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
+  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
+                            experimental_hints):
     """The implementation of reduce of `per_replica_value` to `destinations`.
 
     Overriding this method is useful for subclass implementers.
@@ -325,8 +344,10 @@ class CrossDeviceOps(object):
     Args:
       reduce_op: An instance `tf.distribute.ReduceOp` that indicates of how
         per_replica_value will be reduced.
-      per_replica_value: a PerReplica object or a tensor with device set.
+      per_replica_value: A PerReplica object or a tensor with device set.
       destinations: the reduction destinations.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       a Mirrored object.
@@ -339,7 +360,8 @@ class CrossDeviceOps(object):
         "_reduce method must be implemented in descendants.")
 
   @doc_controls.for_subclass_implementers
-  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
+  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
+                                  experimental_hints):
     """Implementation of reduce PerReplica objects in a batch.
 
     Overriding this method is useful for subclass implementers.
@@ -350,8 +372,10 @@ class CrossDeviceOps(object):
     Args:
       reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
         per_replica_value will be reduced.
-      value_destination_pairs: an iterable of tuples of PerReplica objects
+      value_destination_pairs: An iterable of tuples of PerReplica objects
         (or tensors with device set if there is one device) and destinations.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       a list of Mirrored objects.
@@ -361,7 +385,8 @@ class CrossDeviceOps(object):
         tuples of PerReplica objects and destinations
     """
     raise NotImplementedError(
-        "_batch_reduce method must be implemented in descendants.")
+        "batch_reduce_implementation method must be implemented in descendants."
+    )
 
   @doc_controls.for_subclass_implementers
   def broadcast_implementation(self, tensor, destinations):
@@ -402,7 +427,9 @@ class ReductionToOneDevice(CrossDeviceOps):
     self.accumulation_fn = accumulation_fn or math_ops.add_n
     super(ReductionToOneDevice, self).__init__()
 
-  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
+  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
+                            experimental_hints):
+    del experimental_hints  # Unused.
     if check_destinations(destinations):
       devices = get_devices_from(destinations)
     else:
@@ -415,9 +442,11 @@ class ReductionToOneDevice(CrossDeviceOps):
                              self.accumulation_fn, reduce_op)
     return self.broadcast(reduced, destinations)
 
-  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
+  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
+                                  experimental_hints):
     return [
-        self.reduce_implementation(reduce_op, t, destinations=v)
+        self.reduce_implementation(
+            reduce_op, t, destinations=v, experimental_hints=experimental_hints)
         for t, v in value_destination_pairs
     ]
 
@@ -625,21 +654,24 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
     self._simple_cross_replica_ops = ReductionToOneDevice()
     super(AllReduceCrossDeviceOps, self).__init__()
 
-  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
+  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
+                            experimental_hints):
+    del experimental_hints  # Unused.
     if _devices_match(per_replica_value, destinations):
       return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
     else:
       return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
                                                    destinations)
 
-  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
+  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
+                                  experimental_hints):
     if _all_devices_match(value_destination_pairs):
       return self._batch_all_reduce(reduce_op,
                                     [v[0] for v in value_destination_pairs])
     else:
       return [
-          self.reduce_implementation(reduce_op, t, destinations=v)
-          for t, v in value_destination_pairs
+          self.reduce_implementation(reduce_op, value, dest, experimental_hints)
+          for value, dest in value_destination_pairs
       ]
 
   def _batch_all_reduce(self, reduce_op, per_replica_values):
@@ -903,7 +935,6 @@ class CollectiveAllReduce(CrossDeviceOps):
   def __init__(self,
                num_workers=1,
                num_gpus_per_worker=0,
-               num_packs=1,
                collective_keys=None,
                communication=CollectiveCommunication.AUTO):
     """Initializes the object.
@@ -911,13 +942,11 @@ class CollectiveAllReduce(CrossDeviceOps):
     Args:
       num_workers: number of workers in the between-graph replicated training.
       num_gpus_per_worker: number of GPUs per worker.
-      num_packs: gradients will be packed into `num_packs` chunks.
       collective_keys: an optional CollectiveKey object.
       communication: indicates which collective communication to use.
     """
     self._num_workers = num_workers
     self._num_gpus_per_worker = num_gpus_per_worker
-    self._num_packs = num_packs
     self._collective_keys = (collective_keys or
                              cross_device_utils.CollectiveKeys())
     self._communication = communication
@@ -927,8 +956,10 @@ class CollectiveAllReduce(CrossDeviceOps):
   def _num_between_graph_workers(self):
     return self._num_workers
 
-  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
-    all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0]
+  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
+                            experimental_hints):
+    all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
+                                         experimental_hints)[0]
     devices = get_devices_from(destinations)
 
     if (isinstance(all_reduced, value_lib.Mirrored) and
@@ -957,11 +988,13 @@ class CollectiveAllReduce(CrossDeviceOps):
             index.append(array_ops.identity(all_reduced._primary))  # pylint: disable=protected-access
     return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
 
-  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
+  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
+                                  experimental_hints):
     all_devices_match = _all_devices_match(value_destination_pairs)
     if all_devices_match:
       return self._batch_all_reduce(reduce_op,
-                                    [v[0] for v in value_destination_pairs])
+                                    [v[0] for v in value_destination_pairs],
+                                    experimental_hints)
     else:
       if not all_devices_match:
         logging.log_first_n(
@@ -969,47 +1002,18 @@ class CollectiveAllReduce(CrossDeviceOps):
             "destinations are different.", 10)
 
       return [
-          self.reduce_implementation(reduce_op, t, destinations=v)
-          for t, v in value_destination_pairs
+          self.reduce_implementation(reduce_op, value, dest, experimental_hints)
+          for value, dest in value_destination_pairs
       ]
 
-  def _make_gradient_chunks(self, per_replica_values, num_packs):
-    """Make `per_replica_values` into chunks."""
-    chunked_by_device = _group_value_by_device(per_replica_values)
-    chunked_by_var = list(zip(*chunked_by_device))
-    # chunked_by_var is chunked by variables and takes the following format:
-    # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..),
-    #  ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..),
-    #  ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..),
-    #  ...
-    # ]
-
-    # No chunking if number of variables is fewer than number of packs.
-    if len(chunked_by_var) < num_packs:
-      return [chunked_by_var]
-
-    # First n-1 chunks get `chunk_size` grads, last chunk gets leftover grads.
-    # This strategy can cause the last chunk to have larger size compared to the
-    # first n-1 chunks.  Alternatively, we can increment chunk_size by 1 to get
-    # slightly larger first n-1 chunks and smaller last chunk.
-    # TODO(ayushd): compare different packing strategies.
-    chunk_size = len(chunked_by_var) // num_packs
-    leftover_size = len(chunked_by_var) - chunk_size * (num_packs - 1)
-    assert leftover_size > 0
-    chunked_gv = [
-        chunked_by_var[x:x + chunk_size]
-        for x in range(0, len(chunked_by_var) - leftover_size, chunk_size)
-    ]
-    chunked_gv.append(chunked_by_var[-leftover_size:])
-
-    return chunked_gv
-
-  def _batch_all_reduce(self, reduce_op, per_replica_values):
+  def _batch_all_reduce(self, reduce_op, per_replica_values,
+                        experimental_hints):
     """All reduce algorithm in a batch."""
     dense_values, dense_indices, sparse_values, sparse_indices = (
         cross_device_utils.split_by_sparsity(per_replica_values))
     if dense_values:
-      dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values)
+      dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values,
+                                                      experimental_hints)
     else:
       dense_results = []
     if sparse_values:
@@ -1017,83 +1021,84 @@ class CollectiveAllReduce(CrossDeviceOps):
                                                         sparse_values)
     else:
       sparse_results = []
-    return cross_device_utils.stitch_values(((dense_results, dense_indices),
-                                             (sparse_results, sparse_indices)))
+    return cross_device_utils.stitch_values(
+        ((dense_results, dense_indices), (sparse_results, sparse_indices)))
 
-  def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values):
+  def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values,
+                                 experimental_hints):
     """All-reduce across all workers in a batch."""
 
-    chunked_gv = self._make_gradient_chunks(per_replica_values, self._num_packs)
-    # Actual number of packs may be different from `self._num_packs`.  e.g. if
-    # there are fewer tensors than `self._num_packs`.
-    num_actual_packs = len(chunked_gv)
-
     batch_size = len(per_replica_values)
     # Pass self._communication to the runtime as a communication hint.
-    communication_hint = self._communication.value
+    communication = self._communication.value
     # For now, we use NCCL only when batch_size > 1.
     # TODO(b/132575814): switch to NCCL for all collectives when communication
     # is NCCL.
     if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
-      communication_hint = CollectiveCommunication.AUTO.value
+      communication = CollectiveCommunication.AUTO.value
+
+    # Reverse the lists so that there's better chance that values follows
+    # the order in which they are calculated (e.g. when they're gradients), so
+    # as to overlap calculation with communication. However, this may not be
+    # optimal for cases like gradients of complicated non-sequential models.
+    #
+    # Note that we reverse the list before packing so that the first pack won't
+    # be too small, since it's more likely for first few packs to have long
+    # queuing time due to concurrent intense computation.
+    #
+    # TODO(b/147393503): explore solutions for optimal ordering.
+    packs = cross_device_utils.pack_by_size(
+        list(reversed(per_replica_values)), experimental_hints.bytes_per_pack)
 
     if batch_size > 1:
       logging.info(
           "Collective batch_all_reduce: %d all-reduces, num_workers = %d, "
-          "communication_hint = %s, num_packs = %d" %
-          (batch_size, self._num_workers, communication_hint, num_actual_packs))
+          "communication_hint = %s, num_packs = %d", batch_size,
+          self._num_workers, communication, len(packs))
     else:
       logging.log_first_n(
           logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
           "num_workers = %d, communication_hint = %s, num_packs = %d" %
-          (batch_size, self._num_workers, communication_hint, num_actual_packs),
-          10)
+          (batch_size, self._num_workers, communication, len(packs)), 10)
 
     def batch_fn():
       """Wrapper function around batched all-reduce calls."""
-      reduced_gv_list = []
-      # Reverse the gradient lists so that the gradient grouping roughly follows
-      # the order in which gradients are calculated in backprop.  This should
-      # enable overlapping gradient all-reduce with backprop for most models.
-      # However, it is likely that for some complicated non-sequential models
-      # this grouping is not optimal.
-      #
-      # TODO(b/147393503): explore solutions for optimal gradient grouping.
-      for chunk in reversed(chunked_gv):
-        # By placing all CollectiveReduce ops in a chunk under single name
-        # scope, we ensure they will be picked up by the `ScopedAllocator`
-        # grappler optimizer and packed into a single all-reduce.
+      reduced_values = []
+      for pack in packs:
+        # By placing all CollectiveReduce ops in a pack under single name scope,
+        # we ensure they will be picked up by the `ScopedAllocator` grappler
+        # optimizer and packed into a single all-reduce.
         with ops.name_scope("allreduce"):
-          for grad_and_vars in reversed(chunk):
-            # Gradients for the same variable but from different devices.
-            grads = [g for g, _ in grad_and_vars]
+          for per_replica in pack:
             # Add control dependencies per device from the last gradients to the
             # current set, in order to serialize NCCL launches.
-            if (communication_hint == CollectiveCommunication.NCCL.value and
-                reduced_gv_list):
-              control_input_grads = [g for g, _ in reduced_gv_list[-1]]
+            if (communication == CollectiveCommunication.NCCL.value and
+                reduced_values):
+              control_inputs = [g for g in reduced_values[-1]]
             else:
-              control_input_grads = None
-            collective_reduced = cross_device_utils.build_collective_reduce(
-                grads, self._num_workers, self._collective_keys, "Add", "Id",
-                communication_hint, control_input_grads)
-            result = []
-            for (_, v), g in zip(grad_and_vars, collective_reduced):
-              result.append([g, v])
-            reduced_gv_list.append(result)
-      # Reverse the batch reduced gradients to (approximately) recover the order
-      # in the input per_replica_values.
-      reduced_gv_list.reverse()
-      return reduced_gv_list
+              control_inputs = None
+            reduced_values.append(
+                cross_device_utils.build_collective_reduce(
+                    per_replica.values, self._num_workers,
+                    self._collective_keys, "Add", "Id", communication,
+                    control_inputs))
+      return reduced_values
+
     if context.executing_eagerly():
       batch_fn = def_function.function(batch_fn)
 
-    new_device_grads = [list(x) for x in zip(*batch_fn())]
-    return _ungroup_and_make_mirrored(
-        new_device_grads,
-        per_replica_values[0],
-        reduce_op,
-        num_between_graph_workers=self._num_workers)
+    reduced_values = batch_fn()
+    mirrored = []
+    # Reverse the order of reduced value to recover the order in the input.
+    for value in reversed(reduced_values):
+      if reduce_op == reduce_util.ReduceOp.MEAN:
+        # Assume each worker has the same number of replicas.
+        num_replicas = len(value) * self._num_workers
+        for i, v in enumerate(value):
+          with ops.device(v.device):
+            value[i] = v / num_replicas
+      mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
+    return mirrored
 
   def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
     """All-reduce IndexedSlices across all workers in a batch."""
@@ -1105,48 +1110,31 @@ class CollectiveAllReduce(CrossDeviceOps):
 
     # Pass self._communication to the runtime as a communication hint.
     communication_hint = self._communication.value
-    # For now, we use NCCL only when batch_size > 1 and num_packs is 1.
-    # TODO(b/132575814): Enable NCCL if num_packs > 1.
-    # TODO(b/132575814): Switch to NCCL for all collectives when communication
+    # For now, we use NCCL only when batch_size > 1.
+    # TODO(b/132575814): switch to NCCL for all collectives when communication
     # is NCCL.
-    if self._communication == CollectiveCommunication.NCCL and (
-        len(per_replica_values) == 1 or self._num_packs != 1):
+    if self._communication == CollectiveCommunication.NCCL and len(
+        per_replica_values) == 1:
       communication_hint = CollectiveCommunication.AUTO.value
 
-    chunked_gv = self._make_gradient_chunks(per_replica_values, self._num_packs)
+    gathered_values = []
+    with ops.name_scope("allreduce"):
+      for per_replica in per_replica_values:
+        gathered_values.append(
+            cross_device_utils.build_collective_gather_indexed_slices(
+                per_replica.values, self._num_workers, self._collective_keys,
+                communication_hint))
 
-    reduced_gv_list = []
-    for chunk in chunked_gv:
-      # By placing all CollectiveReduce ops in a chunk under single name scope,
-      # we ensure they will be picked up by the `ScopedAllocator` grappler
-      # optimizer and packed into a single all-reduce.
-      with ops.name_scope("allreduce"):
-        for grad_and_vars in chunk:
-          grads = [g for g, _ in grad_and_vars]
-
-          # Add control dependencies per device from the last gradients to the
-          # current set, in order to serialize NCCL launches.
-          if (communication_hint == CollectiveCommunication.NCCL.value and
-              reduced_gv_list):
-            control_input_grads = [g for g, _ in reduced_gv_list[-1]]
-          else:
-            control_input_grads = None
-
-          collective_reduced = (
-              cross_device_utils.build_collective_gather_indexed_slices(
-                  grads, self._num_workers, self._collective_keys,
-                  communication_hint, control_input_grads))
-          result = []
-          for (_, v), g in zip(grad_and_vars, collective_reduced):
-            result.append([g, v])
-          reduced_gv_list.append(result)
-
-    new_device_grads = [list(x) for x in zip(*reduced_gv_list)]
-    return _ungroup_and_make_mirrored(
-        new_device_grads,
-        per_replica_values[0],
-        reduce_op,
-        num_between_graph_workers=self._num_workers)
+    mirrored = []
+    for value in gathered_values:
+      if reduce_op == reduce_util.ReduceOp.MEAN:
+        # Assume each worker has the same number of replicas.
+        num_replicas = len(value) * self._num_workers
+        for i, v in enumerate(value):
+          with ops.device(v.device):
+            value[i].values = value[i].values / num_replicas
+      mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
+    return mirrored
 
 
 def choose_the_best(devices, session_config=None):
diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py
index 3a9d7b7ec44..aba5316f7b5 100644
--- a/tensorflow/python/distribute/cross_device_ops_test.py
+++ b/tensorflow/python/distribute/cross_device_ops_test.py
@@ -24,6 +24,7 @@ from absl.testing import parameterized
 import numpy as np
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.distribute import collective_all_reduce_strategy
+from tensorflow.python.distribute import collective_util
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
 from tensorflow.python.distribute import cross_device_utils
@@ -463,8 +464,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
                         num_gpus=0,
                         communication=CollectiveCommunication.AUTO,
                         use_strategy_object=False,
-                        local_mode=False,
-                        num_packs=1):
+                        local_mode=False):
     collective_keys = cross_device_utils.CollectiveKeys(
         group_key_start=10 + CollectiveAllReduceTest.collective_key_base,
         op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base,
@@ -487,7 +487,6 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
             1,
             num_gpus,
             collective_keys=collective_keys,
-            num_packs=num_packs,
             communication=communication)
         return collective_all_reduce_ops, devices, ""
     else:
@@ -520,7 +519,6 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
             NUM_WORKERS,
             num_gpus,
             collective_keys=collective_keys,
-            num_packs=num_packs,
             communication=communication)
         return (collective_all_reduce_ops, devices,
                 "grpc://" + self._cluster_spec[task_type][task_id])
@@ -532,15 +530,14 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
                       communication,
                       use_strategy_object=False,
                       local_mode=False,
-                      num_packs=1):
+                      hints=None):
     collective_all_reduce, devices, master_target = self._get_test_objects(
         task_type,
         task_id,
         num_gpus,
         communication=communication,
         use_strategy_object=use_strategy_object,
-        local_mode=local_mode,
-        num_packs=num_packs)
+        local_mode=local_mode)
     if local_mode:
       num_workers = 1
       worker_device = None
@@ -553,17 +550,19 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
       if use_strategy_object:
         with test_object.scope():
           return test_object.extended.reduce_to(reduce_op, per_replica,
-                                                destinations)
+                                                destinations, hints)
       else:
-        return test_object.reduce(reduce_op, per_replica, destinations)
+        return test_object.reduce(reduce_op, per_replica, destinations, hints)
 
     def _batch_reduce(test_object, reduce_op, value_destination_pairs):
       if use_strategy_object:
         with test_object.scope():
           return test_object.extended.batch_reduce_to(reduce_op,
-                                                      value_destination_pairs)
+                                                      value_destination_pairs,
+                                                      hints)
       else:
-        return test_object.batch_reduce(reduce_op, value_destination_pairs)
+        return test_object.batch_reduce(reduce_op, value_destination_pairs,
+                                        hints)
 
     with ops.Graph().as_default(), \
          ops.device(worker_device), \
@@ -724,16 +723,17 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
           mode=["graph"],
           required_gpus=[0, 1, 2],
           use_strategy_object=[True, False],
-          num_packs=[1, 2]))
+          bytes_per_pack=[0, 1, 4]))
   def testReductionDistributed(self, required_gpus, use_strategy_object,
-                               num_packs):
+                               bytes_per_pack):
+    hints = collective_util.Hints(bytes_per_pack=bytes_per_pack)
     self._run_between_graph_clients(
         self._test_reduction,
         self._cluster_spec,
         required_gpus,
         communication=CollectiveCommunication.RING,
         use_strategy_object=use_strategy_object,
-        num_packs=num_packs)
+        hints=hints)
 
   @combinations.generate(
       combinations.combine(
diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py
index fa6f612af17..f9917385b59 100644
--- a/tensorflow/python/distribute/cross_device_utils.py
+++ b/tensorflow/python/distribute/cross_device_utils.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import collective_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nccl_ops
+from tensorflow.python.platform import tf_logging as logging
 
 
 OP_INSTANCE_KEY_START_NUMBER = 100
@@ -896,6 +897,67 @@ def stitch_values(values_and_indices_list):
   return result
 
 
+def per_replica_num_elements(per_replica):
+  """Returns the static number of elements of one replica.
+
+  Args:
+    per_replica: A PerReplica of Tensor or IndexedSlices.
+
+  Returns:
+    Number of elements. None if some replica has a different or unknown shape.
+  """
+
+  values = per_replica._values  # pylint: disable=protected-access
+  s0 = values[0].shape
+  for v in values:
+    assert not isinstance(v, ops.IndexedSlices)
+    if v.shape != s0:
+      return None
+  return s0.num_elements()
+
+
+def pack_by_size(per_replica_list, bytes_per_pack):
+  """Packs `per_replica_list` into chunks of `bytes_per_pack`.
+
+  The method preserves the original order of `per_replica_list`. The packing is
+  best effort, each pack could have more or less bytes than `bytes_per_pack`.
+  It only packs values with known shape. Note that, the usage is different from
+  `cross_device_ops._pack_tensors`, this function is intended to work with the
+  ScopeAllocator style batching used in `CollectiveAllReduce`.
+
+  Args:
+    per_replica_list: A list of PerReplica.
+    bytes_per_pack: Bytes per pack.
+
+  Returns:
+    A list of packs of PerReplica. All values are packed into one pack if
+      `bytes_per_pack` is zero or any of the value has unknown shape.
+  """
+
+  if bytes_per_pack == 0:
+    return [per_replica_list]
+  packs = []
+  last_pack_size = 0
+  for value in per_replica_list:
+    num_elements = per_replica_num_elements(value)
+    if num_elements is None:
+      # Can't pack values with unknown shape.
+      logging.warning(
+          'not packing values due to the unknown or inconsistent shape of %s',
+          value)
+      return [per_replica_list]
+    size = num_elements * value._primary.dtype.size  # pylint: disable=protected-access
+    # Try to keep each pack as close to bytes_per_pack as possible, while each
+    # pack is at least bytes_per_pack large. I.E. we err on the side of having
+    # few but large packs.
+    if not packs or last_pack_size > bytes_per_pack:
+      packs.append([])
+      last_pack_size = 0
+    packs[-1].append(value)
+    last_pack_size += size
+  return packs
+
+
 def _control_input(inputs, control_inputs, idx):
   """Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies.
 
diff --git a/tensorflow/python/distribute/cross_device_utils_test.py b/tensorflow/python/distribute/cross_device_utils_test.py
index 217883ea21b..ae0f3b8fe76 100644
--- a/tensorflow/python/distribute/cross_device_utils_test.py
+++ b/tensorflow/python/distribute/cross_device_utils_test.py
@@ -26,8 +26,11 @@ from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import values as value_lib
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import input_layer
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 
 
@@ -133,8 +136,86 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
 
     self.assertIsInstance(result, ops.IndexedSlices)
     self._assert_values_equal(t, result)
-    self.assertEqual(device_util.resolve(destination),
-                     device_util.resolve(result.device))
+    self.assertEqual(
+        device_util.resolve(destination), device_util.resolve(result.device))
+
+
+class PackBySizeTest(test.TestCase):
+
+  def assertShape(self, per_replica, shape):
+    for v in per_replica._values:  # pylint: disable=protected-access
+      self.assertEqual(v.shape, shape)
+
+  def testPreferLargerPack(self):
+    # Each packs except the last one should be equal or larger than
+    # bytes_per_pack.
+    values = [
+        # size = 2 * 4 * 4 * 4 = 128
+        array_ops.ones([2, 4, 4], dtype=dtypes.float32),
+        # size = 8 * 4 = 32
+        array_ops.ones([8], dtype=dtypes.int32),
+        # size = 10 * 10 * 8 = 800
+        array_ops.ones([10, 10], dtype=dtypes.int64),
+        # size = 1 * 4 = 4
+        array_ops.ones([1], dtype=dtypes.int32),
+    ]
+    per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
+    packs = cross_device_utils.pack_by_size(
+        per_replica_values, bytes_per_pack=200)
+    self.assertLen(packs, 2)
+    self.assertLen(packs[0], 3)
+    self.assertShape(packs[0][0], [2, 4, 4])
+    self.assertShape(packs[0][1], [8])
+    self.assertShape(packs[0][2], [10, 10])
+    self.assertLen(packs[1], 1)
+    self.assertShape(packs[1][0], [1])
+
+  def testZeroBytesPerPack(self):
+    values = [
+        array_ops.ones([1], dtype=dtypes.float32),
+        array_ops.ones([2], dtype=dtypes.float32),
+    ]
+    per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
+    packs = cross_device_utils.pack_by_size(
+        per_replica_values, bytes_per_pack=0)
+    self.assertLen(packs, 1)
+    self.assertLen(packs[0], 2)
+    self.assertShape(packs[0][0], [1])
+    self.assertShape(packs[0][1], [2])
+
+  def testUnknownShape(self):
+    per_replica_values = [
+        value_lib.PerReplica([
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+        ]),
+        value_lib.PerReplica([
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+            input_layer.Input(
+                shape=(10), batch_size=None, dtype=dtypes.float32),
+        ]),
+    ]
+    packs = cross_device_utils.pack_by_size(
+        per_replica_values, bytes_per_pack=1)
+    self.assertLen(packs, 1)
+    self.assertEqual(packs[0], per_replica_values)
+
+  def testInconsistentShape(self):
+    per_replica_values = [
+        value_lib.PerReplica([
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+        ]),
+        value_lib.PerReplica([
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+            input_layer.Input(
+                shape=(10), batch_size=None, dtype=dtypes.float32),
+        ]),
+    ]
+    packs = cross_device_utils.pack_by_size(
+        per_replica_values, bytes_per_pack=1)
+    self.assertLen(packs, 1)
+    self.assertEqual(packs[0], per_replica_values)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index 7fac3fa3d9e..1d1e44f97c9 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -108,6 +108,7 @@ import six
 from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
 from tensorflow.python.autograph.impl import api as autograph
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import collective_util
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import numpy_dataset
@@ -1719,10 +1720,10 @@ class StrategyExtendedV2(object):
   def _reduce(self, reduce_op, value):
     # Default implementation until we have an implementation for each strategy.
     return self._local_results(
-        self._reduce_to(reduce_op, value,
-                        device_util.current() or "/device:CPU:0"))[0]
+        self.reduce_to(reduce_op, value,
+                       device_util.current() or "/device:CPU:0"))[0]
 
-  def reduce_to(self, reduce_op, value, destinations):
+  def reduce_to(self, reduce_op, value, destinations, experimental_hints=None):
     """Combine (via e.g. sum or mean) values across replicas.
 
     Args:
@@ -1732,6 +1733,8 @@ class StrategyExtendedV2(object):
         string. The return value will be copied to all destination devices (or
         all the devices where the `destinations` value resides). To perform an
         all-reduction, pass `value` to `destinations`.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       A tensor or value mirrored to `destinations`.
@@ -1744,18 +1747,25 @@ class StrategyExtendedV2(object):
       reduce_op = reduce_util.ReduceOp(reduce_op.upper())
     assert (reduce_op == reduce_util.ReduceOp.SUM or
             reduce_op == reduce_util.ReduceOp.MEAN)
-    return self._reduce_to(reduce_op, value, destinations)
+    if experimental_hints is None:
+      experimental_hints = collective_util.Hints()
+    return self._reduce_to(reduce_op, value, destinations, experimental_hints)
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     raise NotImplementedError("must be implemented in descendants")
 
-  def batch_reduce_to(self, reduce_op, value_destination_pairs):
+  def batch_reduce_to(self,
+                      reduce_op,
+                      value_destination_pairs,
+                      experimental_hints=None):
     """Combine multiple `reduce_to` calls into one for faster execution.
 
     Args:
       reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
-      value_destination_pairs: A sequence of (value, destinations)
-        pairs. See `reduce_to()` for a description.
+      value_destination_pairs: A sequence of (value, destinations) pairs. See
+        `reduce_to()` for a description.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       A list of mirrored values, one per pair in `value_destination_pairs`.
@@ -1765,11 +1775,16 @@ class StrategyExtendedV2(object):
     assert not isinstance(reduce_op, variable_scope.VariableAggregation)
     if isinstance(reduce_op, six.string_types):
       reduce_op = reduce_util.ReduceOp(reduce_op.upper())
-    return self._batch_reduce_to(reduce_op, value_destination_pairs)
+    if experimental_hints is None:
+      experimental_hints = collective_util.Hints()
+    return self._batch_reduce_to(reduce_op, value_destination_pairs,
+                                 experimental_hints)
 
-  def _batch_reduce_to(self, reduce_op, value_destination_pairs):
+  def _batch_reduce_to(self, reduce_op, value_destination_pairs,
+                       experimental_hints):
     return [
-        self.reduce_to(reduce_op, t, destinations=v)
+        self.reduce_to(
+            reduce_op, t, destinations=v, experimental_hints=experimental_hints)
         for t, v in value_destination_pairs
     ]
 
@@ -2267,7 +2282,7 @@ class ReplicaContext(object):
     require_replica_context(self)
     return (device_util.current(),)
 
-  def all_reduce(self, reduce_op, value):
+  def all_reduce(self, reduce_op, value, experimental_hints=None):
     """All-reduces the given `value Tensor` nest across replicas.
 
     If `all_reduce` is called in any replica, it must be called in all replicas.
@@ -2289,16 +2304,21 @@ class ReplicaContext(object):
       reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
       value: The nested structure of `Tensor`s to all-reduce. The structure must
         be compatible with `tf.nest`.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
        A `Tensor` nest with the reduced `value`s from each replica.
     """
     if isinstance(reduce_op, six.string_types):
       reduce_op = reduce_util.ReduceOp(reduce_op.upper())
+    if experimental_hints is None:
+      experimental_hints = collective_util.Hints()
 
     def batch_all_reduce(strategy, *value_flat):
       return strategy.extended.batch_reduce_to(
-          reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat])
+          reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
+          experimental_hints)
 
     if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
       # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
@@ -2449,9 +2469,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
         replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
       return fn(*args, **kwargs)
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     # TODO(josh11b): Use destinations?
-    del reduce_op, destinations
+    del reduce_op, destinations, experimental_hints
     return value
 
   def _update(self, var, fn, args, kwargs, group):
diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py
index 8c7ad0ae40d..bac623ada52 100644
--- a/tensorflow/python/distribute/distribute_lib_test.py
+++ b/tensorflow/python/distribute/distribute_lib_test.py
@@ -95,8 +95,8 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
   def _local_results(self, value):
     return (value,)
 
-  def _reduce_to(self, reduce_op, value, destinations):
-    del reduce_op, destinations
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
+    del reduce_op, destinations, experimental_hints
     return value
 
   def _experimental_make_numpy_dataset(self, numpy_input, session):
diff --git a/tensorflow/python/distribute/distributed_file_utils.py b/tensorflow/python/distribute/distributed_file_utils.py
index 9e4a2b202f1..b8e3fd8a8c9 100644
--- a/tensorflow/python/distribute/distributed_file_utils.py
+++ b/tensorflow/python/distribute/distributed_file_utils.py
@@ -50,7 +50,6 @@ from __future__ import print_function
 
 import os
 from tensorflow.python.distribute import distribution_strategy_context
-from tensorflow.python.distribute import multi_worker_util
 from tensorflow.python.lib.io import file_io
 
 
@@ -80,7 +79,7 @@ def write_dirpath(dirpath, strategy=None):
     return dirpath
   if not strategy.extended._in_multi_worker_mode():  # pylint: disable=protected-access
     return dirpath
-  if multi_worker_util.is_chief():
+  if strategy.extended.should_checkpoint:
     return dirpath
   # If this worker is not chief and hence should not save file, save it to a
   # temporary directory to be removed later.
@@ -96,8 +95,10 @@ def remove_temp_dirpath(dirpath, strategy=None):
     # If strategy is still not available, this is not in distributed training.
     # Fallback to no-op.
     return
-  if strategy.extended._in_multi_worker_mode():  # pylint: disable=protected-access
-    if not multi_worker_util.is_chief():
+  # TODO(anjalisridhar): Consider removing the check for multi worker mode since
+  # it is redundant when used with the should_checkpoint property.
+  if (strategy.extended._in_multi_worker_mode() and  # pylint: disable=protected-access
+      not strategy.extended.should_checkpoint):
       # If this worker is not chief and hence should not save file, remove
       # the temporary directory.
-      file_io.delete_recursively(_get_temp_dir(dirpath, strategy))
+    file_io.delete_recursively(_get_temp_dir(dirpath, strategy))
diff --git a/tensorflow/python/distribute/distribution_strategy_context.py b/tensorflow/python/distribute/distribution_strategy_context.py
index 24d6a67187f..29593d65c5d 100644
--- a/tensorflow/python/distribute/distribution_strategy_context.py
+++ b/tensorflow/python/distribute/distribution_strategy_context.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import contextlib
 import threading
 
 from tensorflow.python.framework import ops
@@ -266,6 +267,20 @@ def experimental_set_strategy(strategy):
     ops.get_default_graph()._global_distribute_strategy_scope = new_scope  # pylint: disable=protected-access
 
 
+# ------------------------------------------------------------------------------
+# Internal helpers.
+
+
+@contextlib.contextmanager
+def enter_or_assert_strategy(strategy):
+  if not has_strategy():
+    with strategy.scope():
+      yield
+  else:
+    _assert_strategy(strategy)
+    yield
+
+
 # ------------------------------------------------------------------------------
 # Defaults that are used when no tf.distribute.Strategy is explicitly created.
 # We create them lazily in a function so that we can workaround the circular
@@ -284,6 +299,17 @@ _default_replica_context_lock = threading.Lock()
 _default_replica_mode_lock = threading.Lock()
 
 
+def _assert_strategy(strategy):
+  if not has_strategy():
+    raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
+                       (strategy,))
+  current_strategy = get_strategy()
+  if current_strategy is not strategy:
+    raise RuntimeError(
+        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
+        (current_strategy, strategy))
+
+
 def _get_default_strategy():
   if _defaults["strategy"] is None:
     # Avoid race condition causing two defaults to be created
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index e57c656139a..baa6e1ac76e 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -788,7 +788,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
   def _get_cross_device_ops(self):
     return self._cross_device_ops or self._inferred_cross_device_ops
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     if (isinstance(value, values.Mirrored) and
         reduce_op == reduce_util.ReduceOp.MEAN):
       return value
@@ -801,11 +801,16 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
       return cross_device_ops_lib.reduce_non_distributed_value(
           reduce_op, value, destinations, self._num_replicas_in_sync)
     return self._get_cross_device_ops().reduce(
-        reduce_op, value, destinations=destinations)
+        reduce_op,
+        value,
+        destinations=destinations,
+        experimental_hints=experimental_hints)
 
-  def _batch_reduce_to(self, reduce_op, value_destination_pairs):
+  def _batch_reduce_to(self, reduce_op, value_destination_pairs,
+                       experimental_hints):
     return self._get_cross_device_ops().batch_reduce(reduce_op,
-                                                     value_destination_pairs)
+                                                     value_destination_pairs,
+                                                     experimental_hints)
 
   def _update(self, var, fn, args, kwargs, group):
     # TODO(josh11b): In eager mode, use one thread per device.
diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py
index 9f4b07a3e75..0ab4018ce13 100644
--- a/tensorflow/python/distribute/mirrored_strategy_test.py
+++ b/tensorflow/python/distribute/mirrored_strategy_test.py
@@ -1356,14 +1356,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
     def forward(x, w, b):
       return x * w + b
 
-    x = constant_op.constant([1.0], name="x_useless")
+    x = array_ops.identity([1.0], name="x_useless")
     concrete_forward = forward.get_concrete_function(x, w._primary, b._primary)
 
     with distribution.scope():
 
       def replica_fn():
         with backprop.GradientTape() as t:
-          x = constant_op.constant([1.0], name="x")
+          x = array_ops.identity([1.0], name="x")
           loss = concrete_forward(x, w._get(), b._get()) - [1.0]
           return t.gradient(loss, [w, b])
 
diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py
index 1fb3538b517..5a6973a699b 100644
--- a/tensorflow/python/distribute/one_device_strategy.py
+++ b/tensorflow/python/distribute/one_device_strategy.py
@@ -356,8 +356,8 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
     with ops.device(self._device), _OneDeviceReplicaContext(strategy):
       return fn(*args, **kwargs)
 
-  def _reduce_to(self, reduce_op, value, destinations):
-    del reduce_op, destinations
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
+    del reduce_op, destinations, experimental_hints
     return value
 
   def _update(self, var, fn, args, kwargs, group):
diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py
index d27bacf6be7..f785a0c6266 100644
--- a/tensorflow/python/distribute/parameter_server_strategy.py
+++ b/tensorflow/python/distribute/parameter_server_strategy.py
@@ -466,20 +466,25 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
             "Cannot reduce to another worker: %r, current worker is %r" %
             (d, self._input_workers.worker_devices[0]))
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     self._verify_destinations_not_different_worker(destinations)
     if not isinstance(value, values.DistributedValues):
       # pylint: disable=protected-access
       return cross_device_ops_lib.reduce_non_distributed_value(
           reduce_op, value, destinations, self._num_replicas_in_sync)
     return self._cross_device_ops.reduce(
-        reduce_op, value, destinations=destinations)
+        reduce_op,
+        value,
+        destinations=destinations,
+        experimental_hints=experimental_hints)
 
-  def _batch_reduce_to(self, reduce_op, value_destination_pairs):
+  def _batch_reduce_to(self, reduce_op, value_destination_pairs,
+                       experimental_hints):
     for _, destinations in value_destination_pairs:
       self._verify_destinations_not_different_worker(destinations)
     return self._cross_device_ops.batch_reduce(reduce_op,
-                                               value_destination_pairs)
+                                               value_destination_pairs,
+                                               experimental_hints)
 
   def _select_single_value(self, structured):
     """Select any single value in `structured`."""
diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py
index c889484ae68..00730959d4e 100644
--- a/tensorflow/python/distribute/strategy_test_lib.py
+++ b/tensorflow/python/distribute/strategy_test_lib.py
@@ -688,9 +688,9 @@ class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase):
 
   def _testDeviceScope(self, distribution):
     with distribution.scope():
-      a = constant_op.constant(1.)
+      a = array_ops.identity(1.)
       with ops.device("/cpu:0"):
-        b = constant_op.constant(1.)
+        b = array_ops.identity(1.)
       if context.executing_eagerly():
         device = "/job:worker/replica:0/task:0/device:CPU:0"
       else:
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 54e2028ccaf..7176a2e2dc9 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -34,6 +34,7 @@ from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import input_lib
 from tensorflow.python.distribute import numpy_dataset
 from tensorflow.python.distribute import reduce_util
+from tensorflow.python.distribute import tpu_values
 from tensorflow.python.distribute import values
 from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
 from tensorflow.python.eager import context
@@ -543,7 +544,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
 
     self._logical_device_stack.append(logical_device_id)
     try:
-      if values._enclosing_tpu_context() is None:  # pylint: disable=protected-access
+      if tpu_values.enclosing_tpu_context() is None:
         yield
       else:
         with ops.device(tpu.core(logical_device_id)):
@@ -648,20 +649,20 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
           with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
             v = next_creator(**kwargs)
 
-          assert not isinstance(v, values.TPUMirroredVariable)
+          assert not isinstance(v, tpu_values.TPUMirroredVariable)
           value_list.append(v)
       return value_list
 
     return values.create_mirrored_variable(self._container_strategy(),
                                            _real_mirrored_creator,
-                                           values.TPUMirroredVariable,
-                                           values.TPUSyncOnReadVariable,
+                                           tpu_values.TPUMirroredVariable,
+                                           tpu_values.TPUSyncOnReadVariable,
                                            **kwargs)
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     if (isinstance(value, values.DistributedValues) or
         tensor_util.is_tensor(value)
-       ) and values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
+       ) and tpu_values.enclosing_tpu_context() is not None:
       if reduce_op == reduce_util.ReduceOp.MEAN:
         # TODO(jhseu):  Revisit once we support model-parallelism.
         value *= (1. / self._num_replicas_in_sync)
@@ -701,9 +702,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
     return output
 
   def _update(self, var, fn, args, kwargs, group):
-    assert isinstance(var, values.TPUVariableMixin) or isinstance(
+    assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
         var, resource_variable_ops.BaseResourceVariable)
-    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
+    if tpu_values.enclosing_tpu_context() is not None:
       if group:
         return fn(var, *args, **kwargs)
       else:
@@ -724,7 +725,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
     return values.update_regroup(self, updates, group)
 
   def read_var(self, var):
-    assert isinstance(var, values.TPUVariableMixin) or isinstance(
+    assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
         var, resource_variable_ops.BaseResourceVariable)
     return var.read_value()
 
@@ -745,7 +746,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
     # since the `1` gets broadcast as an int32 but global_step is int64.
     if isinstance(tensor, (float, int)):
       return tensor
-    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
+    if tpu_values.enclosing_tpu_context() is not None:
       broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
       result = tpu_ops.all_to_all(
           broadcast_tensor,
diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py
new file mode 100644
index 00000000000..871c85405e2
--- /dev/null
+++ b/tensorflow/python/distribute/tpu_values.py
@@ -0,0 +1,245 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various classes representing TPU distributed values.
+
+Note that the tests are in values_test.py .
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+from tensorflow.python.distribute import distribution_strategy_context as ds_context
+from tensorflow.python.distribute import values
+from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.tpu import tpu
+
+
+@contextlib.contextmanager
+def _maybe_enter_graph(tensor):
+  # Note: might have an eager tensor but not be executing eagerly when
+  # building functions.
+  if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
+      ops.has_default_graph()):
+    yield
+  else:
+    with tensor.graph.as_default():
+      yield
+
+
+def _make_raw_assign_fn(raw_assign_fn):  # pylint: disable=missing-docstring
+
+  def assign_fn(var, value, use_locking=False, name=None, read_value=True):  # pylint: disable=missing-docstring
+    del use_locking  # Unused.
+
+    with _maybe_enter_graph(var.handle):
+      op = raw_assign_fn(
+          var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
+
+      with ops.control_dependencies([op]):
+        return var._read_variable_op() if read_value else op  # pylint: disable=protected-access
+
+  return assign_fn
+
+
+class TPUVariableMixin(object):
+  """Mixin for TPU variables."""
+
+  def __init__(self, *args, **kwargs):
+    super(TPUVariableMixin, self).__init__(*args, **kwargs)
+
+    # Handle ID is needed for `get_replicated_var_handle` to cache the variables
+    # correctly since in eager mode different variables can have the same name.
+    if ops.executing_eagerly_outside_functions():
+      self._handle_id = self._common_name + "_" + str(id(self._primary))
+    else:
+      self._handle_id = self._common_name
+
+  def __getattr__(self, name):
+    if enclosing_tpu_context() is None:
+      return super(TPUVariableMixin, self).__getattr__(name)
+    else:
+      raise AttributeError(
+          "'{}' not accessible within a TPU context.".format(name))
+
+  def get(self):
+    if enclosing_tpu_context() is None:
+      return super(TPUVariableMixin, self).get()
+    else:
+      raise NotImplementedError(
+          "`TPUVariableMixin.get()` is not supported within a TPU context.")
+
+  def _get_as_operand(self):
+    return self.read_value()
+
+  def _get_closest(self):
+    if enclosing_tpu_context() is None:
+      return super(TPUVariableMixin, self)._get_closest()
+    else:
+      return self._primary
+
+  def numpy(self):
+    if context.executing_eagerly():
+      return self.read_value().numpy()
+    else:
+      raise NotImplementedError(
+          "numpy() is only available when eager execution is enabled.")
+
+  def _is_mirrored(self):
+    raise NotImplementedError(
+        "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
+
+  @property
+  def handle(self):
+    # If we're in a tpu.rewrite(), return the replicated handle.
+    tpu_context = enclosing_tpu_context()
+    if tpu_context is None:
+      return self._get_closest().handle
+    else:
+      return tpu_context.get_replicated_var_handle(self._handle_id,
+                                                   self._values,
+                                                   self._is_mirrored())
+
+  @property
+  def device(self):
+    return self.handle.device
+
+  def _read_variable_op(self):
+    if self.trainable:
+      tape.variable_accessed(self)
+    return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
+
+  def read_value(self):
+    if enclosing_tpu_context() is None:
+      return super(TPUVariableMixin, self).read_value()
+    else:
+      return self._read_variable_op()
+
+  def value(self):
+    if enclosing_tpu_context() is None:
+      return super(TPUVariableMixin, self).value()
+    else:
+      return self._read_variable_op()
+
+  def _as_graph_element(self):
+    if enclosing_tpu_context() is None:
+      return super(TPUVariableMixin, self)._as_graph_element()  # pylint: disable=protected-access
+    else:
+      return None
+
+  @property
+  def op(self):
+    return values.DistributedVarOp(self._primary.op.name,
+                                   self._primary.op.graph,
+                                   self._primary.op.traceback,
+                                   self._primary.op.type)
+
+  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+    """Converts a variable to a tensor."""
+    # pylint: disable=protected-access
+    if enclosing_tpu_context() is None:
+      return super(TPUVariableMixin, self)._dense_var_to_tensor(
+          dtype=dtype, name=name, as_ref=as_ref)
+    # pylint: enable=protected-access
+    elif dtype is not None and dtype != self.dtype:
+      return math_ops.cast(self.read_value(), dtype)
+    else:
+      return self.handle if as_ref else self.read_value()
+
+
+def enclosing_tpu_context():
+  """Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
+  graph = ops.get_default_graph()
+  while graph is not None:
+    # pylint: disable=protected-access
+    context_ = graph._get_control_flow_context()
+    # pylint: enable=protected-access
+    while context_ is not None:
+      if isinstance(context_, tpu.TPUReplicateContext):
+        return context_
+      context_ = context_.outer_context
+    # This may be a FuncGraph due to defuns or v2 control flow. We need to
+    # find the original graph with the XLAControlFlowContext.
+    graph = getattr(graph, "outer_graph", None)
+  return None
+
+
+class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
+  """Holds a map from replica to TPU variables whose values are kept in sync."""
+
+  def _assign_func(self, *args, **kwargs):
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
+      if (ds_context.in_cross_replica_context() and
+          (enclosing_tpu_context() is not None)):
+        f = kwargs.pop("f")
+        return self._distribute_strategy.extended.update(
+            self, f, args=args, kwargs=kwargs)
+      else:
+        return values.MirroredVariable._assign_func(self, *args, **kwargs)
+
+  def assign_sub(self, *args, **kwargs):
+    assign_sub_fn = _make_raw_assign_fn(
+        gen_resource_variable_ops.assign_sub_variable_op)
+    return self._assign_func(f=assign_sub_fn, *args, **kwargs)
+
+  def assign_add(self, *args, **kwargs):
+    assign_add_fn = _make_raw_assign_fn(
+        gen_resource_variable_ops.assign_add_variable_op)
+    return self._assign_func(f=assign_add_fn, *args, **kwargs)
+
+  def assign(self, *args, **kwargs):
+    assign_fn = _make_raw_assign_fn(
+        gen_resource_variable_ops.assign_variable_op)
+    return self._assign_func(f=assign_fn, *args, **kwargs)
+
+  def _is_mirrored(self):
+    return True
+
+
+class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
+  """Holds a map from replica to variables whose values are reduced on save."""
+
+  def assign_sub(self, *args, **kwargs):
+    if enclosing_tpu_context() is None:
+      return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
+    else:
+      return _make_raw_assign_fn(
+          gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
+                                                            **kwargs)
+
+  def assign_add(self, *args, **kwargs):
+    if enclosing_tpu_context() is None:
+      return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
+    else:
+      return _make_raw_assign_fn(
+          gen_resource_variable_ops.assign_add_variable_op)(self, *args,
+                                                            **kwargs)
+
+  def assign(self, *args, **kwargs):
+    if enclosing_tpu_context() is None:
+      return values.SyncOnReadVariable.assign(self, *args, **kwargs)
+    else:
+      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
+          self, *args, **kwargs)
+
+  def _is_mirrored(self):
+    return False
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index 48adb5159bf..c23819cde11 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -19,12 +19,11 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
-import contextlib
 import weakref
 
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import distribute_lib
-from tensorflow.python.distribute import distribution_strategy_context
+from tensorflow.python.distribute import distribution_strategy_context as ds_context
 from tensorflow.python.distribute import reduce_util
 from tensorflow.python.eager import context
 from tensorflow.python.eager import tape
@@ -34,11 +33,9 @@ from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import type_spec
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_resource_variable_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.ops import variables as variables_lib
-from tensorflow.python.tpu import tpu
 from tensorflow.python.training.saving import saveable_object
 from tensorflow.python.training.saving import saveable_object_util
 from tensorflow.python.training.tracking import base as trackable
@@ -48,7 +45,7 @@ from tensorflow.python.util.tf_export import tf_export
 
 def _get_current_replica_id_as_int():
   """Returns the current replica ID as an integer, or `None`."""
-  replica_context = distribution_strategy_context.get_replica_context()
+  replica_context = ds_context.get_replica_context()
   if replica_context:
     replica_id = replica_context.replica_id_in_sync_group
     if not isinstance(replica_id, int):
@@ -362,7 +359,7 @@ class PerReplicaSpec(type_spec.TypeSpec):
     return self._value_specs
 
   def _to_components(self, value):
-    replica_context = distribution_strategy_context.get_replica_context()
+    replica_context = ds_context.get_replica_context()
     if replica_context is not None and replica_context.num_replicas_in_sync > 1:
       raise ValueError(
           "Flattening a PerReplica to components is not supported in replica "
@@ -405,27 +402,6 @@ def _assign_sub_on_device(device, variable, tensor):
     return variable.assign_sub(tensor)
 
 
-def _assert_strategy(strategy):
-  if not distribution_strategy_context.has_strategy():
-    raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
-                       (strategy,))
-  current_strategy = distribution_strategy_context.get_strategy()
-  if current_strategy is not strategy:
-    raise RuntimeError(
-        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
-        (current_strategy, strategy))
-
-
-@contextlib.contextmanager
-def _enter_or_assert_strategy(strategy):
-  if not distribution_strategy_context.has_strategy():
-    with strategy.scope():
-      yield
-  else:
-    _assert_strategy(strategy)
-    yield
-
-
 DistributedVarOp = collections.namedtuple(
     "DistributedVarOp", ["name", "graph", "traceback", "type"])
 
@@ -578,7 +554,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
     # We want cross-replica code that does some var.op.X calls
     # to work (even if the current device isn't in self._devices), but
     # other uses of var.op in a cross-replica context to fail.
-    if distribution_strategy_context.in_cross_replica_context():
+    if ds_context.in_cross_replica_context():
       return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
                               self._primary.op.traceback, self._primary.op.type)
     return self._get().op
@@ -588,7 +564,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
     return self._primary._in_graph_mode  # pylint: disable=protected-access
 
   def read_value(self):
-    with _enter_or_assert_strategy(self._distribute_strategy):
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
       return array_ops.identity(self._get())
 
   def value(self):
@@ -602,135 +578,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
 ops.register_dense_tensor_like_type(DistributedVariable)
 
 
-@contextlib.contextmanager
-def _maybe_enter_graph(tensor):
-  # Note: might have an eager tensor but not be executing eagerly when
-  # building functions.
-  if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
-      ops.has_default_graph()):
-    yield
-  else:
-    with tensor.graph.as_default():
-      yield
-
-
-def _make_raw_assign_fn(raw_assign_fn):  # pylint: disable=missing-docstring
-
-  def assign_fn(var, value, use_locking=False, name=None, read_value=True):  # pylint: disable=missing-docstring
-    del use_locking  # Unused.
-
-    with _maybe_enter_graph(var.handle):
-      op = raw_assign_fn(
-          var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
-
-      with ops.control_dependencies([op]):
-        return var._read_variable_op() if read_value else op  # pylint: disable=protected-access
-
-  return assign_fn
-
-
-class TPUVariableMixin(object):
-  """Mixin for TPU variables."""
-
-  def __init__(self, *args, **kwargs):
-    super(TPUVariableMixin, self).__init__(*args, **kwargs)
-
-    # Handle ID is needed for `get_replicated_var_handle` to cache the variables
-    # correctly since in eager mode different variables can have the same name.
-    if ops.executing_eagerly_outside_functions():
-      self._handle_id = self._common_name + "_" + str(id(self._primary))
-    else:
-      self._handle_id = self._common_name
-
-  def __getattr__(self, name):
-    if _enclosing_tpu_context() is None:
-      return super(TPUVariableMixin, self).__getattr__(name)
-    else:
-      raise AttributeError(
-          "'{}' not accessible within a TPU context.".format(name))
-
-  def get(self):
-    if _enclosing_tpu_context() is None:
-      return super(TPUVariableMixin, self).get()
-    else:
-      raise NotImplementedError(
-          "`TPUVariableMixin.get()` is not supported within a TPU context.")
-
-  def _get_as_operand(self):
-    return self.read_value()
-
-  def _get_closest(self):
-    if _enclosing_tpu_context() is None:
-      return super(TPUVariableMixin, self)._get_closest()
-    else:
-      return self._primary
-
-  def numpy(self):
-    if context.executing_eagerly():
-      return self.read_value().numpy()
-    else:
-      raise NotImplementedError(
-          "numpy() is only available when eager execution is enabled.")
-
-  def _is_mirrored(self):
-    raise NotImplementedError(
-        "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
-
-  @property
-  def handle(self):
-    # If we're in a tpu.rewrite(), return the replicated handle.
-    tpu_context = _enclosing_tpu_context()
-    if tpu_context is None:
-      return self._get_closest().handle
-    else:
-      return tpu_context.get_replicated_var_handle(
-          self._handle_id, self._values, self._is_mirrored())
-
-  @property
-  def device(self):
-    return self.handle.device
-
-  def _read_variable_op(self):
-    if self.trainable:
-      tape.variable_accessed(self)
-    return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
-
-  def read_value(self):
-    if _enclosing_tpu_context() is None:
-      return super(TPUVariableMixin, self).read_value()
-    else:
-      return self._read_variable_op()
-
-  def value(self):
-    if _enclosing_tpu_context() is None:
-      return super(TPUVariableMixin, self).value()
-    else:
-      return self._read_variable_op()
-
-  def _as_graph_element(self):
-    if _enclosing_tpu_context() is None:
-      return super(TPUVariableMixin, self)._as_graph_element()  # pylint: disable=protected-access
-    else:
-      return None
-
-  @property
-  def op(self):
-    return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
-                            self._primary.op.traceback, self._primary.op.type)
-
-  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
-    """Converts a variable to a tensor."""
-    # pylint: disable=protected-access
-    if _enclosing_tpu_context() is None:
-      return super(TPUVariableMixin, self)._dense_var_to_tensor(
-          dtype=dtype, name=name, as_ref=as_ref)
-    # pylint: enable=protected-access
-    elif dtype is not None and dtype != self.dtype:
-      return math_ops.cast(self.read_value(), dtype)
-    else:
-      return self.handle if as_ref else self.read_value()
-
-
 def _validate_colocate_extended(v, extended):
   variable_strategy = v._distribute_strategy  # pylint: disable=protected-access
   if variable_strategy.extended is not extended:
@@ -888,9 +735,9 @@ class MirroredVariable(DistributedVariable, Mirrored):
   # update_non_slot() function (like OptimizerV2._finish), which can
   # update several non-slot variables in one call.
   def _assign_func(self, *args, **kwargs):
-    with _enter_or_assert_strategy(self._distribute_strategy):
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
       f = kwargs.pop("f")
-      if distribution_strategy_context.in_cross_replica_context():
+      if ds_context.in_cross_replica_context():
         update_replica_id = distribute_lib.get_update_replica_id()
         if update_replica_id is not None:
           # We are calling an assign function on the mirrored variable in an
@@ -933,7 +780,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
           return strategy.extended.update(
               self, f, args=(v,) + other_args, kwargs=other_kwargs)
 
-        return distribution_strategy_context.get_replica_context().merge_call(
+        return ds_context.get_replica_context().merge_call(
             merge_fn, args=args, kwargs=kwargs)
 
   def assign_sub(self, *args, **kwargs):
@@ -1003,60 +850,11 @@ ops.register_tensor_conversion_function(Mirrored,
                                         _tensor_conversion_mirrored_val)
 
 
-def _enclosing_tpu_context():
-  """Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
-  graph = ops.get_default_graph()
-  while graph is not None:
-    # pylint: disable=protected-access
-    context_ = graph._get_control_flow_context()
-    # pylint: enable=protected-access
-    while context_ is not None:
-      if isinstance(context_, tpu.TPUReplicateContext):
-        return context_
-      context_ = context_.outer_context
-    # This may be a FuncGraph due to defuns or v2 control flow. We need to
-    # find the original graph with the XLAControlFlowContext.
-    graph = getattr(graph, "outer_graph", None)
-  return None
-
-
 def is_distributed_variable(v):
   """Determine if a variable is ds variable or TPU mirrored variable."""
   return isinstance(v, DistributedVariable)
 
 
-class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
-  """Holds a map from replica to TPU variables whose values are kept in sync."""
-
-  def _assign_func(self, *args, **kwargs):
-    with _enter_or_assert_strategy(self._distribute_strategy):
-      if (distribution_strategy_context.in_cross_replica_context() and
-          (_enclosing_tpu_context() is not None)):
-        f = kwargs.pop("f")
-        return self._distribute_strategy.extended.update(
-            self, f, args=args, kwargs=kwargs)
-      else:
-        return MirroredVariable._assign_func(self, *args, **kwargs)
-
-  def assign_sub(self, *args, **kwargs):
-    assign_sub_fn = _make_raw_assign_fn(
-        gen_resource_variable_ops.assign_sub_variable_op)
-    return self._assign_func(f=assign_sub_fn, *args, **kwargs)
-
-  def assign_add(self, *args, **kwargs):
-    assign_add_fn = _make_raw_assign_fn(
-        gen_resource_variable_ops.assign_add_variable_op)
-    return self._assign_func(f=assign_add_fn, *args, **kwargs)
-
-  def assign(self, *args, **kwargs):
-    assign_fn = _make_raw_assign_fn(
-        gen_resource_variable_ops.assign_variable_op)
-    return self._assign_func(f=assign_fn, *args, **kwargs)
-
-  def _is_mirrored(self):
-    return True
-
-
 class _SyncOnReadSaveable(saveable_object.SaveableObject):
   """Class for defining how to restore a SyncOnReadVariable."""
 
@@ -1094,7 +892,7 @@ class _SyncOnReadSaveable(saveable_object.SaveableObject):
 
 
 def _assert_replica_context(strategy):
-  replica_context = distribution_strategy_context.get_replica_context()
+  replica_context = ds_context.get_replica_context()
   if not replica_context:
     raise RuntimeError(
         "Replica-local variables may only be assigned in a replica context.")
@@ -1111,8 +909,8 @@ class SyncOnReadVariable(DistributedVariable):
     self._aggregation = aggregation
 
   def assign_sub(self, *args, **kwargs):
-    with _enter_or_assert_strategy(self._distribute_strategy):
-      if distribution_strategy_context.in_cross_replica_context():
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
+      if ds_context.in_cross_replica_context():
         if self._aggregation == vs.VariableAggregation.SUM:
           raise ValueError(
               "SyncOnReadVariable does not support `assign_sub` in "
@@ -1126,8 +924,8 @@ class SyncOnReadVariable(DistributedVariable):
         return self._get().assign_sub(*args, **kwargs)
 
   def assign_add(self, *args, **kwargs):
-    with _enter_or_assert_strategy(self._distribute_strategy):
-      if distribution_strategy_context.in_cross_replica_context():
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
+      if ds_context.in_cross_replica_context():
         if self._aggregation == vs.VariableAggregation.SUM:
           raise ValueError(
               "SyncOnReadVariable does not support `assign_add` in "
@@ -1141,8 +939,8 @@ class SyncOnReadVariable(DistributedVariable):
         return self._get().assign_add(*args, **kwargs)
 
   def assign(self, *args, **kwargs):
-    with _enter_or_assert_strategy(self._distribute_strategy):
-      if distribution_strategy_context.in_cross_replica_context():
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
+      if ds_context.in_cross_replica_context():
         # To preserve the sum across save and restore, we have to divide the
         # total across all devices when restoring a variable that was summed
         # when saving.
@@ -1155,8 +953,8 @@ class SyncOnReadVariable(DistributedVariable):
         return self._get().assign(*args, **kwargs)
 
   def value(self):
-    with _enter_or_assert_strategy(self._distribute_strategy):
-      if distribution_strategy_context.in_cross_replica_context():
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
+      if ds_context.in_cross_replica_context():
         return self._get_cross_replica()
       else:
         # _get_closest() returns a Variable.
@@ -1177,7 +975,7 @@ class SyncOnReadVariable(DistributedVariable):
     if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
       return self._primary
 
-    with _enter_or_assert_strategy(self._distribute_strategy):
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
       return self._distribute_strategy.reduce(
           reduce_util.ReduceOp.from_variable_aggregation(self.aggregation),
           self,
@@ -1185,8 +983,8 @@ class SyncOnReadVariable(DistributedVariable):
 
   def _as_graph_element(self):
     # pylint: disable=protected-access
-    with _enter_or_assert_strategy(self._distribute_strategy):
-      if distribution_strategy_context.in_cross_replica_context():
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
+      if ds_context.in_cross_replica_context():
         return ops.convert_to_tensor(self._get_cross_replica())
     return self._get()._as_graph_element()
 
@@ -1207,7 +1005,7 @@ class SyncOnReadVariable(DistributedVariable):
 
   def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
     """Converts a variable to a tensor."""
-    with _enter_or_assert_strategy(self._distribute_strategy):
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
       return ops.convert_to_tensor(
           self._get(), dtype=dtype, name=name, as_ref=as_ref)
 
@@ -1222,36 +1020,6 @@ ops.register_tensor_conversion_function(SyncOnReadVariable,
                                         _tensor_conversion_sync_on_read)
 
 
-class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
-  """Holds a map from replica to variables whose values are reduced on save."""
-
-  def assign_sub(self, *args, **kwargs):
-    if _enclosing_tpu_context() is None:
-      return SyncOnReadVariable.assign_sub(self, *args, **kwargs)
-    else:
-      return _make_raw_assign_fn(
-          gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
-                                                            **kwargs)
-
-  def assign_add(self, *args, **kwargs):
-    if _enclosing_tpu_context() is None:
-      return SyncOnReadVariable.assign_add(self, *args, **kwargs)
-    else:
-      return _make_raw_assign_fn(
-          gen_resource_variable_ops.assign_add_variable_op)(self, *args,
-                                                            **kwargs)
-
-  def assign(self, *args, **kwargs):
-    if _enclosing_tpu_context() is None:
-      return SyncOnReadVariable.assign(self, *args, **kwargs)
-    else:
-      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
-          self, *args, **kwargs)
-
-  def _is_mirrored(self):
-    return False
-
-
 def regroup(values, wrap_class=PerReplica):
   """Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
   v0 = values[0]
@@ -1444,9 +1212,9 @@ class AggregatingVariable(variables_lib.Variable):
     return getattr(self._v, name)
 
   def _assign_func(self, *args, **kwargs):
-    with _enter_or_assert_strategy(self._distribute_strategy):
+    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
       f = kwargs.pop("f")
-      if distribution_strategy_context.in_cross_replica_context():
+      if ds_context.in_cross_replica_context():
         if distribute_lib.get_update_replica_id() is not None:
           # We are calling an assign function in an update context.
           return f(self._v, *args, **kwargs)
@@ -1456,7 +1224,7 @@ class AggregatingVariable(variables_lib.Variable):
         return self._distribute_strategy.extended.update(
             self, f, args=args, kwargs=kwargs)
       else:
-        replica_context = distribution_strategy_context.get_replica_context()
+        replica_context = ds_context.get_replica_context()
         assert replica_context
         # We are calling an assign function in replica context.
         # We reduce the value we want to assign/add/sub. More details about how
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index f2922e6e53a..d66726424a1 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import tpu_strategy
+from tensorflow.python.distribute import tpu_values
 from tensorflow.python.distribute import values
 from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
 from tensorflow.python.eager import context
@@ -824,7 +825,7 @@ def _make_replica_local(method, strategy=None):
           name=n, initializer=init, use_resource=True))
 
   if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
-    var_cls = values.TPUSyncOnReadVariable
+    var_cls = tpu_values.TPUSyncOnReadVariable
   else:
     var_cls = values.SyncOnReadVariable
   replica_local = var_cls(strategy, v, method)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 7aef5da11f2..a6bc0b7ec56 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -878,6 +878,24 @@ tpu_py_test(
     ],
 )
 
+tpu_py_test(
+    name = "remote_cloud_tpu_pod_test",
+    srcs = ["remote_cloud_tpu_test.py"],
+    args = ["--num_tpu_devices=32"],
+    main = "remote_cloud_tpu_test.py",
+    python_version = "PY3",
+    tags = [
+        "notap",
+        "tpu_pod",
+    ],
+    deps = [
+        ":context",
+        ":remote",
+        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
+        "//tensorflow/python/tpu:tpu_strategy_util",
+    ],
+)
+
 cuda_py_test(
     name = "device_placement_test",
     size = "small",
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index 273777090c6..47b3966827f 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -79,7 +79,6 @@ class TFETest(test_util.TensorFlowTestCase):
 
   def setUp(self):
     super(TFETest, self).setUp()
-    ops.device(None).__enter__()
     context._reset_context()
     configure_virtual_cpus()
 
@@ -395,7 +394,7 @@ class TFETest(test_util.TensorFlowTestCase):
 
   def testMultiCpuPlacement(self):
     with ops.device('cpu:1'):
-      x = constant_op.constant(1.0)
+      x = array_ops.identity(1.0)
     with ops.device('cpu:0'):
       y = array_ops.identity(x)
     self.assertEqual(x.device, '/job:localhost/replica:0/task:0/device:CPU:1')
@@ -1064,7 +1063,6 @@ class SendRecvTest(test_util.TensorFlowTestCase):
 
   def setUp(self):
     super(SendRecvTest, self).setUp()
-    ops.device(None).__enter__()
     context._reset_context()
     configure_virtual_cpus()
 
@@ -1084,7 +1082,7 @@ class SendRecvTest(test_util.TensorFlowTestCase):
   def testLocalCrossDevice(self):
     gpu_device_name = '/job:localhost/replica:0/task:0/device:GPU:0'
     with ops.device('GPU:0'):
-      t0 = constant_op.constant(1.0)
+      t0 = array_ops.identity(1.0)
       self._send(t0, 't0', self.cpu_device)
     with ops.device('cpu:0'):
       self.assertAllEqual(
@@ -1101,7 +1099,6 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase):
 
   def setUp(self):
     super(EagerTensorCacheTest, self).setUp()
-    ops.device(None).__enter__()
     context._reset_context()
     configure_virtual_cpus()
 
@@ -1115,4 +1112,5 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase):
 
 
 if __name__ == '__main__':
+  context.set_log_device_placement(True)
   test.main()
diff --git a/tensorflow/python/eager/device_placement_test.py b/tensorflow/python/eager/device_placement_test.py
index 32ca6d3a826..af6c68243b4 100644
--- a/tensorflow/python/eager/device_placement_test.py
+++ b/tensorflow/python/eager/device_placement_test.py
@@ -18,24 +18,26 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import remote
 from tensorflow.python.eager import test
 from tensorflow.python.framework import config
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 
 
-class SoftDevicePlacementTest(test.TestCase):
+class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase):
 
   def setUp(self):
     super(SoftDevicePlacementTest, self).setUp()
-    context._context = None
-    ops.enable_eager_execution_internal()
+    context._reset_context()
     config.set_soft_device_placement(enabled=True)
     context.context().log_device_placement = True
 
@@ -90,13 +92,21 @@ class SoftDevicePlacementTest(test.TestCase):
     # We don't support nested device placement right now.
     self.assertIn('GPU:0', c.device)
 
+  @parameterized.named_parameters(('float', 1.0, None),
+                                  ('int32', [1], dtypes.int32),
+                                  ('string', ['a'], None))
+  def testSoftPlacedCPUConstant(self, value, dtype):
+    with ops.device('GPU:0'):
+      a = constant_op.constant(value, dtype=dtype)
+    self.assertIn('CPU:0', a.device)
+    self.assertIn('CPU:0', a.backing_device)
 
-class HardDevicePlacementTest(test.TestCase):
+
+class HardDevicePlacementTest(test.TestCase, parameterized.TestCase):
 
   def setUp(self):
     super(HardDevicePlacementTest, self).setUp()
-    context._context = None
-    ops.enable_eager_execution_internal()
+    context._reset_context()
     config.set_soft_device_placement(enabled=False)
     context.context().log_device_placement = True
     self.assertEqual(config.get_soft_device_placement(), False)
@@ -114,13 +124,27 @@ class HardDevicePlacementTest(test.TestCase):
       self.assertIn('GPU:0', y.device)
       self.assertIn('GPU:0', y.backing_device)
 
+  @parameterized.named_parameters(('float_cpu0', 'CPU:0', 1.0, None),
+                                  ('int32_cpu0', 'CPU:0', [1], dtypes.int32),
+                                  ('string_cpu0', 'CPU:0', ['a'], None),
+                                  ('float_gpu0', 'GPU:0', 1.0, None),
+                                  ('int32_gpu0', 'GPU:0', [1], dtypes.int32),
+                                  ('string_gpu0', 'GPU:0', ['a'], None),
+                                  ('float_gpu99', 'GPU:99', 1.0, None),
+                                  ('int32_gpu99', 'GPU:99', [1], dtypes.int32),
+                                  ('string_gpu99', 'GPU:99', ['a'], None))
+  def testHardPlacedCPUConstant(self, device, value, dtype):
+    with ops.device(device):
+      a = constant_op.constant(value, dtype=dtype)
+      self.assertIn('CPU:0', a.device)
+      self.assertIn('CPU:0', a.backing_device)
+
 
 class ClusterPlacementTest(test.TestCase):
 
   def setUp(self):
     super(ClusterPlacementTest, self).setUp()
-    context._context = None
-    ops.enable_eager_execution_internal()
+    context._reset_context()
     config.set_soft_device_placement(enabled=True)
     context.context().log_device_placement = True
     workers, _ = test_util.create_local_cluster(2, 0)
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 0a34d4a3852..7b599a995e2 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1570,10 +1570,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
   def testColocateWithRespected(self):
     # TODO(b/113291792): Use multiple CPUs instead of a GPU.
     with ops.device('cpu:0'):
-      x = constant_op.constant(1.0)
+      x = array_ops.identity(1.0)
 
     with ops.device('gpu:0'):
-      y = constant_op.constant(1.0)
+      y = array_ops.identity(1.0)
 
     @def_function.function
     def foo():
@@ -3239,9 +3239,9 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
       return b, a
 
     with ops.device('/device:CPU:0'):
-      a = constant_op.constant(3.0)
+      a = array_ops.identity(3.0)
     with ops.device('/device:GPU:0'):
-      b = constant_op.constant(5.0)
+      b = array_ops.identity(5.0)
 
     m1, m2 = func(a, b)
     self.assertAllEqual(m1.numpy(), 5.0)
@@ -3306,9 +3306,9 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
     devices = ['/device:CPU:0', '/device:GPU:0']
     for dev1, dev2 in itertools.product(devices, devices):
       with ops.device(dev1):
-        a = constant_op.constant(1.0)
+        a = array_ops.identity(1.0)
       with ops.device(dev2):
-        b = constant_op.constant(10.0)
+        b = array_ops.identity(10.0)
 
       ra, rb = func(a, b)
       self.assertEqual(ra.numpy(), 2.0)
@@ -3469,13 +3469,13 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
     with ops.device('/device:CPU:0'):
       rc0 = resource_variable_ops.ResourceVariable(2.0)
       rc1 = resource_variable_ops.ResourceVariable(3.0)
-      cc0 = constant_op.constant(5.0)
-      cc1 = constant_op.constant(7.0)
+      cc0 = array_ops.identity(5.0)
+      cc1 = array_ops.identity(7.0)
     with ops.device('/device:GPU:0'):
       rg0 = resource_variable_ops.ResourceVariable(11.0)
       rg1 = resource_variable_ops.ResourceVariable(13.0)
-      cg0 = constant_op.constant(17.0)
-      cg1 = constant_op.constant(19.0)
+      cg0 = array_ops.identity(17.0)
+      cg1 = array_ops.identity(19.0)
 
     # Make sure tensors are on expected devices.
     for tensor in [cc0, cc1]:
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index b5c9bfb6824..f8e1fb568ac 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -278,39 +278,13 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
     }
   }
 
-  // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
-  // memory. We approximate the same behavior for eager execution - keeping
-  // int32 tensors in host memory.
+  // We always generate CPU:0 tensors, but we may need to change the device
+  // slightly, as for example from /job:localhost/... to /job:worker/...
   //
-  // We do so to preclude the need for callers into such kernels from having to
-  // explicitly place the int32 tensors in host memory. For example, without
-  // this, one needed:
-  //
-  // with tf.device('/gpu:0'):
-  //   ...// code here
-  //   with tf.device('/cpu:0'):
-  //     shape = tf.constant(...)
-  //   y = tf.random_uniform(shape)
-  //
-  // Without the CPU device block, tfe.ops.random_uniform would fail since the
-  // kernel expects the shape in host memory.
-  //
-  // With this support, we simplify the code:
-  //
-  // with tf.device('/gpu:0'):
-  //   y = tf.random_uniform(...)
-  //
-  // The approximation is not exact there are GPU kernels which do not require
-  // host memory for int32 tensors. This will lead to a discrepancy between
-  // eager and graph execution.
-  //
-  // To support remote execution copy int32 tensors to another CPU device.
-  // TODO(ashankar): Fix this.
+  // Note that this is a shallow copy and will share the underlying buffer,
+  // because we are copying to the same device.
   if (device_name != nullptr &&
-      (TFE_TensorHandleDataType(handle.get()) != TF_INT32 ||
-       strstr(device_name, "/device:CPU:0") != nullptr)) {
-    // Note that this is a shallow copy and will share the underlying buffer
-    // if copying to the same device.
+      strstr(device_name, "/device:CPU:0") != nullptr) {
     handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx,
                                                     device_name, status.get()));
     if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) {
@@ -318,6 +292,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
     }
   }
 
+  // We always enable implicit mirroring for constants. Without this, code
+  // written previously under the assumption that
+  //
+  //   with tf.device('GPU:0'): x = tf.constant(1.0)
+  //
+  // will be placed in the GPU will suffer a non-trivial performance regression
+  // (measured at ~20% for certain benchmarks).
+  handle->handle->EnableImplicitMirroring();
+
   return handle.release();
 }
 
diff --git a/tensorflow/python/eager/remote_cloud_tpu_test.py b/tensorflow/python/eager/remote_cloud_tpu_test.py
index d63a8924bc8..8ba11a3e6ac 100644
--- a/tensorflow/python/eager/remote_cloud_tpu_test.py
+++ b/tensorflow/python/eager/remote_cloud_tpu_test.py
@@ -31,24 +31,25 @@ flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
 flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
 flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
 
+flags.DEFINE_integer('num_tpu_devices', 8, 'The expected number of TPUs.')
+DEVICES_PER_TASK = 8
+
 EXPECTED_DEVICES_PRE_CONNECT = [
     '/device:CPU:0',
     '/device:XLA_CPU:0',
 ]
-EXPECTED_DEVICES_AFTER_CONNECT = [
-    '/device:CPU:0',
-    '/device:XLA_CPU:0',
-    '/job:worker/replica:0/task:0/device:CPU:0',
-    '/job:worker/replica:0/task:0/device:XLA_CPU:0',
-    '/job:worker/replica:0/task:0/device:TPU_SYSTEM:0',
-    '/job:worker/replica:0/task:0/device:TPU:0',
-    '/job:worker/replica:0/task:0/device:TPU:1',
-    '/job:worker/replica:0/task:0/device:TPU:2',
-    '/job:worker/replica:0/task:0/device:TPU:3',
-    '/job:worker/replica:0/task:0/device:TPU:4',
-    '/job:worker/replica:0/task:0/device:TPU:5',
-    '/job:worker/replica:0/task:0/device:TPU:6',
-    '/job:worker/replica:0/task:0/device:TPU:7',
+EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES = [
+    '/job:worker/replica:0/task:{task}/device:CPU:0',
+    '/job:worker/replica:0/task:{task}/device:XLA_CPU:0',
+    '/job:worker/replica:0/task:{task}/device:TPU_SYSTEM:0',
+    '/job:worker/replica:0/task:{task}/device:TPU:0',
+    '/job:worker/replica:0/task:{task}/device:TPU:1',
+    '/job:worker/replica:0/task:{task}/device:TPU:2',
+    '/job:worker/replica:0/task:{task}/device:TPU:3',
+    '/job:worker/replica:0/task:{task}/device:TPU:4',
+    '/job:worker/replica:0/task:{task}/device:TPU:5',
+    '/job:worker/replica:0/task:{task}/device:TPU:6',
+    '/job:worker/replica:0/task:{task}/device:TPU:7',
 ]
 
 
@@ -56,6 +57,9 @@ class RemoteCloudTPUTest(absltest.TestCase):
   """Test that we can connect to a real Cloud TPU."""
 
   def test_connect(self):
+    # Log full diff on failure.
+    self.maxDiff = None  # pylint:disable=invalid-name
+
     self.assertCountEqual(
         EXPECTED_DEVICES_PRE_CONNECT,
         [device.name for device in config.list_logical_devices()])
@@ -65,8 +69,15 @@ class RemoteCloudTPUTest(absltest.TestCase):
     )
     remote.connect_to_cluster(resolver)
 
+    expected_devices = EXPECTED_DEVICES_PRE_CONNECT
+    for task in range(FLAGS.num_tpu_devices // DEVICES_PER_TASK):
+      expected_devices.extend([
+          template.format(task=task)
+          for template in EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES
+      ])
+
     self.assertCountEqual(
-        EXPECTED_DEVICES_AFTER_CONNECT,
+        expected_devices,
         [device.name for device in config.list_logical_devices()])
 
     tpu_strategy_util.initialize_tpu_system(resolver)
diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py
index d9ed93fc662..44af62666ee 100644
--- a/tensorflow/python/eager/remote_test.py
+++ b/tensorflow/python/eager/remote_test.py
@@ -264,6 +264,42 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
 
     self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
 
+  @test_util.eager_lazy_remote_copy_on_and_off
+  def testMultiDeviceFunctionOnRemoteDeviceWithWait(self):
+    with ops.device('/job:worker/replica:0/task:1'):
+      variable_b = variables.Variable([1.0])
+
+    @def_function.function
+    def remote_function(i):
+      x = array_ops.ones([1000, 1000])
+      for _ in range(1, 1000):
+        x = x * x
+      variable_b.assign_add(i)
+      a = 1.0 + variable_b
+      return a
+
+    @def_function.function
+    def remote_function2(i):
+      variable_b.assign_add(i)
+      a = 1.0 + variable_b
+      return a
+
+    # Runs first function:
+    # - on remote device
+    # - needs remote input
+    # - is side impacting
+    # - runs much slower
+    with ops.device('/job:worker/replica:0/task:0'):
+      remote_function(constant_op.constant([2.0]))
+
+    # Runs second function:
+    # - on remote device
+    # - is side impacting
+    # There should be a sync point here and the next function will be executed
+    # only after the first function has completed.
+    with ops.device('/job:worker/replica:0/task:2'):
+      self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0])
+
   @test_util.eager_lazy_remote_copy_on_and_off
   def testMultiDeviceFunctionOnRemoteDevice(self):
     with ops.device('/job:worker/replica:0/task:1'):
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index dd1f049cdcc..fe4a7933a32 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -281,9 +281,9 @@ class TFETensorTest(test_util.TensorFlowTestCase):
   @test_util.run_gpu_only
   def testStringTensorOnGPU(self):
     with ops.device("/device:GPU:0"):
-      with self.assertRaisesRegexp(
-          RuntimeError, "Can't copy Tensor with type string to device"):
-        _create_tensor("test string")
+      t = _create_tensor("test string")
+      self.assertIn("CPU", t.device)
+      self.assertIn("CPU", t.backing_device)
 
   def testInvalidUTF8ProducesReasonableError(self):
     if sys.version_info[0] < 3:
diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py
index aef685e2390..2b18eb4ede0 100644
--- a/tensorflow/python/framework/auto_control_deps.py
+++ b/tensorflow/python/framework/auto_control_deps.py
@@ -18,7 +18,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
+import enum
+
 from tensorflow.python.eager import context
+from tensorflow.python.framework import auto_control_deps_utils as utils
 from tensorflow.python.framework import dtypes as dtypes_module
 from tensorflow.python.framework import op_def_registry
 from tensorflow.python.framework import ops
@@ -120,6 +124,11 @@ def op_is_stateful(op):
       op.type in _WHITELIST_STATELESS_OPS)
 
 
+class ResourceType(enum.Enum):
+  READ_ONLY = "read-only"
+  READ_WRITE = "read-write"
+
+
 class AutomaticControlDependencies(object):
   """Context manager to automatically add control dependencies.
 
@@ -197,7 +206,7 @@ class AutomaticControlDependencies(object):
     return self
 
   def _process_switch(self, switch_op, ops_which_must_run,
-                      last_op_using_resource_tensor, merge_for_resource):
+                      last_write_to_resource, merge_for_resource):
     """Processes a switch node for a resource input.
 
     When tensorflow creates a cond, it creates a control flow context for each
@@ -227,7 +236,7 @@ class AutomaticControlDependencies(object):
     Args:
       switch_op: the switch op to be processed
       ops_which_must_run: the set of ops which must run
-      last_op_using_resource_tensor: map from resource tensor to last op using
+      last_write_to_resource: map from resource tensor to last op updating
         it
       merge_for_resource: map from resource tensor to merge which must follow
         all usages of it.
@@ -235,8 +244,8 @@ class AutomaticControlDependencies(object):
     inp = switch_op.inputs[0]
     input_id = ops.tensor_id(inp)
     if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
-      self._process_switch(inp.op, ops_which_must_run,
-                           last_op_using_resource_tensor, merge_for_resource)
+      self._process_switch(inp.op, ops_which_must_run, last_write_to_resource,
+                           merge_for_resource)
     output = switch_op.outputs[0]
     output_id = ops.tensor_id(output)
     if output_id in merge_for_resource:
@@ -247,11 +256,11 @@ class AutomaticControlDependencies(object):
         switch_op._control_flow_context.outer_context)  # pylint: disable=protected-access
     # Ensures the merge always runs
     ops_which_must_run.add(new_merge[0].op)
-    if input_id in last_op_using_resource_tensor:
+    if input_id in last_write_to_resource:
       # Ensures the switch executes after the previous op using the resource.
-      switch_op._add_control_input(last_op_using_resource_tensor[input_id])  # pylint: disable=protected-access
+      switch_op._add_control_input(last_write_to_resource[input_id])  # pylint: disable=protected-access
     # Ensure the next op outside the cond happens after the merge.
-    last_op_using_resource_tensor[input_id] = new_merge[0].op
+    last_write_to_resource[input_id] = new_merge[0].op
     if input_id in merge_for_resource:
       merge_for_resource[input_id]._add_control_input(new_merge[0].op)  # pylint: disable=protected-access
     for o in switch_op.outputs:
@@ -274,8 +283,11 @@ class AutomaticControlDependencies(object):
       self._graph._add_control_dependencies = False
     # pylint: enable=protected-access
 
-    # map from resource tensor to the last op which used it
-    last_op_using_resource_tensor = {}
+    # map from resource tensor to the last op which wrote to it
+    last_write_to_resource = {}
+    # map from resource tensor to the list of reads from it since the last
+    # write or since the beginning of the function.
+    reads_since_last_write_to_resource = collections.defaultdict(list)
     # set of conditional and loop exits
     ops_which_must_run = set()
     # merge which must depend on ops which use this resource
@@ -285,7 +297,7 @@ class AutomaticControlDependencies(object):
 
     # Ensures that uses of resource tensors get serialized properly and all
     # execute. This is done by keeping a map from resource tensor to the last op
-    # in graph-construction order which used it (last_op_using_resource_tensor).
+    # in graph-construction order which used it (last_write_to_resource).
     #
     # Conditionals are written in TensorFlow such that every external tensor
     # accessed in the conditional goes through a switch op and every return
@@ -317,7 +329,10 @@ class AutomaticControlDependencies(object):
         continue
       control_inputs = set()
       # Ensure stateful ops run
-      if op_def_registry.get(op.type) is None or op_is_stateful(op):
+      if (op_def_registry.get(op.type) is None or
+          (op_is_stateful(op) and op.type not in utils.RESOURCE_READ_OPS)):
+        # TODO(srbs): Do not add functional ops to `ops_which_must_run` if
+        # they only have variable reads and are otherwise stateless.
         ops_which_must_run.add(op)
       # Ignore switches (they're handled separately)
       if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
@@ -328,15 +343,16 @@ class AutomaticControlDependencies(object):
           op._add_control_input(o)  # pylint: disable=protected-access
           for inp in o.inputs:
             input_id = ops.tensor_id(inp)
-            if input_id in last_op_using_resource_tensor:
-              last_op_using_resource_tensor[input_id] = op
+            if input_id in last_write_to_resource:
+              last_write_to_resource[input_id] = op
         ops_which_must_run = set([op])
         continue
 
       resource_inputs = set()
       # Check for any resource inputs. If we find any, we update control_inputs
-      # and last_op_using_resource_tensor.
-      for inp in _get_resource_inputs(op):
+      # and last_write_to_resource.
+      for inp, resource_type in _get_resource_inputs(op):
+        is_read = resource_type == ResourceType.READ_ONLY
         input_id = ops.tensor_id(inp)
 
         # If the op receives the same resource tensor twice as an input, we skip
@@ -348,25 +364,29 @@ class AutomaticControlDependencies(object):
         # Deal with switches, finally.
         if inp.op.type == "Switch":
           self._process_switch(inp.op, ops_which_must_run,
-                               last_op_using_resource_tensor,
-                               merge_for_resource)
+                               last_write_to_resource, merge_for_resource)
         is_building_function = op.graph.building_function
         # Ensure uses of resources are serialized
-        if input_id in last_op_using_resource_tensor:
+        if input_id in last_write_to_resource:
           if is_building_function or (
-              last_op_using_resource_tensor[input_id]._control_flow_context  # pylint: disable=protected-access
+              last_write_to_resource[input_id]._control_flow_context  # pylint: disable=protected-access
               is op._control_flow_context):  # pylint: disable=protected-access
-            control_inputs.add(last_op_using_resource_tensor[input_id])
+            control_inputs.add(last_write_to_resource[input_id])
         # Ensure merges happen after the closing of a cond block
         if input_id in merge_for_resource:
           merge_for_resource[input_id]._add_control_input(op)  # pylint: disable=protected-access
-        last_op_using_resource_tensor[input_id] = op
+        if is_read:
+          reads_since_last_write_to_resource[input_id].append(op)
+        else:
+          control_inputs.update(reads_since_last_write_to_resource[input_id])
+          reads_since_last_write_to_resource[input_id] = []
+          last_write_to_resource[input_id] = op
 
       if (op_is_stateful(op) and not resource_inputs
           and op._control_flow_context is None):  # pylint: disable=protected-access
-        if None in last_op_using_resource_tensor:
-          op._add_control_input(last_op_using_resource_tensor[None])  # pylint: disable=protected-access
-        last_op_using_resource_tensor[None] = op
+        if None in last_write_to_resource:
+          op._add_control_input(last_write_to_resource[None])  # pylint: disable=protected-access
+        last_write_to_resource[None] = op
       control_inputs = [
           c for c in control_inputs if is_building_function or
           (c._control_flow_context is op._control_flow_context)]  # pylint: disable=protected-access
@@ -384,34 +404,43 @@ class AutomaticControlDependencies(object):
             ])
 
 
-_acd_resource_resolvers_registry = registry.Registry("acd_resouce_resolvers")
+_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers")
 
 
 def register_acd_resource_resolver(f):
   """Register a function for resolving resources touched by an op.
 
+  `f` is called for every Operation added in the ACD context with the op's
+  original resource reads and writes. `f` is expected to update the sets of
+  resource reads and writes in-place and return True if it updated either of the
+  sets, False otherwise.
+
   Example:
   @register_acd_resource_resolver
-  def ResolveIdentity(op, resource_inputs):
+  def ResolveIdentity(op, resource_reads, resource_writes):
     # op: The `Operation` being processed by ACD currently.
-    # resource_inputs: An `ObjectIdentitySet` that can be updated in-place.
-    if not resource_inputs:
+    # resource_reads: An `ObjectIdentitySet` of read-only resources.
+    # resource_writes: An `ObjectIdentitySet` of read-write resources.
+    if not resource_reads or resource_writes:
       return False
-    to_add = []
-    to_remove = []
-    for t in resource_inputs:
-      if t.op.type == "Identity":
-        to_remove.append(t)
-        to_add.append(t.op.inputs[0])
-    if not to_add and not to_remove:
-      return False
-    for t in to_remove:
-      resource_inputs.discard(t)
-    resource_inputs.update(to_add)
-    return True  # `resource_inputs` was updated.
+    def update(resource_inputs):
+      to_add = []
+      to_remove = []
+      for t in resource_inputs:
+        if t.op.type == "Identity":
+          to_remove.append(t)
+          to_add.append(t.op.inputs[0])
+      if not to_add and not to_remove:
+        return False
+      for t in to_remove:
+        resource_inputs.discard(t)
+      resource_inputs.update(to_add)
+      return True
+    return update(resource_reads) or update(resource_writes)
 
   Args:
-    f: Python function
+    f: Python function with signature
+    (Operation, ObjectIdentitySet, ObjectIdentitySet) -> bool
 
   Returns:
     The function `f` after adding it to the registry.
@@ -422,8 +451,14 @@ def register_acd_resource_resolver(f):
 
 def _get_resource_inputs(op):
   """Returns an iterable of resources touched by this `op`."""
-  resource_inputs = object_identity.ObjectIdentitySet(
-      t for t in op.inputs if t.dtype == dtypes_module.resource)
+  reads = object_identity.ObjectIdentitySet()
+  writes = object_identity.ObjectIdentitySet()
+  for t in op.inputs:
+    if t.dtype == dtypes_module.resource:
+      if utils.op_writes_to_resource(t, op):
+        writes.add(t)
+      else:
+        reads.add(t)
   saturated = False
   while not saturated:
     saturated = True
@@ -432,10 +467,16 @@ def _get_resource_inputs(op):
       # resource_inputs.
       # TODO(srbs): An alternate would be to just compare the old and new set
       # but that may not be as fast.
-      updated = _acd_resource_resolvers_registry.lookup(key)(op,
-                                                             resource_inputs)
+      updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes)
+      if updated:
+        # Conservatively remove any resources from `reads` that are also writes.
+        reads = reads.difference(writes)
       saturated = saturated and not updated
-  return resource_inputs
+
+  # Note: A resource handle that is not written to is treated as read-only. We
+  # don't have a special way of denoting an unused resource.
+  return ([(t, ResourceType.READ_ONLY) for t in reads] +
+          [(t, ResourceType.READ_WRITE) for t in writes])
 
 
 def automatic_control_dependencies(f):
diff --git a/tensorflow/python/framework/auto_control_deps_test.py b/tensorflow/python/framework/auto_control_deps_test.py
index 19c0606eb42..a655cadbc10 100644
--- a/tensorflow/python/framework/auto_control_deps_test.py
+++ b/tensorflow/python/framework/auto_control_deps_test.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import itertools
+
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
@@ -30,6 +32,7 @@ from tensorflow.python.framework import test_util
 from tensorflow.python.keras.layers import core as keras_core
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_resource_variable_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
@@ -50,6 +53,509 @@ class AutomaticControlDependenciesTest(test.TestCase):
         val = c.mark_as_return(val)
       self.assertAllEqual(val.eval(), 4.0)
 
+  def testNoControlDepsBetweenVariableReads(self):
+    with context.graph_mode(), self.cached_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      self.evaluate(variables.global_variables_initializer())
+      with acd.AutomaticControlDependencies():
+        read_op1 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        read_op2 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+      self.assertNotIn(read_op1, read_op2.control_inputs)
+      self.assertNotIn(read_op2, read_op1.control_inputs)
+
+  def testVariableReadThenWrite(self):
+    with context.graph_mode(), self.cached_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      self.evaluate(variables.global_variables_initializer())
+      with acd.AutomaticControlDependencies():
+        read_op1 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        read_op2 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        assign_op = gen_resource_variable_ops.assign_variable_op(
+            v.handle, v + 1)
+      # Writes should have control deps from "all" reads since last write
+      # or start of the code block.
+      self.assertIn(read_op1, assign_op.control_inputs)
+      self.assertIn(read_op2, assign_op.control_inputs)
+      # There should be no control deps between reads.
+      self.assertNotIn(read_op1, read_op2.control_inputs)
+      self.assertNotIn(read_op2, read_op1.control_inputs)
+
+  def testVariableWriteThenRead(self):
+    with context.graph_mode(), self.cached_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      self.evaluate(variables.global_variables_initializer())
+      with acd.AutomaticControlDependencies():
+        assign_op = gen_resource_variable_ops.assign_variable_op(
+            v.handle, v + 1)
+        read_op1 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        read_op2 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+      # Reads should have a control dep from the last write.
+      self.assertIn(assign_op, read_op1.control_inputs)
+      self.assertIn(assign_op, read_op2.control_inputs)
+      # There should be no control deps between reads.
+      self.assertNotIn(read_op1, read_op2.control_inputs)
+      self.assertNotIn(read_op2, read_op1.control_inputs)
+
+  def testVariableReadsNotInOpsWithMustRun(self):
+    with context.graph_mode(), self.cached_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      self.evaluate(variables.global_variables_initializer())
+      with acd.AutomaticControlDependencies() as c:
+        read_op1 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        read_op2 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        assign_op = gen_resource_variable_ops.assign_variable_op(
+            v.handle, v + 1)
+      # Reads must not be in `ops_which_must_run` since those get added to the
+      # `control_outputs`.
+      self.assertNotIn(read_op1, c.ops_which_must_run)
+      self.assertNotIn(read_op2, c.ops_which_must_run)
+      # Last write must be in `ops_which_must_run`.
+      self.assertIn(assign_op, c.ops_which_must_run)
+
+  def testVariableMultipleReadsAndWrites(self):
+    with context.graph_mode(), self.cached_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      self.evaluate(variables.global_variables_initializer())
+      with acd.AutomaticControlDependencies() as c:
+        # 2 reads -> 2 writes -> 2 reads -> 2 writes.
+        read_op1 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        read_op2 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        assign_op1 = gen_resource_variable_ops.assign_variable_op(
+            v.handle, v + 1)
+        assign_op2 = gen_resource_variable_ops.assign_variable_op(
+            v.handle, v + 1)
+        read_op3 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        read_op4 = gen_resource_variable_ops.read_variable_op(
+            v.handle, v.dtype).op
+        assign_op3 = gen_resource_variable_ops.assign_variable_op(
+            v.handle, v + 1)
+        assign_op4 = gen_resource_variable_ops.assign_variable_op(
+            v.handle, v + 1)
+
+      # Verify the control edges.
+      self.assertIn(read_op1, assign_op1.control_inputs)
+      self.assertIn(read_op2, assign_op1.control_inputs)
+      self.assertIn(assign_op1, assign_op2.control_inputs)
+      self.assertIn(assign_op2, read_op3.control_inputs)
+      self.assertIn(assign_op2, read_op4.control_inputs)
+      self.assertIn(read_op3, assign_op3.control_inputs)
+      self.assertIn(read_op4, assign_op3.control_inputs)
+      self.assertIn(assign_op3, assign_op4.control_inputs)
+
+      # There should be no control deps between reads.
+      read_ops = [read_op1, read_op2, read_op3, read_op4]
+      for src_op, tgt_op in itertools.product(read_ops, read_ops):
+        self.assertNotIn(src_op, tgt_op.control_inputs)
+
+      # Reads must not be in `ops_which_must_run`.
+      self.assertNotIn(read_op1, c.ops_which_must_run)
+      self.assertNotIn(read_op2, c.ops_which_must_run)
+      self.assertNotIn(read_op3, c.ops_which_must_run)
+      self.assertNotIn(read_op4, c.ops_which_must_run)
+      # Last write must be in `ops_which_must_run`.
+      self.assertIn(assign_op4, c.ops_which_must_run)
+
+  def _testVariableReadInFunctionalOp(self, build_functional_op, op_type):
+    v = resource_variable_ops.ResourceVariable(1.0)
+    self.evaluate(variables.global_variables_initializer())
+
+    @def_function.function
+    def read_var_in_while():
+      gen_resource_variable_ops.read_variable_op(
+          v.handle, v.dtype, name="read1")
+
+      result = build_functional_op(v)
+      gen_resource_variable_ops.read_variable_op(
+          v.handle, v.dtype, name="read2")
+      gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+      return result
+
+    func_graph = read_var_in_while.get_concrete_function().graph
+    assert len(func_graph.inputs) == 1
+
+    def get_op(op_type, sub_name):
+      operations = [
+          op for op in func_graph.get_operations()
+          if op.type == op_type and sub_name in op.name
+      ]
+      assert len(operations) == 1
+      return operations[0]
+
+    read1 = get_op("ReadVariableOp", "read1")
+    functional_op = get_op(op_type, "")
+    read2 = get_op("ReadVariableOp", "read2")
+    assign_op = get_op("AssignVariableOp", "")
+    # Since the functional op only has reads, previous reads e.g. read1 do not\
+    # have a control edge to it and next future reads e.g. read2 do not have a
+    # control edge from it.
+    self.assertNotIn(read1, functional_op.control_inputs)
+    self.assertNotIn(functional_op, read2.control_inputs)
+    self.assertIn(read1, assign_op.control_inputs)
+    self.assertIn(read2, assign_op.control_inputs)
+    self.assertIn(functional_op, assign_op.control_inputs)
+
+  def testVariableReadInWhileLoop(self):
+
+    def build_functional_op(v):
+
+      def body(_):
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      return control_flow_ops.while_loop(
+          lambda i: True, body, [0.0], maximum_iterations=1)
+
+    self._testVariableReadInFunctionalOp(build_functional_op, "While")
+
+  def testVariableReadInCondTrueBranch(self):
+
+    def build_functional_op(v):
+
+      def then_branch():
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      def else_branch():
+        return array_ops.zeros([], v.dtype)
+
+      return control_flow_ops.cond(
+          constant_op.constant(True), then_branch, else_branch)
+
+    self._testVariableReadInFunctionalOp(build_functional_op, "If")
+
+  def testVariableReadInCondFalseBranch(self):
+
+    def build_functional_op(v):
+
+      def then_branch():
+        return array_ops.zeros([], v.dtype)
+
+      def else_branch():
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      return control_flow_ops.cond(
+          constant_op.constant(False), then_branch, else_branch)
+
+    self._testVariableReadInFunctionalOp(build_functional_op, "If")
+
+  def testVariableReadInCaseBranch0(self):
+
+    def build_functional_op(v):
+
+      def branch0():
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      def branch1():
+        return array_ops.zeros([], v.dtype)
+
+      return control_flow_ops.switch_case(
+          constant_op.constant(0), [branch0, branch1])
+
+    self._testVariableReadInFunctionalOp(build_functional_op, "Case")
+
+  def testVariableReadInCaseBranch1(self):
+
+    def build_functional_op(v):
+
+      def branch0():
+        return array_ops.zeros([], v.dtype)
+
+      def branch1():
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      return control_flow_ops.switch_case(
+          constant_op.constant(0), [branch0, branch1])
+
+    self._testVariableReadInFunctionalOp(build_functional_op, "Case")
+
+  def testVariableReadInFunction(self):
+
+    def build_functional_op(v):
+
+      @def_function.function
+      def fn_with_read():
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      return fn_with_read()
+
+    self._testVariableReadInFunctionalOp(build_functional_op,
+                                         "StatefulPartitionedCall")
+
+  def testVariableReadInNestedFunction(self):
+
+    def build_functional_op(v):
+
+      @def_function.function
+      def fn_with_read():
+
+        @def_function.function
+        def inner_fn():
+          return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+        return inner_fn()
+
+      return fn_with_read()
+
+    self._testVariableReadInFunctionalOp(build_functional_op,
+                                         "StatefulPartitionedCall")
+
+  def testVariableReadInWhileInInnerFunc(self):
+
+    def build_functional_op(v):
+
+      @def_function.function
+      def fn_with_read():
+
+        @def_function.function
+        def inner_fn():
+
+          def body(_):
+            return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+          return control_flow_ops.while_loop(
+              lambda i: True, body, [0.0], maximum_iterations=1)
+
+        return inner_fn()
+
+      return fn_with_read()
+
+    self._testVariableReadInFunctionalOp(build_functional_op,
+                                         "StatefulPartitionedCall")
+
+  def testVariableReadInCondInInnerFunc(self):
+
+    def build_functional_op(v):
+
+      @def_function.function
+      def fn_with_read():
+
+        @def_function.function
+        def inner_fn():
+
+          def then_branch():
+            return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+          def else_branch():
+            return array_ops.zeros([], v.dtype)
+
+          return control_flow_ops.cond(
+              constant_op.constant(True), then_branch, else_branch)
+
+        return inner_fn()
+
+      return fn_with_read()
+
+    self._testVariableReadInFunctionalOp(build_functional_op,
+                                         "StatefulPartitionedCall")
+
+  def _testVariableWriteInFunctionalOp(self, build_functional_op, op_type):
+    v = resource_variable_ops.ResourceVariable(1.0)
+    self.evaluate(variables.global_variables_initializer())
+
+    @def_function.function
+    def write_var_in_while():
+      gen_resource_variable_ops.read_variable_op(
+          v.handle, v.dtype, name="read1")
+
+      result = build_functional_op(v)
+      gen_resource_variable_ops.read_variable_op(
+          v.handle, v.dtype, name="read2")
+      gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+      return result
+
+    func_graph = write_var_in_while.get_concrete_function().graph
+    assert len(func_graph.inputs) == 1
+
+    def get_op(op_type, sub_name):
+      operations = [
+          op for op in func_graph.get_operations()
+          if op.type == op_type and sub_name in op.name
+      ]
+      assert len(operations) == 1
+      return operations[0]
+
+    read1 = get_op("ReadVariableOp", "read1")
+    functional_op = get_op(op_type, "")
+    read2 = get_op("ReadVariableOp", "read2")
+    assign_op = get_op("AssignVariableOp", "")
+    # Since the While has writes, it has control edges from previous reads
+    # e.g. `read1` and to future reads(`read2`) and writes(`assign_op`).
+    self.assertIn(read1, functional_op.control_inputs)
+    self.assertIn(functional_op, read2.control_inputs)
+    self.assertIn(read2, assign_op.control_inputs)
+    self.assertIn(functional_op, assign_op.control_inputs)
+
+  def testVariableWriteInWhileLoop(self):
+
+    def build_functional_op(v):
+
+      def body(_):
+        gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      return control_flow_ops.while_loop(
+          lambda i: True, body, [0.0], maximum_iterations=1)
+
+    self._testVariableWriteInFunctionalOp(build_functional_op, "While")
+
+  def testVariableWriteInCondTrueBranch(self):
+
+    def build_functional_op(v):
+
+      def then_branch():
+        gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      def else_branch():
+        return array_ops.zeros([], v.dtype)
+
+      return control_flow_ops.cond(
+          constant_op.constant(True), then_branch, else_branch)
+
+    self._testVariableWriteInFunctionalOp(build_functional_op, "If")
+
+  def testVariableWriteInCondFalseBranch(self):
+
+    def build_functional_op(v):
+
+      def then_branch():
+        return array_ops.zeros([], v.dtype)
+
+      def else_branch():
+        gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      return control_flow_ops.cond(
+          constant_op.constant(False), then_branch, else_branch)
+
+    self._testVariableWriteInFunctionalOp(build_functional_op, "If")
+
+  def testVariableWriteInCaseBranch0(self):
+
+    def build_functional_op(v):
+
+      def branch0():
+        gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      def branch1():
+        return array_ops.zeros([], v.dtype)
+
+      return control_flow_ops.switch_case(
+          constant_op.constant(0), [branch0, branch1])
+
+    self._testVariableWriteInFunctionalOp(build_functional_op, "Case")
+
+  def testVariableWriteInCaseBranch1(self):
+
+    def build_functional_op(v):
+
+      def branch0():
+        return array_ops.zeros([], v.dtype)
+
+      def branch1():
+        gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      return control_flow_ops.switch_case(
+          constant_op.constant(0), [branch0, branch1])
+
+    self._testVariableWriteInFunctionalOp(build_functional_op, "Case")
+
+  def testVariableWriteInFunction(self):
+
+    def build_functional_op(v):
+
+      @def_function.function
+      def fn_with_write():
+        gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+        return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+      return fn_with_write()
+
+    self._testVariableWriteInFunctionalOp(build_functional_op,
+                                          "StatefulPartitionedCall")
+
+  def testVariableWriteInNestedFunction(self):
+
+    def build_functional_op(v):
+
+      @def_function.function
+      def fn_with_write():
+
+        @def_function.function
+        def inner_fn():
+          gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+          return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+        return inner_fn()
+
+      return fn_with_write()
+
+    self._testVariableWriteInFunctionalOp(build_functional_op,
+                                          "StatefulPartitionedCall")
+
+  def testVariableWriteInWhileInInnerFunc(self):
+
+    def build_functional_op(v):
+
+      @def_function.function
+      def fn_with_write():
+
+        @def_function.function
+        def inner_fn():
+
+          def body(_):
+            gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+            return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+          return control_flow_ops.while_loop(
+              lambda i: True, body, [0.0], maximum_iterations=1)
+
+        return inner_fn()
+
+      return fn_with_write()
+
+    self._testVariableWriteInFunctionalOp(build_functional_op,
+                                          "StatefulPartitionedCall")
+
+  def testVariableWriteInCondInInnerFunc(self):
+
+    def build_functional_op(v):
+
+      @def_function.function
+      def fn_with_write():
+
+        @def_function.function
+        def inner_fn():
+
+          def then_branch():
+            gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
+            return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype)
+
+          def else_branch():
+            return array_ops.zeros([], v.dtype)
+
+          return control_flow_ops.cond(
+              constant_op.constant(True), then_branch, else_branch)
+
+        return inner_fn()
+
+      return fn_with_write()
+
+    self._testVariableWriteInFunctionalOp(build_functional_op,
+                                          "StatefulPartitionedCall")
+
   @test_util.run_v1_only("b/120545219")
   def testCondMustRun(self):
     with context.graph_mode(), self.cached_session():
diff --git a/tensorflow/python/framework/auto_control_deps_utils.py b/tensorflow/python/framework/auto_control_deps_utils.py
new file mode 100644
index 00000000000..d8b4656e3d9
--- /dev/null
+++ b/tensorflow/python/framework/auto_control_deps_utils.py
@@ -0,0 +1,91 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for AutomaticControlDependencies."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+
+READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"
+RESOURCE_READ_OPS = set()
+
+
+def register_read_only_resource_op(op_type):
+  """Declares that `op_type` does not update its touched resource."""
+  RESOURCE_READ_OPS.add(op_type)
+
+
+def resource_has_writes(handle):
+  """Returns whether any of the consumers of handle write to it.
+
+  Args:
+    handle: Tensor of type DT_RESOURCE.
+
+  Returns:
+    Returns True if at least one consumer of `handle` writes to it.
+    Returns False if all consumers of `handle` do not write to it or if the
+    `handle` has no consumers.
+  """
+  assert handle.dtype == dtypes.resource
+  return any(op_writes_to_resource(handle, op) for op in handle.consumers())
+
+
+def op_writes_to_resource(handle, op):
+  """Returns whether op writes to resource handle.
+
+  Args:
+    handle: Resource handle. Must be an input of `op`.
+    op: Operation.
+
+  Returns:
+    Returns False if op is a read-only op registered using
+    `register_read_only_resource_op` or if `handle` is an input at one of
+    the indices in the `READ_ONLY_RESOURCE_INPUTS_ATTR` attr of the op, True
+    otherwise.
+
+  Raises:
+    ValueError: if `handle` is not an input of `op`.
+  """
+  if op.type in RESOURCE_READ_OPS:
+    return False
+  input_index = _input_index(op, handle)
+  try:
+    read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
+  except ValueError:
+    # Attr was not set. Conservatively assume that the resource is written to.
+    return True
+  return input_index not in read_only_input_indices
+
+
+def _input_index(op, handle):
+  """Returns the index of `handle` in `op.inputs`.
+
+  Args:
+    op: Operation.
+    handle: Resource handle.
+
+  Returns:
+    Index in `op.inputs` receiving the resource `handle`.
+
+  Raises:
+    ValueError: If handle and its replicated input are both not found in
+    `op.inputs`.
+  """
+  for i, t in enumerate(op.inputs):
+    if handle is t:
+      return i
+  raise ValueError("%s not in list" % str(handle))
diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py
index 72612a21cbf..2ef7d737d73 100644
--- a/tensorflow/python/framework/config_test.py
+++ b/tensorflow/python/framework/config_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 from tensorflow.python.util import compat
@@ -380,10 +381,10 @@ class DeviceTest(test.TestCase):
     with ops.device('/device:CPU:1'):
       b = constant_op.constant(1.0)
       self.evaluate(b)
-    with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
-      with ops.device('/device:CPU:2'):
-        c = constant_op.constant(1.0)
-        self.evaluate(c)
+    with ops.device('/device:CPU:2'):
+      c = constant_op.constant(1.0)
+      self.evaluate(c)
+    self.assertIn('CPU:0', c.device)
 
     # Ensure we can place ops on each of the device names
     for vcpu in vcpus:
@@ -408,6 +409,7 @@ class DeviceTest(test.TestCase):
   @test_util.run_gpu_only
   @reset_eager
   def testGpuNone(self):
+    config.set_soft_device_placement(False)
     gpus = config.list_physical_devices('GPU')
     self.assertGreater(len(gpus), 0)
 
@@ -427,14 +429,16 @@ class DeviceTest(test.TestCase):
     self.assertEqual(len(config.get_visible_devices('GPU')), 0)
     self.assertEqual(len(config.list_logical_devices('XLA_GPU')), 0)
 
-    with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
+    with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                 'Could not satisfy'):
       with ops.device('/device:GPU:0'):
-        a = constant_op.constant(1.0)
+        a = array_ops.identity(1.0)
         self.evaluate(a)
 
-    with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
+    with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                 'Could not satisfy'):
       with ops.device('/device:XLA_GPU:0'):
-        a = constant_op.constant(1.0)
+        a = array_ops.identity(1.0)
         self.evaluate(a)
 
     # Modifying the visible devices is not supported
@@ -465,6 +469,7 @@ class DeviceTest(test.TestCase):
   @test_util.run_gpu_only
   @reset_eager
   def testVirtualGpu(self):
+    config.set_soft_device_placement(False)
     gpus = config.list_physical_devices('GPU')
     self.assertNotEqual(len(gpus), 0)
 
@@ -479,12 +484,13 @@ class DeviceTest(test.TestCase):
     self.assertTrue(len(logical_gpus), len(gpus) + 1)
     for i in range(0, len(logical_gpus)):
       with ops.device('/device:GPU:' + str(i)):
-        a = constant_op.constant(1.0)
+        a = array_ops.identity(1.0)
         self.evaluate(a)
 
-    with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
+    with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                 'Could not satisfy'):
       with ops.device('/device:GPU:' + str(len(logical_gpus))):
-        a = constant_op.constant(1.0)
+        a = array_ops.identity(1.0)
         self.evaluate(a)
 
     # Modifying the GPU configuration is not supported
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index 4d9aa29ad60..9736bb8b78b 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -224,6 +224,10 @@ def constant(value, dtype=None, shape=None, name="Const"):
   ...
   NotImplementedError: ...
 
+  `tf.constant` will _always_ create CPU (host) tensors. In order to create
+  tensors on other devices, use `tf.identity`. (If the `value` is an eager
+  Tensor, however, the tensor will be returned unmodified as mentioned above.)
+
   Related Ops:
 
   * `tf.convert_to_tensor` is similar but:
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f716dfa33dd..d3df8cb973d 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -3171,6 +3171,8 @@ class Graph(object):
     Raises:
       ValueError: if another function is defined with the same name.
     """
+    self._check_not_finalized()
+
     name = function.name
     # Sanity checks on gradient definition.
     if (function.grad_func_name is not None) and (function.python_grad_func is
@@ -3455,6 +3457,8 @@ class Graph(object):
     Returns:
       A list of the new `Operation` objects.
     """
+    self._check_not_finalized()
+
     # Create all Operation objects before accessing their inputs since an op may
     # be created before its inputs.
     new_ops = [
@@ -6738,3 +6742,9 @@ class _TensorIterator(object):
     return result
 
   next = __next__  # python2.x compatibility.
+
+
+def set_int_list_attr(op, attr_name, ints):
+  """TF internal method used to set a list(int) attribute in the node_def."""
+  ints_list = attr_value_pb2.AttrValue.ListValue(i=ints)
+  op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list))  # pylint:disable=protected-access
diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py
index ab8ab71e3b0..7408b314bb6 100644
--- a/tensorflow/python/keras/applications/inception_resnet_v2.py
+++ b/tensorflow/python/keras/applications/inception_resnet_v2.py
@@ -370,9 +370,34 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
 
 @keras_export('keras.applications.inception_resnet_v2.preprocess_input')
 def preprocess_input(x, data_format=None):
+  """Preprocesses a numpy array encoding a batch of images.
+
+  Arguments
+    x: A 4D numpy array consists of RGB values within [0, 255].
+
+  Returns
+    Preprocessed array.
+
+  Raises
+    ValueError: In case of unknown `data_format` argument.
+  """
   return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
 
 
 @keras_export('keras.applications.inception_resnet_v2.decode_predictions')
 def decode_predictions(preds, top=5):
+  """Decodes the prediction result from the model.
+
+  Arguments
+    preds: Numpy tensor encoding a batch of predictions.
+    top: Integer, how many top-guesses to return.
+
+  Returns
+    A list of lists of top class prediction tuples
+    `(class_name, class_description, score)`.
+    One list of tuples per sample in batch input.
+
+  Raises
+    ValueError: In case of invalid shape of the `preds` array (must be 2D).
+  """
   return imagenet_utils.decode_predictions(preds, top=top)
diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py
index 224e8c84496..128553f0d39 100644
--- a/tensorflow/python/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/applications/mobilenet.py
@@ -436,9 +436,34 @@ def _depthwise_conv_block(inputs,
 
 @keras_export('keras.applications.mobilenet.preprocess_input')
 def preprocess_input(x, data_format=None):
+  """Preprocesses a numpy array encoding a batch of images.
+
+  Arguments
+    x: A 4D numpy array consists of RGB values within [0, 255].
+
+  Returns
+    Preprocessed array.
+
+  Raises
+    ValueError: In case of unknown `data_format` argument.
+  """
   return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
 
 
 @keras_export('keras.applications.mobilenet.decode_predictions')
 def decode_predictions(preds, top=5):
+  """Decodes the prediction result from the model.
+
+  Arguments
+    preds: Numpy tensor encoding a batch of predictions.
+    top: Integer, how many top-guesses to return.
+
+  Returns
+    A list of lists of top class prediction tuples
+    `(class_name, class_description, score)`.
+    One list of tuples per sample in batch input.
+
+  Raises
+    ValueError: In case of invalid shape of the `preds` array (must be 2D).
+  """
   return imagenet_utils.decode_predictions(preds, top=top)
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 8651cf27375..aea77ca2e47 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -894,9 +894,12 @@ class History(Callback):
   gets returned by the `fit` method of models.
   """
 
+  def __init__(self):
+    super(History, self).__init__()
+    self.history = {}
+
   def on_train_begin(self, logs=None):
     self.epoch = []
-    self.history = {}
 
   def on_epoch_end(self, epoch, logs=None):
     logs = logs or {}
@@ -904,6 +907,10 @@ class History(Callback):
     for k, v in logs.items():
       self.history.setdefault(k, []).append(v)
 
+    # Set the history attribute on the model after the epoch ends. This will
+    # make sure that the state which is set is the latest one.
+    self.model.history = self
+
 
 @keras_export('keras.callbacks.ModelCheckpoint')
 class ModelCheckpoint(Callback):
@@ -1672,6 +1679,8 @@ class TensorBoard(Callback):
     self._writers = {}
     self._start_batch, self._stop_batch = self._init_profile_batch(
         profile_batch)
+    if self._start_batch > 0:
+      profiler.warmup()  # Improve the profiling accuracy.
     # True when a trace is running.
     self._is_tracing = False
 
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index e13ab8f0b92..148df242e48 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.keras.utils import layer_utils
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.platform import tf_logging as logging
@@ -238,6 +239,9 @@ class Network(base_layer.Layer):
       outputs = outputs[0]
     self._nested_outputs = outputs
     self._nested_inputs = inputs
+    self._nested_inputs_are_flat_list = (
+        isinstance(self._nested_inputs, (list, tuple)) and
+        not any(nest.is_sequence(t) for t in self._nested_inputs))
     self.inputs = nest.flatten(inputs)
     self.outputs = nest.flatten(outputs)
 
@@ -814,40 +818,18 @@ class Network(base_layer.Layer):
     # use masking, it does not interfere with regular behavior at all and you
     # can ignore it.
 
-    if isinstance(inputs, dict) and isinstance(self._nested_inputs,
-                                               (list, tuple)):
-      # Backwards compat: Allows passing a dict to a Model constructed with a
-      # list. Matches dict keys to input names.
-      inputs = [
-          inputs[inp._keras_history.layer.name] for inp in self._nested_inputs
-      ]
-    else:
-      inputs = nest.flatten(inputs)
-
+    inputs = self._flatten_to_reference_inputs(inputs)
     if mask is None:
       masks = [None for _ in range(len(inputs))]
     else:
-      masks = nest.flatten(mask)
-
+      masks = self._flatten_to_reference_inputs(mask)
     for input_t, mask in zip(inputs, masks):
       input_t._keras_mask = mask
 
     # Dictionary mapping reference tensors to computed tensors.
     tensor_dict = {}
-
     for x, y in zip(self.inputs, inputs):
-      # Set shape and dtype based on `keras.Input`s.
-      if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
-        try:
-          y.set_shape(y.shape.merge_with(x.shape))
-        except ValueError:
-          logging.warning(
-              'Model was constructed with shape {} for input {}, but it was '
-              're-called on a Tensor with incompatible shape {}.'
-              .format(x, x.shape, y.shape))
-      if isinstance(x, (ops.Tensor, composite_tensor.CompositeTensor)):
-        y = math_ops.cast(y, dtype=x.dtype)
-
+      y = self._conform_to_reference_input(y, ref_input=x)
       x_id = str(id(x))
       tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
 
@@ -925,6 +907,53 @@ class Network(base_layer.Layer):
     output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors)
     return output_tensors
 
+  def _flatten_to_reference_inputs(self, tensors):
+    """Maps `tensors` to their respective `keras.Input`."""
+    if self._nested_inputs_are_flat_list and isinstance(tensors, dict):
+      # Backwards compat: Allows passing a dict to a Model constructed with a
+      # list. Matches dict keys to input names.
+      tensors = [
+          tensors[inp._keras_history.layer.name] for inp in self._nested_inputs
+      ]
+    else:
+      # Otherwise both self.inputs and tensors will be flattened in same order.
+      tensors = nest.flatten(tensors)
+    return tensors
+
+  def _conform_to_reference_input(self, tensor, ref_input):
+    """Set shape and dtype based on `keras.Input`s."""
+    # Shape handling (only for non-CompositeTensors).
+    if isinstance(tensor, ops.Tensor) and isinstance(ref_input, ops.Tensor):
+      # Allow (None,) and (None, 1) Tensors to be passed interchangably. Use the
+      # shape specified by the `keras.Input`.
+      if tensor.shape.rank is not None and ref_input.shape.rank is not None:
+        should_squeeze_last_dim = (
+            tensor.shape.rank == ref_input.shape.rank + 1 and
+            tensor.shape[-1] == 1)
+        should_expand_last_dim = (
+            tensor.shape.rank == ref_input.shape.rank - 1 and
+            ref_input.shape[-1] == 1)
+        if should_squeeze_last_dim:
+          tensor = array_ops.squeeze_v2(tensor, axis=-1)
+        elif should_expand_last_dim:
+          tensor = array_ops.expand_dims_v2(tensor, axis=-1)
+
+      # Add shape hints to Tensors that might have None shape dims but have
+      # shapes defined by the `keras.Input`.
+      try:
+        tensor.set_shape(tensor.shape.merge_with(ref_input.shape))
+      except ValueError:
+        logging.warning(
+            'Model was constructed with shape {} for input {}, but it was '
+            'called on an input with incompatible shape {}.'.format(
+                ref_input.shape, ref_input, tensor.shape))
+
+    # Dtype handling.
+    if isinstance(ref_input, (ops.Tensor, composite_tensor.CompositeTensor)):
+      tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
+
+    return tensor
+
   def get_config(self):
     if not self._is_graph_network:
       raise NotImplementedError
diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py
index 17f08889936..d890cc118ae 100644
--- a/tensorflow/python/keras/engine/network_test.py
+++ b/tensorflow/python/keras/engine/network_test.py
@@ -1879,6 +1879,24 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
     for i in range(999, 1024):
       self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2))
 
+  def test_2d_inputs_squeezed_to_1d(self):
+    input_1d = input_layer_lib.Input(shape=())
+    outputs = input_1d * 2.
+    net = network_lib.Network(input_1d, outputs)
+
+    x = np.ones((10, 1))
+    y = net(x)
+    self.assertEqual(y.shape.rank, 1)
+
+  def test_1d_inputs_expanded_to_2d(self):
+    input_1d = input_layer_lib.Input(shape=(1,))
+    outputs = input_1d * 2.
+    net = network_lib.Network(input_1d, outputs)
+
+    x = np.ones((10,))
+    y = net(x)
+    self.assertEqual(y.shape.rank, 2)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index f9ec6f37b45..c487ba6a322 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -174,6 +174,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
 
     # Fault-tolerance handler. Set in `ModelCheckpoint`.
     self._training_state = None
+    self.history = None
 
   def get_weights(self):
     """Retrieves the weights of the model.
@@ -768,7 +769,11 @@ class Model(network.Network, version_utils.ModelVersionSelector):
                 step_num=step,
                 batch_size=batch_size):
               callbacks.on_train_batch_begin(step)
-              logs = train_function(iterator)
+              tmp_logs = train_function(iterator)
+              # Catch possible OutOfRangeError here.
+              # TODO(b/150292341): Allow multiple async steps.
+              context.async_wait()
+              logs = tmp_logs
               callbacks.on_train_batch_end(step, logs)
         epoch_logs = {m.name: m.result() for m in self.metrics}
 
@@ -995,7 +1000,9 @@ class Model(network.Network, version_utils.ModelVersionSelector):
                 graph_type='test',
                 step_num=step):
               callbacks.on_test_batch_begin(step)
-              logs = test_function(iterator)
+              tmp_logs = test_function(iterator)
+              context.async_wait()  # Possible OutOfRangeError here.
+              logs = tmp_logs
               callbacks.on_test_batch_end(step, logs)
       callbacks.on_test_end()
 
@@ -1175,7 +1182,9 @@ class Model(network.Network, version_utils.ModelVersionSelector):
         with data_handler.catch_stop_iteration():
           for step in data_handler.steps():
             callbacks.on_predict_batch_begin(step)
-            batch_outputs = predict_function(iterator)
+            tmp_batch_outputs = predict_function(iterator)
+            context.async_wait()  # Possible OutOfRangeError here.
+            batch_outputs = tmp_batch_outputs
             if outputs is None:
               outputs = nest.map_structure(lambda batch_output: [batch_output],
                                            batch_outputs)
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 8e0590b9f80..ca4734ef8cc 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -675,8 +675,10 @@ class BatchNormalizationBase(Layer):
         training = bool(training)
       if base_layer_utils.is_in_keras_graph():
         training = math_ops.logical_and(training, self._get_trainable_var())
-      else:
-        training = math_ops.logical_and(training, self.trainable)
+      elif not self.trainable:
+        # When the layer is not trainable, it overrides the value passed from
+        # model.
+        training = self.trainable
     return training
 
   def call(self, inputs, training=None):
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 1c851581a05..e5610309bec 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import base_layer_utils
@@ -48,6 +49,7 @@ from tensorflow.python.keras.losses import mean_squared_logarithmic_error
 from tensorflow.python.keras.losses import poisson
 from tensorflow.python.keras.losses import sparse_categorical_crossentropy
 from tensorflow.python.keras.losses import squared_hinge
+from tensorflow.python.keras.saving.saved_model import metric_serialization
 from tensorflow.python.keras.utils import metrics_utils
 from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
 from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
@@ -63,7 +65,9 @@ from tensorflow.python.ops import nn
 from tensorflow.python.ops import variables as tf_variables
 from tensorflow.python.ops import weights_broadcast_ops
 from tensorflow.python.ops.losses import util as tf_losses_utils
+from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util import nest
+from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import keras_export
 from tensorflow.tools.docs import doc_controls
 
@@ -154,10 +158,13 @@ class Metric(base_layer.Layer):
     # custom metrics in v1 need not worry about control dependencies and
     # return ops.
     if (base_layer_utils.is_in_eager_or_tf_function() or
-        cls.__module__ == Metric.__module__):
+        is_built_in(cls)):
       update_state_fn = obj.update_state
     else:
-      update_state_fn = def_function.function(obj.update_state)
+      if isinstance(obj.update_state, def_function.Function):
+        update_state_fn = obj.update_state
+      else:
+        update_state_fn = def_function.function(obj.update_state)
 
     obj.update_state = types.MethodType(
         metrics_utils.update_state_wrapper(update_state_fn), obj)
@@ -278,6 +285,10 @@ class Metric(base_layer.Layer):
 
   ### End: For use by subclasses ###
 
+  @property
+  def _trackable_saved_model_saver(self):
+    return metric_serialization.MetricSavedModelSaver(self)
+
 
 class Reduce(Metric):
   """Encapsulates metrics that perform a reduce operation on the values."""
@@ -595,11 +606,27 @@ class MeanMetricWrapper(Mean):
 
   def get_config(self):
     config = {}
+
+    if type(self) is MeanMetricWrapper:  # pylint: disable=unidiomatic-typecheck
+      # Only include function argument when the object is a MeanMetricWrapper
+      # and not a subclass.
+      config['fn'] = self._fn
+
     for k, v in six.iteritems(self._fn_kwargs):
       config[k] = K.eval(v) if is_tensor_or_variable(v) else v
     base_config = super(MeanMetricWrapper, self).get_config()
     return dict(list(base_config.items()) + list(config.items()))
 
+  @classmethod
+  def from_config(cls, config):
+    # Note that while MeanMetricWrapper itself isn't public, objects of this
+    # class may be created and added to the model by calling model.compile.
+    if cls is MeanMetricWrapper:
+      fn = get(config.pop('fn'))
+      return cls(fn, **config)
+
+    return super(MeanMetricWrapper, cls).from_config(config)
+
 
 @keras_export('keras.metrics.Accuracy')
 class Accuracy(MeanMetricWrapper):
@@ -1883,7 +1910,7 @@ class AUC(Metric):
     else:
       variable_shape = tensor_shape.TensorShape(
           [tensor_shape.Dimension(self.num_thresholds)])
-
+    self._build_input_shape = shape
     # Create metric variables
     self.true_positives = self.add_weight(
         'true_positives',
@@ -1927,7 +1954,7 @@ class AUC(Metric):
     """
     deps = []
     if not self._built:
-      self._build(y_pred.shape)
+      self._build(tensor_shape.TensorShape(y_pred.shape))
 
     if self.multi_label or (self.label_weights is not None):
       # y_true should have shape (number of examples, number of labels).
@@ -2735,6 +2762,7 @@ class MeanTensor(Metric):
 
   def _build(self, shape):
     self._shape = tensor_shape.TensorShape(shape)
+    self._build_input_shape = self._shape
     # Create new state variables
     self._total = self.add_weight(
         'total', shape=shape, initializer=init_ops.zeros_initializer)
@@ -3251,3 +3279,7 @@ def get(identifier):
     error_msg = 'Could not interpret metric function identifier: {}'.format(
         identifier)
     raise ValueError(error_msg)
+
+
+def is_built_in(cls):
+  return cls.__module__ == Metric.__module__
diff --git a/tensorflow/python/keras/preprocessing/__init__.py b/tensorflow/python/keras/preprocessing/__init__.py
index 58b670d0b0e..1542dab6d35 100644
--- a/tensorflow/python/keras/preprocessing/__init__.py
+++ b/tensorflow/python/keras/preprocessing/__init__.py
@@ -23,12 +23,14 @@ from __future__ import print_function
 import keras_preprocessing
 
 from tensorflow.python.keras import backend
+from tensorflow.python.keras.preprocessing import image
+from tensorflow.python.keras.preprocessing import sequence
+from tensorflow.python.keras.preprocessing import text
 from tensorflow.python.keras.utils import all_utils as utils
 
 # This exists for compatibility with prior version of keras_preprocessing.
 keras_preprocessing.set_keras_submodules(backend=backend, utils=utils)
 
-
 del absolute_import
 del division
 del print_function
diff --git a/tensorflow/python/keras/preprocessing/sequence.py b/tensorflow/python/keras/preprocessing/sequence.py
index 41a3ab6c1be..5ba2e2b47d5 100644
--- a/tensorflow/python/keras/preprocessing/sequence.py
+++ b/tensorflow/python/keras/preprocessing/sequence.py
@@ -24,7 +24,6 @@ from keras_preprocessing import sequence
 from tensorflow.python.keras.utils import data_utils
 from tensorflow.python.util.tf_export import keras_export
 
-pad_sequences = sequence.pad_sequences
 make_sampling_table = sequence.make_sampling_table
 skipgrams = sequence.skipgrams
 # TODO(fchollet): consider making `_remove_long_seq` public.
@@ -34,6 +33,7 @@ _remove_long_seq = sequence._remove_long_seq  # pylint: disable=protected-access
 @keras_export('keras.preprocessing.sequence.TimeseriesGenerator')
 class TimeseriesGenerator(sequence.TimeseriesGenerator, data_utils.Sequence):
   """Utility class for generating batches of temporal data.
+
   This class takes in a sequence of data-points gathered at
   equal intervals, along with time series parameters such as
   stride, length of history, etc., to produce batches for
@@ -89,7 +89,74 @@ class TimeseriesGenerator(sequence.TimeseriesGenerator, data_utils.Sequence):
   pass
 
 
-keras_export('keras.preprocessing.sequence.pad_sequences')(pad_sequences)
+@keras_export('keras.preprocessing.sequence.pad_sequences')
+def pad_sequences(sequences, maxlen=None, dtype='int32',
+                  padding='pre', truncating='pre', value=0.):
+  """Pads sequences to the same length.
+
+  This function transforms a list (of length `num_samples`)
+  of sequences (lists of integers)
+  into a 2D Numpy array of shape `(num_samples, num_timesteps)`.
+  `num_timesteps` is either the `maxlen` argument if provided,
+  or the length of the longest sequence in the list.
+
+  Sequences that are shorter than `num_timesteps`
+  are padded with `value` until they are `num_timesteps` long.
+
+  Sequences longer than `num_timesteps` are truncated
+  so that they fit the desired length.
+
+  The position where padding or truncation happens is determined by
+  the arguments `padding` and `truncating`, respectively.
+  Pre-padding or removing values from the beginning of the sequence is the
+  default.
+
+  >>> sequence = [[1], [2, 3], [4, 5, 6]]
+  >>> tf.keras.preprocessing.sequence.pad_sequences(sequence)
+  array([[0, 0, 1],
+         [0, 2, 3],
+         [4, 5, 6]], dtype=int32)
+
+  >>> tf.keras.preprocessing.sequence.pad_sequences(sequence, value=-1)
+  array([[-1, -1,  1],
+         [-1,  2,  3],
+         [ 4,  5,  6]], dtype=int32)
+
+  >>> tf.keras.preprocessing.sequence.pad_sequences(sequence, padding='post')
+  array([[1, 0, 0],
+         [2, 3, 0],
+         [4, 5, 6]], dtype=int32)
+
+  >>> tf.keras.preprocessing.sequence.pad_sequences(sequence, maxlen=2)
+  array([[0, 1],
+         [2, 3],
+         [5, 6]], dtype=int32)
+
+  Arguments:
+      sequences: List of sequences (each sequence is a list of integers).
+      maxlen: Optional Int, maximum length of all sequences. If not provided,
+          sequences will be padded to the length of the longest individual
+          sequence.
+      dtype: (Optional, defaults to int32). Type of the output sequences.
+          To pad sequences with variable length strings, you can use `object`.
+      padding: String, 'pre' or 'post' (optional, defaults to 'pre'):
+          pad either before or after each sequence.
+      truncating: String, 'pre' or 'post' (optional, defaults to 'pre'):
+          remove values from sequences larger than
+          `maxlen`, either at the beginning or at the end of the sequences.
+      value: Float or String, padding value. (Optional, defaults to 0.)
+
+  Returns:
+      Numpy array with shape `(len(sequences), maxlen)`
+
+  Raises:
+      ValueError: In case of invalid values for `truncating` or `padding`,
+          or in case of invalid shape for a `sequences` entry.
+  """
+  return sequence.pad_sequences(
+      sequences, maxlen=maxlen, dtype=dtype,
+      padding=padding, truncating=truncating, value=value)
+
 keras_export(
     'keras.preprocessing.sequence.make_sampling_table')(make_sampling_table)
 keras_export('keras.preprocessing.sequence.skipgrams')(skipgrams)
diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD
index 7ab6639d118..eda5776c7a6 100644
--- a/tensorflow/python/keras/saving/BUILD
+++ b/tensorflow/python/keras/saving/BUILD
@@ -22,6 +22,7 @@ py_library(
         "saved_model/json_utils.py",
         "saved_model/layer_serialization.py",
         "saved_model/load.py",
+        "saved_model/metric_serialization.py",
         "saved_model/model_serialization.py",
         "saved_model/network_serialization.py",
         "saved_model/save.py",
diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
index 6dffcc65c7e..66a70e55c3b 100644
--- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
@@ -52,13 +52,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
         dtype=policy.serialize(self.obj._dtype_policy),  # pylint: disable=protected-access
         batch_input_shape=getattr(self.obj, '_batch_input_shape', None))
 
-    with generic_utils.skip_failed_serialization():
-      # Store the config dictionary, which may be used when reviving the object.
-      # When loading, the program will attempt to revive the object from config,
-      # and if that fails, the object will be revived from the SavedModel.
-      config = generic_utils.serialize_keras_object(self.obj)['config']
-      if config is not None:
-        metadata['config'] = config
+    metadata.update(get_config(self.obj))
     if self.obj.input_spec is not None:
       # Layer's input_spec has already been type-checked in the property setter.
       metadata['input_spec'] = nest.map_structure(
@@ -109,6 +103,20 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
     return objects, functions
 
 
+# TODO(kathywu): Move serialization utils (and related utils from
+# generic_utils.py) to a separate file.
+def get_config(obj):
+  with generic_utils.skip_failed_serialization():
+    # Store the config dictionary, which may be used when reviving the object.
+    # When loading, the program will attempt to revive the object from config,
+    # and if that fails, the object will be revived from the SavedModel.
+    config = generic_utils.serialize_keras_object(obj)['config']
+
+  if config is not None:
+    return {'config': config}
+  return {}
+
+
 class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
   """InputLayer serialization."""
 
diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py
index af511e6586a..fddd6c46759 100644
--- a/tensorflow/python/keras/saving/saved_model/load.py
+++ b/tensorflow/python/keras/saving/saved_model/load.py
@@ -18,6 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 import re
+import types
 
 from tensorflow.python.eager import context
 from tensorflow.python.eager import function as defun
@@ -32,6 +33,7 @@ from tensorflow.python.keras.saving.saved_model import json_utils
 from tensorflow.python.keras.saving.saved_model import utils
 from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
 from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.keras.utils import metrics_utils
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.saved_model import load as tf_load
 from tensorflow.python.saved_model import nested_structure_coder
@@ -69,6 +71,8 @@ training_lib = LazyLoader(
 training_lib_v1 = LazyLoader(
     "training_lib_v1", globals(),
     "tensorflow.python.keras.engine.training_v1")
+metrics = LazyLoader("metrics", globals(),
+                     "tensorflow.python.keras.metrics")
 # pylint:enable=g-inconsistent-quotes
 
 
@@ -77,9 +81,9 @@ PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
 PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
 
 
-KERAS_OBJECT_IDENTIFIERS = (
-    '_tf_keras_layer', '_tf_keras_input_layer', '_tf_keras_network',
-    '_tf_keras_model', '_tf_keras_sequential')
+KERAS_OBJECT_IDENTIFIERS = ('_tf_keras_layer', '_tf_keras_input_layer',
+                            '_tf_keras_network', '_tf_keras_model',
+                            '_tf_keras_sequential', '_tf_keras_metric')
 
 
 def load(path, compile=True):  # pylint: disable=redefined-builtin
@@ -179,6 +183,20 @@ class KerasObjectLoader(tf_load.Loader):
 
     super(KerasObjectLoader, self).__init__(*args, **kwargs)
 
+    # Now that the node object has been fully loaded, and the checkpoint has
+    # been restored, the object no longer needs to track objects added from
+    # SerializedAttributes. (Note that saving a training checkpoint still
+    # functions correctly, because layers and variables are tracked separately
+    # by the Layer object.)
+    # TODO(kathywu): Instead of outright deleting these nodes (which would
+    # make restoring from a different checkpoint tricky), mark them as extra
+    # dependencies that are OK to overwrite.
+    for node in self._nodes:
+      if not isinstance(node, base_layer.Layer):
+        continue
+      for name in PUBLIC_ATTRIBUTES:
+        delete_tracking(node, name)
+
   def _load_all(self):
     """Reconstruct the object graph from the SavedModel."""
     # Load layer and model objects from either config or SavedModel. The objects
@@ -192,19 +210,6 @@ class KerasObjectLoader(tf_load.Loader):
     # Finish setting up layers and models. See function docstring for more info.
     self._finalize_objects()
 
-    # Now that the node object has been fully loaded, the object no longer needs
-    # to track objects added from SerializedAttributes. (Note that saving a
-    # training checkpoint still functions correctly, because layers and
-    # variables are tracked separately by the Layer object.)
-    # TODO(kathywu): Instead of outright deleting these nodes (which would
-    # make restoring from a different checkpoint tricky), mark them as extra
-    # dependencies that are OK to overwrite.
-    for node in self._nodes:
-      if not isinstance(node, base_layer.Layer):
-        continue
-      for name in PUBLIC_ATTRIBUTES:
-        delete_tracking(node, name)
-
   @property
   def _expect_partial_checkpoint(self):
     return True
@@ -230,10 +235,30 @@ class KerasObjectLoader(tf_load.Loader):
       return
     self._traversed_nodes_from_config.append(node_id)
     obj._maybe_initialize_trackable()
+    if isinstance(obj, base_layer.Layer) and not obj.built:
+      metadata = json_utils.decode(proto.user_object.metadata)
+      self._try_build_layer(obj, node_id, metadata.get('build_input_shape'))
 
+    # Create list of all possible children
+    children = []
+    # Look for direct children
     for reference in proto.children:
       obj_child = obj._lookup_dependency(reference.local_name)
-      child_id = reference.node_id
+      children.append((obj_child, reference.node_id))
+
+    # Add metrics that may have been added to the layer._metrics list.
+    # This is stored in the SavedModel as layer.keras_api.layer_metrics in
+    # SavedModels created after Tf 2.2.
+    metric_list_node_id = self._search_for_child_node(
+        node_id, [constants.KERAS_ATTR, 'layer_metrics'], raise_error=False)
+    if metric_list_node_id is not None and hasattr(obj, '_metrics'):
+      obj_metrics = {m.name: m for m in obj._metrics}
+      for reference in self._proto.nodes[metric_list_node_id].children:
+        metric = obj_metrics.get(reference.local_name)
+        if metric is not None:
+          children.append((metric, reference.node_id))
+
+    for (obj_child, child_id) in children:
       child_proto = self._proto.nodes[child_id]
 
       if not isinstance(obj_child, trackable.Trackable):
@@ -263,10 +288,23 @@ class KerasObjectLoader(tf_load.Loader):
 
   def _load_layers(self):
     layers = {}
+
+    # Load metrics after models and layers, since it's likely that models
+    # and layers will create the metric when initialized (this avoids wasting
+    # time by creating objects multiple times).
+    metric_list = []
     for node_id, proto in enumerate(self._proto.nodes):
-      if (proto.WhichOneof('kind') == 'user_object' and
-          proto.user_object.identifier in KERAS_OBJECT_IDENTIFIERS):
-        layers[node_id] = self._load_layer(proto.user_object, node_id)
+      if (proto.WhichOneof('kind') != 'user_object' or
+          proto.user_object.identifier not in KERAS_OBJECT_IDENTIFIERS):
+        continue
+      if proto.user_object.identifier == '_tf_keras_metric':
+        metric_list.append((node_id, proto))
+        continue
+
+      layers[node_id] = self._load_layer(proto.user_object, node_id)
+
+    for node_id, proto in metric_list:
+      layers[node_id] = self._load_layer(proto.user_object, node_id)
     return layers
 
   def _load_layer(self, proto, node_id):
@@ -277,8 +315,6 @@ class KerasObjectLoader(tf_load.Loader):
     if node_id in self._nodes_recreated_from_config:
       node, setter = self._nodes_recreated_from_config[node_id]
 
-      self._try_build_layer(node, node_id, metadata.get('build_input_shape'))
-
       # Revive setter requires the object to have a `_serialized_attributes`
       # property. Add it here.
       _maybe_add_serialized_attributes(node, metadata)
@@ -291,7 +327,7 @@ class KerasObjectLoader(tf_load.Loader):
 
     # Detect whether this object can be revived from the config. If not, then
     # revive from the SavedModel instead.
-    obj, setter = self._revive_from_config(metadata, node_id)
+    obj, setter = self._revive_from_config(proto.identifier, metadata, node_id)
     if obj is None:
       obj, setter = revive_custom_object(proto.identifier, metadata)
 
@@ -302,10 +338,15 @@ class KerasObjectLoader(tf_load.Loader):
     _maybe_add_serialized_attributes(obj, metadata)
     return obj, setter
 
-  def _revive_from_config(self, metadata, node_id):
+  def _revive_from_config(self, identifier, metadata, node_id):
     """Revives a layer/model from config, or returns None."""
-    obj = (self._revive_graph_network(metadata, node_id) or
-           self._revive_layer_from_config(metadata, node_id))
+    if identifier == '_tf_keras_metric':
+      obj = self._revive_metric_from_config(metadata, node_id)
+    else:
+      obj = (
+          self._revive_graph_network(metadata, node_id) or
+          self._revive_layer_from_config(metadata, node_id))
+
     if obj is None:
       return None, None
 
@@ -382,6 +423,25 @@ class KerasObjectLoader(tf_load.Loader):
 
     return obj
 
+  def _revive_metric_from_config(self, metadata, node_id):
+    class_name = compat.as_str(metadata['class_name'])
+    config = metadata.get('config')
+
+    if not generic_utils.validate_config(config):
+      return None
+
+    try:
+      obj = metrics.deserialize(
+          generic_utils.serialize_keras_class_and_config(class_name, config))
+    except ValueError:
+      return None
+
+    build_input_shape = metadata.get('build_input_shape')
+    if build_input_shape is not None and hasattr(obj, '_build'):
+      obj._build(build_input_shape)  # pylint: disable=protected-access
+
+    return obj
+
   def _try_build_layer(self, obj, node_id, build_input_shape):
     """Attempts to build the layer."""
     if obj.built or hasattr(obj.build, '_is_default'):
@@ -424,16 +484,18 @@ class KerasObjectLoader(tf_load.Loader):
           node_id in self.model_layer_dependencies):
         continue
 
-      # No need to apply the finalizing steps to input layers.
+      self._unblock_model_reconstruction(node_id, node)
+
       if isinstance(node, input_layer.InputLayer):
-        self._unblock_model_reconstruction(node_id, node)
+        continue
+      elif isinstance(node, metrics.Metric):
         continue
 
       if node_id in self._nodes_recreated_from_config:
         layers_revived_from_config.append(node)
       else:
         layers_revived_from_saved_model.append(node)
-      self._unblock_model_reconstruction(node_id, node)
+
     _finalize_saved_model_layers(layers_revived_from_saved_model)
     _finalize_config_layers(layers_revived_from_config)
 
@@ -503,29 +565,63 @@ class KerasObjectLoader(tf_load.Loader):
     self._unblock_model_reconstruction(model_id, model)
 
   def _get_child_layer_node_ids(self, node_id, name):
-    # First, retrieve the node.keras_api.layers attribute, which is a list of
-    # all the layers in the node.
-    keras_attr = self._search_for_child_node(node_id, constants.KERAS_ATTR,
-                                             name)
-    layers_node = self._search_for_child_node(keras_attr, 'layers', name)
-    return [node.node_id for node in self._proto.nodes[layers_node].children]
+    """Returns the node ids of the children layers of a node."""
+    # Retrieve the node id of layer.keras_api.layers.
+    layer_list = self._search_for_child_node(
+        node_id, [constants.KERAS_ATTR, 'layers'], name)
+    return [node.node_id for node in self._proto.nodes[layer_list].children]
 
-  def _search_for_child_node(self, node_id, child_name, debugging_name):
-    for child in self._proto.nodes[node_id].children:
-      if child.local_name == child_name:
-        return child.node_id
-    raise ValueError(
-        'Error when loading {}: could not find attribute {}.\n'
-        'Most likely this object was serialized incorrectly.'
-        .format(debugging_name, child_name))
+  def _search_for_child_node(
+      self, parent_id, path_to_child, debugging_name=None, raise_error=True):
+    """Returns node id of child node.
+
+    A helper method for traversing the object graph proto.
+
+    As an example, say that the object graph proto in the SavedModel contains an
+    object with the following child and grandchild attributes:
+
+    `parent.child_a.child_b`
+
+    This method can be used to retrieve the node id of `child_b` using the
+    parent's node id by calling:
+
+    `_search_for_child_node(parent_id, ['child_a', 'child_b'])`.
+
+    Args:
+      parent_id: node id of parent node
+      path_to_child: list of children names.
+      debugging_name: the name to print out when raising an error.
+      raise_error: Whether to raise an error if the child isn't found.
+
+    Returns:
+      node_id of child, or None if child isn't found.
+
+    Raises:
+      ValueError: if child isn't found and raise_error is True.
+    """
+    if not path_to_child:
+      return parent_id
+
+    for child in self._proto.nodes[parent_id].children:
+      if child.local_name == path_to_child[0]:
+        return self._search_for_child_node(child.node_id, path_to_child[1:],
+                                           debugging_name, raise_error)
+
+    if raise_error:
+      raise ValueError(
+          'Error when loading {}: could not find attribute {}.\n'
+          'Most likely this object was serialized incorrectly.'
+          .format(debugging_name or path_to_child[0], path_to_child[0]))
+    else:
+      return None
 
   def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
     """Infers input shape of layer from SavedModel functions."""
     coder = nested_structure_coder.StructureCoder()
-    try:
-      call_fn_id = self._search_for_child_node(
-          layer_node_id, 'call_and_return_all_conditional_losses', None)
-    except ValueError:
+    call_fn_id = self._search_for_child_node(
+        layer_node_id, ['call_and_return_all_conditional_losses'], None,
+        raise_error=False)
+    if call_fn_id is None:
       return None
 
     concrete_functions = (
@@ -579,6 +675,10 @@ def _finalize_saved_model_layers(layers):
     # 3. Add losses that aren't generated by the layer.call function.
     _restore_layer_unconditional_losses(layer)
     _restore_layer_activation_loss(layer)
+
+    # 4. Restore metrics list
+    _restore_layer_metrics(layer)
+
   # pylint: enable=protected-access
 
 
@@ -600,6 +700,15 @@ def _finalize_config_layers(layers):
     # loading behavior between HDF5 and SavedModel.
     _restore_layer_activation_loss(layer)
 
+    # Restore metrics list.
+    _restore_layer_metrics(layer)
+
+
+def _finalize_metric(metric):
+  metric.update_state = types.MethodType(metrics_utils.update_state_wrapper(
+      metric.keras_api.update_state), metric)
+  metric.result = metric.keras_api.result
+
 
 def _restore_layer_unconditional_losses(layer):
   """Restore unconditional losses from SavedModel."""
@@ -641,7 +750,7 @@ def revive_custom_object(identifier, metadata):
       '_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer),
       '_tf_keras_network': (RevivedNetwork, network_lib.Network),
       '_tf_keras_model': (RevivedNetwork, model_class),
-      '_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential)
+      '_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential),
   }
 
   parent_classes = revived_classes.get(identifier, None)
@@ -651,6 +760,21 @@ def revive_custom_object(identifier, metadata):
     revived_cls = type(
         compat.as_str(metadata['class_name']), parent_classes, {})
     return revived_cls._init_from_metadata(metadata)  # pylint: disable=protected-access
+  else:
+    raise ValueError('Unable to restore custom object of type {} currently. '
+                     'Please make sure that the layer implements `get_config`'
+                     'and `from_config` when saving. In addition, please use '
+                     'the `custom_objects` arg when calling `load_model()`.'
+                     .format(identifier))
+
+
+def _restore_layer_metrics(layer):
+  metrics_list = getattr(_get_keras_attr(layer), 'layer_metrics', {})
+  layer_metrics = {m.name: m for m in layer._metrics}  # pylint: disable=protected-access
+  for name, metric in metrics_list.items():
+    if name not in layer_metrics:
+      # Metrics may be added during initialization/building of custom layers.
+      layer._metrics.append(metric)  # pylint: disable=protected-access
 
 
 # TODO(kathywu): Centrally define keys and functions for both  serialization and
diff --git a/tensorflow/python/keras/saving/saved_model/metric_serialization.py b/tensorflow/python/keras/saving/saved_model/metric_serialization.py
new file mode 100644
index 00000000000..efe977ec55f
--- /dev/null
+++ b/tensorflow/python/keras/saving/saved_model/metric_serialization.py
@@ -0,0 +1,45 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classes and functions implementing Metrics SavedModel serialization."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.saving.saved_model import layer_serialization
+from tensorflow.python.training.tracking import data_structures
+
+
+class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver):
+  """Metric serialization."""
+
+  @property
+  def object_identifier(self):
+    return '_tf_keras_metric'
+
+  def _python_properties_internal(self):
+    metadata = dict(
+        class_name=type(self.obj).__name__,
+        name=self.obj.name,
+        dtype=self.obj.dtype)
+    metadata.update(layer_serialization.get_config(self.obj))
+    if self.obj._build_input_shape is not None:  # pylint: disable=protected-access
+      metadata['build_input_shape'] = self.obj._build_input_shape  # pylint: disable=protected-access
+    return metadata
+
+  def _get_serialized_attributes_internal(self, unused_serialization_cache):
+    return (dict(variables=data_structures.ListWrapper(self.obj.variables)),
+            dict())  # TODO(b/135550038): save functions to enable saving
+                     # custom metrics.
diff --git a/tensorflow/python/keras/saving/saved_model/revive_test.py b/tensorflow/python/keras/saving/saved_model/revive_test.py
index ca3ecfc5a77..575bf0fc4bc 100644
--- a/tensorflow/python/keras/saving/saved_model/revive_test.py
+++ b/tensorflow/python/keras/saving/saved_model/revive_test.py
@@ -90,6 +90,8 @@ class CustomLayerNoConfig(keras.layers.Layer):
     def a_regularizer():
       return self.a * 2
     self.add_loss(a_regularizer)
+    self.sum_metric = keras.metrics.Sum(name='inputs_sum')
+    self.unused_metric = keras.metrics.Sum(name='not_added_to_metrics')
 
   def build(self, input_shape):
     self.c = variables.Variable(
@@ -97,6 +99,9 @@ class CustomLayerNoConfig(keras.layers.Layer):
 
   def call(self, inputs):
     self.add_loss(math_ops.reduce_sum(inputs), inputs)
+    self.add_metric(self.sum_metric(inputs))
+    self.add_metric(inputs, aggregation='mean', name='mean')
+
     return inputs + self.c
 
 
@@ -143,6 +148,9 @@ class TestModelRevive(keras_parameterized.TestCase):
     self.assertAllClose(model(input_arr), revived(input_arr))
     self.assertAllClose(sum(model.losses), sum(revived.losses))
     self.assertAllClose(len(model.losses), len(revived.losses))
+    self.assertEqual(len(model.metrics), len(revived.metrics))
+    self.assertAllClose([m.result() for m in model.metrics],
+                        [m.result() for m in revived.metrics])
     model_layers = {layer.name: layer for layer in model.layers}
     revived_layers = {layer.name: layer for layer in revived.layers}
     self.assertAllEqual(model_layers.keys(), revived_layers.keys())
diff --git a/tensorflow/python/keras/saving/saved_model/save_impl.py b/tensorflow/python/keras/saving/saved_model/save_impl.py
index 7bd2b52fe84..9d9ff2b0ed2 100644
--- a/tensorflow/python/keras/saving/saved_model/save_impl.py
+++ b/tensorflow/python/keras/saving/saved_model/save_impl.py
@@ -53,6 +53,8 @@ from tensorflow.python.util.lazy_loader import LazyLoader
 base_layer = LazyLoader(
     "base_layer", globals(),
     "tensorflow.python.keras.engine.base_layer")
+metrics = LazyLoader("metrics", globals(),
+                     "tensorflow.python.keras.metrics")
 input_layer = LazyLoader(
     "input_layer", globals(),
     "tensorflow.python.keras.engine.input_layer")
@@ -108,6 +110,9 @@ def wrap_layer_objects(layer, serialization_cache):
       wrapped_loss_functions.append(wrapped_loss)
   wrapped_layer_losses = [keras_loss_cache[fn]
                           for fn in layer._callable_losses[:]]  # pylint: disable=protected-access
+
+  layer_metrics = data_structures._DictWrapper(  # pylint: disable=protected-access
+      {m.name: m for m in layer._metrics})  # pylint: disable=protected-access
   return dict(
       variables=data_structures.ListWrapper(layer.variables),
       trainable_variables=data_structures.ListWrapper(
@@ -119,7 +124,9 @@ def wrap_layer_objects(layer, serialization_cache):
       regularization_losses=data_structures.ListWrapper(
           wrapped_loss_functions),
       layer_regularization_losses=data_structures.ListWrapper(
-          wrapped_layer_losses))
+          wrapped_layer_losses),
+      layer_metrics=layer_metrics)
+  # pylint: disable=protected-access
 
 
 def wrap_layer_functions(layer, serialization_cache):
@@ -218,24 +225,55 @@ def _replace_child_layer_functions(layer, serialization_cache):
       { Child layer 1: {
           'losses': Original losses,
           'call': Original call function
-          'activity_regularizer': Original activity regularizer},
+          '_activity_regularizer': Original activity regularizer},
         Child layer 2: ...
       }
   """
   # pylint: disable=protected-access
   original_fns = {}
+
+  def replace_layer_functions(child_layer, serialized_fns):
+    """Replaces layer call and activity regularizer with wrapped functions."""
+    original_fns[child_layer] = {
+        'call': child_layer.call,
+        '_activity_regularizer': child_layer._activity_regularizer
+    }
+    with trackable.no_automatic_dependency_tracking_scope(child_layer):
+      try:
+        child_layer._activity_regularizer = serialized_fns.get(
+            'activity_regularizer_fn')
+      except AttributeError:
+        # Some layers have an unsettable activity regularizer.
+        pass
+      child_layer.call = utils.use_wrapped_call(
+          child_layer,
+          serialized_fns['call_and_return_conditional_losses'],
+          default_training_value=False)
+
+  def replace_metric_functions(child_layer, serialized_fns):
+    """Replaces metric functions with wrapped functions."""
+    original_fns[child_layer] = {
+        '__call__': child_layer.__call__,
+        'result': child_layer.result,
+        'update_state': child_layer.update_state
+    }
+    with trackable.no_automatic_dependency_tracking_scope(child_layer):
+      child_layer.__call__ = serialized_fns['__call__']
+      child_layer.result = serialized_fns['result']
+      child_layer.update_state = serialized_fns['update_state']
+
   for child_layer in utils.list_all_layers(layer):
     if isinstance(child_layer, input_layer.InputLayer):
       continue
 
     if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
-      layer_fns = (
+      serialized_functions = (
           child_layer._trackable_saved_model_saver._get_serialized_attributes(
               serialization_cache).functions)
     else:
-      layer_fns = (
+      serialized_functions = (
           serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions)
-    if not layer_fns:
+    if not serialized_functions:
       # This indicates either:
       #   - circular dependency, which means the current layer's functions
       #     should be wrapped first.
@@ -243,20 +281,12 @@ def _replace_child_layer_functions(layer, serialization_cache):
       #     wrapped. In this case, no replacement is necessary so move on to the
       #     next child.
       continue
-    original_fns[child_layer] = {
-        'call': child_layer.call,
-        'activity_regularizer': child_layer._activity_regularizer
-    }
-    with trackable.no_automatic_dependency_tracking_scope(child_layer):
-      try:
-        child_layer._activity_regularizer = layer_fns.get(
-            'activity_regularizer_fn')
-      except AttributeError:
-        # Some layers have an unsettable activity regularizer.
-        pass
-      child_layer.call = utils.use_wrapped_call(
-          child_layer, layer_fns['call_and_return_conditional_losses'],
-          default_training_value=False)
+
+    if isinstance(child_layer, metrics.Metric):
+      replace_metric_functions(child_layer, serialized_functions)
+    else:
+      replace_layer_functions(child_layer, serialized_functions)
+
   return original_fns
   # pylint: enable=protected-access
 
@@ -265,11 +295,12 @@ def _restore_child_layer_functions(original_fns):
   """Restores attributes replaced with `_replace_child_layer_functions`."""
   for child_layer, fns in original_fns.items():
     with trackable.no_automatic_dependency_tracking_scope(child_layer):
-      child_layer.call = fns['call']
-      try:
-        child_layer._activity_regularizer = fns['activity_regularizer']  # pylint: disable=protected-access
-      except AttributeError:
-        pass
+      for fn_name, fn in fns.items():
+        try:
+          setattr(child_layer, fn_name, fn)  # pylint: disable=protected-access
+        except AttributeError:
+          pass  # In the case of _activity_regularizer, setting the attribute
+          # may be disallowed.
 
 
 # pylint: disable=protected-access
diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
index da86a7cdac1..63eeceb2a6b 100644
--- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py
+++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
@@ -27,6 +27,7 @@ from __future__ import print_function
 import os
 import shutil
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.core.example import example_pb2
@@ -49,6 +50,7 @@ from tensorflow.python.keras import regularizers
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.saving.saved_model import load as keras_load
 from tensorflow.python.keras.saving.saved_model import save_impl as keras_save
+from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
@@ -746,5 +748,128 @@ class TestLayerCallTracing(test.TestCase):
     self.assertAllEqual(previous_losses, layer.losses)
 
 
+@test_util.run_all_in_graph_and_eager_modes
+class MetricTest(test.TestCase, parameterized.TestCase):
+
+  def _save_model_dir(self, dirname='saved_model'):
+    temp_dir = self.get_temp_dir()
+    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
+    return os.path.join(temp_dir, dirname)
+
+  def generate_inputs(self, num_tensor_args, shape=(1, 5)):
+    return [
+        np.random.uniform(0, 1, shape).astype('float32')
+        for _ in range(num_tensor_args)
+    ]
+
+  def _test_metric_save_and_load(self,
+                                 metric,
+                                 save_dir,
+                                 num_tensor_args,
+                                 shape=(1, 5),
+                                 test_sample_weight=True):
+    tf_save.save(metric, save_dir)
+    loaded = keras_load.load(save_dir)
+    self.evaluate([v.initializer for v in loaded.variables])
+    self.assertEqual(metric.name, loaded.name)
+    self.assertEqual(metric.dtype, loaded.dtype)
+
+    inputs = self.generate_inputs(num_tensor_args, shape)
+    actual = self.evaluate(metric(*inputs))
+    self.assertAllClose(actual, loaded(*inputs))
+    self.assertAllClose(metric.variables, loaded.variables)
+
+    # Test with separate calls to update state and result.
+    inputs = self.generate_inputs(num_tensor_args, shape)
+    self.evaluate(metric.update_state(*inputs))
+    self.evaluate(loaded.update_state(*inputs))
+    actual = self.evaluate(metric.result())
+    self.assertAllClose(actual, loaded.result())
+
+    if test_sample_weight:
+      # Test with sample weights input.
+      inputs = self.generate_inputs(num_tensor_args, shape)
+      sample_weight = self.generate_inputs(1, [])[0]
+      inputs.append(sample_weight)
+
+      actual = self.evaluate(metric(*inputs))
+      self.assertAllClose(actual, loaded(*inputs))
+    return loaded
+
+  @parameterized.named_parameters([
+      ('mean', keras.metrics.Mean, 1, (1, 5)),
+      ('false_positives', keras.metrics.FalsePositives, 2, (1, 5)),
+      ('precision_at_top_k', keras.metrics.Precision, 2, (2, 3, 4), {
+          'top_k': 2,
+          'class_id': 1
+      }),
+      ('precision_at_recall', keras.metrics.PrecisionAtRecall, 2, (1, 5), {
+          'recall': .8
+      }), ('auc', keras.metrics.AUC, 2, (1, 5), {
+          'multi_label': True
+      }), ('cosine_similarity', keras.metrics.CosineSimilarity, 2, (2, 3, 1))
+  ])
+  def test_metric(self, metric_cls, num_tensor_args, shape, init_kwargs=None):
+    init_kwargs = init_kwargs or {}
+    metric = metric_cls(**init_kwargs)
+    metric(*self.generate_inputs(num_tensor_args, shape))
+    self.evaluate([v.initializer for v in metric.variables])
+    loaded = self._test_metric_save_and_load(metric, self._save_model_dir(),
+                                             num_tensor_args, shape)
+    self.assertEqual(type(loaded), type(metric))
+
+  @parameterized.named_parameters([
+      ('mean', keras.metrics.Mean, 1, False),
+      ('auc', keras.metrics.AUC, 2, False),
+      ('mean_tensor', keras.metrics.MeanTensor, 1, True)])
+  def test_custom_metric(self, base_cls, num_tensor_args, requires_build):
+
+    class CustomMetric(base_cls):
+
+      def update_state(self, *args):  # pylint: disable=useless-super-delegation
+        # Sometimes built-in metrics return an op in update_state. Custom
+        # metrics don't support returning ops, so wrap the update_state method
+        # while returning nothing.
+        super(CustomMetric, self).update_state(*args)
+
+    metric = CustomMetric()
+    save_dir = self._save_model_dir('first_save')
+
+    if requires_build:
+      metric(*self.generate_inputs(num_tensor_args))  # pylint: disable=not-callable
+
+    self.evaluate([v.initializer for v in metric.variables])
+
+    with self.assertRaisesRegexp(ValueError, 'Unable to restore custom object'):
+      self._test_metric_save_and_load(metric, save_dir, num_tensor_args)
+    with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}):
+      loaded = self._test_metric_save_and_load(
+          metric,
+          save_dir,
+          num_tensor_args,
+          test_sample_weight=False)
+
+      self._test_metric_save_and_load(
+          loaded,
+          self._save_model_dir('second_save'),
+          num_tensor_args,
+          test_sample_weight=False)
+
+  def test_custom_metric_wrapped_call(self):
+
+    class NegativeMean(keras.metrics.Mean):
+
+      @def_function.function(
+          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
+      def update_state(self, value):
+        super(NegativeMean, self).update_state(-value)
+
+    metric = NegativeMean()
+    self.evaluate([v.initializer for v in metric.variables])
+    with generic_utils.CustomObjectScope({'NegativeMean': NegativeMean}):
+      self._test_metric_save_and_load(
+          metric, self._save_model_dir(), 1, test_sample_weight=False)
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/saving/saved_model/serialized_attributes.py b/tensorflow/python/keras/saving/saved_model/serialized_attributes.py
index c62dbb697a5..c87a0d64cf7 100644
--- a/tensorflow/python/keras/saving/saved_model/serialized_attributes.py
+++ b/tensorflow/python/keras/saving/saved_model/serialized_attributes.py
@@ -34,6 +34,8 @@ base_layer = LazyLoader(
 training_lib = LazyLoader(
     "training_lib", globals(),
     "tensorflow.python.keras.engine.training")
+metrics = LazyLoader("metrics", globals(),
+                     "tensorflow.python.keras.metrics")
 # pylint:enable=g-inconsistent-quotes
 
 
@@ -138,6 +140,8 @@ class SerializedAttributes(object):
   def new(obj):
     if isinstance(obj, training_lib.Model):
       return ModelAttributes()
+    elif isinstance(obj, metrics.Metric):
+      return MetricAttributes()
     elif isinstance(obj, base_layer.Layer):
       return LayerAttributes()
     else:
@@ -203,7 +207,8 @@ class SerializedAttributes(object):
         self._object_dict[key] = object_dict[key]
         setattr(self._keras_trackable, key, object_dict[key])
       else:
-        raise ValueError('Object {} missing from serialized object dict.')
+        raise ValueError(
+            'Object {} missing from serialized object dict.'.format(key))
     return self.checkpointable_objects
 
 
@@ -233,7 +238,7 @@ class CommonEndpoints(SerializedAttributes.with_attributes(
 class LayerAttributes(SerializedAttributes.with_attributes(
     'LayerAttributes',
     checkpointable_objects=['non_trainable_variables', 'layers', 'metrics',
-                            'layer_regularization_losses'],
+                            'layer_regularization_losses', 'layer_metrics'],
     functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'],
     copy_from=[CommonEndpoints]
     )):
@@ -252,6 +257,7 @@ class LayerAttributes(SerializedAttributes.with_attributes(
       activity regularizer.
     activity_regularizer_fn: Callable that returns the activity regularizer loss
     layer_regularization_losses: List of losses owned only by this layer.
+    layer_metrics: List of metrics owned by this layer.
   """
 
 
@@ -265,3 +271,17 @@ class ModelAttributes(SerializedAttributes.with_attributes(
   """
   # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which
   #  list all losses and metrics defined by `model.compile`.
+
+
+class MetricAttributes(
+    SerializedAttributes.with_attributes(
+        'MetricAttributes',
+        checkpointable_objects=['variables'],
+        functions=[],
+    )):
+  """Attributes that are added to Metric objects when saved to SavedModel.
+
+  List of all attributes:
+    variables: list of all variables
+  """
+  pass
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 60d34a0e299..94a0e73e64f 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -780,21 +780,29 @@ class CondV2Test(test.TestCase):
 
     self.evaluate(variables.global_variables_initializer())
 
+    def update_v1():
+      v1.assign(v1)
+      return v1
+
+    def update_v2():
+      v2.assign(v2)
+      return v2
+
     @def_function.function
     def fn_with_cond():
       cond_v2.cond_v2(
           constant_op.constant(True),
-          lambda: v1,
+          update_v1,
           lambda: constant_op.constant(0.),
           name="cond_1")
       cond_2 = cond_v2.cond_v2(
           constant_op.constant(False),
           lambda: constant_op.constant(0.),
-          lambda: v1,
+          update_v1,
           name="cond_2")
       cond_v2.cond_v2(
           constant_op.constant(True),
-          lambda: v2,
+          update_v2,
           lambda: constant_op.constant(0.),
           name="cond_3")
       cond_4 = cond_v2.cond_v2(
@@ -841,24 +849,34 @@ class CondV2Test(test.TestCase):
 
     @def_function.function
     def fn_with_cond():
+
+      def update_v1():
+        v1.assign(v1)
+        return v1
+
+      def update_v2():
+        v2.assign(v2)
+        return v2
+
       cond_v2.cond_v2(
           constant_op.constant(True),
-          lambda: v1,
+          update_v1,
           lambda: constant_op.constant(0.),
           name="cond_1")
       cond_2 = cond_v2.cond_v2(
           constant_op.constant(False),
           lambda: constant_op.constant(0.),
-          lambda: v1,
+          update_v1,
           name="cond_2")
       cond_v2.cond_v2(
           constant_op.constant(True),
-          lambda: v2,
+          update_v2,
           lambda: constant_op.constant(0.),
           name="cond_3")
 
       @def_function.function
       def cond_4_false_branch():
+        v2.assign(v2)
         return v2
 
       cond_4 = cond_v2.cond_v2(
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
index 15a389294e5..8f131723cd2 100644
--- a/tensorflow/python/kernel_tests/while_v2_test.py
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -340,14 +340,24 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 
     @def_function.function
     def Fn():
+
+      def Body1(v):
+        x.assign(x)
+        return v * x
+
       ret1 = while_loop_v2(
           lambda v: v < 4.,
-          lambda v: v * x, [c],
+          Body1, [c],
           return_same_structure=False,
           name="while_1")  # 2x
+
+      def Body2(v):
+        x.assign(x)
+        return v * x * x
+
       ret2 = while_loop_v2(
           lambda v: v < 16.,
-          lambda v: v * x * x, [c],
+          Body2, [c],
           return_same_structure=False,
           name="while_2")  # 4x
       return ret1, ret2
@@ -368,24 +378,44 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 
     @def_function.function
     def Fn():
+
+      def Body1(v):
+        x1.assign(x1)
+        return v * x1
+
       ret1 = while_loop_v2(
           lambda v: v < 4.,
-          lambda v: v * x1, [c],
+          Body1, [c],
           return_same_structure=False,
           name="while_1")  # 2x
+
+      def Body2(v):
+        x1.assign(x1)
+        return v * x1 * x1
+
       ret2 = while_loop_v2(
           lambda v: v < 16.,
-          lambda v: v * x1 * x1, [c],
+          Body2, [c],
           return_same_structure=False,
           name="while_2")  # 4x
+
+      def Body3(v):
+        x2.assign(x2)
+        return v * x2
+
       ret3 = while_loop_v2(
           lambda v: v < 4.,
-          lambda v: v * x2, [c],
+          Body3, [c],
           return_same_structure=False,
           name="while_3")  # 3x
+
+      def Body4(v):
+        x2.assign(x2)
+        return v * x2 * x2
+
       ret4 = while_loop_v2(
           lambda v: v < 16.,
-          lambda v: v * x2 * x2, [c],
+          Body4, [c],
           return_same_structure=False,
           name="while_4")  # 9x
       ret5 = while_loop_v2(
diff --git a/tensorflow/python/module/BUILD b/tensorflow/python/module/BUILD
index 4585d39e592..fea6fe123ad 100644
--- a/tensorflow/python/module/BUILD
+++ b/tensorflow/python/module/BUILD
@@ -28,6 +28,7 @@ tf_py_test(
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:variables",
         "//tensorflow/python/compat:v2_compat",
+        "//tensorflow/python/distribute:tpu_values",
         "@absl_py//absl/testing:parameterized",
     ],
 )
diff --git a/tensorflow/python/module/module_test.py b/tensorflow/python/module/module_test.py
index 7fa4fc14d7f..b2fc4ff9645 100644
--- a/tensorflow/python/module/module_test.py
+++ b/tensorflow/python/module/module_test.py
@@ -26,6 +26,7 @@ from absl.testing import parameterized
 import six
 
 from tensorflow.python import tf2
+from tensorflow.python.distribute import tpu_values
 from tensorflow.python.distribute import values as distributed_values
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
@@ -249,10 +250,8 @@ class VariableTrackingTest(test_util.TensorFlowTestCase):
   def test_supports_distributed_variables(self):
     mirrored = distributed_values.MirroredVariable(
         None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
-    tpu = distributed_values.TPUMirroredVariable(
-        strategy=None,
-        values=[variables.Variable(42.)],
-        aggregation=None)
+    tpu = tpu_values.TPUMirroredVariable(
+        strategy=None, values=[variables.Variable(42.)], aggregation=None)
     aggregating = distributed_values.AggregatingVariable(
         strategy=None, v=variables.Variable(1.), aggregation=None)
 
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index fd0102328d1..19f777115db 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -26,6 +26,7 @@ from __future__ import print_function
 import collections
 
 from tensorflow.python.eager import backprop_util
+from tensorflow.python.framework import auto_control_deps_utils as acd
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import func_graph as func_graph_module
@@ -279,6 +280,7 @@ def _build_cond(pred,
     if_op._false_graph = false_graph
     util.maybe_set_lowering_attr(if_op)
     util.maybe_propagate_compile_time_consts_in_xla(if_op)
+    _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
     # Prevent fetching since the variant outputs can't be fetched directly.
     if_op.graph.prevent_fetching(if_op)
 
@@ -1105,6 +1107,7 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
   if case_op is not None:
     util.maybe_set_lowering_attr(case_op)
     util.maybe_propagate_compile_time_consts_in_xla(case_op)
+    _set_read_only_resource_inputs_attr(case_op, branch_graphs)
     # Prevent fetching since the variant outputs can't be fetched directly.
     case_op.graph.prevent_fetching(case_op)
 
@@ -1119,3 +1122,28 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
   tensors = [array_ops.identity(t) for t in tensors]
 
   return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
+
+
+def _set_read_only_resource_inputs_attr(op, branch_graphs):
+  """Sets the list of resource inputs which are read-only.
+
+  This is used by AutomaticControlDependencies.
+
+  Args:
+    op: If or Case Operation.
+    branch_graphs: List of branch FuncGraphs.
+  """
+  read_only_indices = []
+  for i in range(1, len(op.inputs)):
+    if op.inputs[i].dtype != dtypes.resource:
+      continue
+    has_write = False
+    for branch_graph in branch_graphs:
+      handle = branch_graph.inputs[i - 1]
+      if acd.resource_has_writes(handle):
+        has_write = True
+        break
+    if not has_write:
+      read_only_indices.append(i)
+  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
+                        read_only_indices)
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 4698f870785..a43cfc0be65 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.python.eager import context
+from tensorflow.python.framework import auto_control_deps_utils as acd
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
@@ -1178,4 +1179,25 @@ def partitioned_call(args,
     op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr]
   op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
   outputs = op.outputs
+  if hasattr(f, "graph"):
+    _set_read_only_resource_inputs_attr(op, f.graph)
   return outputs if outputs else op
+
+
+def _set_read_only_resource_inputs_attr(op, func_graph):
+  """Sets the list of resource inputs which are read-only.
+
+  This is used by AutomaticControlDependencies.
+
+  Args:
+    op: PartitionedCall Operation.
+    func_graph: FuncGraph.
+  """
+  read_only_indices = []
+  for i in range(len(op.inputs)):
+    handle = func_graph.inputs[i]
+    if handle.dtype != dtypes.resource or acd.resource_has_writes(handle):
+      continue
+    read_only_indices.append(i)
+  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
+                        read_only_indices)
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index cab7573256b..e6a67efa301 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -1735,7 +1735,8 @@ class SpectralTest(PForTestCase, parameterized.TestCase):
       (fft_ops.irfft2d,),
       (fft_ops.irfft3d,),
   )
-  def test_irfft(self, op_func):
+  # TODO(agarwal): Reenable this once the test flaky is fixed.
+  def disabled_test_irfft(self, op_func):
     for dtype in (dtypes.complex64, dtypes.complex128):
       shape = [2, 3, 4, 3, 4]
       x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
diff --git a/tensorflow/python/ops/ragged/ragged_string_ops.py b/tensorflow/python/ops/ragged/ragged_string_ops.py
index d5a508a08be..493e5b97cd6 100755
--- a/tensorflow/python/ops/ragged/ragged_string_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_string_ops.py
@@ -457,8 +457,8 @@ def string_split_v2(input, sep=None, maxsplit=-1, name=None):  # pylint: disable
   """Split elements of `input` based on `sep` into a `RaggedTensor`.
 
   Let N be the size of `input` (typically N will be the batch size). Split each
-  element of `input` based on `sep` and return a `SparseTensor` or
-  `RaggedTensor` containing the split tokens. Empty tokens are ignored.
+  element of `input` based on `sep` and return a `RaggedTensor` containing the 
+  split tokens. Empty tokens are ignored.
 
   Example:
 
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 6e02531fccb..20e2b87c69b 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -29,6 +29,7 @@ from tensorflow.python import _pywrap_utils
 from tensorflow.python.client import pywrap_tf_session
 from tensorflow.python.eager import context
 from tensorflow.python.eager import tape
+from tensorflow.python.framework import auto_control_deps_utils as acd
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import cpp_shape_inference_pb2
 from tensorflow.python.framework import dtypes
@@ -53,6 +54,13 @@ from tensorflow.python.util.deprecation import deprecated
 from tensorflow.python.util.deprecation import deprecated_args
 
 
+acd.register_read_only_resource_op("ReadVariableOp")
+acd.register_read_only_resource_op("VariableShape")
+acd.register_read_only_resource_op("ResourceGather")
+acd.register_read_only_resource_op("ResourceGatherNd")
+acd.register_read_only_resource_op("_ReadVariablesOp")
+
+
 def get_resource_handle_data(graph_op):
   assert type(graph_op) == ops.Tensor  # pylint: disable=unidiomatic-typecheck
 
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 4b8f45e2617..c7ec0668ef0 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -26,6 +26,7 @@ from __future__ import print_function
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.python.client import pywrap_tf_session as c_api
 from tensorflow.python.eager import backprop_util
+from tensorflow.python.framework import auto_control_deps_utils as acd
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import func_graph as func_graph_module
@@ -275,13 +276,8 @@ def while_loop(cond,
       # This is needed so we do not compute derivative wrt these extra outputs.
       outputs[0].op._set_attr("_num_original_outputs",
                               attr_value_pb2.AttrValue(i=num_original_outputs))
-
     outputs[0].op._cond_graph = cond_graph
     outputs[0].op._body_graph = body_graph
-    _copy_handle_data(body_graph.outputs, outputs)
-    util.maybe_set_lowering_attr(outputs[0].op)
-    util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op)
-
     if not ops.get_default_graph().building_function:
       # In V1 graph mode, return identities for each output of the While op,
       # rather than the output of the While op directly. This makes pruning work
@@ -401,11 +397,6 @@ def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
       output_shapes=[t.shape for t in body_grad_graph.outputs],
       parallel_iterations=parallel_iterations,
       name="%s_grad" % while_op.name)
-  grad_op = outputs[0].op
-
-  _copy_handle_data(body_grad_graph.outputs, outputs)
-  util.maybe_set_lowering_attr(grad_op)
-  util.maybe_propagate_compile_time_consts_in_xla(grad_op)
 
   # See comment in while_loop.
   outputs = [array_ops.identity(t) for t in outputs]
@@ -426,13 +417,19 @@ def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
   else:
     op_fn = gen_functional_ops.stateless_while
 
-  return op_fn(
+  outputs = op_fn(
       loop_vars,
       util.create_new_tf_function(cond_graph),
       util.create_new_tf_function(body_graph),
       output_shapes=output_shapes,
       parallel_iterations=parallel_iterations,
       name=name)
+  while_op = outputs[0].op
+  _copy_handle_data(body_graph.outputs, outputs)
+  util.maybe_set_lowering_attr(while_op)
+  util.maybe_propagate_compile_time_consts_in_xla(while_op)
+  _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
+  return outputs
 
 
 def _get_intermediates(func_graph):
@@ -1294,4 +1291,28 @@ class _OperationWithOutputs(ops.Operation):
     self._is_stateful = False
 
 
+def _set_read_only_resource_inputs_attr(op, branch_graphs):
+  """Sets the list of resource inputs which are read-only.
+
+  This is used by AutomaticControlDependencies.
+
+  Args:
+    op: While Operation.
+    branch_graphs: List of branch FuncGraphs.
+  """
+  read_only_indices = []
+  for i in range(len(op.inputs)):
+    if op.inputs[i].dtype != dtypes.resource:
+      continue
+    has_write = False
+    for branch_graph in branch_graphs:
+      handle = branch_graph.inputs[i]
+      if acd.resource_has_writes(handle):
+        has_write = True
+        break
+    if not has_write:
+      read_only_indices.append(i)
+  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
+                        read_only_indices)
+
 # pylint: enable=protected-access
diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc
index 5c11fbb1cff..ef1a1c032f7 100644
--- a/tensorflow/python/profiler/internal/profiler_wrapper.cc
+++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc
@@ -40,6 +40,7 @@ tensorflow::ProfileRequest MakeProfileRequest() {
   tensorflow::ProfileRequest request;
   request.add_tools("overview_page");
   request.add_tools("input_pipeline");
+  request.add_tools("kernel_stats");
   request.add_tools("tensorflow_stats");
   return request;
 }
diff --git a/tensorflow/python/profiler/profiler_v2.py b/tensorflow/python/profiler/profiler_v2.py
index afbe1ec5881..ba65aea7621 100644
--- a/tensorflow/python/profiler/profiler_v2.py
+++ b/tensorflow/python/profiler/profiler_v2.py
@@ -106,6 +106,18 @@ def stop(save=True):
     _profiler = None
 
 
+def warmup():
+  """Warm-up the profiler session.
+
+  The profiler session will set up profiling context, including loading CUPTI
+  library for GPU profiling. This is used for improving the accuracy of
+  the profiling results.
+
+  """
+  start('')
+  stop(save=False)
+
+
 @tf_export('profiler.experimental.server.start', v1=[])
 def start_server(port):
   """Start a profiler grpc server that listens to given port.
diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py
index 96789d2cea5..7c4f40c6b66 100644
--- a/tensorflow/python/tpu/tpu.py
+++ b/tensorflow/python/tpu/tpu.py
@@ -208,13 +208,14 @@ def _enclosing_tpu_device_assignment():
 
 
 @auto_control_deps.register_acd_resource_resolver
-def tpu_replicated_input_resolver(op, resource_inputs):
+def tpu_replicated_input_resolver(op, resource_reads, resource_writes):
   """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
   # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
   # control deps on the replicated inputs.
   if op.type == "TPUReplicatedInput":
-    if resource_inputs:
-      resource_inputs.clear()
+    if resource_reads or resource_writes:
+      resource_reads.clear()
+      resource_writes.clear()
       return True
     else:
       return False
@@ -222,18 +223,21 @@ def tpu_replicated_input_resolver(op, resource_inputs):
   # with the actual replicated inputs. This allows ACD to correct add control
   # deps when there are multiple calls to `experimental_run_v2` in a
   # `tf.function`.
-  to_remove = []
-  to_add = []
-  for resource in resource_inputs:
-    if resource.op.type == "TPUReplicatedInput":
-      to_remove.append(resource)
-      to_add.extend(resource.op.inputs)
-  if not to_add and not to_remove:
-    return False
-  for t in to_remove:
-    resource_inputs.discard(t)
-  resource_inputs.update(to_add)
-  return True
+  def replace_with_unreplicated_resources(resource_inputs):
+    """Replaces handles in `resource_inputs` with their unreplicated inputs."""
+    to_remove = []
+    to_add = []
+    for resource in resource_inputs:
+      if resource.op.type == "TPUReplicatedInput":
+        to_remove.append(resource)
+        to_add.extend(resource.op.inputs)
+    for t in to_remove:
+      resource_inputs.discard(t)
+    resource_inputs.update(to_add)
+    return to_add or to_remove
+
+  return (replace_with_unreplicated_resources(resource_reads) or
+          replace_with_unreplicated_resources(resource_writes))
 
 
 class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py
index 035c416d793..a6ca3c6fda8 100644
--- a/tensorflow/python/util/deprecation_test.py
+++ b/tensorflow/python/util/deprecation_test.py
@@ -761,7 +761,7 @@ class DeprecatedArgValuesTest(test.TestCase):
       deprecation.deprecated_arg_values(date, None, deprecated=True)
     with self.assertRaisesRegexp(ValueError, "instructions"):
       deprecation.deprecated_arg_values(date, "", deprecated=True)
-    with self.assertRaisesRegexp(ValueError, "argument", deprecated=True):
+    with self.assertRaisesRegexp(ValueError, "argument"):
       deprecation.deprecated_arg_values(date, instructions)
 
   @test.mock.patch.object(logging, "warning", autospec=True)
diff --git a/tensorflow/stream_executor/allocator_stats.cc b/tensorflow/stream_executor/allocator_stats.cc
index 440d6f46a3c..8a45efdef83 100644
--- a/tensorflow/stream_executor/allocator_stats.cc
+++ b/tensorflow/stream_executor/allocator_stats.cc
@@ -18,7 +18,7 @@ limitations under the License.
 
 namespace stream_executor {
 
-string AllocatorStats::DebugString() const {
+std::string AllocatorStats::DebugString() const {
   return absl::StrFormat(
       "Limit:        %20lld\n"
       "InUse:        %20lld\n"
diff --git a/tensorflow/stream_executor/allocator_stats.h b/tensorflow/stream_executor/allocator_stats.h
index 62edfff3c1b..9a99c1099c9 100644
--- a/tensorflow/stream_executor/allocator_stats.h
+++ b/tensorflow/stream_executor/allocator_stats.h
@@ -51,7 +51,7 @@ struct AllocatorStats {
         bytes_reserved(0),
         peak_bytes_reserved(0) {}
 
-  string DebugString() const;
+  std::string DebugString() const;
 };
 
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc
index 9b8fdf1efe8..f499b3003d0 100644
--- a/tensorflow/stream_executor/blas.cc
+++ b/tensorflow/stream_executor/blas.cc
@@ -20,7 +20,7 @@ limitations under the License.
 namespace stream_executor {
 namespace blas {
 
-string TransposeString(Transpose t) {
+std::string TransposeString(Transpose t) {
   switch (t) {
     case Transpose::kNoTranspose:
       return "NoTranspose";
@@ -33,7 +33,7 @@ string TransposeString(Transpose t) {
   }
 }
 
-string UpperLowerString(UpperLower ul) {
+std::string UpperLowerString(UpperLower ul) {
   switch (ul) {
     case UpperLower::kUpper:
       return "Upper";
@@ -44,7 +44,7 @@ string UpperLowerString(UpperLower ul) {
   }
 }
 
-string DiagonalString(Diagonal d) {
+std::string DiagonalString(Diagonal d) {
   switch (d) {
     case Diagonal::kUnit:
       return "Unit";
@@ -55,7 +55,7 @@ string DiagonalString(Diagonal d) {
   }
 }
 
-string SideString(Side s) {
+std::string SideString(Side s) {
   switch (s) {
     case Side::kLeft:
       return "Left";
@@ -68,9 +68,11 @@ string SideString(Side s) {
 
 // -- AlgorithmConfig
 
-string AlgorithmConfig::ToString() const { return absl::StrCat(algorithm_); }
+std::string AlgorithmConfig::ToString() const {
+  return absl::StrCat(algorithm_);
+}
 
-string ComputationTypeString(ComputationType ty) {
+std::string ComputationTypeString(ComputationType ty) {
   switch (ty) {
     case ComputationType::kF16:
       return "f16";
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index d361343c381..5018d487ed1 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -67,27 +67,27 @@ namespace blas {
 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose };
 
 // Returns a name for t.
-string TransposeString(Transpose t);
+std::string TransposeString(Transpose t);
 
 // Specifies whether the upper or lower triangular part of a
 // symmetric/Hermitian matrix is used.
 enum class UpperLower { kUpper, kLower };
 
 // Returns a name for ul.
-string UpperLowerString(UpperLower ul);
+std::string UpperLowerString(UpperLower ul);
 
 // Specifies whether a matrix is unit triangular.
 enum class Diagonal { kUnit, kNonUnit };
 
 // Returns a name for d.
-string DiagonalString(Diagonal d);
+std::string DiagonalString(Diagonal d);
 
 // Specifies whether a Hermitian matrix appears on the left or right in
 // operation.
 enum class Side { kLeft, kRight };
 
 // Returns a name for s.
-string SideString(Side s);
+std::string SideString(Side s);
 
 // Type with which intermediate computations of a blas routine are performed.
 //
@@ -104,7 +104,7 @@ enum class ComputationType {
 };
 
 // Converts a ComputationType to a string.
-string ComputationTypeString(ComputationType ty);
+std::string ComputationTypeString(ComputationType ty);
 
 std::ostream &operator<<(std::ostream &os, ComputationType ty);
 
@@ -157,7 +157,7 @@ class AlgorithmConfig {
   bool operator!=(const AlgorithmConfig &other) const {
     return !(*this == other);
   }
-  string ToString() const;
+  std::string ToString() const;
 
  private:
   AlgorithmType algorithm_;
@@ -1383,7 +1383,7 @@ class BlasSupport {
                           const DeviceMemory<std::complex<double>> &a, int lda,
                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
 
-  virtual port::Status GetVersion(string *version) = 0;
+  virtual port::Status GetVersion(std::string *version) = 0;
 
  protected:
   BlasSupport() {}
@@ -2196,7 +2196,7 @@ class BlasSupport {
                   uint64 n, std::complex<double> alpha,                        \
                   const DeviceMemory<std::complex<double>> &a, int lda,        \
                   DeviceMemory<std::complex<double>> *b, int ldb) override;    \
-  port::Status GetVersion(string *version) override;
+  port::Status GetVersion(std::string *version) override;
 
 }  // namespace blas
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/device_description.cc b/tensorflow/stream_executor/device_description.cc
index 5bdfb7ef1d0..9ee6e6837d7 100644
--- a/tensorflow/stream_executor/device_description.cc
+++ b/tensorflow/stream_executor/device_description.cc
@@ -55,10 +55,11 @@ DeviceDescription::DeviceDescription()
       core_count_(-1),
       ecc_enabled_(false) {}
 
-std::unique_ptr<std::map<string, string>> DeviceDescription::ToMap() const {
-  std::unique_ptr<std::map<string, string>> owned_result{
-      new std::map<string, string>};
-  std::map<string, string> &result = *owned_result;
+std::unique_ptr<std::map<std::string, std::string>> DeviceDescription::ToMap()
+    const {
+  std::unique_ptr<std::map<std::string, std::string>> owned_result{
+      new std::map<std::string, std::string>};
+  std::map<std::string, std::string> &result = *owned_result;
   result["Device Vendor"] = device_vendor();
   result["Platform Version"] = platform_version();
   result["Driver Version"] = driver_version();
diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h
index db14b516a19..fa7426eb04b 100644
--- a/tensorflow/stream_executor/device_description.h
+++ b/tensorflow/stream_executor/device_description.h
@@ -42,22 +42,22 @@ class DeviceDescription {
   // Returns the platform being run on; this value is primarily intended for
   // printing, and comes out something like "OpenCL 1.2" or "Compute Capability
   // 3.5".
-  const string &platform_version() const { return platform_version_; }
+  const std::string &platform_version() const { return platform_version_; }
 
   // Returns the driver version interfacing with the underlying platform. Vendor
   // dependent format.
-  const string &driver_version() const { return driver_version_; }
+  const std::string &driver_version() const { return driver_version_; }
 
   // Return the runtime version, if one is provided by the underlying platform.
   // Vendor dependent format / usefulness.
-  const string &runtime_version() const { return runtime_version_; }
+  const std::string &runtime_version() const { return runtime_version_; }
 
   // Returns the name that the device reports. Vendor dependent.
-  const string &name() const { return name_; }
+  const std::string &name() const { return name_; }
 
   // Returns the PCI bus identifier for this device, of the form
   // [domain]:[bus]:[device].[function]
-  const string &pci_bus_id() const { return pci_bus_id_; }
+  const std::string &pci_bus_id() const { return pci_bus_id_; }
 
   // Returns the NUMA node associated with this device, for use in
   // determining socket locality. If the NUMA node could not be determined, -1
@@ -126,7 +126,7 @@ class DeviceDescription {
 
   // Returns the device vendor string, e.g., "NVIDIA Corporation", "Advanced
   // Micro Devices, Inc.", or "GenuineIntel".
-  const string &device_vendor() const { return device_vendor_; }
+  const std::string &device_vendor() const { return device_vendor_; }
 
   // Returns the CUDA compute capability if we're running on the CUDA platform.
   // If a CUDA compute capability is not available, the major version will be
@@ -150,7 +150,7 @@ class DeviceDescription {
   // TODO(leary): resident blocks per core will be useful.
 
   // Convenience typedef for the string-based DeviceDescription mapping.
-  typedef std::map<string, string> Map;
+  typedef std::map<std::string, std::string> Map;
 
   // Returns a mapping from readable names to readable values that describe the
   // device. This is useful for things like printing.
@@ -169,12 +169,12 @@ class DeviceDescription {
   // above.
   //
   // N.B. If another field is added, update ToMap() above.
-  string device_vendor_;
-  string platform_version_;
-  string driver_version_;
-  string runtime_version_;
-  string pci_bus_id_;
-  string name_;
+  std::string device_vendor_;
+  std::string platform_version_;
+  std::string driver_version_;
+  std::string runtime_version_;
+  std::string pci_bus_id_;
+  std::string name_;
 
   ThreadDim thread_dim_limit_;
   BlockDim block_dim_limit_;
@@ -221,22 +221,24 @@ class DeviceDescriptionBuilder {
   // For descriptions of the following fields, see comments on the corresponding
   // DeviceDescription::* accessors above.
 
-  void set_device_vendor(const string &value) {
+  void set_device_vendor(const std::string &value) {
     device_description_->device_vendor_ = value;
   }
-  void set_platform_version(const string &value) {
+  void set_platform_version(const std::string &value) {
     device_description_->platform_version_ = value;
   }
-  void set_driver_version(const string &value) {
+  void set_driver_version(const std::string &value) {
     device_description_->driver_version_ = value;
   }
-  void set_runtime_version(const string &value) {
+  void set_runtime_version(const std::string &value) {
     device_description_->runtime_version_ = value;
   }
-  void set_pci_bus_id(const string &value) {
+  void set_pci_bus_id(const std::string &value) {
     device_description_->pci_bus_id_ = value;
   }
-  void set_name(const string &value) { device_description_->name_ = value; }
+  void set_name(const std::string &value) {
+    device_description_->name_ = value;
+  }
 
   void set_thread_dim_limit(const ThreadDim &value) {
     device_description_->thread_dim_limit_ = value;
diff --git a/tensorflow/stream_executor/device_options.h b/tensorflow/stream_executor/device_options.h
index 2646950f42e..b195bc84e14 100644
--- a/tensorflow/stream_executor/device_options.h
+++ b/tensorflow/stream_executor/device_options.h
@@ -71,13 +71,13 @@ struct DeviceOptions {
     return !(*this == other);
   }
 
-  string ToString() {
+  std::string ToString() {
     return flags_ == 0 ? "none" : "kDoNotReclaimStackAllocation";
   }
 
   // Platform-specific device options. Expressed as key-value pairs to avoid
   // DeviceOptions subclass proliferation.
-  std::map<string, string> non_portable_tags;
+  std::map<std::string, std::string> non_portable_tags;
 
  private:
   unsigned flags_;
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index e1289165252..567cca5f6a2 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -33,7 +33,7 @@ uint64 AlgorithmDesc::hash() const {
   return absl::Hash<decltype(p)>()(p);
 }
 
-string AlgorithmDesc::ToString() const {
+std::string AlgorithmDesc::ToString() const {
   if (tensor_ops_enabled()) {
     return absl::StrCat(algo_id(), "#TC");
   } else {
@@ -74,7 +74,7 @@ bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
   return false;
 }
 
-string QuantizedActivationModeString(QuantizedActivationMode mode) {
+std::string QuantizedActivationModeString(QuantizedActivationMode mode) {
   switch (mode) {
     case dnn::QuantizedActivationMode::k8Bit:
       return "uint8";
@@ -89,7 +89,7 @@ string QuantizedActivationModeString(QuantizedActivationMode mode) {
   return "unknown quantized_activation_mode";
 }
 
-string ActivationModeString(ActivationMode mode) {
+std::string ActivationModeString(ActivationMode mode) {
   switch (mode) {
     case ActivationMode::kSigmoid:
       return "sigmoid";
@@ -109,7 +109,7 @@ string ActivationModeString(ActivationMode mode) {
   return "unknown activation_mode";
 }
 
-string ElementwiseOperationString(ElementwiseOperation op) {
+std::string ElementwiseOperationString(ElementwiseOperation op) {
   switch (op) {
     case ElementwiseOperation::kAdd:
       return "add";
@@ -121,7 +121,7 @@ string ElementwiseOperationString(ElementwiseOperation op) {
   return "unknown element wise op";
 }
 
-string DataLayoutString(DataLayout layout) {
+std::string DataLayoutString(DataLayout layout) {
   switch (layout) {
     case DataLayout::kYXDepthBatch:
       return "YXDepthBatch";
@@ -139,7 +139,7 @@ string DataLayoutString(DataLayout layout) {
   return "unknown data layout";
 }
 
-string FilterLayoutString(FilterLayout layout) {
+std::string FilterLayoutString(FilterLayout layout) {
   switch (layout) {
     case FilterLayout::kOutputInputYX:
       return "OutputInputYX";
@@ -157,7 +157,7 @@ string FilterLayoutString(FilterLayout layout) {
   return "unknown filter layout";
 }
 
-string PadAlignmentString(PadAlignment alignment) {
+std::string PadAlignmentString(PadAlignment alignment) {
   switch (alignment) {
     case PadAlignment::kDefault:
       return "default";
@@ -173,7 +173,7 @@ std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) {
   return str << PadAlignmentString(alignment);
 }
 
-string ShortPoolingModeString(PoolingMode mode) {
+std::string ShortPoolingModeString(PoolingMode mode) {
   switch (mode) {
     case PoolingMode::kMaximum:
       return "Max";
@@ -247,12 +247,12 @@ std::vector<int64> ReorderDims(const std::vector<int64>& input,
 
 // -- AlgorithmConfig
 
-string AlgorithmConfig::ToString() const {
-  string algo = "none";
+std::string AlgorithmConfig::ToString() const {
+  std::string algo = "none";
   if (algorithm().has_value()) {
     algo = algorithm()->ToString();
   }
-  string algo_no_scratch = "none";
+  std::string algo_no_scratch = "none";
   if (algorithm_no_scratch().has_value()) {
     algo_no_scratch = algorithm_no_scratch()->ToString();
   }
@@ -306,8 +306,8 @@ void BatchDescriptor::CloneFrom(const BatchDescriptor& other) {
   quantized_activation_mode_ = other.quantized_activation_mode_;
 }
 
-string BatchDescriptor::ToString() const {
-  string spatial;
+std::string BatchDescriptor::ToString() const {
+  std::string spatial;
   for (int i = 0; i < ndims(); i++) {
     absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
   }
@@ -318,19 +318,19 @@ string BatchDescriptor::ToString() const {
       DataLayoutString(layout()));
 }
 
-string BatchDescriptor::ToShortString() const {
+std::string BatchDescriptor::ToShortString() const {
   // All the constituent strings are less than 15 characters, so the
   // small string optimization ensures that there will be at most one
   // heap memory allocation.
-  string depth = absl::StrCat("d", feature_map_count());
-  string batch = absl::StrCat("b", count());
+  std::string depth = absl::StrCat("d", feature_map_count());
+  std::string batch = absl::StrCat("b", count());
 
-  string spatial = "s";
+  std::string spatial = "s";
   for (int i = 0; i < ndims(); i++) {
     absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
   }
 
-  string suffix;
+  std::string suffix;
   if (value_min() != value_max()) {
     absl::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
   }
@@ -419,8 +419,8 @@ void FilterDescriptor::CloneFrom(const FilterDescriptor& other) {
   tensor_ = other.tensor_;
 }
 
-string FilterDescriptor::ToString() const {
-  string desc = absl::StrFormat(
+std::string FilterDescriptor::ToString() const {
+  std::string desc = absl::StrFormat(
       "{output_feature_map_count: %d input_feature_map_count: %d "
       "layout: %s shape: ",
       output_feature_map_count(), input_feature_map_count(),
@@ -433,14 +433,14 @@ string FilterDescriptor::ToString() const {
   return desc;
 }
 
-string FilterDescriptor::ToShortString() const {
+std::string FilterDescriptor::ToShortString() const {
   // All the constituent strings are less than 15 characters, so the
   // small string optimization ensures that there will be at most one
   // heap memory allocation.
-  string od = absl::StrCat("od", output_feature_map_count());
-  string id = absl::StrCat("id", input_feature_map_count());
+  std::string od = absl::StrCat("od", output_feature_map_count());
+  std::string id = absl::StrCat("id", input_feature_map_count());
 
-  string spatial = "s";
+  std::string spatial = "s";
   for (int i = 0; i < ndims(); i++) {
     absl::StrAppendFormat(&spatial, "%d ", input_filter_dims()[i]);
   }
@@ -491,10 +491,10 @@ ConvolutionDescriptor::ConvolutionDescriptor()
 
 ConvolutionDescriptor::~ConvolutionDescriptor() {}
 
-string ConvolutionDescriptor::ToString() const {
-  string padding;
-  string strides;
-  string dilations;
+std::string ConvolutionDescriptor::ToString() const {
+  std::string padding;
+  std::string strides;
+  std::string dilations;
   for (int i = 0; i < ndims(); i++) {
     absl::StrAppendFormat(&padding, "%d ", this->padding()[i]);
     absl::StrAppendFormat(&strides, "%d ", this->strides()[i]);
@@ -507,8 +507,8 @@ string ConvolutionDescriptor::ToString() const {
       padding, PadAlignmentString(pad_alignment()), strides, dilations);
 }
 
-string ConvolutionDescriptor::ToShortString() const {
-  string desc;
+std::string ConvolutionDescriptor::ToShortString() const {
+  std::string desc;
   for (int i = 0; i < ndims(); i++) {
     if (i > 0) absl::StrAppend(&desc, "_");
     absl::StrAppendFormat(&desc, "p%d:%d", i, padding()[i]);
@@ -543,11 +543,11 @@ void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
   propagate_nans_ = other.propagate_nans_;
 }
 
-string PoolingDescriptor::ToString() const {
+std::string PoolingDescriptor::ToString() const {
   const char* mode_string =
       mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
 
-  string window, strides, padding;
+  std::string window, strides, padding;
   for (int i = 0; i < ndims_; i++) {
     absl::StrAppendFormat(&window, "%d ", window_[i]);
     absl::StrAppendFormat(&strides, "%d ", strides_[i]);
@@ -561,8 +561,8 @@ string PoolingDescriptor::ToString() const {
       mode_string, window, strides, padding, propagate_string);
 }
 
-string PoolingDescriptor::ToShortString() const {
-  string window, strides, padding;
+std::string PoolingDescriptor::ToShortString() const {
+  std::string window, strides, padding;
   for (int i = 0; i < ndims_; i++) {
     absl::StrAppendFormat(&window, "_w%d:%d", i, window_[i]);
     absl::StrAppendFormat(&strides, "_s%d:%d", i, strides_[i]);
@@ -592,14 +592,14 @@ void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) {
   segment_size_ = other.segment_size_;
 }
 
-string NormalizeDescriptor::ToString() const {
+std::string NormalizeDescriptor::ToString() const {
   return absl::StrFormat(
       "{bias: %f range: %d alpha: %f beta: %f wrap_around: %d "
       "segment_size: %d}",
       bias_, range_, alpha_, beta_, wrap_around_, segment_size_);
 }
 
-string NormalizeDescriptor::ToShortString() const {
+std::string NormalizeDescriptor::ToShortString() const {
   return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_,
                       "_beta:", beta_, "_wrap:", wrap_around_,
                       "_size:", segment_size_);
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 8885e568ed1..771dc908a12 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -101,7 +101,7 @@ inline absl::Span<int64> AsInt64Slice(T* repeated_field) {
 }
 
 // Returns a string representation of the given data layout.
-string DataLayoutString(DataLayout layout);
+std::string DataLayoutString(DataLayout layout);
 
 // Specifies a quantization for activations in a given BatchDescriptor.
 enum class QuantizedActivationMode {
@@ -209,7 +209,7 @@ class RnnStateTensorDescriptor {
 };
 
 // Returns a string representation of the given quantization mode.
-string QuantizedActivationModeString(QuantizedActivationMode mode);
+std::string QuantizedActivationModeString(QuantizedActivationMode mode);
 
 // Describes the dimensions that a layer consumes/produces.
 //
@@ -260,8 +260,8 @@ class BatchDescriptor {
   // Clones values from 'other' for initialization.
   void CloneFrom(const BatchDescriptor& other);
 
-  string ToString() const;
-  string ToShortString() const;
+  std::string ToString() const;
+  std::string ToShortString() const;
 
   // Pre-condition:
   //   value_max_ == 0
@@ -374,7 +374,7 @@ class BatchDescriptor {
 };
 
 // Returns a string representation of the given filter layout.
-string FilterLayoutString(FilterLayout layout);
+std::string FilterLayoutString(FilterLayout layout);
 
 // Describes a filter for the convolution. This is the "window" from
 // height-by-width patches of each of the feature maps in the input layer to the
@@ -439,8 +439,8 @@ class FilterDescriptor {
 
   void CloneFrom(const FilterDescriptor& other);
 
-  string ToString() const;
-  string ToShortString() const;
+  std::string ToString() const;
+  std::string ToShortString() const;
   TensorDescriptorProto ToProto(DataType data_type) const;
 
   // Returns the number of weights required as parameters for a convolution
@@ -486,7 +486,7 @@ enum class PadAlignment : int64 {
 };
 
 // Returns a string representation of the given padding alignment.
-string PadAlignmentString(PadAlignment alignment);
+std::string PadAlignmentString(PadAlignment alignment);
 
 // Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
@@ -529,8 +529,8 @@ class ConvolutionDescriptor {
   explicit ConvolutionDescriptor(int ndims);
   ~ConvolutionDescriptor();
 
-  string ToString() const;
-  string ToShortString() const;
+  std::string ToString() const;
+  std::string ToShortString() const;
   ConvolutionDescriptorProto ToProto() const { return proto_; }
 
   ConvolutionDescriptor& set_zero_padding_height(int64 value) {
@@ -578,7 +578,7 @@ class ConvolutionDescriptor {
                                      : ConvolutionMode::CROSS_CORRELATION);
     return *this;
   }
-  ConvolutionDescriptor& set_name(const string& name) {
+  ConvolutionDescriptor& set_name(const std::string& name) {
     proto_.set_name(name);
     return *this;
   }
@@ -621,7 +621,7 @@ class ConvolutionDescriptor {
     return AsInt64Slice(proto_.paddings());
   }
 
-  string name() const { return proto_.name(); }
+  std::string name() const { return proto_.name(); }
 
  private:
   absl::Span<int64> strides() { return AsInt64Slice(proto_.mutable_strides()); }
@@ -658,7 +658,7 @@ enum class SpaceConcatenateMode : int64 {
 };
 
 // Returns a short name for the pooling mode, e.g. "Avg".
-string ShortPoolingModeString(PoolingMode mode);
+std::string ShortPoolingModeString(PoolingMode mode);
 
 // Describes a pooling operation to be enqueued onto a stream via a platform's
 // DnnSupport.
@@ -722,7 +722,7 @@ class PoolingDescriptor {
     propagate_nans_ = value;
     return *this;
   }
-  PoolingDescriptor& set_name(const string& name) {
+  PoolingDescriptor& set_name(const std::string& name) {
     name_ = name;
     return *this;
   }
@@ -730,8 +730,8 @@ class PoolingDescriptor {
   int ndims() const { return ndims_; }
   void CloneFrom(const PoolingDescriptor& other);
 
-  string ToString() const;
-  string ToShortString() const;
+  std::string ToString() const;
+  std::string ToShortString() const;
 
   PoolingMode mode() const { return mode_; }
   int64 window_height() const { return GetDim(window_, DimIndex::Y); }
@@ -747,13 +747,13 @@ class PoolingDescriptor {
   absl::Span<const int64> padding() const { return padding_; }
   absl::Span<const int64> strides() const { return strides_; }
   bool propagate_nans() const { return propagate_nans_; }
-  string name() const { return name_; }
+  std::string name() const { return name_; }
 
  private:
   PoolingMode mode_;
   int ndims_;
   bool propagate_nans_;
-  string name_;  // Name as in Tensorflow NodeDef, for debugging purposes.
+  std::string name_;  // Name as in Tensorflow NodeDef, for debugging purposes.
 
   // Stored as: ..., y, x.
   std::vector<int64> window_;
@@ -783,7 +783,7 @@ class AlgorithmDesc {
 
   AlgorithmProto ToProto() const { return proto_; }
 
-  string ToString() const;
+  std::string ToString() const;
 
  private:
   AlgorithmProto proto_;
@@ -860,7 +860,7 @@ class AlgorithmConfig {
   bool operator!=(const AlgorithmConfig& other) const {
     return !(*this == other);
   }
-  string ToString() const;
+  std::string ToString() const;
 
  private:
   absl::optional<AlgorithmDesc> algorithm_;
@@ -927,8 +927,8 @@ class NormalizeDescriptor {
 
   void CloneFrom(const NormalizeDescriptor& other);
 
-  string ToString() const;
-  string ToShortString() const;
+  std::string ToString() const;
+  std::string ToShortString() const;
 
   float bias() const { return bias_; }
   int32 range() const { return range_; }
@@ -947,13 +947,13 @@ class NormalizeDescriptor {
 };
 
 // Returns a string representation of the given activation mode.
-string ActivationModeString(ActivationMode mode);
+std::string ActivationModeString(ActivationMode mode);
 
 // Describes the operation that DoElementwiseOperation should perform on its
 // inputs.
 enum class ElementwiseOperation { kAdd, kMultiply };
 
-string ElementwiseOperationString(ElementwiseOperation op);
+std::string ElementwiseOperationString(ElementwiseOperation op);
 
 // A simple class representing the version of the backing library, to
 // workaround the "too perfect forwarding" issue in gcc6+ compilers.
diff --git a/tensorflow/stream_executor/kernel.cc b/tensorflow/stream_executor/kernel.cc
index 2aee9617bff..ec78c120965 100644
--- a/tensorflow/stream_executor/kernel.cc
+++ b/tensorflow/stream_executor/kernel.cc
@@ -91,7 +91,7 @@ KernelCacheConfig KernelBase::GetPreferredCacheConfig() const {
 }
 
 void KernelBase::set_name(absl::string_view name) {
-  name_ = string(name);
+  name_ = std::string(name);
 
   // CUDA splitter prefixes stub functions with __device_stub_.
   demangled_name_ =
diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h
index 1e4f375073e..d3d386bca4a 100644
--- a/tensorflow/stream_executor/kernel.h
+++ b/tensorflow/stream_executor/kernel.h
@@ -178,8 +178,8 @@ class KernelBase {
   KernelCacheConfig GetPreferredCacheConfig() const;
 
   void set_name(absl::string_view name);
-  const string &name() const { return name_; }
-  const string &demangled_name() const { return demangled_name_; }
+  const std::string &name() const { return name_; }
+  const std::string &demangled_name() const { return demangled_name_; }
 
  private:
   // The StreamExecutor that loads this kernel object.
@@ -188,8 +188,8 @@ class KernelBase {
   // Implementation delegated to for platform-specific functionality.
   std::unique_ptr<internal::KernelInterface> implementation_;
 
-  string name_;
-  string demangled_name_;
+  std::string name_;
+  std::string demangled_name_;
 
   KernelMetadata metadata_;
 
diff --git a/tensorflow/stream_executor/kernel_spec.cc b/tensorflow/stream_executor/kernel_spec.cc
index d7e00205103..de10d8ec9f1 100644
--- a/tensorflow/stream_executor/kernel_spec.cc
+++ b/tensorflow/stream_executor/kernel_spec.cc
@@ -19,11 +19,11 @@ limitations under the License.
 namespace stream_executor {
 
 KernelLoaderSpec::KernelLoaderSpec(absl::string_view kernelname)
-    : kernelname_(string(kernelname)) {}
+    : kernelname_(std::string(kernelname)) {}
 
 OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(absl::string_view filename,
                                                absl::string_view kernelname)
-    : KernelLoaderSpec(kernelname), filename_(string(filename)) {}
+    : KernelLoaderSpec(kernelname), filename_(std::string(filename)) {}
 
 CudaPtxOnDisk::CudaPtxOnDisk(absl::string_view filename,
                              absl::string_view kernelname)
@@ -77,13 +77,13 @@ CudaPtxInMemory::CudaPtxInMemory(
   }
 }
 
-string CudaPtxInMemory::DecompressPtx(const char *ptx) {
+std::string CudaPtxInMemory::DecompressPtx(const char *ptx) {
   // Get the length of the PTX string from the beginning of the buffer.
   uint64 ptx_length = *reinterpret_cast<const uint64 *>(ptx);
   // Get the PTX string from the buffer with offset and length.
-  string compressed_ptx(ptx + sizeof(uint64),
-                        ptx + sizeof(uint64) + ptx_length);
-  string decompressed_ptx;
+  std::string compressed_ptx(ptx + sizeof(uint64),
+                             ptx + sizeof(uint64) + ptx_length);
+  std::string decompressed_ptx;
   // Decompress the PTX string with bzip2.
   LOG(FATAL) << "bzip2 decompression is not supported yet.";
   return decompressed_ptx;
diff --git a/tensorflow/stream_executor/kernel_spec.h b/tensorflow/stream_executor/kernel_spec.h
index 7199f60e4ca..5138e3b3a2c 100644
--- a/tensorflow/stream_executor/kernel_spec.h
+++ b/tensorflow/stream_executor/kernel_spec.h
@@ -73,7 +73,7 @@ class KernelLoaderSpec {
   virtual ~KernelLoaderSpec() {}
 
   // Returns the kernel name to load out of the program.
-  const string &kernelname() const { return kernelname_; }
+  const std::string &kernelname() const { return kernelname_; }
 
  protected:
   explicit KernelLoaderSpec(absl::string_view kernelname);
@@ -81,7 +81,7 @@ class KernelLoaderSpec {
  private:
   // The kernel name that should be loaded out of the program description given
   // above.
-  string kernelname_;
+  std::string kernelname_;
 
   SE_DISALLOW_COPY_AND_ASSIGN(KernelLoaderSpec);
 };
@@ -94,7 +94,7 @@ class OnDiskKernelLoaderSpec : public KernelLoaderSpec {
   ~OnDiskKernelLoaderSpec() override {}
 
   // Returns the path to the on-disk loadable kernel file.
-  const string &filename() const { return filename_; }
+  const std::string &filename() const { return filename_; }
 
   // Returns the canonical suffix for this on-disk kernel loader spec format;
   // e.g. PTX files on disk have a canonical suffix of ".ptx".
@@ -104,7 +104,7 @@ class OnDiskKernelLoaderSpec : public KernelLoaderSpec {
   OnDiskKernelLoaderSpec(absl::string_view filename,
                          absl::string_view kernelname);
 
-  string filename_;
+  std::string filename_;
 
  private:
   SE_DISALLOW_COPY_AND_ASSIGN(OnDiskKernelLoaderSpec);
@@ -128,12 +128,12 @@ class CudaCubinOnDisk : public OnDiskKernelLoaderSpec {
   CudaCubinOnDisk(absl::string_view filename, absl::string_view kernelname);
   ~CudaCubinOnDisk() override {}
 
-  const string &filename() const { return filename_; }
+  const std::string &filename() const { return filename_; }
 
   const char *CanonicalSuffix() const override { return ".cubin"; }
 
  private:
-  string filename_;
+  std::string filename_;
 
   SE_DISALLOW_COPY_AND_ASSIGN(CudaCubinOnDisk);
 };
@@ -192,7 +192,7 @@ class CudaPtxInMemory : public KernelLoaderSpec {
                             int compute_capability_minor) const;
 
   // Decompresses the PTX string using bzip2.
-  static string DecompressPtx(const char *ptx);
+  static std::string DecompressPtx(const char *ptx);
 
  private:
   // PTX translation unit text contents in memory. The key is of as a tuple
@@ -205,7 +205,7 @@ class CudaPtxInMemory : public KernelLoaderSpec {
 
   // Stores all decompressed ptx strings, with original ptx string as keys.
   // It is marked as mutable for lazy decompression.
-  mutable std::map<const char *, string> decompressed_ptx_;
+  mutable std::map<const char *, std::string> decompressed_ptx_;
   mutable absl::Mutex mu_;
 
   // Defines the minimum compute capability possible. Used when PTX has no
@@ -246,11 +246,11 @@ class OpenCLTextInMemory : public KernelLoaderSpec {
   ~OpenCLTextInMemory() override {}
 
   // Returns the OpenCL text contents.
-  const string &text() const { return text_; }
+  const std::string &text() const { return text_; }
 
  private:
   // OpenCL translation unit text contents in memory.
-  string text_;
+  std::string text_;
 
   SE_DISALLOW_COPY_AND_ASSIGN(OpenCLTextInMemory);
 };
diff --git a/tensorflow/stream_executor/launch_dim.h b/tensorflow/stream_executor/launch_dim.h
index 4a3c882d9f7..95643ecf9bd 100644
--- a/tensorflow/stream_executor/launch_dim.h
+++ b/tensorflow/stream_executor/launch_dim.h
@@ -56,7 +56,7 @@ struct ThreadDim : public Dim3D {
       : Dim3D(x, y, z) {}
 
   // Returns a string representation of the thread dimensionality.
-  string ToString() const {
+  std::string ToString() const {
     return absl::StrCat("ThreadDim{", x, ", ", y, ", ", z, "}");
   }
 };
@@ -68,7 +68,7 @@ struct BlockDim : public Dim3D {
       : Dim3D(x, y, z) {}
 
   // Returns a string representation of the block dimensionality.
-  string ToString() const {
+  std::string ToString() const {
     return absl::StrCat("BlockDim{", x, ", ", y, ", ", z, "}");
   }
 };
diff --git a/tensorflow/stream_executor/lib/demangle.cc b/tensorflow/stream_executor/lib/demangle.cc
index adb6b4f2d11..8fb5db0777e 100644
--- a/tensorflow/stream_executor/lib/demangle.cc
+++ b/tensorflow/stream_executor/lib/demangle.cc
@@ -33,8 +33,8 @@ namespace port {
 // The API reference of abi::__cxa_demangle() can be found in
 // libstdc++'s manual.
 // https://gcc.gnu.org/onlinedocs/libstdc++/libstdc++-html-USERS-4.3/a01696.html
-string Demangle(const char *mangled) {
-  string demangled;
+std::string Demangle(const char *mangled) {
+  std::string demangled;
   int status = 0;
   char *result = nullptr;
 #if HAS_CXA_DEMANGLE
diff --git a/tensorflow/stream_executor/lib/demangle.h b/tensorflow/stream_executor/lib/demangle.h
index af16fa7d8cb..4c1007a1947 100644
--- a/tensorflow/stream_executor/lib/demangle.h
+++ b/tensorflow/stream_executor/lib/demangle.h
@@ -21,7 +21,7 @@ limitations under the License.
 namespace stream_executor {
 namespace port {
 
-string Demangle(const char* mangled);
+std::string Demangle(const char* mangled);
 
 }  // namespace port
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/env.h b/tensorflow/stream_executor/lib/env.h
index a5eb8ef1d43..f8ccac75e98 100644
--- a/tensorflow/stream_executor/lib/env.h
+++ b/tensorflow/stream_executor/lib/env.h
@@ -27,12 +27,12 @@ namespace port {
 using tensorflow::Env;
 using tensorflow::Thread;
 
-inline Status FileExists(const string& filename) {
+inline Status FileExists(const std::string& filename) {
   return Env::Default()->FileExists(filename);
 }
 
 inline Status FileExists(const absl::string_view& filename) {
-  return Env::Default()->FileExists(string(filename));
+  return Env::Default()->FileExists(std::string(filename));
 }
 
 }  // namespace port
diff --git a/tensorflow/stream_executor/lib/human_readable.h b/tensorflow/stream_executor/lib/human_readable.h
index 5e5525e6b5b..a8ede952fd1 100644
--- a/tensorflow/stream_executor/lib/human_readable.h
+++ b/tensorflow/stream_executor/lib/human_readable.h
@@ -28,7 +28,7 @@ namespace port {
 
 class HumanReadableNumBytes {
  public:
-  static string ToString(int64 num_bytes) {
+  static std::string ToString(int64 num_bytes) {
     if (num_bytes == std::numeric_limits<int64>::min()) {
       // Special case for number with not representable nagation.
       return "-8E";
diff --git a/tensorflow/stream_executor/lib/numbers.cc b/tensorflow/stream_executor/lib/numbers.cc
index b670c42ec84..96bd9c3868e 100644
--- a/tensorflow/stream_executor/lib/numbers.cc
+++ b/tensorflow/stream_executor/lib/numbers.cc
@@ -32,7 +32,7 @@ bool safe_strto32(const char* str, int32* value) {
 // Convert strings to floating point values.
 // Leading and trailing spaces are allowed.
 // Values may be rounded on over- and underflow.
-bool safe_strto32(const string& str, int32* value) {
+bool safe_strto32(const std::string& str, int32* value) {
   return port::safe_strto32(str.c_str(), value);
 }
 
diff --git a/tensorflow/stream_executor/lib/numbers.h b/tensorflow/stream_executor/lib/numbers.h
index 2f48281d2d6..15fecfbfca9 100644
--- a/tensorflow/stream_executor/lib/numbers.h
+++ b/tensorflow/stream_executor/lib/numbers.h
@@ -24,7 +24,7 @@ namespace port {
 // Convert strings to floating point values.
 // Leading and trailing spaces are allowed.
 // Values may be rounded on over- and underflow.
-bool safe_strto32(const string& str, int32* value);
+bool safe_strto32(const std::string& str, int32* value);
 
 }  // namespace port
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/path.cc b/tensorflow/stream_executor/lib/path.cc
index 47eedbc6a16..3388843cfa2 100644
--- a/tensorflow/stream_executor/lib/path.cc
+++ b/tensorflow/stream_executor/lib/path.cc
@@ -27,14 +27,14 @@ static bool IsAbsolutePath(absl::string_view path) {
 
 // For an array of paths of length count, append them all together,
 // ensuring that the proper path separators are inserted between them.
-string JoinPathImpl(std::initializer_list<absl::string_view> paths) {
-  string result;
+std::string JoinPathImpl(std::initializer_list<absl::string_view> paths) {
+  std::string result;
 
   for (absl::string_view path : paths) {
     if (path.empty()) continue;
 
     if (result.empty()) {
-      result = string(path);
+      result = std::string(path);
       continue;
     }
 
diff --git a/tensorflow/stream_executor/lib/path.h b/tensorflow/stream_executor/lib/path.h
index 902331b4273..17ede80d25c 100644
--- a/tensorflow/stream_executor/lib/path.h
+++ b/tensorflow/stream_executor/lib/path.h
@@ -25,7 +25,7 @@ namespace port {
 namespace internal {
 // TODO(rspringer): Move to cc/implementation file.
 // Not part of the public API.
-string JoinPathImpl(std::initializer_list<absl::string_view> paths);
+std::string JoinPathImpl(std::initializer_list<absl::string_view> paths);
 }  // namespace internal
 
 // Join multiple paths together.
@@ -47,7 +47,7 @@ string JoinPathImpl(std::initializer_list<absl::string_view> paths);
 // string path = file::JoinPath("/var/log", dirname, filename);
 // string path = file::JoinPath(FLAGS_test_srcdir, filename);
 template <typename... T>
-inline string JoinPath(const T&... args) {
+inline std::string JoinPath(const T&... args) {
   return internal::JoinPathImpl({args...});
 }
 
diff --git a/tensorflow/stream_executor/lib/process_state.cc b/tensorflow/stream_executor/lib/process_state.cc
index 5a351e7a8d5..3bf4c5aecf4 100644
--- a/tensorflow/stream_executor/lib/process_state.cc
+++ b/tensorflow/stream_executor/lib/process_state.cc
@@ -29,14 +29,14 @@ limitations under the License.
 namespace stream_executor {
 namespace port {
 
-string Hostname() {
+std::string Hostname() {
   char hostname[1024];
   gethostname(hostname, sizeof hostname);
   hostname[sizeof hostname - 1] = 0;
   return std::string(hostname);
 }
 
-bool GetCurrentDirectory(string* dir) {
+bool GetCurrentDirectory(std::string* dir) {
   size_t len = 128;
   std::unique_ptr<char[]> a(new char[len]);
   for (;;) {
diff --git a/tensorflow/stream_executor/lib/process_state.h b/tensorflow/stream_executor/lib/process_state.h
index 248218c759e..68f06f969e8 100644
--- a/tensorflow/stream_executor/lib/process_state.h
+++ b/tensorflow/stream_executor/lib/process_state.h
@@ -21,8 +21,8 @@ limitations under the License.
 namespace stream_executor {
 namespace port {
 
-string Hostname();
-bool GetCurrentDirectory(string* dir);
+std::string Hostname();
+bool GetCurrentDirectory(std::string* dir);
 
 }  // namespace port
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/statusor_test.cc b/tensorflow/stream_executor/lib/statusor_test.cc
index 56584e18920..16480b30789 100644
--- a/tensorflow/stream_executor/lib/statusor_test.cc
+++ b/tensorflow/stream_executor/lib/statusor_test.cc
@@ -147,19 +147,19 @@ TEST(StatusOr, TestMoveOnlyVector) {
 }
 
 TEST(StatusOr, TestMoveWithValuesAndErrors) {
-  StatusOr<string> status_or(string(1000, '0'));
-  StatusOr<string> value1(string(1000, '1'));
-  StatusOr<string> value2(string(1000, '2'));
-  StatusOr<string> error1(Status(tensorflow::error::UNKNOWN, "error1"));
-  StatusOr<string> error2(Status(tensorflow::error::UNKNOWN, "error2"));
+  StatusOr<std::string> status_or(std::string(1000, '0'));
+  StatusOr<std::string> value1(std::string(1000, '1'));
+  StatusOr<std::string> value2(std::string(1000, '2'));
+  StatusOr<std::string> error1(Status(tensorflow::error::UNKNOWN, "error1"));
+  StatusOr<std::string> error2(Status(tensorflow::error::UNKNOWN, "error2"));
 
   ASSERT_TRUE(status_or.ok());
-  EXPECT_EQ(string(1000, '0'), status_or.ValueOrDie());
+  EXPECT_EQ(std::string(1000, '0'), status_or.ValueOrDie());
 
   // Overwrite the value in status_or with another value.
   status_or = std::move(value1);
   ASSERT_TRUE(status_or.ok());
-  EXPECT_EQ(string(1000, '1'), status_or.ValueOrDie());
+  EXPECT_EQ(std::string(1000, '1'), status_or.ValueOrDie());
 
   // Overwrite the value in status_or with an error.
   status_or = std::move(error1);
@@ -174,23 +174,23 @@ TEST(StatusOr, TestMoveWithValuesAndErrors) {
   // Overwrite the error with a value.
   status_or = std::move(value2);
   ASSERT_TRUE(status_or.ok());
-  EXPECT_EQ(string(1000, '2'), status_or.ValueOrDie());
+  EXPECT_EQ(std::string(1000, '2'), status_or.ValueOrDie());
 }
 
 TEST(StatusOr, TestCopyWithValuesAndErrors) {
-  StatusOr<string> status_or(string(1000, '0'));
-  StatusOr<string> value1(string(1000, '1'));
-  StatusOr<string> value2(string(1000, '2'));
-  StatusOr<string> error1(Status(tensorflow::error::UNKNOWN, "error1"));
-  StatusOr<string> error2(Status(tensorflow::error::UNKNOWN, "error2"));
+  StatusOr<std::string> status_or(std::string(1000, '0'));
+  StatusOr<std::string> value1(std::string(1000, '1'));
+  StatusOr<std::string> value2(std::string(1000, '2'));
+  StatusOr<std::string> error1(Status(tensorflow::error::UNKNOWN, "error1"));
+  StatusOr<std::string> error2(Status(tensorflow::error::UNKNOWN, "error2"));
 
   ASSERT_TRUE(status_or.ok());
-  EXPECT_EQ(string(1000, '0'), status_or.ValueOrDie());
+  EXPECT_EQ(std::string(1000, '0'), status_or.ValueOrDie());
 
   // Overwrite the value in status_or with another value.
   status_or = value1;
   ASSERT_TRUE(status_or.ok());
-  EXPECT_EQ(string(1000, '1'), status_or.ValueOrDie());
+  EXPECT_EQ(std::string(1000, '1'), status_or.ValueOrDie());
 
   // Overwrite the value in status_or with an error.
   status_or = error1;
@@ -205,13 +205,13 @@ TEST(StatusOr, TestCopyWithValuesAndErrors) {
   // Overwrite the error with a value.
   status_or = value2;
   ASSERT_TRUE(status_or.ok());
-  EXPECT_EQ(string(1000, '2'), status_or.ValueOrDie());
+  EXPECT_EQ(std::string(1000, '2'), status_or.ValueOrDie());
 
   // Verify original values unchanged.
-  EXPECT_EQ(string(1000, '1'), value1.ValueOrDie());
+  EXPECT_EQ(std::string(1000, '1'), value1.ValueOrDie());
   EXPECT_EQ("error1", error1.status().error_message());
   EXPECT_EQ("error2", error2.status().error_message());
-  EXPECT_EQ(string(1000, '2'), value2.ValueOrDie());
+  EXPECT_EQ(std::string(1000, '2'), value2.ValueOrDie());
 }
 
 TEST(StatusOr, TestDefaultCtor) {
diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc
index ad6437709c4..cbbe0654953 100644
--- a/tensorflow/stream_executor/multi_platform_manager.cc
+++ b/tensorflow/stream_executor/multi_platform_manager.cc
@@ -39,10 +39,10 @@ class MultiPlatformManagerImpl {
       LOCKS_EXCLUDED(mu_);
 
   port::StatusOr<Platform*> InitializePlatformWithName(
-      absl::string_view target, const std::map<string, string>& options)
-      LOCKS_EXCLUDED(mu_);
+      absl::string_view target,
+      const std::map<std::string, std::string>& options) LOCKS_EXCLUDED(mu_);
   port::StatusOr<Platform*> InitializePlatformWithId(
-      const Platform::Id& id, const std::map<string, string>& options)
+      const Platform::Id& id, const std::map<std::string, std::string>& options)
       LOCKS_EXCLUDED(mu_);
 
   port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
@@ -66,13 +66,13 @@ class MultiPlatformManagerImpl {
   absl::Mutex mu_;
   std::vector<std::unique_ptr<Listener>> listeners_ GUARDED_BY(mu_);
   absl::flat_hash_map<Platform::Id, Platform*> id_map_ GUARDED_BY(mu_);
-  absl::flat_hash_map<string, Platform*> name_map_ GUARDED_BY(mu_);
+  absl::flat_hash_map<std::string, Platform*> name_map_ GUARDED_BY(mu_);
 };
 
 port::Status MultiPlatformManagerImpl::RegisterPlatform(
     std::unique_ptr<Platform> platform) {
   CHECK(platform != nullptr);
-  string key = absl::AsciiStrToLower(platform->Name());
+  std::string key = absl::AsciiStrToLower(platform->Name());
   absl::MutexLock lock(&mu_);
   if (name_map_.find(key) != name_map_.end()) {
     return port::Status(port::error::INTERNAL,
@@ -118,7 +118,8 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
 }
 
 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
-    absl::string_view target, const std::map<string, string>& options) {
+    absl::string_view target,
+    const std::map<std::string, std::string>& options) {
   absl::MutexLock lock(&mu_);
 
   SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
@@ -134,7 +135,7 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
 }
 
 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId(
-    const Platform::Id& id, const std::map<string, string>& options) {
+    const Platform::Id& id, const std::map<std::string, std::string>& options) {
   absl::MutexLock lock(&mu_);
 
   SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
@@ -224,13 +225,14 @@ MultiPlatformManagerImpl& Impl() {
 
 /*static*/ port::StatusOr<Platform*>
 MultiPlatformManager::InitializePlatformWithName(
-    absl::string_view target, const std::map<string, string>& options) {
+    absl::string_view target,
+    const std::map<std::string, std::string>& options) {
   return Impl().InitializePlatformWithName(target, options);
 }
 
 /*static*/ port::StatusOr<Platform*>
 MultiPlatformManager::InitializePlatformWithId(
-    const Platform::Id& id, const std::map<string, string>& options) {
+    const Platform::Id& id, const std::map<std::string, std::string>& options) {
   return Impl().InitializePlatformWithId(id, options);
 }
 
diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h
index 6e6617a6da9..556015de790 100644
--- a/tensorflow/stream_executor/multi_platform_manager.h
+++ b/tensorflow/stream_executor/multi_platform_manager.h
@@ -111,10 +111,12 @@ class MultiPlatformManager {
   // Ownership of the platform is NOT transferred to the caller --
   // the MultiPlatformManager owns the platforms in a singleton-like fashion.
   static port::StatusOr<Platform*> InitializePlatformWithName(
-      absl::string_view target, const std::map<string, string>& options);
+      absl::string_view target,
+      const std::map<std::string, std::string>& options);
 
   static port::StatusOr<Platform*> InitializePlatformWithId(
-      const Platform::Id& id, const std::map<string, string>& options);
+      const Platform::Id& id,
+      const std::map<std::string, std::string>& options);
 
   // Retrieves the platforms satisfying the given filter, i.e. returns true.
   // Returned Platforms are always initialized.
diff --git a/tensorflow/stream_executor/platform.cc b/tensorflow/stream_executor/platform.cc
index 9c995814386..fce9a1c0cd2 100644
--- a/tensorflow/stream_executor/platform.cc
+++ b/tensorflow/stream_executor/platform.cc
@@ -24,7 +24,7 @@ limitations under the License.
 
 namespace stream_executor {
 
-string PlatformKindString(PlatformKind kind) {
+std::string PlatformKindString(PlatformKind kind) {
   switch (kind) {
     case PlatformKind::kCuda:
       return "CUDA";
@@ -41,7 +41,7 @@ string PlatformKindString(PlatformKind kind) {
   }
 }
 
-PlatformKind PlatformKindFromString(string kind) {
+PlatformKind PlatformKindFromString(std::string kind) {
   for (int i = 0; i < static_cast<int>(PlatformKind::kSize); ++i) {
     if (kind == PlatformKindString(static_cast<PlatformKind>(i))) {
       return static_cast<PlatformKind>(i);
@@ -91,7 +91,7 @@ Platform::~Platform() {}
 bool Platform::Initialized() const { return true; }
 
 port::Status Platform::Initialize(
-    const std::map<string, string> &platform_options) {
+    const std::map<std::string, std::string> &platform_options) {
   if (!platform_options.empty()) {
     return port::Status(port::error::UNIMPLEMENTED,
                         "this platform does not support custom initialization");
diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h
index aefb94f6192..8f80ee837d0 100644
--- a/tensorflow/stream_executor/platform.h
+++ b/tensorflow/stream_executor/platform.h
@@ -60,11 +60,11 @@ bool PlatformIsRunnable(PlatformKind kind);
 bool PlatformIsRunnableOnDevice(PlatformKind kind);
 
 // Returns a printable description of a PlatformKind.
-string PlatformKindString(PlatformKind kind);
+std::string PlatformKindString(PlatformKind kind);
 
 // Returns the PlatformKind corresponding to the input string; returns kInvalid
 // in the case of no match.
-PlatformKind PlatformKindFromString(string platform_string);
+PlatformKind PlatformKindFromString(std::string platform_string);
 
 // Checks that kind takes on a valid value.
 void CheckPlatformKindIsValid(PlatformKind kind);
@@ -114,7 +114,7 @@ class Platform {
   virtual Id id() const = 0;
 
   // Name of this platform.
-  virtual const string& Name() const = 0;
+  virtual const std::string& Name() const = 0;
 
   // Returns the number of devices accessible on this platform.
   //
@@ -133,7 +133,7 @@ class Platform {
   // MultiPlatformManager, this method will be called automatically by
   // InitializePlatformWithId/InitializePlatformWithName.
   virtual port::Status Initialize(
-      const std::map<string, string>& platform_options);
+      const std::map<std::string, std::string>& platform_options);
 
   // Returns a populated DeviceDescription for the device at the given ordinal.
   // This should not require device initialization. Note that not all platforms
diff --git a/tensorflow/stream_executor/plugin_registry.cc b/tensorflow/stream_executor/plugin_registry.cc
index 1e6a2d4f2a9..4f85f92ca1c 100644
--- a/tensorflow/stream_executor/plugin_registry.cc
+++ b/tensorflow/stream_executor/plugin_registry.cc
@@ -27,7 +27,7 @@ namespace stream_executor {
 const PluginId kNullPlugin = nullptr;
 
 // Returns the string representation of the specified PluginKind.
-string PluginKindString(PluginKind plugin_kind) {
+std::string PluginKindString(PluginKind plugin_kind) {
   switch (plugin_kind) {
     case PluginKind::kBlas:
       return "BLAS";
@@ -70,7 +70,7 @@ void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
 
 template <typename FACTORY_TYPE>
 port::Status PluginRegistry::RegisterFactoryInternal(
-    PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
+    PluginId plugin_id, const std::string& plugin_name, FACTORY_TYPE factory,
     std::map<PluginId, FACTORY_TYPE>* factories) {
   absl::MutexLock lock{&GetPluginRegistryMutex()};
 
@@ -110,7 +110,7 @@ bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
   if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
     port::StatusOr<Platform*> status =
         MultiPlatformManager::PlatformWithId(platform_id);
-    string platform_name = "<unregistered platform>";
+    std::string platform_name = "<unregistered platform>";
     if (status.ok()) {
       platform_name = status.ValueOrDie()->Name();
     }
@@ -194,7 +194,7 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id,
                                                                               \
   template <>                                                                 \
   port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
-      Platform::Id platform_id, PluginId plugin_id, const string& name,       \
+      Platform::Id platform_id, PluginId plugin_id, const std::string& name,  \
       PluginRegistry::FACTORY_TYPE factory) {                                 \
     return RegisterFactoryInternal(plugin_id, name, factory,                  \
                                    &factories_[platform_id].FACTORY_VAR);     \
@@ -202,7 +202,8 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id,
                                                                               \
   template <>                                                                 \
   port::Status PluginRegistry::RegisterFactoryForAllPlatforms<                \
-      PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name,   \
+      PluginRegistry::FACTORY_TYPE>(PluginId plugin_id,                       \
+                                    const std::string& name,                  \
                                     PluginRegistry::FACTORY_TYPE factory) {   \
     return RegisterFactoryInternal(plugin_id, name, factory,                  \
                                    &generic_factories_.FACTORY_VAR);          \
diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h
index 7e4407d4b27..2ea39c9cafc 100644
--- a/tensorflow/stream_executor/plugin_registry.h
+++ b/tensorflow/stream_executor/plugin_registry.h
@@ -62,13 +62,13 @@ class PluginRegistry {
   // with that platform (but execution should be otherwise unaffected).
   template <typename FactoryT>
   port::Status RegisterFactory(Platform::Id platform_id, PluginId plugin_id,
-                               const string& name, FactoryT factory);
+                               const std::string& name, FactoryT factory);
 
   // Registers the specified factory as usable by _all_ platform types.
   // Reports errors just as RegisterFactory.
   template <typename FactoryT>
   port::Status RegisterFactoryForAllPlatforms(PluginId plugin_id,
-                                              const string& name,
+                                              const std::string& name,
                                               FactoryT factory);
 
   // TODO(b/22689637): Setter for temporary mapping until all users are using
@@ -122,7 +122,7 @@ class PluginRegistry {
   // Actually performs the work of registration.
   template <typename FactoryT>
   port::Status RegisterFactoryInternal(PluginId plugin_id,
-                                       const string& plugin_name,
+                                       const std::string& plugin_name,
                                        FactoryT factory,
                                        std::map<PluginId, FactoryT>* factories);
 
@@ -155,7 +155,7 @@ class PluginRegistry {
   std::map<Platform::Id, DefaultFactories> default_factories_;
 
   // Lookup table for plugin names.
-  std::map<PluginId, string> plugin_names_;
+  std::map<PluginId, std::string> plugin_names_;
 
   SE_DISALLOW_COPY_AND_ASSIGN(PluginRegistry);
 };
@@ -164,7 +164,7 @@ class PluginRegistry {
 #define DECLARE_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE)                          \
   template <>                                                                 \
   port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
-      Platform::Id platform_id, PluginId plugin_id, const string& name,       \
+      Platform::Id platform_id, PluginId plugin_id, const std::string& name,  \
       PluginRegistry::FACTORY_TYPE factory);                                  \
   template <>                                                                 \
   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 4e430b11af5..c4564a613e1 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -35,55 +35,57 @@ namespace {
 // will be VLOG'ed. We need overloads, instead of
 // e.g. BatchDescriptorToVlogString(), as the code that calls these
 // functions does not know what the type of the parameter is.
-string ToVlogString(const dnn::BatchDescriptor &descriptor) {
+std::string ToVlogString(const dnn::BatchDescriptor &descriptor) {
   return descriptor.ToShortString();
 }
 
-string ToVlogString(const dnn::FilterDescriptor &descriptor) {
+std::string ToVlogString(const dnn::FilterDescriptor &descriptor) {
   return descriptor.ToShortString();
 }
 
-string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
+std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
   return descriptor.ToShortString();
 }
 
-string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
+std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
   return descriptor.ToShortString();
 }
 
-string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
+std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
   return descriptor.ToShortString();
 }
 
-string ToVlogString(dnn::ActivationMode mode) {
+std::string ToVlogString(dnn::ActivationMode mode) {
   return dnn::ActivationModeString(mode);
 }
 
-string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
+std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
   return algo_config.ToString();
 }
 
-string ToVlogString(dnn::ElementwiseOperation op) {
+std::string ToVlogString(dnn::ElementwiseOperation op) {
   return dnn::ElementwiseOperationString(op);
 }
 
-string ToVlogString(dnn::QuantizedActivationMode mode) {
+std::string ToVlogString(dnn::QuantizedActivationMode mode) {
   return dnn::QuantizedActivationModeString(mode);
 }
 
-string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
+std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
 
-string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
+std::string ToVlogString(blas::UpperLower ul) {
+  return blas::UpperLowerString(ul);
+}
 
-string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
+std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
 
-string ToVlogString(blas::Side s) { return blas::SideString(s); }
+std::string ToVlogString(blas::Side s) { return blas::SideString(s); }
 
-string ToVlogString(blas::ComputationType ty) {
+std::string ToVlogString(blas::ComputationType ty) {
   return blas::ComputationTypeString(ty);
 }
 
-string ToVlogString(const void *ptr) {
+std::string ToVlogString(const void *ptr) {
   if (ptr == nullptr) {
     return "null";
   }
@@ -95,7 +97,7 @@ string ToVlogString(const void *ptr) {
 }
 
 template <class T>
-string ToVlogString(const std::complex<T> &c) {
+std::string ToVlogString(const std::complex<T> &c) {
   // StrCat does not convert std::complex to text.
   std::ostringstream out;
   out << c;
@@ -103,36 +105,36 @@ string ToVlogString(const std::complex<T> &c) {
 }
 
 template <class T>
-string ToVlogString(const std::function<T> &f) {
+std::string ToVlogString(const std::function<T> &f) {
   return f == nullptr ? "null" : "<non-null function>";
 }
 
-string ToVlogString(const DeviceMemoryBase &memory) {
+std::string ToVlogString(const DeviceMemoryBase &memory) {
   return ToVlogString(memory.opaque());
 }
 
-string ToVlogString(const DeviceMemoryBase *memory) {
+std::string ToVlogString(const DeviceMemoryBase *memory) {
   return memory == nullptr ? "null" : ToVlogString(*memory);
 }
 
-string ToVlogString(const Eigen::half &h) {
+std::string ToVlogString(const Eigen::half &h) {
   return absl::StrCat(static_cast<float>(h));
 }
 
-string ToVlogString(int i) { return absl::StrCat(i); }
+std::string ToVlogString(int i) { return absl::StrCat(i); }
 
-string ToVlogString(uint32 i) { return absl::StrCat(i); }
+std::string ToVlogString(uint32 i) { return absl::StrCat(i); }
 
-string ToVlogString(uint64 i) { return absl::StrCat(i); }
+std::string ToVlogString(uint64 i) { return absl::StrCat(i); }
 
-string ToVlogString(int64 i) { return absl::StrCat(i); }
+std::string ToVlogString(int64 i) { return absl::StrCat(i); }
 
-string ToVlogString(float f) { return absl::StrCat(f); }
+std::string ToVlogString(float f) { return absl::StrCat(f); }
 
-string ToVlogString(double d) { return absl::StrCat(d); }
+std::string ToVlogString(double d) { return absl::StrCat(d); }
 
 template <typename T>
-string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
+std::string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
   if (memory_or_constant.is_pointer()) {
     return ToVlogString(memory_or_constant.pointer());
   }
@@ -140,8 +142,8 @@ string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
 }
 
 template <class T>
-string ToVlogString(port::ArraySlice<T> elements) {
-  string str = absl::StrCat(
+std::string ToVlogString(port::ArraySlice<T> elements) {
+  std::string str = absl::StrCat(
       ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
       elements.size(), "]{");
   const char *separator = "";
@@ -166,11 +168,11 @@ string ToVlogString(port::ArraySlice<T> elements) {
 }
 
 template <class T>
-string ToVlogString(port::MutableArraySlice<T> elements) {
+std::string ToVlogString(port::MutableArraySlice<T> elements) {
   return ToVlogString(port::ArraySlice<T>(elements));
 }
 
-string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
+std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
   switch (depth_to_space_layout) {
     case dnn::DepthToSpaceLayout::DepthHeightWidth:
       return "DepthToSpaceLayout::DepthHeightWidth";
@@ -178,7 +180,7 @@ string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
   return "unknown DepthToSpaceLayout";
 }
 
-string ToVlogString(dnn::DataType data_type) {
+std::string ToVlogString(dnn::DataType data_type) {
   switch (data_type) {
     case dnn::DataType::kFloat:
       return "dnn::DataType::kFloat";
@@ -205,14 +207,14 @@ string ToVlogString(dnn::DataType data_type) {
 // See VLOG_CALL for a short-hand for this. This way of doing it saves
 // a tremendous amount of boilerplate code given how many functions
 // there are on Stream and how many parameters they each have.
-string CallStr(const char *function_name, Stream *stream,
-               std::vector<std::pair<const char *, string>> params) {
+std::string CallStr(const char *function_name, Stream *stream,
+                    std::vector<std::pair<const char *, std::string>> params) {
   // Do not call this function unless VLOG is on since just
   // constructing all the strings in params is expensive.
   CHECK(VLOG_IS_ON(1));
 
-  string str = absl::StrCat(stream->DebugStreamPointers(),
-                            " Called Stream::", function_name, "(");
+  std::string str = absl::StrCat(stream->DebugStreamPointers(),
+                                 " Called Stream::", function_name, "(");
   const char *separator = "";
   for (const auto &param : params) {
     absl::StrAppend(&str, separator, param.first, "=", param.second);
@@ -5470,7 +5472,7 @@ void Stream::RunAfterBlockHostUntilDoneCallbacks() {
   }
 }
 
-string Stream::DebugStreamPointers() const {
+std::string Stream::DebugStreamPointers() const {
   // Relies on the ToVlogString(const void*) overload above.
   return absl::StrCat("[stream=", ToVlogString(this),
                       ",impl=", ToVlogString(implementation_.get()), "]");
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index d5ff7ea206c..7e4f2e627ee 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -2014,7 +2014,7 @@ class Stream {
   internal::TemporaryMemoryManager *temporary_memory_manager();
 
   // Returns a debugging string "[stream=0x...,impl=0x...]".
-  string DebugStreamPointers() const;
+  std::string DebugStreamPointers() const;
 
  private:
   friend class host::HostBlas;  // for parent_.
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 793fdeb6d56..408b4fc8207 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -286,8 +286,9 @@ class StreamExecutorInterface {
   // If ModuleHandle is set then we search for `symbol_name` only within the
   // module corresponding to `module_handle`.  Otherwise all loaded modules are
   // searched.
-  virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
-                         void **mem, size_t *bytes) {
+  virtual bool GetSymbol(const std::string &symbol_name,
+                         ModuleHandle module_handle, void **mem,
+                         size_t *bytes) {
     return false;
   }
 
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 3b7f2b9760e..dc777c7a2d4 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -47,7 +47,7 @@ bool FLAGS_check_device_leaks = false;
 namespace stream_executor {
 namespace {
 
-string StackTraceIfVLOG10() {
+std::string StackTraceIfVLOG10() {
   if (VLOG_IS_ON(10)) {
     return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
   } else {
@@ -149,7 +149,7 @@ StreamExecutor::StreamExecutor(
       mem_alloc_bytes_(0),
       memory_limit_bytes_(GetMemoryLimitBytes()),
       allocator_(this) {
-  string name = absl::AsciiStrToLower(platform_->Name());
+  std::string name = absl::AsciiStrToLower(platform_->Name());
   if (name == "cuda") {
     platform_kind_ = PlatformKind::kCuda;
   } else if (name == "rocm") {
@@ -239,7 +239,7 @@ port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
   if (config != SharedMemoryConfig::kDefault &&
       config != SharedMemoryConfig::kFourByte &&
       config != SharedMemoryConfig::kEightByte) {
-    string error_msg = absl::StrFormat(
+    std::string error_msg = absl::StrFormat(
         "Invalid shared memory config specified: %d", static_cast<int>(config));
     LOG(ERROR) << error_msg;
     return port::Status(port::error::INVALID_ARGUMENT, error_msg);
@@ -492,7 +492,7 @@ DeviceMemoryBase StreamExecutor::Allocate(uint64 size, int64 memory_space) {
 }
 
 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
-    const string &symbol_name, ModuleHandle module_handle) {
+    const std::string &symbol_name, ModuleHandle module_handle) {
   // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
   // be nullptr/0 for consistency with DeviceMemory semantics.
   void *opaque = nullptr;
@@ -515,7 +515,7 @@ port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
   }
 }
 
-bool StreamExecutor::GetSymbol(const string &symbol_name,
+bool StreamExecutor::GetSymbol(const std::string &symbol_name,
                                ModuleHandle module_handle, void **mem,
                                size_t *bytes) {
   return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 9fd78c54ebb..391ae52ebfd 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -50,7 +50,7 @@ struct AllocRecord {
   // Holds a representation of the stack at the time the associated buffer was
   // allocated. Produced in a form described in
   // //util/symbolize/symbolized_stacktrace.h.
-  string stack_trace;
+  std::string stack_trace;
 };
 
 // Forward declaration of private friend class.
@@ -175,12 +175,12 @@ class StreamExecutor {
   // If `module_handle` is set then searches only within the module
   // corresponding to `module_handle`.
   template <typename T>
-  port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name,
+  port::StatusOr<DeviceMemory<T>> GetSymbol(const std::string &symbol_name,
                                             ModuleHandle module_handle = {});
 
   // An untyped version of GetSymbol.
   port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
-      const string &symbol_name, ModuleHandle module_handle = {});
+      const std::string &symbol_name, ModuleHandle module_handle = {});
 
   // Deallocate the DeviceMemory previously allocated via this interface.
   // Deallocation of a nullptr-representative value is permitted.
@@ -554,7 +554,7 @@ class StreamExecutor {
 
   // Finds and retrieves device memory for the symbol on the underlying
   // platform.
-  bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
+  bool GetSymbol(const std::string &symbol_name, ModuleHandle module_handle,
                  void **mem, size_t *bytes);
 
   // Entrains a memcpy operation onto stream, with a host destination location
@@ -805,7 +805,7 @@ inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count,
 
 template <typename T>
 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
-    const string &symbol_name, ModuleHandle module_handle) {
+    const std::string &symbol_name, ModuleHandle module_handle) {
   port::StatusOr<DeviceMemoryBase> untyped_symbol =
       GetUntypedSymbol(symbol_name, module_handle);
   if (!untyped_symbol.ok()) {
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 5671fdba5ac..22709278d0c 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -969,7 +969,7 @@ def tf_gen_op_wrapper_py(
             name = name + "_pygenrule",
             outs = [out],
             srcs = api_def_srcs + [hidden_file],
-            exec_tools = [tool_name] + tf_binary_additional_srcs(),
+            tools = [tool_name] + tf_binary_additional_srcs(),
             cmd = ("$(location " + tool_name + ") " + api_def_args_str +
                    " @$(location " + hidden_file + ") > $@"),
         )
@@ -978,7 +978,7 @@ def tf_gen_op_wrapper_py(
             name = name + "_pygenrule",
             outs = [out],
             srcs = api_def_srcs,
-            exec_tools = [tool_name] + tf_binary_additional_srcs(),
+            tools = [tool_name] + tf_binary_additional_srcs(),
             cmd = ("$(location " + tool_name + ") " + api_def_args_str + " " +
                    op_list_arg + " " +
                    ("1" if op_list_is_whitelist else "0") + " > $@"),
@@ -2691,6 +2691,9 @@ def if_cuda_or_rocm(if_true, if_false = []):
         "//conditions:default": if_false,
     })
 
+def tf_monitoring_deps():
+    return []
+
 def tf_jit_compilation_passes_extra_deps():
     return []
 
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt
index a2ea2343241..0a8e0b4421a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt
@@ -8,11 +8,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -24,10 +24,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
index a38c4b21d56..b5ccb39d075 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
@@ -10,11 +10,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -26,10 +26,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt
index bdc09bcd84b..1a039b10501 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt
@@ -10,11 +10,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -26,10 +26,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt
index f5ade9f86ba..7876166dc40 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt
@@ -9,11 +9,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -25,10 +25,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
index c3b79911757..0101212e4cc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
@@ -24,7 +24,7 @@ tf_class {
   }
   member_method {
     name: "all_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "merge_call"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
index 8d5f6daff47..b2d9d4ee2cb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
@@ -37,7 +37,7 @@ tf_class {
   }
   member_method {
     name: "batch_reduce_to"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "broadcast_to"
@@ -69,7 +69,7 @@ tf_class {
   }
   member_method {
     name: "reduce_to"
-    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "update"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-hints.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-hints.pbtxt
new file mode 100644
index 00000000000..c010134466c
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-hints.pbtxt
@@ -0,0 +1,9 @@
+path: "tensorflow.distribute.experimental.CollectiveHints"
+tf_class {
+  is_instance: "<class \'tensorflow.python.distribute.collective_util.Hints\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'bytes_per_pack\'], varargs=None, keywords=None, defaults=[\'0\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt
index 879cceec3ac..9247db37925 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt
@@ -8,6 +8,10 @@ tf_module {
     name: "CollectiveCommunication"
     mtype: "<class \'enum.EnumMeta\'>"
   }
+  member {
+    name: "CollectiveHints"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "MultiWorkerMirroredStrategy"
     mtype: "<type \'type\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt
index a2ea2343241..0a8e0b4421a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt
@@ -8,11 +8,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -24,10 +24,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
index a38c4b21d56..b5ccb39d075 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
@@ -10,11 +10,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -26,10 +26,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt
index bdc09bcd84b..1a039b10501 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt
@@ -10,11 +10,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -26,10 +26,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt
index f5ade9f86ba..7876166dc40 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt
@@ -9,11 +9,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -25,10 +25,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
index c3b79911757..0101212e4cc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
@@ -24,7 +24,7 @@ tf_class {
   }
   member_method {
     name: "all_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "merge_call"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
index c5e87b33070..f3fa80427a4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
@@ -20,7 +20,7 @@ tf_class {
   }
   member_method {
     name: "batch_reduce_to"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "colocate_vars_with"
@@ -32,7 +32,7 @@ tf_class {
   }
   member_method {
     name: "reduce_to"
-    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "update"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-hints.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-hints.pbtxt
new file mode 100644
index 00000000000..c010134466c
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-hints.pbtxt
@@ -0,0 +1,9 @@
+path: "tensorflow.distribute.experimental.CollectiveHints"
+tf_class {
+  is_instance: "<class \'tensorflow.python.distribute.collective_util.Hints\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'bytes_per_pack\'], varargs=None, keywords=None, defaults=[\'0\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt
index 879cceec3ac..9247db37925 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt
@@ -8,6 +8,10 @@ tf_module {
     name: "CollectiveCommunication"
     mtype: "<class \'enum.EnumMeta\'>"
   }
+  member {
+    name: "CollectiveHints"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "MultiWorkerMirroredStrategy"
     mtype: "<type \'type\'>"
diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh
index dcac9688bf0..9e4a59ccab3 100644
--- a/tensorflow/tools/ci_build/release/common.sh
+++ b/tensorflow/tools/ci_build/release/common.sh
@@ -103,25 +103,6 @@ function update_bazel_linux {
 # LINT.ThenChange(
 #   //tensorflow_estimator/google/kokoro/common.sh)
 
-# Install the given bazel version on macos
-function update_bazel_macos {
-  if [[ -z "$1" ]]; then
-    BAZEL_VERSION=${LATEST_BAZEL_VERSION}
-  else
-    BAZEL_VERSION=$1
-  fi
-  BAZEL_COMMAND="curl -L https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-darwin-x86_64.sh -O && \
-  chmod +x bazel-*.sh && ./bazel-${BAZEL_VERSION}-installer-darwin-x86_64.sh --user && \
-  rm -f bazel-${BAZEL_VERSION}-installer-darwin-x86_64.sh"
-  # If the bazel update fails retry again in 60 seconds.
-  run_with_retry "${BAZEL_COMMAND}"
-  # Add new bazel installation to path
-  PATH="/Users/kbuilder/bin:$PATH"
-  set_bazel_outdir
-  which bazel
-  bazel version
-}
-
 function install_pip2 {
   curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
   sudo python2 get-pip.py
@@ -153,6 +134,7 @@ function install_pip_deps {
   ${SUDO_CMD} ${PIP_CMD} install astunparse==1.6.3
   ${SUDO_CMD} ${PIP_CMD} install keras_preprocessing==1.1.0 --no-deps
   "${PIP_CMD}" install numpy==1.16.0 --user
+  "${PIP_CMD}" install PyYAML==3.12 --user
   ${SUDO_CMD} ${PIP_CMD} install gast==0.3.3
   ${SUDO_CMD} ${PIP_CMD} install h5py==2.10.0
   ${SUDO_CMD} ${PIP_CMD} install six==1.12.0
@@ -193,6 +175,7 @@ function install_ubuntu_16_pip_deps {
   "${PIP_CMD}" install portpicker --user
   "${PIP_CMD}" install scipy --user
   "${PIP_CMD}" install scikit-learn --user
+  "${PIP_CMD}" install PyYAML==3.12 --user
   "${PIP_CMD}" install --user --upgrade tf-estimator-nightly
   "${PIP_CMD}" install --user --upgrade tb-nightly
   # LINT.ThenChange(:ubuntu_pip_installations)
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py b/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py
index ec2914f9638..7a53fd87169 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py
@@ -47,19 +47,12 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
 
     # List module renames. If changed, please update max_submodule_depth.
     self.import_renames = {
-        "tensorflow.google":
-            ast_edits.ImportRename(
-                "tensorflow.google.compat.v1",
-                excluded_prefixes=[
-                    "tensorflow.google.compat.v1",
-                    "tensorflow.google.compat.v2",
-                ],
-            ),
         "tensorflow":
             ast_edits.ImportRename(
                 "tensorflow.compat.v1",
                 excluded_prefixes=[
                     "tensorflow.contrib", "tensorflow.flags",
+                    "tensorflow.compat",
                     "tensorflow.compat.v1", "tensorflow.compat.v2",
                     "tensorflow.google"
                 ],
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py
index b974a6963ae..eee946b0cb2 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py
@@ -70,14 +70,12 @@ class TfUpgradeV2SafetyTest(test_util.TensorFlowTestCase):
 
   def testTensorFlowGoogleImport(self):
     text = "import tensorflow.google as tf"
-    expected_text = "import tensorflow.google.compat.v1 as tf"
     _, _, _, new_text = self._upgrade(text)
-    self.assertEqual(expected_text, new_text)
+    self.assertEqual(text, new_text)
 
     text = "import tensorflow.google"
-    expected_text = "import tensorflow.google.compat.v1"
     _, _, _, new_text = self._upgrade(text)
-    self.assertEqual(expected_text, new_text)
+    self.assertEqual(text, new_text)
 
     text = "import tensorflow.google.compat.v1 as tf"
     expected_text = "import tensorflow.google.compat.v1 as tf"
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index d8a45098b78..5d349da84bf 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -34,6 +34,7 @@ py_test(
     deps = [
         ":tf_doctest_lib",
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/preprocessing",
         "//third_party/py/numpy",
     ],
 )
diff --git a/tensorflow/tools/docs/tf_doctest.py b/tensorflow/tools/docs/tf_doctest.py
index 2dce63541f5..19624659e37 100644
--- a/tensorflow/tools/docs/tf_doctest.py
+++ b/tensorflow/tools/docs/tf_doctest.py
@@ -28,6 +28,7 @@ import numpy as np
 
 import tensorflow.compat.v2 as tf
 
+from tensorflow.python.keras import preprocessing
 from tensorflow.tools.docs import tf_doctest_lib
 
 # We put doctest after absltest so that it picks up the unittest monkeypatch.
@@ -36,6 +37,9 @@ import doctest  # pylint: disable=g-bad-import-order
 
 tf.compat.v1.enable_v2_behavior()
 
+# Inject keras.preprocessing files into `tf.keras.preprocessing` namespace.
+tf.keras.preprocessing = preprocessing
+
 FLAGS = flags.FLAGS
 
 flags.DEFINE_string('module', None, 'A specific module to run doctest on.')
diff --git a/tensorflow/tools/git/BUILD b/tensorflow/tools/git/BUILD
index 8a47f4c4c2d..c1f0577f33b 100644
--- a/tensorflow/tools/git/BUILD
+++ b/tensorflow/tools/git/BUILD
@@ -13,5 +13,4 @@ py_binary(
     srcs = ["gen_git_source.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
-    deps = ["@six_archive//:six"],
 )
diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py
index 011406e2288..0cb1d006142 100755
--- a/tensorflow/tools/git/gen_git_source.py
+++ b/tensorflow/tools/git/gen_git_source.py
@@ -35,8 +35,6 @@ import os
 import shutil
 import subprocess
 
-import six
-
 
 def parse_branch_ref(filename):
   """Given a filename of a .git/HEAD file return ref path.
@@ -169,8 +167,8 @@ def get_git_version(git_base_path, git_tag_override):
         subprocess.check_output([
             "git",
             str("--git-dir=%s/.git" % git_base_path),
-            str("--work-tree=" + six.ensure_str(git_base_path)), "describe",
-            "--long", "--tags"
+            str("--work-tree=%s" % git_base_path), "describe", "--long",
+            "--tags"
         ]).strip())
     version_separator = b"-"
     if git_tag_override and val:
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index 0a9e5d151e0..1479c0177e7 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -161,7 +161,6 @@ genrule(
         "@nasm//:LICENSE",
         "@nsync//:LICENSE",
         "@png//:LICENSE",
-        "@six_archive//:LICENSE",
         "@snappy//:COPYING",
         "@sobol_data//:LICENSE",
         "@zlib//:zlib.h",
@@ -244,7 +243,6 @@ genrule(
         "@nasm//:LICENSE",
         "@nsync//:LICENSE",
         "@png//:LICENSE",
-        "@six_archive//:LICENSE",
         "@snappy//:COPYING",
         "@sobol_data//:LICENSE",
         "@zlib//:zlib.h",
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 8b8248399d5..682d022d691 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -201,11 +201,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
         name = "eigen_archive",
         build_file = clean_dep("//third_party:eigen.BUILD"),
         patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
-        sha256 = "71905cca5553804beee85e9ab8b254931d3cbeda8df1a40e5af3773f5b657179",  # SHARED_EIGEN_SHA
-        strip_prefix = "eigen-3fda850c46e5e589668a85d89299433e0686eec9",
+        sha256 = "88e95180a7eae9acd3e79d2efeea1026eefad9f515a44418b63b189a1887108c",  # SHARED_EIGEN_SHA
+        strip_prefix = "eigen-52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/3fda850c46e5e589668a85d89299433e0686eec9/eigen-3fda850c46e5e589668a85d89299433e0686eec9.tar.gz",
-            "https://gitlab.com/libeigen/eigen/-/archive/3fda850c46e5e589668a85d89299433e0686eec9/eigen-3fda850c46e5e589668a85d89299433e0686eec9.tar.gz",
+            "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c/eigen-52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c.tar.gz",
+            "https://gitlab.com/libeigen/eigen/-/archive/52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c/eigen-52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c.tar.gz",
         ],
     )
 
@@ -305,12 +305,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
     tf_http_archive(
         name = "org_sqlite",
         build_file = clean_dep("//third_party:sqlite.BUILD"),
-        sha256 = "adf051d4c10781ea5cfabbbc4a2577b6ceca68590d23b58b8260a8e24cc5f081",
-        strip_prefix = "sqlite-amalgamation-3300100",
+        sha256 = "f3c79bc9f4162d0b06fa9fe09ee6ccd23bb99ce310b792c5145f87fbcc30efca",
+        strip_prefix = "sqlite-amalgamation-3310100",
         system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/www.sqlite.org/2019/sqlite-amalgamation-3300100.zip",
-            "https://www.sqlite.org/2019/sqlite-amalgamation-3300100.zip",
+            "https://storage.googleapis.com/mirror.tensorflow.org/www.sqlite.org/2020/sqlite-amalgamation-3310100.zip",
+            "https://www.sqlite.org/2020/sqlite-amalgamation-3310100.zip",
         ],
     )
 
@@ -576,12 +576,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
 
     tf_http_archive(
         name = "com_github_nanopb_nanopb",
-        sha256 = "8bbbb1e78d4ddb0a1919276924ab10d11b631df48b657d960e0c795a25515735",
+        sha256 = "18234d9f01b57248472a9bfa65c3379352b5d66c15b0ef1c2b4feece4b5670fe",
         build_file = "@com_github_grpc_grpc//third_party:nanopb.BUILD",
-        strip_prefix = "nanopb-f8ac463766281625ad710900479130c7fcb4d63b",
+        strip_prefix = "nanopb-0.4.1",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/nanopb/nanopb/archive/f8ac463766281625ad710900479130c7fcb4d63b.tar.gz",
-            "https://github.com/nanopb/nanopb/archive/f8ac463766281625ad710900479130c7fcb4d63b.tar.gz",
+            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/nanopb/nanopb/archive/0.4.1.tar.gz",
+            "https://github.com/nanopb/nanopb/archive/0.4.1.tar.gz",
         ],
     )
 
@@ -597,8 +597,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
     )
 
     # Check out LLVM and MLIR from llvm-project.
-    LLVM_COMMIT = "78be61871704a451a5d9462d7e96ed6c982746d4"
-    LLVM_SHA256 = "35ce5950da1c83c91ab10ea5ac6e88c46fca160044cfd9dc6e50c83879d963ef"
+    LLVM_COMMIT = "fee41517fe0f7ff9f0e204dd9200ebf32ca03cb8"
+    LLVM_SHA256 = "dceb84396e8c30348dbd426c53eeae6657f5c67a24830c9a610a037fffcbe5cf"
     LLVM_URLS = [
         "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
         "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
diff --git a/third_party/hexagon/workspace.bzl b/third_party/hexagon/workspace.bzl
index 79529f3fb9c..d60fcaf82ac 100644
--- a/third_party/hexagon/workspace.bzl
+++ b/third_party/hexagon/workspace.bzl
@@ -5,9 +5,9 @@ load("//third_party:repo.bzl", "third_party_http_archive")
 def repo():
     third_party_http_archive(
         name = "hexagon_nn",
-        sha256 = "4cbf3c18834e24b1f64cc507f9c2f22b4fe576c6ff938d55faced5d8f1bddf62",
+        sha256 = "281d46b47f7191f03a8a4071c4c8d2af9409bb9d59573dc2e42f04c4fd61f1fd",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/storage.cloud.google.com/download.tensorflow.org/tflite/hexagon_nn_headers_v1.10.3.1.2.tgz",
+            "https://storage.googleapis.com/mirror.tensorflow.org/storage.cloud.google.com/download.tensorflow.org/tflite/hexagon_nn_headers_v1.10.3.1.3.tgz",
         ],
         build_file = "//third_party/hexagon:BUILD",
     )
diff --git a/third_party/llvm/llvm.autogenerated.BUILD b/third_party/llvm/llvm.autogenerated.BUILD
index f8bc4cda891..a89838ebac9 100644
--- a/third_party/llvm/llvm.autogenerated.BUILD
+++ b/third_party/llvm/llvm.autogenerated.BUILD
@@ -575,7 +575,8 @@ gentbl(
     name = "amdgpu_isel_target_gen",
     tbl_outs = [
         ("-gen-global-isel", "lib/Target/AMDGPU/AMDGPUGenGlobalISel.inc"),
-        ("-gen-global-isel-combiner -combiners=AMDGPUPreLegalizerCombinerHelper", "lib/Target/AMDGPU/AMDGPUGenGICombiner.inc"),
+        ("-gen-global-isel-combiner -combiners=AMDGPUPreLegalizerCombinerHelper", "lib/Target/AMDGPU/AMDGPUGenPreLegalizeGICombiner.inc"),
+        ("-gen-global-isel-combiner -combiners=AMDGPUPostLegalizerCombinerHelper", "lib/Target/AMDGPU/AMDGPUGenPostLegalizeGICombiner.inc"),
     ],
     tblgen = ":llvm-tblgen",
     td_file = "lib/Target/AMDGPU/AMDGPUGISel.td",
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index cd213f4046a..5729afac05d 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -199,15 +199,17 @@ gentbl(
 
 cc_library(
     name = "LoopOpsTransforms",
-    srcs = ["lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp"],
+    srcs = glob(["lib/Dialect/LoopOps/Transforms/*.cpp"]),
     hdrs = ["include/mlir/Dialect/LoopOps/Passes.h"],
     includes = ["include"],
     deps = [
+        ":AffineOps",
         ":IR",
         ":LoopOps",
         ":Pass",
         ":StandardOps",
         ":Transforms",
+        "@llvm-project//llvm:support",
     ],
     alwayslink = 1,
 )
@@ -216,7 +218,7 @@ filegroup(
     name = "StdOpsTdFiles",
     srcs = [
         "include/mlir/Analysis/CallInterfaces.td",
-        "include/mlir/Dialect/StandardOps/Ops.td",
+        "include/mlir/Dialect/StandardOps/IR/Ops.td",
         "include/mlir/IR/OpAsmInterface.td",
         ":OpBaseTdFiles",
     ],
@@ -228,23 +230,23 @@ gentbl(
     tbl_outs = [
         (
             "-gen-op-decls",
-            "include/mlir/Dialect/StandardOps/Ops.h.inc",
+            "include/mlir/Dialect/StandardOps/IR/Ops.h.inc",
         ),
         (
             "-gen-op-defs",
-            "include/mlir/Dialect/StandardOps/Ops.cpp.inc",
+            "include/mlir/Dialect/StandardOps/IR/Ops.cpp.inc",
         ),
         (
             "-gen-enum-decls",
-            "include/mlir/Dialect/StandardOps/OpsEnums.h.inc",
+            "include/mlir/Dialect/StandardOps/IR/OpsEnums.h.inc",
         ),
         (
             "-gen-enum-defs",
-            "include/mlir/Dialect/StandardOps/OpsEnums.cpp.inc",
+            "include/mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/StandardOps/Ops.td",
+    td_file = "include/mlir/Dialect/StandardOps/IR/Ops.td",
     td_srcs = [
         ":StdOpsTdFiles",
     ],
@@ -384,13 +386,13 @@ cc_library(
     name = "StandardOps",
     srcs = glob(
         [
-            "lib/Dialect/StandardOps/*.cpp",
-            "lib/Dialect/StandardOps/*.h",
+            "lib/Dialect/StandardOps/IR/*.cpp",
+            "lib/Dialect/StandardOps/IR/*.h",
             "lib/Dialect/StandardOps/EDSC/*.cpp",
         ],
     ),
     hdrs = glob([
-        "include/mlir/Dialect/StandardOps/*.h",
+        "include/mlir/Dialect/StandardOps/IR/*.h",
         "include/mlir/Dialect/StandardOps/EDSC/*.h",
     ]) + [
         "include/mlir/Analysis/CallInterfaces.h",
@@ -1371,11 +1373,13 @@ cc_library(
     includes = ["include"],
     deps = [
         ":AffineOps",
+        ":GPUDialect",
         ":LoopOps",
         ":LoopsToGPU",
         ":Pass",
         ":StandardOps",
         ":Support",
+        ":Transforms",
         "@llvm-project//llvm:support",
     ],
 )
@@ -1693,6 +1697,7 @@ cc_library(
         "@llvm-project//mlir/test:TestDialect",
         "@llvm-project//mlir/test:TestIR",
         "@llvm-project//mlir/test:TestPass",
+        "@llvm-project//mlir/test:TestSPIRV",
         "@llvm-project//mlir/test:TestTransforms",
     ],
 )
@@ -1820,6 +1825,7 @@ cc_binary(
         "@llvm-project//mlir/test:TestDialect",
         "@llvm-project//mlir/test:TestIR",
         "@llvm-project//mlir/test:TestPass",
+        "@llvm-project//mlir/test:TestSPIRV",
         "@llvm-project//mlir/test:TestTransforms",
     ],
 )
@@ -2519,7 +2525,7 @@ exports_files(
         "include/mlir/Analysis/CallInterfaces.h",
         "include/mlir/Analysis/CallInterfaces.td",
         "include/mlir/Dialect/LLVMIR/LLVMOpBase.td",
-        "include/mlir/Dialect/StandardOps/Ops.td",
+        "include/mlir/Dialect/StandardOps/IR/Ops.td",
         "include/mlir/IR/OpAsmInterface.td",
         "include/mlir/IR/OpBase.td",
         "include/mlir/Transforms/InliningUtils.h",
diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD
index 1e89b553ac4..657d011254b 100644
--- a/third_party/mlir/test.BUILD
+++ b/third_party/mlir/test.BUILD
@@ -166,3 +166,16 @@ cc_library(
         "@llvm-project//mlir:VectorToLoops",
     ],
 )
+
+cc_library(
+    name = "TestSPIRV",
+    srcs = glob([
+        "lib/Dialect/SPIRV/*.cpp",
+    ]),
+    deps = [
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SPIRVDialect",
+        "@llvm-project//mlir:SPIRVLowering",
+    ],
+)
diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD
index d93f0307690..a2ab4924f29 100644
--- a/third_party/snappy.BUILD
+++ b/third_party/snappy.BUILD
@@ -27,6 +27,10 @@ cc_library(
             "-Wno-implicit-function-declaration",
         ],
     }),
+    defines = select({
+        "@org_tensorflow//tensorflow:windows": [],
+        "//conditions:default": ["HAVE_SYS_UIO_H"],
+    }),
 )
 
 genrule(
diff --git a/third_party/toolchains/remote_config/configs.bzl b/third_party/toolchains/remote_config/configs.bzl
index 4945db280b6..973efb40af1 100644
--- a/third_party/toolchains/remote_config/configs.bzl
+++ b/third_party/toolchains/remote_config/configs.bzl
@@ -22,6 +22,18 @@ def initialize_rbe_configs():
         tensorrt_version = "5.1",
     )
 
+    tensorflow_rbe_config(
+        name = "ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0",
+        compiler = "/dt7/usr/bin/gcc",
+        compiler_prefix = "/usr/bin",
+        cuda_version = "10.1",
+        cudnn_version = "7",
+        os = "ubuntu16.04-manylinux2010",
+        python_version = "3",
+        tensorrt_install_path = "/usr",
+        tensorrt_version = "6.0",
+    )
+
     tensorflow_rbe_config(
         name = "ubuntu16.04-py3_opt-gcc5-rocm",
         compiler = "gcc",
diff --git a/third_party/toolchains/remote_config/containers.bzl b/third_party/toolchains/remote_config/containers.bzl
index 27e0cceee1c..26cb3ea3367 100644
--- a/third_party/toolchains/remote_config/containers.bzl
+++ b/third_party/toolchains/remote_config/containers.bzl
@@ -17,6 +17,13 @@ containers = {
         "digest": container_digests["cuda10.0-cudnn7-ubuntu16.04-manylinux2010"],
     },
 
+    # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010.
+    "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": {
+        "registry": "gcr.io",
+        "repository": "tensorflow-testing/nosla-cuda10.1-cudnn7-ubuntu16.04-manylinux2010",
+        "digest": container_digests["cuda10.1-cudnn7-ubuntu16.04-manylinux2010"],
+    },
+
     # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu16.04
     "rocm-ubuntu16.04": {
         "registry": "gcr.io",