diff --git a/ISSUES.md b/ISSUES.md
index aabd3158b39..a6c77f76950 100644
--- a/ISSUES.md
+++ b/ISSUES.md
@@ -1,7 +1,9 @@
-If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance
-issue or a feature request or a build issue or a documentation issue (for small
-doc fixes please send a PR instead). 2. Make sure the Issue Template is filled
-out. 3. The issue should be related to the repo it is created in.
+If you open a GitHub Issue, here is our policy:
+
+1.  It must be a bug/performance issue or a feature request or a build issue or
+    a documentation issue (for small doc fixes please send a PR instead).
+1.  Make sure the Issue Template is filled out.
+1.  The issue should be related to the repo it is created in.
 
 **Here's why we have this policy:** We want to focus on the work that benefits
 the whole community, e.g., fixing bugs and adding features. Individual support
diff --git a/RELEASE.md b/RELEASE.md
index 237e03f11fa..9ccef55583a 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -54,6 +54,7 @@
     *   Corrected higher-order gradients of control flow constructs (`tf.cond`,
         `tf.while_loop`, and compositions like `tf.foldl`) computed with
         `tf.GradientTape` inside a `tf.function`.
+    *   Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests.
 
 *   `tf.summary`:
   *   New `tf.summary.graph` allows manual write of TensorFlow graph
@@ -65,6 +66,19 @@
     supported MSVC version to 16.4 (current: 16.8).
     *   See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
 
+*   TensorRT
+    *   Removed the deprecated `session_config` parameter for the TF1-TRT
+        converter `TrtGraphConverter`. Previously, we issued a warning when the
+        value of the parameter is not None.
+    *   The TF2-TRT converter `TrtGraphConverterV2` takes an object of class
+        TrtConversionParams as a parameter. Removed three deprecated fields from
+        this class: `rewriter_config_template`, `is_dynamic_op`, and
+        `max_batch_size`. Previously, we issued a warning when the value of
+        `rewriter_config_template` is not None. We issued an error when the
+        value of `is_dynamic_op` is not True. We didn't use the value for
+        `max_batch_size` for building TensorRT engines.
+    *   Issue a warning when function get_tensorrt_rewriter_config is used.
+
 ## Thanks to our Contributors
 
 This release contains contributions from many people at Google, as well as:
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 659bed6e6b8..5e78db99a52 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -72,6 +72,14 @@ config_setting(
     visibility = ["//visibility:public"],
 )
 
+# Config setting that disables the default logger, only logging
+# to registered TFLogSinks
+config_setting(
+    name = "no_default_logger",
+    define_values = {"no_default_logger": "true"},
+    visibility = ["//visibility:public"],
+)
+
 # Config setting for determining if we are building for Android.
 config_setting(
     name = "android",
@@ -732,6 +740,7 @@ tf_cc_shared_object(
         "//tensorflow/c/experimental/filesystem:filesystem_interface",
         "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
         "//tensorflow/c:kernels_hdrs",
+        "//tensorflow/c:logging",
         "//tensorflow/c:ops_hdrs",
         "//tensorflow/cc/saved_model:loader_lite_impl",
         "//tensorflow/core/common_runtime:core_cpu_impl",
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 08dd5d0820e..279e3108318 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -301,7 +301,7 @@ tf_cuda_cc_test(
     ],
     args = ["--heap_check=local"],
     linkstatic = tf_kernel_tests_linkstatic(),
-    tags = tf_cuda_tests_tags(),
+    tags = tf_cuda_tests_tags() + ["no_cuda_asan"],  # b/173654156
     deps = [
         ":c_api_experimental",
         ":c_api_unified_internal",
@@ -469,6 +469,7 @@ tf_cuda_cc_test(
     linkstatic = tf_kernel_tests_linkstatic(),
     tags = tf_cuda_tests_tags() + [
         "nomac",
+        "no_cuda_asan",  # b/173825513
     ],
     deps = [
         ":abstract_tensor_handle",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 88705cf3058..00b45209edb 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -61,6 +61,7 @@ limitations under the License.
 // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
+#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed.h"
 #endif  // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
 
 #if !defined(IS_MOBILE_PLATFORM)
@@ -101,7 +102,14 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
   if (opts->use_tfrt) {
 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
-    return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
+    tfrt::tf::ContextInterface* tfrt_context =
+        new tfrt::tf::ContextInterface(opts->async);
+#if !defined(IS_MOBILE_PLATFORM)
+    tfrt_context->SetDistributedManager(
+        std::make_unique<tfrt::tf::DistributedManagerContextInterface>(
+            tfrt_context->GetCoreRuntime()->GetHostContext()));
+#endif  // !IS_MOBILE_PLATFORM
+    return tensorflow::wrap(tfrt_context);
 #else
     status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
     return nullptr;
diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc
index 58ffcf247cf..8adc51328fa 100644
--- a/tensorflow/c/eager/gradients.cc
+++ b/tensorflow/c/eager/gradients.cc
@@ -226,7 +226,7 @@ void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
 
 // Helper functions which delegate to `AbstractOperation`, update
 // the state of the ForwardOperation and call the tape as appropriate.
-// These APIs are mainly to faciliate testing and are subject to change.
+// These APIs are mainly to facilitate testing and are subject to change.
 namespace internal {
 Status Reset(AbstractOperation* op_, const char* op,
              const char* raw_device_name, ForwardOperation* forward_op_) {
diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h
index a0860da7b04..b54dcf942c7 100644
--- a/tensorflow/compiler/jit/flags.h
+++ b/tensorflow/compiler/jit/flags.h
@@ -39,7 +39,7 @@ struct XlaAutoJitFlag {
   int32 optimization_level_general;
 };
 
-// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax
+// Sets the xla_auto_jit_flag based on the given flag string. Supported syntax
 // is:
 // <number>: sets general and single_gpu setting to the provided number.
 // single-gpu(<number>): sets the single_gpu setting to the provided number.
diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc
index a549c99d9b7..baca8b99088 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator.cc
+++ b/tensorflow/compiler/jit/xla_kernel_creator.cc
@@ -103,7 +103,9 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
   if (flr->config_proto()) {
     config_proto = *flr->config_proto();
   }
-  if (!IsMlirBridgePassEnabled(*fbody->graph, config_proto)) {
+  MlirBridgeRolloutPolicy policy =
+      GetMlirBridgeRolloutPolicy(*fbody->graph, config_proto);
+  if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
     RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
     if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
       std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 0ebf3ac6bd1..dd6d3b3fbaf 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
 
   let results = (outs HLO_StaticShapeTensor);
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 
   let hasCustomHLOConverter = 1;
 }
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td
index b3708bf4ff1..4613627c36d 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td
@@ -202,9 +202,9 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
   let arguments = (ins
     Arg<LHLO_Buffer, "", [MemRead]>:$input,
     Arg<LHLO_Buffer, "", [MemWrite]>:$output,
-    Arg<UntypedBuffer, "", [MemWrite]>:$scratch,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$scratch,
     Arg<I32Buffer, "", [MemWrite]>:$info,
-    BoolAttr:$is_upper);
+    BoolAttr:$is_lower);
 }
 
 #endif // LHLO_GPU_OPS
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
index 0ee44f3202f..ff052b53098 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
@@ -197,10 +197,11 @@ def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO
 //===----------------------------------------------------------------------===//
 
 // TODO(b/139813999): specify required function signature in a type-safe way.
-def LHLO_ReduceOp: LHLO_Op<"reduce", [
-      SameVariadicOperandSize,
-      SingleBlockImplicitTerminator<"TerminatorOp">
-    ]>, BASE_HLO_ReduceOp {
+//
+// The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are
+// moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp.
+// TODO(timshen): cleanup lmhlo.TerminatorOp.
+def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp {
   let arguments = (ins
     Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
     Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index 389c5794c91..2b92afe956b 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
+                                            MLIRContext* context) {
+  results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
+      context);
+}
+
 //===----------------------------------------------------------------------===//
 // Case Op
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td
index 776732b178b..01564b86381 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td
@@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat<
 def ShapeOfDynamicReshape : Pat<
   (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
   (replaceWithValue $shape)>;
+
+def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
+
+def IdentityBroadcastReshape : Pat<
+  (HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
+  (replaceWithValue $input),
+  [(HasSameType $input, $op)]>;
+
+def IdentityBroadcastInDimReshape : Pat<
+  (HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)),
+  (replaceWithValue $input),
+  [(HasSameType $input, $op)]>;
diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc b/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc
index 0751d2c626c..a29f0a628c4 100644
--- a/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc
@@ -61,12 +61,16 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
   // mapValues always takes a function returning APInt, even when the output
   // is actually float.
   using func_type = llvm::APInt(const llvm::APInt&);
+
+  // TODO(hinsu): Correctly handle unsigned element types.
+  bool is_bool = old_type.isInteger(1);
   if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
     // Int -> Float
     return elements.mapValues(
-        new_type, llvm::function_ref<func_type>([&newFloatType](
+        new_type, llvm::function_ref<func_type>([&newFloatType, &is_bool](
                                                     const llvm::APInt& intVal) {
-          llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
+          int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
+          llvm::APFloat newDouble(static_cast<double>(val));
           bool loses_info = false;
           newDouble.convert(newFloatType.getFloatSemantics(),
                             llvm::APFloat::rmNearestTiesToEven, &loses_info);
@@ -76,9 +80,10 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
   // new_type is Integer
   // Int -> Int
   return elements.mapValues(
-      new_type,
-      llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
-        return llvm::APInt(bit_width, intVal.getSExtValue());
+      new_type, llvm::function_ref<func_type>([&bit_width, &is_bool](
+                                                  const llvm::APInt& intVal) {
+        int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
+        return llvm::APInt(bit_width, val);
       }));
 }
 
diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
index 8470f363fcb..41eedeeabe5 100644
--- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
@@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> {
   // CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
   // CHECK-SAME: ]> : tensor<4x5xi32>
 }
+
+// CHECK-LABEL: @identity_broadcast_reshape
+func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
+  %0 = "mhlo.broadcast"(%arg0) {
+    broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>
+  %1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
+  return %1 : tensor<128xf32>
+  // CHECK: return %arg0 : tensor<128xf32>
+}
+
+// CHECK-LABEL: @identity_broadcast_in_dim_reshape
+func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
+  %0 = "mhlo.broadcast_in_dim"(%arg0) {
+    broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32>
+  %1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
+  return %1 : tensor<128xf32>
+  // CHECK: return %arg0 : tensor<128xf32>
+}
diff --git a/tensorflow/compiler/mlir/hlo/tests/convert.mlir b/tensorflow/compiler/mlir/hlo/tests/convert.mlir
index dab395c52cd..246cf415d27 100644
--- a/tensorflow/compiler/mlir/hlo/tests/convert.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/convert.mlir
@@ -123,6 +123,17 @@ func @const_int_bf16() -> tensor<bf16> {
 
 // -----
 
+// CHECK-LABEL: func @const_bool_f32
+func @const_bool_f32() -> tensor<2xf32> {
+  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
+  %cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
+  %0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xf32>
+  // CHECK-NEXT: return [[CST]]
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @const_bf16_int
 func @const_bf16_int() -> tensor<i16> {
   // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
@@ -145,8 +156,8 @@ func @const_int_narrowing() -> tensor<i32> {
 
 // -----
 
-// CHECK-LABEL: func @const_int_widening
-func @const_int_widening() -> tensor<i64> {
+// CHECK-LABEL: func @const_bool_widening
+func @const_bool_widening() -> tensor<i64> {
   // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
   %cst = mhlo.constant dense<42> : tensor<i32>
   %0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
@@ -156,6 +167,17 @@ func @const_int_widening() -> tensor<i64> {
 
 // -----
 
+// CHECK-LABEL: func @const_int_widening
+func @const_int_widening() -> tensor<2xi32> {
+  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0, 1]> : tensor<2xi32>
+  %cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
+  %0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xi32>
+  // CHECK-NEXT: return [[CST]]
+  return %0 : tensor<2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @const_negative_int_widening
 func @const_negative_int_widening() -> tensor<i64> {
   // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir
index 9e5ce67f39a..a939cab6d10 100644
--- a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir
@@ -93,7 +93,7 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
 func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
   %scratch = alloc() : memref<32xi8>
   %info = alloc() : memref<32xi32>
-  "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true }
+  "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
       : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
   return
 }
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 63f7a26899f..978d0bbbfa9 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -704,6 +704,7 @@ cc_library(
         ":convert_type",
         ":flatbuffer_tflite_operator_lib",
         ":tensorflow_lite",
+        "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:mangling_util",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index a3cd4a8ad51..2861c14b32b 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -64,6 +64,7 @@ limitations under the License.
 #include "mlir/Translation.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@@ -149,7 +150,7 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
 
   int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
                             is_signed, storage_type.getWidth()) +
-                        is_weight_buffer;
+                        static_cast<int>(is_weight_buffer);
   int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
       is_signed, storage_type.getWidth());
   uint32_t flags =
@@ -177,12 +178,25 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
       quant_params.zero_point.at(0), storage_min, storage_max);
 }
 
+// import float tensor with calibration value into calibrated quantized type.
+StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
+                                                   Builder builder) {
+  if (tensor.quantization == nullptr) {
+    return errors::InvalidArgument("The tensor is not quantized.");
+  }
+  auto raw_elem_type = ConvertElementType(tensor.type, builder);
+  float min = tensor.quantization->min[0];
+  float max = tensor.quantization->max[0];
+  return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
+}
+
 // TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
 // make that distinction and don't have to rely on context
 // (input to main and constants must have static shape)
 StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
                                          bool shapeless_are_scalars = false,
-                                         bool is_constant = false) {
+                                         bool is_constant = false,
+                                         bool is_intermediate = false) {
   mlir::Type elem_type = ConvertElementType(tensor.type, builder);
   // TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere
   // if it's set
@@ -191,6 +205,13 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
                         GetQuantizedType(tensor, builder, is_constant));
   }
 
+  // Intermediate tensors with calibration value (but not scale and zero points)
+  // should return calibrated quantized type.
+  if (is_intermediate && tensor.quantization != nullptr &&
+      !IsQuantized(tensor)) {
+    TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
+  }
+
   if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
     return RankedTensorType::get({}, elem_type);
   }
@@ -1033,7 +1054,7 @@ StatusOr<FuncOp> ConvertSubgraph(
       }
     }
 
-    // Intermediate tensors for tfl.lstm are used to carry quantization range
+    // Intermediate tensors for LSTMs are used to carry quantization range
     // in their types, so we only need and extract their types.
     std::vector<mlir::TensorType> intermediate_types;
     intermediate_types.reserve(5);
@@ -1041,7 +1062,8 @@ StatusOr<FuncOp> ConvertSubgraph(
       TF_ASSIGN_OR_RETURN(
           auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
                                    /*shapeless_are_scalars=*/true,
-                                   /*is_constant=*/true));
+                                   /*is_constant=*/false,
+                                   /*is_intermediate=*/true));
       intermediate_types.emplace_back(type);
     }
 
@@ -1135,7 +1157,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
 
   auto builder = Builder(context);
 
-
   std::vector<std::string> func_names;
   for (auto& subgraph : model->subgraphs) {
     func_names.push_back(subgraph->name);
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 215812a6d1d..1ec6a4c28a7 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -1978,6 +1978,10 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
 
 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 1);
+  if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
+    return input();
+  }
+
   // For now, only supports cast between integer types.
   auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
   if (!elements_attr) {
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 51300978c39..d6ab33e0673 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -483,9 +483,9 @@ value of each element in `x`. For example, if x is an input element and y is
 an output element, this operation computes \\(y = |x|\\).
   }];
 
-  let arguments = (ins TFL_FpTensor:$x);
+  let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x);
 
-  let results = (outs TFL_FpTensor:$y);
+  let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y);
 
   let hasFolder = 1;
 }
@@ -587,15 +587,15 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
 
   let arguments = (ins
     TFL_I32Tensor:$output_shape,
-    TFL_TensorOf<[F32, QI8, QUI8]>:$weights,
-    TFL_TensorOf<[F32, QI8, QUI8]>:$input,
-    TFL_TensorOfOrNone<[F32, QI32]>:$bias,
+    TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$weights,
+    TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
+    TFL_TensorOfOrNone<[F32, QI32, I64]>:$bias,
     TFL_PaddingAttr:$padding,
     Confined<I32Attr, [IntPositive]>:$stride_h,
     Confined<I32Attr, [IntPositive]>:$stride_w
   );
 
-  let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
+  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
 
   let hasOptions = 1;
 
@@ -624,7 +624,7 @@ def TFL_AveragePool2DOp:
   }];
 
   let arguments = (
-    ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
+    ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
     I32Attr:$filter_height,
     I32Attr:$filter_width,
     TFL_PaddingAttr:$padding,
@@ -633,7 +633,7 @@ def TFL_AveragePool2DOp:
     TFL_AFAttr:$fused_activation_function
   );
 
-  let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
+  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
 
   let hasOptions = 1;
   let customOption = "Pool2DOptions";
@@ -947,7 +947,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
 
   let arguments = (ins
     TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input,
-    TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter,
+    TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$filter,
     TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
 
     TFL_AFAttr:$fused_activation_function,
@@ -999,14 +999,14 @@ in the batch dimensions and broadcasting.
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[F32, QI8]>:$x,
-    TFL_TensorOf<[F32, QI8]>:$y,
+    TFL_TensorOf<[F32, QI8, QI16]>:$x,
+    TFL_TensorOf<[F32, QI8, QI16]>:$y,
     DefaultValuedAttr<BoolAttr, "false">:$adj_x,
     DefaultValuedAttr<BoolAttr, "false">:$adj_y
   );
 
    let results = (outs
-    TFL_TensorOf<[F32, QI8]>:$output
+    TFL_TensorOf<[F32, QI8, QI16]>:$output
   );
 
   let hasOptions = 1;
@@ -1026,7 +1026,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$params,
+    TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$params,
     TFL_TensorOf<[I32, I64]>:$indices,
     I32Attr:$axis
   );
@@ -1038,7 +1038,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
   ];
 
   let results = (outs
-    TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$output
+    TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$output
   );
 
   let hasOptions = 1;
@@ -1750,12 +1750,12 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
   }];
 
   let arguments = (
-    ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input,
+    ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$input,
     // Slope of the activation function at x < 0.
     F32Attr:$alpha
   );
 
-  let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
+  let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$output);
 
   let hasOptions = 0b1;
 }
@@ -1977,12 +1977,12 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
   }];
 
   let arguments = (
-    ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
-    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
+    ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
+    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
   );
 
   let results = (outs
-    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$max
+    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$max
   );
 
   let builders = [TFL_BroadcastableBinaryBuilder];
@@ -2005,13 +2005,13 @@ def TFL_MeanOp : TFL_Op<"mean", [
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$input,
+    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$input,
     TFL_TensorOf<[I32, I64]>:$axis,
     BoolAttr:$keep_dims
   );
 
   let results = (outs
-    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$output);
+    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$output);
 
   let hasOptions = 1;
   let customOption = "ReducerOptions";
@@ -2090,13 +2090,13 @@ equivalent to setting:
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$input,
+    TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input,
     TFL_I32OrI64Tensor:$begin,
     TFL_I32OrI64Tensor:$size
   );
 
   let results = (outs
-    TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$output
+    TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output
   );
 
   let verifier = [{ return Verify(*this); }];
@@ -2116,12 +2116,12 @@ def TFL_SumOp: TFL_Op<"sum", [
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
+    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
     TFL_I32Tensor:$axes,
     BoolAttr:$keep_dims
   );
 
-  let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
+  let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
 
   let hasOptions = 1;
   let customOption = "ReducerOptions";
@@ -2139,13 +2139,13 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
+    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
     TFL_I32Tensor:$axes,
     BoolAttr:$keep_dims
   );
 
   let results = (outs
-    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
+    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
 
   let hasOptions = 1;
   let customOption = "ReducerOptions";
@@ -2163,13 +2163,13 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
+    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
     TFL_I32Tensor:$axes,
     BoolAttr:$keep_dims
   );
 
   let results = (outs
-    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
+    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
 
   let hasOptions = 1;
   let customOption = "ReducerOptions";
@@ -2186,13 +2186,13 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
+    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
     TFL_I32Tensor:$axes,
     BoolAttr:$keep_dims
   );
 
   let results = (outs
-    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
+    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
 
   let hasOptions = 1;
   let customOption = "ReducerOptions";
@@ -2210,12 +2210,12 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
   }];
 
   let arguments = (
-    ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
-    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
+    ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
+    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
   );
 
   let results = (outs
-    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$min
+    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$min
   );
 
   let builders = [TFL_BroadcastableBinaryBuilder];
@@ -2364,10 +2364,10 @@ def TFL_PadOp : TFL_Op<"pad", [
     ```
   }];
 
-  let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
+  let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
     TFL_I32OrI64Tensor:$padding);
 
-  let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
+  let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
 
   let hasOptions = 1;
 }
@@ -2500,9 +2500,9 @@ def TFL_ReluOp: TFL_Op<"relu", [
       x -> max(0, x)
   }];
 
-  let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
+  let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$x);
 
-  let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
+  let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$y);
 
   // This builder doesn't work with quantized type, so it can only be used by
   // non-quantization tablegen patterns. Currently, it is used by the
@@ -2828,11 +2828,11 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
   }];
 
   let arguments = (
-    ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
+    ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$input,
     F32Attr:$beta
   );
 
-  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
+  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$output);
 
   let hasOptions = 1;
 
@@ -3058,12 +3058,12 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$input,
+    TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$input,
     TFL_TensorOf<[I32]>:$perm
   );
 
   let results = (outs
-    TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$output
+    TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$output
   );
 
   let verifier = [{ return Verify(*this); }];
@@ -3330,14 +3330,14 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$input,
+    TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$input,
     TFL_I32Tensor:$size,
     BoolAttr:$align_corners,
     DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
   );
 
   let results = (outs
-    TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$output
+    TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$output
   );
 
   let hasOptions = 1;
@@ -3830,6 +3830,9 @@ Ba et al. 'Layer Normalization'
 
     // Types of the optional intermediate tensors, which exist for fully
     // quantized LSTM op and hold the ranges of the intermediate tensors.
+    // The type for intermediate tenssors are be quant.calibrated when imported
+    // to only store calibrated min, max values. The proper quantization spec is
+    // determined while going through quantization passes.
     OptionalAttr<TypeAttr>:$input_to_input_intermediate,
     OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
     OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
@@ -3945,6 +3948,9 @@ def TFL_UnidirectionalSequenceLSTMOp :
 
     // Types of the optional intermediate tensors, which exist for fully
     // quantized op and hold the ranges of the intermediate tensors.
+    // The type for intermediate tenssors are be quant.calibrated when imported
+    // to only store calibrated min, max values. The proper quantization spec is
+    // determined while going through quantization passes.
     OptionalAttr<TypeAttr>:$input_to_input_intermediate,
     OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
     OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
index 897e623b37b..015323df5e3 100644
--- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
@@ -89,6 +89,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
   bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
   pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
   pass_config.lower_tensor_list_ops = true;
+  // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
+  // conversion to an unsupported 16x16 TFL::FullyConnectedOp.
+  if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
+    pass_config.unfold_batch_matmul = false;
+  }
 
   return internal::ConvertMLIRToTFLiteFlatBuffer(
       toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},
diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
index e16dbdea02a..01f8070b37a 100644
--- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
@@ -174,6 +174,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
   bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
   pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
   pass_config.lower_tensor_list_ops = true;
+  // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
+  // conversion to an unsupported 16x16 TFL::FullyConnectedOp.
+  if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
+    pass_config.unfold_batch_matmul = false;
+  }
 
   // TODO(b/153507667): Pass the session object when importing logic is removed.
   auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h
index 7343853fea0..0e2f4906a7a 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h
@@ -53,10 +53,11 @@ struct QuantizationSpecs {
   bool disable_per_channel = false;
 
   // When set to true, the fixed output ranges of the activation ops (tanh,
-  // sigmoid, etc.) are not enforced. Then, to quantize these ops, quantization
-  // emulation ops should be specified after the ops in the input graph. This
-  // flag should be set to false for post-training quantization.
-  bool disable_enforced_fixed_output_range = false;
+  // sigmoid, etc.) and the weight constants are not inferred. Then, to quantize
+  // these ops, quantization emulation ops should be placed after the ops in the
+  // input graph. This flag should be set to false for post-training
+  // quantization.
+  bool disable_infer_tensor_range = false;
 
   // The node type when the model is exported. Currently this is limited to
   // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
index 53674972f5b..63afab0cf12 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
@@ -100,13 +100,13 @@ class QuantizationDriver {
   explicit QuantizationDriver(FuncOp fn, bool is_signed,
                               bool disable_per_channel,
                               OpQuantSpecGetter op_quant_spec_getter,
-                              bool enforce_fixed_output_range)
+                              bool infer_tensor_range)
       : fn_(fn),
         builder_(fn.getBody()),
         is_signed_(is_signed),
         disable_per_channel_(disable_per_channel),
         op_quant_spec_getter_(op_quant_spec_getter),
-        enforce_fixed_output_range_(enforce_fixed_output_range) {}
+        infer_tensor_range_(infer_tensor_range) {}
 
   // The entry point of the quantization parameters propagation.
   void Run();
@@ -384,7 +384,9 @@ class QuantizationDriver {
 
   OpQuantSpecGetter op_quant_spec_getter_;
 
-  bool enforce_fixed_output_range_;
+  // Infer output ranges for activation ops and constants. This is usually
+  // required for post-training quantization.
+  bool infer_tensor_range_;
 };
 }  // namespace
 
@@ -670,33 +672,43 @@ void QuantizationDriver::PreprocessConstantOps() {
 
     Value value = cst.getResult();
     builder_.setInsertionPoint(cst);
-    for (auto indexed_use : llvm::enumerate(value.getUses())) {
-      auto &use = indexed_use.value();
-      auto spec = GetQuantSpec(use.getOwner());
-      auto biases = spec->biases_params;
-      Operation *user = use.getOwner();
-      int operand_num = use.getOperandNumber();
 
+    // The following loop will change the value uses, thus we cache all the uses
+    // needs to be changed.
+    llvm::SmallVector<std::pair<Operation *, int>, 4> uses;
+    for (auto &use : value.getUses()) {
+      uses.push_back({use.getOwner(), use.getOperandNumber()});
+    }
+    for (auto indexed_use : llvm::enumerate(uses)) {
+      Operation *user = indexed_use.value().first;
+      int operand_num = indexed_use.value().second;
+
+      auto spec = GetQuantSpec(user);
+      auto biases = spec->biases_params;
+
+      // The quantization parameters of a `weight` shouldn't be determined by
+      // other values. So any constants which are not bias, an operand of an
+      // op with same scale requirements, and haven't been quantized are
+      // weights.
       if (biases.find(operand_num) == biases.end() &&
           !llvm::dyn_cast<mlir::SameScalesOpInterface>(user) &&
           !llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
-        // Needs to scan the content to get the quantization parameters if there
-        // are no quantization parameters (FakeQuant ops).
-        // For this case, the weight isn't duplicated.
+        // Needs to scan the content of weights to get the quantization
+        // parameters if there are no quantization parameters (FakeQuant ops).
+        // For this case, the weight will not be duplicated.
         weights_.insert(cst);
         auto affine_user =
             llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
-        if (affine_user &&
-            affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
+        if (affine_user && affine_user.GetAffineOperandIndex() == operand_num &&
             affine_user.RequiredNarrowRangeAffineOperand()) {
           optimized_weights_.insert(
               {cst, affine_user.GetQuantizationDimIndex()});
         }
       } else {
-        // This is a bias, so the quantization parameter isn't determined by the
-        // local content. Same if the user can have quantization parameter
-        // propagated from other places.
-        // Duplicate this constant in case it is shared by different users.
+        // This is a bias or an operand of an op with same scale requirements,
+        // so the quantization parameter are propagated from or determined by
+        // other values. Duplicate this constant in case it is shared by
+        // different users.
         if (indexed_use.index() > 0) {
           cst = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
         }
@@ -786,12 +798,14 @@ bool QuantizationDriver::PropagateParams() {
     quantized_.insert(op);
 
     if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
-      // If it isn't a weight or has been quantized, skip.
-      if (!IsWeight(cst) || IsQuantized(op)) continue;
-
-      // The quantization parameters are determined by the content of the
-      // constant.
-      changed |= SetConstantResultParams(op);
+      // If the workflow requires inferring ranges from the content
+      // (post-training quantization) and it is weight (filter) and hasn't
+      // been quantized, we infer the quantization parameters from the content.
+      if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) {
+        // The quantization parameters are determined by the content of the
+        // constant.
+        changed |= SetConstantResultParams(op);
+      }
       continue;
     }
 
@@ -826,7 +840,9 @@ bool QuantizationDriver::PropagateParams() {
 
     // TODO(fengliuai): make the bit width configurable.
     auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
-    if (restricted && enforce_fixed_output_range_) {
+    if (restricted && infer_tensor_range_) {
+      // Infer ranges from the activation ops. This is usually required for
+      // the post-training quantization workflow.
       // TODO(fengliuai): different result can have different fixed range.
       auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
       for (auto i = 0; i < op->getNumResults(); ++i) {
@@ -903,9 +919,9 @@ void QuantizationDriver::Run() {
 void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
                                         bool disable_per_channel,
                                         OpQuantSpecGetter op_quant_spec_getter,
-                                        bool post_training_quantization) {
+                                        bool infer_tensor_ranges) {
   QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
-                     post_training_quantization)
+                     infer_tensor_ranges)
       .Run();
 }
 
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
index 5463ba15e18..cb131453a7b 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
@@ -490,13 +490,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
 // and the propagation results are materialized by inserting pairs of quantize
 // and dequantize ops to this function. Set `disable_per_channel` to true to not
 // use per channel quantization even the op supports it.
-// Setting `enforce_fixed_output_range` to true, to infer quantization
-// parameters from the fixed output range ops. This is only used for
-// post-training quantization.
+// Setting `infer_tensor_range` to true, to infer quantization parameters from
+// the activation ops and weight constants. This is only used for post-training
+// quantization.
 void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
                                         bool disable_per_channel,
                                         OpQuantSpecGetter op_quant_spec_getter,
-                                        bool enforce_fixed_output_range);
+                                        bool infer_tensor_ranges);
 
 // The function might contain more stats ops than required, and it will
 // introduce requantize if the calibration stats have conflicts. This method
diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
index 27a7068cda6..88b084220e8 100644
--- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
@@ -638,4 +638,10 @@ func @cast_ui8_to_i32() -> tensor<4xi32> {
 // CHECK:  return %[[CST]]
 }
 
+// CHECK-LABEL: @cast_identity
+func @cast_identity(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
+  %0 = "tfl.cast"(%arg0) : (tensor<7xf32>) -> tensor<7xf32>
+  return %0 : tensor<7xf32>
+  // CHECK: return %arg0 : tensor<7xf32>
+}
 
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json
new file mode 100644
index 00000000000..7052d289796
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json
@@ -0,0 +1,335 @@
+// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
+
+// CHECK: effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01, 5.000000e-01>>>
+// CHECK: input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00, 4.000000e+00>>>
+// CHECK: input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01, 1.600000e+01>>>
+// CHECK: input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01, 3.200000e+01>>>
+// CHECK: input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00, 1.000000e+00>>>
+
+{
+  "version": 3,
+  "operator_codes": [
+    {
+      "builtin_code": "UNIDIRECTIONAL_SEQUENCE_LSTM"
+    }
+  ],
+  "subgraphs": [
+    {
+      "tensors": [
+        {
+          "shape": [1, 5],
+          "name": "input0"
+        },
+        {
+          "shape": [2, 5],
+          "buffer": 1,
+          "name": "input2input_weights1"
+        },
+        {
+          "shape": [2, 5],
+          "buffer": 2,
+          "name": "input2forget_weights2"
+        },
+        {
+          "shape": [2, 5],
+          "buffer": 3,
+          "name": "input2cell_weights3"
+        },
+        {
+          "shape": [2, 5],
+          "buffer": 4,
+          "name": "input2output_weights4"
+        },
+        {
+          "shape": [2, 4],
+          "buffer": 5,
+          "name": "rec2input_weights5"
+        },
+        {
+          "shape": [2, 4],
+          "buffer": 6,
+          "name": "rec2forget_weights6"
+        },
+        {
+          "shape": [2, 4],
+          "buffer": 7,
+          "name": "rec2cell_weights7"
+        },
+        {
+          "shape": [2, 4],
+          "buffer": 8,
+          "name": "rec2output_weights8"
+        },
+        {
+          "shape": [2],
+          "buffer": 9,
+          "name": "cell2input_weights9"
+        },
+        {
+          "shape": [2],
+          "buffer": 10,
+          "name": "cell2forget_weights10"
+        },
+        {
+          "shape": [2],
+          "buffer": 11,
+          "name": "cell2output_weights11"
+        },
+        {
+          "shape": [2],
+          "buffer": 12,
+          "name": "input_gate_bias12"
+        },
+        {
+          "shape": [2],
+          "buffer": 13,
+          "name": "forget_gate_bias13"
+        },
+        {
+          "shape": [2],
+          "buffer": 14,
+          "name": "cell_gate_bias14"
+        },
+        {
+          "shape": [2],
+          "buffer": 15,
+          "name": "output_gate_bias15"
+        },
+        {
+          "shape": [4, 2],
+          "buffer": 16,
+          "name": "proj_weights16"
+        },
+        {
+          "shape": [4],
+          "buffer": 17,
+          "name": "proj_bias17"
+        },
+        {
+          "shape": [1, 4],
+          "name": "input_activation_state18",
+          "is_variable": true,
+          "quantization": {
+            "min": [-0.9],
+            "max": [0.9]
+          }
+        },
+        {
+          "shape": [1, 2],
+          "name": "input_cell_state19",
+          "is_variable": true,
+          "quantization": {
+            "min": [-0.8],
+            "max": [0.8]
+          }
+        },
+        {
+          "shape": [2],
+          "buffer": 18,
+          "name": "input_norm20"
+        },
+        {
+          "shape": [2],
+          "buffer": 19,
+          "name": "forget_norm21"
+        },
+        {
+          "shape": [2],
+          "buffer": 20,
+          "name": "cell_norm22"
+        },
+        {
+          "shape": [2],
+          "buffer": 21,
+          "name": "output_norm23"
+        },
+        {
+          "shape": [],
+          "name": "output24"
+        },
+        {
+          "shape": [],
+          "name": "intermediate_0",
+          "is_variable": true,
+          "quantization": {
+            "min": [-32],
+            "max": [32]
+          }
+        },
+        {
+          "shape": [],
+          "name": "intermediate_1",
+          "is_variable": true,
+          "quantization": {
+            "min": [-16],
+            "max": [16]
+          }
+        },
+        {
+          "shape": [],
+          "name": "intermediate_2",
+          "is_variable": true,
+          "quantization": {
+            "min": [-4],
+            "max": [4]
+          }
+        },
+        {
+          "shape": [],
+          "name": "intermediate_3",
+          "is_variable": true,
+          "quantization": {
+            "min": [-1.0],
+            "max": [1.0]
+          }
+        },
+        {
+          "shape": [],
+          "name": "intermediate_4",
+          "is_variable": true,
+          "quantization": {
+            "min": [-0.5],
+            "max": [0.5]
+          }
+        }
+      ],
+      "inputs": [0],
+      "outputs": [24],
+      "operators": [
+        {
+          "inputs": [
+            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23
+          ],
+          "outputs": [24],
+          "intermediates": [
+            25, 26, 27, 28, 29
+          ],
+          "builtin_options_type": "UnidirectionalSequenceLSTMOptions",
+          "builtin_options": {
+            "fused_activation_function": "TANH",
+            "cell_clip": 50.0
+          },
+          "mutating_variable_inputs": [
+            false,
+            false, false, false, false,
+            false, false, false, false,
+            false, false, false,
+            false, false, false, false,
+            true, true,
+            false, false, false, false
+          ]
+        }
+      ]
+    }
+  ],
+  "buffers": [
+    {
+      "data": []
+    },
+    {
+      "data": [
+        36, 167, 168, 63, 0, 140, 72, 191, 120, 20, 147, 62, 20, 152, 196, 190, 121, 98, 82, 187, 95, 128, 213, 61, 189, 3, 138, 63, 54, 103, 13, 62, 46, 224, 66, 63, 157, 204, 180, 191
+      ]
+    },
+    {
+      "data": [
+        223, 20, 21, 64, 246, 166, 31, 191, 6, 51, 157, 188, 114, 90, 167, 62, 118, 240, 59, 63, 49, 162, 255, 62, 17, 91, 160, 63, 32, 47, 26, 63, 40, 136, 178, 191, 243, 154, 236, 61
+      ]
+    },
+    {
+      "data": [
+        137, 231, 86, 63, 41, 154, 16, 63, 239, 37, 77, 191, 55, 189, 24, 189, 86, 63, 18, 63, 42, 55, 13, 191, 110, 139, 138, 191, 219, 148, 181, 63, 71, 232, 108, 191, 66, 226, 145, 191
+      ]
+    },
+    {
+      "data": [
+        245, 179, 225, 190, 51, 202, 176, 189, 132, 47, 53, 191, 155, 25, 50, 191, 197, 130, 240, 191, 98, 125, 45, 62, 243, 70, 83, 62, 85, 155, 139, 63, 113, 239, 11, 192, 35, 251, 139, 62
+      ]
+    },
+    {
+      "data": [
+        248, 188, 211, 191, 142, 11, 73, 62, 36, 8, 84, 63, 186, 208, 11, 191, 76, 208, 190, 191, 223, 200, 210, 63, 183, 170, 103, 63, 116, 129, 145, 63
+      ]
+    },
+    {
+      "data": [
+        235, 202, 222, 190, 159, 201, 112, 191, 217, 248, 166, 63, 165, 199, 131, 191, 130, 59, 47, 63, 179, 11, 186, 62, 55, 168, 18, 192, 152, 213, 26, 64
+      ]
+    },
+    {
+      "data": [
+        245, 123, 138, 62, 213, 106, 231, 59, 211, 218, 250, 62, 25, 157, 134, 63, 147, 22, 164, 63, 25, 221, 139, 62, 1, 230, 247, 62, 210, 185, 142, 63
+      ]
+    },
+    {
+      "data": [
+        197, 123, 23, 192, 45, 96, 178, 190, 174, 87, 165, 62, 213, 225, 200, 191, 119, 248, 15, 191, 128, 125, 171, 189, 90, 125, 222, 63, 4, 76, 95, 62
+      ]
+    },
+    {
+      "data": [
+        210, 73, 183, 63, 248, 177, 13, 191
+      ]
+    },
+    {
+      "data": [
+        78, 251, 212, 191, 169, 29, 147, 63
+      ]
+    },
+    {
+      "data": [
+        178, 227, 203, 191, 247, 155, 103, 63
+      ]
+    },
+    {
+      "data": [
+        206, 111, 165, 190, 153, 77, 227, 63
+      ]
+    },
+    {
+      "data": [
+        255, 114, 132, 191, 253, 202, 140, 191
+      ]
+    },
+    {
+      "data": [
+        90, 247, 1, 192, 125, 120, 209, 191
+      ]
+    },
+    {
+      "data": [
+        65, 75, 243, 191, 58, 122, 146, 190
+      ]
+    },
+    {
+      "data": [
+        40, 135, 20, 63, 109, 50, 220, 191, 56, 241, 189, 63, 65, 12, 92, 63, 61, 14, 162, 62, 157, 138, 81, 63, 125, 61, 191, 61, 102, 231, 20, 63
+      ]
+    },
+    {
+      "data": [
+        145, 79, 49, 189, 175, 235, 220, 190, 182, 111, 157, 190, 144, 236, 97, 191
+      ]
+    },
+    {
+      "data": [
+        76, 188, 109, 63, 228, 150, 201, 190
+      ]
+    },
+    {
+      "data": [
+        6, 146, 66, 191, 122, 127, 100, 191
+      ]
+    },
+    {
+      "data": [
+        216, 59, 169, 190, 161, 178, 215, 191
+      ]
+    },
+    {
+      "data": [
+        208, 144, 101, 191, 127, 233, 195, 190
+      ]
+    }
+  ]
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir
index 07962e15c33..b2a721c947d 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir
@@ -20,20 +20,20 @@ func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32
 
 func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
     %cst = constant unit
-    %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
+    %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
     return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
 // CHECK-LABEL: testFullyQuantizedLSTM
 // CHECK: %cst = constant unit
 // CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18)
-// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
+// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
 }
 
 // -----
 
 // CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
 func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
-  // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
-  %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index 2542c6c9af7..d991452b187 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -318,6 +318,15 @@ func @any(%arg0: tensor<2x2xi1>, %arg1: tensor<i32>) -> tensor<i1> {
 // CHECK:  "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
 }
 
+func @any_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tensor<?xi1> {
+  %0 = "tf.Any"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi64>) -> tensor<?xi1>
+  return %0 : tensor<?xi1>
+
+  // CHECK-LABEL: any_i64axes
+  // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
+  // CHECK: "tfl.reduce_any"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor<?xi1>
+}
+
 func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
   %0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
   return %0 : tensor<8x16xf32>
@@ -972,6 +981,15 @@ func @sum_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32
   // CHECK: "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
 }
 
+func @sum_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
+  %0 = "tf.Sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+
+  // CHECK-LABEL: sum_i64axes
+  // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
+  // CHECK: "tfl.sum"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
+}
+
 func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
   %0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
@@ -988,6 +1006,15 @@ func @reduce_min_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso
   // CHECK: "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
 }
 
+func @reduce_min_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
+  %0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+
+  // CHECK-LABEL: reduce_min_i64axes
+  // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
+  // CHECK: "tfl.reduce_min"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
+}
+
 func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
   %0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
@@ -1004,6 +1031,15 @@ func @reduce_max_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso
   // CHECK: "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
 }
 
+func @reduce_max_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
+  %0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+
+  // CHECK-LABEL: reduce_max_i64axes
+  // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
+  // CHECK: "tfl.reduce_max"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
+}
+
 func @reduce_prod(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
   %0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
@@ -1020,6 +1056,15 @@ func @reduce_prod_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tens
   // CHECK: "tfl.reduce_prod"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
 }
 
+func @reduce_prod_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
+  %0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+
+  // CHECK-LABEL: reduce_prod_i64axes
+  // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
+  // CHECK: "tfl.reduce_prod"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
+}
+
 func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
   %0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
@@ -1172,10 +1217,10 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
   %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
   return %0 : tensor<10x10xf32>
   // CHECK-LABEL: strided_slice_with_constant_attributes
-  // CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32>
-  // CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32>
-  // CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32>
-  // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
+  // CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32>
+  // CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32>
+  // CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32>
+  // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
 }
 
 func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
@@ -1185,6 +1230,39 @@ func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tenso
   // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
 }
 
+func @strided_slice_with_unranked_input_and_i64_parameters(%arg0: tensor<*xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<*xf32> {
+  %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+  // CHECK-LABEL: strided_slice_with_unranked_input_and_i64_parameters
+  // CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32>
+  // CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32>
+  // CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32>
+  // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<*xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xf32>
+}
+
+func @strided_slice_with_i64_parameters(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<1x2x2x5xf32> {
+  %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1x2x2x5xf32>
+  return %0 : tensor<1x2x2x5xf32>
+  // CHECK-LABEL: strided_slice_with_i64_parameters
+  // CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32>
+  // CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32>
+  // CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32>
+  // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
+}
+
+func @strided_slice_with_i64_constant_attributes(%arg0: tensor<10x10x10xf32>) -> tensor<10x10xf32> {
+  %cst = constant dense<-1> : tensor<1xi64>
+  %cst_1 = constant dense<0> : tensor<1xi64>
+  %cst_2 = constant dense<1> : tensor<1xi64>
+  %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<10x10xf32>
+  return %0 : tensor<10x10xf32>
+  // CHECK-LABEL: strided_slice_with_i64_constant_attributes
+  // CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32>
+  // CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32>
+  // CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32>
+  // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
+}
+
 func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
   %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
   return %0 : tensor<?x3x5xf32>
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
index 53caf15bc8f..5407e8dfdae 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
@@ -625,6 +625,35 @@ func @QuantizeSharedBiases2(
 // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
 }
 
+// Make sure constants are duplicataed for all users.
+// CHECK-LABEL: QuantizeSharedConstantsMultipleUsers
+func @QuantizeSharedConstantsMultipleUsers(
+    %arg0: tensor<32x!quant.uniform<u8:f32, 1.0>>,
+    %arg1: tensor<32x!quant.uniform<u8:f32, 2.0>>,
+    %arg2: tensor<32x!quant.uniform<u8:f32, 3.0>>,
+    %arg3: tensor<32x!quant.uniform<u8:f32, 4.0>>) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) {
+  %cst = constant dense<0.0> : tensor<32xf32>
+  %0 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform<u8:f32, 1.0>>) -> tensor<32xf32>
+  %1 = "tfl.dequantize"(%arg1) : (tensor<32x!quant.uniform<u8:f32, 2.0>>) -> tensor<32xf32>
+  %2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<u8:f32, 3.0>>) -> tensor<32xf32>
+  %3 = "tfl.dequantize"(%arg3) : (tensor<32x!quant.uniform<u8:f32, 4.0>>) -> tensor<32xf32>
+
+  %4 = "tfl.minimum"(%0, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+  %5 = "tfl.minimum"(%1, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+  %6 = "tfl.minimum"(%2, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+  %7 = "tfl.minimum"(%3, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+  return %4, %5, %6, %7 : tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>
+
+// CHECK: %[[cst1:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<32xf32>
+// CHECK: %[[cst2:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 3.000000e+00>>) -> tensor<32xf32>
+// CHECK: %[[cst3:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 4.000000e+00>>) -> tensor<32xf32>
+// CHECK: %[[cst4:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<32xf32>
+// CHECK: "tfl.minimum"(%{{.*}}, %[[cst4]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+// CHECK: "tfl.minimum"(%{{.*}}, %[[cst1]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+// CHECK: "tfl.minimum"(%{{.*}}, %[[cst2]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+// CHECK: "tfl.minimum"(%{{.*}}, %[[cst3]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
+}
+
 // Make sure quantization parameters are scanned from weight, but not from bias.
 // CHECK-LABEL: QuantizeWeight
 func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 53fc1bf88be..50a93ecf4b5 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -551,6 +551,19 @@ func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf
   return %0 : tensor<8x4x16x1xf32>
 }
 
+func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<10x10xf32> {
+  %cst = constant dense<-1> : tensor<1xi32>
+  %cst_1 = constant dense<0> : tensor<1xi32>
+  %cst_2 = constant dense<1> : tensor<1xi32>
+  %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
+  return %0 : tensor<10x10xf32>
+  // CHECK-LABEL: strided_slice_with_constant_attributes
+  // CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32>
+  // CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32>
+  // CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32>
+  // CHECK-NEXT: "tf.StridedSlice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
+}
+
 func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
   %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
   return %0: tensor<3x3xf32>
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
index 5b35c71b0bb..f4bc89b9801 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
@@ -328,23 +328,28 @@ def LegalizeMean : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2),
                        (TFL_MeanOp $arg0, $arg1, $arg2)>;
 
 def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2),
-                      (TFL_SumOp $arg, $axes, $arg2)>;
+                      (TFL_SumOp $arg, (CreateTFCastToInt32Op $axes), $arg2)>;
 
 // TopK in TFL is always sorted so we ignore that attribute here.
 def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted),
                          (TFL_TopKV2Op $input, $k)>;
 
-def LegalizeMin : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2),
-                      (TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
+def LegalizeMin : Pat<
+  (TF_MinOp $arg0, $axes, BoolAttr:$arg2),
+  (TFL_ReduceMinOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
 
-def LegalizeMax : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2),
-                      (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
+def LegalizeMax : Pat<
+  (TF_MaxOp $arg0, $axes, BoolAttr:$arg2),
+  (TFL_ReduceMaxOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
 
-def LegalizeProd : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2),
-                       (TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
+def LegalizeProd : Pat<
+  (TF_ProdOp $arg0, $axes, BoolAttr:$arg2),
+  (TFL_ReduceProdOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
 
-def LegalizeAny : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims),
-                      (TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>;
+def LegalizeAny : Pat<
+  (TF_AnyOp $input, $reduction_indices, $keep_dims),
+  (TFL_ReduceAnyOp $input, (CreateTFCastToInt32Op $reduction_indices),
+                   $keep_dims)>;
 
 def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
 
@@ -471,3 +476,14 @@ def LegalizeCumsum : Pat<
 def LegalizeReshape : Pat<
   (TF_ReshapeOp $input, $shape),
   (TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>;
+
+def LegalizeStridedSlice : Pat<
+  (TF_StridedSliceOp
+    $input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask,
+    $new_axis_mask, $shrink_axis_mask),
+  (TFL_StridedSliceOp $input,
+    (CreateTFCastToInt32Op $begin), (CreateTFCastToInt32Op $end),
+    (CreateTFCastToInt32Op $strides), (convertIntAttrTo32Bit $begin_mask),
+    (convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
+    (convertIntAttrTo32Bit $new_axis_mask),
+    (convertIntAttrTo32Bit $shrink_axis_mask))>;
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index bf623976423..febbd7d2c83 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -148,7 +148,6 @@ DECL_CONVERT_OP(MatrixDiagV3);
 DECL_CONVERT_OP(Pack);
 DECL_CONVERT_OP(Split);
 DECL_CONVERT_OP(SplitV);
-DECL_CONVERT_OP(StridedSlice);
 DECL_CONVERT_OP(Unpack);
 DECL_CONVERT_OP(RandomUniform);
 
@@ -325,81 +324,6 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite(
   return success();
 }
 
-Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
-                                    Value attribute,
-                                    ArrayRef<int32_t> padding_val, int* mask) {
-  DenseIntElementsAttr dense_elem_attr;
-  SmallVector<int32_t, 8> padded_val;
-
-  auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
-  if (!ranked_attr_type ||
-      !matchPattern(attribute, m_Constant(&dense_elem_attr))) {
-    // If the input attribute is neither ranked type nor constant, we
-    // can't do any padding. Instead we just return it.
-    return attribute;
-  }
-  for (const auto& idx : dense_elem_attr.getIntValues()) {
-    padded_val.push_back(idx.getSExtValue());
-  }
-  auto attr_dim_count = ranked_attr_type.getShape()[0];
-  int full_dim_count = padding_val.size();
-  for (int i = attr_dim_count; i < full_dim_count; ++i) {
-    padded_val.push_back(padding_val[i]);
-    if (mask) *mask |= 1 << i;
-  }
-  auto type =
-      RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
-  auto attr = DenseElementsAttr::get<int32_t>(type, padded_val);
-  return rewriter.create<ConstantOp>(op->getLoc(), type, attr);
-}
-
-LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
-    Operation* op, PatternRewriter& rewriter) const {
-  auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
-  auto ranked_input_type =
-      tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
-  if (!ranked_input_type) {
-    // If input is not a ranked tensor, we can't deduce the padding dimensions
-    // from it, so we just do a plain conversion here.
-    rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
-        op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
-        tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
-        tf_strided_slice_op.strides(),
-        rewriter.getI32IntegerAttr(tf_strided_slice_op.begin_mask()),
-        rewriter.getI32IntegerAttr(tf_strided_slice_op.end_mask()),
-        rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
-        rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
-        rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
-    return success();
-  }
-
-  int num_input_dims = ranked_input_type.getRank();
-  // Pad `begin` array with zero values and update the `begin_mask`.
-  SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0);
-  int begin_mask = tf_strided_slice_op.begin_mask();
-  Value padded_begin = PadStridedSliceAttributeArray(
-      op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask);
-  // Pad `end` array with `input_shape` and update the `end_mask`.
-  int end_mask = tf_strided_slice_op.end_mask();
-  auto input_shape = ranked_input_type.getShape();
-  SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end());
-  Value padded_end = PadStridedSliceAttributeArray(
-      op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask);
-  // Pad `strides` array with ones.
-  SmallVector<int32_t, 8> strides_pad_val(num_input_dims, 1);
-  Value padded_strides = PadStridedSliceAttributeArray(
-      op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
-  rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
-      op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
-      padded_begin, padded_end, padded_strides,
-      rewriter.getI32IntegerAttr(begin_mask),
-      rewriter.getI32IntegerAttr(end_mask),
-      rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
-      rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
-      rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
-  return success();
-}
-
 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tf_unpack_op = cast<TF::UnpackOp>(op);
@@ -769,8 +693,8 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
   patterns
       .insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
               ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
-              ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
-              ConvertTFAssertOp, ConvertTFRandomUniformOp>(context);
+              ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
+              ConvertTFRandomUniformOp>(context);
 
   // Ophint python converter converted tf node pattern.
   patterns.insert<LegalizeUnidirectionalSequenceLstm,
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
index a2af9b6e9d7..235117fcb92 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
@@ -376,14 +376,17 @@ void PrepareQuantizePass::runOnFunction() {
   OwningRewritePatternList patterns;
   bool is_signed = quant_specs_.IsSignedInferenceType();
   int bit_width = quant_specs_.GetQuantizationTypeWidth();
-  bool quantization_aware_training_mode = ContainsQuantizeOps(func);
-  // Enforce fixed output range for post-training quantization and
-  // when the model has quantization emulation ops, unless it was disabled
-  // explicitly by the flag.
-  bool enforced_output_range =
-      (quant_specs_.post_training_quantization ||
-       quantization_aware_training_mode) &&
-      !quant_specs_.disable_enforced_fixed_output_range;
+  // When this is true, the quantizer will try its best to extract the
+  // quantization parameters from the op quantization property and constant
+  // content. This is also set to true when the `quantize_allowlist` and
+  // `quantize_signed` test flags are enabled.
+  bool eager_quantize = ContainsQuantizeOps(func) ||
+                        (!quantize_allowlist.empty() || quantize_signed);
+  // Infer the tensor range for the activation ops and weight constants unless
+  // it is disabled explicitly.
+  bool infer_tensor_range =
+      (quant_specs_.post_training_quantization || eager_quantize) &&
+      !quant_specs_.disable_infer_tensor_range;
   if (is_signed) {
     patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
     // Convert quant stats to int8 quantization parameters.
@@ -403,7 +406,7 @@ void PrepareQuantizePass::runOnFunction() {
   // values (tensors).
   ApplyQuantizationParamsPropagation(
       func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
-      GetOpQuantSpec, enforced_output_range);
+      GetOpQuantSpec, infer_tensor_range);
 
   ConvertMlirQuantOpsToTFLQuantOps(func);
 }
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index bb44788e961..f56c8bc0d06 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -712,6 +712,23 @@ struct ConvertTFStridedSlice : public RewritePattern {
     return success();
   }
 
+  void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
+                                     SmallVectorImpl<int32_t> &val,
+                                     SmallVectorImpl<int32_t> &padded_val,
+                                     ArrayRef<int32_t> padding_val,
+                                     int *mask) const {
+    for (const auto &idx : dense_elem_attr.getIntValues()) {
+      val.push_back(idx.getSExtValue());
+      padded_val.push_back(idx.getSExtValue());
+    }
+    int attr_dim_count = val.size();
+    int full_dim_count = padding_val.size();
+    for (int i = attr_dim_count; i < full_dim_count; ++i) {
+      padded_val.push_back(padding_val[i]);
+      if (mask) *mask |= 1 << i;
+    }
+  }
+
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
@@ -719,17 +736,102 @@ struct ConvertTFStridedSlice : public RewritePattern {
     // Handle new axis mask.
     if (strided_slice_op.new_axis_mask() != 0) {
       // We currently don't handle simultaneous shrink_ and new_axis masks.
-      if (strided_slice_op.shrink_axis_mask()) {
-        return failure();
+      if (!strided_slice_op.shrink_axis_mask()) {
+        return RewriteNewAxisMask(strided_slice_op, rewriter);
       }
-      return RewriteNewAxisMask(strided_slice_op, rewriter);
     }
 
     // Handle ellipsis mask.
     if (strided_slice_op.ellipsis_mask() != 0) {
       return RewriteEllipsisMask(strided_slice_op, rewriter);
     }
-    return failure();
+
+    auto ranked_input_type =
+        strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
+    if (!ranked_input_type) {
+      return failure();
+    }
+
+    auto begin_attr = strided_slice_op.begin();
+    auto end_attr = strided_slice_op.end();
+    auto strides_attr = strided_slice_op.strides();
+
+    auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
+    auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
+    auto strides_attr_type =
+        strides_attr.getType().dyn_cast<RankedTensorType>();
+
+    DenseIntElementsAttr begin_elem_attr;
+    DenseIntElementsAttr end_elem_attr;
+    DenseIntElementsAttr strides_elem_attr;
+
+    if (!begin_attr_type ||
+        !matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
+      return failure();
+    }
+    if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
+      return failure();
+    }
+    if (!strides_attr_type ||
+        !matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
+      return failure();
+    }
+
+    SmallVector<int32_t, 4> begin, end, strides;
+    SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
+
+    int num_input_dims = ranked_input_type.getRank();
+    SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
+    auto input_shape = ranked_input_type.getShape();
+    SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
+    SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
+
+    int begin_mask = strided_slice_op.begin_mask();
+    int end_mask = strided_slice_op.end_mask();
+
+    PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
+                                  padding_begin, &begin_mask);
+    PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
+                                  &end_mask);
+    PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
+                                  padding_strides, nullptr);
+
+    if (begin == padded_begin && end == padded_end &&
+        strides == padded_strides &&
+        begin_mask == strided_slice_op.begin_mask() &&
+        end_mask == strided_slice_op.end_mask()) {
+      return failure();
+    }
+
+    auto begin_end_type =
+        RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
+    auto new_begin_attr = rewriter.create<ConstantOp>(
+        op->getLoc(), begin_end_type,
+        DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
+    auto new_end_attr = rewriter.create<ConstantOp>(
+        op->getLoc(), begin_end_type,
+        DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
+    auto strides_type =
+        RankedTensorType::get({static_cast<long>(padded_strides.size())},
+                              rewriter.getIntegerType(32));
+    auto new_strides_attr = rewriter.create<ConstantOp>(
+        op->getLoc(), strides_type,
+        DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
+
+    auto attribute_type = rewriter.getIntegerType(64);
+    rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
+        op, strided_slice_op.output().getType(), strided_slice_op.input(),
+        new_begin_attr, new_end_attr, new_strides_attr,
+        rewriter.getIntegerAttr(attribute_type, begin_mask),
+        rewriter.getIntegerAttr(attribute_type, end_mask),
+        rewriter.getIntegerAttr(attribute_type,
+                                strided_slice_op.ellipsis_mask()),
+        rewriter.getIntegerAttr(attribute_type,
+                                strided_slice_op.new_axis_mask()),
+        rewriter.getIntegerAttr(attribute_type,
+                                strided_slice_op.shrink_axis_mask()));
+
+    return success();
   }
 };
 
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index bedb8d0c5ea..d26f09c5c91 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -337,6 +337,7 @@ cc_library(
         ":tensorflow",
         "//tensorflow/compiler/mlir/hlo",
         "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:Dialect",
@@ -859,6 +860,7 @@ cc_library(
         "transforms/mark_ops_for_outside_compilation.cc",
         "transforms/materialize_mlir_passthrough_op.cc",
         "transforms/optimize.cc",
+        "transforms/outside_compiled_to_host_launch.cc",
         "transforms/parallel_execute_to_islands.cc",
         "transforms/parallelize_embedding_params_ops_pass.cc",
         "transforms/promote_resources_to_args.cc",
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index c8d1a83058f..76e8836fef1 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -31,7 +31,7 @@ limitations under the License.
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 
-def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> {
+def TF_AbsOp : TF_Op<"Abs", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes the absolute value of a tensor.";
 
   let description = [{
@@ -1002,13 +1002,13 @@ reverse of SpaceToBatch.  See below for a precise description.
     TF_Tensor:$output
   );
 
-  let verifier = [{
-    return Verify(*this);
-  }];
-
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
   TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>;
   TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
+
+  let verifier = [{
+    return Verify(*this);
+  }];
 }
 
 def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> {
@@ -1486,7 +1486,7 @@ def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
   let hasFolder = 1;
 }
 
-def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> {
+def TF_CeilOp : TF_Op<"Ceil", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Returns element-wise smallest integer not less than x.";
 
   let arguments = (ins
@@ -3502,8 +3502,8 @@ tf.math.equal(x, y) ==> array([True,  True])
   }];
 
   let arguments = (ins
-    TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
-    TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
+    TF_Tensor:$x,
+    TF_Tensor:$y,
 
     DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
   );
@@ -3843,8 +3843,8 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
   let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation.";
 
   let arguments = (ins
-    F32Tensor:$gradients,
-    F32Tensor:$inputs,
+    TF_Float32Tensor:$gradients,
+    TF_Float32Tensor:$inputs,
 
     DefaultValuedAttr<F32Attr, "-6.0f">:$min,
     DefaultValuedAttr<F32Attr, "6.0f">:$max,
@@ -3853,7 +3853,7 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
   );
 
   let results = (outs
-    F32Tensor:$backprops
+    TF_Float32Tensor:$backprops
   );
 }
 
@@ -3911,19 +3911,19 @@ def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradien
   let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation.";
 
   let arguments = (ins
-    F32Tensor:$gradients,
-    F32Tensor:$inputs,
-    F32Tensor:$min,
-    F32Tensor:$max,
+    TF_Float32Tensor:$gradients,
+    TF_Float32Tensor:$inputs,
+    TF_Float32Tensor:$min,
+    TF_Float32Tensor:$max,
 
     DefaultValuedAttr<I64Attr, "8">:$num_bits,
     DefaultValuedAttr<BoolAttr, "false">:$narrow_range
   );
 
   let results = (outs
-    F32Tensor:$backprops_wrt_input,
-    F32Tensor:$backprop_wrt_min,
-    F32Tensor:$backprop_wrt_max
+    TF_Float32Tensor:$backprops_wrt_input,
+    TF_Float32Tensor:$backprop_wrt_min,
+    TF_Float32Tensor:$backprop_wrt_max
   );
 }
 
@@ -4026,7 +4026,7 @@ fill([2, 3], 9) ==> [[9, 9, 9]
   ];
 }
 
-def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> {
+def TF_FloorOp : TF_Op<"Floor", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Returns element-wise largest integer not greater than x.";
 
   let arguments = (ins
@@ -4977,13 +4977,13 @@ $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
   }];
 
   let arguments = (ins
-    F32Tensor:$predictions,
+    TF_Float32Tensor:$predictions,
     TF_I32OrI64Tensor:$targets,
     TF_I32OrI64Tensor:$k
   );
 
   let results = (outs
-    I1Tensor:$precision
+    TF_BoolTensor:$precision
   );
 
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
@@ -6912,7 +6912,7 @@ retained with length 1.
   ];
 }
 
-def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
+def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
   let summary = "Performs max pooling on the input.";
 
   let arguments = (ins
@@ -6936,6 +6936,9 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInter
     SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
     SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
     LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
+    // TF_LayoutSensitiveInterface:
+    StringRef GetOptimalLayout(const RuntimeDevices& devices);
+    LogicalResult UpdateDataFormat(StringRef data_format);
   }];
 }
 
@@ -7852,8 +7855,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> {
   }];
 
   let arguments = (ins
-    TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
-    TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
+    TF_Tensor:$x,
+    TF_Tensor:$y,
 
     DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
   );
@@ -8031,7 +8034,7 @@ times by rerunning "MakeIterator".
   );
 }
 
-def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> {
+def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Returns a tensor of ones with the same shape and type as x.";
 
   let arguments = (ins
@@ -8429,6 +8432,10 @@ def TF_QrOp : TF_Op<"Qr", [NoSideEffect]> {
 Computes the QR decomposition of each inner matrix in `tensor` such that
 `tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
 
+Currently, the gradient for the QR decomposition is well-defined only when
+the first `P` columns of the inner matrix are linearly independent, where
+`P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`.
+
 ```python
 # a is a tensor.
 # q is a tensor of orthonormal matrices.
@@ -9114,7 +9121,7 @@ most one RecvTPUEmbeddingActivations op in the TPU graph.
   TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>;
 }
 
-def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
+def TF_ReluOp : TF_Op<"Relu", [Idempotent, NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
   let summary = "Computes rectified linear: `max(features, 0)`.";
 
   let description = [{
@@ -9140,7 +9147,7 @@ array([ 0.,  0., -0.,  3.], dtype=float32)
   }];
 }
 
-def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> {
+def TF_Relu6Op : TF_Op<"Relu6", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`.";
 
   let arguments = (ins
@@ -10538,7 +10545,7 @@ bitwise_ops.right_shift(lhs, rhs)
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
-def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> {
+def TF_RintOp : TF_Op<"Rint", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Returns element-wise integer closest to x.";
 
   let description = [{
@@ -10605,7 +10612,7 @@ roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
   TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>;
 }
 
-def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> {
+def TF_RoundOp : TF_Op<"Round", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
   let summary = [{
 Rounds the values of a tensor to the nearest integer, element-wise.
   }];
@@ -11335,7 +11342,7 @@ Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
-def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> {
+def TF_SignOp : TF_Op<"Sign", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Returns an element-wise indication of the sign of a number.";
 
   let description = [{
@@ -12345,14 +12352,14 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
 }
 
 def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> {
-  let summary = [{
-Picks the best counter-based RNG algorithm based on device.
-  }];
+  let summary = "Picks the best counter-based RNG algorithm based on device.";
 
   let description = [{
 This op picks the best counter-based RNG algorithm based on device.
   }];
 
+  let arguments = (ins);
+
   let results = (outs
     TF_Int32Tensor:$alg
   );
@@ -14088,73 +14095,35 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are
 scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
 for the existing tensor cannot be re-used, a copy is made and updated.
 
-If `indices` contains duplicates, then their updates are accumulated (summed).
+If `indices` contains duplicates, then we pick the last update for the index.
 
-**WARNING**: The order in which updates are applied is nondeterministic, so the
-output will be nondeterministic if `indices` contains duplicates -- because
-of some numerical approximation issues, numbers summed in different order
-may yield different results.
+If an out of bound index is found on CPU, an error is returned.
+
+**WARNING**: There are some GPU specific semantics for this operation.
+- If an out of bound index is found, the index is ignored.
+- The order in which updates are applied is nondeterministic, so the output
+will be nondeterministic if `indices` contains duplicates.
 
 `indices` is an integer tensor containing indices into a new tensor of shape
-`shape`.  The last dimension of `indices` can be at most the rank of `shape`:
+`shape`.
 
-    indices.shape[-1] <= shape.rank
+* `indices` must have at least 2 axes: `(num_updates, index_depth)`.
+* The last axis of `indices` is how deep to index into `tensor` so  this index
+  depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim`
 
-The last dimension of `indices` corresponds to indices into elements
-(if `indices.shape[-1] = shape.rank`) or slices
-(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of
-`shape`.  `updates` is a tensor with shape
+if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements.
+if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input
+`tensor`.
 
-    indices.shape[:-1] + shape[indices.shape[-1]:]
+Each `update` has a rank of `tensor.rank - indices.shape[-1]`.
+The overall shape of `updates` is:
 
-The simplest form of scatter is to insert individual elements in a tensor by
-index. For example, say we want to insert 4 scattered elements in a rank-1
-tensor with 8 elements.
+```
+indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
+```
 
-<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd1.png" alt>
-</div>
-
-In Python, this scatter operation would look like this:
-
-    >>> indices = tf.constant([[4], [3], [1], [7]])
-    >>> updates = tf.constant([9, 10, 11, 12])
-    >>> tensor = tf.ones([8], dtype=tf.int32)
-    >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
-    tf.Tensor([ 1 11  1 10  9  1  1 12], shape=(8,), dtype=int32)
-
-We can also, insert entire slices of a higher rank tensor all at once. For
-example, if we wanted to insert two slices in the first dimension of a
-rank-3 tensor with two matrices of new values.
-
-In Python, this scatter operation would look like this:
-
-    >>> indices = tf.constant([[0], [2]])
-    >>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
-    ...                         [7, 7, 7, 7], [8, 8, 8, 8]],
-    ...                        [[5, 5, 5, 5], [6, 6, 6, 6],
-    ...                         [7, 7, 7, 7], [8, 8, 8, 8]]])
-    >>> tensor = tf.ones([4, 4, 4], dtype=tf.int32)
-    >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy())
-    [[[5 5 5 5]
-      [6 6 6 6]
-      [7 7 7 7]
-      [8 8 8 8]]
-     [[1 1 1 1]
-      [1 1 1 1]
-      [1 1 1 1]
-      [1 1 1 1]]
-     [[5 5 5 5]
-      [6 6 6 6]
-      [7 7 7 7]
-      [8 8 8 8]]
-     [[1 1 1 1]
-      [1 1 1 1]
-      [1 1 1 1]
-      [1 1 1 1]]]
-
-Note that on CPU, if an out of bound index is found, an error is returned.
-On GPU, if an out of bound index is found, the index is ignored.
+For usage examples see the python [tf.tensor_scatter_nd_update](
+https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
   }];
 
   let arguments = (ins
@@ -15080,7 +15049,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
   }];
 
   let arguments = (ins
-    Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
+    Arg<TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
     Arg<TF_I32OrI64Tensor, [{Array containing the starting indices of the slices we gather.}]>:$start_indices,
     Arg<TF_I32OrI64Tensor, [{slice_sizes[i] is the bounds for the slice on dimension i.}]>:$slice_sizes,
 
@@ -15089,7 +15058,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
   );
 
   let results = (outs
-    TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
+    TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
   );
 
   TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
@@ -15457,7 +15426,7 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
-def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> {
+def TF_ZerosLikeOp : TF_Op<"ZerosLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Returns a tensor of zeros with the same shape and type as x.";
 
   let arguments = (ins
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 817dfc4cfe2..20a3a9d8a7e 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -684,12 +684,23 @@ body: A function that takes a list of tensors and returns another
 
     FlatSymbolRefAttr:$cond,
     FlatSymbolRefAttr:$body,
-    DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
     DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
 
     // Used to map StatelessWhile and While op defined in TensorFlow to a common
     // op.
-    BoolAttr:$is_stateless
+    BoolAttr:$is_stateless,
+
+    // In TensorFlow, While has a special behavior where if `output_shapes`
+    // attribute is not empty, those shapes are used in its shape function
+    // as result shapes instead of propagating operand shapes as result shapes.
+    // This allows for different result shapes from operand shapes. While these
+    // shapes are imported and set as a part of the result type, there is no
+    // indicator differentiating between having no output shapes compared to
+    // having all unranked shapes. Thus this attribute is set to determine
+    // which shape function behavior to use for this op, specifically
+    // propagating operand shapes as result shapes when this attribute is not
+    // set, or preserving result shapes as is when this attribute is set.
+    UnitAttr:$shape_invariant
   );
 
   let results = (outs
@@ -697,6 +708,7 @@ body: A function that takes a list of tensors and returns another
   );
 
   TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
+  TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
 
   let verifier = [{
     return Verify(*this);
@@ -752,12 +764,23 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
   let arguments = (ins
     Variadic<AnyTensor>:$input,
 
-    DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
     DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
 
     // Used to map StatelessWhile and While op defined in TensorFlow to a common
     // op.
-    BoolAttr:$is_stateless
+    BoolAttr:$is_stateless,
+
+    // In TensorFlow, While has a special behavior where if `output_shapes`
+    // attribute is not empty, those shapes are used in its shape function
+    // as result shapes instead of propagating operand shapes as result shapes.
+    // This allows for different result shapes from operand shapes. While these
+    // shapes are imported and set as a part of the result type, there is no
+    // indicator differentiating between having no output shapes compared to
+    // having all unranked shapes. Thus this attribute is set to determine
+    // which shape function behavior to use for this op, specifically
+    // propagating operand shapes as result shapes when this attribute is not
+    // set, or preserving result shapes as is when this attribute is set.
+    UnitAttr:$shape_invariant
   );
   let results = (outs Variadic<AnyTensor>:$output);
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index 9dde6c5769e..b48338e3ced 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -2467,6 +2467,33 @@ LogicalResult MaxPoolOp::FoldOperandsPermutation(
       permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
 }
 
+LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) {
+  StringRef src_data_format = data_format();
+
+  auto perm = GetDataFormatPermutation(src_data_format, new_data_format);
+  if (perm.empty()) return failure();
+
+  // Update data_format attribute and result types.
+  if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this)))
+    return failure();
+
+  stridesAttr(ShuffleArrayAttr(strides(), perm));
+  explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2));
+  ksizeAttr(ShuffleArrayAttr(ksize(), perm));
+
+  return success();
+}
+
+StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) {
+  // Keep current data format if no GPUs are available or if explicit placement
+  // does not allow to use GPU for this operation.
+  if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
+    return data_format();
+
+  // Defaults to NCHW.
+  return "NCHW";
+}
+
 //===----------------------------------------------------------------------===//
 // MaxPoolGradOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir
index afc9e1e51ed..b4738ed5605 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir
@@ -41,3 +41,34 @@ func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32
   // CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32>
   // CHECK: %[[V1]] : tensor<3x5x7xf32>
 }
+
+// CHECK-LABEL: @broadcast_eq
+func @broadcast_eq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> {
+  %cst = constant dense<[5, 7]> : tensor<2xi32>
+  %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
+  %1 = "tf.Equal"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
+  return %1 : tensor<5x7xi1>
+  // CHECK: %[[V0:.*]] = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
+  // CHECK: %[[V0]] : tensor<5x7xi1>
+}
+
+// CHECK-LABEL: @broadcast_neq
+func @broadcast_neq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> {
+  %cst = constant dense<[5, 7]> : tensor<2xi32>
+  %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
+  %1 = "tf.NotEqual"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
+  return %1 : tensor<5x7xi1>
+  // CHECK: %[[V0:.*]] = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
+  // CHECK: %[[V0]] : tensor<5x7xi1>
+}
+
+// CHECK-LABEL: @broadcast_both_operand
+func @broadcast_both_operand(%arg0: tensor<7xf32>, %arg1: tensor<5x1xf32>) -> tensor<5x7xf32> {
+  %cst = constant dense<[5, 7]> : tensor<2xi64>
+  %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi64>) -> tensor<5x7xf32>
+  %1 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5x1xf32>, tensor<2xi64>) -> tensor<5x7xf32>
+  %2 = "tf.Add"(%0, %1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
+  return %2 : tensor<5x7xf32>
+  // CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
+  // CHECK: %[[V0]] : tensor<5x7xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt
index dbfc42a11ae..890ca03b4a7 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt
@@ -1,4 +1,4 @@
-# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
+# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1,WhileWithOutputShapes:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
 
 # Verify that TensorFlow While and StatelessWhile ops are mapped to the
 # composite While op in MLIR with is_stateless attribute set accordingly to
@@ -6,6 +6,7 @@
 
 # CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile")
 # CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile")
+# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} shape_invariant{{.*}} -> (tensor<i32>, tensor<*xf32>) loc("WhileWithOutputShapes")
 
 node {
   name: "StatefulWhile"
@@ -73,6 +74,51 @@ node {
   experimental_debug_info {
   }
 }
+node {
+  name: "WhileWithOutputShapes"
+  op: "While"
+  input: "iter"
+  input: "val"
+  attr {
+    key: "T"
+    value {
+      list {
+        type: DT_INT32
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    key: "body"
+    value {
+      func {
+        name: "body"
+      }
+    }
+  }
+  attr {
+    key: "cond"
+    value {
+      func {
+        name: "cond"
+      }
+    }
+  }
+  attr {
+    key: "output_shapes"
+    value {
+      list {
+        shape {
+        }
+        shape {
+          unknown_rank: true
+        }
+      }
+    }
+  }
+  experimental_debug_info {
+  }
+}
 node {
   name: "main"
   op: "_Retval"
@@ -107,6 +153,23 @@ node {
     }
   }
 }
+node {
+  name: "main2"
+  op: "_Retval"
+  input: "WhileWithOutputShapes:1"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "index"
+    value {
+      i: 2
+    }
+  }
+}
 node {
   name: "iter"
   op: "Placeholder"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
index 0034d3f4308..342804f1230 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
@@ -82,3 +82,20 @@ func @bias_add_nchw(%arg0: tensor<1x256x150x150xf32>, %arg1: tensor<256xf32>) ->
   %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW", device = ""} : (tensor<1x256x150x150xf32>, tensor<256xf32>) -> tensor<1x256x150x150xf32>
   return %0 : tensor<1x256x150x150xf32>
 }
+
+// CHECK-LABEL: maxpool_nchw
+func @maxpool_nchw(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> {
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
+  // CHECK: %[[R0:.*]] = "tf.Transpose"(%arg0, %[[CST]])
+  // CHECK: %[[R1:.*]] = "tf.MaxPool"(%[[R0]]) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}
+  // CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK: "tf.Transpose"(%[[R1]], %[[CST_0]])
+  %0 = "tf.MaxPool"(%arg0)
+       {
+         data_format = "NCHW",
+         ksize = [1, 1, 3, 3],
+         padding = "SAME",
+         strides = [1, 1, 2, 2]
+       } : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
+  return %0 : tensor<1x64x56x56xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
index 4be65bd0a3b..311b84f0374 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
@@ -1703,3 +1703,110 @@ func @convert_iota_3d() -> tensor<5x7x9xi32> {
   return %0 : tensor<5x7x9xi32>
 }
 
+// CHECK-LABEL:   func @convert_avgpool_valid(
+// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
+// CHECK:         }
+func @convert_avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+  %0 = mhlo.constant dense<0.0> : tensor<f32>
+  %1 = mhlo.constant dense<9.0> : tensor<4x7x7x8xf32>
+  %2 = "mhlo.reduce_window"(%arg0, %0) ( {
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %5 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%5) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<0> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+  %3 = mhlo.divide %2, %1 : tensor<4x7x7x8xf32>
+  return %3 : tensor<4x7x7x8xf32>
+}
+
+// CHECK-LABEL:   func @convert_avgpool_valid_rw(
+// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
+// CHECK:         }
+func @convert_avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+  %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
+  %1 = mhlo.constant dense<0.0> : tensor<f32>
+  %2 = "mhlo.reduce_window"(%arg0, %1) ( {
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+  %3 = "mhlo.reduce_window"(%0, %1) ( {
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+  %4 = mhlo.divide %2, %3 : tensor<4x7x7x8xf32>
+  return %4 : tensor<4x7x7x8xf32>
+}
+
+// CHECK-LABEL:   func @convert_avgpool_valid_3d(
+// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) {data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]} : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
+// CHECK:           return %[[VAL_1]] : tensor<4x7x7x7x8xf32>
+// CHECK:         }
+func @convert_avgpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
+  %0 = mhlo.constant dense<0.0> : tensor<f32>
+  %1 = mhlo.constant dense<27.0> : tensor<4x7x7x7x8xf32>
+  %2 = "mhlo.reduce_window"(%arg0, %0) ( {
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %5 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%5) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<5xi64>,
+    padding = dense<0> : tensor<5x2xi64>,
+    window_dilations = dense<1> : tensor<5xi64>,
+    window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
+    window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<4x16x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x7x8xf32>
+  %3 = mhlo.divide %2, %1 : tensor<4x7x7x7x8xf32>
+  return %3 : tensor<4x7x7x7x8xf32>
+}
+
+// CHECK-LABEL:   func @convert_avgpool_same(
+// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
+// CHECK:           return %[[VAL_1]] : tensor<4x8x8x8xf32>
+// CHECK:         }
+func @convert_avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
+  %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
+  %1 = mhlo.constant dense<0.0> : tensor<f32>
+  %2 = "mhlo.reduce_window"(%arg0, %1) ( {
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
+  %3 = "mhlo.reduce_window"(%0, %1) ( {
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
+  %4 = mhlo.divide %2, %3 : tensor<4x8x8x8xf32>
+  return %4 : tensor<4x8x8x8xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
index dc2a6e8bd05..52937fc0a5b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
@@ -244,9 +244,9 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
   // CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
   // CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
   // CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
-  // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1)
+  // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.AddV2"([[PADDINGS]]#0, [[PADDINGS]]#1)
   // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
-  // CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
+  // CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.AddV2"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
   // CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
   // CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1)
   // CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0)
@@ -338,10 +338,10 @@ func @fake_quant_with_min_max_args(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   // CHECK-DAG: [[VAL5:%.+]] = "tf.ClipByValue"(%arg0, [[VAL2]], [[VAL1]])
   // CHECK-DAG: [[VAL6:%.+]] = "tf.Sub"([[VAL5]], [[VAL2]])
   // CHECK-DAG: [[VAL7:%.+]] = "tf.Mul"([[VAL6]], [[VAL0]])
-  // CHECK-DAG: [[VAL8:%.+]] = "tf.Add"([[VAL7]], [[VAL4]])
+  // CHECK-DAG: [[VAL8:%.+]] = "tf.AddV2"([[VAL7]], [[VAL4]])
   // CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
   // CHECK-DAG: [[VAL10:%.+]] = "tf.Mul"([[VAL9]], [[VAL3]])
-  // CHECK-DAG: [[VAL11:%.+]] = "tf.Add"([[VAL10]], [[VAL2]])
+  // CHECK-DAG: [[VAL11:%.+]] = "tf.AddV2"([[VAL10]], [[VAL2]])
   %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 1.0 : f32, min = -1.0 : f32, narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32>
 
   // CHECK: return [[VAL11]]
@@ -361,7 +361,7 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>,
   // CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
   // CHECK-DAG: [[VAL10:%.+]] = "tf.Sub"([[VAL8]], [[VAL9]])
   // CHECK-DAG: [[VAL11:%.+]] = "tf.Less"([[VAL10]], [[VAL3]])
-  // CHECK-DAG: [[VAL12:%.+]] = "tf.Add"([[VAL2]], [[VAL9]])
+  // CHECK-DAG: [[VAL12:%.+]] = "tf.AddV2"([[VAL9]], [[VAL2]])
   // CHECK-DAG: [[VAL13:%.+]] = "tf.Select"([[VAL11]], [[VAL9]], [[VAL12]])
   // CHECK-DAG: [[VAL14:%.+]] = "tf.ClipByValue"([[VAL13]], [[VAL0]], [[VAL1]]) :
   // CHECK-DAG: [[VAL15:%.+]] = "tf.Sub"([[VAL0]], [[VAL14]])
@@ -371,10 +371,10 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>,
   // CHECK-DAG: [[VAL19:%.+]] = "tf.ClipByValue"(%arg0, [[VAL17]], [[VAL18]])
   // CHECK-DAG: [[VAL20:%.+]] = "tf.Sub"([[VAL19]], [[VAL17]])
   // CHECK-DAG: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL6]])
-  // CHECK-DAG: [[VAL22:%.+]] = "tf.Add"([[VAL21]], [[VAL3]])
+  // CHECK-DAG: [[VAL22:%.+]] = "tf.AddV2"([[VAL21]], [[VAL3]])
   // CHECK-DAG: [[VAL23:%.+]] = "tf.Floor"([[VAL22]])
   // CHECK-DAG: [[VAL24:%.+]] = "tf.Mul"([[VAL23]], [[VAL5]])
-  // CHECK-DAG: [[VAL25:%.+]] = "tf.Add"([[VAL24]], [[VAL17]])
+  // CHECK-DAG: [[VAL25:%.+]] = "tf.AddV2"([[VAL24]], [[VAL17]])
   %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
 
   // CHECK: return [[VAL25]]
@@ -746,7 +746,7 @@ func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> {
   // CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
   // CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
   // CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
-  // CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]])
+  // CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]])
   // CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
   %0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
 
@@ -761,7 +761,7 @@ func @round_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
   // CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
   // CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
   // CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
-  // CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]])
+  // CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]])
   // CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
   %0 = "tf.Round"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir
index 4bb324d0d85..8f5066d4162 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir
@@ -1,12 +1,13 @@
 // RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
 
-func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
-  %0:2 = tf_executor.graph {
-    %outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
-    %outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
-    tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32>
+func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) {
+  %0:3 = tf_executor.graph {
+    %outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
+    %outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
+    %outputs_6:2, %control_7 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, shape_invariant} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("WhileWithOutputShapes")
+    tf_executor.fetch %outputs_2#1, %outputs_4#1, %outputs_6#1 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
   }
-  return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32>
+  return %0#0, %0#1, %0#2 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
 }
 
 func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
@@ -36,6 +37,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
 // CHECK-NOT: name:
 // CHECK: op: "While"
 // CHECK-NOT: is_stateless
+// CHECK-NOT: shape_invariant
 // CHECK:  attr {
 // CHECK:    key: "output_shapes"
 // CHECK:    value {
@@ -54,6 +56,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
 // CHECK-NOT: name:
 // CHECK: op: "StatelessWhile"
 // CHECK-NOT: is_stateless
+// CHECK-NOT: shape_invariant
 // CHECK:  attr {
 // CHECK:    key: "output_shapes"
 // CHECK:    value {
@@ -67,3 +70,20 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
 // CHECK:    }
 // CHECK:  }
 
+// CHECK: name: "WhileWithOutputShapes"
+// CHECK-NOT: name:
+// CHECK: op: "While"
+// CHECK-NOT: is_stateless
+// CHECK-NOT: shape_invariant
+// CHECK:  attr {
+// CHECK:    key: "output_shapes"
+// CHECK:    value {
+// CHECK:      list {
+// CHECK:        shape {
+// CHECK:          dim {
+// CHECK:            size: 5
+// CHECK:          }
+// CHECK:        }
+// CHECK:      }
+// CHECK:    }
+// CHECK:  }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir
new file mode 100644
index 00000000000..44aee930fc3
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir
@@ -0,0 +1,153 @@
+// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-outside-compiled-to-host-launch | FILECHECK_OPTS="" FileCheck %s
+
+// expected-error@+1 {{'module' op bad 'tf.devices' attribute at index 0, not a string}}
+module attributes {tf.versions = {producer = 888 : i32}, tf.devices = [1]} {
+  // Tests that missing `_xla_outside_compilation` attribute value results in an error.
+  func @invalid_device_attribute() -> tensor<?xi32> {
+    %0 = "tf_device.cluster"() ( {
+      %1 = "tf.A"() : () -> tensor<?xi32>
+      %2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32>
+      tf_device.return %2 : tensor<?xi32>
+    }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
+    return %0 : tensor<?xi32>
+  }
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
+  // Tests that missing `_xla_outside_compilation` attribute value results in an error.
+  func @empty_outside_compilation_attribute() -> tensor<?xi32> {
+    %0 = "tf_device.cluster"() ( {
+      %1 = "tf.A"() : () -> tensor<?xi32>
+      // expected-error@+1 {{'tf.B' op requires non empty '_xla_outside_compilation' string attribute}}
+      %2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32>
+      tf_device.return %2 : tensor<?xi32>
+    }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
+    return %0 : tensor<?xi32>
+  }
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
+
+  // Tests that TPU cluster with no outside compilation does not generate launch op.
+
+  // CHECK-LABEL: func @no_outside_compilation
+  // CHECK-NOT: "tf_device.launch"
+  func @no_outside_compilation() -> tensor<?xi32> {
+    %0 = "tf_device.cluster"() ( {
+      %1 = "tf.A"() : () -> tensor<?xi32>
+      %2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
+      tf_device.return %2 : tensor<?xi32>
+    }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
+    return %0 : tensor<?xi32>
+  }
+
+
+  // Tests the launch wrap of a single outside compiled cluster with no input or output dependencies.
+
+  // CHECK-LABEL: func @nodep_single_outside_compilation
+  func @nodep_single_outside_compilation() -> () {
+    // CHECK:      "tf.A"
+    // CHECK:      "tf_device.launch"
+    // CHECK-NEXT:   "tf.B"
+    // CHECK-NOT:    _xla_outside_compilation
+    // CHECK-NEXT: tf_device.return
+    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
+    // CHECK: device_assignment =  [], num_cores_per_replica = 1 : i64, topology =  ""
+    "tf_device.cluster"() ( {
+      "tf.A"() : () -> ()
+      "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
+      "tf.C"() : () -> ()
+      tf_device.return
+    }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
+    return
+  }
+
+  // Tests the launch wrap of a single outside compiled cluster with data parallelism.
+
+  // CHECK-LABEL: func @single_outside_compilation_with_replicate
+  func @single_outside_compilation_with_replicate(%arg0: tensor<?xi32>) -> () {
+    // CHECK:      "tf.A"
+    // CHECK:      tf_device.replicate
+    // CHECK-NEXT:   "tf_device.cluster"
+    // CHECK-NEXT:     "tf.B"
+    // CHECK-NEXT:     "tf_device.launch"
+    // CHECK-NEXT:       "tf.C"
+    // CHECK-NOT:        _xla_outside_compilation
+    // CHECK:            tf_device.return
+    // CHECK-NEXT:     device = "TPU_REPLICATED_HOST"
+    // CHECK: device_assignment =  [], num_cores_per_replica = 1 : i64, topology =  ""
+    %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+    tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
+      "tf_device.cluster"() ( {
+        "tf.B"() : () -> ()
+        "tf.C"(%ri_0) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
+        "tf.D"() : () -> ()
+        tf_device.return
+      }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
+      tf_device.return
+    }
+    return
+  }
+
+  // Tests launch wrap of a single outside compiled cluster with input/output.
+
+  // CHECK-LABEL: func @single_outside_compilation_input_output
+  func @single_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+    %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+    // CHECK:      %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
+    // CHECK:          "tf_device.cluster"
+    // CHECK:          %[[A_OUTPUT:[0-9]*]] = "tf.A"
+    // CHECK-NEXT:     %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
+    // CHECK:            %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]])
+    // CHECK:            tf_device.return %[[B_OUTPUT]]
+    // CHECK:          "tf.C"(%[[LAUNCH_OUTPUT]])
+    %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
+      %2 = "tf_device.cluster"() ( {
+        %3 = "tf.A"() : () -> (tensor<?xi32>)
+        %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32>
+        %5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
+        tf_device.return %5 : tensor<?xi32>
+      }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
+      tf_device.return %2 : tensor<?xi32>
+    }
+
+    return %1 : tensor<?xi32>
+  }
+
+  // Tests launch wrap of multiple outside compiled cluster with input/output.
+
+  // CHECK-LABEL: func @multiple_outside_compilation_input_output
+  func @multiple_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+    %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+    // CHECK:      %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
+    // CHECK:          "tf_device.cluster"
+    // CHECK:          %[[A_OUTPUT:[0-9]*]] = "tf.A"
+    // CHECK-NEXT:     %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
+    // CHECK:            %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]])
+    // CHECK:            tf_device.return %[[B_OUTPUT]]
+    // CHECK:          %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[LAUNCH_OUTPUT]])
+    // CHECK-NEXT:     %[[LAUNCH_OUTPUT2:[0-9]*]] = "tf_device.launch"
+    // CHECK:            %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]])
+    // CHECK:            %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[D_OUTPUT]])
+    // CHECK:            tf_device.return %[[E_OUTPUT]]
+    // CHECK:          "tf.F"(%[[LAUNCH_OUTPUT2]])
+    %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
+      %2 = "tf_device.cluster"() ( {
+        %3 = "tf.A"() : () -> (tensor<?xi32>)
+        %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32>
+        %5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
+        %6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32>
+        %7 = "tf.E"(%6) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32>
+        %8 = "tf.F"(%7) : (tensor<?xi32>) -> tensor<?xi32>
+        tf_device.return %8 : tensor<?xi32>
+      }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
+      tf_device.return %2 : tensor<?xi32>
+    }
+
+    return %1 : tensor<?xi32>
+  }
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
index c2760000e82..3f58514f6e9 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
@@ -22,6 +22,7 @@ limitations under the License.
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
 #include "mlir/IR/Function.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -38,6 +39,12 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern {
 
   LogicalResult matchAndRewrite(Operation* op,
                                 PatternRewriter& rewriter) const override;
+
+ private:
+  template <typename Op>
+  LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
+
+  LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter) const;
 };
 
 class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
@@ -47,7 +54,27 @@ class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
 
 LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
-  if (!op->hasTrait<OpTrait::ResultsBroadcastableShape>()) return failure();
+  if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
+    return RewriteOp(op, rewriter);
+
+  // tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
+  // incompatible_shape_error is `true` (what is also checked by the verifier).
+  if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
+  if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
+
+  return failure();
+}
+
+template <typename Op>
+LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
+    Operation* op, PatternRewriter& rewriter) const {
+  auto eq_op = llvm::dyn_cast_or_null<Op>(op);
+  if (eq_op && eq_op.incompatible_shape_error()) return RewriteOp(op, rewriter);
+  return failure();
+}
+
+LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
+    Operation* op, PatternRewriter& rewriter) const {
   if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
     return failure();
 
@@ -56,6 +83,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
       op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
   if (!result_type || !result_type.hasStaticShape()) return failure();
 
+  bool changed = false;
   for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
     // Check that the i'th operand is a broadcast.
     auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
@@ -89,10 +117,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
     // Update the operand of the op to be the operand of the broadcast.
     rewriter.updateRootInPlace(
         op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
-    return success();
+    changed = true;
   }
-
-  return failure();
+  return success(changed);
 }
 
 void BroadcastFoldPass::runOnFunction() {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc
index a92d3f367cf..e865597c92e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc
@@ -112,8 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
 LogicalResult ConvertWhileOp(WhileOp while_op) {
   auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
       while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
-      while_op.output_shapes(), while_op.parallel_iterations(),
-      while_op.is_stateless());
+      while_op.parallel_iterations(), while_op.is_stateless(),
+      while_op.shape_invariant());
   CopyDeviceAndUnderscoredAttributes(while_op, while_region);
 
   YieldOp cond_yield =
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
index 4e9f9871964..418f6c0f1c6 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
@@ -21,15 +21,18 @@ limitations under the License.
 #include <numeric>
 #include <vector>
 
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
@@ -44,6 +47,7 @@ limitations under the License.
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/core/framework/kernel_shape_util.h"
+#include "tensorflow/core/lib/math/math_util.h"
 
 namespace mlir {
 namespace TF {
@@ -479,18 +483,16 @@ template <typename ReductionOp>
 LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
   Block &body = function.front();
   if (body.getNumArguments() != 2) return failure();
-  if (body.getOperations().size() != 2) return failure();
-
-  ReductionOp reduce_op = dyn_cast<ReductionOp>(body.front());
-  if (!reduce_op) return failure();
-  if (reduce_op.lhs() != body.getArgument(0) ||
-      reduce_op.rhs() != body.getArgument(1))
-    return failure();
 
   mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
   if (!return_op) return failure();
-  if (return_op.getNumOperands() != 1 ||
-      return_op.results().front() != reduce_op)
+  if (return_op.getNumOperands() != 1) return failure();
+
+  ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
+      return_op.getOperands().front().getDefiningOp());
+  if (!reduce_op) return failure();
+  if (reduce_op.lhs() != body.getArgument(0) ||
+      reduce_op.rhs() != body.getArgument(1))
     return failure();
 
   return success();
@@ -654,6 +656,190 @@ class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
   }
 };
 
+// Maps the following represenattions of AvgPool in MHLO into a tf.AvgPool{3D}
+// operation when they cleanly map to 2D or 3D average pool with VALID or SAME
+// padding:
+// * div(reduce_sum_window(x), constant(sizeof(window)))
+// * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
+class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
+ public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::DivOp div_op, ArrayRef<Value> args,
+      ConversionPatternRewriter &rewriter) const final {
+    auto rw =
+        dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
+    if (!rw) return failure();
+
+    // Check that the reduce-window is a sum-reduce-window.
+    if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
+      return failure();
+
+    // Check that this is a floating point reduce window with a rank of 4 or 5.
+    RankedTensorType rw_type = rw.getType().dyn_cast<RankedTensorType>();
+    if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
+        rw_type.getRank() <= 3 || rw_type.getRank() > 5)
+      return failure();
+
+    // Check that the Div op doesn't do broadcasting on the output of the reduce
+    // window.
+    if (div_op.getType() != rw.getType()) return failure();
+
+    // tf.avg_pool need at least 3 dimensions (batch, spatial, channel)
+    const uint64_t rank = rw.window_dimensions().size();
+    if (rank <= 2) return failure();
+
+    // If the init value isn't zero then it can't be an average pool.
+    if (!isFloatZero(rw.init_value())) return failure();
+
+    llvm::SmallVector<int64_t, 5> window_strides;
+    if (rw.window_strides().hasValue()) {
+      window_strides.insert(window_strides.end(),
+                            rw.window_strides()->getValues<int64_t>().begin(),
+                            rw.window_strides()->getValues<int64_t>().end());
+    } else {
+      window_strides.resize(rank, 1);
+    }
+
+    llvm::SmallVector<int64_t, 10> padding;
+    if (rw.padding().hasValue()) {
+      padding.insert(padding.begin(),
+                     rw.padding()->getValues<int64_t>().begin(),
+                     rw.padding()->getValues<int64_t>().end());
+    } else {
+      padding.resize(2 * rank, 0);
+    }
+
+    // Check that we don't do any reduction along the batch (first) and channel
+    // (last) dimensions.
+    const uint64_t batch_dim = 0;
+    const uint64_t channel_dim = rank - 1;
+    if (rw.window_dimensions().getValue<int64_t>({batch_dim}) != 1 ||
+        rw.window_dimensions().getValue<int64_t>({channel_dim}) != 1 ||
+        window_strides[batch_dim] != 1 || window_strides[channel_dim] != 1 ||
+        padding[2 * batch_dim] != 0 || padding[2 * batch_dim + 1] != 0 ||
+        padding[2 * channel_dim] != 0 || padding[2 * channel_dim + 1] != 0)
+      return failure();
+
+    if (rw.window_dilations().hasValue() &&
+        !(rw.window_dilations()->isSplat() &&
+          rw.window_dilations()->getSplatValue<APInt>() == 1))
+      return failure();
+
+    if (rw.base_dilations().hasValue() &&
+        !(rw.base_dilations()->isSplat() &&
+          rw.base_dilations()->getSplatValue<APInt>() == 1))
+      return failure();
+
+    DenseFPElementsAttr divisor;
+    if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
+      // If the divisor is a constant then check that it matches with the number
+      // of elements inside the window what is required for a VALID AvgPool.
+      if (!divisor.isSplat()) return failure();
+      int64_t window_size = 1;
+      for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
+        window_size *= w;
+      }
+      if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
+        return failure();
+
+      // Check that we have no padding.
+      if (!llvm::all_of(padding, [](int64_t i) { return i == 0; }))
+        return failure();
+
+      return replaceWithAvgPool(
+          div_op, rw.operand(),
+          llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
+          window_strides, "VALID", rewriter);
+    }
+
+    auto rw_rhs =
+        dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
+    if (rw_rhs) {
+      // Check that RHS is a sum-reduce-window.
+      if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
+        return failure();
+
+      // Check that the RHS is a reduce_window over a constant 1 input with 0 as
+      // the init value.
+      DenseFPElementsAttr rhs_input;
+      if (!isFloatZero(rw_rhs.init_value()) ||
+          !matchPattern(rw_rhs.operand(), m_Constant(&rhs_input)) ||
+          !rhs_input.isSplat() ||
+          !rhs_input.getSplatValue<APFloat>().isExactlyValue(1.0))
+        return failure();
+
+      // Check that the two reduce window have the same window configuration.
+      if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
+          rw.window_strides() != rw_rhs.window_strides() ||
+          rw.window_dilations() != rw_rhs.window_dilations() ||
+          rw.base_dilations() != rw_rhs.base_dilations() ||
+          rw.padding() != rw_rhs.padding())
+        return failure();
+
+      if (llvm::all_of(padding, [](int64_t i) { return i == 0; }))
+        return replaceWithAvgPool(
+            div_op, rw.operand(),
+            llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
+            window_strides, "VALID", rewriter);
+
+      RankedTensorType input_type =
+          rw.operand().getType().dyn_cast<RankedTensorType>();
+      RankedTensorType output_type = rw.getType().dyn_cast<RankedTensorType>();
+      if (!input_type || !output_type) return failure();
+
+      // Check that the individual padding values are corresponding to SAME
+      // padding from TensorFlow.
+      for (uint64_t i = 1; i < rank - 1; ++i) {
+        int64_t padding_size =
+            (output_type.getShape()[i] - 1) * window_strides[i] +
+            rw.window_dimensions().getValue<int64_t>({i}) -
+            input_type.getShape()[i];
+        if (padding[2 * i] !=
+                tensorflow::MathUtil::FloorOfRatio(padding_size, int64_t(2)) ||
+            padding[2 * i + 1] !=
+                tensorflow::MathUtil::CeilOfRatio(padding_size, int64_t(2)))
+          return failure();
+      }
+      return replaceWithAvgPool(
+          div_op, rw.operand(),
+          llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
+          window_strides, "SAME", rewriter);
+    }
+    return failure();
+  }
+
+ private:
+  bool isFloatZero(Value value) const {
+    DenseFPElementsAttr initial_value;
+    return matchPattern(value, m_Constant(&initial_value)) &&
+           initial_value.getNumElements() == 1 &&
+           initial_value.getValue<APFloat>({}).isZero();
+  }
+
+  LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
+                                   llvm::ArrayRef<int64_t> ksizes,
+                                   llvm::ArrayRef<int64_t> kstrides,
+                                   llvm::StringRef padding,
+                                   ConversionPatternRewriter &rewriter) const {
+    if (ksizes.size() == 4) {
+      rewriter.replaceOpWithNewOp<AvgPoolOp>(
+          op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
+          rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
+          rewriter.getStringAttr("NHWC"));
+      return success();
+    } else if (ksizes.size() == 5) {
+      rewriter.replaceOpWithNewOp<AvgPool3DOp>(
+          op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
+          rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
+          rewriter.getStringAttr("NDHWC"));
+      return success();
+    }
+    return failure();
+  }
+};
+
 class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<TF::TensorFlowDialect>();
@@ -794,10 +980,10 @@ static PassRegistration<LegalizeHloToTf> pass(
 
 void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
                                      MLIRContext *context) {
+  patterns->insert<ConvertAvgPoolOp, ConvertConvOp, ConvertSliceOp,
+                   ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
+                   ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(context);
   populateWithGenerated(context, *patterns);
-  patterns->insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
-                   ConvertReduceOpToTfMin, ConvertReduceOpToTfSum,
-                   ConvertIotaOpToTfRange>(context);
 }
 
 std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
index 9d43eb17624..2d9c9023ac1 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
@@ -332,12 +332,13 @@ class LowerDynamicStitchOp : public RewritePattern {
 class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
  public:
   explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
-      : RewritePattern(FakeQuantWithMinMaxVarsOp::getOperationName(),
-                       {SubOp::getOperationName(), ConstOp::getOperationName(),
-                        MulOp::getOperationName(), FloorOp::getOperationName(),
-                        ClipByValueOp::getOperationName(),
-                        DivOp::getOperationName(), RoundOp::getOperationName()},
-                       1, context) {}
+      : RewritePattern(
+            FakeQuantWithMinMaxVarsOp::getOperationName(),
+            {AddV2Op::getOperationName(), SubOp::getOperationName(),
+             ConstOp::getOperationName(), MulOp::getOperationName(),
+             FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
+             DivOp::getOperationName(), RoundOp::getOperationName()},
+            1, context) {}
 
   LogicalResult matchAndRewrite(Operation *src_op,
                                 PatternRewriter &rewriter) const override {
@@ -419,8 +420,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
         op.getLoc(),
         DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
 
-    quantized_input = rewriter.create<AddOp>(op.getLoc(), input_ty,
-                                             quantized_input, half_val);
+    quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
+                                               quantized_input, half_val);
 
     quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
 
@@ -428,8 +429,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
     Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
                                           quantized_input, quant_to_float);
 
-    output =
-        rewriter.create<AddOp>(op.getLoc(), input_ty, output, nudged_float_min);
+    output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
+                                      nudged_float_min);
 
     rewriter.replaceOp(op, {output});
     return success();
@@ -811,7 +812,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
                            CastOp::getOperationName(),
                            ConstOp::getOperationName(),
                            ConcatV2Op::getOperationName(),
-                           AddOp::getOperationName(),
+                           AddV2Op::getOperationName(),
                            PadOp::getOperationName(),
                            SplitOp::getOperationName(),
                            UnpackOp::getOperationName(),
@@ -907,8 +908,8 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
     auto paddings_split = rewriter.create<UnpackOp>(
         loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
         rewriter.getI64IntegerAttr(1));
-    auto paddings_sum = rewriter.create<AddOp>(loc, paddings_split.getResult(0),
-                                               paddings_split.getResult(1));
+    auto paddings_sum = rewriter.create<AddV2Op>(
+        loc, paddings_split.getResult(0), paddings_split.getResult(1));
 
     auto input_shape_tensor = rewriter.create<ConstOp>(
         loc,
@@ -918,7 +919,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
 
     // padded_shape_tensor is the shape of padded.
     auto padded_shape_tensor =
-        rewriter.create<AddOp>(loc, paddings_sum, input_shape_tensor);
+        rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
 
     auto zero_i32 = rewriter.create<ConstOp>(
         loc, GetScalarOfType(rewriter.getIntegerType(32), 0));
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
index d0bd97501e1..f667080f0c2 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
@@ -237,7 +237,7 @@ def : Pat<(TF_RoundOp:$res TF_FloatTensor:$input),
             (TF_SubOp $input, (TF_FloorOp:$floor $input)),
             (TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))),
            $floor,
-           (TF_AddOp
+           (TF_AddV2Op
             (TF_ConstOp (GetScalarOfType<1> $input)), $floor))>;
 
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/outside_compiled_to_host_launch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/outside_compiled_to_host_launch.cc
new file mode 100644
index 00000000000..225f6c0cc15
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/outside_compiled_to_host_launch.cc
@@ -0,0 +1,184 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
+
+namespace mlir {
+namespace TFTPU {
+
+namespace {
+
+constexpr char kDeviceAttr[] = "device";
+constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
+
+// Mapping for `_xla_outside_compilation` attribute to ops of a cluster.
+using OutsideClusterMap =
+    llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<Operation*, 8>, 8>;
+
+// This pass wraps ops with the same `_xla_outside_compilation`
+// attribute value in a tf_device.launch op with host device assignment.
+//
+// A simple example:
+//   "tf_device.cluster"() ( {
+//     "tf.A"()
+//     "tf.B"() {_xla_outside_compilation = "cluster1"}
+//     "tf.C"()
+//     tf_device.return
+//   }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []}
+//
+// Would become the following ops (unimportant attribute, type are omitted):
+//   "tf_device.cluster"() ( {
+//     "tf.A"()
+//     "tf_device.launch"() {
+//       "tf.B"() {_xla_outside_compilation = "cluster1"}
+//       tf_device.return
+//     } {device = "TPU_REPLICATED_HOST"} : () -> ()
+//     "tf.C"()
+//     tf_device.return
+//   }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []}
+//
+
+struct OutsideCompiledToHostLaunch
+    : public PassWrapper<OutsideCompiledToHostLaunch, OperationPass<ModuleOp>> {
+  void runOnOperation() override;
+};
+
+// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
+// attribute into `clusters` This returns an error if a
+// `_xla_outside_compilation` attribute of an op is empty.
+LogicalResult CollectAndGroupOutsideClusterOps(Block* block,
+                                               OutsideClusterMap* clusters) {
+  auto walk_result = block->walk([&](Operation* op) {
+    if (auto attr = op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
+      if (attr.getValue().empty()) {
+        op->emitOpError() << "requires non empty '"
+                          << kXlaOutsideCompilationAttr << "' string attribute";
+        return WalkResult::interrupt();
+      }
+
+      auto it = clusters->try_emplace(attr.getValue());
+      it.first->getSecond().push_back(op);
+    }
+    return WalkResult::advance();
+  });
+
+  return failure(walk_result.wasInterrupted());
+}
+
+// Extracts all externally used outputs of `cluster_ops`.
+llvm::SmallSetVector<Value, 4> GetExternalOutputs(
+    llvm::ArrayRef<Operation*> cluster_ops) {
+  llvm::SmallSetVector<Value, 4> external_outputs;
+  llvm::SmallPtrSet<Operation*, 4> host_cluster_ops_set;
+  for (auto op : cluster_ops) {
+    op->walk([&](Operation* host_cluster_op) {
+      host_cluster_ops_set.insert(host_cluster_op);
+    });
+  }
+
+  for (Operation* op : cluster_ops) {
+    for (Operation* user : op->getUsers()) {
+      bool is_external = llvm::none_of(
+          host_cluster_ops_set,
+          [&](Operation* cluster_op) { return user == cluster_op; });
+      if (!is_external) continue;
+      for (Value v : user->getOperands()) {
+        if (v.getDefiningOp() == op) external_outputs.insert(v);
+      }
+    }
+  }
+
+  return external_outputs;
+}
+
+void WrapClusterInLaunch(llvm::ArrayRef<Operation*> cluster_ops,
+                         llvm::StringRef host_device) {
+  auto* last_cluster_op = cluster_ops.back();
+  OpBuilder builder(last_cluster_op);
+  llvm::SmallVector<Type, 4> launch_output_types;
+  auto external_outputs = GetExternalOutputs(cluster_ops);
+  for (const auto& external_output : external_outputs)
+    launch_output_types.push_back(external_output.getType());
+
+  auto launch_op = builder.create<tf_device::LaunchOp>(
+      last_cluster_op->getLoc(), builder.getStringAttr(host_device),
+      /*result_types=*/launch_output_types);
+  for (auto result : llvm::zip(external_outputs, launch_op.getResults())) {
+    std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
+  }
+
+  launch_op.body().push_back(new Block);
+  builder.setInsertionPointToEnd(&launch_op.GetBody());
+  auto* return_op =
+      builder
+          .create<tf_device::ReturnOp>(last_cluster_op->getLoc(),
+                                       external_outputs.getArrayRef())
+          .getOperation();
+  MLIRContext* context = launch_op.getContext();
+  for (Operation* cluster_op : cluster_ops) {
+    cluster_op->removeAttr(
+        Identifier::get(kXlaOutsideCompilationAttr, context));
+    cluster_op->removeAttr(Identifier::get(kDeviceAttr, context));
+    cluster_op->moveBefore(return_op);
+  }
+}
+
+void OutsideCompiledToHostLaunch::runOnOperation() {
+  // Get runtime devices information from the closest parent module.
+  auto module = getOperation();
+  mlir::TF::RuntimeDevices devices;
+  if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
+    return signalPassFailure();
+
+  auto result = module.walk([&](tf_device::ClusterOp tpu_cluster) {
+    OutsideClusterMap clusters;
+    if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
+                                                &clusters)))
+      return WalkResult::interrupt();
+
+    if (clusters.empty()) return WalkResult::advance();
+
+    std::string host_device;
+    tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster,
+                                                &host_device);
+    for (const auto& cluster : clusters) {
+      WrapClusterInLaunch(cluster.getSecond(), host_device);
+    }
+    return WalkResult::advance();
+  });
+  if (result.wasInterrupted()) return signalPassFailure();
+}
+
+}  // anonymous namespace
+
+std::unique_ptr<OperationPass<ModuleOp>>
+CreateOutsideCompiledToHostLaunchPass() {
+  return std::make_unique<OutsideCompiledToHostLaunch>();
+}
+
+static PassRegistration<OutsideCompiledToHostLaunch> pass(
+    "tf-outside-compiled-to-host-launch",
+    "Wraps ops with ithe same _xla_outside_compiled attribute in "
+    "tf_device.launch on replicated host device.");
+
+}  // namespace TFTPU
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index 3293ef7810e..25945182c20 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -345,6 +345,11 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUColocateCompositeResourceOps();
 // run-time according to compilation result.
 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass();
 
+// Creates a pass that wraps ops with the same `_xla_outside_compilation`
+// attribute value in a tf_device.launch op with host device assignment.
+std::unique_ptr<OperationPass<ModuleOp>>
+CreateOutsideCompiledToHostLaunchPass();
+
 // Creates a pass that groups outside compiled operations (CPU ops inside TPU
 // cluster) into clusters that can be extracted and run on the CPU.
 std::unique_ptr<OperationPass<ModuleOp>>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
index 2363663cb5a..a61fa650973 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
@@ -398,8 +398,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
   OpBuilder builder(while_region);
   auto while_op = builder.create<WhileOp>(
       while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
-      while_region.output_shapes(), while_region.parallel_iterations(),
-      while_region.is_stateless());
+      while_region.parallel_iterations(), while_region.is_stateless(),
+      while_region.shape_invariant());
   CopyDeviceAndUnderscoredAttributes(while_region, while_op);
 
   // Redirect old results to new results.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
index 7953dfe1832..caf4e6b9197 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
@@ -255,8 +255,7 @@ TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
                                   OpBuilder& builder) {
   auto host_side_while = builder.create<TF::WhileRegionOp>(
       loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
-      /*output_shapes=*/builder.getArrayAttr({}), parallel_iterations,
-      is_stateless);
+      parallel_iterations, is_stateless, /*shape_invariant=*/false);
 
   // Create empty else branch region.
   auto& body = host_side_while.body();
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc
index 0057e498cea..125994b928d 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc
@@ -155,6 +155,9 @@ StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore(
   if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst))
     attrs_to_ignore.insert("is_stateless");
 
+  if (llvm::isa<mlir::TF::WhileOp>(inst))
+    attrs_to_ignore.insert("shape_invariant");
+
   return attrs_to_ignore;
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index 4242cdd349e..986b570f4fc 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -971,6 +971,16 @@ StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
                                        etype.getContext()));
   }
 
+  if (node.IsWhileNode()) {
+    auto* output_shapes = node.attrs().Find("output_shapes");
+    auto* element_types = node.attrs().Find("T");
+    if (output_shapes && !output_shapes->list().shape().empty()) {
+      const auto& output_shape = output_shapes->list().shape(idx);
+      const auto& element_type = element_types->list().type(idx);
+      return ConvertToMlirTensorType(output_shape, element_type, &builder);
+    }
+  }
+
   // Returns a simple, more conservative unranked tensor type.
   auto default_type = [&]() -> StatusOr<mlir::Type> {
     mlir::Type element_type;
@@ -1907,7 +1917,13 @@ Status ImporterBase::ConvertNode(const Node& node) {
   // Case/If/While op in MLIR and add the differentiating attribute.
   if (node.IsCaseNode()) composite_control_flow_op("Case");
   if (node.IsIfNode()) composite_control_flow_op("If");
-  if (node.IsWhileNode()) composite_control_flow_op("While");
+  if (node.IsWhileNode()) {
+    composite_control_flow_op("While");
+    auto* output_shapes = node.attrs().Find("output_shapes");
+    if (output_shapes && !output_shapes->list().shape().empty())
+      result.attributes.push_back(
+          builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
+  }
 
   // Register the mapping between the TF node and the newly created operation.
   node_values_[node.id()] =
@@ -3420,6 +3436,8 @@ SavedModelSignatureDefImporterLite::ConvertGraph(
     const std::vector<std::pair<std::string, TensorInfo>>& inputs,
     const std::vector<std::pair<std::string, TensorInfo>>& outputs,
     const std::vector<std::string> control_outputs) {
+  VLOG(1) << "Importing Signature: " << name;
+
   GraphImportConfig specs;
   specs.prune_unused_nodes = true;
   specs.inputs = ParseInputArrays(inputs);
@@ -3491,6 +3509,9 @@ SavedModelSignatureDefImporterLite::ParseInputArrays(
     // Only dense tensor is supported.
     DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName);
 
+    VLOG(1) << "Importing Signature Input: input_name = " << iter.first
+            << ", tensor_info = " << tensor_info.DebugString();
+
     ArrayInfo array_info;
     array_info.imported_dtype = tensor_info.dtype();
     array_info.shape = tensor_info.tensor_shape();
diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD
index 735ee64596e..d088adc3f9b 100644
--- a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD
+++ b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD
@@ -68,6 +68,7 @@ distribute_py_test(
     tags = [
         "no_cuda_asan",  # b/173431253
         "no_oss",
+        "notap",  # b/173661843
         "notsan",  # b/173246447
     ],
     xla_enable_strict_auto_jit = False,  # b/173254861
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
index e7e9025c4aa..a8a99f1cd10 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
@@ -90,8 +90,17 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
     pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
   }
 
-  // Legalize only hlo operations to lhlo, keep the rest as tensors.
-  pm.addPass(mlir::kernel_gen::transforms::CreateHloBufferizePass());
+  // Partial bufferization: Transforms inparticular HLO operation to their
+  // corresponding LHLO operations and converts the function signature. Leaves
+  // shape operations untouched.
+  pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass(
+      /*allow_partial_bufferization=*/true));
+  // Run CSE to ensure that loads and stores to the same location get recognized
+  // as such.
+  pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
+  // Forward stores to buffers to loads.
+  pm.addNestedPass<mlir::FuncOp>(xla::mlir_gpu::createStoreForwardingPass());
+
   // Clean up the IR for further processing.
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
@@ -100,18 +109,22 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
   // needed.
   llvm::SmallVector<unsigned, 4> tiling_for_unrolling;
   llvm::SmallVector<int64_t, 4> as_int64;
-  if (!unroll_factors.empty()) {
-    tiling_for_unrolling.reserve(tile_sizes.size());
-    for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
-      tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
-      as_int64.push_back(std::get<1>(pair));
-    }
-  } else {
-    tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
+  tiling_for_unrolling.reserve(tile_sizes.size());
+  for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
+    tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
+    as_int64.push_back(std::get<1>(pair));
   }
+  tiling_for_unrolling.append(
+      tile_sizes.drop_front(unroll_factors.size()).begin(), tile_sizes.end());
   // Transform LHLO operations to LinAlg.
   pm.addNestedPass<mlir::FuncOp>(
       ::mlir::lmhlo::createLegalizeLhloToLinalgPass());
+  if (!gpu_binary_only) {
+    // Find candidates for buffer reuse. This is only successful if buffer size
+    // equality can be determined based on `linalg.generic` operations.
+    pm.addNestedPass<mlir::FuncOp>(
+        mlir::kernel_gen::transforms::CreateBufferReusePass());
+  }
   // Fuse linalg operations.
   pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass(
       /*use_parallel_loops=*/true, tiling_for_unrolling));
@@ -141,11 +154,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
   // Some basic cleanup.
   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
-  if (!gpu_binary_only) {
-    // Find candidates for buffer reuse.
-    pm.addNestedPass<mlir::FuncOp>(
-        mlir::kernel_gen::transforms::CreateBufferReusePass());
-  }
   // Greedily map the remaining loop to GPU hardware dimensions.
   pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass());
 
@@ -157,7 +165,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
   pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
-  pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass());
+  pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass());
   pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass(64));
   // TODO(herhut): Enabled this to avoid leaks once fixed.
   // pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass());
@@ -189,10 +197,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
   pm.addPass(::mlir::createLowerAffinePass());
   // Constraints are removed as late as possible and before lowering to CFG.
   pm.addNestedPass<::mlir::FuncOp>(::mlir::createConvertShapeConstraintsPass());
-  if (embed_memref_prints) {
-    pm.addNestedPass<::mlir::FuncOp>(
-        mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass());
-  }
   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
   // TODO(herhut): Remove this pass once the LowerToCFG pass can handle it.
   pm.addNestedPass<mlir::FuncOp>(
@@ -200,6 +204,10 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
   pm.addPass(::mlir::createLowerToCFGPass());
   // Map allocs, asserts, etc. to the tensorflow framework.
   pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass());
+  if (embed_memref_prints) {
+    pm.addNestedPass<::mlir::FuncOp>(
+        mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass());
+  }
   if (failed(pm.run(module))) {
     return InternalError("Lowering to GPU kernels failed.");
   }
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_reuse.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_reuse.mlir
index b297c72af5d..bb86e25414d 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_reuse.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_reuse.mlir
@@ -49,12 +49,58 @@ func @local_reuse_with_memref_maps(
     iterator_types = ["parallel"]
   } ins(%arg : memref<?xi64, offset: 2, strides: [3]>)
     outs(%result : memref<?xi64, offset: 2, strides: [3]>) {
-  ^bb0(%a: i64, %b: i64):
+  ^bb0(%a : i64, %b : i64):
     linalg.yield %a : i64
   }
   return %result : memref<?xi64, offset: 2, strides: [3]>
 }
 
+// CHECK-LABEL: @memref_reinterpret_cast_alias
+func @memref_reinterpret_cast_alias(%arg : memref<f32>, %n : index)
+    -> memref<?xf32> attributes {tf_entry} {
+  %c0 = constant 0 : index
+  %reinterpreted = memref_reinterpret_cast %arg to
+      offset: [0],
+      sizes: [%n],
+      strides: [%c0]: memref<f32> to memref<?xf32>
+
+  // CHECK: alloc
+  // CHECK-SAME: reuse_input_candidates = [0 : index]
+  %result = alloc(%n) : memref<?xf32>
+
+  // reinterpreted (arg) and result are of same size.
+  linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]
+  } ins(%reinterpreted : memref<?xf32>) outs(%result : memref<?xf32>) {
+  ^bb0(%a : f32, %b : f32):
+    linalg.yield %a : f32
+  }
+
+  return %result : memref<?xf32>
+}
+
+// CHECK-LABEL: @memref_cast_alias
+func @memref_cast_alias(%arg : memref<*xf32>, %n : index)
+    -> memref<?xf32> attributes {tf_entry} {
+  %casted = memref_cast %arg : memref<*xf32> to memref<?xf32>
+
+  // CHECK: alloc
+  // CHECK-SAME: reuse_input_candidates = [0 : index]
+  %result = alloc(%n) : memref<?xf32>
+
+  // reinterpreted (arg) and result are of same size.
+  linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]
+  } ins(%casted : memref<?xf32>) outs(%result : memref<?xf32>) {
+  ^bb0(%a : f32, %b : f32):
+    linalg.yield %a : f32
+  }
+
+  return %result : memref<?xf32>
+}
+
 // CHECK-LABEL: @indirect_size_equality
 func @indirect_size_equality(%arg0 : memref<?xi64>,
                              %arg1 : memref<?xi64>,
@@ -65,7 +111,7 @@ func @indirect_size_equality(%arg0 : memref<?xi64>,
     indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
     iterator_types = ["parallel"]
   } ins(%arg0 : memref<?xi64>) outs(%arg1 : memref<?xi64>) {
-  ^bb0(%a: i64, %b: i64):
+  ^bb0(%a : i64, %b : i64):
     linalg.yield %a : i64
   }
 
@@ -78,7 +124,7 @@ func @indirect_size_equality(%arg0 : memref<?xi64>,
     indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
     iterator_types = ["parallel"]
   } ins(%arg0 : memref<?xi64>) outs(%result : memref<?xi64>) {
-  ^bb0(%a: i64, %b: i64):
+  ^bb0(%a : i64, %b : i64):
     linalg.yield %a : i64
   }
 
@@ -317,7 +363,7 @@ func @abs_unranked_i64(%arg : memref<*xi64>,
     indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
     iterator_types = ["parallel"]
   } ins(%flat_arg : memref<?xi64>) outs(%flat_result : memref<?xi64>) {
-  ^bb0(%a: i64, %b: i64):
+  ^bb0(%a : i64, %b : i64):
     %c0 = constant 0 : i64
     %a_pos = cmpi "sge", %a, %c0 : i64
     %a_neg = subi %c0, %a : i64
@@ -360,3 +406,41 @@ func @index_element_type(%arg : memref<2x3xindex>) -> memref<2x3xindex>
   %result = alloc() : memref<2x3xindex>
   return %result : memref<2x3xindex>
 }
+
+// Example as it occurs in the `tf.Abs` kernel for `f32`.
+// CHECK-LABEL: @abs_f32
+func @abs_f32(%arg0: memref<*xf32>) -> memref<*xf32>
+    attributes {llvm.emit_c_interface, tf_entry} {
+  %c0 = constant 0 : index
+  %0 = shape.shape_of %arg0 : memref<*xf32> -> tensor<?xindex>
+  %1 = shape.num_elements %0 : tensor<?xindex> -> index
+  // CHECK-LABEL: alloc
+  // CHECK-SAME: reuse_input_candidates = []
+  %2 = alloc() : memref<1xindex>
+  store %1, %2[%c0] : memref<1xindex>
+  %3 = memref_reshape %arg0(%2)
+      : (memref<*xf32>, memref<1xindex>) -> memref<?xf32>
+  %4 = dim %3, %c0 : memref<?xf32>
+  %5 = index_cast %4 : index to i64
+  // CHECK-LABEL: alloc
+  // CHECK-SAME: reuse_input_candidates = []
+  %6 = alloc() : memref<1xi64>
+  store %5, %6[%c0] : memref<1xi64>
+  %7 = load %6[%c0] : memref<1xi64>
+  %8 = index_cast %7 : i64 to index
+  // CHECK-LABEL: alloc
+  // CHECK-SAME: reuse_input_candidates = [0 : index]
+  %9 = alloc(%8) : memref<?xf32>
+  linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]
+  } ins(%3 : memref<?xf32>) outs(%9 : memref<?xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
+    %12 = absf %arg1 : f32
+    linalg.yield %12 : f32
+  }
+  %10 = tensor_to_memref %0 : memref<?xindex>
+  %11 = memref_reshape %9(%10)
+      : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
+  return %11 : memref<*xf32>
+}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
index 244fa24dc6b..eca3d586294 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: kernel-gen-opt %s --final-bufferize | FileCheck %s
+// RUN: kernel-gen-opt %s --bufferize | FileCheck %s
 
 // CHECK-LABEL: @extract_element
 // CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> f32
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir
index d251c40981a..27912cb7b57 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --canonicalize --shape-to-descriptors --canonicalize --hlo-bufferize --std-bufferize --func-bufferize | FileCheck %s
+// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --canonicalize --shape-to-descriptors --canonicalize --bufferize | FileCheck %s
 
 // Test whether all shape computations required for isinf can be lowered to
 // the standard dialect, scf and descriptors.
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir
index 1b119f56377..c89e1ca7362 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir
@@ -24,7 +24,7 @@ func @print_memrefs(
   return %output : memref<*xf16>
 }
 
-// CHECK:   func private @print_memref_index(memref<*xindex>)
+// CHECK:   func private @print_memref_i64(memref<*xi64>)
 
 // CHECK-LABEL: func @print_memrefs
 
@@ -33,12 +33,16 @@ func @print_memrefs(
 // CHECK: [[NUM_ELEM:%.*]] = alloca() : memref<1xindex>
 // CHECK: store {{%.*}}, [[NUM_ELEM]]
 
-// CHECK: [[UNRANKED_NUM_ELEM:%.*]] = memref_cast [[NUM_ELEM]]
-// CHECK-NEXT: call @print_memref_index([[UNRANKED_NUM_ELEM]])
+// CHECK: [[NUM_ELEM_I64:%.*]] = index_cast [[NUM_ELEM]]
+// CHECK-SAME: : memref<1xindex> to memref<1xi64>
+// CHECK-NEXT: [[UNRANKED_NUM_ELEM:%.*]] = memref_cast [[NUM_ELEM_I64]]
+// CHECK-NEXT: call @print_memref_i64([[UNRANKED_NUM_ELEM]])
 
 // CHECK: memref_reshape
 // CHECK: tf_framework.alloc
 
-// CHECK: [[UNRANKED_SHAPE:%.*]] = memref_cast [[SHAPE]]
-// CHECK-NEXT: call @print_memref_index([[UNRANKED_SHAPE]])
+// CHECK: [[SHAPE_I64:%.*]] = index_cast [[SHAPE]]
+// CHECK-SAME: : memref<?xindex> to memref<?xi64>
+// CHECK-NEXT: [[UNRANKED_SHAPE:%.*]] = memref_cast [[SHAPE_I64]]
+// CHECK-NEXT: call @print_memref_i64([[UNRANKED_SHAPE]])
 // CHECK: memref_reshape
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir
index 6a8f4107b65..6b9f011a109 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --shape-to-descriptors --canonicalize --hlo-bufferize --std-bufferize --scf-bufferize --func-bufferize | FileCheck %s
+// RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --shape-to-descriptors --canonicalize --bufferize | FileCheck %s
 
 // Test whether all shape computations required for tanh can be lowered to
 // the standard dialect, scf and descriptors. We check for a sparse pattern here,
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir
index 9b3baf69488..2fc585d9e9d 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --hlo-bufferize
+// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --bufferize
 
 func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> {
   %0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
index 7d263169d60..62dd39b8897 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
@@ -75,7 +75,7 @@ extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) {
 extern "C" void _mlir_ciface_tf_report_error(void* op_kernel_ctx,
                                              int32_t error_code, char* msg) {
   Optional<ErrorCode> symbol = symbolizeErrorCode(error_code);
-  if (symbol.hasValue()) {
+  if (!symbol.hasValue()) {
     LOG(ERROR) << "No valid conversion from integer value = " << error_code
                << "to ErrorCode attribute";
     return;
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc
index a184b80ea44..ac72f6b157c 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc
@@ -55,12 +55,14 @@ namespace {
 /// A temporary buffer size analysis that is correct but may be incomplete.
 class BufferSizeAnalysis {
  public:
-  explicit BufferSizeAnalysis(FuncOp f) { build(f); }
+  BufferSizeAnalysis(FuncOp f, const BufferAliasAnalysis &aliases) {
+    build(f, aliases);
+  }
 
   bool is_same_size(Value a, Value b) { return ecs_.isEquivalent(a, b); }
 
  private:
-  void build(FuncOp &f) {
+  void build(FuncOp &f, const BufferAliasAnalysis &aliases) {
     auto buffers = find_buffer_values(f);
 
     // Memrefs with statically known same shape and same symbol-free affine maps
@@ -102,10 +104,16 @@ class BufferSizeAnalysis {
       }
     });
 
-    // Operand and result of `reshape_memref_cast` must be of same size.
-    f.walk([&](MemRefReshapeOp reshapeOp) {
-      ecs_.unionSets(reshapeOp.result(), reshapeOp.source());
-    });
+    // All aliases of a memref must be of the same underlying buffer size.
+    for (auto e : aliases) {
+      Value value = e.getFirst();
+      if (!value.getType().isa<BaseMemRefType>()) continue;
+      for (Value alias : e.getSecond()) {
+        assert(alias.getType().isa<BaseMemRefType>() &&
+               "Expected aliases of memref to be memrefs.");
+        ecs_.unionSets(value, alias);
+      }
+    }
   }
 
   bool affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as,
@@ -181,7 +189,7 @@ class BufferReuseAnalysis {
 
   void find_reuse_candiates(FuncOp &f, BufferAliasAnalysis &aliases) {
     Liveness liveness(f);
-    BufferSizeAnalysis size_equivalences(f);
+    BufferSizeAnalysis size_equivalences(f, aliases);
     f.walk([&](Block *block) {
       find_reuse_candiates(block, aliases, liveness.getLiveness(block),
                            size_equivalences, f.getArguments());
@@ -204,50 +212,53 @@ class BufferReuseAnalysis {
 
       // Find reuse candidates for the regarded allocation.
       SmallVector<int64_t, 2> local_reuse_candidates;
-      for (auto it : llvm::enumerate(arguments)) {
-        int64_t old_buffer_index = it.index();
-        Value old_buffer = it.value();
+      for (BlockArgument old_buffer : arguments) {
         if (!old_buffer.getType().isa<BaseMemRefType>()) continue;
 
-        // Will not reuse buffers of different size as they may be too small.
+        // Size criterion: Do not reuse buffers of different size as they may be
+        // too small.
         if (!size_equivalences.is_same_size(new_buffer, old_buffer)) continue;
 
-        // Only reuse buffers that are no longer used on first reuse, i.e. they
-        // are no longer alive.
-        bool livetimes_compatible = true;
+        // Lifetime criterion: Only reuse buffers that are no longer used on
+        // first reuse, i.e. they are no longer alive.
+        bool lifetimes_compatible = true;
         for (Value old_buffer_alias : aliases.resolve(old_buffer)) {
           if (first_reuse == nullptr) {
             // If the first use is beyond the end of this block we look at the
             // block end. An argument buffer that is already reusable there is
-            // certainly reusable at any later actual use.
+            // certainly reusable at any later actual use. Otherwise, lifetimes
+            // are incompatible.
             if (liveness->isLiveOut(old_buffer_alias)) {
-              livetimes_compatible = false;
+              lifetimes_compatible = false;
               break;
             }
           } else {
-            // A buffer is *not* reusable if
-            //   i)  its last use is after the point of reuse, or
-            //   ii) its last use is also its first reuse but the operation
-            //       does not allow for local reuse.
+            // A buffer is reusable if
+            //   i)  its last use is before the point of reuse, or
+            //   ii) its last use is also its first reuse and the operation
+            //       allows for local reuse.
+            // Otherwise, lifetimes are incompatible.
             Operation *last_use =
                 liveness->getEndOperation(old_buffer_alias, &block->front());
             assert(last_use != nullptr && last_use->getBlock() == block &&
                    "Expected last use in same block.");
             if (first_reuse->isBeforeInBlock(last_use)) {
-              livetimes_compatible = false;
+              lifetimes_compatible = false;
               break;
             }
             if (first_reuse == last_use &&
                 !can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) {
-              livetimes_compatible = false;
+              lifetimes_compatible = false;
               break;
             }
           }
         }
 
-        // All criteria are fulfilled 🙂.
-        if (livetimes_compatible)
+        if (lifetimes_compatible) {
+          // All criteria are fulfilled 🙂.
+          int64_t old_buffer_index = old_buffer.getArgNumber();
           local_reuse_candidates.push_back(old_buffer_index);
+        }
       }
 
       reuse_candidates_[&op] = local_reuse_candidates;
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
index 46dbfbb8405..6199ea14b4b 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
@@ -48,40 +48,6 @@ namespace kernel_gen {
 namespace transforms {
 namespace {
 
-#define GEN_PASS_CLASSES
-#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
-
-struct HloBufferizePass : public HloBufferizePassBase<HloBufferizePass> {
-  // TODO(b/173201243): Move to tablegen.
-  void getDependentDialects(DialectRegistry& registry) const override {
-    registry.insert<lmhlo::LmhloDialect>();
-  }
-
- public:
-  void runOnOperation() override {
-    OwningRewritePatternList patterns;
-    auto& context = getContext();
-    ConversionTarget target(context);
-    target.addLegalDialect<lmhlo::LmhloDialect>();
-    target.addLegalDialect<StandardOpsDialect>();
-    target.addIllegalDialect<mhlo::MhloDialect>();
-
-    BufferizeTypeConverter converter;
-    // Configure bufferize pattern for functions and lhlo.
-    mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
-
-    // Configure legality and structural patterns.
-    populateBufferizeMaterializationLegality(target);
-    populateShapeStructuralTypeConversionsAndLegality(&context, converter,
-                                                      patterns, target);
-    scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
-                                                         patterns, target);
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      signalPassFailure();
-  }
-};
-
 // TODO(herhut) : This could become a real pattern in bufferize pass. What we
 // would need to do is insert a copy to model the semantics correctly. The same
 // is true for the TensorLoad pattern that is already in there.  Then buffer
@@ -104,11 +70,34 @@ class UnrankedTensorStoreTestOnlyPattern
   }
 };
 
-struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
+// TODO(frgossen): Move this upstream to `populateFuncOpTypeConversionPattern`
+// This pattern is merely needed to materialize type casts for return values so
+// that they match the function signature conversion.
+class ReturnOpTypeConversionPattern
+    : public OpConversionPattern<mlir::ReturnOp> {
+ public:
+  using OpConversionPattern<mlir::ReturnOp>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      ReturnOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter& rewriter) const final {
+    rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
+    return success();
+  }
+};
+
+#define GEN_PASS_CLASSES
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
+
+struct BufferizePass : public BufferizePassBase<BufferizePass> {
+  explicit BufferizePass(bool allow_partial_bufferization) {
+    allow_partial_bufferization_ = allow_partial_bufferization;
+  }
+
   // TODO(b/173201243): Move to tablegen.
   void getDependentDialects(DialectRegistry& registry) const override {
     registry.insert<AffineDialect, scf::SCFDialect, shape::ShapeDialect,
-                    tf_framework::TFFrameworkDialect>();
+                    tf_framework::TFFrameworkDialect, lmhlo::LmhloDialect>();
   }
 
  public:
@@ -117,12 +106,17 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
     ConversionTarget target(context);
     target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
                            tf_framework::TFFrameworkDialect, AffineDialect,
-                           shape::ShapeDialect>();
+                           shape::ShapeDialect, lmhlo::LmhloDialect>();
     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+
     target.addIllegalDialect<mhlo::MhloDialect>();
     target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
-                        TensorFromElementsOp, TensorCastOp, TensorLoadOp,
-                        TensorToMemrefOp>();
+                        TensorFromElementsOp, TensorCastOp>();
+
+    if (!allow_partial_bufferization_) {
+      target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
+    }
+
     // Certain operations are no longer legal on tensors but otherwise are.
     target.addDynamicallyLegalOp<ConstantOp, SelectOp>([&](Operation* op) {
       return llvm::none_of(op->getResultTypes(),
@@ -144,8 +138,8 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
       return converter.isLegal(inputs) && converter.isLegal(results) &&
              converter.isLegal(&op.getBody());
     });
-    target.addDynamicallyLegalOp<CallOp, ConstantOp, DimOp, RankOp, SelectOp>(
-        typesAreLegal);
+    target.addDynamicallyLegalOp<CallOp, ConstantOp, DimOp, RankOp, SelectOp,
+                                 ReturnOp>(typesAreLegal);
 
     OwningRewritePatternList patterns;
     mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
@@ -160,22 +154,27 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
     scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
                                                          patterns, target);
     patterns.insert<UnrankedTensorStoreTestOnlyPattern>(&context);
+    patterns.insert<ReturnOpTypeConversionPattern>(converter, &context);
 
     auto module = getOperation();
-    if (failed(applyFullConversion(module, target, std::move(patterns)))) {
-      signalPassFailure();
+    if (allow_partial_bufferization_) {
+      if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
+        signalPassFailure();
+      }
+    } else {
+      if (failed(
+              mlir::applyFullConversion(module, target, std::move(patterns)))) {
+        signalPassFailure();
+      }
     }
   }
 };
 
 }  // namespace
 
-std::unique_ptr<OperationPass<ModuleOp> > CreateHloBufferizePass() {
-  return std::make_unique<HloBufferizePass>();
-}
-
-std::unique_ptr<OperationPass<ModuleOp> > CreateFinalBufferizePass() {
-  return std::make_unique<FinalBufferizePass>();
+std::unique_ptr<OperationPass<ModuleOp> > CreateBufferizePass(
+    bool allow_partial_bufferization) {
+  return std::make_unique<BufferizePass>(allow_partial_bufferization);
 }
 
 }  // namespace transforms
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc
index 1d845f5ae5d..50d514eeb0a 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc
@@ -37,9 +37,8 @@ namespace {
 
 using tf_framework::TFFrameworkDialect;
 
-Operation* emitCallToFunc(Location loc, StringRef func_name,
-                          ArrayRef<Type> return_types, ValueRange args,
-                          OpBuilder* b) {
+Operation* emitCallToPrint(Location loc, StringRef func_name, Value arg,
+                           OpBuilder* b) {
   auto caller_func =
       b->getInsertionBlock()->getParent()->getParentOfType<FuncOp>();
   auto callee_func =
@@ -49,12 +48,12 @@ Operation* emitCallToFunc(Location loc, StringRef func_name,
 
     auto module = caller_func.getParentOfType<ModuleOp>();
     b->setInsertionPointToStart(module.getBody());
-    auto func_type =
-        FunctionType::get(args.getTypes(), return_types, b->getContext());
+    auto func_type = FunctionType::get(arg.getType(), /*results=*/llvm::None,
+                                       b->getContext());
     callee_func = b->create<FuncOp>(module.getLoc(), func_name, func_type);
     callee_func.setPrivate();
   }
-  return b->create<CallOp>(loc, callee_func, args);
+  return b->create<CallOp>(loc, callee_func, arg);
 }
 
 void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) {
@@ -70,28 +69,32 @@ void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) {
       liveness.getLiveness(op->getBlock())->getEndOperation(memref, op);
   b->setInsertionPoint(end_op);
 
+  if (element_type.isIndex()) {
+    element_type = b->getI64Type();
+    memref_type = MemRefType::get(memref_type.getShape(), element_type,
+                                  memref_type.getAffineMaps(),
+                                  memref_type.getMemorySpace());
+    memref = b->create<IndexCastOp>(loc, memref, memref_type);
+  }
+
   auto unranked_type =
       UnrankedMemRefType::get(element_type, memref_type.getMemorySpace());
   Value unranked_memref = b->create<MemRefCastOp>(loc, memref, unranked_type);
 
   if (element_type.isF32()) {
-    emitCallToFunc(loc, "print_memref_f32", {}, {unranked_memref}, b);
+    emitCallToPrint(loc, "print_memref_f32", unranked_memref, b);
     return;
   }
   if (element_type.isF64()) {
-    emitCallToFunc(loc, "print_memref_f64", {}, {unranked_memref}, b);
-    return;
-  }
-  if (element_type.isIndex()) {
-    emitCallToFunc(loc, "print_memref_index", {}, {unranked_memref}, b);
+    emitCallToPrint(loc, "print_memref_f64", unranked_memref, b);
     return;
   }
   if (element_type.isInteger(32)) {
-    emitCallToFunc(loc, "print_memref_i32", {}, {unranked_memref}, b);
+    emitCallToPrint(loc, "print_memref_i32", unranked_memref, b);
     return;
   }
-  if (element_type.isInteger(64)) {
-    emitCallToFunc(loc, "print_memref_i64", {}, {unranked_memref}, b);
+  if (element_type.isInteger(64) || element_type.isIndex()) {
+    emitCallToPrint(loc, "print_memref_i64", unranked_memref, b);
     return;
   }
 }
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
index e2d92c3c8c3..c067d2b779d 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
@@ -47,13 +47,10 @@ std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass();
 // using memref descriptors.
 std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass();
 
-// Pass to tranform hlo-level computations on values to their corresponding
-// parts on buffers.
-std::unique_ptr<OperationPass<ModuleOp>> CreateHloBufferizePass();
-
-// Pass to tranform late-dialect level computations (essentially all non-hlo
-// dialects) on values to their corresponding parts on buffers.
-std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass();
+// Pass to tranform operations on values to their corresponding parts on
+// buffers.
+std::unique_ptr<OperationPass<ModuleOp>> CreateBufferizePass(
+    bool allow_partial_bufferization = false);
 
 // Pass to materialize broadcasts.
 std::unique_ptr<FunctionPass> CreateMaterializeBroadcastsPass();
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
index b6f8202217a..96e66297771 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
@@ -38,16 +38,14 @@ def ShapeToDescriptorsPass : Pass<"shape-to-descriptors", "ModuleOp"> {
   let constructor = "transforms::CreateShapeToDescriptorsPass()";
 }
 
-def HloBufferizePass : Pass<"hlo-bufferize", "ModuleOp"> {
-  let summary = "Pass to transform hlo operations on values to buffer based "
-                "ones.";
-  let constructor = "transforms::CreateHloBufferizePass()";
-}
-
-def FinalBufferizePass : Pass<"final-bufferize", "ModuleOp"> {
-  let summary = "Pass to transform operations from all non-hlo dialects on "
-                "values to buffer-based ones.";
-  let constructor = "transforms::CreateFinalBufferizePass()";
+def BufferizePass : Pass<"bufferize", "ModuleOp"> {
+  let summary = "Pass to transform operations on values to buffer-based ones.";
+  let options = [
+    Option<"allow_partial_bufferization_", "allow-partial-bufferization",
+           "bool", /*default=*/"false", "Allow partial bufferization. "
+           "Value-based operations may remain, e.g. for shape operations.">,
+  ];
+  let constructor = "transforms::CreateBufferizePass()";
 }
 
 def MaterializeBroadcastsPass : FunctionPass<"materialize-broadcast"> {
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index c1e08284d21..5dc89adffc4 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -140,6 +140,7 @@ cc_library(
         ":translate_cl_options",
         "//tensorflow/compiler/mlir/hlo",
         "//tensorflow/compiler/mlir/hlo:lhlo",
+        "//tensorflow/compiler/mlir/hlo:lhlo_gpu",
         "//tensorflow/compiler/xla:debug_options_flags",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:util",
@@ -148,6 +149,8 @@ cc_library(
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_casting_utils",
         "//tensorflow/compiler/xla/service:hlo_parser",
+        "//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
+        "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
index 95595d850d2..395d4bb8f9f 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
@@ -63,8 +63,11 @@ HloModule SelectAndScatter
   ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10)
 }
 
+// CHECK-LABEL: module
+// CHECK: global_memref "private" constant @[[$GLOBAL:.*]] : memref<f32> = dense<0.000000e+00>
 // CHECK-LABEL: func @main
-// CHECK: "lmhlo.select_and_scatter"
+// CHECK: %[[GLOBAL_MEMREF:.*]] = get_global_memref @[[$GLOBAL]] : memref<f32>
+// CHECK: "lmhlo.select_and_scatter"(%{{.*}}, %{{.*}}, %[[GLOBAL_MEMREF]], %{{.*}})
 // CHECK: ^bb0(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>):
 // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG0]], %[[ARG1]]) {comparison_direction = "GE"}
 // CHECK: "mhlo.return"(%[[COMPARE]]) : (tensor<i1>) -> ()
@@ -78,7 +81,7 @@ HloModule SelectAndScatter
 ENTRY main () -> f32[6] {
   %operand = f32[6]{0} parameter(0)
   %source = f32[2]{0} parameter(1)
-  %init = f32[] parameter(2)
+  %init = f32[] constant(0)
   ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32
 }
 
@@ -101,4 +104,20 @@ ENTRY main {
                                                   s32[] %static),
                                       custom_call_target="SliceToDynamic",
                                       backend_config=""
-}
\ No newline at end of file
+}
+
+// -----
+
+
+HloModule Cholesky
+
+// CHECK-LABEL: func @main
+// CHECK: "lmhlo_gpu.cholesky"
+// CHECK-SAME: is_lower = true
+ENTRY main {
+  %param = f32[3,3] parameter(0)
+  ROOT %custom-call = (f32[3,3], f32[3], s32[]) custom-call(f32[3,3] %param),
+                                custom_call_target="__cusolver$cholesky",
+                                operand_layout_constraints={f32[3,3]},
+                                backend_config="{\"lower\":true}"
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
index e7312e2114c..cd72707123c 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
@@ -374,3 +374,23 @@ func @main(%arg0: tuple<tuple<tensor<f32>>, tensor<f32>>, %arg1: tuple<tensor<f3
 
   return %result : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
 }
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK:   "lmhlo.reduce"({{.*}}) ( {
+// CHECK:   ^bb0(%[[VAL1:.*]]: tensor<f32>, %[[VAL2:.*]]: tensor<i32>, %[[VAL3:.*]]: tensor<f32>, %[[VAL4:.*]]: tensor<i32>):  // no predecessors
+// CHECK:     %[[VAL5:.*]] = mhlo.maximum %[[VAL1]], %[[VAL3]] : tensor<f32>
+// CHECK:     %[[VAL6:.*]] = mhlo.maximum %[[VAL2]], %[[VAL4:.*]] : tensor<i32>
+// CHECK:     %[[VAL7:.*]] = "mhlo.tuple"(%[[VAL5]], %[[VAL6:.*]]) : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
+// CHECK:     "mhlo.return"(%[[VAL7:.*]]) : (tuple<tensor<f32>, tensor<i32>>) -> ()
+// CHECK:   })
+func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
+  %result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
+    ^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>):   // no predecessors
+      %fmax = "mhlo.maximum"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+      %imax = "mhlo.maximum"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      "mhlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
+    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
+  return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir
index 7f37dbb0479..f24b1b6e6bc 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir
@@ -13,10 +13,8 @@
 // CHECK-LABEL: func @add
 func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
   // CHECK-NEXT:  %[[SUM0:.*]] = mhlo.add %arg0, %arg0 : tensor<2xi32>
-  // CHECK-NEXT:  %[[SUM1:.*]] = mhlo.add %[[SUM0]], %arg0 : tensor<2xi32>
-  // CHECK-NEXT:  return %[[SUM1]] : tensor<2xi32>
-  %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
-  %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  // CHECK-NEXT:  return %[[SUM0]] : tensor<2xi32>
+  %1 = "tf.AddV2"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   return %1: tensor<2xi32>
 }
 
@@ -27,7 +25,7 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
 func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
   // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
   // CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
-  %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
   return %0: tensor<1x2xi32>
 }
 
@@ -37,7 +35,7 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
 func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
   // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
   // CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
-  %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
   return %0: tensor<4x4x4x4xi32>
 }
 
@@ -55,7 +53,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
   // CHECK-NEXT:   %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
   // CHECK-NEXT:   %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32>
   // CHECK-NEXT:   shape.assuming_yield %[[RESULT]]
-  %0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
   return %0: tensor<?x?xi32>
 }
 
@@ -64,7 +62,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
 func @broadcast_add_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
   // CHECK: tf.Add
   // CHLO: chlo.broadcast_add %arg0, %arg1
-  %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
   return %0: tensor<*xi32>
 }
 
@@ -264,6 +262,13 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1>
   return %0: tensor<*xi1>
 }
 
+// CHECK-LABEL: func @equal_unsupported_type
+func @equal_unsupported_type(%arg0: tensor<*x!tf.string>, %arg1: tensor<*x!tf.string>) -> tensor<*xi1> {
+  // CHECK: "tf.Equal"
+  %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*x!tf.string>, tensor<*x!tf.string>) -> tensor<*xi1>
+  return %0: tensor<*xi1>
+}
+
 // CHECK-LABEL: func @notequal
 func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
   // CHECK-NEXT:  "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"}
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir
index 0660af4ed1c..0646a391eca 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir
@@ -26,7 +26,7 @@ func @tf_unknown_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
 // -----
 
 func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
-  %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  %0 = "tf.AddV2"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   return %0: tensor<2xi32>
 }
 
@@ -38,7 +38,7 @@ func @tf_unknown_known_mix(%arg0: tensor<2xi32>) -> tensor<2xi32> {
   // expected-error@+1 {{'tf.OpA' op is not legalizable}}
   %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   %1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
-  %2 = "tf.Add"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  %2 = "tf.AddV2"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   %3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   return %2: tensor<2xi32>
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index fd6ed5a0219..5061de06c6f 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -3984,31 +3984,35 @@ func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) {
 // tf.Unpack legalization
 //===----------------------------------------------------------------------===//
 
-// TODO(b/156340000): Re-enable when fixed.
-// // C-HECK-LABEL: @unpack
-// func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) {
-//   // C-HECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
-//   // C-HECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32>
-//   // C-HECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
-//   // C-HECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
-//   // C-HECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
-//   // C-HECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
+// CHECK-LABEL: @unpack
+func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) {
+  // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
+  // CHECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
+  // CHECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
+  // CHECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
 
-//   %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
-//   // return %[[RES1]], %[[RES2]], %[[RES3]]
-//   return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>
-// }
+  %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
+  // return %[[RES1]], %[[RES2]], %[[RES3]]
+  return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>
+}
 
-// // C-HECK-LABEL: @unpack_dynamic
-// func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
-//   // C-HECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32>
-//   // C-HECK: "mhlo.reshape"(%[[SLICE1]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32>
-//   // C-HECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32>
-//   // C-HECK: "mhlo.reshape"(%[[SLICE2]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32>
+// CHECK-LABEL: @unpack_dynamic
+func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
 
-//   %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
-//   return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
-// }
+  // CHECK: tf.Unpack
+  %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
+  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @unpack_unranked
+func @unpack_unranked(%input: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+
+  // CHECK: tf.Unpack
+  %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
+  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
 
 //===----------------------------------------------------------------------===//
 // tf.UnsortedSegment{Max|Min|Prod|Sum} legalization
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index fe75a2a8dff..f30fdeb0991 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -3311,7 +3311,7 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
       auto input_val = GetScalarConstOfType(begin_element_ty, loc,
                                             input_shape[d], &rewriter);
       auto wrapped_index =
-          rewriter.create<TF::AddOp>(loc, input_val, reshaped_index);
+          rewriter.create<TF::AddV2Op>(loc, input_val, reshaped_index);
       auto final_index = rewriter.create<SelectOp>(
           loc, type, index_negative, wrapped_index, reshaped_index);
       slice_begin_indices.push_back(final_index);
@@ -4808,7 +4808,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
 
   LogicalResult matchAndRewrite(TF::UnpackOp op,
                                 PatternRewriter &rewriter) const override {
-    auto value_type = op.value().getType().cast<RankedTensorType>();
+    auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
     if (!value_type) return failure();
 
     int64_t value_rank = value_type.getRank();
@@ -4820,7 +4820,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
     auto end_indices = llvm::to_vector<4>(value_type.getShape());
     SmallVector<int64_t, 4> strides(value_rank, 1);
 
-    // All HLO slice+reshape results used to replace the original tf.Unpack op.
+    // All HLO slice+squeeze results used to replace the original tf.Unpack op.
     SmallVector<Value, 4> results;
     results.reserve(op.getNumResults());
 
@@ -4833,9 +4833,10 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
           GetI64ElementsAttr(end_indices, &rewriter),
           GetI64ElementsAttr(strides, &rewriter));
       // Reshape to drop the axis dimension.
-      auto reshape_op = rewriter.create<mhlo::ReshapeOp>(
-          op.getLoc(), op.getType(i), slice_op);
-      results.push_back(reshape_op);
+      auto result =
+          rewriter.create<TF::SqueezeOp>(op.getLoc(), op.getType(i), slice_op,
+                                         rewriter.getI64ArrayAttr(op.axis()));
+      results.push_back(result);
     }
 
     rewriter.replaceOp(op, results);
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index 11af809ffb7..113d88158d2 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -89,8 +89,7 @@ class DirectBinaryPat<Op FromOp, Op ToOp>
   : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
         (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
 
-foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp],
-                         [TF_AddV2Op, HLOClient_BroadcastAddOp],
+foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp],
                          [TF_DivOp, HLOClient_BroadcastDivOp],
                          [TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp],
                          [TF_MaximumOp, HLOClient_BroadcastMaxOp],
@@ -225,7 +224,7 @@ class EqualityPat<Op FromOp, StrEnumAttrCase direction>
         (HLOClient_BroadcastCompareOp
          $l, $r, (BinBroadcastDimensions $l, $r), direction,
          (HLO_DEFAULT_COMPARISON_TYPE)),
-        [(AreBroadcastCompatible $l, $r)]>;
+        [(AreBroadcastCompatible $l, $r), (HLO_Tensor $l)]>;
 
 def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>;
 def : EqualityPat<TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE>;
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index edc7edf8068..505bdcdcf48 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
 
+#include <climits>
 #include <memory>
 #include <tuple>
 
@@ -38,6 +39,7 @@ limitations under the License.
 #include "mlir/Pass/PassOptions.h"  // from @llvm-project
 #include "mlir/Translation.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
@@ -46,18 +48,21 @@ limitations under the License.
 #include "tensorflow/compiler/xla/debug_options_flags.h"
 #include "tensorflow/compiler/xla/service/backend.h"
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/util.h"
 
 using xla::BufferAllocation;
 using xla::BufferAssignment;
 using xla::HloComputation;
+using xla::HloCustomCallInstruction;
 using xla::HloInstruction;
 using xla::HloModule;
 using xla::HloModuleProto;
@@ -140,8 +145,9 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
 class XlaHloToLhloPass
     : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
   void getDependentDialects(DialectRegistry& registry) const override {
-    registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
-                    mlir::lmhlo::LmhloDialect>();
+    registry
+        .insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
+                mlir::lmhlo::LmhloDialect, mlir::lmhlo_gpu::LmhloGpuDialect>();
   }
 
  public:
@@ -274,6 +280,10 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
       return EmitSelectAndScatterOp(instr);
     case HloOpcode::kCustomCall:
       return EmitCustomCallOp(instr);
+    case HloOpcode::kConstant:
+      return EmitConstant(instr);
+    case HloOpcode::kReduce:
+      return EmitReduceOp(instr);
     default:
       llvm::errs() << instr->ToString();
       return tensorflow::errors::Internal(
@@ -485,13 +495,18 @@ StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
   return select_and_scatter;
 }
 
-StatusOr<lmhlo::CustomCallOp> LhloDialectEmitter::EmitCustomCallOp(
+StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
     HloInstruction* instr) {
+  auto* custom_call_instr = ::xla::Cast<::xla::HloCustomCallInstruction>(instr);
+
+  if (xla::gpu::IsCustomCallToCusolver(*instr)) {
+    return EmitCholesky(custom_call_instr);
+  }
+
   size_t num_arguments, num_results;
   TF_ASSIGN_OR_RETURN(auto custom_call,
                       CreateOpWithoutAttrs<lmhlo::CustomCallOp>(
                           instr, num_arguments, num_results));
-  auto* custom_call_instr = ::xla::Cast<::xla::HloCustomCallInstruction>(instr);
   custom_call.call_target_nameAttr(
       builder_.getStringAttr(custom_call_instr->custom_call_target()));
   custom_call.backend_configAttr(
@@ -500,12 +515,102 @@ StatusOr<lmhlo::CustomCallOp> LhloDialectEmitter::EmitCustomCallOp(
                                static_cast<int32_t>(num_results)};
   custom_call.setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
                       builder_.getI32VectorAttr(segments));
-  return custom_call;
+  return custom_call.getOperation();
+}
+
+StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
+    HloCustomCallInstruction* custom_call) {
+  TF_ASSIGN_OR_RETURN(auto cholesky_op,
+                      CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
+  TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
+                      custom_call->backend_config<xla::CholeskyOptions>());
+  cholesky_op.is_lowerAttr(builder_.getBoolAttr(options.lower()));
+  return cholesky_op;
+}
+
+// Convert an XLA HLO constant to a global_memref + get_global_memref pair.
+StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant(
+    const HloInstruction* instr) {
+  // Insert a global_memref in the module.
+  Location loc = getLocation(instr);
+
+  auto const_instr = ::xla::Cast<::xla::HloConstantInstruction>(instr);
+  TF_RET_CHECK(const_instr->shape().IsArray() &&
+               const_instr->shape().is_static());
+  TF_ASSIGN_OR_RETURN(Type type, ::xla::ConvertShapeToType<MemRefType>(
+                                     const_instr->shape(), builder_));
+  auto memref_type = type.dyn_cast<MemRefType>();
+  TF_RET_CHECK(memref_type != nullptr);
+
+  TF_ASSIGN_OR_RETURN(
+      DenseElementsAttr initial_value,
+      CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_));
+
+  std::string constant_name = ::xla::llvm_ir::ConstantHloToGlobalName(*instr);
+
+  // Insert the global memref at the top level.
+  {
+    OpBuilder::InsertionGuard guard(builder_);
+    builder_.clearInsertionPoint();
+    auto global_var = builder_.create<GlobalMemrefOp>(
+        loc, constant_name, builder_.getStringAttr("private"),
+        TypeAttr::get(memref_type), initial_value, true);
+    SymbolTable(module_).insert(global_var);
+    global_var.getOperation()->moveBefore(&module_.front());
+  }
+
+  auto get_global_memref =
+      builder_.create<GetGlobalMemrefOp>(loc, memref_type, constant_name);
+
+  // For operations that do not fold this constant value in their codegen, we
+  // still need to materialize it into a buffer. Since buffer allocation is
+  // already done, annotate the get_global_memref with the information to get to
+  // the allocated buffer slice for this constant if need be.
+
+  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
+                      assignment_.GetUniqueTopLevelSlice(instr));
+  get_global_memref.setAttr("lmhlo.alloc",
+                            builder_.getIndexAttr(slice.index()));
+  get_global_memref.setAttr("lmhlo.slice_offset",
+                            builder_.getI64IntegerAttr(slice.offset()));
+  get_global_memref.setAttr("lmhlo.slice_size",
+                            builder_.getI64IntegerAttr(slice.size()));
+
+  // Update the cache to remember this value.
+  auto& cached_value = slices_[std::make_pair(instr, ::xla::ShapeIndex())];
+  TF_RET_CHECK(cached_value == nullptr);
+  cached_value = get_global_memref;
+  return get_global_memref;
+}
+
+StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp(
+    HloInstruction* instr) {
+  TF_ASSIGN_OR_RETURN(auto reduce_op,
+                      CreateOpWithoutAttrs<lmhlo::ReduceOp>(instr));
+  auto* reduce = ::xla::Cast<::xla::HloReduceInstruction>(instr);
+  std::vector<int64_t> dimensions(reduce->dimensions().begin(),
+                                  reduce->dimensions().end());
+  reduce_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
+  TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
+      *instr->called_computations()[0], &reduce_op.body(), &builder_));
+  return reduce_op;
 }
 
 StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
     const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
     const ::xla::ShapeIndex& shape_index) {
+  // Cache generated ViewOp and StaticMemRefCastOp by (instruction,
+  // shape_index).
+  auto& cached_value = slices_[std::make_pair(instr, shape_index)];
+  if (cached_value) {
+    return cached_value;
+  }
+
+  if (instr->IsConstant() && shape_index.empty()) {
+    TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr));
+    return cached_value = constant_memref;
+  }
+
   // If the shape happens to have dynamic dimensions, create the memref using
   // the underlying static shape.
   // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
@@ -518,7 +623,7 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
                       assignment_.GetUniqueSlice(instr, shape_index));
   Value alloc = allocations_[slice.allocation()];
   if (alloc.getType() == out_type && slice.offset() == 0) {
-    return alloc;
+    return cached_value = alloc;
   }
 
   auto out_memref_type = out_type.dyn_cast<MemRefType>();
@@ -527,13 +632,6 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
         "Expected memref type when creating a view for leaf type of a "
         "tuple.");
 
-  // Cache generated ViewOp and StaticMemRefCastOp by (instruction,
-  // shape_index).
-  auto& cached_value = slices_[std::make_pair(instr, shape_index)];
-  if (cached_value) {
-    return cached_value;
-  }
-
   Value byte_shift =
       builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset());
 
@@ -695,8 +793,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
 Status HloToLhloModule(const BufferAssignment& assignment,
                        const HloModule& hlo_module, ModuleOp module) {
   module.getContext()
-      ->loadDialect<StandardOpsDialect, mhlo::MhloDialect,
-                    lmhlo::LmhloDialect>();
+      ->loadDialect<StandardOpsDialect, mhlo::MhloDialect, lmhlo::LmhloDialect,
+                    lmhlo_gpu::LmhloGpuDialect>();
   HloComputation* computation = hlo_module.entry_computation();
 
   LhloDialectEmitter emitter(assignment, *computation, module);
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
index 29a4c67d366..169cbdc93c5 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
@@ -16,13 +16,15 @@ limitations under the License.
 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
 #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
 
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 
 namespace mlir {
@@ -55,8 +57,18 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
   ::xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(::xla::HloInstruction* instr);
   ::xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp(
       ::xla::HloInstruction* instr);
-  ::xla::StatusOr<lmhlo::CustomCallOp> EmitCustomCallOp(
-      ::xla::HloInstruction* instr);
+
+  ::xla::StatusOr<Operation*> EmitCustomCallOp(::xla::HloInstruction* instr);
+  ::xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky(
+      ::xla::HloCustomCallInstruction* custom_call);
+
+  ::xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(::xla::HloInstruction* instr);
+  ::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
+      ::xla::HloInstruction* instr) {
+    return EmitConstant(static_cast<const ::xla::HloInstruction*>(instr));
+  }
+  ::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
+      const ::xla::HloInstruction* instr);
 
   ::xla::Status CreateOperands(::xla::HloInstruction* instr,
                                SmallVectorImpl<Value>& operands,
@@ -122,7 +134,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
                                               OpBuilder* b, Location loc);
 
   // Return an MLIR location for an HLO instruction.
-  Location getLocation(::xla::HloInstruction* inst) {
+  Location getLocation(const ::xla::HloInstruction* inst) {
     return NameLoc::get(builder_.getIdentifier(inst->name()),
                         builder_.getContext());
   }
@@ -180,7 +192,7 @@ tensorflow::Status HloToLhloModule(const ::xla::BufferAssignment& assignment,
                                    ModuleOp module);
 
 OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
-                                               mlir::MLIRContext* context);
+                                               MLIRContext* context);
 
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index 99d7730bfe5..db631b9c5bb 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -85,6 +85,9 @@ namespace tensorflow {
 namespace tensorrt {
 namespace convert {
 
+using absl::StrAppend;
+using absl::StrCat;
+
 bool IsEngineInput(absl::string_view name) {
   return absl::StartsWith(name, IONamePrefixes::kInputPHName);
 }
@@ -92,47 +95,6 @@ bool IsEngineOutput(absl::string_view name) {
   return absl::StartsWith(name, IONamePrefixes::kOutputPHName);
 }
 
-using absl::StrAppend;
-using absl::StrCat;
-
-inline Status TfDataTypeToTrt(DataType tf_dtype,
-                              nvinfer1::DataType* trt_dtype) {
-  switch (tf_dtype) {
-    case DataType::DT_FLOAT:
-      *trt_dtype = nvinfer1::DataType::kFLOAT;
-      break;
-    case DataType::DT_HALF:
-      *trt_dtype = nvinfer1::DataType::kHALF;
-      break;
-    case DataType::DT_INT32:
-      *trt_dtype = nvinfer1::DataType::kINT32;
-      break;
-    default:
-      return errors::InvalidArgument("Unsupported data type ",
-                                     DataTypeString(tf_dtype));
-  }
-  return Status::OK();
-}
-
-inline Status TrtDataTypeToTf(nvinfer1::DataType trt_dtype,
-                              DataType* tf_dtype) {
-  switch (trt_dtype) {
-    case nvinfer1::DataType::kFLOAT:
-      *tf_dtype = DataType::DT_FLOAT;
-      break;
-    case nvinfer1::DataType::kHALF:
-      *tf_dtype = DataType::DT_HALF;
-      break;
-    case nvinfer1::DataType::kINT32:
-      *tf_dtype = DataType::DT_INT32;
-      break;
-    default:
-      return errors::InvalidArgument("Unsupported data type ",
-                                     DebugString(trt_dtype));
-  }
-  return Status::OK();
-}
-
 class TFAttrs {
  public:
   explicit TFAttrs(const NodeDef& tf_node) {
@@ -182,7 +144,7 @@ std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
 template <>
 nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
   nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
-  TF_CHECK_OK(TfDataTypeToTrt(this->at(key)->type(), &trt_dtype));
+  TF_CHECK_OK(TfTypeToTrtType(this->at(key)->type(), &trt_dtype));
   return trt_dtype;
 }
 
@@ -271,7 +233,7 @@ Status ValidateTensorProperties(const string& producer_node_type,
                                 nvinfer1::DataType* trt_dtype,
                                 nvinfer1::Dims* trt_dims, int* batch_size) {
   // Convert data type.
-  TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, trt_dtype));
+  TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, trt_dtype));
 
   // Convert shape.
   if (shape.dims() < 0) {
@@ -512,7 +474,7 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
   TFAttrs attrs(params->node_def);
   if (attrs.count(dtype_attr_name)) {
     DataType dtype = attrs.get<DataType>(dtype_attr_name);
-    TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_type));
+    TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, &trt_type));
   }
 
   // In order to be broadcastable, the number of dims has to match.
@@ -1091,7 +1053,7 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype,
   DataType tf_dtype;
   // TODO(laigd): make it return a status.
   TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape));
-  TF_CHECK_OK(TrtDataTypeToTf(trt_dtype, &tf_dtype));
+  TF_CHECK_OK(TrtTypeToTfType(trt_dtype, &tf_dtype));
   // TODO(jie): check weights size_bytes. 0 means type error
   Tensor tensor(tf_dtype, shape);
   TRT_ShapedWeights weights(trt_dtype, dims, tensor);
@@ -2621,6 +2583,7 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input,
   nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation(
       const_cast<nvinfer1::ITensor* const*>(concat_inputs.data()),
       concat_inputs.size());
+  SetLayerName(concat_layer, params->node_def, "concat", op_instance);
   concat_layer->setAxis(0);
   nvinfer1::ITensor* new_shape = concat_layer->getOutput(0);
   // Reshape input using new shape
@@ -4219,7 +4182,7 @@ Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
 
   // Verify that the dtype is supported by TensorRT. Otherwise, return an error.
   nvinfer1::DataType trt_dtype;
-  TF_RETURN_IF_ERROR(TfDataTypeToTrt(converted_dtype, &trt_dtype));
+  TF_RETURN_IF_ERROR(TfTypeToTrtType(converted_dtype, &trt_dtype));
 
   if (tensor.NumElements() == 0) {
     // Return empty weights.
@@ -6244,7 +6207,7 @@ Status ConvertGraphDefToEngine(
       TFAttrs attrs(node_def);
       DataType tf_dtype = attrs.get<DataType>("T");
       nvinfer1::DataType trt_dtype;
-      TF_RETURN_IF_ERROR(TfDataTypeToTrt(tf_dtype, &trt_dtype));
+      TF_RETURN_IF_ERROR(TfTypeToTrtType(tf_dtype, &trt_dtype));
       if (output_tensors.size() <= slot_number) {
         output_tensors.resize(slot_number + 1);
       }
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index a33d5c28cb2..1d60ebbd047 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -135,20 +135,6 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) {
   return os;
 }
 
-nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) {
-  nvinfer1::DataType trt_type;
-  Status status = TfTypeToTrtType(tf_type, &trt_type);
-  EXPECT_EQ(status, Status::OK());
-  return trt_type;
-}
-
-DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) {
-  DataType tf_type;
-  Status status = TrtTypeToTfType(trt_type, &tf_type);
-  EXPECT_EQ(status, Status::OK());
-  return tf_type;
-}
-
 NodeDef MakeNodeDef(const string& name, const string& op,
                     const std::vector<string>& inputs,
                     const std::map<string, AttrValue> attrs = {}) {
@@ -1048,8 +1034,10 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) {
 
 template <typename T>
 void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
-  TRT_ShapedWeights weights = weight_store->GetTempWeights(
-      TfDataTypeToTrt(DataTypeToEnum<T>::v()), GetTestDims({2, 3}));
+  nvinfer1::DataType trt_type;
+  TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &trt_type));
+  TRT_ShapedWeights weights =
+      weight_store->GetTempWeights(trt_type, GetTestDims({2, 3}));
   const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
   memcpy(weights.GetValues(), values.data(), weights.size_bytes());
 
@@ -1445,7 +1433,8 @@ class OpConverterTest : public ::testing::Test {
       ASSERT_NE(-1, input_index);
       const nvinfer1::DataType trt_dtype =
           engine_->getBindingDataType(input_index);
-      const DataType tf_type = TrtDataTypeToTf(trt_dtype);
+      DataType tf_type;
+      TF_ASSERT_OK(TrtTypeToTfType(trt_dtype, &tf_type));
       ASSERT_EQ(data.tensor.dtype(), tf_type)
           << DataTypeString(data.tensor.dtype()) << " vs. "
           << DataTypeString(tf_type);
@@ -1457,8 +1446,9 @@ class OpConverterTest : public ::testing::Test {
     // Mark the output tensor as TRT engine output.
     std::vector<Converter::EngineOutputInfo> output_info;
     for (const auto& data : *output_data) {
-      output_info.push_back(
-          {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())});
+      nvinfer1::DataType trt_type;
+      TF_RETURN_IF_ERROR(TfTypeToTrtType(data.tensor.dtype(), &trt_type));
+      output_info.push_back({data.name, data.name, trt_type});
     }
     TF_RETURN_IF_ERROR(converter_->RenameAndMarkOutputTensors(output_info));
 
@@ -1519,7 +1509,8 @@ class OpConverterTest : public ::testing::Test {
       const string& name, const std::vector<int32>& dims,
       nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT,
       Status add_input_status = Status::OK()) {
-    DataType tf_type = TrtDataTypeToTf(trt_type);
+    DataType tf_type;
+    TF_ASSERT_OK(TrtTypeToTfType(trt_type, &tf_type));
     ops::Placeholder::Attrs attrs;
     TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_));
 
@@ -1566,7 +1557,8 @@ class OpConverterTest : public ::testing::Test {
     node_inputs_[name] = ops::Const(scope_.WithOpName(name), t);
 
     // Add weights for conversion.
-    const nvinfer1::DataType dtype = TfDataTypeToTrt(DataTypeToEnum<T>::v());
+    nvinfer1::DataType dtype;
+    TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &dtype));
     const nvinfer1::Dims trt_dims = GetTestDims(dims);
     const int64_t num_elements = TrtWeightDimsNumElements(trt_dims);
     QCHECK_EQ(num_elements, values.size())
@@ -1800,8 +1792,9 @@ class ParameterizedOpConverterTestBase
         partial_shape = dims;
       }
     }
-    AddTestTensorWithTFDims(name, partial_shape, TfDataTypeToTrt(tf_type),
-                            add_input_status);
+    nvinfer1::DataType trt_type;
+    TF_ASSERT_OK(TfTypeToTrtType(tf_type, &trt_type));
+    AddTestTensorWithTFDims(name, partial_shape, trt_type, add_input_status);
     if (!values.empty()) {
       VLOG(2) << "Adding test tensor: " << name << " "
               << DataTypeString(tf_type);
@@ -2032,7 +2025,7 @@ TEST_F(OpConverterTest, ConvertConst) {
     Reset();
     NodeDef node_def = MakeConstNodeDef<double>("my_const", {});
     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
-                               "Unsupported data type double");
+                               "Unsupported tensorflow data type double");
   }
   {
     Reset();
@@ -2805,8 +2798,9 @@ void TestAddN(OpConverterTest* test) {
     test->Reset();
     DataVec input_data;
     for (const auto name : {"inp1", "inp2", "inp3"}) {
-      test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/2,
-                          TfDataTypeToTrt(dtype));
+      nvinfer1::DataType trt_type;
+      TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
+      test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/2, trt_type);
       input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2),
                                                          CType(3), CType(4)})});
     }
@@ -2828,8 +2822,9 @@ void TestAddN(OpConverterTest* test) {
     test->Reset();
     DataVec input_data;
     for (const auto name : {"inp1", "inp2"}) {
-      test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/1,
-                          TfDataTypeToTrt(dtype));
+      nvinfer1::DataType trt_type;
+      TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
+      test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/1, trt_type);
       input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2)})});
     }
     test->AddTestWeights("inp3", /*dims=*/{1, 1, 2},
@@ -4252,8 +4247,9 @@ TEST_P(OpConverterTest1, ConvertConv2D) {
     Reset();
     NodeDef node_def = get_conv2d_nodedef();
     // Channel dim unknown, should fail.
-    AddTestTensorWithTFDims("input", {-1, -1, -1, -1},
-                            TfDataTypeToTrt(tf_type_));
+    nvinfer1::DataType trt_type;
+    TF_ASSERT_OK(TfTypeToTrtType(tf_type_, &trt_type));
+    AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, trt_type);
     AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1});
     RunValidationAndConversion(
         node_def, error::INVALID_ARGUMENT,
@@ -5018,8 +5014,9 @@ TEST_F(OpConverterTest, ConvertTopK) {
     {
       // K is a tensor, should fail.
       Reset();
-      AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1,
-                    /*trt_dtype=*/TfDataTypeToTrt(dtype));
+      nvinfer1::DataType trt_type;
+      TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
+      AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1, trt_type);
       AddTestTensor("weights", {2});
       RunValidationAndConversion(
           node_def, error::UNIMPLEMENTED,
@@ -5590,8 +5587,10 @@ void TestConvertConcat(OpConverterTest* test) {
     NodeDef node_def = get_concat_nodedef(dtype, num_inputs);
     // Create inputs.
     for (int j = 0; j < num_inputs; ++j) {
+      nvinfer1::DataType trt_type;
+      TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
       test->AddTestTensor(StrCat("values_", j), ok_params[i].input_shapes[j], 1,
-                          TfDataTypeToTrt(dtype));
+                          trt_type);
     }
     test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
     test->RunValidationAndConversion(node_def);
@@ -5752,8 +5751,9 @@ void TestConvertSplit(OpConverterTest* test) {
     NodeDef node_def = get_split_nodedef(dtype, ok_params[i].num_split);
     // Create inputs.
     test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
-    test->AddTestTensor("value", ok_params[i].input_shape, 1,
-                        TfDataTypeToTrt(dtype));
+    nvinfer1::DataType trt_type;
+    TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
+    test->AddTestTensor("value", ok_params[i].input_shape, 1, trt_type);
     // Convert.
     test->RunValidationAndConversion(node_def);
 
@@ -5929,8 +5929,9 @@ void TestConvertUnpack(OpConverterTest* test) {
     NodeDef node_def =
         get_unpack_nodedef(dtype, ok_params[i].num, ok_params[i].axis);
     // Create inputs.
-    test->AddTestTensor("value", ok_params[i].input_shape, 1,
-                        TfDataTypeToTrt(dtype));
+    nvinfer1::DataType trt_type;
+    TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
+    test->AddTestTensor("value", ok_params[i].input_shape, 1, trt_type);
     // Convert.
     test->RunValidationAndConversion(node_def);
 
@@ -6272,8 +6273,10 @@ void TestConvertArgMinMax(OpConverterTest* test) {
 
     NodeDef node_def = GetArgMinMaxNodeDef<OpType>(dtype, DT_INT32);
     // Create inputs.
+    nvinfer1::DataType trt_type;
+    TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
     test->AddTestTensor("input", params[i].input_shape, /*batch_size=*/1,
-                        /*trt_dtype=*/TfDataTypeToTrt(dtype));
+                        /*trt_dtype=*/trt_type);
     test->AddTestWeights<int32>("dimension", {1}, {params[i].axis});
     test->RunValidationAndConversion(node_def);
 
@@ -6374,8 +6377,9 @@ void TestConvertDepthSpaceShuffle(
 
     NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>(
         dtype, params[i].block_size, params[i].data_format);
-    test->AddTestTensor("input", params[i].input_dims, 1,
-                        TfDataTypeToTrt(dtype));
+    nvinfer1::DataType trt_type;
+    TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
+    test->AddTestTensor("input", params[i].input_dims, 1, trt_type);
     test->RunValidationAndConversion(node_def);
 
     TRT_TensorOrWeights output;
@@ -6648,7 +6652,9 @@ void TestConvertClipByValue(OpConverterTest* test) {
     test->Reset();
 
     NodeDef node_def = GetClipByValueNodeDef(dtype);
-    test->AddTestTensor("t", params[i].dims, 1, TfDataTypeToTrt(dtype));
+    nvinfer1::DataType trt_type;
+    TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
+    test->AddTestTensor("t", params[i].dims, 1, trt_type);
     test->AddTestWeights<CType>("clip_value_min", {1},
                                 {params[i].clip_value_min});
     test->AddTestWeights<CType>("clip_value_max", {1},
@@ -6848,9 +6854,11 @@ void TestConvertResize(OpConverterTest* test) {
     // Create resize node.
     NodeDef node_def =
         MakeResizeNodeDef<OpType>("my_resize", dtype, params[i].align_corners);
+    nvinfer1::DataType trt_type;
+    TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
     // Create input tensor
     test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1,
-                        /*trt_dtype=*/TfDataTypeToTrt(dtype));
+                        /*trt_dtype=*/trt_type);
     // Create output size.
     test->AddTestWeights<int32>("size", {2}, params[i].output_resize_dims);
 
@@ -6949,8 +6957,10 @@ void TestConvertPad(OpConverterTest* test) {
     // Create pad node.
     NodeDef node_def = MakePadNodeDef("my_pad", dtype);
     // Create input tensor
+    nvinfer1::DataType trt_type;
+    TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
     test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1,
-                        /*trt_dtype=*/TfDataTypeToTrt(dtype));
+                        /*trt_dtype=*/trt_type);
     // Create output size.
     test->AddTestWeights<int32>("padding", params[i].pad_dims,
                                 {0, 0, 1, 0, 0, 1, 0, 0});
diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc
index 1fc0d13c993..650aff5836f 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc
@@ -198,7 +198,8 @@ Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
       *trt_type = nvinfer1::DataType::kINT32;
       break;
     default:
-      return errors::Internal("Unsupported tensorflow type");
+      return errors::InvalidArgument("Unsupported tensorflow data type ",
+                                     DataTypeString(tf_type));
   }
   return Status::OK();
 }
@@ -215,7 +216,7 @@ Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
       *tf_type = DT_INT32;
       break;
     default:
-      return errors::Internal("Invalid TRT type");
+      return errors::InvalidArgument("Invalid TRT data type");
   }
   return Status::OK();
 }
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
index e0a731e502e..2d56209a068 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
@@ -103,13 +103,16 @@ class TRTEngineOp : public AsyncOpKernel {
                           TRTEngineCacheResource* cache_res,
                           AsyncHelper* helper);
 
-  // Construct a function handle for executing native funcdef graph
-  // These are the exact same function.
+  // Constructs a function handle for the segment of the TRTEngineOp.
+  StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle(
+      FunctionLibraryRuntime* lib, const string& device_name,
+      bool allow_soft_placement = false, size_t num_inputs = 0,
+      size_t num_outputs = 0);
 
-  Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
-                                 const string& device_name,
-                                 bool allow_soft_placement = false,
-                                 size_t num_inputs = 0, size_t num_outputs = 0);
+  // Imports the GraphDef for the segment of the TRTEngineOp to
+  // segment_graph_def_.
+  Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
+                               const string& device_name);
 
   // Executes replaced native segment as function Op.
   void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
@@ -175,12 +178,16 @@ class TRTEngineOp : public AsyncOpKernel {
   // Whether to build TensorRT engines at runtime.
   bool allow_build_at_runtime_;
 
+  // Whether to allow soft placement when the graph is executed with native
+  // TensorFlow.
+  bool allow_soft_placement_;
+
   // Maximum number of cached engines.
   int max_cached_engines_;
 
   int64 workspace_size_;
   mutex engine_mutex_;
-  FunctionLibraryRuntime::Handle func_handle_;
+  FunctionLibraryRuntime::Handle native_execution_func_handle_;
 
   // The finalized calibrator for inference.
   std::unique_ptr<TRTInt8Calibrator> calibrator_;
@@ -260,11 +267,9 @@ static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
   return Status::OK();
 }
 
-Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
-                                            const string& device_name,
-                                            bool allow_soft_placement,
-                                            size_t num_inputs,
-                                            size_t num_outputs) {
+StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle(
+    FunctionLibraryRuntime* lib, const string& device_name,
+    bool allow_soft_placement, size_t num_inputs, size_t num_outputs) {
   VLOG(1) << "Constructing function handle";
   if (lib == nullptr) {
     return errors::Internal("Context function library is null");
@@ -298,8 +303,20 @@ Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
       inst_ops.config_proto.set_allow_soft_placement(true);
     }
   }
-  return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
-                          &func_handle_);
+  FunctionLibraryRuntime::Handle func_handle;
+  Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
+                                   inst_ops, &func_handle);
+  if (status.ok()) {
+    return func_handle;
+  }
+  return status;
+}
+
+Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
+                                          const string& device_name) {
+  TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle,
+                      ConstructFunctionHandle(lib, device_name));
+  return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_);
 }
 
 TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
@@ -335,14 +352,21 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
             << context->device()->name()
             << ", thus setting _allow_build_at_runtime=true";
     allow_build_at_runtime_ = true;
+  } else {
+    OP_REQUIRES_OK(context, status);
   }
-  func_handle_ = kInvalidHandle;
+
+  status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_);
+  if (status.code() == tensorflow::error::NOT_FOUND) {
+    allow_soft_placement_ = true;
+  } else {
+    OP_REQUIRES_OK(context, status);
+  }
+
+  native_execution_func_handle_ = kInvalidHandle;
   if (!static_engine_) {
-    FunctionLibraryRuntime* lib = context->function_library();
-    OP_REQUIRES_OK(context,
-                   ConstructFunctionHandle(lib, context->device()->name()));
-    OP_REQUIRES_OK(
-        context, FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_));
+    OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(),
+                                                  context->device()->name()));
   }
   // TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
   // backward compatibility reasons. Remove it once all known users switch to
@@ -411,13 +435,13 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
                                        AsyncHelper* helper) {
   std::vector<Tensor> inputs;
   std::vector<Tensor>* outputs = new std::vector<Tensor>();
-  if (func_handle_ == kInvalidHandle) {
-    OP_REQUIRES_OK_ASYNC(
-        ctx,
+  if (native_execution_func_handle_ == kInvalidHandle) {
+    StatusOr<FunctionLibraryRuntime::Handle> status_or_handle =
         ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
-                                /*allow_soft_placement=*/true,
-                                ctx->num_inputs(), ctx->num_outputs()),
-        *helper);
+                                allow_soft_placement_, ctx->num_inputs(),
+                                ctx->num_outputs());
+    OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), *helper);
+    native_execution_func_handle_ = status_or_handle.ValueOrDie();
   }
   auto lib = ctx->function_library();
   FunctionLibraryRuntime::Options opts;
@@ -430,7 +454,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
   }
   helper->Ref();  // Increment count for calculating native graph
   VLOG(1) << "Executing native segment: " << name();
-  lib->Run(opts, func_handle_, inputs, outputs,
+  lib->Run(opts, native_execution_func_handle_, inputs, outputs,
            [this, ctx, outputs, helper](const Status& s) {
              core::ScopedUnref sc(helper);
              OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
@@ -854,12 +878,8 @@ StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
         return std::pair<EngineContext*, int>(&empty_context, 0);
       }
       if (segment_graph_def_.node().empty()) {
-        FunctionLibraryRuntime* lib = ctx->function_library();
-        auto status = ConstructFunctionHandle(lib, ctx->device()->name());
-        if (status.ok()) {
-          status =
-              FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_);
-        }
+        Status status = ImportSegmentGraphDef(ctx->function_library(),
+                                              ctx->device()->name());
         if (!status.ok()) {
           LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for "
                                             << name() << " failed. "
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc
index 71193dc24cf..e7beb589f7b 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc
@@ -91,6 +91,13 @@ class TRTEngineOpTestBase : public OpsTestBase {
     OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
     NameAttrList function;
     function.set_name(StrCat(op_name, "_native_segment"));
+    // We disable allow_soft_placement when executing the native segment of the
+    // TRTEngineOp for the following reasons:
+    //    OpsTestBase only allow one device in the device manager.
+    //    We need to define the GPU device to test TRTEngineOp.
+    //    When allow_soft_placement is true, the TensorFlow runtime produces an
+    //      error if a CPU device is not defined
+    //      (see ProcessFunctionLibraryRuntime::InstantiateMultiDevice).
     TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp")
                      .Input(FakeInput(1, dtype))
                      .Attr("input_shapes", {shape})
@@ -105,6 +112,7 @@ class TRTEngineOpTestBase : public OpsTestBase {
                      .Attr("use_calibration", false)
                      .Attr("_use_implicit_batch", use_implicit_batch)
                      .Attr("_allow_build_at_runtime", allow_build_at_runtime)
+                     .Attr("_allow_soft_placement", false)
                      .Attr("OutT", {dtype})
                      .Finalize(OpsTestBase::node_def()));
     TF_ASSERT_OK(InitOpWithFunctionLibrary());
diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
index 350c198ee70..ba3dea924de 100644
--- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
+++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
@@ -52,7 +52,11 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
 
   bool IsEnabled(const ConfigProto& config_proto,
                  const Graph& graph) const override {
-    return IsMlirBridgePassEnabled(graph, config_proto);
+    // Do not run the bridge if it's enabled by the graph analysis,
+    // only run if it's enabled by the user explicitly.
+    MlirBridgeRolloutPolicy policy =
+        GetMlirBridgeRolloutPolicy(graph, config_proto);
+    return policy == MlirBridgeRolloutPolicy::kEnabledByUser;
   }
 
   // This should be used as a thin mapper around mlir::ModulePass::runOnModule
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 9883a3f47c5..d44f9991936 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -361,6 +361,7 @@ tf_cc_test(
         ":util",
         ":xla_data_proto_cc",
         "//tensorflow/core:lib",
+        "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "@com_google_absl//absl/strings",
     ],
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 28f76829f14..171afa42351 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -234,6 +234,7 @@ cc_library(
         "//tensorflow/compiler/xla/service:hlo_proto_cc",
         "//tensorflow/compiler/xla/service:shape_inference",
         "//tensorflow/core:lib",
+        "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 92d222f32b2..c797f58274c 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -59,7 +59,6 @@ cc_library(
     srcs = ["comparators.cc"],
     hdrs = [
         "comparators.h",
-        "//tensorflow/compiler/xla:literal_util",
     ],
     deps = [
         ":constants",
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 1043528e752..33506535ddf 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -34,6 +34,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/comparison_util.h"
 #include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -46,6 +47,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace xla {
 
@@ -1266,7 +1268,8 @@ StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
 }
 
 XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
-                      const PrecisionConfig* precision_config) {
+                      const PrecisionConfig* precision_config,
+                      absl::optional<PrimitiveType> preferred_element_type) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
 
@@ -1278,15 +1281,17 @@ XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
   });
 }
 
-XlaOp XlaBuilder::DotGeneral(XlaOp lhs, XlaOp rhs,
-                             const DotDimensionNumbers& dimension_numbers,
-                             const PrecisionConfig* precision_config) {
+XlaOp XlaBuilder::DotGeneral(
+    XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
+    const PrecisionConfig* precision_config,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
-    TF_ASSIGN_OR_RETURN(Shape shape,
-                        ShapeInference::InferDotOpShape(*lhs_shape, *rhs_shape,
-                                                        dimension_numbers));
+    TF_ASSIGN_OR_RETURN(
+        Shape shape,
+        ShapeInference::InferDotOpShape(
+            *lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
     return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
                               precision_config);
   });
@@ -1353,28 +1358,33 @@ Status XlaBuilder::VerifyConvolution(
 XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
                        absl::Span<const int64> window_strides, Padding padding,
                        int64 feature_group_count, int64 batch_group_count,
-                       const PrecisionConfig* precision_config) {
+                       const PrecisionConfig* precision_config,
+                       absl::optional<PrimitiveType> preferred_element_type) {
   return ConvWithGeneralDimensions(
       lhs, rhs, window_strides, padding,
       CreateDefaultConvDimensionNumbers(window_strides.size()),
-      feature_group_count, batch_group_count, precision_config);
+      feature_group_count, batch_group_count, precision_config,
+      preferred_element_type);
 }
 
 XlaOp XlaBuilder::ConvWithGeneralPadding(
     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
     absl::Span<const std::pair<int64, int64>> padding,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config) {
+    const PrecisionConfig* precision_config,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return ConvGeneral(lhs, rhs, window_strides, padding,
                      CreateDefaultConvDimensionNumbers(window_strides.size()),
-                     feature_group_count, batch_group_count, precision_config);
+                     feature_group_count, batch_group_count, precision_config,
+                     preferred_element_type);
 }
 
 XlaOp XlaBuilder::ConvWithGeneralDimensions(
     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config) {
+    const PrecisionConfig* precision_config,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
@@ -1402,7 +1412,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
                        MakePadding(base_area_dimensions, window_dimensions,
                                    window_strides, padding),
                        dimension_numbers, feature_group_count,
-                       batch_group_count, precision_config);
+                       batch_group_count, precision_config,
+                       preferred_element_type);
   });
 }
 
@@ -1411,10 +1422,12 @@ XlaOp XlaBuilder::ConvGeneral(
     absl::Span<const std::pair<int64, int64>> padding,
     const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config) {
+    const PrecisionConfig* precision_config,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
                             dimension_numbers, feature_group_count,
-                            batch_group_count, precision_config);
+                            batch_group_count, precision_config,
+                            preferred_element_type);
 }
 
 XlaOp XlaBuilder::ConvGeneralDilated(
@@ -1423,7 +1436,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
     const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config) {
+    const PrecisionConfig* precision_config,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
@@ -1442,10 +1456,11 @@ XlaOp XlaBuilder::ConvGeneralDilated(
                         ShapeInference::InferWindowFromDimensions(
                             window_dimensions, window_strides, padding,
                             lhs_dilation, rhs_dilation));
-    TF_ASSIGN_OR_RETURN(Shape shape,
-                        ShapeInference::InferConvolveShape(
-                            *lhs_shape, *rhs_shape, feature_group_count,
-                            batch_group_count, window, dimension_numbers));
+    TF_ASSIGN_OR_RETURN(
+        Shape shape,
+        ShapeInference::InferConvolveShape(
+            *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
+            window, dimension_numbers, preferred_element_type));
     return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
                                       padding, lhs_dilation, rhs_dilation,
                                       dimension_numbers, feature_group_count,
@@ -1459,7 +1474,8 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
     const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config, PaddingType padding_type) {
+    const PrecisionConfig* precision_config, PaddingType padding_type,
+    absl::optional<PrimitiveType> preferred_element_type) {
   TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
   TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
   std::vector<int64> window_dimensions(
@@ -1472,10 +1488,11 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
   TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
                                          window_dimensions, window_strides,
                                          padding, lhs_dilation, rhs_dilation));
-  TF_ASSIGN_OR_RETURN(Shape shape,
-                      ShapeInference::InferConvolveShape(
-                          *lhs_shape, *rhs_shape, feature_group_count,
-                          batch_group_count, window, dimension_numbers));
+  TF_ASSIGN_OR_RETURN(
+      Shape shape,
+      ShapeInference::InferConvolveShape(
+          *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
+          window, dimension_numbers, preferred_element_type));
 
   HloInstructionProto instr;
   *instr.mutable_shape() = shape.ToProto();
@@ -1499,14 +1516,15 @@ XlaOp XlaBuilder::DynamicConvInputGrad(
     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
     const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config, PaddingType padding_type) {
+    const PrecisionConfig* precision_config, PaddingType padding_type,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(
         HloInstructionProto instr,
-        DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
-                               rhs_dilation, dimension_numbers,
-                               feature_group_count, batch_group_count,
-                               precision_config, padding_type));
+        DynamicConvInstruction(
+            lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
+            dimension_numbers, feature_group_count, batch_group_count,
+            precision_config, padding_type, preferred_element_type));
 
     instr.set_custom_call_target("DynamicConvolutionInputGrad");
 
@@ -1521,14 +1539,16 @@ XlaOp XlaBuilder::DynamicConvKernelGrad(
     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
     const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config, PaddingType padding_type) {
+    const PrecisionConfig* precision_config, PaddingType padding_type,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(
         HloInstructionProto instr,
         DynamicConvInstruction(activations, gradients, window_strides, padding,
                                lhs_dilation, rhs_dilation, dimension_numbers,
                                feature_group_count, batch_group_count,
-                               precision_config, padding_type));
+                               precision_config, padding_type,
+                               preferred_element_type));
 
     instr.set_custom_call_target("DynamicConvolutionKernelGrad");
     // The gradient of kernel has kernel shape and shouldn't have any dynamic
@@ -1545,14 +1565,15 @@ XlaOp XlaBuilder::DynamicConvForward(
     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
     const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config, PaddingType padding_type) {
+    const PrecisionConfig* precision_config, PaddingType padding_type,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(
         HloInstructionProto instr,
-        DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
-                               rhs_dilation, dimension_numbers,
-                               feature_group_count, batch_group_count,
-                               precision_config, padding_type));
+        DynamicConvInstruction(
+            lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
+            dimension_numbers, feature_group_count, batch_group_count,
+            precision_config, padding_type, preferred_element_type));
     instr.set_custom_call_target("DynamicConvolutionForward");
 
     return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
@@ -3331,6 +3352,11 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
     }
 
     if (!need_rewrite) {
+      if (opcode == HloOpcode::kCompare) {
+        CHECK(!instr_proto->comparison_type().empty());
+        new_instr->set_comparison_type(
+            ComparisonTypeToString(Comparison::DefaultComparisonType(PRED)));
+      }
       *new_instr->mutable_name() =
           GetFullName(instr_proto->opcode(), kNameSeparator, id);
       return Status::OK();
@@ -3990,11 +4016,26 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
 }
 
+static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
+                               absl::Span<const int64> broadcast_dimensions,
+                               ComparisonDirection comparison_direction) {
+  auto b = lhs.builder();
+  return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
+    auto operand_element_type = operand_shape.element_type();
+    auto compare_type =
+        primitive_util::IsFloatingPointType(operand_element_type)
+            ? Comparison::Type::kFloatTotalOrder
+            : Comparison::DefaultComparisonType(operand_element_type);
+    return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
+                   compare_type);
+  });
+}
+
 XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  auto compare_type = Comparison::Type::kFloatTotalOrder;
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq,
-                 compare_type);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kEq);
 }
 
 XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
@@ -4004,9 +4045,8 @@ XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
 
 XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  auto compare_type = Comparison::Type::kFloatTotalOrder;
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe,
-                 compare_type);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kNe);
 }
 
 XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
@@ -4016,9 +4056,8 @@ XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
 
 XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  auto compare_type = Comparison::Type::kFloatTotalOrder;
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe,
-                 compare_type);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kGe);
 }
 
 XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
@@ -4028,9 +4067,8 @@ XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
 
 XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  auto compare_type = Comparison::Type::kFloatTotalOrder;
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt,
-                 compare_type);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kGt);
 }
 
 XlaOp Le(const XlaOp lhs, const XlaOp rhs,
@@ -4040,10 +4078,10 @@ XlaOp Le(const XlaOp lhs, const XlaOp rhs,
 
 XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  auto compare_type = Comparison::Type::kFloatTotalOrder;
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe,
-                 compare_type);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kLe);
 }
+
 XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
          absl::Span<const int64> broadcast_dimensions) {
   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
@@ -4051,8 +4089,8 @@ XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
 
 XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt,
-                 Comparison::Type::kFloatTotalOrder);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kLt);
 }
 
 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
@@ -4074,44 +4112,49 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
 }
 
 XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
-          const PrecisionConfig* precision_config) {
-  return lhs.builder()->Dot(lhs, rhs, precision_config);
+          const PrecisionConfig* precision_config,
+          absl::optional<PrimitiveType> preferred_element_type) {
+  return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
 }
 
 XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
                  const DotDimensionNumbers& dimension_numbers,
-                 const PrecisionConfig* precision_config) {
+                 const PrecisionConfig* precision_config,
+                 absl::optional<PrimitiveType> preferred_element_type) {
   return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
-                                   precision_config);
+                                   precision_config, preferred_element_type);
 }
 
 XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
            absl::Span<const int64> window_strides, Padding padding,
            int64 feature_group_count, int64 batch_group_count,
-           const PrecisionConfig* precision_config) {
+           const PrecisionConfig* precision_config,
+           absl::optional<PrimitiveType> preferred_element_type) {
   return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
                              feature_group_count, batch_group_count,
-                             precision_config);
+                             precision_config, preferred_element_type);
 }
 
-XlaOp ConvWithGeneralPadding(const XlaOp lhs, const XlaOp rhs,
-                             absl::Span<const int64> window_strides,
-                             absl::Span<const std::pair<int64, int64>> padding,
-                             int64 feature_group_count, int64 batch_group_count,
-                             const PrecisionConfig* precision_config) {
+XlaOp ConvWithGeneralPadding(
+    const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
+    absl::Span<const std::pair<int64, int64>> padding,
+    int64 feature_group_count, int64 batch_group_count,
+    const PrecisionConfig* precision_config,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return lhs.builder()->ConvWithGeneralPadding(
       lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
-      precision_config);
+      precision_config, preferred_element_type);
 }
 
 XlaOp ConvWithGeneralDimensions(
     const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config) {
+    const PrecisionConfig* precision_config,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return lhs.builder()->ConvWithGeneralDimensions(
       lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
-      batch_group_count, precision_config);
+      batch_group_count, precision_config, preferred_element_type);
 }
 
 XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
@@ -4119,10 +4162,11 @@ XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
                   absl::Span<const std::pair<int64, int64>> padding,
                   const ConvolutionDimensionNumbers& dimension_numbers,
                   int64 feature_group_count, int64 batch_group_count,
-                  const PrecisionConfig* precision_config) {
-  return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
-                                    dimension_numbers, feature_group_count,
-                                    batch_group_count, precision_config);
+                  const PrecisionConfig* precision_config,
+                  absl::optional<PrimitiveType> preferred_element_type) {
+  return lhs.builder()->ConvGeneral(
+      lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
+      batch_group_count, precision_config, preferred_element_type);
 }
 
 XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
@@ -4132,26 +4176,27 @@ XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
                          absl::Span<const int64> rhs_dilation,
                          const ConvolutionDimensionNumbers& dimension_numbers,
                          int64 feature_group_count, int64 batch_group_count,
-                         const PrecisionConfig* precision_config) {
+                         const PrecisionConfig* precision_config,
+                         absl::optional<PrimitiveType> preferred_element_type) {
   return lhs.builder()->ConvGeneralDilated(
       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
       dimension_numbers, feature_group_count, batch_group_count,
-      precision_config);
+      precision_config, preferred_element_type);
 }
 
-XlaOp DynamicConvInputGrad(XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
-                           absl::Span<const int64> window_strides,
-                           absl::Span<const std::pair<int64, int64>> padding,
-                           absl::Span<const int64> lhs_dilation,
-                           absl::Span<const int64> rhs_dilation,
-                           const ConvolutionDimensionNumbers& dimension_numbers,
-                           int64 feature_group_count, int64 batch_group_count,
-                           const PrecisionConfig* precision_config,
-                           PaddingType padding_type) {
+XlaOp DynamicConvInputGrad(
+    XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
+    absl::Span<const int64> window_strides,
+    absl::Span<const std::pair<int64, int64>> padding,
+    absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
+    const ConvolutionDimensionNumbers& dimension_numbers,
+    int64 feature_group_count, int64 batch_group_count,
+    const PrecisionConfig* precision_config, PaddingType padding_type,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return lhs.builder()->DynamicConvInputGrad(
       input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
       rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
-      precision_config, padding_type);
+      precision_config, padding_type, preferred_element_type);
 }
 
 XlaOp DynamicConvKernelGrad(
@@ -4160,11 +4205,12 @@ XlaOp DynamicConvKernelGrad(
     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
     const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config, PaddingType padding_type) {
+    const PrecisionConfig* precision_config, PaddingType padding_type,
+    absl::optional<PrimitiveType> preferred_element_type) {
   return activations.builder()->DynamicConvKernelGrad(
       activations, gradients, window_strides, padding, lhs_dilation,
       rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
-      precision_config, padding_type);
+      precision_config, padding_type, preferred_element_type);
 }
 
 XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
@@ -4175,11 +4221,12 @@ XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
                          const ConvolutionDimensionNumbers& dimension_numbers,
                          int64 feature_group_count, int64 batch_group_count,
                          const PrecisionConfig* precision_config,
-                         PaddingType padding_type) {
+                         PaddingType padding_type,
+                         absl::optional<PrimitiveType> preferred_element_type) {
   return lhs.builder()->DynamicConvForward(
       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
       dimension_numbers, feature_group_count, batch_group_count,
-      precision_config, padding_type);
+      precision_config, padding_type, preferred_element_type);
 }
 
 XlaOp Fft(const XlaOp operand, FftType fft_type,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index e54ed12c655..9e6a46f77b7 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -521,56 +521,63 @@ class XlaBuilder {
                                                   XlaOp tuple_data,
                                                   int64 index);
 
-  XlaOp Dot(XlaOp lhs, XlaOp rhs,
-            const PrecisionConfig* precision_config = nullptr);
+  XlaOp Dot(
+      XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
-  XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
-                   const DotDimensionNumbers& dimension_numbers,
-                   const PrecisionConfig* precision_config = nullptr);
+  XlaOp DotGeneral(
+      XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
+      const PrecisionConfig* precision_config = nullptr,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
-  XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
-             Padding padding, int64 feature_group_count = 1,
-             int64 batch_group_count = 1,
-             const PrecisionConfig* precision_config = nullptr);
+  XlaOp Conv(
+      XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
+      Padding padding, int64 feature_group_count = 1,
+      int64 batch_group_count = 1,
+      const PrecisionConfig* precision_config = nullptr,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
   XlaOp ConvWithGeneralPadding(
       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
       absl::Span<const std::pair<int64, int64>> padding,
       int64 feature_group_count = 1, int64 batch_group_count = 1,
-      const PrecisionConfig* precision_config = nullptr);
+      const PrecisionConfig* precision_config = nullptr,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
   XlaOp ConvWithGeneralDimensions(
       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
       Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count = 1, int64 batch_group_count = 1,
-      const PrecisionConfig* precision_config = nullptr);
+      const PrecisionConfig* precision_config = nullptr,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
-  XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
-                    absl::Span<const int64> window_strides,
-                    absl::Span<const std::pair<int64, int64>> padding,
-                    const ConvolutionDimensionNumbers& dimension_numbers,
-                    int64 feature_group_count = 1, int64 batch_group_count = 1,
-                    const PrecisionConfig* precision_config = nullptr);
+  XlaOp ConvGeneral(
+      XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
+      absl::Span<const std::pair<int64, int64>> padding,
+      const ConvolutionDimensionNumbers& dimension_numbers,
+      int64 feature_group_count = 1, int64 batch_group_count = 1,
+      const PrecisionConfig* precision_config = nullptr,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
-  XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
-                           absl::Span<const int64> window_strides,
-                           absl::Span<const std::pair<int64, int64>> padding,
-                           absl::Span<const int64> lhs_dilation,
-                           absl::Span<const int64> rhs_dilation,
-                           const ConvolutionDimensionNumbers& dimension_numbers,
-                           int64 feature_group_count = 1,
-                           int64 batch_group_count = 1,
-                           const PrecisionConfig* precision_config = nullptr);
+  XlaOp ConvGeneralDilated(
+      XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
+      absl::Span<const std::pair<int64, int64>> padding,
+      absl::Span<const int64> lhs_dilation,
+      absl::Span<const int64> rhs_dilation,
+      const ConvolutionDimensionNumbers& dimension_numbers,
+      int64 feature_group_count = 1, int64 batch_group_count = 1,
+      const PrecisionConfig* precision_config = nullptr,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
-  XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
-                           absl::Span<const int64> window_strides,
-                           absl::Span<const std::pair<int64, int64>> padding,
-                           absl::Span<const int64> lhs_dilation,
-                           absl::Span<const int64> rhs_dilation,
-                           const ConvolutionDimensionNumbers& dimension_numbers,
-                           int64 feature_group_count, int64 batch_group_count,
-                           const PrecisionConfig* precision_config,
-                           PaddingType padding_type);
+  XlaOp DynamicConvForward(
+      XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
+      absl::Span<const std::pair<int64, int64>> padding,
+      absl::Span<const int64> lhs_dilation,
+      absl::Span<const int64> rhs_dilation,
+      const ConvolutionDimensionNumbers& dimension_numbers,
+      int64 feature_group_count, int64 batch_group_count,
+      const PrecisionConfig* precision_config, PaddingType padding_type,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
   XlaOp DynamicConvInputGrad(
       XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
@@ -580,7 +587,8 @@ class XlaBuilder {
       absl::Span<const int64> rhs_dilation,
       const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count, int64 batch_group_count,
-      const PrecisionConfig* precision_config, PaddingType padding_type);
+      const PrecisionConfig* precision_config, PaddingType padding_type,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
   XlaOp DynamicConvKernelGrad(
       XlaOp activations, XlaOp gradients,
@@ -590,7 +598,8 @@ class XlaBuilder {
       absl::Span<const int64> rhs_dilation,
       const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count, int64 batch_group_count,
-      const PrecisionConfig* precision_config, PaddingType padding_type);
+      const PrecisionConfig* precision_config, PaddingType padding_type,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
   StatusOr<HloInstructionProto> DynamicConvInstruction(
       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
@@ -599,7 +608,8 @@ class XlaBuilder {
       absl::Span<const int64> rhs_dilation,
       const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count, int64 batch_group_count,
-      const PrecisionConfig* precision_config, PaddingType padding_type);
+      const PrecisionConfig* precision_config, PaddingType padding_type,
+      absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
   virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
       const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
@@ -1098,10 +1108,12 @@ class XlaBuilder {
                        ComparisonDirection direction,
                        Comparison::Type compare_type);
   friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
-                   const PrecisionConfig* precision_config);
+                   const PrecisionConfig* precision_config,
+                   absl::optional<PrimitiveType> preferred_element_type);
   friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
                           const DotDimensionNumbers& dimension_number,
-                          const PrecisionConfig* precision_config);
+                          const PrecisionConfig* precision_config,
+                          absl::optional<PrimitiveType> preferred_element_type);
   virtual StatusOr<XlaOp> DotGeneralInternal(
       const Shape& shape, XlaOp lhs, XlaOp rhs,
       const DotDimensionNumbers& dimension_number,
@@ -1109,23 +1121,27 @@ class XlaBuilder {
   friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
                     absl::Span<const int64> window_strides, Padding padding,
                     int64 feature_group_count, int64 batch_group_count,
-                    const PrecisionConfig* precision_config);
+                    const PrecisionConfig* precision_config,
+                    absl::optional<PrimitiveType> preferred_element_type);
   friend XlaOp ConvWithGeneralPadding(
       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
       absl::Span<const std::pair<int64, int64>> padding,
       int64 feature_group_count, int64 batch_group_count,
-      const PrecisionConfig* precision_config);
+      const PrecisionConfig* precision_config,
+      absl::optional<PrimitiveType> preferred_element_type);
   friend XlaOp ConvWithGeneralDimensions(
       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
       Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count, int64 batch_group_count,
-      const PrecisionConfig* precision_confige);
-  friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
-                           absl::Span<const int64> window_strides,
-                           absl::Span<const std::pair<int64, int64>> padding,
-                           const ConvolutionDimensionNumbers& dimension_numbers,
-                           int64 feature_group_count, int64 batch_group_count,
-                           const PrecisionConfig* precision_config);
+      const PrecisionConfig* precision_config,
+      absl::optional<PrimitiveType> preferred_element_type);
+  friend XlaOp ConvGeneral(
+      XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
+      absl::Span<const std::pair<int64, int64>> padding,
+      const ConvolutionDimensionNumbers& dimension_numbers,
+      int64 feature_group_count, int64 batch_group_count,
+      const PrecisionConfig* precision_config,
+      absl::optional<PrimitiveType> preferred_element_type);
   friend XlaOp DynamicConvForward(
       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
       absl::Span<const std::pair<int64, int64>> padding,
@@ -1133,7 +1149,8 @@ class XlaBuilder {
       absl::Span<const int64> rhs_dilation,
       const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count, int64 batch_group_count,
-      const PrecisionConfig* precision_config, PaddingType padding_type);
+      const PrecisionConfig* precision_config, PaddingType padding_type,
+      absl::optional<PrimitiveType> preferred_element_type);
   friend XlaOp DynamicConvKernelGrad(
       XlaOp activations, XlaOp gradients,
       absl::Span<const int64> window_strides,
@@ -1142,7 +1159,8 @@ class XlaBuilder {
       absl::Span<const int64> rhs_dilation,
       const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count, int64 batch_group_count,
-      const PrecisionConfig* precision_config, PaddingType padding_type);
+      const PrecisionConfig* precision_config, PaddingType padding_type,
+      absl::optional<PrimitiveType> preferred_element_type);
   friend XlaOp DynamicConvInputGrad(
       XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
       absl::Span<const int64> window_strides,
@@ -1151,7 +1169,8 @@ class XlaBuilder {
       absl::Span<const int64> rhs_dilation,
       const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count, int64 batch_group_count,
-      const PrecisionConfig* precision_config, PaddingType padding_type);
+      const PrecisionConfig* precision_config, PaddingType padding_type,
+      absl::optional<PrimitiveType> preferred_element_type);
 
   friend XlaOp ConvKernelGrad(
       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
@@ -1160,7 +1179,8 @@ class XlaBuilder {
       absl::Span<const int64> rhs_dilation,
       const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count, int64 batch_group_count,
-      const PrecisionConfig* precision_config);
+      const PrecisionConfig* precision_config,
+      absl::optional<PrimitiveType> preferred_element_type);
 
   friend XlaOp ConvGeneralDilated(
       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
@@ -1169,7 +1189,8 @@ class XlaBuilder {
       absl::Span<const int64> rhs_dilation,
       const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count, int64 batch_group_count,
-      const PrecisionConfig* precision_config);
+      const PrecisionConfig* precision_config,
+      absl::optional<PrimitiveType> preferred_element_type);
   friend XlaOp Fft(XlaOp operand, FftType fft_type,
                    absl::Span<const int64> fft_length);
   friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
@@ -1813,28 +1834,31 @@ XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
 
 // Enqueues a dot instruction onto the computation.
 XlaOp Dot(XlaOp lhs, XlaOp rhs,
-          const PrecisionConfig* precision_config = nullptr);
+          const PrecisionConfig* precision_config = nullptr,
+          absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
 // Enqueues a general dot instruction onto the computation.
-XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
-                 const DotDimensionNumbers& dimension_numbers,
-                 const PrecisionConfig* precision_config = nullptr);
+XlaOp DotGeneral(
+    XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
+    const PrecisionConfig* precision_config = nullptr,
+    absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
 // Enqueues a convolution instruction onto the computation, which uses the
 // default convolution dimension numbers.
-XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
-           Padding padding, int64 feature_group_count = 1,
-           int64 batch_group_count = 1,
-           const PrecisionConfig* precision_config = nullptr);
+XlaOp Conv(
+    XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
+    Padding padding, int64 feature_group_count = 1, int64 batch_group_count = 1,
+    const PrecisionConfig* precision_config = nullptr,
+    absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
 // Enqueues a convolution instruction onto the computation, with the caller
 // provided padding configuration in the format returned by MakePadding().
-XlaOp ConvWithGeneralPadding(XlaOp lhs, XlaOp rhs,
-                             absl::Span<const int64> window_strides,
-                             absl::Span<const std::pair<int64, int64>> padding,
-                             int64 feature_group_count = 1,
-                             int64 batch_group_count = 1,
-                             const PrecisionConfig* precision_config = nullptr);
+XlaOp ConvWithGeneralPadding(
+    XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
+    absl::Span<const std::pair<int64, int64>> padding,
+    int64 feature_group_count = 1, int64 batch_group_count = 1,
+    const PrecisionConfig* precision_config = nullptr,
+    absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
 // Enqueues a convolution instruction onto the computation, with the caller
 // provided dimension numbers configuration.
@@ -1842,47 +1866,48 @@ XlaOp ConvWithGeneralDimensions(
     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count = 1, int64 batch_group_count = 1,
-    const PrecisionConfig* precision_config = nullptr);
+    const PrecisionConfig* precision_config = nullptr,
+    absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
 // Enqueues a convolution instruction onto the computation, with the caller
 // provided padding configuration as well as the dimension numbers.
-XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
-                  absl::Span<const std::pair<int64, int64>> padding,
-                  const ConvolutionDimensionNumbers& dimension_numbers,
-                  int64 feature_group_count = 1, int64 batch_group_count = 1,
-                  const PrecisionConfig* precision_config = nullptr);
+XlaOp ConvGeneral(
+    XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
+    absl::Span<const std::pair<int64, int64>> padding,
+    const ConvolutionDimensionNumbers& dimension_numbers,
+    int64 feature_group_count = 1, int64 batch_group_count = 1,
+    const PrecisionConfig* precision_config = nullptr,
+    absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
 // Enqueues a convolution instruction onto the computation, with the caller
 // provided padding configuration, dilation factors and dimension numbers.
-XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
-                         absl::Span<const int64> window_strides,
-                         absl::Span<const std::pair<int64, int64>> padding,
-                         absl::Span<const int64> lhs_dilation,
-                         absl::Span<const int64> rhs_dilation,
-                         const ConvolutionDimensionNumbers& dimension_numbers,
-                         int64 feature_group_count = 1,
-                         int64 batch_group_count = 1,
-                         const PrecisionConfig* precision_config = nullptr);
+XlaOp ConvGeneralDilated(
+    XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
+    absl::Span<const std::pair<int64, int64>> padding,
+    absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
+    const ConvolutionDimensionNumbers& dimension_numbers,
+    int64 feature_group_count = 1, int64 batch_group_count = 1,
+    const PrecisionConfig* precision_config = nullptr,
+    absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
-XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
-                         absl::Span<const int64> window_strides,
-                         absl::Span<const std::pair<int64, int64>> padding,
-                         absl::Span<const int64> lhs_dilation,
-                         absl::Span<const int64> rhs_dilation,
-                         const ConvolutionDimensionNumbers& dimension_numbers,
-                         int64 feature_group_count, int64 batch_group_count,
-                         const PrecisionConfig* precision_config,
-                         PaddingType padding_type);
+XlaOp DynamicConvForward(
+    XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
+    absl::Span<const std::pair<int64, int64>> padding,
+    absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
+    const ConvolutionDimensionNumbers& dimension_numbers,
+    int64 feature_group_count, int64 batch_group_count,
+    const PrecisionConfig* precision_config, PaddingType padding_type,
+    absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
-XlaOp DynamicConvInputGrad(XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
-                           absl::Span<const int64> window_strides,
-                           absl::Span<const std::pair<int64, int64>> padding,
-                           absl::Span<const int64> lhs_dilation,
-                           absl::Span<const int64> rhs_dilation,
-                           const ConvolutionDimensionNumbers& dimension_numbers,
-                           int64 feature_group_count, int64 batch_group_count,
-                           const PrecisionConfig* precision_config,
-                           PaddingType padding_type);
+XlaOp DynamicConvInputGrad(
+    XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
+    absl::Span<const int64> window_strides,
+    absl::Span<const std::pair<int64, int64>> padding,
+    absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
+    const ConvolutionDimensionNumbers& dimension_numbers,
+    int64 feature_group_count, int64 batch_group_count,
+    const PrecisionConfig* precision_config, PaddingType padding_type,
+    absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
 XlaOp DynamicConvKernelGrad(
     XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
@@ -1890,7 +1915,8 @@ XlaOp DynamicConvKernelGrad(
     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
     const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count, int64 batch_group_count,
-    const PrecisionConfig* precision_config, PaddingType padding_type);
+    const PrecisionConfig* precision_config, PaddingType padding_type,
+    absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
 
 // Enqueues an FFT instruction onto the computation, of the given type and
 // with the given FFT length.
diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc
index 4fc6c848a38..cd21c6dc414 100644
--- a/tensorflow/compiler/xla/client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_builder_test.cc
@@ -338,8 +338,7 @@ TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) {
                  /*broadcast_dimensions=*/{0, 1, 2});
   auto statusor = BuildHloModule(&b);
   ASSERT_FALSE(statusor.ok());
-  EXPECT_THAT(statusor.status().error_message(),
-              HasSubstr("shape's dimensions must not be < 0"));
+  EXPECT_THAT(statusor.status().error_message(), HasSubstr("invalid shape"));
 }
 
 TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
@@ -1066,6 +1065,56 @@ TEST_F(XlaBuilderTest, DynamicTranspose) {
       << result_shape;
 }
 
+TEST_F(XlaBuilderTest, DotWithPreferredElementType) {
+  XlaBuilder b(TestName());
+  Shape p0_shape = ShapeUtil::MakeShape(U8, {2, 3});
+  Shape p1_shape = ShapeUtil::MakeShape(U16, {3, 2});
+  auto p0 = Parameter(&b, 0, p0_shape, "p0");
+  auto p1 = Parameter(&b, 1, p1_shape, "p1");
+
+  DotDimensionNumbers dnums;
+  dnums.add_lhs_contracting_dimensions(1);
+  dnums.add_rhs_contracting_dimensions(0);
+  DotGeneral(p0, p1, dnums, /*precision_config=*/nullptr,
+             /*preferred_element_type=*/U32);
+  TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+  const Shape& result_shape =
+      module->entry_computation()->root_instruction()->shape();
+  ASSERT_TRUE(
+      ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {2, 2}), result_shape));
+}
+
+TEST_F(XlaBuilderTest, ConvolutionWithPreferredElementType) {
+  XlaBuilder b(TestName());
+  Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128});
+  Shape p1_shape = ShapeUtil::MakeShape(S8, {2, 2, 128, 8});
+  auto p0 = Parameter(&b, 0, p0_shape, "p0");
+  auto p1 = Parameter(&b, 1, p1_shape, "p1");
+
+  ConvolutionDimensionNumbers dnums;
+  dnums.set_input_batch_dimension(0);
+  dnums.set_output_batch_dimension(0);
+  dnums.add_input_spatial_dimensions(1);
+  dnums.add_output_spatial_dimensions(1);
+  dnums.add_input_spatial_dimensions(2);
+  dnums.add_output_spatial_dimensions(2);
+  dnums.set_input_feature_dimension(3);
+  dnums.set_output_feature_dimension(3);
+  dnums.add_kernel_spatial_dimensions(0);
+  dnums.add_kernel_spatial_dimensions(1);
+  dnums.set_kernel_input_feature_dimension(2);
+  dnums.set_kernel_output_feature_dimension(3);
+  ConvWithGeneralDimensions(p0, p1, {1, 1}, Padding::kValid, dnums,
+                            /*feature_group_count=*/1, /*batch_group_count=*/1,
+                            /*precision_config=*/nullptr,
+                            /*preferred_element_type=*/S32);
+  TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+  const Shape& result_shape =
+      module->entry_computation()->root_instruction()->shape();
+  ASSERT_TRUE(
+      ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {1, 1, 1, 8}), result_shape));
+}
+
 TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
   XlaBuilder b(TestName());
   AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});
diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD
index 5a14e229fa3..8b6eee1f75e 100644
--- a/tensorflow/compiler/xla/pjrt/BUILD
+++ b/tensorflow/compiler/xla/pjrt/BUILD
@@ -243,6 +243,7 @@ cc_library(
     hdrs = ["nvidia_gpu_device.h"],
     deps = [
         ":pjrt_client",
+        "@com_google_absl//absl/container:flat_hash_map",
         "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla/client:client_library",
diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
index 057f8d498a7..570c78c2d70 100644
--- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
+++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
@@ -15,12 +15,15 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
 
+#include "absl/container/flat_hash_map.h"
+
 #ifdef NCCL_ENABLED
 #include "third_party/nccl/nccl.h"
 #endif  // NCCL_ENABLED
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
 #include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/common_runtime/device/device_host_allocator.h"
 #include "tensorflow/core/common_runtime/device/device_id.h"
@@ -166,56 +169,57 @@ std::unique_ptr<tensorflow::BFCAllocator> GetGpuHostAllocator(
 
 // A table mapping NcclCliqueKeys to ncclUniqueId values encoded as strings.
 // In a distributed setup the table of NCCL IDs is kept on the master node
-// (node 0). Currently node 0 is the only node that generates ncclUniqueIds;
-// see the TODO below.
+// (node 0). The node of the first participating device will create the unique
+// id.
 class NcclIdStore {
  public:
-  NcclIdStore(int node_id, std::shared_ptr<DistributedRuntimeClient> client)
-      : node_id_(node_id), client_(std::move(client)) {}
+  NcclIdStore(int node_id, std::shared_ptr<DistributedRuntimeClient> client,
+              absl::flat_hash_map<GlobalDeviceId, int> device_to_node)
+      : node_id_(node_id),
+        client_(std::move(client)),
+        device_to_node_(std::move(device_to_node)) {}
 
   StatusOr<std::string> GetNcclUniqueId(const NcclCliqueKey& key);
 
  private:
   const int node_id_;
   const std::shared_ptr<DistributedRuntimeClient> client_;
+  const absl::flat_hash_map<GlobalDeviceId, int> device_to_node_;
 
   absl::Mutex mu_;
-  absl::flat_hash_map<std::string, std::string> cache_ ABSL_GUARDED_BY(mu_);
+  absl::flat_hash_map<NcclCliqueKey, std::string> cache_ ABSL_GUARDED_BY(mu_);
 };
 
 StatusOr<std::string> NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) {
-  std::string key_string = GlobalDeviceIdsToString(key.devices());
+  // The caller must ensure that threads calling this method concurrently have
+  // unique keys, otherwise the global key-value store may hold the wrong value.
   {
     absl::MutexLock lock(&mu_);
-    auto it = cache_.find(key_string);
+    auto it = cache_.find(key);
     if (it != cache_.end()) {
       return it->second;
     }
   }
-  auto result = [&]() -> StatusOr<std::string> {
-    // TODO(phawkins): this will deadlock if node 0 is not involved in the
-    // computation. Add support for computations that only use a subset of
-    // replicas.
-    if (node_id_ == 0) {
+  std::string id_string;
+  int primary_node_id = device_to_node_.at(key.devices()[0]);
+  if (node_id_ == primary_node_id) {
 #ifdef NCCL_ENABLED
-      ncclUniqueId id;
-      ncclResult_t r = ncclGetUniqueId(&id);
-      TF_RET_CHECK(r == ncclSuccess);
-      std::string value(id.internal, NCCL_UNIQUE_ID_BYTES);
-      TF_RETURN_IF_ERROR(client_->KeyValueSet(key_string, value));
-      return value;
+    ncclUniqueId id;
+    ncclResult_t r = ncclGetUniqueId(&id);
+    TF_RET_CHECK(r == ncclSuccess);
+    id_string = std::string(id.internal, NCCL_UNIQUE_ID_BYTES);
+    TF_RETURN_IF_ERROR(client_->KeyValueSet(key.ToString(), id_string));
 #else
-      return FailedPrecondition("NCCL support was not built into XLA binary.");
+    return FailedPrecondition("NCCL support was not built into XLA binary.");
 #endif
-    } else {
-      return client_->BlockingKeyValueGet(key_string, absl::Minutes(5));
-    }
-  }();
-  if (!result.ok()) {
-    return result.status();
+  } else {
+    TF_ASSIGN_OR_RETURN(id_string, client_->BlockingKeyValueGet(
+                                       key.ToString(), absl::Minutes(5)));
   }
   absl::MutexLock lock(&mu_);
-  return cache_.emplace(key_string, result.ValueOrDie()).first->second;
+  auto result = cache_.emplace(key, std::move(id_string));
+  TF_RET_CHECK(result.second) << "Unique ID already in cache.";
+  return result.first->second;
 }
 
 std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
@@ -258,8 +262,11 @@ Status BuildDistributedDevices(
       distributed_client->EnumerateDevices(local_topology, &global_topology));
 
   std::vector<GlobalDeviceId> gpu_device_ids(local_device_states.size());
+  absl::flat_hash_map<GlobalDeviceId, int> device_to_node;
   for (const LocalTopologyProto& node : global_topology.nodes()) {
     for (const DeviceProto& device_proto : node.devices()) {
+      GlobalDeviceId global_device_id(device_proto.global_device_id());
+      device_to_node[global_device_id] = node.node_id();
       std::unique_ptr<LocalDeviceState> local_device;
       if (node.node_id() == node_id) {
         TF_RET_CHECK(device_proto.local_device_ordinal() >= 0 &&
@@ -269,8 +276,7 @@ Status BuildDistributedDevices(
                      nullptr);
         local_device =
             std::move(local_device_states[device_proto.local_device_ordinal()]);
-        gpu_device_ids[device_proto.local_device_ordinal()] =
-            GlobalDeviceId(device_proto.global_device_id());
+        gpu_device_ids[device_proto.local_device_ordinal()] = global_device_id;
       }
       auto device = absl::make_unique<GpuDevice>(
           device_proto.global_device_id(), std::move(local_device),
@@ -283,8 +289,8 @@ Status BuildDistributedDevices(
   }
   gpu_executable_run_options->set_gpu_global_device_ids(
       std::move(gpu_device_ids));
-  auto nccl_id_store =
-      std::make_shared<NcclIdStore>(node_id, distributed_client);
+  auto nccl_id_store = std::make_shared<NcclIdStore>(
+      node_id, distributed_client, device_to_node);
   gpu_executable_run_options->set_nccl_unique_id_callback(
       [nccl_id_store](const NcclCliqueKey& key) {
         return nccl_id_store->GetNcclUniqueId(key);
diff --git a/tensorflow/compiler/xla/python/bfloat16.cc b/tensorflow/compiler/xla/python/bfloat16.cc
index 5dcfc3b0dcc..5f96c494c25 100644
--- a/tensorflow/compiler/xla/python/bfloat16.cc
+++ b/tensorflow/compiler/xla/python/bfloat16.cc
@@ -597,19 +597,25 @@ struct TypeDescriptor<uint16> {
   static int Dtype() { return NPY_UINT16; }
 };
 
+// We register "int", "long", and "long long" types for portability across
+// Linux, where "int" and "long" are the same type, and Windows, where "long"
+// and "longlong" are the same type.
 template <>
-struct TypeDescriptor<uint32> {
-  typedef uint32 T;
-  static int Dtype() { return NPY_UINT32; }
+struct TypeDescriptor<unsigned int> {
+  typedef unsigned int T;
+  static int Dtype() { return NPY_UINT; }
 };
 
-template <typename Uint64Type>
-struct TypeDescriptor<
-    Uint64Type, typename std::enable_if<std::is_integral<Uint64Type>::value &&
-                                        !std::is_signed<Uint64Type>::value &&
-                                        sizeof(Uint64Type) == 8>::type> {
-  typedef Uint64Type T;
-  static int Dtype() { return NPY_UINT64; }
+template <>
+struct TypeDescriptor<unsigned long> {  // NOLINT
+  typedef unsigned long T;              // NOLINT
+  static int Dtype() { return NPY_ULONG; }
+};
+
+template <>
+struct TypeDescriptor<unsigned long long> {  // NOLINT
+  typedef unsigned long long T;              // NOLINT
+  static int Dtype() { return NPY_ULONGLONG; }
 };
 
 template <>
@@ -625,18 +631,21 @@ struct TypeDescriptor<int16> {
 };
 
 template <>
-struct TypeDescriptor<int32> {
-  typedef int32 T;
-  static int Dtype() { return NPY_INT32; }
+struct TypeDescriptor<int> {
+  typedef int T;
+  static int Dtype() { return NPY_INT; }
 };
 
-template <typename Int64Type>
-struct TypeDescriptor<
-    Int64Type, typename std::enable_if<std::is_integral<Int64Type>::value &&
-                                       std::is_signed<Int64Type>::value &&
-                                       sizeof(Int64Type) == 8>::type> {
-  typedef Int64Type T;
-  static int Dtype() { return NPY_INT64; }
+template <>
+struct TypeDescriptor<long> {  // NOLINT
+  typedef long T;              // NOLINT
+  static int Dtype() { return NPY_LONG; }
+};
+
+template <>
+struct TypeDescriptor<long long> {  // NOLINT
+  typedef long long T;              // NOLINT
+  static int Dtype() { return NPY_LONGLONG; }
 };
 
 template <>
@@ -1354,7 +1363,15 @@ bool Initialize() {
   if (!RegisterBfloat16Cast<uint16>(NPY_UINT16, /*cast_is_safe=*/false)) {
     return false;
   }
-  if (!RegisterBfloat16Cast<uint32>(NPY_UINT32, /*cast_is_safe=*/false)) {
+  if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG,  // NOLINT
+                                           /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast<unsigned long long>(  // NOLINT
+          NPY_ULONGLONG, /*cast_is_safe=*/false)) {
     return false;
   }
   if (!RegisterBfloat16Cast<uint64>(NPY_UINT64, /*cast_is_safe=*/false)) {
@@ -1366,14 +1383,15 @@ bool Initialize() {
   if (!RegisterBfloat16Cast<int16>(NPY_INT16, /*cast_is_safe=*/false)) {
     return false;
   }
-  if (!RegisterBfloat16Cast<int32>(NPY_INT32, /*cast_is_safe=*/false)) {
+  if (!RegisterBfloat16Cast<int>(NPY_INT, /*cast_is_safe=*/false)) {
     return false;
   }
-  if (!RegisterBfloat16Cast<int64>(NPY_INT64, /*cast_is_safe=*/false)) {
+  if (!RegisterBfloat16Cast<long>(NPY_LONG,  // NOLINT
+                                  /*cast_is_safe=*/false)) {
     return false;
   }
-  if (!RegisterBfloat16Cast<npy_longlong>(NPY_LONGLONG,
-                                          /*cast_is_safe=*/false)) {
+  if (!RegisterBfloat16Cast<long long>(  // NOLINT
+          NPY_LONGLONG, /*cast_is_safe=*/false)) {
     return false;
   }
   // Following the numpy convention. imag part is dropped when converting to
diff --git a/tensorflow/compiler/xla/python/bfloat16_test.py b/tensorflow/compiler/xla/python/bfloat16_test.py
index 9aaa955d546..4c7321a5b7f 100644
--- a/tensorflow/compiler/xla/python/bfloat16_test.py
+++ b/tensorflow/compiler/xla/python/bfloat16_test.py
@@ -293,7 +293,7 @@ class Bfloat16NumPyTest(parameterized.TestCase):
     for dtype in [
         np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
         np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
-        np.uint64
+        np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
     ]:
       x = np.array([[1, 2, 3]], dtype=dtype)
       y = x.astype(bfloat16)
diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc
index af02c1ef0d4..27bbde09f09 100644
--- a/tensorflow/compiler/xla/python/jax_jit.cc
+++ b/tensorflow/compiler/xla/python/jax_jit.cc
@@ -595,7 +595,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
 
   static const auto* xla_module =
       new py::module(py::module::import("jax.interpreters.xla"));
-  const auto& device_array = xla_module->attr("DeviceArray");
+  const auto& device_array = xla_module->attr("_DeviceArray");
 
   static const auto* numpy_module = new py::module(py::module::import("numpy"));
   const auto& np_array = numpy_module->attr("array");
@@ -613,7 +613,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
     for (py::handle arg : arguments.flat_dynamic_args) {
       // We specically only deal with DeviceArray (not ShardedDeviceArray).
       // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
-      if (arg.get_type().is(device_array)) {
+      if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) {
         xla::PyBuffer* buffer;
         if (arg.attr("_device").is_none()) {  // Skip non-sticky devices.
           continue;
@@ -653,7 +653,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
   xla::PjRtClient* pjrt_client = data_device->client();
 
   for (py::handle arg : arguments.flat_dynamic_args) {
-    if (arg.get_type().is(device_array)) {
+    if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) {
       if (!HasTrivialLazyExpr(arg)) {
         return InvalidArgument(
             "Non-trivial lazy expression not supported in C++. "
diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc
index 04e68f9a563..60d47081e21 100644
--- a/tensorflow/compiler/xla/python/ops.cc
+++ b/tensorflow/compiler/xla/python/ops.cc
@@ -108,7 +108,8 @@ void BuildOpsSubmodule(py::module* m) {
           py::arg("lhs_dilation"), py::arg("rhs_dilation"),
           py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
           py::arg("batch_group_count") = 1,
-          py::arg("precision_config") = nullptr);
+          py::arg("precision_config") = nullptr,
+          py::arg("preferred_element_type") = absl::nullopt);
   ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
           py::arg("new_element_type"));
   ops.def(
@@ -136,9 +137,11 @@ void BuildOpsSubmodule(py::module* m) {
       py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
       py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false);
   ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
-          py::arg("precision_config") = nullptr);
+          py::arg("precision_config") = nullptr,
+          py::arg("preferred_element_type") = absl::nullopt);
   ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
-          py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
+          py::arg("dimension_numbers"), py::arg("precision_config") = nullptr,
+          py::arg("preferred_element_type") = absl::nullopt);
   ops.def("DynamicSlice",
           static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
                                 absl::Span<const int64>)>(&DynamicSlice),
diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc
index cac14142b75..1f39266a989 100644
--- a/tensorflow/compiler/xla/python/py_buffer.cc
+++ b/tensorflow/compiler/xla/python/py_buffer.cc
@@ -238,4 +238,22 @@ PyBufferProcs PjRtBufferProcs = []() {
   return &PjRtBufferProcs;
 }
 
+void PyBuffer::SetStickyDevice(pybind11::object sticky_device) {
+  if (sticky_device_ && !sticky_device_->equal(sticky_device)) {
+    throw std::invalid_argument(
+        "One cannot set again the stickyness of a buffer and needs to create "
+        "a new one or a `_DeviceArray`");
+  }
+  sticky_device_ = sticky_device;
+}
+
+void PyBuffer::SetAval(pybind11::object aval) {
+  if (aval_ && !aval_->equal(aval)) {
+    throw std::invalid_argument(
+        "One cannot set again the aval_ of a buffer and needs to create a "
+        "new one or a `_DeviceArray`");
+  }
+  aval_ = aval;
+}
+
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/python/py_buffer.h b/tensorflow/compiler/xla/python/py_buffer.h
index 0562beaa14a..c7c62d83225 100644
--- a/tensorflow/compiler/xla/python/py_buffer.h
+++ b/tensorflow/compiler/xla/python/py_buffer.h
@@ -17,8 +17,11 @@ limitations under the License.
 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
 
 #include <memory>
+#include <stdexcept>
 #include <vector>
 
+#include "absl/types/optional.h"
+#include "pybind11/pybind11.h"
 #include "tensorflow/compiler/xla/python/py_client.h"
 #include "tensorflow/compiler/xla/python/traceback.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -41,6 +44,9 @@ class DeviceArrayBase {
 // Python wrapper around PjRtBuffer. We use a wrapper class:
 // a) to keep the PjRtClient alive via a std::shared_ptr<>
 // b) to add Python-specific functionality.
+//
+// A `PyBuffer` can be used from Python without being wrapped in a Python
+// `DeviceArray` object, at the condition there is no associated LazyExpr.
 class PyBuffer : public DeviceArrayBase {
  public:
   PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer,
@@ -57,8 +63,12 @@ class PyBuffer : public DeviceArrayBase {
   StatusOr<std::unique_ptr<PyBuffer>> CopyToDevice(
       const ClientAndPtr<PjRtDevice>& dst_device) const;
 
-  void Delete() { return buffer_->Delete(); }
+  void Delete() {
+    buffer_->Delete();
+    npy_value_ = pybind11::none();
+  }
 
+  // Returns xla::InvalidArgument if the buffer has been deleted.
   Status BlockHostUntilReady();
   Status CopyToHostAsync() { return buffer_->CopyToHostAsync(); }
 
@@ -75,13 +85,33 @@ class PyBuffer : public DeviceArrayBase {
 
   Traceback* traceback() { return traceback_.get(); }
 
+  // Returns the size (i.e. number of elements) of the (host) numpy array.
+  int64 size() { return ShapeUtil::ElementsIn(buffer()->on_host_shape()); }
+
+  // Returns the number of dimensions of the (host) numpy array.
+  int ndim() const { return buffer()->on_host_shape().dimensions_size(); }
+
+  void SetStickyDevice(pybind11::object sticky_device);
+  pybind11::object GetStickyDevice() const { return sticky_device_.value(); }
+
+  void SetNpyValue(pybind11::object npy_value) { npy_value_ = npy_value; }
+  pybind11::object GetNpyValue() const { return npy_value_; }
+
+  void SetAval(pybind11::object aval);
+  pybind11::object GetAval() const { return aval_.value(); }
+
  private:
   friend class PyClient;
 
   std::shared_ptr<PyClient> client_;
   std::unique_ptr<PjRtBuffer> buffer_;
   std::shared_ptr<Traceback> traceback_;
-
+  // The host numpy array caching the value when it has been copied to the host.
+  pybind11::object npy_value_ = pybind11::none();
+  absl::optional<pybind11::object> sticky_device_ = absl::nullopt;
+  // TODO(jblespiau): It's currently there for convenience but maybe we can do
+  // without it (adding `weak_type` instead).
+  absl::optional<pybind11::object> aval_ = absl::nullopt;
   // Doubly-linked list of all buffers known to the client. Protected by the
   // GIL.
   PyBuffer* next_;
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index b7c5148a4cf..00af81b7064 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -73,6 +73,25 @@ bool IsOptimizedBuild() {
 #endif  // NDEBUG
 }
 
+StatusOr<py::object> BufferToPython(PyBuffer* buffer, py::handle& buffer_obj) {
+  GlobalPyRefManager()->CollectGarbage();
+  if (buffer->buffer()->IsOnCpu() &&
+      buffer->buffer()->on_device_shape().IsArray() &&
+      buffer->buffer()->on_device_shape().element_type() != BF16) {
+    py::object out =
+        py::reinterpret_steal<py::object>(PyArray_FROM_O(buffer_obj.ptr()));
+    CHECK(out.ptr() != nullptr) << buffer->buffer()->on_host_shape().ToString(
+        /*print_layout=*/true);
+    return out;
+  }
+  std::shared_ptr<Literal> literal;
+  {
+    py::gil_scoped_release gil_release;
+    TF_ASSIGN_OR_RETURN(literal, buffer->buffer()->ToLiteral());
+  }
+  return LiteralToPython(std::move(literal));
+}
+
 }  // namespace
 
 PYBIND11_MODULE(xla_extension, m) {
@@ -261,33 +280,65 @@ PYBIND11_MODULE(xla_extension, m) {
   // TODO(phawkins): alias for backward compatibility. Remove after JAX no
   // longer uses this name.
   m.add_object("PyLocalBuffer", buffer);
-  buffer.def("copy_to_device", &PyBuffer::CopyToDevice)
+  buffer
+      .def_property_readonly("__array_priority__",
+                             [](py::object) { return 100; })
+      .def_property("_device", &PyBuffer::GetStickyDevice,
+                    &PyBuffer::SetStickyDevice)
+      .def_property("aval", &PyBuffer::GetAval, &PyBuffer::SetAval)
+      .def_property_readonly("_lazy_expr",
+                             [](py::object buffer) { return py::none(); })
+      .def_property_readonly("device_buffer",
+                             [](py::object buffer) { return buffer; })
+      .def_property_readonly(
+          "shape",
+          [](const PyBuffer& pybuffer) -> pybind11::tuple {
+            return IntSpanToTuple(
+                pybuffer.buffer()->on_host_shape().dimensions());
+          })
+      .def_property_readonly(
+          "dtype",
+          [](const PyBuffer& buffer) {
+            PrimitiveType primitive =
+                buffer.buffer()->on_host_shape().element_type();
+            return PrimitiveTypeToDtype(primitive).ValueOrDie();
+          })
+      .def_property_readonly("size", &PyBuffer::size)
+      .def_property_readonly("ndim", &PyBuffer::ndim)
+      .def_property_readonly(
+          "_value",
+          [](py::handle buffer_obj) -> pybind11::object {
+            PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
+            if (buffer->is_deleted()) {
+              throw std::runtime_error("DeviceArray has been deleted.");
+            }
+            py::object npy_value_ = buffer->GetNpyValue();
+            if (npy_value_.is_none()) {
+              npy_value_ = BufferToPython(buffer, buffer_obj).ValueOrDie();
+              // TODO(jblspiau): Change `LiteralToPython` to return a
+              // `py::array`, so we can set more easily the attribute.
+              npy_value_.attr("flags").attr("writeable") = Py_False;
+              buffer->SetNpyValue(npy_value_);
+            }
+            return npy_value_;
+          })
+      .def("copy_to_device", &PyBuffer::CopyToDevice)
       .def("delete", &PyBuffer::Delete)
+      // The GIL is released within BlockHostUntilReady.
+      .def("block_until_ready",
+           [](py::object buffer_obj) -> xla::StatusOr<py::object> {
+             PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
+             TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
+             return buffer_obj;
+           })
       .def("block_host_until_ready", &PyBuffer::BlockHostUntilReady)
       .def("copy_to_host_async", &PyBuffer::CopyToHostAsync,
            py::call_guard<py::gil_scoped_release>())
       .def("to_py",
-           [](py::object buffer_obj) -> StatusOr<py::object> {
-             GlobalPyRefManager()->CollectGarbage();
+           [](py::handle buffer_obj) {
              PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
-             if (buffer->buffer()->IsOnCpu() &&
-                 buffer->buffer()->on_device_shape().IsArray() &&
-                 buffer->buffer()->on_device_shape().element_type() != BF16) {
-               py::object out = py::reinterpret_steal<py::object>(
-                   PyArray_FROM_O(buffer_obj.ptr()));
-               CHECK(out.ptr() != nullptr)
-                   << buffer->buffer()->on_host_shape().ToString(
-                          /*print_layout=*/true);
-               return out;
-             }
-             std::shared_ptr<Literal> literal;
-             {
-               py::gil_scoped_release gil_release;
-               TF_ASSIGN_OR_RETURN(literal, buffer->buffer()->ToLiteral());
-             }
-             return LiteralToPython(std::move(literal));
+             return BufferToPython(buffer, buffer_obj);
            })
-      .def("shape", &PyBuffer::shape)
       .def("xla_shape", &PyBuffer::shape)
       .def_property_readonly("client", &PyBuffer::client)
       .def("device", &PyBuffer::device)
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 63891c537ff..0e538fa8ca8 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -28,6 +28,8 @@ import inspect
 import os
 from typing import List, Sequence, Tuple, Union
 
+from . import xla_extension as _xla
+
 from absl import logging
 import numpy as np
 
@@ -36,8 +38,6 @@ import numpy as np
 # of TensorFlow. If we use protocol buffers here, then importing both jaxlib
 # and TensorFlow may fail with duplicate protocol buffer message definitions.
 
-from tensorflow.compiler.xla.python import xla_extension as _xla
-
 # Most functions are snake_case for consistency with other modules, some
 # method names are CamelCase for consistency with XLA.
 # pylint: disable=invalid-name
diff --git a/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py b/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py
index 180bb040cc4..eb9c90941b6 100644
--- a/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py
@@ -38,8 +38,7 @@ ops = xla_client.ops
 class ShapeTest(absltest.TestCase):
 
   def testInvalidShapes(self):
-    with self.assertRaisesRegex(RuntimeError,
-                                "shape's dimensions must not be < 0.*"):
+    with self.assertRaisesRegex(RuntimeError, "invalid shape"):
       xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4])
 
     with self.assertRaisesRegex(
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index fd2a065c34f..fc30883880e 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -14,10 +14,6 @@
 # ==============================================================================
 """Backend-dependent tests for the Python XLA client."""
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
 import functools
 import itertools
 import re
@@ -445,17 +441,10 @@ def TestFactory(xla_backend, cloud_tpu=False):
       with self.assertRaises(RuntimeError):
         compiled_c.execute([arg_buffer])
 
-    def testShape(self):
-      pyval = np.array([[1., 2.]], np.float32)
-      local_buffer = self.backend.buffer_from_pyval(pyval)
-      xla_shape = local_buffer.shape()
-      self.assertEqual(xla_shape.dimensions(), (1, 2))
-      self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
-
     def testXlaShape(self):
       pyval = np.array([[1., 2.]], np.float32)
-      buffer = self.backend.buffer_from_pyval(pyval)
-      xla_shape = buffer.xla_shape()
+      local_buffer = self.backend.buffer_from_pyval(pyval)
+      xla_shape = local_buffer.xla_shape()
       self.assertEqual(xla_shape.dimensions(), (1, 2))
       self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
 
@@ -476,6 +465,27 @@ def TestFactory(xla_backend, cloud_tpu=False):
               "BlockHostUntilReady() called on deleted or donated buffer")):
         buffer.block_host_until_ready()
 
+    def testDeviceArrayBaseSignatures(self):
+      # When extending `DeviceArrayBase`, the object behaves as a `DeviceArray`
+      # and thus needs to correctly implement the following methods.
+      arg = np.array([[1., 2., 3.]], np.float32)
+      buffer = self.backend.buffer_from_pyval(arg)
+      if not isinstance(buffer, xla_client.DeviceArrayBase):
+        raise unittest.SkipTest(
+            "The objectof type {} do not extend DeviceArrayBase".format(
+                type(buffer)))
+
+      self.assertEqual(buffer.__array_priority__, 100)
+      self.assertEqual(buffer.shape, (1, 3))
+      self.assertEqual(buffer.dtype, np.float32)
+      self.assertEqual(buffer.size, 3)
+      self.assertEqual(buffer.ndim, 2)
+
+      self.assertIs(buffer, buffer.block_until_ready())
+      buffer.delete()
+      with self.assertRaises(RuntimeError):
+        buffer.block_until_ready()
+
     def testCopyToHost(self):
       arg0 = np.array([[1., 2.]], np.float32)
       arg1 = np.array([[3., 4.]], np.float32)
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 007ba568527..c1e6b1b1a1d 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -465,7 +465,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
   const Shape& shape =
       ShapeInference::InferConvolveShape(
           lhs_literal.shape(), rhs_literal.shape(),
-          /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums)
+          /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
+          /*preferred_element_type=*/absl::nullopt)
           .ConsumeValueOrDie();
 
   HloInstruction* lhs_instruction =
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 85de316e569..5fe89e84995 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3721,6 +3721,7 @@ cc_library(
         ":hlo_casting_utils",
         ":hlo_pass",
         ":shape_inference",
+        "//tensorflow/compiler/xla:comparison_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:util",
@@ -3774,6 +3775,7 @@ cc_library(
         "//tensorflow/compiler/xla:util",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/container:inlined_vector",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 39d1eddc569..9a725cd541d 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1798,10 +1798,10 @@ StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
                 ShapeUtil::DropDegenerateDimensions(rhs_shape),
                 dot->mutable_operand(1)))
           : dot->mutable_operand(1);
-  TF_ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums,
-                                               dot->precision_config()));
-  // TODO(b/165824019): Add an optional preferred element type to MakeDotHlo.
-  new_dot->mutable_shape()->set_element_type(dot->shape().element_type());
+  TF_ASSIGN_OR_RETURN(
+      auto new_dot,
+      MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(),
+                 /*preferred_element_type=*/dot->shape().element_type()));
   if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
     TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
   } else {
@@ -4678,10 +4678,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
       }
     }
     TF_ASSIGN_OR_RETURN(
-        auto new_dot, MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config()));
+        auto new_dot,
+        MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(),
+                   /*preferred_element_type=*/dot->shape().element_type()));
     dot->SetupDerivedInstruction(new_dot);
-    // TODO(b/165824019): Add an optional preferred element type to MakeDotHlo.
-    new_dot->mutable_shape()->set_element_type(dot->shape().element_type());
     if (reduce_dims.empty()) {
       return ReplaceInstruction(hlo, new_dot);
     }
@@ -5312,10 +5312,13 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
   if (!reverse_dimensions.empty()) {
     TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
   }
-  TF_ASSIGN_OR_RETURN(HloInstruction * new_convolution,
-                      MakeConvolveHlo(kernel, input, /*feature_group_count=*/1,
-                                      /*batch_group_count=*/1, swapped_window,
-                                      swapped_dnums, precision_config));
+  TF_ASSIGN_OR_RETURN(
+      HloInstruction * new_convolution,
+      MakeConvolveHlo(
+          kernel, input, /*feature_group_count=*/1,
+          /*batch_group_count=*/1, swapped_window, swapped_dnums,
+          precision_config,
+          /*preferred_element_type=*/convolution->shape().element_type()));
 
   convolution->SetupDerivedInstruction(new_convolution);
   TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 91c6c29ee80..1050ba9ad64 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -3785,9 +3785,11 @@ TEST_P(ConvInputPaddingTest, DoTest) {
       ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
           .ValueOrDie();
   builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
-                                         /*feature_group_count=*/1,
-                                         /*batch_group_count=*/1, window, dnums)
+      ShapeInference::InferConvolveShape(
+          lhs_pad->shape(), filter->shape(),
+          /*feature_group_count=*/1,
+          /*batch_group_count=*/1, window, dnums,
+          /*preferred_element_type=*/absl::nullopt)
           .ValueOrDie(),
       lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1,
       window, dnums, DefaultPrecisionConfig(2)));
@@ -3902,9 +3904,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
   precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
 
   builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
-                                         /*feature_group_count=*/1,
-                                         /*batch_group_count=*/1, window, dnums)
+      ShapeInference::InferConvolveShape(
+          input->shape(), rhs_pad->shape(),
+          /*feature_group_count=*/1,
+          /*batch_group_count=*/1, window, dnums,
+          /*preferred_element_type=*/absl::nullopt)
           .ValueOrDie(),
       input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1,
       window, dnums, precision_config));
@@ -4050,7 +4054,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
         b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
     Shape out_shape = ShapeInference::InferConvolveShape(
                           in_shape, f_shape, /*feature_group_count=*/1,
-                          /*batch_group_count=*/1, window, dnums)
+                          /*batch_group_count=*/1, window, dnums,
+                          /*preferred_element_type=*/absl::nullopt)
                           .ValueOrDie();
     if (options.output_minor_to_major_layout) {
       out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(),
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index 72112585cb3..942ee59a3aa 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -87,9 +87,11 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
       0,
       new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size());
 
-  TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
-                      MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
-                                 batch_dot->precision_config()));
+  TF_ASSIGN_OR_RETURN(
+      HloInstruction * new_dot,
+      MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
+                 batch_dot->precision_config(),
+                 /*preferred_element_type=*/batch_dot->shape().element_type()));
 
   TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
                       MakeReshapeHlo(batch_dot->shape(), new_dot));
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc
index 199bc787b83..bff68574e1a 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc
@@ -144,6 +144,15 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
         << conditional->ToShortString();
     return false;
   }
+
+  bool branch_empty =
+      ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) ||
+      ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1));
+  // Empty branch is faster to execute than select.
+  if (branch_empty) {
+    return false;
+  }
+
   HloInstruction* true_call_op = create_call(0);
   HloInstruction* false_call_op = create_call(1);
   auto condition_broadcast = [&](const Shape& shape) {
@@ -160,13 +169,6 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
         hlo->shape().tuple_shapes(i), hlo, i));
   };
 
-  bool branch_empty =
-      ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) ||
-      ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1));
-  // Empty branch is faster to execute than select.
-  if (branch_empty) {
-    return false;
-  }
   std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
       [&](HloInstruction* t, HloInstruction* f) {
         if (f->shape().IsToken()) {
diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc
index f5506b894fd..b202c3adc41 100644
--- a/tensorflow/compiler/xla/service/convolution_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc
@@ -299,9 +299,11 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
     window_dim->set_window_reversal(false);
     window_dim->set_window_dilation(1);
     HloInstruction* new_convolution =
-        MakeConvolveHlo(activation, filter, convolution->feature_group_count(),
-                        /*batch_group_count=*/1, window, dim_numbers,
-                        convolution->precision_config())
+        MakeConvolveHlo(
+            activation, filter, convolution->feature_group_count(),
+            /*batch_group_count=*/1, window, dim_numbers,
+            convolution->precision_config(),
+            /*preferred_element_type=*/convolution->shape().element_type())
             .ValueOrDie();
     convolution->SetupDerivedInstruction(new_convolution);
     TF_CHECK_OK(computation_->ReplaceInstruction(
@@ -650,9 +652,11 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
   window_dim->set_window_reversal(false);
   window_dim->set_window_dilation(1);
   HloInstruction* new_convolution =
-      MakeConvolveHlo(activation, filter, /*feature_group_count=*/1,
-                      /*batch_group_count=*/1, window, dim_numbers,
-                      convolution->precision_config())
+      MakeConvolveHlo(
+          activation, filter, /*feature_group_count=*/1,
+          /*batch_group_count=*/1, window, dim_numbers,
+          convolution->precision_config(),
+          /*preferred_element_type=*/convolution->shape().element_type())
           .ValueOrDie();
   convolution->SetupDerivedInstruction(new_convolution);
   changed_ = true;
diff --git a/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h b/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h
index 857de4a8143..bfd04ab60a7 100644
--- a/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h
+++ b/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h
@@ -21,8 +21,13 @@ limitations under the License.
 static const char kTargetCpuForHost[] = "ppc";
 static const char kTargetTripleForHost[] = "ppc64le-ibm-linux-gnu";
 #else
+#if defined(__s390x__)
+static const char kTargetCpuForHost[] = "s390x";
+static const char kTargetTripleForHost[] = "systemz-none-linux-gnu";
+#else
 static const char kTargetCpuForHost[] = "";
 static const char kTargetTripleForHost[] = "x86_64-pc-linux";
 #endif
+#endif
 
 #endif
diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc
index 3adde5f7d48..2abbe887a51 100644
--- a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc
+++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc
@@ -142,7 +142,8 @@ CreateShardedConvForDotGeneralConvolution(
       ShapeInference::InferConvolveShape(
           sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
           /*feature_group_count=*/conv.feature_group_count(),
-          /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums));
+          /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
+          /*preferred_element_type=*/conv.shape().element_type()));
   *sharded_conv_shape.mutable_layout() = conv.shape().layout();
   return HloInstruction::CreateConvolve(
       sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 5edbfebddda..1ebb82d7e81 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -268,6 +268,7 @@ cc_library(
         "//tensorflow/compiler/mlir:name_utils",
         "//tensorflow/compiler/mlir/hlo",
         "//tensorflow/compiler/mlir/hlo:lhlo",
+        "//tensorflow/compiler/mlir/hlo:lhlo_gpu",
         "//tensorflow/compiler/mlir/xla:hlo_module_importer",
         "//tensorflow/compiler/mlir/xla:hlo_utils",
         "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc
index b10c33da2fd..e9e65c9ebfa 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc
@@ -110,7 +110,8 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) {
       ShapeInference::InferConvolveShape(
           activations->shape(), gradients->shape(), /*feature_group_count=*/1,
           /*batch_group_count=*/1, conv_window,
-          tf_default_dnums_for_backward_filter_)
+          tf_default_dnums_for_backward_filter_,
+          /*preferred_element_type=*/absl::nullopt)
           .ConsumeValueOrDie(),
       activations, gradients, /*feature_group_count=*/1,
       /*batch_group_count=*/1, conv_window,
@@ -150,7 +151,8 @@ TEST_F(GpuConvRewriterTest,
       ShapeInference::InferConvolveShape(
           activations->shape(), gradients->shape(), /*feature_group_count=*/1,
           /*batch_group_count=*/1, conv_window,
-          tf_default_dnums_for_backward_filter_)
+          tf_default_dnums_for_backward_filter_,
+          /*preferred_element_type=*/absl::nullopt)
           .ConsumeValueOrDie(),
       activations, gradients, /*feature_group_count=*/1,
       /*batch_group_count=*/1, conv_window,
@@ -292,11 +294,12 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) {
       DefaultPrecisionConfig(2)));
   // Verify the convolution's shape is consistent with ShapeInference.
   CHECK(ShapeUtil::Compatible(
-      conv->shape(), ShapeInference::InferConvolveShape(
-                         output->shape(), reverse_kernel->shape(),
-                         /*feature_group_count=*/1, /*batch_group_count=*/1,
-                         conv_window, conv_dnums)
-                         .ValueOrDie()));
+      conv->shape(),
+      ShapeInference::InferConvolveShape(
+          output->shape(), reverse_kernel->shape(),
+          /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+          conv_dnums, /*preferred_element_type=*/absl::nullopt)
+          .ValueOrDie()));
 
   auto module = CreateNewVerifiedModule();
   HloComputation* entry_computation =
@@ -337,10 +340,12 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) {
   conv_window.mutable_dimensions(1)->set_base_dilation(2);
 
   builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
-                                         /*feature_group_count=*/1,
-                                         /*batch_group_count=*/1, conv_window,
-                                         tf_default_dnums_for_backward_input_)
+      ShapeInference::InferConvolveShape(
+          output->shape(), kernel->shape(),
+          /*feature_group_count=*/1,
+          /*batch_group_count=*/1, conv_window,
+          tf_default_dnums_for_backward_input_,
+          /*preferred_element_type=*/absl::nullopt)
           .ConsumeValueOrDie(),
       /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
       /*batch_group_count=*/1, conv_window,
@@ -374,7 +379,8 @@ TEST_F(GpuConvRewriterTest,
       ShapeInference::InferConvolveShape(
           output->shape(), kernel->shape(), /*feature_group_count=*/1,
           /*batch_group_count=*/1, default_conv_window_,
-          tf_default_dnums_for_backward_input_)
+          tf_default_dnums_for_backward_input_,
+          /*preferred_element_type=*/absl::nullopt)
           .ConsumeValueOrDie(),
       /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
       /*batch_group_count=*/1, default_conv_window_,
@@ -431,7 +437,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) {
       conv->shape(), ShapeInference::InferConvolveShape(
                          output->shape(), reverse_kernel->shape(),
                          /*feature_group_count=*/1, /*batch_group_count=*/1,
-                         conv_window, tf_default_dnums_for_backward_input_)
+                         conv_window, tf_default_dnums_for_backward_input_,
+                         /*preferred_element_type=*/absl::nullopt)
                          .ValueOrDie()));
 
   auto module = CreateNewVerifiedModule();
@@ -481,7 +488,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
       conv->shape(), ShapeInference::InferConvolveShape(
                          output->shape(), reverse_kernel->shape(),
                          /*feature_group_count=*/1, /*batch_group_count=*/1,
-                         conv_window, tf_default_dnums_for_backward_input_)
+                         conv_window, tf_default_dnums_for_backward_input_,
+                         /*preferred_element_type=*/absl::nullopt)
                          .ValueOrDie()));
 
   auto module = CreateNewVerifiedModule();
@@ -535,7 +543,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) {
       conv->shape(), ShapeInference::InferConvolveShape(
                          output->shape(), reverse_kernel->shape(),
                          /*feature_group_count=*/1, /*batch_group_count=*/1,
-                         conv_window, tf_default_dnums_for_backward_input_)
+                         conv_window, tf_default_dnums_for_backward_input_,
+                         /*preferred_element_type=*/absl::nullopt)
                          .ValueOrDie()));
 
   auto module = CreateNewVerifiedModule();
@@ -590,7 +599,8 @@ TEST_F(GpuConvRewriterTest,
       conv->shape(), ShapeInference::InferConvolveShape(
                          output->shape(), reverse_kernel->shape(),
                          /*feature_group_count=*/1, /*batch_group_count=*/1,
-                         conv_window, tf_default_dnums_for_backward_input_)
+                         conv_window, tf_default_dnums_for_backward_input_,
+                         /*preferred_element_type=*/absl::nullopt)
                          .ValueOrDie()));
 
   auto module = CreateNewVerifiedModule();
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc
index b152962eb99..3566742a03b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc
@@ -37,6 +37,10 @@ NcclCliqueKey::NcclCliqueKey(std::vector<GlobalDeviceId> devices)
       << GlobalDeviceIdsToString(devices_);
 }
 
+std::string NcclCliqueKey::ToString() const {
+  return GlobalDeviceIdsToString(devices_);
+}
+
 GpuExecutableRunOptions& GpuExecutableRunOptions::set_gpu_global_device_ids(
     absl::optional<std::vector<GlobalDeviceId>> gpu_global_device_ids) {
   gpu_global_device_ids_ = std::move(gpu_global_device_ids);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h
index 7a43c80121b..58284b2292b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h
@@ -53,6 +53,8 @@ class NcclCliqueKey {
 
   const std::vector<GlobalDeviceId>& devices() const { return devices_; }
 
+  std::string ToString() const;
+
  private:
   std::vector<GlobalDeviceId> devices_;
 };
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
index 34b93ca5b3f..81cfaabd63a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
@@ -20,6 +20,7 @@ limitations under the License.
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
@@ -50,9 +51,9 @@ class IrEmitterContext {
         profile_index_map_(profile_index_map),
         mlir_context_(mlir_context),
         llvm_module_(llvm_module) {
-    mlir_context_
-        ->loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
-                      mlir::StandardOpsDialect>();
+    mlir_context_->loadDialect<
+        mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
+        mlir::StandardOpsDialect, mlir::lmhlo_gpu::LmhloGpuDialect>();
   }
   // Disallow copy and assign.
   IrEmitterContext(const IrEmitterContext&) = delete;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 68e29bf68c6..6b7a0f03e23 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -167,18 +167,28 @@ int64_t GetAllocationIndex(mlir::BlockArgument func_arg) {
       .getSExtValue();
 }
 
+static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) {
+  // For i1 memrefs, the underlying allocation is 8 bits.
+  if (type.getElementType().isInteger(/*width=*/1)) {
+    return type.getNumElements();
+  } else {
+    return type.getSizeInBits() / CHAR_BIT;
+  }
+}
+
 StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
     mlir::Value v, absl::Span<const BufferAllocation> allocations) {
-  int64 size = v.getType().cast<mlir::MemRefType>().getSizeInBits() / 8;
+  int64 size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
 
   if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
     return BufferAllocation::Slice(&allocations[GetAllocationIndex(arg)], 0,
                                    size);
   }
 
-  // We match two patterns here:
-  // * v = ViewOp(arg);
-  // * v = MemRefReinterpretCastOp(ViewOp(arg));
+  // We match the following patterns here:
+  //  base := ViewOp(arg) | get_global_memref (global_memref)
+  //  root := base | MemRefReinterpretCastOp(base)
+
   if (mlir::Operation* op = v.getDefiningOp()) {
     if (auto cast = mlir::dyn_cast<mlir::MemRefReinterpretCastOp>(op)) {
       mlir::Value source = cast.getViewSource();
@@ -197,6 +207,14 @@ StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
               .getValue()
               .getSExtValue(),
           size);
+    } else if (mlir::isa<mlir::GetGlobalMemrefOp>(op)) {
+      int64_t index =
+          op->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
+      int64_t offset =
+          op->getAttrOfType<mlir::IntegerAttr>("lmhlo.slice_offset").getInt();
+      int64_t size =
+          op->getAttrOfType<mlir::IntegerAttr>("lmhlo.slice_size").getInt();
+      return BufferAllocation::Slice(&allocations[index], offset, size);
     }
     return Unimplemented("MemRefReinterpretCastOp has to wrap a ViewOp");
   }
@@ -554,13 +572,24 @@ Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
   // because we don't have a fully functioning LMHLO graph yet.
 
   mlir::Location loc = input.op->getLoc();
-  mlir::lmhlo::FusionOp fusion = nullptr;
+  mlir::lmhlo::FusionOp fusion =
+      mlir::OpBuilder(input.op).create<mlir::lmhlo::FusionOp>(
+          loc, llvm::ArrayRef<mlir::NamedAttribute>());
   Shape output_shape;
+  mlir::OpBuilder b(&fusion.region());
+
+  const auto load_memrefs = [loc, &b](mlir::ValueRange range) {
+    std::vector<mlir::Value> operands;
+    for (mlir::Value memref : range) {
+      auto load = b.create<mlir::TensorLoadOp>(loc, memref);
+      HloFunctionImporter::SetLayoutForMlir(load,
+                                            TypeToShape(memref.getType()));
+      operands.push_back(load);
+    }
+    return operands;
+  };
+
   if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(input.op)) {
-    fusion = mlir::OpBuilder(copy).create<mlir::lmhlo::FusionOp>(
-        loc, llvm::ArrayRef<mlir::NamedAttribute>());
-    copy.getOperation()->moveBefore(&fusion.region().front().back());
-    mlir::OpBuilder b(copy);
     auto operand = b.create<mlir::TensorLoadOp>(loc, copy.operand());
     HloFunctionImporter::SetLayoutForMlir(
         operand, TypeToShape(copy.operand().getType()));
@@ -568,15 +597,41 @@ Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
     output_shape = TypeToShape(copy.output().getType());
     HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape);
     b.create<mlir::TensorStoreOp>(loc, fused_copy, copy.output());
-    copy.getOperation()->erase();
+  } else if (auto reduce = mlir::dyn_cast<mlir::lmhlo::ReduceOp>(input.op)) {
+    std::vector<mlir::Value> operands = load_memrefs(reduce.operands());
+    std::vector<mlir::Value> init_values = load_memrefs(reduce.init_values());
+    auto fused_reduce = b.create<mlir::mhlo::ReduceOp>(
+        loc, operands, init_values, reduce.dimensions());
+    fused_reduce.body().takeBody(reduce.body());
+    CHECK_EQ(fused_reduce.getNumResults(), reduce.out().size());
+    std::vector<Shape> output_shapes;
+    for (int i = 0; i < reduce.out().size(); i++) {
+      b.create<mlir::TensorStoreOp>(loc, fused_reduce.getResult(i),
+                                    reduce.out()[i]);
+      auto shape = TypeToShape(reduce.out()[i].getType());
+      if (i == 0) {
+        HloFunctionImporter::SetLayoutForMlir(fused_reduce, shape);
+      }
+      output_shapes.push_back(shape);
+    }
+    if (output_shapes.size() == 1) {
+      output_shape = output_shapes[0];
+    } else {
+      output_shape = ShapeUtil::MakeTupleShape(output_shapes);
+    }
   } else {
     input.op->dump();
     LOG(FATAL) << "Unimplemented default action for mlir op";
   }
+  input.op->erase();
   input.op = fusion;
-  auto ret = EmitLoopFusionFromMlir(
-      input, output_shape,
-      ComputeMaxUnrollFactor(output_shape, hlo_module_config_));
+  int unroll_factor = 1;
+  // TODO(timshen): Port MayPreventVectorization as we add more ops into this
+  // function.
+  if (output_shape.IsArray()) {
+    unroll_factor = ComputeMaxUnrollFactor(output_shape, hlo_module_config_);
+  }
+  auto ret = EmitLoopFusionFromMlir(input, output_shape, unroll_factor);
   return ret;
 }
 
@@ -893,7 +948,8 @@ Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) {
 // This function won't be needed once ElementalIrEmitter migrates to take MHLO
 // instead.
 static Status ProcessFusionForConversion(mlir::Region* region,
-                                         std::vector<Shape>* operand_shapes) {
+                                         std::vector<Shape>* operand_shapes,
+                                         std::vector<Shape>* output_shapes) {
   std::vector<mlir::TensorLoadOp> loads;
   std::vector<mlir::TensorStoreOp> stores;
 
@@ -913,8 +969,7 @@ static Status ProcessFusionForConversion(mlir::Region* region,
     auto arg = region->addArgument(load.getType());
     load.replaceAllUsesWith(arg);
     Shape shape = TypeToShape(load.getType());
-    auto attr = mlir::GetLayoutFromMlirHlo(load);
-    if (attr) {
+    if (auto attr = mlir::GetLayoutFromMlirHlo(load)) {
       std::vector<int64> minor_to_major;
       absl::c_transform(
           attr, std::back_inserter(minor_to_major),
@@ -930,6 +985,16 @@ static Status ProcessFusionForConversion(mlir::Region* region,
 
   std::vector<mlir::Value> returned_values;
   for (auto store : stores) {
+    Shape shape = TypeToShape(store.memref().getType());
+    if (auto attr = mlir::GetLayoutFromMlirHlo(store)) {
+      std::vector<int64> minor_to_major;
+      absl::c_transform(
+          attr, std::back_inserter(minor_to_major),
+          std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
+      *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+    }
+    output_shapes->push_back(shape);
+
     returned_values.push_back(store.tensor());
     store.erase();
   }
@@ -1254,12 +1319,14 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
 }
 
 Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
+  TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(reduce));
+
   if (IsReductionFromOrToContiguousDimensions(*reduce) &&
       reduce->shape().IsArray()) {
     return EmitReductionFromOrToContiguousDimensions(reduce, {reduce});
   }
 
-  return IrEmitter::HandleReduce(reduce);
+  return DefaultActionForMlir(input);
 }
 
 Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
@@ -1325,23 +1392,23 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
         "Dilation for SelectAndScatter not implemented on GPU.");
   }
 
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
-                      BuildInitializerThunk(select_and_scatter));
-
   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(select_and_scatter));
-  return EmitSelectAndScatterFromMlir(input, std::move(initializer_thunk));
+  return EmitSelectAndScatterFromMlir(input);
 }
 
 Status IrEmitterUnnested::EmitSelectAndScatterFromMlir(
-    MlirEmitterInput mlir_input, std::unique_ptr<Thunk>&& initializer_thunk) {
+    MlirEmitterInput mlir_input) {
   auto select_and_scatter_op =
       ::mlir::cast<::mlir::lmhlo::SelectAndScatterOp>(mlir_input.op);
 
   std::string name = mlir::GetNameFromLoc(select_and_scatter_op.getLoc());
 
   std::vector<std::unique_ptr<Thunk>> thunks;
-  thunks.push_back(std::move(initializer_thunk));
-
+  thunks.emplace_back();
+  TF_ASSIGN_OR_RETURN(thunks.back(),
+                      BuildInitializerThunkForMlir(
+                          mlir_input.op, select_and_scatter_op.init_value(),
+                          select_and_scatter_op.out()));
   absl::Span<const BufferAllocation> allocations(
       ir_emitter_context_->buffer_assignment().Allocations());
 
@@ -1862,9 +1929,10 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
                                                        bool is_fusion) {
   std::unique_ptr<HloModule>& module = scratch_nested_computations_[region];
   if (module == nullptr) {
-    std::vector<Shape> operand_shapes;
+    std::vector<Shape> operand_shapes, output_shapes;
     if (is_fusion) {
-      TF_RETURN_IF_ERROR(ProcessFusionForConversion(region, &operand_shapes));
+      TF_RETURN_IF_ERROR(
+          ProcessFusionForConversion(region, &operand_shapes, &output_shapes));
     }
 
     xla::XlaComputation xla_computation;
@@ -1878,6 +1946,62 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
         module, HloModule::CreateFromProto(xla_computation.proto(),
                                            HloModuleConfig(program_shape)));
 
+    if (is_fusion) {
+      HloComputation* fused_computation = module->entry_computation();
+      CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters());
+      for (int i = 0; i < fused_computation->num_parameters(); i++) {
+        *fused_computation->parameter_instruction(i)
+             ->mutable_shape()
+             ->mutable_layout() = operand_shapes[i].layout();
+      }
+      HloInstruction* root = fused_computation->root_instruction();
+      // Manually fold Tuple(GTE(a, 0), GTE(a, 1), GTE(a, 2), ...) to a.
+      // FusedIrEmitter doesn't take GTE ops because we aim to elimiate tuples
+      // as much as possible.
+      if (root->opcode() == HloOpcode::kTuple) {
+        [&] {
+          HloInstruction* real_root = nullptr;
+          int expected_tuple_index = 0;
+          for (HloInstruction* operand : root->operands()) {
+            if (operand->opcode() != HloOpcode::kGetTupleElement) {
+              return;
+            }
+            if (real_root == nullptr) {
+              real_root = operand->mutable_operand(0);
+            } else if (real_root != operand->operand(0)) {
+              return;
+            }
+            if (expected_tuple_index != operand->tuple_index()) {
+              return;
+            }
+            expected_tuple_index++;
+          }
+          fused_computation->set_root_instruction(real_root);
+          std::vector<HloInstruction*> to_be_removed;
+          to_be_removed.push_back(root);
+          for (HloInstruction* operand : root->operands()) {
+            to_be_removed.push_back(operand);
+          }
+          for (auto instr : to_be_removed) {
+            TF_CHECK_OK(fused_computation->RemoveInstruction(instr));
+          }
+
+          root = real_root;
+        }();
+      }
+
+      if (output_shapes.size() > 1) {
+        CHECK(root->shape().IsTuple());
+        CHECK_EQ(root->shape().tuple_shapes_size(), output_shapes.size());
+
+        for (int i = 0; i < output_shapes.size(); i++) {
+          *root->mutable_shape()->mutable_tuple_shapes(i) = output_shapes.at(i);
+        }
+      } else {
+        CHECK_EQ(1, output_shapes.size());
+        *root->mutable_shape() = output_shapes[0];
+      }
+    }
     // Post-process the generated computation:
     // * Sanitize constant names, so that they can be used as LLVM global
     // symbols.
@@ -1887,22 +2011,13 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
         if (instr->opcode() == HloOpcode::kConstant) {
           instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(*instr));
         }
-        if (instr->shape().IsTuple()) {
-          TF_ASSIGN_OR_RETURN(*instr->mutable_shape(),
-                              ShapeInference::InferVariadicOpShape(
-                                  instr->opcode(), instr->operands()));
+        if (instr->shape().IsTuple() &&
+            computation == module->entry_computation() &&
+            instr != computation->root_instruction()) {
+          return InternalError("Non-root tuple types are not handled.");
         }
       }
     }
-    if (is_fusion) {
-      HloComputation* fused_computation = module->entry_computation();
-      CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters());
-      for (int i = 0; i < fused_computation->num_parameters(); i++) {
-        *fused_computation->parameter_instruction(i)
-             ->mutable_shape()
-             ->mutable_layout() = operand_shapes[i].layout();
-      }
-    }
   }
   return module->entry_computation();
 }
@@ -2620,6 +2735,45 @@ IrEmitterUnnested::BuildKernelThunkForMlir(
                                  ir_arrays);
 }
 
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildConstantInitializerThunk(
+    absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest,
+    const Shape& output_shape) {
+  int64 num_bytes = init_value.size();
+  if (absl::c_all_of(init_value, [](uint8 byte) { return byte == 0; })) {
+    return absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), dest);
+  }
+
+  // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
+  // repeating the literal 4 or 2 times, so long as the destination buffer is
+  // an even multiple of 32 bits long.
+  if ((num_bytes == 1 || num_bytes == 2) &&
+      ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
+    uint16 pattern16;
+    if (num_bytes == 1) {
+      uint8 b = init_value.front();
+      pattern16 = uint16{b} | (uint16{b} << 8);
+    } else {
+      memcpy(&pattern16, init_value.data(), sizeof(pattern16));
+    }
+    uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
+    return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(),
+                                                    pattern32, dest);
+  }
+
+  // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
+  // memset so long as all 32-bit words of the scalar are equal to each other.
+  if (num_bytes >= 4 && num_bytes % 4 == 0 &&
+      memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) ==
+          0) {
+    uint32 word;
+    memcpy(&word, init_value.data(), sizeof(word));
+    return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), word,
+                                                    dest);
+  }
+
+  return nullptr;
+}
+
 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
     HloInstruction* hlo, const ShapeIndex& index) {
   bool fused = HloOpcode::kFusion == hlo->opcode();
@@ -2661,43 +2815,14 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
     CHECK(ShapeUtil::IsScalar(init_value->shape()));
     int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_value->shape());
     const auto& literal = init_value->literal();
-
-    // Are all the bytes of this scalar equal to 0?  If so, we can create a
-    // MemzeroThunk.
     absl::Span<const uint8> literal_bytes(
         reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
-    if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
-      return {absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(),
-                                              GetAllocationSlice(*hlo, index))};
-    }
-
-    // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
-    // repeating the literal 4 or 2 times, so long as the destination buffer is
-    // an even multiple of 32 bits long.
     const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index);
-    if ((num_bytes == 1 || num_bytes == 2) &&
-        ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
-      uint16 pattern16;
-      if (num_bytes == 1) {
-        uint8 b = literal_bytes.front();
-        pattern16 = uint16{b} | (uint16{b} << 8);
-      } else {
-        memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16));
-      }
-      uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
-      return {absl::make_unique<Memset32BitValueThunk>(
-          Thunk::ThunkInfo(), pattern32, GetAllocationSlice(*hlo, index))};
-    }
 
-    // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
-    // memset so long as all 32-bit words of the scalar are equal to each other.
-    if (num_bytes >= 4 && num_bytes % 4 == 0 &&
-        memcmp(literal_bytes.data(), literal_bytes.data() + 4,
-               literal_bytes.size() - 4) == 0) {
-      uint32 word;
-      memcpy(&word, literal_bytes.data(), sizeof(word));
-      return {absl::make_unique<Memset32BitValueThunk>(
-          Thunk::ThunkInfo(), word, GetAllocationSlice(*hlo, index))};
+    auto thunk = BuildConstantInitializerThunk(
+        literal_bytes, GetAllocationSlice(*hlo, index), output_shape);
+    if (thunk) {
+      return thunk;
     }
   }
 
@@ -2744,6 +2869,83 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
   return {std::move(kernel_thunk)};
 }
 
+StatusOr<std::unique_ptr<Thunk>>
+IrEmitterUnnested::BuildInitializerThunkForMlir(mlir::Operation* op,
+                                                mlir::Value init_value,
+                                                mlir::Value dest) {
+  // initial value must be a scalar memref.
+  auto init_type = init_value.getType().dyn_cast<mlir::MemRefType>();
+  TF_RET_CHECK(init_type.getRank() == 0);
+
+  const Shape dest_shape = TypeToShape(dest.getType());
+  if (auto get_global_memref = mlir::dyn_cast_or_null<mlir::GetGlobalMemrefOp>(
+          init_value.getDefiningOp())) {
+    auto global_memref =
+        mlir::SymbolTable::lookupNearestSymbolFrom<mlir::GlobalMemrefOp>(
+            get_global_memref, get_global_memref.name());
+    if (global_memref.constant() && global_memref.initial_value()) {
+      // If the initial value happens to be a constant, generate a specialized
+      // thunk.
+      auto const_init = global_memref.initial_value()
+                            .getValue()
+                            .cast<mlir::DenseElementsAttr>();
+
+      Shape init_shape = TypeToShape(init_value.getType());
+      CHECK(ShapeUtil::IsScalar(init_shape));
+      int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_shape);
+      bool bool_init;
+      absl::Span<const uint8> literal_bytes(
+          reinterpret_cast<const uint8*>(const_init.getRawData().data()),
+          num_bytes);
+      if (init_type.getElementTypeBitWidth() == 1) {
+        TF_RET_CHECK(num_bytes == 1);
+        bool_init = *const_init.getBoolValues().begin();
+        literal_bytes =
+            absl::MakeSpan(reinterpret_cast<const uint8_t*>(&bool_init), 1);
+      }
+
+      absl::Span<const BufferAllocation> allocations(
+          ir_emitter_context_->buffer_assignment().Allocations());
+      TF_ASSIGN_OR_RETURN(auto dest_slice,
+                          GetAllocationSliceForMlir(dest, allocations));
+
+      auto thunk =
+          BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape);
+      if (thunk) {
+        return {std::move(thunk)};
+      }
+    }
+  }
+
+  // Otherwise fall back to our slow initializer code. The thunk in this case
+  // will just need the IR arrays for the initial value and the destination.
+  std::vector<llvm_ir::IrArray> ir_arrays;
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<KernelThunk> kernel_thunk,
+      BuildKernelThunkForMlir(op, {init_value, dest}, Thunk::ThunkInfo(), {},
+                              &ir_arrays));
+  const llvm_ir::IrArray init_array = ir_arrays[0];
+  const llvm_ir::IrArray dest_array = ir_arrays[1];
+
+  LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+      dest_shape, ir_emitter_context_->gpu_device_info());
+  UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
+                         ir_emitter_context_->llvm_module());
+
+  // TODO(jurahul): Handled fused cases.
+  // In the unfused case the element is already there, just read from it.
+  std::string name = mlir::GetNameFromLoc(op->getLoc());
+  TF_RETURN_IF_ERROR(ParallelLoopEmitter(
+                         [=](const IrArray::Index& index) {
+                           return init_array.EmitReadArrayElement(index, &b_);
+                         },
+                         dest_array, launch_dimensions, &b_)
+                         .EmitLoop(mlir::GetNameFromLoc(op->getLoc())));
+
+  // Convert unique_ptr<KernelThunk> to StatusOr<unique_ptr<Thunk>>.
+  return {std::move(kernel_thunk)};
+}
+
 namespace {
 
 // Checks that the buffers corresponding to the given two HLOs share the same
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 197a25e2cc1..68661ce3034 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -175,8 +175,7 @@ class IrEmitterUnnested : public IrEmitter,
   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
   Status HandleReduce(HloInstruction* reduce) override;
   Status HandleSelectAndScatter(HloInstruction* instruction) override;
-  Status EmitSelectAndScatterFromMlir(
-      MlirEmitterInput mlir_input, std::unique_ptr<Thunk>&& initializer_thunk);
+  Status EmitSelectAndScatterFromMlir(MlirEmitterInput mlir_input);
   Status HandleTuple(HloInstruction* tuple) override;
   Status HandleWhile(HloInstruction* xla_while) override;
   Status HandleInfeed(HloInstruction* xla_infeed) override;
@@ -633,6 +632,13 @@ class IrEmitterUnnested : public IrEmitter,
   StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
       HloInstruction* hlo, const ShapeIndex& index = {});
 
+  std::unique_ptr<Thunk> BuildConstantInitializerThunk(
+      absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest,
+      const Shape& output_shape);
+
+  StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunkForMlir(
+      mlir::Operation* op, mlir::Value init_value, mlir::Value dest);
+
   // Returns a WhileThunk that invokes thunk sequences for 'condition' and
   // 'body' sub-computations of while instruction 'hlo'.
   StatusOr<std::unique_ptr<Thunk>> BuildWhileThunk(const HloInstruction* hlo);
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo b/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo
index baf3ccf03d5..8e9ee58fbf2 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo
@@ -13,15 +13,15 @@
 // CHECK:         %[[VAL_9:.*]] = alloca float, align 4
 // CHECK:         %[[VAL_10:.*]] = alloca float, align 4
 // CHECK:         %[[VAL_11:.*]] = getelementptr inbounds i8, i8* %[[VAL_12:.*]], i64 0
-// CHECK:         %[[VAL_13:.*]] = bitcast i8* %[[VAL_11]] to [2 x i8*]*
+// CHECK:         %[[VAL_13:.*]] = bitcast i8* %[[VAL_11]] to [100 x [200 x [300 x float]]]*
 // CHECK:         %[[VAL_14:.*]] = getelementptr inbounds i8, i8* %[[VAL_15:.*]], i64 0
-// CHECK:         %[[VAL_16:.*]] = bitcast i8* %[[VAL_14]] to [200 x float]*
+// CHECK:         %[[VAL_16:.*]] = bitcast i8* %[[VAL_14]] to [100 x [200 x [300 x float]]]*
 // CHECK:         %[[VAL_17:.*]] = getelementptr inbounds i8, i8* %[[VAL_18:.*]], i64 0
 // CHECK:         %[[VAL_19:.*]] = bitcast i8* %[[VAL_17]] to [200 x float]*
 // CHECK:         %[[VAL_20:.*]] = getelementptr inbounds i8, i8* %[[VAL_21:.*]], i64 0
-// CHECK:         %[[VAL_22:.*]] = bitcast i8* %[[VAL_20]] to [100 x [200 x [300 x float]]]*
+// CHECK:         %[[VAL_22:.*]] = bitcast i8* %[[VAL_20]] to [200 x float]*
 // CHECK:         %[[VAL_23:.*]] = getelementptr inbounds i8, i8* %[[VAL_24:.*]], i64 0
-// CHECK:         %[[VAL_25:.*]] = bitcast i8* %[[VAL_23]] to [100 x [200 x [300 x float]]]*
+// CHECK:         %[[VAL_25:.*]] = bitcast i8* %[[VAL_23]] to [2 x i8*]*
 // CHECK:         %[[VAL_26:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
 // CHECK:         %[[VAL_27:.*]] = icmp eq i32 0, %[[VAL_26]]
 // CHECK:         %[[VAL_28:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
@@ -41,11 +41,11 @@
 // CHECK:       d.in_bounds-after:                                ; preds = %[[VAL_43:.*]], %[[VAL_32]]
 // CHECK:         ret void
 // CHECK:       emit_mof_tuple-true:                              ; preds = %[[VAL_33]]
-// CHECK:         %[[VAL_44:.*]] = bitcast [200 x float]* %[[VAL_16]] to i8*
-// CHECK:         %[[VAL_45:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_13]], i64 0, i64 0
+// CHECK:         %[[VAL_44:.*]] = bitcast [200 x float]* %[[VAL_19]] to i8*
+// CHECK:         %[[VAL_45:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_25]], i64 0, i64 0
 // CHECK:         store i8* %[[VAL_44]], i8** %[[VAL_45]], align 8
-// CHECK:         %[[VAL_46:.*]] = bitcast [200 x float]* %[[VAL_19]] to i8*
-// CHECK:         %[[VAL_47:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_13]], i64 0, i64 1
+// CHECK:         %[[VAL_46:.*]] = bitcast [200 x float]* %[[VAL_22]] to i8*
+// CHECK:         %[[VAL_47:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_25]], i64 0, i64 1
 // CHECK:         store i8* %[[VAL_46]], i8** %[[VAL_47]], align 8
 // CHECK:         br label %[[VAL_32]]
 // CHECK:       d.in_bounds-true:                                 ; preds = %[[VAL_32]]
@@ -55,23 +55,23 @@
 // CHECK:         store float %[[VAL_49]], float* %[[VAL_9]], align 4
 // CHECK:         store i32 0, i32* %[[VAL_8]], align 4
 // CHECK:         br label %[[VAL_50:.*]]
-// CHECK:       d.inner.loop_header.reduction_dim.0:              ; preds = %[[VAL_51:.*]], %[[VAL_41]]
+// CHECK:       reduce.13.inner.loop_header.reduction_dim.0:      ; preds = %[[VAL_51:.*]], %[[VAL_41]]
 // CHECK:         %[[VAL_52:.*]] = load i32, i32* %[[VAL_8]], align 4
 // CHECK:         %[[VAL_53:.*]] = icmp uge i32 %[[VAL_52]], 100
 // CHECK:         br i1 %[[VAL_53]], label %[[VAL_43]], label %[[VAL_54:.*]]
-// CHECK:       d.inner.loop_body.reduction_dim.0:                ; preds = %[[VAL_50]]
+// CHECK:       reduce.13.inner.loop_body.reduction_dim.0:        ; preds = %[[VAL_50]]
 // CHECK:         store i32 0, i32* %[[VAL_7]], align 4
 // CHECK:         br label %[[VAL_55:.*]]
-// CHECK:       d.inner.loop_header.reduction_dim.2:              ; preds = %[[VAL_56:.*]], %[[VAL_54]]
+// CHECK:       reduce.13.inner.loop_header.reduction_dim.2:      ; preds = %[[VAL_56:.*]], %[[VAL_54]]
 // CHECK:         %[[VAL_57:.*]] = load i32, i32* %[[VAL_7]], align 4
 // CHECK:         %[[VAL_58:.*]] = icmp uge i32 %[[VAL_57]], 300
 // CHECK:         br i1 %[[VAL_58]], label %[[VAL_51]], label %[[VAL_56]]
-// CHECK:       d.inner.loop_body.reduction_dim.2:                ; preds = %[[VAL_55]]
+// CHECK:       reduce.13.inner.loop_body.reduction_dim.2:        ; preds = %[[VAL_55]]
 // CHECK:         %[[VAL_59:.*]] = load float, float* %[[VAL_10]], align 4
 // CHECK:         %[[VAL_60:.*]] = load float, float* %[[VAL_9]], align 4
-// CHECK:         %[[VAL_61:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_22]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]]
+// CHECK:         %[[VAL_61:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_13]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]]
 // CHECK:         %[[VAL_62:.*]] = load float, float* %[[VAL_61]], align 4, !invariant.load !4
-// CHECK:         %[[VAL_63:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_25]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]]
+// CHECK:         %[[VAL_63:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_16]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]]
 // CHECK:         %[[VAL_64:.*]] = load float, float* %[[VAL_63]], align 4, !invariant.load !4
 // CHECK:         store float %[[VAL_59]], float* %[[VAL_5]], align 4
 // CHECK:         store float %[[VAL_60]], float* %[[VAL_4]], align 4
@@ -83,7 +83,7 @@
 // CHECK:         %[[VAL_67:.*]] = bitcast float* %[[VAL_1]] to i8*
 // CHECK:         %[[VAL_68:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_6]], i64 0, i64 1
 // CHECK:         store i8* %[[VAL_67]], i8** %[[VAL_68]], align 8
-// CHECK:         call void @Add(float* %[[VAL_5]], float* %[[VAL_4]], float* %[[VAL_3]], float* %[[VAL_2]], [2 x i8*]* %[[VAL_6]])
+// CHECK:         call void @region_1_5(float* %[[VAL_5]], float* %[[VAL_4]], float* %[[VAL_3]], float* %[[VAL_2]], [2 x i8*]* %[[VAL_6]])
 // CHECK:         %[[VAL_69:.*]] = load float, float* %[[VAL_0]], align 4
 // CHECK:         %[[VAL_70:.*]] = load float, float* %[[VAL_1]], align 4
 // CHECK:         store float %[[VAL_69]], float* %[[VAL_10]], align 4
@@ -91,21 +91,21 @@
 // CHECK:         %[[VAL_71:.*]] = add nuw nsw i32 %[[VAL_57]], 1
 // CHECK:         store i32 %[[VAL_71]], i32* %[[VAL_7]], align 4
 // CHECK:         br label %[[VAL_55]]
-// CHECK:       d.inner.loop_exit.reduction_dim.2:                ; preds = %[[VAL_55]]
+// CHECK:       reduce.13.inner.loop_exit.reduction_dim.2:        ; preds = %[[VAL_55]]
 // CHECK:         %[[VAL_72:.*]] = add nuw nsw i32 %[[VAL_52]], 1
 // CHECK:         store i32 %[[VAL_72]], i32* %[[VAL_8]], align 4
 // CHECK:         br label %[[VAL_50]]
-// CHECK:       d.inner.loop_exit.reduction_dim.0:                ; preds = %[[VAL_50]]
+// CHECK:       reduce.13.inner.loop_exit.reduction_dim.0:        ; preds = %[[VAL_50]]
 // CHECK:         %[[VAL_73:.*]] = load float, float* %[[VAL_10]], align 4
 // CHECK:         %[[VAL_74:.*]] = insertvalue { float, float } undef, float %[[VAL_73]], 0
 // CHECK:         %[[VAL_75:.*]] = load float, float* %[[VAL_9]], align 4
 // CHECK:         %[[VAL_76:.*]] = insertvalue { float, float } %[[VAL_74]], float %[[VAL_75]], 1
 // CHECK:         %[[VAL_77:.*]] = extractvalue { float, float } %[[VAL_76]], 0
-// CHECK:         %[[VAL_78:.*]] = bitcast [200 x float]* %[[VAL_16]] to float*
+// CHECK:         %[[VAL_78:.*]] = bitcast [200 x float]* %[[VAL_19]] to float*
 // CHECK:         %[[VAL_79:.*]] = getelementptr inbounds float, float* %[[VAL_78]], i32 %[[VAL_37]]
 // CHECK:         store float %[[VAL_77]], float* %[[VAL_79]], align 4
 // CHECK:         %[[VAL_80:.*]] = extractvalue { float, float } %[[VAL_76]], 1
-// CHECK:         %[[VAL_81:.*]] = bitcast [200 x float]* %[[VAL_19]] to float*
+// CHECK:         %[[VAL_81:.*]] = bitcast [200 x float]* %[[VAL_22]] to float*
 // CHECK:         %[[VAL_82:.*]] = getelementptr inbounds float, float* %[[VAL_81]], i32 %[[VAL_37]]
 // CHECK:         store float %[[VAL_80]], float* %[[VAL_82]], align 4
 // CHECK:         br label %[[VAL_42]]
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index a90d2828000..30ee9853903 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -94,13 +94,15 @@ StatusOr<HloInstruction*> MakeConvolveHlo(
     HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
     int64 batch_group_count, const Window& window,
     const ConvolutionDimensionNumbers& dimension_numbers,
-    const PrecisionConfig& precision_config) {
+    const PrecisionConfig& precision_config,
+    absl::optional<PrimitiveType> preferred_element_type) {
   HloComputation* computation = lhs->parent();
   CHECK_EQ(computation, rhs->parent());
-  TF_ASSIGN_OR_RETURN(Shape convolve_shape,
-                      ShapeInference::InferConvolveShape(
-                          lhs->shape(), rhs->shape(), feature_group_count,
-                          batch_group_count, window, dimension_numbers));
+  TF_ASSIGN_OR_RETURN(
+      Shape convolve_shape,
+      ShapeInference::InferConvolveShape(
+          lhs->shape(), rhs->shape(), feature_group_count, batch_group_count,
+          window, dimension_numbers, preferred_element_type));
   return computation->AddInstruction(HloInstruction::CreateConvolve(
       convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window,
       dimension_numbers, precision_config));
@@ -281,14 +283,17 @@ HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
       HloInstruction::CreateIota(shape, iota_dimension));
 }
 
-StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
-                                     const DotDimensionNumbers& dim_numbers,
-                                     const PrecisionConfig& precision_config) {
+StatusOr<HloInstruction*> MakeDotHlo(
+    HloInstruction* lhs, HloInstruction* rhs,
+    const DotDimensionNumbers& dim_numbers,
+    const PrecisionConfig& precision_config,
+    absl::optional<PrimitiveType> preferred_element_type) {
   HloComputation* computation = lhs->parent();
   CHECK_EQ(computation, rhs->parent());
   TF_ASSIGN_OR_RETURN(
       Shape dot_shape,
-      ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
+      ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers,
+                                      preferred_element_type));
   return computation->AddInstruction(HloInstruction::CreateDot(
       dot_shape, lhs, rhs, dim_numbers, precision_config));
 }
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 53eeeffb858..1fa2a3faeea 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -59,11 +59,14 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
 
 // Creates a convolution HLO instruction and adds it to the computation
 // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
+// If the result shape has integral element type, an optional
+// preferred_element_type can be specified to override the element type.
 StatusOr<HloInstruction*> MakeConvolveHlo(
     HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
     int64 batch_group_count, const Window& window,
     const ConvolutionDimensionNumbers& dimension_numbers,
-    const PrecisionConfig& precision_config);
+    const PrecisionConfig& precision_config,
+    absl::optional<PrimitiveType> preferred_element_type);
 
 // Creates a transpose HLO instruction and adds it to the computation containing
 // `operand`.
@@ -128,10 +131,14 @@ HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
                             int64 iota_dimension);
 
 // Creates a Dot HLO instruction and adds it to the computation containing `lhs`
-// and `rhs` (both must be in the same computation).
-StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
-                                     const DotDimensionNumbers& dim_numbers,
-                                     const PrecisionConfig& precision_config);
+// and `rhs` (both must be in the same computation). If the result shape has
+// integral element type, an optional preferred_element_type can be specified to
+// override the element type.
+StatusOr<HloInstruction*> MakeDotHlo(
+    HloInstruction* lhs, HloInstruction* rhs,
+    const DotDimensionNumbers& dim_numbers,
+    const PrecisionConfig& precision_config,
+    absl::optional<PrimitiveType> preferred_element_type);
 
 // Creates a Map HLO instruction and adds it to the computation containing the
 // operands. All operands must be in the same computation.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 1480fdd53ac..57f96c8eea2 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -388,9 +388,10 @@ StatusOr<Literal> HloEvaluator::EvaluateDotOp(
   std::unique_ptr<HloInstruction> rhs_instr =
       HloInstruction::CreateConstant(rhs.Clone());
 
-  TF_ASSIGN_OR_RETURN(
-      Shape dot_shape,
-      ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
+  TF_ASSIGN_OR_RETURN(Shape dot_shape,
+                      ShapeInference::InferDotOpShape(
+                          lhs.shape(), rhs.shape(), dim_numbers,
+                          /*preferred_element_type=*/absl::nullopt));
 
   std::unique_ptr<HloInstruction> cloned_instruction =
       HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 01ad536e033..7f2157a85bd 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -4567,6 +4567,50 @@ TEST_F(HloEvaluatorTest, MapBF16) {
   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
+TEST_F(HloEvaluatorTest, MapS16) {
+  const absl::string_view hlo_text = R"(
+  HloModule test
+
+  map_computation {
+    p = s16[] parameter(0)
+    add = s16[] add(p, p)
+    ROOT conv = f32[] convert(add)
+  }
+
+  ENTRY CopyStartCopyDone {
+    c = s16[3] constant({1, 2, 3})
+    ROOT map = f32[3] map(c), to_apply=map_computation
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
+  Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
+  TF_ASSERT_OK_AND_ASSIGN(
+      Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
+TEST_F(HloEvaluatorTest, MapU16) {
+  const absl::string_view hlo_text = R"(
+  HloModule test
+
+  map_computation {
+    p = u16[] parameter(0)
+    add = u16[] add(p, p)
+    ROOT conv = f32[] convert(add)
+  }
+
+  ENTRY CopyStartCopyDone {
+    c = u16[3] constant({1, 2, 3})
+    ROOT map = f32[3] map(c), to_apply=map_computation
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
+  Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
+  TF_ASSERT_OK_AND_ASSIGN(
+      Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
 TEST_F(HloEvaluatorTest, DotUpcast) {
   const absl::string_view hlo_text = R"(
   HloModule test
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 808c2a4b769..953145df474 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1290,7 +1290,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
                         ShapeInference::InferConvolveShape(
                             lhs_shape, rhs_shape, conv->feature_group_count(),
-                            conv->batch_group_count(), window, dnums));
+                            conv->batch_group_count(), window, dnums,
+                            /*preferred_element_type=*/absl::nullopt));
     CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
         << "return shape set to: " << ShapeUtil::HumanString(result_shape)
         << " but is inferred to be: "
@@ -1769,6 +1770,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map));
         break;
       }
+      case U16: {
+        TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint16>(map));
+        break;
+      }
       case U32: {
         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map));
         break;
@@ -1781,6 +1786,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map));
         break;
       }
+      case S16: {
+        TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int16>(map));
+        break;
+      }
       case S32: {
         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map));
         break;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index f2ea03f063a..e43f68fd257 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -25,6 +25,7 @@ limitations under the License.
 #include "absl/strings/str_join.h"
 #include "absl/strings/str_split.h"
 #include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 59b1ac31e9b..864f293420b 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -16,10 +16,12 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
 
 #include <algorithm>
+#include <iterator>
 #include <memory>
 #include <set>
 #include <string>
 
+#include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
@@ -153,7 +155,22 @@ struct Item {
   int64 position;
 };
 
+// Data structure meant to record the user of the buffer defined from an Item.
+// It records also the operand_number from where such use derives, so that
+// indirect uses can be better identified (like for example a buffer used
+// through a bitcast).
+struct ItemUse {
+  Item* user;
+  int operand_number;
+
+  ItemUse(Item* user, int op_num) : user(user), operand_number(op_num) {}
+  bool operator==(const ItemUse& other) const {
+    return user == other.user && operand_number == other.operand_number;
+  }
+};
+
 using ItemList = absl::InlinedVector<Item*, 3>;
+using UsesList = absl::InlinedVector<ItemUse, 3>;
 
 // Class which maintains an ordered list of instructions with fast insertion
 // before arbitrary elements.
@@ -412,11 +429,11 @@ class InstructionList {
 // has_indirect_users to whether any of the uses is indirect. A use is indirect
 // if the instruction defining logical_buffer is not an operand of the use. This
 // can happen via buffer aliasing (eg, tuples).
-ItemList GetUsers(const InstructionList& instruction_list,
+UsesList GetUsers(const InstructionList& instruction_list,
                   const LogicalBuffer* logical_buffer,
                   const TuplePointsToAnalysis& points_to_analysis,
                   bool* has_indirect_users) {
-  ItemList users;
+  UsesList users;
   // To identify uses iterate through all HloInstruction users of the
   // BufferAliases of the logical buffer.
   *has_indirect_users = false;
@@ -431,14 +448,18 @@ ItemList GetUsers(const InstructionList& instruction_list,
         // instruction (the GTE instruction only uses the pointer vector).
         continue;
       }
-      if (buffer_alias.instruction() != logical_buffer->instruction()) {
+      if (buffer_alias.instruction() != logical_buffer->instruction() &&
+          buffer_alias.instruction()->opcode() != HloOpcode::kBitcast) {
         *has_indirect_users = true;
       }
       // A buffer may be used by the instruction via more than one alias. For
       // example, a buffer which appears in more than one element of a tuple.
       Item* user_item = instruction_list.GetItem(user);
-      if (!absl::c_linear_search(users, user_item)) {
-        users.push_back(user_item);
+      for (int64 op_idx : user->OperandIndices(buffer_alias.instruction())) {
+        if (!absl::c_linear_search(
+                users, ItemUse{user_item, static_cast<int>(op_idx)})) {
+          users.push_back(ItemUse{user_item, static_cast<int>(op_idx)});
+        }
       }
     }
   }
@@ -516,7 +537,8 @@ class MemoryUsageTracker {
   // is remat_item. This method should be called after the HLO graph has
   // been transformed (rematerialization instruction created and connected
   // to uses).
-  Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
+  Status AddRematerializedInstruction(Item* original_item, Item* remat_item,
+                                      absl::Span<Item*> bitcasts);
 
   // Selects and returns the best candidate instructions for rematerialization.
   // A sequence of candidate instructions of length between min_block_size and
@@ -538,6 +560,9 @@ class MemoryUsageTracker {
   // Returns whether 'item' has any unplaced users.
   bool HasUnplacedUsers(Item* item) const;
 
+  // Returns the list of uses for a specific 'item'.
+  const UsesList GetItemUses(Item* item) const;
+
   // Returns whether 'item' is currently in progress.
   bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
 
@@ -588,7 +613,7 @@ class MemoryUsageTracker {
     bool has_indirect_uses;
 
     // The instructions which use this buffer.
-    ItemList users;
+    UsesList users;
 
     // The number of users (HloInstructions) of this buffer which have not yet
     // been placed in the sequence.
@@ -611,7 +636,7 @@ class MemoryUsageTracker {
       const LogicalBuffer* logical_buffer,
       const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
     bool has_indirect_uses = false;
-    ItemList users = GetUsers(instruction_list_, logical_buffer,
+    UsesList users = GetUsers(instruction_list_, logical_buffer,
                               points_to_analysis, &has_indirect_uses);
     return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
                      logical_buffer->shape(), std::move(users), live_out,
@@ -621,13 +646,13 @@ class MemoryUsageTracker {
   // Create a new buffer representing a rematerialization of given buffer for
   // the given uses.
   Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
-                              ItemList&& rematerialized_uses) {
+                              UsesList&& rematerialized_uses) {
     CHECK(original_buffer.defining_instruction->placed)
         << original_buffer.defining_instruction->instruction->name();
     CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
     CHECK(!original_buffer.live_out) << original_buffer.ToString();
-    for (Item* use : rematerialized_uses) {
-      CHECK(!use->placed) << use->instruction->name();
+    for (ItemUse& use : rematerialized_uses) {
+      CHECK(!use.user->placed) << use.user->instruction->name();
     }
     return NewBuffer(remat_item, original_buffer.shape,
                      std::move(rematerialized_uses), /*live_out=*/false,
@@ -692,11 +717,18 @@ class MemoryUsageTracker {
 
   // Create a new buffer, add it to buffers_, and return a reference.
   Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
-                    ItemList&& users, bool live_out, bool has_indirect_uses) {
+                    UsesList&& uses, bool live_out, bool has_indirect_uses) {
     int buffer_id = buffers_.size();
+    auto get_num_of_unique_users = [](const UsesList& uses) -> int64 {
+      absl::flat_hash_set<Item*> users_set;
+      for (const ItemUse& use : uses) {
+        users_set.insert(use.user);
+      }
+      return users_set.size();
+    };
     buffers_.push_back(Buffer{
         buffer_id, defining_instruction, size_function_(shape), shape, live_out,
-        has_indirect_uses, users, static_cast<int64>(users.size())});
+        has_indirect_uses, uses, get_num_of_unique_users(uses)});
     return buffers_.back();
   }
 
@@ -771,12 +803,15 @@ MemoryUsageTracker::MemoryUsageTracker(
 
         // Add users of while to Buffer users.
         bool unused;
-        for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
-                                        points_to_analysis, &unused)) {
-          if (!absl::c_linear_search(buffer->users, user_item)) {
-            buffer->users.push_back(user_item);
+        for (ItemUse& user_item : GetUsers(instruction_list_, logical_buffer,
+                                           points_to_analysis, &unused)) {
+          auto existing_user_it = absl::c_find_if(
+              buffer->users,
+              [&](const ItemUse& use) { return user_item.user == use.user; });
+          if (existing_user_it == buffer->users.end()) {
             buffer->unfinished_user_count++;
-            user_item->buffers_used.push_back(buffer->id);
+            user_item.user->buffers_used.push_back(buffer->id);
+            buffer->users.push_back(user_item);
           }
         }
       } else {
@@ -784,8 +819,10 @@ MemoryUsageTracker::MemoryUsageTracker(
             logical_buffer, points_to_analysis,
             ContainsKey(live_out_set, logical_buffer));
         item->buffers_defined.push_back(buffer->id);
-        for (Item* user : buffer->users) {
-          user->buffers_used.push_back(buffer->id);
+        for (ItemUse& user : buffer->users) {
+          if (!absl::c_linear_search(user.user->buffers_used, buffer->id)) {
+            user.user->buffers_used.push_back(buffer->id);
+          }
         }
       }
 
@@ -1003,14 +1040,14 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
   // Compressed buffer is now alive.
   memory_usage_ += size_function_(compressed_item->instruction->shape());
 
-  ItemList placed_users;
-  ItemList unplaced_users;
+  UsesList placed_users;
+  UsesList unplaced_users;
   CHECK_EQ(original_item->buffers_output.size(), 1);
   BufferId original_buffer_id = original_item->buffers_output[0];
   Buffer& original_buffer = buffers_.at(original_buffer_id);
-  for (Item* user : original_buffer.users) {
-    if (user->placed) {
-      CHECK(IsFinished(user)) << user->instruction->name();
+  for (ItemUse& user : original_buffer.users) {
+    if (user.user->placed) {
+      CHECK(IsFinished(user.user)) << user.user->instruction->name();
       placed_users.push_back(user);
     } else {
       unplaced_users.push_back(user);
@@ -1018,10 +1055,10 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
   }
   original_buffer.users = std::move(placed_users);
   original_buffer.unfinished_user_count = 0;
-  original_buffer.users.push_back(compressed_item);
+  original_buffer.users.push_back(ItemUse{compressed_item, 0});
   Buffer& compressed_buffer =
       NewBuffer(compressed_item, compressed_item->instruction->shape(),
-                {uncompressed_item}, /*live_out=*/false,
+                {ItemUse{uncompressed_item, 0}}, /*live_out=*/false,
                 /*has_indirect_uses=*/false);
   compressed_item->buffers_used = original_item->buffers_output;
   compressed_item->buffers_output = {compressed_buffer.id};
@@ -1036,8 +1073,8 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
   uncompressed_item->buffers_output = {uncompressed_buffer.id};
   uncompressed_item->buffers_defined = {uncompressed_buffer.id};
 
-  for (Item* user : uncompressed_buffer.users) {
-    BufferIdList& buffers_used = user->buffers_used;
+  for (ItemUse& user : uncompressed_buffer.users) {
+    BufferIdList& buffers_used = user.user->buffers_used;
     std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
                  uncompressed_buffer.id);
   }
@@ -1045,8 +1082,8 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
   return Status::OK();
 }
 
-Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
-                                                        Item* remat_item) {
+Status MemoryUsageTracker::AddRematerializedInstruction(
+    Item* original_item, Item* remat_item, absl::Span<Item*> bitcasts) {
   VLOG(3) << "AddRematerializedInstruction: original_instruction = "
           << original_item->instruction->name()
           << ", remat_instruction = " << remat_item->instruction->name();
@@ -1067,9 +1104,23 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
       // Buffer used by this instruction was dead, now is alive.
       memory_usage_ += AllocatedSize(buffer.id);
     }
-
     buffer.unfinished_user_count++;
-    buffer.users.push_back(remat_item);
+    absl::InlinedVector<ItemUse, 2> filtered_users;
+    std::copy_if(buffer.users.begin(), buffer.users.end(),
+                 std::back_inserter(filtered_users),
+                 [&](const ItemUse& iu) { return iu.user == original_item; });
+    for (ItemUse& u : filtered_users) {
+      buffer.users.push_back(ItemUse{remat_item, u.operand_number});
+    }
+  }
+
+  for (Item* bitcast : bitcasts) {
+    CHECK_EQ(bitcast->instruction->opcode(), HloOpcode::kBitcast);
+    for (BufferId buffer_id : bitcast->buffers_used) {
+      Buffer& buffer = buffers_.at(buffer_id);
+      buffer.unfinished_user_count++;
+      buffer.users.push_back(ItemUse{bitcast, 0});
+    }
   }
 
   // Create a new set of Buffers defined by the new rematerialization
@@ -1078,10 +1129,10 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
   for (BufferId old_buffer_id : original_item->buffers_defined) {
     Buffer& old_buffer = buffers_.at(old_buffer_id);
 
-    ItemList placed_users;
-    ItemList unplaced_users;
-    for (Item* user : old_buffer.users) {
-      if (user->placed) {
+    UsesList placed_users;
+    UsesList unplaced_users;
+    for (ItemUse& user : old_buffer.users) {
+      if (user.user->placed) {
         placed_users.push_back(user);
       } else {
         unplaced_users.push_back(user);
@@ -1097,8 +1148,8 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
         RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
 
     remat_item->buffers_defined.push_back(new_buffer.id);
-    for (Item* user : new_buffer.users) {
-      BufferIdList& buffers_used = user->buffers_used;
+    for (ItemUse& user : new_buffer.users) {
+      BufferIdList& buffers_used = user.user->buffers_used;
       std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
                    new_buffer.id);
     }
@@ -1131,6 +1182,10 @@ string MemoryUsageTracker::ToString() const {
       absl::StrAppend(&output, "      ", buffer.ToString(), live, ", ",
                       buffer.unfinished_user_count, " unfinished uses\n");
     }
+    absl::StrAppend(&output, "    Outputs:\n");
+    for (BufferId buffer_id : item->buffers_output) {
+      absl::StrAppend(&output, "      ", buffers_[buffer_id].ToString(), "\n");
+    }
     absl::StrAppend(&output, "    Uses:\n");
     for (BufferId buffer_id : item->buffers_used) {
       absl::StrAppend(&output, "      ", buffers_[buffer_id].ToString(), "\n");
@@ -1190,12 +1245,14 @@ bool MemoryUsageTracker::Check() const {
   }
   for (const Buffer& buffer : buffers_) {
     int64 unfinished_uses = 0;
-    for (Item* user : buffer.users) {
-      const BufferIdList& used_buffers = user->buffers_used;
+    absl::flat_hash_set<Item*> already_counted_user;
+    for (const ItemUse& user : buffer.users) {
+      const BufferIdList& used_buffers = user.user->buffers_used;
       CHECK(absl::c_linear_search(used_buffers, buffer.id))
-          << "Instruction " << user->instruction->name()
+          << "Instruction " << user.user->instruction->name()
           << " used buffers is missing " << buffer.ToString();
-      if (!IsFinished(user)) {
+      if (!IsFinished(user.user) &&
+          already_counted_user.insert(user.user).second) {
         unfinished_uses++;
       }
     }
@@ -1397,8 +1454,8 @@ MemoryUsageTracker::PickRematerializationCandidates(
 bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
   for (BufferId buffer_id : item->buffers_defined) {
     const Buffer& buffer = buffers_.at(buffer_id);
-    for (Item* user : buffer.users) {
-      if (!user->placed) {
+    for (const ItemUse& user : buffer.users) {
+      if (!user.user->placed) {
         return true;
       }
     }
@@ -1406,6 +1463,17 @@ bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
   return false;
 }
 
+const UsesList MemoryUsageTracker::GetItemUses(Item* item) const {
+  UsesList combined_users;
+  for (BufferId buffer_id : item->buffers_defined) {
+    const Buffer& buffer = buffers_.at(buffer_id);
+    for (const ItemUse& user : buffer.users) {
+      combined_users.push_back(user);
+    }
+  }
+  return combined_users;
+}
+
 StatusOr<int64> RematerializeInstructions(
     MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
     absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
@@ -1443,18 +1511,30 @@ StatusOr<int64> RematerializeInstructions(
     Item* remat_item = instruction_list->CreateItem(remat);
 
     // Replace each remaining use of 'best' with the rematerialization.
-    std::vector<HloInstruction*> best_users_copy = best->users();
-    for (HloInstruction* user : best_users_copy) {
-      if (!memory_tracker->IsPlaced(user)) {
+    absl::InlinedVector<Item*, 4> bitcasts;
+    for (auto& user : memory_tracker->GetItemUses(best_item)) {
+      if (!memory_tracker->IsPlaced(user.user->instruction)) {
         VLOG(2) << "  Replacing use of " << best->name() << " in "
-                << user->name() << " with " << remat->name();
-        TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
+                << user.user->instruction->name() << " with " << remat->name();
+        const int op_idx = user.operand_number;
+        auto* remat_use = remat;
+        if (user.user->instruction->operand(op_idx)->shape() !=
+            remat->shape()) {
+          remat_use = computation->AddInstruction(HloInstruction::CreateUnary(
+              user.user->instruction->operand(op_idx)->shape(),
+              HloOpcode::kBitcast, remat));
+          bitcasts.push_back(instruction_list->CreateItem(remat_use));
+          bitcasts.back()->buffers_output = remat_item->buffers_defined;
+          bitcasts.back()->buffers_used = remat_item->buffers_defined;
+        }
+        TF_RETURN_IF_ERROR(
+            user.user->instruction->ReplaceOperandWith(op_idx, remat_use));
       }
     }
 
     // Account for the rematerialization in the memory tracker.
-    TF_RETURN_IF_ERROR(
-        memory_tracker->AddRematerializedInstruction(best_item, remat_item));
+    TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction(
+        best_item, remat_item, absl::MakeSpan(bitcasts)));
 
     // Insert rematerialized instruction right before the earliest unplaced
     // use of the instruction *and* the earliest unplaced last use of any
@@ -1463,7 +1543,14 @@ StatusOr<int64> RematerializeInstructions(
     // this could increase memory usage.
     ItemList place_before;
     for (auto user : remat->users()) {
-      place_before.push_back(instruction_list->GetItem(user));
+      if (!absl::c_linear_search(bitcasts, instruction_list->GetItem(user))) {
+        place_before.push_back(instruction_list->GetItem(user));
+      }
+    }
+    for (auto* bitcast : bitcasts) {
+      for (auto user : bitcast->instruction->users()) {
+        place_before.push_back(instruction_list->GetItem(user));
+      }
     }
     for (auto* operand : remat->operands()) {
       for (auto* operand_user : operand->users()) {
@@ -1486,6 +1573,9 @@ StatusOr<int64> RematerializeInstructions(
     }
     instruction_list->InsertBeforeInstructions(remat_item, place_before);
 
+    for (auto* bitcast : bitcasts) {
+      instruction_list->InsertBeforeInstructions(bitcast, place_before);
+    }
     // If the rematerialized instruction is dead then rematerialization is
     // essentially a move. Don't delete the instruction now because we don't
     // want duplicate HloInstruction* values during the course of the
@@ -1501,8 +1591,12 @@ StatusOr<int64> RematerializeInstructions(
         instruction_list->Denylist(remat);
       }
       remat_move_instructions->insert(remat);
+      net_instructions_added += bitcasts.size();
     } else {
-      net_instructions_added++;
+      net_instructions_added += bitcasts.size() + 1;
+    }
+    for (auto* bitcast : bitcasts) {
+      instruction_list->Denylist(bitcast->instruction);
     }
   }
   VLOG(1) << "Rematerializing instructions ["
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 35f39e9a342..5e747f9076a 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -748,6 +748,73 @@ ENTRY %entry {
               op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
 }
 
+// Test rematerialization of values through bitcasts
+// Its expected that the broadcast gets rematerialized
+TEST_F(HloRematerializationTest, ThroughBitcastRemat) {
+  const string& hlo_string = R"(
+HloModule fusion, is_scheduled=true
+
+ENTRY %mycomp (param: f32[1]) -> f32[1] {
+  %param = f32[1]{0} parameter(0)
+  %reshape = f32[] reshape(f32[1]{0} %param)
+  %broadcast = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
+  %bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast)
+  %negate = f32[1024,1]{1,0} negate(f32[1024,1]{1,0} %broadcast)
+  %concatenate = f32[2048,1]{1,0} concatenate(f32[1024,1]{1,0} %negate, f32[1024,1]{1,0} %negate), dimensions={0}
+  %slice = f32[1,1]{1,0} slice(f32[2048,1]{1,0} %concatenate), slice={[0:1], [0:1]}
+  %bitcast.1 = f32[1]{0} bitcast(f32[1,1]{1,0} %slice)
+  %concatenate.1 = f32[1025]{0} concatenate(f32[1024]{0} %bitcast, f32[1]{0} %bitcast.1), dimensions={0}
+  ROOT %slice.1 = f32[1]{0} slice(f32[1025]{0} %concatenate.1), slice={[0:1]}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  auto* computation = module->entry_computation();
+  // Find and save the original broadcast instruction which should be
+  // rematerialized.
+  const HloInstruction* slice = computation->root_instruction();
+  ASSERT_THAT(slice,
+              op::Slice(op::Concatenate(op::Bitcast(op::Broadcast(_)), _)));
+  const HloInstruction* concat = slice->operand(0);
+  const HloInstruction* bcast = concat->operand(0)->operand(0);
+
+  LOG(INFO) << module->ToString();
+  // Computation requires 16KB without rematerialization, but uses only 12KB
+  // with rematerialization so pick a memory limit between these values (14KB).
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/14 * 1024, module.get()));
+  LOG(INFO) << module->ToString();
+  EXPECT_TRUE(changed);
+
+  // Root should not have changed.
+  EXPECT_EQ(computation->root_instruction(), slice);
+
+  // The bitcast for the rematerialized broadcast
+  const HloInstruction* remat_bitcast = concat->operand(0);
+  // The broadcast should have been rematerialized.
+  const HloInstruction* remat_broadcast = remat_bitcast->operand(0);
+
+  EXPECT_THAT(remat_broadcast, op::Broadcast(::testing::Ne(bcast)));
+
+  // The rematerialized broadcast should be immediately before its bitcast
+  // and the bitcast before the concatenate in the sequence.
+  EXPECT_EQ(module->schedule()
+                .sequence(computation)
+                .instructions()[computation->instruction_count() - 2],
+            concat);
+  EXPECT_EQ(module->schedule()
+                .sequence(computation)
+                .instructions()[computation->instruction_count() - 3],
+            remat_bitcast);
+  EXPECT_EQ(module->schedule()
+                .sequence(computation)
+                .instructions()[computation->instruction_count() - 4],
+            remat_broadcast);
+}
+
 }  // namespace
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 13c754438eb..84e4fe6e3fd 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -19,6 +19,7 @@ limitations under the License.
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/comparison_util.h"
 #include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -136,13 +137,12 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
 }
 
 Status ShapeVerifier::HandleDot(HloInstruction* dot) {
-  TF_ASSIGN_OR_RETURN(Shape expected,
-                      ShapeInference::InferDotOpShape(
-                          dot->operand(0)->shape(), dot->operand(1)->shape(),
-                          dot->dot_dimension_numbers()));
-  if (ShapeUtil::CanUpcastIntegral(expected, dot->shape())) {
-    expected.set_element_type(dot->shape().element_type());
-  }
+  TF_ASSIGN_OR_RETURN(
+      const Shape expected,
+      ShapeInference::InferDotOpShape(
+          dot->operand(0)->shape(), dot->operand(1)->shape(),
+          dot->dot_dimension_numbers(),
+          /*preferred_element_type=*/dot->shape().element_type()));
   return CheckShape(dot, expected);
 }
 
@@ -152,10 +152,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
       ShapeInference::InferConvolveShape(
           convolution->operand(0)->shape(), convolution->operand(1)->shape(),
           convolution->feature_group_count(), convolution->batch_group_count(),
-          convolution->window(), convolution->convolution_dimension_numbers()));
-  if (ShapeUtil::CanUpcastIntegral(expected, convolution->shape())) {
-    expected.set_element_type(convolution->shape().element_type());
-  }
+          convolution->window(), convolution->convolution_dimension_numbers(),
+          /*preferred_element_type=*/convolution->shape().element_type()));
   return CheckShape(convolution, expected);
 }
 
@@ -1749,6 +1747,31 @@ Status CheckElementwiseInstruction(HloInstruction* instruction) {
           ShapeUtil::HumanString(operand_shape));
     }
   }
+  if (auto* comparison = DynCast<HloCompareInstruction>(instruction)) {
+    const Shape& operand_shape = comparison->operand(1)->shape();
+    PrimitiveType operand_element_type = operand_shape.element_type();
+    Comparison::Type default_comparison_type =
+        Comparison::DefaultComparisonType(operand_element_type);
+    if (primitive_util::IsFloatingPointType(operand_element_type)) {
+      if (comparison->type() != Comparison::Type::kFloat &&
+          comparison->type() != Comparison::Type::kFloatTotalOrder) {
+        return FailedPrecondition(
+            "Expected comparison type %s or %s.\n"
+            "actual: %s\noperand: %s\n",
+            ComparisonTypeToString(Comparison::Type::kFloat),
+            ComparisonTypeToString(Comparison::Type::kFloatTotalOrder),
+            ComparisonTypeToString(comparison->type()),
+            ShapeUtil::HumanString(operand_shape));
+      }
+    } else if (comparison->type() != default_comparison_type) {
+      return FailedPrecondition(
+          "Expected comparison type %s.\n"
+          "actual: %s\noperand: %s\n",
+          ComparisonTypeToString(default_comparison_type),
+          ComparisonTypeToString(comparison->type()),
+          ShapeUtil::HumanString(operand_shape));
+    }
+  }
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 0df30166a1c..c6c09e3dee1 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -1220,5 +1220,77 @@ TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
                         "needs to be collective-permute-start, found tuple"));
 }
 
+TEST_F(HloVerifierTest, ComparisonTypeFloat) {
+  const char* const hlo_string = R"(
+  HloModule Module
+
+  ENTRY RngOperandElementTypesNotMatch {
+   p0 = f32[] parameter(0)
+   ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.error_message(),
+              HasSubstr("Expected comparison type FLOAT or TOTALORDER"));
+}
+
+TEST_F(HloVerifierTest, ComparisonTypeSigned) {
+  const char* const hlo_string = R"(
+  HloModule Module
+
+  ENTRY RngOperandElementTypesNotMatch {
+   p0 = s32[] parameter(0)
+   ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.error_message(),
+              HasSubstr("Expected comparison type SIGNED"));
+}
+
+TEST_F(HloVerifierTest, ComparisonTypeUnsigned) {
+  const char* const hlo_string = R"(
+  HloModule Module
+
+  ENTRY RngOperandElementTypesNotMatch {
+   p0 = u32[] parameter(0)
+   ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.error_message(),
+              HasSubstr("Expected comparison type UNSIGNED"));
+}
+
+TEST_F(HloVerifierTest, ComparisonTypePred) {
+  const char* const hlo_string = R"(
+  HloModule Module
+
+  ENTRY RngOperandElementTypesNotMatch {
+   p0 = pred[] parameter(0)
+   ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.error_message(),
+              HasSubstr("Expected comparison type UNSIGNED"));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/integral_upcaster.cc b/tensorflow/compiler/xla/service/integral_upcaster.cc
index d8383b25c84..9bb8e468ad4 100644
--- a/tensorflow/compiler/xla/service/integral_upcaster.cc
+++ b/tensorflow/compiler/xla/service/integral_upcaster.cc
@@ -26,12 +26,14 @@ StatusOr<absl::optional<Shape>> MaybeInferShape(
     case HloOpcode::kDot:
       return ShapeInference::InferDotOpShape(
           instruction->operand(0)->shape(), instruction->operand(1)->shape(),
-          instruction->dot_dimension_numbers());
+          instruction->dot_dimension_numbers(),
+          /*preferred_element_type=*/absl::nullopt);
     case HloOpcode::kConvolution:
       return ShapeInference::InferConvolveShape(
           instruction->operand(0)->shape(), instruction->operand(1)->shape(),
           instruction->feature_group_count(), instruction->batch_group_count(),
-          instruction->window(), instruction->convolution_dimension_numbers());
+          instruction->window(), instruction->convolution_dimension_numbers(),
+          /*preferred_element_type=*/absl::nullopt);
     default:
       return absl::make_optional<Shape>();
   }
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
index 1457fa5df1d..e92cd54e8fd 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
@@ -345,7 +345,9 @@ Status LhloDialectEmitter::HandleReduce(HloInstruction* instr) {
       CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_);
   auto reduce_op = builder.create<lhlo::ReduceOp>(loc, inputs, init_values,
                                                   results, dimensions_attr);
-  reduce_op.ensureTerminator(reduce_op.body(), builder, getLocation(instr));
+  builder.createBlock(&reduce_op.body());
+  OpBuilder::atBlockEnd(&reduce_op.body().front())
+      .create<lhlo::TerminatorOp>(getLocation(instr));
   return SpliceHloComputation(OpBuilder{&reduce_op.body()}, loc,
                               *instr->to_apply(), emission_context_);
 }
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 02895d9627b..badfe81625e 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -214,6 +214,37 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
                                        output_is_dynamic);
 }
 
+StatusOr<PrimitiveType> MaybeUpcast(
+    PrimitiveType from_type,
+    absl::optional<PrimitiveType> preferred_element_type) {
+  if (!preferred_element_type.has_value() ||
+      *preferred_element_type == from_type) {
+    return from_type;
+  }
+  if (primitive_util::IsIntegralType(from_type) !=
+      primitive_util::IsIntegralType(*preferred_element_type)) {
+    return InvalidArgument(
+        "`preferred_element_type` and the original type must both be integral "
+        "or both be floating point.");
+  }
+  if (!primitive_util::IsSignedIntegralType(from_type) !=
+      !primitive_util::IsSignedIntegralType(*preferred_element_type)) {
+    return InvalidArgument(
+        "`preferred_element_type` must have the same signedness as the "
+        "original type.");
+  }
+  if (primitive_util::BitWidth(*preferred_element_type) <
+      primitive_util::BitWidth(from_type)) {
+    if (primitive_util::IsFloatingPointType(from_type)) {
+      return from_type;
+    }
+    return InvalidArgument(
+        "`preferred_element_type` must not be narrower than the original "
+        "type.");
+  }
+  return *preferred_element_type;
+}
+
 }  // namespace
 
 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
@@ -622,7 +653,8 @@ Status ValidateDotDimensionNumbers(
 
 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
     const Shape& lhs, const Shape& rhs,
-    const DotDimensionNumbers& dimension_numbers) {
+    const DotDimensionNumbers& dimension_numbers,
+    absl::optional<PrimitiveType> preferred_element_type) {
   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
 
@@ -700,8 +732,11 @@ Status ValidateDotDimensionNumbers(
       is_dynamic.push_back(rhs.is_dynamic_dimension(i));
     }
   }
-  Shape result = ShapeUtil::MakeShape(
-      ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic);
+  TF_ASSIGN_OR_RETURN(
+      PrimitiveType type,
+      MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
+                  preferred_element_type));
+  Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic);
 
   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
   VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
@@ -1586,7 +1621,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
     const Shape& lhs, const Shape& rhs, int64 feature_group_count,
     int64 batch_group_count, const Window& window,
-    const ConvolutionDimensionNumbers& dnums) {
+    const ConvolutionDimensionNumbers& dnums,
+    absl::optional<PrimitiveType> preferred_element_type) {
   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
 
@@ -1833,8 +1869,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
       }
     }
   }
-  return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
-                              dimensions, is_dynamic);
+  TF_ASSIGN_OR_RETURN(
+      PrimitiveType type,
+      MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
+                  preferred_element_type));
+  return ShapeUtil::MakeShape(type, dimensions, is_dynamic);
 }
 
 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 8ef72f7ac3f..f6d55c45334 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -105,12 +105,14 @@ class ShapeInference {
                                                  const Shape& output_grad_shape,
                                                  int64 feature_index);
 
-  // Infers the shape produced by applying the given convolutional
-  // filter (rhs) to lhs in the way specified by the fields on window.
+  // Infers the shape produced by applying the given convolutional filter (rhs)
+  // to lhs in the way specified by the fields on window. An optional
+  // preferred_element_type can be specified to upcast the element type.
   static StatusOr<Shape> InferConvolveShape(
       const Shape& lhs, const Shape& rhs, int64 feature_group_count,
       int64 batch_group_count, const Window& window,
-      const ConvolutionDimensionNumbers& dimension_numbers);
+      const ConvolutionDimensionNumbers& dimension_numbers,
+      absl::optional<PrimitiveType> preferred_element_type);
 
   // Infers the shape produced by the given FFT type on the given operand.
   static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
@@ -298,10 +300,12 @@ class ShapeInference {
       absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
 
   // Helper that infers the shape produced by performing a dot operation with
-  // the given LHS and RHS shapes.
+  // the given LHS and RHS shapes. An optional preferred_element_type can be
+  // specified to upcast the element type.
   static StatusOr<Shape> InferDotOpShape(
       const Shape& lhs, const Shape& rhs,
-      const DotDimensionNumbers& dimension_numbers);
+      const DotDimensionNumbers& dimension_numbers,
+      absl::optional<PrimitiveType> preferred_element_type);
 
   // Helper that infers the shape of the tensor produced by a gather operation
   // with the given input shape, gather indices shape and gather dimension
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 73bbe5cb3bd..77f84a69205 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -437,7 +437,7 @@ TEST_F(ShapeInferenceTest, Convolve) {
   dim1->set_base_dilation(1);
   auto inferred_status = ShapeInference::InferConvolveShape(
       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
-      window, dnums);
+      window, dnums, /*preferred_element_type=*/absl::nullopt);
   ASSERT_IS_OK(inferred_status.status());
   Shape inferred_shape = inferred_status.ValueOrDie();
   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
@@ -483,7 +483,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
   dim1->set_base_dilation(1);
   auto inferred_status = ShapeInference::InferConvolveShape(
       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
-      window, dnums);
+      window, dnums, /*preferred_element_type=*/absl::nullopt);
   ASSERT_IS_OK(inferred_status.status());
   Shape inferred_shape = inferred_status.ValueOrDie();
   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
@@ -529,7 +529,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
   dim1->set_base_dilation(2);
   auto inferred_status = ShapeInference::InferConvolveShape(
       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
-      window, dnums);
+      window, dnums, /*preferred_element_type=*/absl::nullopt);
   ASSERT_IS_OK(inferred_status.status());
   Shape inferred_shape = inferred_status.ValueOrDie();
   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
@@ -568,7 +568,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
   dim1->set_padding_high(1);
   auto inferred_status = ShapeInference::InferConvolveShape(
       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
-      window, dnums);
+      window, dnums, /*preferred_element_type=*/absl::nullopt);
   ASSERT_FALSE(inferred_status.ok());
   ASSERT_THAT(inferred_status.status().error_message(),
               HasSubstr("each dimension exactly once"));
@@ -605,12 +605,150 @@ TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) {
   dim1->set_window_dilation(2);
   auto inferred_status = ShapeInference::InferConvolveShape(
       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6,
-      window, dnums);
+      window, dnums, /*preferred_element_type=*/absl::nullopt);
   ASSERT_FALSE(inferred_status.ok());
   ASSERT_THAT(inferred_status.status().error_message(),
               HasSubstr("to be a multiple of batch group count"));
 }
 
+struct ConvolveArgs {
+  Shape lhs_shape;
+  Shape rhs_shape;
+  ConvolutionDimensionNumbers dnums;
+  Window window;
+};
+
+ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
+  ConvolveArgs args;
+  ConvolutionDimensionNumbers& dnums = args.dnums;
+
+  // Dimension order: batch, feature, x0, x1
+  args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4});
+  dnums.set_input_batch_dimension(0);
+  dnums.set_output_batch_dimension(0);
+  dnums.set_input_feature_dimension(1);
+  dnums.set_output_feature_dimension(1);
+  dnums.add_input_spatial_dimensions(2);
+  dnums.add_output_spatial_dimensions(2);
+  dnums.add_input_spatial_dimensions(3);
+  dnums.add_output_spatial_dimensions(3);
+
+  // Dimension order: x1, batch, feature, x0
+  args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3});
+  dnums.set_kernel_input_feature_dimension(2);
+  dnums.set_kernel_output_feature_dimension(1);
+  dnums.add_kernel_spatial_dimensions(3);
+  dnums.add_kernel_spatial_dimensions(0);
+
+  auto dim0 = args.window.add_dimensions();
+  auto dim1 = args.window.add_dimensions();
+  dim0->set_size(3);
+  dim0->set_stride(2);
+  dim0->set_padding_low(1);
+  dim0->set_padding_high(1);
+  dim0->set_window_dilation(1);
+  dim0->set_base_dilation(1);
+  dim1->set_size(2);
+  dim1->set_stride(1);
+  dim1->set_padding_low(0);
+  dim1->set_padding_high(0);
+  dim1->set_window_dilation(1);
+  dim1->set_base_dilation(1);
+  return args;
+}
+
+TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
+  ConvolveArgs args = MakeConvolveArgs(S8, S16);
+  TF_ASSERT_OK_AND_ASSIGN(
+      Shape inferred_shape,
+      ShapeInference::InferConvolveShape(
+          args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
+          /*batch_group_count=*/1, args.window, args.dnums,
+          /*preferred_element_type=*/S16))
+  ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}),
+                               inferred_shape));
+}
+
+TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) {
+  ConvolveArgs args = MakeConvolveArgs(S8, S16);
+  TF_ASSERT_OK_AND_ASSIGN(
+      Shape inferred_shape,
+      ShapeInference::InferConvolveShape(
+          args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
+          /*batch_group_count=*/1, args.window, args.dnums,
+          /*preferred_element_type=*/S32))
+  ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
+                               inferred_shape));
+}
+
+TEST_F(ShapeInferenceTest,
+       FloatingPointConvolveWithNarrowerPreferredElementType) {
+  ConvolveArgs args = MakeConvolveArgs(F32, F32);
+  TF_ASSERT_OK_AND_ASSIGN(
+      Shape inferred_shape,
+      ShapeInference::InferConvolveShape(
+          args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
+          /*batch_group_count=*/1, args.window, args.dnums,
+          /*preferred_element_type=*/BF16))
+  ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
+                               inferred_shape));
+}
+
+TEST_F(ShapeInferenceTest,
+       FloatingPointConvolveWithInvalidPreferredElementType) {
+  ConvolveArgs args = MakeConvolveArgs(BF16, BF16);
+  auto inferred_status =
+      ShapeInference::InferConvolveShape(
+          args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
+          /*batch_group_count=*/1, args.window, args.dnums,
+          /*preferred_element_type=*/S32)
+          .status();
+  ASSERT_FALSE(inferred_status.ok());
+  ASSERT_THAT(inferred_status.error_message(),
+              HasSubstr("must both be integral or both be floating point"));
+}
+
+TEST_F(ShapeInferenceTest,
+       IntegralConvolveWithFloatingPointPreferredElementType) {
+  ConvolveArgs args = MakeConvolveArgs(S8, S16);
+  auto inferred_status =
+      ShapeInference::InferConvolveShape(
+          args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
+          /*batch_group_count=*/1, args.window, args.dnums,
+          /*preferred_element_type=*/F32)
+          .status();
+  ASSERT_FALSE(inferred_status.ok());
+  ASSERT_THAT(inferred_status.error_message(),
+              HasSubstr("must both be integral or both be floating point"));
+}
+
+TEST_F(ShapeInferenceTest,
+       ConvolveWithPreferredElementTypeWithDifferentSignedness) {
+  ConvolveArgs args = MakeConvolveArgs(S8, S16);
+  auto inferred_status =
+      ShapeInference::InferConvolveShape(
+          args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
+          /*batch_group_count=*/1, args.window, args.dnums,
+          /*preferred_element_type=*/U32)
+          .status();
+  ASSERT_FALSE(inferred_status.ok());
+  ASSERT_THAT(inferred_status.error_message(),
+              HasSubstr("must have the same signedness as the original type"));
+}
+
+TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) {
+  ConvolveArgs args = MakeConvolveArgs(S8, S16);
+  auto inferred_status =
+      ShapeInference::InferConvolveShape(
+          args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
+          /*batch_group_count=*/1, args.window, args.dnums,
+          /*preferred_element_type=*/S8)
+          .status();
+  ASSERT_FALSE(inferred_status.ok());
+  ASSERT_THAT(inferred_status.error_message(),
+              HasSubstr("must not be narrower than the original type"));
+}
+
 namespace fft {
 
 static const char* unsupported_rank = "only supports ranks 1-3";
@@ -1282,8 +1420,8 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) {
 // scalar <dot> vector: ok
 TEST_F(ShapeInferenceTest, ScalarDotVector) {
   DotDimensionNumbers dot_dnums;
-  auto inferred_status =
-      ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums);
+  auto inferred_status = ShapeInference::InferDotOpShape(
+      f32_, vector_32_, dot_dnums, /*preferred_element_type=*/absl::nullopt);
   EXPECT_TRUE(inferred_status.ok());
   EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_);
 }
@@ -1294,7 +1432,8 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
   auto inferred_status = ShapeInference::InferDotOpShape(
-      ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums);
+      ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums,
+      /*preferred_element_type=*/absl::nullopt);
   EXPECT_TRUE(inferred_status.ok());
   EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
                                ShapeUtil::MakeShape(F32, {32, 32, 64})));
@@ -1306,11 +1445,13 @@ TEST_F(ShapeInferenceTest, VectorDotVector) {
   dot_dnums.add_lhs_contracting_dimensions(0);
   dot_dnums.add_rhs_contracting_dimensions(0);
   auto inferred_status =
-      ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums);
+      ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_IS_OK(inferred_status.status());
   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
   auto inferred_status_mismatch =
-      ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums);
+      ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_FALSE(inferred_status_mismatch.ok());
 }
 
@@ -1320,11 +1461,13 @@ TEST_F(ShapeInferenceTest, MatrixDotVector) {
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
   auto inferred_status =
-      ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums);
+      ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_IS_OK(inferred_status.status());
   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
   auto inferred_status_mismatch =
-      ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums);
+      ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_FALSE(inferred_status_mismatch.ok());
 }
 
@@ -1334,11 +1477,13 @@ TEST_F(ShapeInferenceTest, VectorDotMatrix) {
   dot_dnums.add_lhs_contracting_dimensions(0);
   dot_dnums.add_rhs_contracting_dimensions(0);
   auto inferred_status =
-      ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums);
+      ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_IS_OK(inferred_status.status());
   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
   auto inferred_status_mismatch =
-      ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums);
+      ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_FALSE(inferred_status_mismatch.ok());
 }
 
@@ -1348,7 +1493,8 @@ TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
   auto inferred_status_match =
-      ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums);
+      ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_IS_OK(inferred_status_match.status());
   ASSERT_TRUE(
       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
@@ -1356,7 +1502,8 @@ TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
       << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
   auto inferred_status_mismatch =
-      ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums);
+      ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_FALSE(inferred_status_mismatch.ok());
 }
 
@@ -1376,7 +1523,8 @@ TEST_F(ShapeInferenceTest, DotGeneral) {
   dot_dnums.add_rhs_batch_dimensions(1);
 
   auto inferred_status_match =
-      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_IS_OK(inferred_status_match.status());
   ASSERT_TRUE(
       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
@@ -1399,7 +1547,8 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
   dot_dnums.add_rhs_batch_dimensions(0);
 
   auto inferred_status =
-      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_FALSE(inferred_status.ok());
   ASSERT_THAT(inferred_status.status().error_message(),
               HasSubstr("Must specify the same number of contracting "
@@ -1421,7 +1570,8 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
   dot_dnums.add_rhs_batch_dimensions(0);
 
   auto inferred_status =
-      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   EXPECT_TRUE(inferred_status.ok());
   EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
 }
@@ -1461,7 +1611,8 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
   dot_dnums.add_rhs_batch_dimensions(0);
 
   auto inferred_status =
-      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_FALSE(inferred_status.ok());
   ASSERT_THAT(inferred_status.status().error_message(),
               HasSubstr("Batch dimension sizes must match"));
@@ -1480,7 +1631,8 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) {
   dot_dnums.add_rhs_batch_dimensions(1);
 
   auto inferred_status =
-      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_TRUE(inferred_status.ok());
   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
                                ShapeUtil::MakeShape(F32, {2, 11, 14})));
@@ -1499,7 +1651,8 @@ TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
   dot_dnums.add_rhs_batch_dimensions(1);
 
   auto inferred_status =
-      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_FALSE(inferred_status.ok());
   ASSERT_THAT(inferred_status.status().error_message(),
               HasSubstr("A dimension number is out of range"));
@@ -1518,12 +1671,108 @@ TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
   dot_dnums.add_rhs_batch_dimensions(1);
 
   auto inferred_status =
-      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+      ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
+                                      /*preferred_element_type=*/absl::nullopt);
   ASSERT_FALSE(inferred_status.ok());
   ASSERT_THAT(inferred_status.status().error_message(),
               HasSubstr("A dimension number is not unique"));
 }
 
+TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) {
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
+                          ShapeInference::InferDotOpShape(
+                              ShapeUtil::MakeShape(S8, {32, 32}),
+                              ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
+                              /*preferred_element_type=*/S32));
+  EXPECT_TRUE(
+      ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
+}
+
+TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) {
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
+                          ShapeInference::InferDotOpShape(
+                              ShapeUtil::MakeShape(BF16, {32, 32}),
+                              ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
+                              /*preferred_element_type=*/F32));
+  EXPECT_TRUE(
+      ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
+}
+
+TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
+                          ShapeInference::InferDotOpShape(
+                              ShapeUtil::MakeShape(BF16, {32, 32}),
+                              ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
+                              /*preferred_element_type=*/BF16));
+  EXPECT_TRUE(
+      ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
+}
+
+TEST_F(ShapeInferenceTest, FloatingPointDotWithInvalidPreferredElementType) {
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  auto inferred_status = ShapeInference::InferDotOpShape(
+                             ShapeUtil::MakeShape(BF16, {32, 32}),
+                             ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums,
+                             /*preferred_element_type=*/S32)
+                             .status();
+  ASSERT_FALSE(inferred_status.ok());
+  ASSERT_THAT(inferred_status.error_message(),
+              HasSubstr("must both be integral or both be floating point"));
+}
+
+TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) {
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  auto inferred_status = ShapeInference::InferDotOpShape(
+                             ShapeUtil::MakeShape(S8, {32, 32}),
+                             ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
+                             /*preferred_element_type=*/F32)
+                             .status();
+  ASSERT_FALSE(inferred_status.ok());
+  ASSERT_THAT(inferred_status.error_message(),
+              HasSubstr("must both be integral or both be floating point"));
+}
+
+TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) {
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  auto inferred_status = ShapeInference::InferDotOpShape(
+                             ShapeUtil::MakeShape(S8, {32, 32}),
+                             ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
+                             /*preferred_element_type=*/U32)
+                             .status();
+  ASSERT_FALSE(inferred_status.ok());
+  ASSERT_THAT(inferred_status.error_message(),
+              HasSubstr("must have the same signedness as the original type"));
+}
+
+TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) {
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  auto inferred_status = ShapeInference::InferDotOpShape(
+                             ShapeUtil::MakeShape(S8, {32, 32}),
+                             ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
+                             /*preferred_element_type=*/S8)
+                             .status();
+  ASSERT_FALSE(inferred_status.ok());
+  ASSERT_THAT(inferred_status.error_message(),
+              HasSubstr("must not be narrower than the original type"));
+}
+
 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
   // Test variations of broadcasting a vector for a binary add with a
   // matrix.
diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc
index 1b67f3563bc..05cdc6e24b7 100644
--- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc
+++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include <map>
 #include <memory>
 #include <queue>
+#include <tuple>
 #include <unordered_set>
 #include <utility>
 
@@ -64,6 +65,19 @@ class ConvolutionVisitor {
   // Top-level function to begin space-to-batch conversion.
   Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution);
 
+  // Struct containing details about a convolution.
+  struct ConvDetails {
+    int64 spatial_dimension_to_split, inherent_low_padding,
+        inherent_high_padding, stride, spatial_size, base_dilation_factor,
+        halo_size, high_padding_for_conv, low_padding_for_conv,
+        kernel_spatial_dim_size, input_dim_size;
+  };
+
+  // Return a struct containing various necessary information pieces for
+  // performing space-to-batch on a convolution.
+  ConvDetails GetConvolutionDetails(HloInstruction* convolution,
+                                    ConvolutionDimensionNumbers& dim_numbers);
+
   // Function that determines if space-to-batch can be propagated into the
   // consumer. Such propagation is only possible when all required operands are
   // space-to-batch'ed.
@@ -225,11 +239,29 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
     return false;
   }
 
-  // TODO(b/168316428): Support base dilations.
-  if (convolution->window()
-          .dimensions(get_chosen_spatial_dim(convolution))
-          .base_dilation() != 1) {
-    return false;
+  const ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
+
+  const int64 low_pad = convolution->window()
+                            .dimensions(get_chosen_spatial_dim(convolution))
+                            .padding_low();
+
+  // TODO(b/168316428): Support base dilations more generically.
+  if (c.base_dilation_factor != 1) {
+    if (c.stride != 1) {
+      return false;
+    }
+    // For low pad of 0, only support a pointwise kernel.
+    if (low_pad == 0) {
+      if (c.kernel_spatial_dim_size != 1) {
+        return false;
+      }
+    } else if (c.kernel_spatial_dim_size != c.base_dilation_factor + 1 ||
+               low_pad != c.base_dilation_factor - 1) {
+      // Only support dilations such that base dilation factor and low pad are
+      // compatible with kernel_spatial_dim_size to be compatible with
+      // HaloDuplicateWithSlice.
+      return false;
+    }
   }
 
   int64 activations_batch_dim = dim_numbers.input_batch_dimension();
@@ -240,42 +272,17 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
   if (old_batch_size > limit_on_batch_size_) {
     return false;
   }
-
-  auto kernel = convolution->mutable_operand(1);
-  const auto& kernel_shape = kernel->shape();
-  const int64 kernel_spatial_dim_size =
-      kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions(
-          get_chosen_spatial_dim(convolution)));
-
-  auto activations = convolution->mutable_operand(0);
-
-  const int64 input_dim_size =
-      activations->shape().dimensions(dim_numbers.input_spatial_dimensions(
-          get_chosen_spatial_dim(convolution)));
-
-  const int64 inherent_low_padding =
-      convolution->window()
-          .dimensions(get_chosen_spatial_dim(convolution))
-          .padding_low();
-  const int64 inherent_high_padding =
-      convolution->window()
-          .dimensions(get_chosen_spatial_dim(convolution))
-          .padding_high();
-
-  const int64 spatial_size =
-      input_dim_size + inherent_low_padding + inherent_high_padding;
-  VLOG(1) << "spatial size " << spatial_size;
-
-  const int64 num_splits = kNewBatchSize / old_batch_size;
-
   // We currently only cater to evenly divisible cases.
   if (kNewBatchSize % old_batch_size != 0) {
     return false;
   }
 
-  // Splitting will be incorrect in these cases.
-  if (spatial_size < num_splits ||
-      input_dim_size / num_splits < kernel_spatial_dim_size) {
+  VLOG(1) << "spatial size " << c.spatial_size;
+
+  const int64 num_splits = kNewBatchSize / old_batch_size;
+  // If the ratio is not within the 2X range, we can't Halo Pad from the next
+  // split.
+  if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) {
     return false;
   }
   VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
@@ -292,8 +299,8 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
       activations->shape().dimensions(spatial_dimension_to_split);
   const int64 batch_size =
       activations->shape().dimensions(activations_batch_dim);
-  CHECK_LT(low_padding, spatial_split_size);
 
+  CHECK_LE(std::abs(halo_size - low_padding), spatial_split_size);
   VLOG(1) << "In HaloDuplicateWithSlice with activations "
           << activations->ToString() << " batch_size " << batch_size
           << " spatial_split_size " << spatial_split_size << " low_padding "
@@ -439,6 +446,7 @@ StatusOr<bool> ConvolutionVisitor::Run() {
   // Iterate through all instructions that we could not propagate through, and
   // turn their operands from batch-to-space as needed.
   for (auto instr : non_propagatable_instrs_) {
+    VLOG(1) << "Could not eventually propagate through " << instr->ToString();
     absl::flat_hash_map<int64, HloInstruction*> operand_map;
     for (int64 i = 0; i < instr->operand_count(); ++i) {
       if (old_to_new_instrs_.count(instr->mutable_operand(i))) {
@@ -480,8 +488,9 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
 
       if (!old_to_new_instrs_.contains(old_producer) &&
           !broadcast_or_constant) {
-        VLOG(1) << "Cannot propagate on elementwise op "
-                << consumer->ToString();
+        VLOG(1) << "Cannot propagate on elementwise op " << consumer->ToString()
+                << " because operand " << old_producer->ToString()
+                << " isn't ready ";
         return false;
       } else {
         if (broadcast_or_constant) {
@@ -496,10 +505,11 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
           pivot_operand = old_producer;
           VLOG(2) << "Elementwise op: pivot " << old_producer->ToString();
         } else {
-          VLOG(2) << "Elementwise op: checking for shape equivalence "
-                  << consumer->ToString();
           if (instr_to_dim_map_[pivot_operand] !=
               instr_to_dim_map_[old_producer]) {
+            VLOG(2) << "Elementwise op: checking for shape equivalence "
+                    << consumer->ToString()
+                    << " failed due to changed batch space ordering ";
             return false;
           }
           auto pivot_new_instr = old_to_new_instrs_[pivot_operand];
@@ -509,13 +519,22 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
           for (int j = 0; j < pivot_permute_dims.size(); ++j) {
             // Ensure the dimension mapping is the same.
             if (pivot_permute_dims[j] != permute_dims[j]) {
+              VLOG(2) << "Elementwise op: checking for shape equivalence "
+                      << consumer->ToString()
+                      << " failed due to permuted dimensions ";
               return false;
             }
 
             // Make sure all other dimensions are of the same size.
             if (pivot_new_instr->shape().dimensions(j) !=
                 new_instr->shape().dimensions(j)) {
-              return false;
+              if (!(consumer->IsElementwiseBinary() &&
+                    j == instr_to_dim_map_[pivot_operand].second)) {
+                VLOG(2) << "Elementwise op: checking for shape equivalence "
+                        << consumer->ToString()
+                        << " failed due to changed shape sizes ";
+                return false;
+              }
             }
           }
         }
@@ -769,6 +788,28 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
   if (IsTrivialElementwise(consumer)) {
     auto dim_map_val = instr_to_dim_map_[producer];
     auto new_consumer = computation->AddInstruction(consumer->Clone());
+    if (consumer->IsElementwiseBinary()) {
+      for (int64 i = 0; i < 2; ++i) {
+        if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
+          break;
+        }
+        CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i)));
+        if (i == 1) {
+          // Choose the larger shape to be used as the producer.
+          if (old_to_new_instrs_[consumer->mutable_operand(0)]
+                  ->shape()
+                  .dimensions() >
+              old_to_new_instrs_[consumer->mutable_operand(1)]
+                  ->shape()
+                  .dimensions()) {
+            producer = consumer->mutable_operand(0);
+          } else {
+            producer = consumer->mutable_operand(1);
+          }
+        }
+      }
+    }
+
     for (int64 i = 0; i < consumer->operand_count(); ++i) {
       if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
         CHECK(old_to_new_instrs_.contains(producer));
@@ -786,8 +827,66 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
             new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
       } else {
         CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i)));
-        TF_CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape(
-            i, old_to_new_instrs_[consumer->mutable_operand(i)]));
+        HloInstruction* operand_to_use = nullptr;
+
+        auto result = instr_to_dim_map_[producer];
+        const int64 old_batch_dim = result.first;
+        const int64 old_space_dim = result.second;
+        const int64 old_batch_size =
+            producer->shape().dimensions(old_batch_dim);
+        HloInstruction* new_instr =
+            old_to_new_instrs_[consumer->mutable_operand(i)];
+        HloInstruction* pivot_new_instr = old_to_new_instrs_[producer];
+
+        auto permute_dims = instr_to_dim_permute_map_[new_instr];
+        const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
+        const int64 space_dim = DimLookUp(permute_dims, old_space_dim);
+        const int64 batch_size = new_instr->shape().dimensions(batch_dim);
+
+        if (new_instr->shape().dimensions(space_dim) !=
+            pivot_new_instr->shape().dimensions(space_dim)) {
+          // Because we do not propagate through transposes, the batch should
+          // always be followed by the split space dimension.
+          CHECK_EQ(batch_dim + 1, space_dim);
+
+          // Reshape to 1D, pad to the producer's size, reshape back to 2D.
+          std::vector<int64> new_dimensions(
+              new_instr->shape().dimensions().begin(),
+              new_instr->shape().dimensions().end());
+          new_dimensions[space_dim] *= (batch_size / old_batch_size);
+          new_dimensions[batch_dim] = old_batch_size;
+
+          TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
+                              MakeReshapeHlo(new_dimensions, new_instr));
+
+          const int64 pivot_space_size =
+              pivot_new_instr->shape().dimensions(space_dim) * batch_size /
+              old_batch_size;
+
+          CHECK_GT(pivot_space_size, new_dimensions[space_dim]);
+
+          PaddingConfig padding_config =
+              MakeNoPaddingConfig(reshape->shape().dimensions_size());
+          padding_config.mutable_dimensions(space_dim)->set_edge_padding_high(
+              pivot_space_size - new_dimensions[space_dim]);
+          padding_config.mutable_dimensions(space_dim)->set_edge_padding_low(0);
+          HloInstruction* padding =
+              computation_->AddInstruction(HloInstruction::CreateConstant(
+                  LiteralUtil::Zero(reshape->shape().element_type())));
+
+          TF_ASSIGN_OR_RETURN(HloInstruction * padded_operand,
+                              MakePadHlo(reshape, padding, padding_config));
+
+          TF_ASSIGN_OR_RETURN(
+              operand_to_use,
+              MakeReshapeHlo(pivot_new_instr->shape().dimensions(),
+                             padded_operand));
+
+        } else {
+          operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)];
+        }
+        TF_CHECK_OK(
+            new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
       }
     }
     auto old_type = new_consumer->mutable_shape()->element_type();
@@ -1329,25 +1428,21 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
                      original_conv_dims.input_spatial_dimensions(i)));
   }
 
-  int64 spatial_dimension_to_split =
-      permuted_conv_dims_numbers.input_spatial_dimensions(
-          get_chosen_spatial_dim(convolution));
-
   const int64 old_batch_dim = original_conv_dims.input_batch_dimension();
   const int64 old_batch_size =
       activations_old->shape().dimensions(old_batch_dim);
 
-  const int64 input_dim_size = activations_old->shape().dimensions(
-      permuted_conv_dims_numbers.input_spatial_dimensions(
-          get_chosen_spatial_dim(convolution)));
+  ConvDetails c =
+      GetConvolutionDetails(convolution, permuted_conv_dims_numbers);
 
   VLOG(1) << "Propagating on conv activations_batch_dim "
           << activations_batch_dim << " spatial_dimension_to_split "
-          << spatial_dimension_to_split << " old_batch_size " << old_batch_size;
-  TF_ASSIGN_OR_RETURN(
-      activations_new,
-      BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
-                            spatial_dimension_to_split, activations_batch_dim));
+          << c.spatial_dimension_to_split << " old_batch_size "
+          << old_batch_size;
+  TF_ASSIGN_OR_RETURN(activations_new,
+                      BringSpaceNextToBatch(
+                          activations_new, permuted_conv_dims_numbers,
+                          c.spatial_dimension_to_split, activations_batch_dim));
 
   auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
       LiteralUtil::Zero(activations_new->shape().element_type())));
@@ -1355,32 +1450,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
   TF_ASSIGN_OR_RETURN(
       activations_new,
       SelectValidPortion(activations_new, activations_old, select_val,
-                         activations_batch_dim, spatial_dimension_to_split,
+                         activations_batch_dim, c.spatial_dimension_to_split,
                          old_batch_dim, old_space_dim));
   // Create the new convolution dim numbers.
   auto new_dim_numbers = permuted_conv_dims_numbers;
 
-  auto kernel = convolution->mutable_operand(1);
-  const auto& kernel_shape = kernel->shape();
-  const int64 kernel_spatial_dim_size = kernel_shape.dimensions(
-      permuted_conv_dims_numbers.kernel_spatial_dimensions(
-          get_chosen_spatial_dim(convolution)));
-
-  const int64 inherent_low_padding =
-      convolution->window()
-          .dimensions(get_chosen_spatial_dim(convolution))
-          .padding_low();
-  const int64 inherent_high_padding =
-      convolution->window()
-          .dimensions(get_chosen_spatial_dim(convolution))
-          .padding_high();
-  const int64 stride = convolution->window()
-                           .dimensions(get_chosen_spatial_dim(convolution))
-                           .stride();
-
-  const int64 spatial_size =
-      input_dim_size + inherent_low_padding + inherent_high_padding;
-  VLOG(1) << "spatial size " << spatial_size;
+  VLOG(1) << "spatial size " << c.spatial_size;
 
   const int64 num_splits = kNewBatchSize / old_batch_size;
 
@@ -1390,18 +1465,18 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
   const int64 output_offsets_per_split =
       CeilOfRatio(output_offsets, num_splits);
 
-  int64 spatial_split_size = output_offsets_per_split * stride;
-  const int64 halo_size =
-      std::max(kernel_spatial_dim_size - stride, static_cast<int64>(0));
+  int64 spatial_split_size =
+      CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
+
   // Keep increasing the split size so that overall size isn't smaller than the
   // original spatial dimension. Unlike for the first space-to-batch'ed
   // convolution, while propagating, we can use the last halo_size as available
   // spatial size.
-  while (spatial_split_size * num_splits + halo_size - spatial_size < 0) {
-    spatial_split_size += stride;
+  while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0) {
+    spatial_split_size += c.stride;
   }
 
-  int64 slice_size = spatial_split_size + halo_size;
+  int64 slice_size = spatial_split_size + c.halo_size;
 
   VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size "
           << slice_size;
@@ -1409,7 +1484,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
   const int64 new_batch_size =
       activations_new->shape().dimensions(activations_batch_dim);
   const int64 new_space_size =
-      activations_new->shape().dimensions(spatial_dimension_to_split);
+      activations_new->shape().dimensions(c.spatial_dimension_to_split);
   // In the below case, we cannot use the activations directly for Halo
   // Duplication. We must reshape them.
   if (spatial_split_size > new_space_size) {
@@ -1418,7 +1493,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
         activations_new->shape().dimensions().end());
     const int64 reshaped_space_size =
         new_space_size * new_batch_size / old_batch_size;
-    new_dimensions[spatial_dimension_to_split] = reshaped_space_size;
+    new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
     new_dimensions[activations_batch_dim] = old_batch_size;
 
     // Reshape the output of the new conv into the old convolutions shape.
@@ -1427,10 +1502,10 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
 
     PaddingConfig padding_config =
         MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
-    padding_config.mutable_dimensions(spatial_dimension_to_split)
+    padding_config.mutable_dimensions(c.spatial_dimension_to_split)
         ->set_edge_padding_high(spatial_split_size * new_batch_size -
                                 reshaped_space_size);
-    padding_config.mutable_dimensions(spatial_dimension_to_split)
+    padding_config.mutable_dimensions(c.spatial_dimension_to_split)
         ->set_edge_padding_low(0);
     HloInstruction* padding =
         computation_->AddInstruction(HloInstruction::CreateConstant(
@@ -1444,7 +1519,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
         reshaped_activations->shape().dimensions().begin(),
         reshaped_activations->shape().dimensions().end());
 
-    reshape_back_dims[spatial_dimension_to_split] = spatial_split_size;
+    reshape_back_dims[c.spatial_dimension_to_split] = spatial_split_size;
     reshape_back_dims[activations_batch_dim] = new_batch_size;
 
     TF_ASSIGN_OR_RETURN(
@@ -1453,34 +1528,38 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
 
     TF_ASSIGN_OR_RETURN(
         activations_new,
-        HaloDuplicateWithSlice(reshaped_activations, spatial_dimension_to_split,
-                               activations_batch_dim, old_batch_size,
-                               /*low_padding=*/inherent_low_padding,
-                               /*high_padding=*/inherent_high_padding,
-                               slice_size - spatial_split_size,
-                               old_split_dim_size));
+        HaloDuplicateWithSlice(
+            reshaped_activations, c.spatial_dimension_to_split,
+            activations_batch_dim, old_batch_size,
+            /*low_padding=*/c.base_dilation_factor != 1 &&
+                    c.inherent_low_padding != 0
+                ? c.base_dilation_factor - 1
+                : c.inherent_low_padding,
+            c.inherent_high_padding, slice_size - spatial_split_size,
+            old_split_dim_size));
   } else {
     // If the ideal spatial_split_size was smaller than the incoming spatial
     // dimension size, we don't need reshaping. Instead, we determine the
     // additional space available, and adjust the required slice size (and
-    // thereby the halo size).'t need reshaping. Instead, we determine the
-    // additional space available, and adjust the required slice size (and
     // thereby the halo size).
     if (spatial_split_size < new_space_size) {
-      const int64 additional_space_present = spatial_split_size % stride;
+      const int64 additional_space_present = spatial_split_size % c.stride;
       spatial_split_size = new_space_size;
       slice_size =
-          spatial_split_size +
-          std::max(kernel_spatial_dim_size - stride - additional_space_present,
-                   static_cast<int64>(0));
+          spatial_split_size + std::max(c.kernel_spatial_dim_size - c.stride -
+                                            additional_space_present,
+                                        static_cast<int64>(0));
     }
 
     TF_ASSIGN_OR_RETURN(
         activations_new,
-        HaloDuplicateWithSlice(activations_new, spatial_dimension_to_split,
+        HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
                                activations_batch_dim, old_batch_size,
-                               /*low_padding=*/inherent_low_padding,
-                               /*high_padding=*/inherent_high_padding,
+                               /*low_padding=*/c.base_dilation_factor != 1 &&
+                                       c.inherent_low_padding != 0
+                                   ? c.base_dilation_factor - 1
+                                   : c.inherent_low_padding,
+                               c.inherent_high_padding,
                                slice_size - spatial_split_size,
                                old_split_dim_size));
   }
@@ -1515,15 +1594,16 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
 
   auto new_window = convolution->window();
   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
-      ->set_padding_high(0);
+      ->set_padding_high(c.high_padding_for_conv);
   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
-      ->set_padding_low(0);
+      ->set_padding_low(c.low_padding_for_conv);
   TF_ASSIGN_OR_RETURN(
       HloInstruction * new_conv,
-      MakeConvolveHlo(activations_new, /*rhs=*/convolution->mutable_operand(1),
-                      convolution->feature_group_count(),
-                      convolution->batch_group_count(), new_window,
-                      new_dim_numbers, convolution->precision_config()));
+      MakeConvolveHlo(
+          activations_new, /*rhs=*/convolution->mutable_operand(1),
+          convolution->feature_group_count(), convolution->batch_group_count(),
+          new_window, new_dim_numbers, convolution->precision_config(),
+          /*preferred_element_type=*/convolution->shape().element_type()));
   convolution->SetupDerivedInstruction(new_conv);
 
   old_to_new_instrs_[convolution] = new_conv;
@@ -1800,10 +1880,11 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * new_conv,
-      MakeConvolveHlo(activations_new, kernel_new,
-                      convolution->feature_group_count(),
-                      convolution->batch_group_count(), new_window,
-                      new_dim_numbers, convolution->precision_config()));
+      MakeConvolveHlo(
+          activations_new, kernel_new, convolution->feature_group_count(),
+          convolution->batch_group_count(), new_window, new_dim_numbers,
+          convolution->precision_config(),
+          /*preferred_element_type=*/convolution->shape().element_type()));
   convolution->SetupDerivedInstruction(new_conv);
 
   std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(),
@@ -1853,19 +1934,9 @@ HloInstruction* ConvolutionVisitor::DoesConvolutionFeedReduceWindow(
   return nullptr;
 }
 
-Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
-    HloInstruction* convolution) {
-  VLOG(1) << "Handling conv " << convolution->ToString();
-
-  changed_ = false;
-
-  ConvolutionDimensionNumbers dim_numbers =
-      convolution->convolution_dimension_numbers();
-
-  int64 activations_batch_dim = dim_numbers.input_batch_dimension();
-
-  const int64 old_batch_size =
-      convolution->operand(0)->shape().dimensions(activations_batch_dim);
+ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails(
+    HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
+  auto activations = convolution->mutable_operand(0);
 
   auto kernel = convolution->mutable_operand(1);
   const auto& kernel_shape = kernel->shape();
@@ -1873,14 +1944,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
       kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions(
           get_chosen_spatial_dim(convolution)));
 
-  auto activations = convolution->mutable_operand(0);
-
-  int64 spatial_dimension_to_split =
+  const int64 spatial_dimension_to_split =
       dim_numbers.input_spatial_dimensions(get_chosen_spatial_dim(convolution));
 
   const int64 input_dim_size =
-      activations->shape().dimensions(dim_numbers.input_spatial_dimensions(
-          get_chosen_spatial_dim(convolution)));
+      activations->shape().dimensions(spatial_dimension_to_split);
 
   const int64 inherent_low_padding =
       convolution->window()
@@ -1890,26 +1958,75 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
       convolution->window()
           .dimensions(get_chosen_spatial_dim(convolution))
           .padding_high();
-  const bool inherent_padding_needed =
-      inherent_low_padding != 0 || inherent_high_padding != 0;
 
   const int64 stride = convolution->window()
                            .dimensions(get_chosen_spatial_dim(convolution))
                            .stride();
 
+  const int64 base_dilation_factor =
+      convolution->window()
+          .dimensions(get_chosen_spatial_dim(convolution))
+          .base_dilation();
+
   const int64 spatial_size =
-      input_dim_size + inherent_low_padding + inherent_high_padding;
-  VLOG(1) << "spatial size " << spatial_size;
+      input_dim_size + (base_dilation_factor > 1 ? 0 : inherent_low_padding) +
+      inherent_high_padding;
+
+  const int64 halo_size =
+      std::max(kernel_spatial_dim_size - stride - (base_dilation_factor - 1),
+               static_cast<int64>(0));
+  const int64 high_padding_for_conv = base_dilation_factor == 1 ? 0
+                                      : inherent_low_padding == 0
+                                          ? base_dilation_factor - 1
+                                          : 0;
+  const int64 low_padding_for_conv =
+      base_dilation_factor == 1 ? 0 : inherent_low_padding;
+
+  return ConvDetails{spatial_dimension_to_split,
+                     inherent_low_padding,
+                     inherent_high_padding,
+                     stride,
+                     spatial_size,
+                     base_dilation_factor,
+                     halo_size,
+                     high_padding_for_conv,
+                     low_padding_for_conv,
+                     kernel_spatial_dim_size,
+                     input_dim_size};
+}
+
+Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
+    HloInstruction* convolution) {
+  VLOG(1) << "Handling conv " << convolution->ToString();
+
+  changed_ = false;
+
+  ConvolutionDimensionNumbers dim_numbers =
+      convolution->convolution_dimension_numbers();
+
+  ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
+
+  int64 activations_batch_dim = dim_numbers.input_batch_dimension();
+
+  const int64 old_batch_size =
+      convolution->operand(0)->shape().dimensions(activations_batch_dim);
+
+  auto activations = convolution->mutable_operand(0);
+
+  const bool inherent_padding_needed =
+      c.inherent_low_padding != 0 || c.inherent_high_padding != 0;
+
+  VLOG(1) << "spatial size " << c.spatial_size;
 
   const int64 num_splits = kNewBatchSize / old_batch_size;
   auto original_conv = convolution;
 
   // We'd need transposition of activations here such that batch and space dim
   // that is being split are adjacent (in that order).
-  TF_ASSIGN_OR_RETURN(
-      activations,
-      BringSpaceNextToBatch(activations, dim_numbers,
-                            spatial_dimension_to_split, activations_batch_dim));
+  TF_ASSIGN_OR_RETURN(activations,
+                      BringSpaceNextToBatch(activations, dim_numbers,
+                                            c.spatial_dimension_to_split,
+                                            activations_batch_dim));
   // Create the new convolution dim numbers.
   auto new_dim_numbers = dim_numbers;
 
@@ -1920,11 +2037,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
   const int64 output_offsets_per_split =
       CeilOfRatio(output_offsets, num_splits);
 
-  int64 spatial_split_size = output_offsets_per_split * stride;
+  int64 spatial_split_size =
+      CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
   // Keep increasing the split size so that overall size isn't smaller than the
   // original spatial dimension.
-  while (spatial_split_size * num_splits - spatial_size < 0) {
-    spatial_split_size += stride;
+  while (spatial_split_size * num_splits - c.spatial_size < 0) {
+    spatial_split_size += c.stride;
   }
 
   auto reduce_window = DoesConvolutionFeedReduceWindow(convolution);
@@ -1936,33 +2054,32 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
     // windows.
     const int64 red_win_stride =
         reduce_window->window().dimensions(output_spatial_dim).stride();
-    while ((spatial_split_size / stride) % red_win_stride != 0) {
-      spatial_split_size += stride;
+    while ((spatial_split_size / c.stride) % red_win_stride != 0) {
+      spatial_split_size += c.stride;
     }
   }
 
-  const int64 slice_size =
-      spatial_split_size +
-      std::max(kernel_spatial_dim_size - stride, static_cast<int64>(0));
+  const int64 slice_size = spatial_split_size + c.halo_size;
 
   // Pad spatial dim.
-  const int64 pad_size = spatial_split_size * num_splits - spatial_size;
+  const int64 pad_size = spatial_split_size * num_splits - c.spatial_size;
 
   VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
-          << stride;
-  VLOG(1) << "spatial_dimension_to_split " << spatial_dimension_to_split
+          << c.stride << " slice_size " << slice_size;
+  VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
           << " num_splits " << num_splits << " kernel_spatial_dim_size "
-          << kernel_spatial_dim_size;
+          << c.kernel_spatial_dim_size;
 
   // Because we are splitting the spatial dimension, if convolution needed
   // padding in the spatial dimension, we materialize it.
   if (pad_size != 0 || inherent_padding_needed) {
     PaddingConfig padding_config =
         MakeNoPaddingConfig(activations->shape().dimensions_size());
-    padding_config.mutable_dimensions(spatial_dimension_to_split)
-        ->set_edge_padding_high(inherent_high_padding + pad_size);
-    padding_config.mutable_dimensions(spatial_dimension_to_split)
-        ->set_edge_padding_low(inherent_low_padding);
+    padding_config.mutable_dimensions(c.spatial_dimension_to_split)
+        ->set_edge_padding_high(c.inherent_high_padding + pad_size);
+    padding_config.mutable_dimensions(c.spatial_dimension_to_split)
+        ->set_edge_padding_low(
+            c.base_dilation_factor == 1 ? c.inherent_low_padding : 0);
     HloInstruction* padding =
         computation_->AddInstruction(HloInstruction::CreateConstant(
             LiteralUtil::Zero(activations->shape().element_type())));
@@ -1989,7 +2106,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
       activations->shape().dimensions().begin(),
       activations->shape().dimensions().end());
 
-  reshape_dimensions[spatial_dimension_to_split] = spatial_split_size;
+  reshape_dimensions[c.spatial_dimension_to_split] = spatial_split_size;
   reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size;
 
   TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
@@ -1998,12 +2115,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
 
   VLOG(1) << "First reshape done " << batch_increased_reshape->ToString();
 
-  TF_ASSIGN_OR_RETURN(activations,
-                      HaloDuplicateWithSlice(
-                          batch_increased_reshape, spatial_dimension_to_split,
-                          activations_batch_dim, old_batch_size,
-                          /*low_padding=*/0, /*high_padding=*/0,
-                          slice_size - spatial_split_size, input_dim_size));
+  TF_ASSIGN_OR_RETURN(
+      activations, HaloDuplicateWithSlice(batch_increased_reshape,
+                                          c.spatial_dimension_to_split,
+                                          activations_batch_dim, old_batch_size,
+                                          /*low_padding=*/0, /*high_padding=*/0,
+                                          c.halo_size, c.input_dim_size));
 
   VLOG(1) << "Batch merge done " << activations->ToString();
 
@@ -2038,15 +2155,16 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
           << " batch dim " << new_dim_numbers.input_batch_dimension();
   auto new_window = convolution->window();
   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
-      ->set_padding_high(0);
+      ->set_padding_high(c.high_padding_for_conv);
   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
-      ->set_padding_low(0);
+      ->set_padding_low(c.low_padding_for_conv);
   TF_ASSIGN_OR_RETURN(
       HloInstruction * new_conv,
-      MakeConvolveHlo(activations, /*rhs=*/convolution->mutable_operand(1),
-                      convolution->feature_group_count(),
-                      convolution->batch_group_count(), new_window,
-                      new_dim_numbers, convolution->precision_config()));
+      MakeConvolveHlo(
+          activations, /*rhs=*/convolution->mutable_operand(1),
+          convolution->feature_group_count(), convolution->batch_group_count(),
+          new_window, new_dim_numbers, convolution->precision_config(),
+          /*preferred_element_type=*/convolution->shape().element_type()));
   convolution->SetupDerivedInstruction(new_conv);
 
   VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
index d53bb7d75f3..8921d98cad0 100644
--- a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
+++ b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
@@ -113,7 +113,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
   EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
 }
 
-TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) {
+TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithBaseDilation) {
   string hlo_string = R"(
   
   HloModule module
@@ -129,8 +129,22 @@ ENTRY computation {
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
                           ParseAndReturnVerifiedModule(hlo_string));
 
+  auto computation = module->entry_computation();
   ConvolutionSpaceToBatchConverter converter;
-  ASSERT_FALSE(converter.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
+
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_THAT(root, op::Transpose());
+  EXPECT_THAT(root->operand(0), op::Slice());
+  auto reshape = root->operand(0)->operand(0);
+  EXPECT_THAT(reshape, op::Reshape());
+  EXPECT_THAT(reshape->operand(0)->operand(1), op::Convolution());
+  const int64 batch_dim = reshape->operand(0)
+                              ->operand(1)
+                              ->convolution_dimension_numbers()
+                              .output_batch_dimension();
+
+  EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc
index 0d34c5b62e9..7130ba3dd5b 100644
--- a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc
@@ -950,7 +950,8 @@ StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
       Shape sharded_conv_shape,
       ShapeInference::InferConvolveShape(
           sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
-          feature_group_count, batch_group_count, window, conv_dnums));
+          feature_group_count, batch_group_count, window, conv_dnums,
+          /*preferred_element_type=*/conv.shape().element_type()));
   *sharded_conv_shape.mutable_layout() = conv.shape().layout();
   return HloInstruction::CreateConvolve(
       sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,
diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
index a346d8778d6..bb666221a56 100644
--- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
@@ -80,8 +80,9 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
           const Window& conv_window) -> StatusOr<HloInstruction*> {
     TF_ASSIGN_OR_RETURN(
         auto sharded_dot_shape,
-        ShapeInference::InferDotOpShape(l->shape(), r->shape(),
-                                        hlo->dot_dimension_numbers()));
+        ShapeInference::InferDotOpShape(
+            l->shape(), r->shape(), hlo->dot_dimension_numbers(),
+            /*preferred_element_type=*/hlo->shape().element_type()));
     return b->AddInstruction(HloInstruction::CreateDot(
         sharded_dot_shape, l, r, hlo->dot_dimension_numbers(),
         hlo->precision_config()));
@@ -1289,6 +1290,14 @@ StatusOr<HloInstruction*> PartitionDot(
     const SpmdPartitionerOptions& options, SpmdBuilder* b,
     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
         windowed_dot_general_loops) {
+  // If lhs‘ hlo and rhs' hlo are identical, make a copy for rhs.
+  if (lhs.hlo() == rhs.hlo()) {
+    auto copy_hlo = b->AddInstruction(HloInstruction::CreateUnary(
+        rhs.hlo()->shape(), HloOpcode::kCopy, rhs.hlo()));
+    copy_hlo->set_sharding(rhs.sharding());
+    rhs = PartitionedHlo(copy_hlo, rhs.base_shape(), rhs.state());
+  }
+
   // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
   auto get_partitions_for_dims =
       [&](const HloSharding& sharding,
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
index a70fdd9b828..4c1fb336439 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
@@ -6538,6 +6538,35 @@ ENTRY entry {
                           op::Shape("c64[1,1,3]")));
 }
 
+TEST_F(SpmdPartitioningTest, DotInputsAreIdentical) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %parameter.1 = f32[4000,4000]{1,0} parameter(0),
+    sharding={devices=[2,4]0,1,2,3,4,5,6,7}
+  ROOT %convolution = f32[4000,4000]{1,0} convolution(
+    f32[4000,4000]{1,0} %parameter.1, f32[4000,4000]{1,0} %parameter.1),
+    dim_labels=bf_io->bf, sharding={devices=[2,4]0,1,2,3,4,5,6,7}
+}
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/8));
+  VLOG(1) << module->ToString();
+  auto root = module->entry_computation()->root_instruction();
+  auto param = AllOf(op::Parameter(), op::Shape("f32[2000, 1000]"));
+  auto resharded_lhs =
+      AllOf(op::AllReduce(op::DynamicUpdateSlice(_, param, _, _)),
+            op::Shape("f32[2000, 4000]"));
+  auto resharded_rhs =
+      AllOf(op::AllReduce(op::DynamicUpdateSlice(_, op::Copy(param), _, _)),
+            op::Shape("f32[4000, 1000]"));
+  EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, resharded_rhs),
+                          op::Shape("f32[2000, 1000]")));
+}
+
 }  // namespace
 }  // namespace spmd
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 3fe69d22e9c..18223406da8 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -242,7 +242,8 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
   }
   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
       x->shape(), transpose_y->shape(), /*feature_group_count=*/1,
-      /*batch_group_count=*/1, window, dnums);
+      /*batch_group_count=*/1, window, dnums,
+      /*preferred_element_type=*/absl::nullopt);
   EXPECT_IS_OK(conv_shape);
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       conv_shape.ValueOrDie(), x, transpose_y,
@@ -298,7 +299,8 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
   }
   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
       x->shape(), transpose_y->shape(), /*feature_group_count=*/1,
-      /*batch_group_count=*/1, window, dnums);
+      /*batch_group_count=*/1, window, dnums,
+      /*preferred_element_type=*/absl::nullopt);
   EXPECT_IS_OK(conv_shape);
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       conv_shape.ValueOrDie(), x, transpose_y,
@@ -359,7 +361,8 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
   }
   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
       transpose_x->shape(), y->shape(), /*feature_group_count=*/1,
-      /*batch_group_count=*/1, window, dnums);
+      /*batch_group_count=*/1, window, dnums,
+      /*preferred_element_type=*/absl::nullopt);
   EXPECT_IS_OK(conv_shape);
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       conv_shape.ValueOrDie(), transpose_x, y,
@@ -426,7 +429,8 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
   }
   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
       transpose_x->shape(), y->shape(), /*feature_group_count=*/1,
-      /*batch_group_count=*/1, window, dnums);
+      /*batch_group_count=*/1, window, dnums,
+      /*preferred_element_type=*/absl::nullopt);
   EXPECT_IS_OK(conv_shape);
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       conv_shape.ValueOrDie(), transpose_x, y,
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index cb0edfb6be6..e84a2591707 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -52,6 +52,33 @@ namespace xla {
 using absl::StrAppend;
 using absl::StrCat;
 
+namespace {
+// An array that is indexed by PrimitiveType, and returns
+// the size of each element of that primitive type, or 0
+// if the PrimitiveType is not a primitive type
+constexpr uint8 primitive_byte_size[PrimitiveType_ARRAYSIZE] = {
+    0,                  // PRIMITIVE_TYPE_INVALID = 0,
+    sizeof(int8),       // PRED = 1
+    sizeof(int8),       // S8 = 2
+    sizeof(int16),      // S16 = 3
+    sizeof(int32),      // S32 = 4
+    sizeof(int64),      // S64 = 5
+    sizeof(uint8),      // U8 = 6
+    sizeof(uint16),     // U16 = 7
+    sizeof(uint32),     // U32 = 8
+    sizeof(uint64),     // U64 = 9
+    sizeof(float) / 2,  // F16 = 10
+    sizeof(float),      // F32 = 11
+    sizeof(double),     // F64 = 12
+    0,                  // TUPLE = 13
+    0,                  // OPAQUE_TYPE = 14
+    sizeof(complex64),  // C64 = 15
+    sizeof(float) / 2,  // BF16 = 16
+    0,                  // TOKEN = 17
+    sizeof(complex128)  // C128 = 18
+};
+}  // namespace
+
 string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
 
 string ShapeIndexView::ToString() const {
@@ -175,6 +202,42 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
   return accum;
 }
 
+/* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type,
+                                          absl::Span<const int64> dimensions,
+                                          Shape* shape) {
+  const int eint = static_cast<int>(element_type);
+  int64 dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE)
+                                ? primitive_byte_size[eint]
+                                : 0);  // Out of range: force a failure
+  if (dense_shape_size <= 0) {
+    return false;
+  }
+
+  // Verify that array-based lookup is consistent with public API.
+  DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type))
+      << element_type;
+
+  shape->set_element_type(element_type);
+  const int ndims = dimensions.size();
+  auto layout = shape->mutable_layout();
+  layout->set_format(DENSE);
+  auto* minor_to_major = layout->mutable_minor_to_major();
+  for (int i = 0; i < ndims; i++) {
+    const int64 d = dimensions[i];
+    if (d < 0) {
+      return false;
+    }
+    dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
+    if (dense_shape_size < 0) {
+      return false;
+    }
+
+    shape->add_dimensions(d);
+    minor_to_major->push_back(ndims - 1 - i);
+  }
+  return true;
+}
+
 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
     std::initializer_list<Shape> parameters, Shape result) {
   ProgramShape program_shape;
@@ -187,7 +250,9 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
 
 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
                                         absl::Span<const int64> dimensions) {
-  return MakeValidatedShape(element_type, dimensions).ValueOrDie();
+  Shape shape;
+  CHECK(FillNewShape(element_type, dimensions, &shape));
+  return shape;
 }
 
 /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
@@ -210,18 +275,31 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
 
 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
     PrimitiveType element_type, absl::Span<const int64> dimensions) {
-  CHECK(IsArrayPrimitiveType(element_type)) << element_type;
-  Shape result;
-  TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result));
-  return result;
+  Shape shape;
+  if (!FillNewShape(element_type, dimensions, &shape)) {
+    return InvalidArgument("invalid shape type=%d, dims=[%s]",
+                           static_cast<int>(element_type),
+                           absl::StrJoin(dimensions, ","));
+  }
+  return shape;
 }
 
 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
     PrimitiveType element_type, absl::Span<const int64> dimensions,
     const std::vector<bool>& dynamic_dimensions) {
-  TF_ASSIGN_OR_RETURN(Shape shape,
-                      MakeValidatedShape(element_type, dimensions));
-  for (int i = 0; i < dynamic_dimensions.size(); ++i) {
+  if (dynamic_dimensions.size() != dimensions.size()) {
+    return InvalidArgument(
+        "dynamic dimensions size %d did not match number of dimensions %d",
+        dynamic_dimensions.size(), dimensions.size());
+  }
+
+  Shape shape;
+  if (!FillNewShape(element_type, dimensions, &shape)) {
+    return InvalidArgument("invalid shape type=%d, dims=[%s]",
+                           static_cast<int>(element_type),
+                           absl::StrJoin(dimensions, ","));
+  }
+  for (int i = 0, n = dimensions.size(); i < n; i++) {
     shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
   }
   return shape;
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index c1a6a2c8b1d..ff47ab6ea80 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -792,6 +792,11 @@ class ShapeUtil {
   static bool CanUpcastIntegral(const Shape& from, const Shape& to);
 
  private:
+  // Fills *shape. Returns true on success.
+  // REQUIRES: *shape is empty.
+  static bool FillNewShape(PrimitiveType element_type,
+                           absl::Span<const int64> dimensions, Shape* shape);
+
   // Validates the shape size is sane. This makes sure it's safe to do
   // calculations in int64 without overflowing.
   static Status ValidateShapeSize(const Shape& shape);
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 4e2030667ee..1a944d01941 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test_benchmark.h"
 
 namespace xla {
 namespace {
@@ -827,5 +828,19 @@ TEST(AlignmentTest,
   EXPECT_FALSE(aligned_shape);
 }
 
+void BM_MakeShape(::testing::benchmark::State& state) {
+  for (auto s : state) {
+    ShapeUtil::MakeShape(F32, {2});
+  }
+}
+BENCHMARK(BM_MakeShape);
+
+void BM_MakeValidatedShape(::testing::benchmark::State& state) {
+  for (auto s : state) {
+    ShapeUtil::MakeValidatedShape(F32, {2}).ValueOrDie();
+  }
+}
+BENCHMARK(BM_MakeValidatedShape);
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index fe27a8c6963..69916f6abc5 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -554,7 +554,7 @@ ENTRY jit_broken.874 {
   abs.129 = f32[4]{0} abs(subtract.126)
   constant.130 = f32[] constant(inf)
   broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={}
-  compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ, type=UNSIGNED
+  compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ
   not.133 = pred[4]{0} not(compare.132)
   and.134 = pred[4]{0} and(not.128, not.133)
   add.135 = f32[4]{0} add(add.124, add.89)
@@ -577,7 +577,7 @@ ENTRY jit_broken.874 {
   abs.219 = f32[4]{0} abs(subtract.216)
   constant.220 = f32[] constant(inf)
   broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={}
-  compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ, type=UNSIGNED
+  compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ
   not.223 = pred[4]{0} not(compare.222)
   and.224 = pred[4]{0} and(not.218, not.223)
   add.225 = f32[4]{0} add(add.214, add.179)
@@ -600,7 +600,7 @@ ENTRY jit_broken.874 {
   abs.309 = f32[4]{0} abs(subtract.306)
   constant.310 = f32[] constant(inf)
   broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={}
-  compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ, type=UNSIGNED
+  compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ
   not.313 = pred[4]{0} not(compare.312)
   and.314 = pred[4]{0} and(not.308, not.313)
   add.315 = f32[4]{0} add(add.304, add.269)
@@ -623,7 +623,7 @@ ENTRY jit_broken.874 {
   abs.399 = f32[4]{0} abs(subtract.396)
   constant.400 = f32[] constant(inf)
   broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={}
-  compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ, type=UNSIGNED
+  compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ
   not.403 = pred[4]{0} not(compare.402)
   and.404 = pred[4]{0} and(not.398, not.403)
   add.405 = f32[4]{0} add(add.394, add.359)
@@ -646,7 +646,7 @@ ENTRY jit_broken.874 {
   abs.489 = f32[4]{0} abs(subtract.486)
   constant.490 = f32[] constant(inf)
   broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={}
-  compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ, type=UNSIGNED
+  compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ
   not.493 = pred[4]{0} not(compare.492)
   and.494 = pred[4]{0} and(not.488, not.493)
   add.495 = f32[4]{0} add(add.484, add.449)
@@ -669,7 +669,7 @@ ENTRY jit_broken.874 {
   abs.579 = f32[4]{0} abs(subtract.576)
   constant.580 = f32[] constant(inf)
   broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={}
-  compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ, type=UNSIGNED
+  compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ
   not.583 = pred[4]{0} not(compare.582)
   and.584 = pred[4]{0} and(not.578, not.583)
   add.585 = f32[4]{0} add(add.574, add.539)
@@ -692,7 +692,7 @@ ENTRY jit_broken.874 {
   abs.669 = f32[4]{0} abs(subtract.666)
   constant.670 = f32[] constant(inf)
   broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={}
-  compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ, type=UNSIGNED
+  compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ
   not.673 = pred[4]{0} not(compare.672)
   and.674 = pred[4]{0} and(not.668, not.673)
   add.675 = f32[4]{0} add(add.664, add.629)
@@ -715,7 +715,7 @@ ENTRY jit_broken.874 {
   abs.759 = f32[4]{0} abs(subtract.756)
   constant.760 = f32[] constant(inf)
   broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={}
-  compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ, type=UNSIGNED
+  compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ
   not.763 = pred[4]{0} not(compare.762)
   and.764 = pred[4]{0} and(not.758, not.763)
   add.765 = f32[4]{0} add(add.754, add.719)
@@ -738,7 +738,7 @@ ENTRY jit_broken.874 {
   abs.849 = f32[4]{0} abs(subtract.846)
   constant.850 = f32[] constant(inf)
   broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={}
-  compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ, type=UNSIGNED
+  compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ
   not.853 = pred[4]{0} not(compare.852)
   and.854 = pred[4]{0} and(not.848, not.853)
   add.855 = f32[4]{0} add(add.844, add.809)
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 3e9a3ec2314..aa377b57fe6 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -73,6 +73,8 @@ int main(int argc, char** argv) {
     triple_string = "x86_64-pc-windows-msvc19";
   } else if (target_cpu == "ppc") {
     triple_string = "ppc64le-ibm-linux-gnu";
+  } else if (target_cpu == "s390x") {
+    triple_string = "systemz-none-linux-gnu";
   } else if (target_cpu == "local") {
     triple_string = llvm::sys::getDefaultTargetTriple();
   } else {
diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc
index 6e7deda13f0..f39bd269ef4 100644
--- a/tensorflow/compiler/xla/util.cc
+++ b/tensorflow/compiler/xla/util.cc
@@ -367,15 +367,15 @@ string SanitizeFileName(string file_name) {
 //     precision, Numerische Mathematik, vol. 18, pp. 224–242, 1971.
 std::pair<float, float> SplitF64ToF32(double x) {
   const float x_f32 = static_cast<float>(x);
-  // Early return if x is an infinity or NaN.
-  if (!std::isfinite(x)) {
-    return std::make_pair(x_f32, 0.0f);
-  }
 
-  // Only values within the range of F32 are supported, unless it is infinity.
-  // Small values with large negative exponents would be rounded to zero.
+  // Early return if x is an infinity or NaN.
   if (!std::isfinite(x_f32)) {
-    LOG(WARNING) << "Out of range F64 constant detected: " << x;
+    // Only values within the range of F32 are supported, unless it is infinity.
+    // Small values with large negative exponents would be rounded to zero.
+    if (std::isfinite(x)) {
+      LOG(WARNING) << "Out of range F64 constant detected: " << x;
+    }
+    return std::make_pair(x_f32, 0.0f);
   }
 
   // The high float is simply the double rounded to the nearest float. Because
diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc
index 5477dfba18d..6f60e241d92 100644
--- a/tensorflow/compiler/xla/util_test.cc
+++ b/tensorflow/compiler/xla/util_test.cc
@@ -126,5 +126,13 @@ TEST(UtilTest, RoundTripFpToString) {
             "-nan");
 }
 
+TEST(UtilTest, SplitF64ToF32) {
+  // Overflowing the F32 exponent in SplitF64ToF32 should result in a pair of
+  // [∞,0].
+  EXPECT_EQ(SplitF64ToF32(std::numeric_limits<double>::max()).first,
+            std::numeric_limits<float>::infinity());
+  EXPECT_EQ(SplitF64ToF32(std::numeric_limits<double>::max()).second, 0.0f);
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index eade7c2426d..01de56bf85d 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -85,6 +85,7 @@ enum PrimitiveType {
   // Next = 19
 }
 // LINT.ThenChange(
+//   https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,
 //   https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc
 // )
 
diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD
index b1ced361927..4cab7b79034 100644
--- a/tensorflow/core/api_def/BUILD
+++ b/tensorflow/core/api_def/BUILD
@@ -6,7 +6,6 @@
 #   :python_api_def
 #   :java_api_def
 
-load("//tensorflow:tensorflow.bzl", "filegroup")
 load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
 load(
     "//tensorflow:tensorflow.bzl",
@@ -27,9 +26,9 @@ package(
     licenses = ["notice"],  # Apache 2.0
 )
 
-filegroup(
+alias(
     name = "base_api_def",
-    srcs = glob(["base_api/*"]),
+    actual = "//tensorflow/core/api_def/base_api:base_api_def",
     visibility = ["//tensorflow:internal"],
 )
 
diff --git a/tensorflow/core/api_def/base_api/BUILD b/tensorflow/core/api_def/base_api/BUILD
new file mode 100644
index 00000000000..22cc342b1a1
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/BUILD
@@ -0,0 +1,21 @@
+# Description:
+#   Expose TensorFlow base api.
+
+load("//tensorflow:tensorflow.bzl", "filegroup")
+
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "base_api_def",
+    srcs = glob(
+        [
+            "*",
+        ],
+        exclude = [
+            "BUILD",
+        ],
+    ),
+    visibility = ["//tensorflow:internal"],
+)
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 11f28655f05..06af68bdb65 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -125,7 +125,8 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
 
   // Try allocating.
   size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes);
-  void* mem_addr = sub_allocator_->Alloc(alignment, bytes);
+  size_t bytes_received;
+  void* mem_addr = sub_allocator_->Alloc(alignment, bytes, &bytes_received);
   if (mem_addr == nullptr && !started_backpedal_) {
     // Only backpedal once.
     started_backpedal_ = true;
@@ -136,7 +137,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
     while (mem_addr == nullptr) {
       bytes = RoundedBytes(bytes * kBackpedalFactor);
       if (bytes < rounded_bytes) break;
-      mem_addr = sub_allocator_->Alloc(alignment, bytes);
+      mem_addr = sub_allocator_->Alloc(alignment, bytes, &bytes_received);
     }
   }
 
@@ -158,7 +159,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
 
   VLOG(1) << "Allocated memory at " << mem_addr << " to "
           << static_cast<void*>(static_cast<char*>(mem_addr) + bytes);
-  region_manager_.AddAllocationRegion(mem_addr, bytes);
+  region_manager_.AddAllocationRegion(mem_addr, bytes_received);
 
   // Create one large chunk for the whole memory space that will
   // be chunked later.
diff --git a/tensorflow/core/common_runtime/device/device_host_allocator.h b/tensorflow/core/common_runtime/device/device_host_allocator.h
index 539dab18266..9d11705eadf 100644
--- a/tensorflow/core/common_runtime/device/device_host_allocator.h
+++ b/tensorflow/core/common_runtime/device/device_host_allocator.h
@@ -36,8 +36,10 @@ class DeviceHostAllocator : public SubAllocator {
   }
   ~DeviceHostAllocator() override {}
 
-  void* Alloc(size_t alignment, size_t num_bytes) override {
+  void* Alloc(size_t alignment, size_t num_bytes,
+              size_t* bytes_received) override {
     void* ptr = nullptr;
+    *bytes_received = num_bytes;
     if (num_bytes > 0) {
       ptr = stream_exec_->HostMemoryAllocate(num_bytes);
       if (ptr == nullptr) {
diff --git a/tensorflow/core/common_runtime/device/device_mem_allocator.h b/tensorflow/core/common_runtime/device/device_mem_allocator.h
index 52803102562..16a14c7114b 100644
--- a/tensorflow/core/common_runtime/device/device_mem_allocator.h
+++ b/tensorflow/core/common_runtime/device/device_mem_allocator.h
@@ -41,8 +41,10 @@ class DeviceMemAllocator : public SubAllocator {
   }
   ~DeviceMemAllocator() override {}
 
-  void* Alloc(size_t alignment, size_t num_bytes) override {
+  void* Alloc(size_t alignment, size_t num_bytes,
+              size_t* bytes_received) override {
     void* ptr = nullptr;
+    *bytes_received = num_bytes;
     if (num_bytes > 0) {
       if (use_unified_memory_) {
         ptr = stream_exec_->UnifiedMemoryAllocate(num_bytes);
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index a02ef9188af..1922cdf0937 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -90,6 +90,8 @@ EagerContext::EagerContext(
       log_device_placement_(opts.config.log_device_placement()),
       allow_soft_placement_(opts.config.allow_soft_placement()),
       num_active_steps_(0),
+      step_container_(std::make_unique<ScopedStepContainer>(
+          0, [this](const string& name) { ClearResourceContainer(name); })),
       default_executor_(async),
       log_memory_(LogMemory::IsEnabled()),
       env_(opts.env),
@@ -385,6 +387,11 @@ void EagerContext::ClearCachesAndDefaultExecutor() {
   for (auto& entry : registered_functions_) {
     entry.second->cached_kernel_keys->clear();
   }
+  {
+    mutex_lock ml(metadata_mu_);
+    step_container_.reset(new ScopedStepContainer(
+        0, [this](const string& name) { ClearResourceContainer(name); }));
+  }
 }
 
 void EagerContext::SetThreadLocalDevicePlacementPolicy(
@@ -600,29 +607,20 @@ void EagerContext::ListDevices(
 void EagerContext::StartStep() {
   mutex_lock ml(metadata_mu_);
   num_active_steps_++;
-  if (step_container_ == nullptr) {
-    step_container_.reset(
-        new ScopedStepContainer(0, [this](const string& name) {
-          auto local_devices = local_device_mgr()->ListDevices();
-          for (Device* device : local_devices) {
-            device->resource_manager()->Cleanup(name).IgnoreError();
-          }
-        }));
-  }
 }
 
 void EagerContext::EndStep() {
   mutex_lock ml(metadata_mu_);
   num_active_steps_--;
   if (num_active_steps_ == 0) {
-    step_container_.reset();
+    // TODO(b/139809335): This does not properly clean up remote resources
+    // Clean up the previous step container and create a new one.
+    step_container_.reset(new ScopedStepContainer(
+        0, [this](const string& name) { ClearResourceContainer(name); }));
   }
 }
 
 ScopedStepContainer* EagerContext::StepContainer() {
-  if (num_active_steps_.load() == 0) {
-    return nullptr;
-  }
   mutex_lock ml(metadata_mu_);
   return step_container_.get();
 }
@@ -988,6 +986,15 @@ Status EagerContext::CPUDeviceOnTask(const Device* device,
   return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
 }
 
+void EagerContext::ClearResourceContainer(const string& name) {
+  // TODO(b/139809335): This does not properly clean up remote resources
+  auto local_devices = local_device_mgr()->ListDevices();
+  for (Device* device : local_devices) {
+    // Only ignore container not found errors.
+    device->resource_manager()->Cleanup(name).IgnoreError();
+  }
+}
+
 namespace {
 Status GetTaskName(Device* d, string* task_name) {
   string ignored;
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 1e9516d5a69..dc833a71650 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -506,6 +506,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
 
   void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr);
 
+  void ClearResourceContainer(const string& name);
+
   template <typename T>
   struct OwnedOrUnownedHelper {
    public:
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index d2327ff9592..3b73f018c0d 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -285,13 +285,7 @@ Status KernelAndDeviceOp::Run(
 
   params.runner = get_runner();
 
-  params.step_container =
-      step_container == nullptr ? &step_container_ : step_container;
-  auto step_container_cleanup = gtl::MakeCleanup([step_container, this] {
-    if (step_container == nullptr) {
-      this->step_container_.CleanUp();
-    }
-  });
+  params.step_container = step_container;
 
   params.collective_executor =
       collective_executor_ ? collective_executor_->get() : nullptr;
@@ -392,8 +386,7 @@ void KernelAndDeviceFunc::RunAsync(
     opts->cancellation_manager = local_cm.get();
   }
   opts->allow_dead_tensors = true;
-  opts->step_container =
-      step_container == nullptr ? &step_container_ : step_container;
+  opts->step_container = step_container;
   opts->collective_executor =
       collective_executor_ ? collective_executor_->get() : nullptr;
 
@@ -406,9 +399,6 @@ void KernelAndDeviceFunc::RunAsync(
              [opts, rendezvous, local_cm, step_container, this,
               done = std::move(done)](const Status& s) {
                rendezvous->Unref();
-               if (step_container == nullptr) {
-                 this->step_container_.CleanUp();
-               }
                done(s);
              });
 }
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index f48452aa46b..f0d018ee093 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -201,10 +201,7 @@ class KernelAndDeviceOp final : public KernelAndDevice {
       : KernelAndDevice(flr, runner, std::move(collective_executor),
                         host_cpu_device),
         rendezvous_(rendezvous),
-        log_memory_(log_memory),
-        step_container_(0, [this](const string& name) {
-          device_->resource_manager()->Cleanup(name).IgnoreError();
-        }) {}
+        log_memory_(log_memory) {}
 
   ~KernelAndDeviceOp() override {}
 
@@ -252,7 +249,6 @@ class KernelAndDeviceOp final : public KernelAndDevice {
   Rendezvous* const rendezvous_;
   checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
   const bool log_memory_;
-  ScopedStepContainer step_container_;
 };
 
 // Represents a multi-device function. Functions can also be run using
@@ -286,15 +282,7 @@ class KernelAndDeviceFunc : public KernelAndDevice {
             std::move(input_resource_dtypes_and_shapes)),
         name_(name),
         rendezvous_creator_(std::move(rendezvous_creator)),
-        get_op_id_(std::move(get_op_id)),
-        step_container_(0, [this](const string& name) {
-          // TODO(b/139809335): This does not properly clean up remote resources
-          const std::vector<Device*> devices =
-              pflr_->device_mgr()->ListDevices();
-          for (Device* device : devices) {
-            device->resource_manager()->Cleanup(name).IgnoreError();
-          }
-        }) {}
+        get_op_id_(std::move(get_op_id)) {}
 
   ~KernelAndDeviceFunc() override;
 
@@ -362,8 +350,6 @@ class KernelAndDeviceFunc : public KernelAndDevice {
 
   std::function<Rendezvous*(const int64)> rendezvous_creator_;
   std::function<int64()> get_op_id_;
-
-  ScopedStepContainer step_container_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc
index 6b40fcc4c70..6c66069b275 100644
--- a/tensorflow/core/common_runtime/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/pool_allocator.cc
@@ -127,8 +127,9 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
     delete pr;
     return PrepareChunk(r, alignment, num_bytes);
   } else {
-    void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
-    return PrepareChunk(ptr, alignment, num_bytes);
+    size_t bytes_received;
+    void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes, &bytes_received);
+    return PrepareChunk(ptr, alignment, bytes_received);
   }
 }
 
@@ -256,8 +257,10 @@ void PoolAllocator::EvictOne() {
   }
 }
 
-void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes) {
+void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes,
+                               size_t* bytes_received) {
   void* ptr = nullptr;
+  *bytes_received = num_bytes;
   if (num_bytes > 0) {
     if (numa_node_ == port::kNUMANoAffinity) {
       ptr = port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
diff --git a/tensorflow/core/common_runtime/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h
index 7c896cd4261..da1a3796830 100644
--- a/tensorflow/core/common_runtime/pool_allocator.h
+++ b/tensorflow/core/common_runtime/pool_allocator.h
@@ -22,6 +22,7 @@ limitations under the License.
 #include <map>
 #include <memory>
 #include <vector>
+
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/lib/core/bits.h"
 #include "tensorflow/core/platform/logging.h"
@@ -154,7 +155,8 @@ class BasicCPUAllocator : public SubAllocator {
 
   ~BasicCPUAllocator() override {}
 
-  void* Alloc(size_t alignment, size_t num_bytes) override;
+  void* Alloc(size_t alignment, size_t num_bytes,
+              size_t* bytes_received) override;
 
   void Free(void* ptr, size_t num_bytes) override;
 
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 50f3b52e4c6..4a8f37eca1f 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -1426,6 +1426,10 @@ void ProcessFunctionLibraryRuntime::Run(
                                       InternalArgs* comp_args) -> Status {
       // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
       for (const auto& it : comp_data.arg_indices) {
+        if (it.index >= args.size()) {
+          return errors::InvalidArgument(
+              "index ", it.index, " is out of range [0, ", args.size(), ")");
+        }
         if (it.sub_index >= 0) {
           const Tensor& t = args[it.index];
           if (t.dtype() != DT_RESOURCE) {
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index d386dc6ec6b..eb93402b1a4 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -393,7 +393,7 @@ cc_library(
     hdrs = [
         "test_util.h",
     ],
-    data = glob(["testdata/*.pbtxt"]),
+    data = ["//tensorflow/core/data/service/testdata"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc
index a119331c56a..680edcce33a 100644
--- a/tensorflow/core/data/service/data_service.cc
+++ b/tensorflow/core/data/service/data_service.cc
@@ -148,34 +148,16 @@ Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
   return Status::OK();
 }
 
-Status DataServiceDispatcherClient::CreateJob(int64 dataset_id,
-                                              ProcessingMode processing_mode,
-                                              int64& job_client_id) {
-  TF_RETURN_IF_ERROR(EnsureInitialized());
-  CreateJobRequest req;
-  req.set_dataset_id(dataset_id);
-  req.set_processing_mode(ProcessingModeDef(processing_mode));
-  CreateJobResponse resp;
-  grpc::ClientContext client_ctx;
-  grpc::Status status = stub_->CreateJob(&client_ctx, req, &resp);
-  if (!status.ok()) {
-    return grpc_util::WrapError(
-        absl::StrCat("Failed to create job for dataset with id ", dataset_id),
-        status);
-  }
-  job_client_id = resp.job_client_id();
-  return Status::OK();
-}
-
 Status DataServiceDispatcherClient::GetOrCreateJob(
     int64 dataset_id, ProcessingMode processing_mode,
-    const std::string& job_name, int job_name_index, int64& job_client_id) {
+    const absl::optional<JobKey>& job_key, int64& job_client_id) {
   TF_RETURN_IF_ERROR(EnsureInitialized());
   GetOrCreateJobRequest req;
   req.set_dataset_id(dataset_id);
   req.set_processing_mode(ProcessingModeDef(processing_mode));
-  req.set_job_name(job_name);
-  req.set_job_name_index(job_name_index);
+  if (job_key.has_value()) {
+    *req.mutable_job_key() = job_key.value();
+  }
   GetOrCreateJobResponse resp;
   grpc::ClientContext client_ctx;
   grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h
index f10abae8acf..935e549b303 100644
--- a/tensorflow/core/data/service/data_service.h
+++ b/tensorflow/core/data/service/data_service.h
@@ -100,16 +100,11 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
   // dataset id in `dataset_id`.
   Status RegisterDataset(GraphDef dataset, int64& dataset_id);
 
-  // Creates a new tf.data service job for the specified dataset. The id for the
-  // created job will be stored in `job_client_id`.
-  Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
-                   int64& job_client_id);
-
   // Gets the job id for the job represented by the tuple
   // (job_name, job_name_index), and stores the id in `job_client_id`. If the
   // job doesn't exist yet, it will be created.
   Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
-                        const std::string& job_name, int job_name_index,
+                        const absl::optional<JobKey>& job_key,
                         int64& job_client_id);
 
   // Releases a job client id, indicating that the id will no longer be used to
diff --git a/tensorflow/core/data/service/dispatcher.proto b/tensorflow/core/data/service/dispatcher.proto
index be423b23eb5..3592efc48a8 100644
--- a/tensorflow/core/data/service/dispatcher.proto
+++ b/tensorflow/core/data/service/dispatcher.proto
@@ -57,29 +57,23 @@ message GetOrRegisterDatasetResponse {
   int64 dataset_id = 1;
 }
 
-message CreateJobRequest {
-  // The id of the dataset to create a job for.
-  int64 dataset_id = 1;
-  // A mode controlling how the tf.data service produces data for the job.
-  ProcessingModeDef processing_mode = 2;
-}
-
-message CreateJobResponse {
-  // An id for the client that will read from the job. When the client is done
-  // with the job, they should call ReleaseJobClient with this id.
-  int64 job_client_id = 1;
+message JobKey {
+  // A name for the job.
+  string job_name = 1;
+  // An index for the job. Multiple jobs can be created for the same name, if
+  // they have different indices.
+  int64 job_name_index = 2;
 }
 
 message GetOrCreateJobRequest {
+  reserved 3, 4;
   // The id of the dataset to create a job for.
   int64 dataset_id = 1;
   // A mode controlling how the tf.data service produces data for the job.
   ProcessingModeDef processing_mode = 2;
-  // A name for the job.
-  string job_name = 3;
-  // An index for the job. Multiple jobs can be created for the same name, if
-  // they have different indices.
-  int64 job_name_index = 4;
+  // Optional job key identifying a shared job. If not set, the RPC will always
+  // create a new job.
+  JobKey job_key = 5;
 }
 
 message GetOrCreateJobResponse {
@@ -144,9 +138,6 @@ service DispatcherService {
   // Gets a job if it already exists, otherwise creates it.
   rpc GetOrCreateJob(GetOrCreateJobRequest) returns (GetOrCreateJobResponse);
 
-  // Creates a job for reading from the tf.data service.
-  rpc CreateJob(CreateJobRequest) returns (CreateJobResponse);
-
   // Releases a job client so that a job may eventually be cleaned up.
   rpc ReleaseJobClient(ReleaseJobClientRequest)
       returns (ReleaseJobClientResponse);
diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc
index 5e865242fd1..1e0ee62d27c 100644
--- a/tensorflow/core/data/service/dispatcher_impl.cc
+++ b/tensorflow/core/data/service/dispatcher_impl.cc
@@ -404,55 +404,36 @@ Status DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
   return Apply(update);
 }
 
-Status DataServiceDispatcherImpl::CreateJob(const CreateJobRequest* request,
-                                            CreateJobResponse* response) {
-  TF_RETURN_IF_ERROR(CheckStarted());
-  VLOG(3) << "Received create job request for dataset id "
-          << request->dataset_id();
-  ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
-  std::shared_ptr<const Job> job;
-  std::vector<std::shared_ptr<const Task>> tasks;
-  {
-    mutex_lock l(mu_);
-    TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), processing_mode,
-                                 absl::optional<NamedJobKey>(), job));
-    int64 job_client_id;
-    TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
-    response->set_job_client_id(job_client_id);
-    TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks));
-  }
-  TF_RETURN_IF_ERROR(AssignTasks(tasks));
-
-  VLOG(3) << "Creating job " << job->job_id << " for dataset "
-          << request->dataset_id();
-  return Status::OK();
-}
-
 Status DataServiceDispatcherImpl::GetOrCreateJob(
     const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
   TF_RETURN_IF_ERROR(CheckStarted());
-  VLOG(3) << "Received get or create job request for dataset id "
-          << request->dataset_id() << " with name " << request->job_name()
-          << " and index " << request->job_name_index();
-  NamedJobKey key(request->job_name(), request->job_name_index());
+  VLOG(3) << "GetOrCreateJob(" << request->DebugString() << ")";
+  absl::optional<NamedJobKey> key;
+  if (request->has_job_key()) {
+    key.emplace(request->job_key().job_name(),
+                request->job_key().job_name_index());
+  }
   ProcessingMode requested_processing_mode =
       ProcessingMode(request->processing_mode());
   std::shared_ptr<const Job> job;
   std::vector<std::shared_ptr<const Task>> tasks;
   {
     mutex_lock l(mu_);
-    Status s = state_.NamedJobByKey(key, job);
-    if (s.ok()) {
-      TF_RETURN_IF_ERROR(ValidateMatchingJob(job, requested_processing_mode,
-                                             request->dataset_id()));
-      int64 job_client_id;
-      TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
-      response->set_job_client_id(job_client_id);
-      VLOG(3) << "Found existing job for name=" << key.name
-              << ", index=" << key.index << ". job_id: " << job->job_id;
-      return Status::OK();
-    } else if (!errors::IsNotFound(s)) {
-      return s;
+    if (key.has_value()) {
+      Status s = state_.NamedJobByKey(key.value(), job);
+      if (s.ok()) {
+        TF_RETURN_IF_ERROR(ValidateMatchingJob(job, requested_processing_mode,
+                                               request->dataset_id()));
+        int64 job_client_id;
+        TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
+        response->set_job_client_id(job_client_id);
+        VLOG(3) << "Found existing job for name=" << key.value().name
+                << ", index=" << key.value().index
+                << ". job_id: " << job->job_id;
+        return Status::OK();
+      } else if (!errors::IsNotFound(s)) {
+        return s;
+      }
     }
     TF_RETURN_IF_ERROR(
         CreateJob(request->dataset_id(), requested_processing_mode, key, job));
@@ -462,8 +443,8 @@ Status DataServiceDispatcherImpl::GetOrCreateJob(
     TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks));
   }
   TF_RETURN_IF_ERROR(AssignTasks(tasks));
-  VLOG(3) << "Created job " << job->job_id << " for dataset "
-          << request->dataset_id() << " and name " << request->job_name();
+  VLOG(3) << "Created job " << job->job_id << " for CreateJob("
+          << request->DebugString() << ")";
   return Status::OK();
 }
 
diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h
index e8cd3954d59..31c8b874ef9 100644
--- a/tensorflow/core/data/service/dispatcher_impl.h
+++ b/tensorflow/core/data/service/dispatcher_impl.h
@@ -68,8 +68,6 @@ class DataServiceDispatcherImpl {
   /// Client-facing API.
   Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request,
                               GetOrRegisterDatasetResponse* response);
-  Status CreateJob(const CreateJobRequest* request,
-                   CreateJobResponse* response);
   Status GetOrCreateJob(const GetOrCreateJobRequest* request,
                         GetOrCreateJobResponse* response);
   Status ReleaseJobClient(const ReleaseJobClientRequest* request,
diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.cc b/tensorflow/core/data/service/grpc_dispatcher_impl.cc
index 456fdfcc583..5d8b2c79a3e 100644
--- a/tensorflow/core/data/service/grpc_dispatcher_impl.cc
+++ b/tensorflow/core/data/service/grpc_dispatcher_impl.cc
@@ -45,7 +45,6 @@ HANDLER(WorkerUpdate);
 HANDLER(GetDatasetDef);
 HANDLER(GetSplit);
 HANDLER(GetOrRegisterDataset);
-HANDLER(CreateJob);
 HANDLER(ReleaseJobClient);
 HANDLER(GetOrCreateJob);
 HANDLER(GetTasks);
diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h
index ec6ffbb2d3f..3faa7bc36de 100644
--- a/tensorflow/core/data/service/grpc_dispatcher_impl.h
+++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h
@@ -44,7 +44,6 @@ class GrpcDispatcherImpl : public DispatcherService::Service {
   HANDLER(GetDatasetDef);
   HANDLER(GetSplit);
   HANDLER(GetOrRegisterDataset);
-  HANDLER(CreateJob);
   HANDLER(ReleaseJobClient);
   HANDLER(GetOrCreateJob);
   HANDLER(GetTasks);
diff --git a/tensorflow/core/data/service/testdata/BUILD b/tensorflow/core/data/service/testdata/BUILD
new file mode 100644
index 00000000000..b32f47b91eb
--- /dev/null
+++ b/tensorflow/core/data/service/testdata/BUILD
@@ -0,0 +1,14 @@
+# Description:
+# Data service test data.
+
+load("//tensorflow:tensorflow.bzl", "filegroup")
+
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "testdata",
+    srcs = glob(["*.pbtxt"]),
+    visibility = ["//tensorflow/core/data/service:__pkg__"],
+)
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index d615b36dc42..dd675d7e8d3 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -1251,6 +1251,31 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "extension_type_variant",
+    srcs = ["extension_type_variant.cc"],
+    hdrs = ["extension_type_variant.h"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+tf_cc_test(
+    name = "extension_type_variant_test",
+    size = "small",
+    srcs = ["extension_type_variant_test.cc"],
+    deps = [
+        ":extension_type_variant",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
 # All framewrok protos are self-contained, i.e. they only import other
 # protos from the same package, so we can build the protos here and then
 # link them from core:protos_all without circular dependencies.
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index f7402f7b293..2899a59c548 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -439,7 +439,11 @@ class SubAllocator {
                const std::vector<Visitor>& free_visitors);
 
   virtual ~SubAllocator() {}
-  virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
+  // Allocates at least num_bytes. Returns actual number of bytes allocated in
+  // bytes_received. The caller can safely use the full bytes_received sized
+  // buffer following the returend pointer.
+  virtual void* Alloc(size_t alignment, size_t num_bytes,
+                      size_t* bytes_received) = 0;
   virtual void Free(void* ptr, size_t num_bytes) = 0;
 
  protected:
diff --git a/tensorflow/core/framework/cpu_allocator_impl.cc b/tensorflow/core/framework/cpu_allocator_impl.cc
index 511cfce8ab5..567454a3295 100644
--- a/tensorflow/core/framework/cpu_allocator_impl.cc
+++ b/tensorflow/core/framework/cpu_allocator_impl.cc
@@ -156,7 +156,9 @@ class CPUAllocatorFactory : public AllocatorFactory {
     explicit CPUSubAllocator(CPUAllocator* cpu_allocator)
         : SubAllocator({}, {}), cpu_allocator_(cpu_allocator) {}
 
-    void* Alloc(size_t alignment, size_t num_bytes) override {
+    void* Alloc(size_t alignment, size_t num_bytes,
+                size_t* bytes_received) override {
+      *bytes_received = num_bytes;
       return cpu_allocator_->AllocateRaw(alignment, num_bytes);
     }
 
diff --git a/tensorflow/core/framework/extension_type_variant.cc b/tensorflow/core/framework/extension_type_variant.cc
new file mode 100644
index 00000000000..bd6a6f9d42f
--- /dev/null
+++ b/tensorflow/core/framework/extension_type_variant.cc
@@ -0,0 +1,63 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/extension_type_variant.h"
+
+#include "tensorflow/core/framework/variant_op_registry.h"
+#include "tensorflow/core/platform/errors.h"
+
+namespace tensorflow {
+
+constexpr const char ExtensionTypeVariant::kTypeName[];
+
+void ExtensionTypeVariant::Encode(VariantTensorData* data) const {
+  data->set_type_name(TypeName());
+  metadata_.type_spec_proto().SerializeToString(&data->metadata_string());
+  for (const Tensor& tensor : flat_components_) {
+    data->add_tensor(tensor);
+  }
+}
+
+bool ExtensionTypeVariant::Decode(const VariantTensorData& data) {
+  if (!metadata_.mutable_type_spec_proto()->ParseFromString(
+          data.metadata_string())) {
+    return false;
+  }
+  flat_components_ = data.tensors();
+  return true;
+}
+
+string ExtensionTypeVariant::DebugString() const {
+  string type_spec;
+  ::tensorflow::protobuf::TextFormat::Printer printer;
+  printer.SetSingleLineMode(true);
+  printer.PrintToString(metadata_.type_spec_proto(), &type_spec);
+  string result("<ExtensionTypeVariant type_spec={");
+  result.append(type_spec.empty() ? "none" : type_spec);
+  result.append("}, components=[");
+  for (const auto& tensor : flat_components_) {
+    if (&tensor != &flat_components_[0]) {
+      result.append(", ");
+    }
+    result.append(tensor.DebugString());
+  }
+  result.append("]>");
+  return result;
+}
+
+REGISTER_UNARY_VARIANT_DECODE_FUNCTION(ExtensionTypeVariant,
+                                       ExtensionTypeVariant::kTypeName);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/framework/extension_type_variant.h b/tensorflow/core/framework/extension_type_variant.h
new file mode 100644
index 00000000000..d50abb5ffad
--- /dev/null
+++ b/tensorflow/core/framework/extension_type_variant.h
@@ -0,0 +1,96 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_
+#define TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_
+
+#include <vector>
+
+#include "absl/types/span.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant_tensor_data.h"
+#include "tensorflow/core/protobuf/extension_type_variant.pb.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+
+namespace tensorflow {
+
+// Encoding for a `tf.ExtensionType` value, that can be saved as a Variant.
+//
+// `tf.ExtensionType` (also known as `CompositeTensor`) is a Python base class
+// used to Python types that are supported by TensorFlow APIs.  Example
+// ExtensionTypes include `tf.RaggedTensor` and `tf.SparseTensor`.
+//
+// `ExtensionTypeVariant` decomposes the `ExtensionType` value into two
+// parts:
+//
+//   * `components`: A list of Tensors, which encodes the value's dynamic
+//     data -- i.e., data that may change for different executions of a graph.
+//   * `type_spec_proto`: A serialized TypeSpec, which encodes the value's
+//     static data -- i.e., data that is the same for all executions of a graph.
+//
+// ExtensionTypeVariant can be stored in a Tensor with dtype=DT_VARIANT.
+// Typically, extension type values are encoded with a scalar tensor containing
+// a single ExtensionTypeVariant value.
+class ExtensionTypeVariant {
+ public:
+  ExtensionTypeVariant(const TypeSpecProto& type_spec_proto,
+                       absl::Span<Tensor> flat_components)
+      : flat_components_(flat_components.begin(), flat_components.end()) {
+    *metadata_.mutable_type_spec_proto() = type_spec_proto;
+  }
+
+  // This type is default-constructible, copyable, assignable, and movable.
+  ExtensionTypeVariant() = default;
+  ExtensionTypeVariant(const ExtensionTypeVariant& other) = default;
+  ExtensionTypeVariant& operator=(ExtensionTypeVariant&& other) = default;
+  ExtensionTypeVariant& operator=(const ExtensionTypeVariant& other) = default;
+
+  // Returns the list of Tensor components that encode this value's dynamic
+  // data.
+  absl::Span<const Tensor> flat_components() const {
+    return absl::MakeConstSpan(flat_components_);
+  }
+
+  // Returns the serialized TypeSpec that encodes the value's static data.
+  TypeSpecProto type_spec_proto() const { return metadata_.type_spec_proto(); }
+
+  // Variant methods.
+  string TypeName() const { return kTypeName; }
+
+  // Updates `VariantTensorData` with an encoding for this value.
+  void Encode(VariantTensorData* data) const;
+
+  // Updates this value to match the encoding in a given `VariantTensorData`.
+  bool Decode(const VariantTensorData& data);
+
+  // Returns a string summary for this value.
+  string DebugString() const;
+
+  // Name of this type (used for variant serialization).
+  static constexpr const char kTypeName[] = "ExtensionTypeVariant";
+
+ private:
+  // Tensor components for this value.
+  std::vector<Tensor> flat_components_;
+
+  // TypeSpec for this value.  ExtensionTypeVariantMetadata is a thin wrapper
+  // around a TypeSpecProto, which is used to retain flexibility to change the
+  // variant encoding.
+  ExtensionTypeVariantMetadata metadata_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_
diff --git a/tensorflow/core/framework/extension_type_variant_test.cc b/tensorflow/core/framework/extension_type_variant_test.cc
new file mode 100644
index 00000000000..cd30b320bad
--- /dev/null
+++ b/tensorflow/core/framework/extension_type_variant_test.cc
@@ -0,0 +1,99 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/extension_type_variant.h"
+
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+// TypeSpecProto for a 2D Ragged Tensor.
+constexpr const char* k2DRaggedTensorSpec = R"(
+type_spec_class: RAGGED_TENSOR_SPEC
+type_state: {
+  tuple_value: {
+    values: [
+      {tensor_shape_value: {dim: [{size: -1}, {size: -1}]}},  # shape
+      {tensor_dtype_value: DT_INT32},                         # dtype
+      {int64_value: 1},                                       # ragged_rank
+      {tensor_dtype_value: DT_INT64}                          # row_splits_dtype
+    ]
+  }
+}
+)";
+
+// Returns an ExtensionTypeVariant encoding for a 2D ragged tensor with
+// the specified values and row_splits.
+ExtensionTypeVariant Make2DRaggedTensor(const std::vector<int32>& values,
+                                        const std::vector<int64>& splits) {
+  TypeSpecProto type_spec;
+  EXPECT_TRUE(
+      protobuf::TextFormat::ParseFromString(k2DRaggedTensorSpec, &type_spec));
+  std::vector<Tensor> components;
+  components.push_back(test::AsTensor<int32>(values));
+  components.push_back(test::AsTensor<int64>(splits));
+  ExtensionTypeVariant v(type_spec, absl::MakeSpan(components));
+  return v;
+}
+
+TEST(ExtensionTypeVariantTest, EncodeAndDecodeRagged) {
+  ExtensionTypeVariant v = Make2DRaggedTensor(
+      /* values = */ {5, 5, 3, 4, 1, 8},
+      /* splits = */ {0, 2, 3, 6});
+  Tensor t(DT_VARIANT, {});
+
+  t.flat<Variant>()(0) = v;  // Encode to variant.
+  auto* decoded = t.flat<Variant>()(0).get<ExtensionTypeVariant>();
+
+  EXPECT_EQ(v.type_spec_proto().SerializeAsString(),
+            decoded->type_spec_proto().SerializeAsString());
+  EXPECT_EQ(v.flat_components().size(), 2);
+  test::ExpectTensorEqual<int32>(v.flat_components()[0],
+                                 decoded->flat_components()[0]);
+  test::ExpectTensorEqual<int64>(v.flat_components()[1],
+                                 decoded->flat_components()[1]);
+}
+
+TEST(ExtensionTypeVariantTest, DebugStringForDefaultConstructed) {
+  ExtensionTypeVariant v;
+  EXPECT_EQ(v.DebugString(),
+            "<ExtensionTypeVariant type_spec={none}, components=[]>");
+}
+
+TEST(ExtensionTypeVariantTest, DebugStringForRagged) {
+  ExtensionTypeVariant v = Make2DRaggedTensor(
+      /* values = */ {5, 5, 3, 4, 1},
+      /* splits = */ {0, 2, 3, 5});
+  EXPECT_EQ(v.DebugString(),
+            "<ExtensionTypeVariant type_spec={type_spec_class: "
+            "RAGGED_TENSOR_SPEC type_state { tuple_value { values { "
+            "tensor_shape_value { dim { size: -1 } dim { size: -1 } } } "
+            "values { tensor_dtype_value: DT_INT32 } values "
+            "{ int64_value: 1 } values { tensor_dtype_value: DT_INT64 } } } }, "
+            "components=[Tensor<type: int32 shape: [5] values: 5 5 3...>, "
+            "Tensor<type: int64 shape: [4] values: 0 2 3...>]>");
+}
+
+TEST(ExtensionTypeVariantTest, TypeName) {
+  ExtensionTypeVariant v;
+  EXPECT_EQ(v.TypeName(), "ExtensionTypeVariant");
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index 961e6df005b..6061d9737ec 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -94,14 +94,16 @@ class ScopedStepContainer {
   // prefix: optional string prefix to disambiguate step containers.
   ScopedStepContainer(const int64 step_id,
                       std::function<void(const string&)> cleanup)
-      : container_(strings::StrCat("__per_step_", step_id)),
+      : step_id_(step_id),
+        container_(strings::StrCat("__per_step_", step_id)),
         cleanup_(cleanup),
         dirty_(false) {}
 
   ScopedStepContainer(const int64 step_id,
                       std::function<void(const string&)> cleanup,
                       const std::string& prefix)
-      : container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
+      : step_id_(step_id),
+        container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
         cleanup_(cleanup),
         dirty_(false) {}
 
@@ -141,8 +143,10 @@ class ScopedStepContainer {
   template <typename T>
   Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource,
                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
+  int64 StepId() const { return step_id_; }
 
  private:
+  const int64 step_id_;
   const std::string container_;
   const std::function<void(const string&)> cleanup_;
   mutex mu_;
diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto
index 01b598591e4..e5f33036dcd 100644
--- a/tensorflow/core/framework/types.proto
+++ b/tensorflow/core/framework/types.proto
@@ -84,4 +84,6 @@ enum SpecializedType {
   ST_INVALID = 0;
   // "tensorflow::TensorList" in the variant type registry.
   ST_TENSOR_LIST = 1;
+  // "tensorflow::data::Optional" in the variant type registry.
+  ST_OPTIONAL = 2;
 }
diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
index 8ac60a0916e..c49e3937cb9 100644
--- a/tensorflow/core/grappler/clusters/BUILD
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -149,6 +149,7 @@ tf_cc_test(
     tags = [
         "no_cuda_on_cpu_tap",
         "no_gpu",
+        "no_mac",  # b/174055645
         "nomsan",  # TODO(b/160921160): broken by NOAUTOROLLBACK CL
     ],
     deps = [
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 79b3405712f..8bda3f60913 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -16,12 +16,9 @@ package(
     licenses = ["notice"],  # Apache 2.0
 )
 
-filegroup(
+alias(
     name = "graph_properties_testdata",
-    srcs = glob([
-        "graph_properties_testdata/*.pbtxt",
-        "graph_properties_testdata/*.pbtxt.html",
-    ]),
+    actual = "//tensorflow/core/grappler/costs/graph_properties_testdata:graph_properties_testdata",
     visibility = ["//visibility:public"],
 )
 
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/BUILD b/tensorflow/core/grappler/costs/graph_properties_testdata/BUILD
new file mode 100644
index 00000000000..3ba0979ca6c
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/BUILD
@@ -0,0 +1,17 @@
+# Description:
+# Graph properties test data.
+
+load("//tensorflow:tensorflow.bzl", "filegroup")
+
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "graph_properties_testdata",
+    srcs = glob([
+        "*.pbtxt",
+        "*.pbtxt.html",
+    ]),
+    visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 14cfe09944f..b7d1d7eb597 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -41,6 +41,7 @@ const char kAttrSrcDevice[] = "send_device";
 const char kAttrDstDevice[] = "recv_device";
 const char kAttrTensorName[] = "tensor_name";
 const char kChannelDevice[] = "Channel";
+const char kStreaming[] = "_streaming";
 
 namespace {
 
@@ -109,14 +110,10 @@ void UpdateDeviceAnnotationState(const NodeDef* node,
       (execution_count > 1 && node->attr().count(kOutputSame) == 0) ? 1 : 0;
 }
 
-bool IsStreamingNode(const NodeDef& node) {
-  return node.attr().contains("_streaming");
-}
-
 bool IsStreamingPort(const NodeDef& node, const int port) {
-  if (!IsStreamingNode(node)) return false;
+  if (!node.attr().contains(kStreaming)) return false;
 
-  auto& attr_list = node.attr().at("_streaming").list();
+  auto& attr_list = node.attr().at(kStreaming).list();
   bool is_streaming_port = false;
   if (port >= 0 && port < attr_list.b().size()) {
     is_streaming_port = attr_list.b(port);
@@ -662,10 +659,12 @@ std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv(
 
   auto input_node_port_num = NodePosition(input_name);
   string src_name;
+  bool control_input = false;
   if (input_node_port_num >= 0) {
     src_name = absl::StrCat(from->name(), "_", input_node_port_num);
   } else {
     src_name = absl::StrCat(from->name(), "_minus1");
+    control_input = true;
   }
 
   // _Send op.
@@ -695,16 +694,24 @@ std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv(
   recv->add_input(send->name());
   recv->set_device(DeviceName(to));
   auto& recv_attr = *(recv->mutable_attr());
-  if (from->attr().contains("_streaming")) {
-    *(recv_attr["_streaming"].mutable_list()) =
-        from->attr().at("_streaming").list();
-  }
   recv_attr[kAttrInputSrc].set_s(input_name);
   if (input_node->attr().count(kAttrTensorName)) {
     recv_attr[kAttrTensorName].set_s(
         input_node->attr().at(kAttrTensorName).s());
   }
 
+  // Propagate the streaming attribute to the send/recv nodes.
+  if (from->attr().contains(kStreaming) && !control_input) {
+    if (input_node_port_num >= from->attr().at(kStreaming).list().b_size()) {
+      LOG(ERROR)
+          << from->name()
+          << " port index larger than length of _streaming attribute list.";
+    } else if (from->attr().at(kStreaming).list().b(input_node_port_num)) {
+      send_attr[kStreaming].mutable_list()->add_b(true);
+      recv_attr[kStreaming].mutable_list()->add_b(true);
+    }
+  }
+
   // NodeState for _Send op.
   auto& send_node_state = GetNodeStateOrCreateIt(send);
   send_node_state.device_name = send->device();  // Set Channel device.
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index 04f1e571ae5..308659aeac6 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -37,6 +37,7 @@ ABSL_CONST_INIT extern const char kAttrSrcDevice[];
 ABSL_CONST_INIT extern const char kAttrDstDevice[];
 ABSL_CONST_INIT extern const char kAttrTensorName[];
 ABSL_CONST_INIT extern const char kChannelDevice[];
+ABSL_CONST_INIT extern const char kStreaming[];
 
 struct NodeState {
   // A node (i.e., an op) takes a set of input:port pairs and produces
@@ -438,7 +439,7 @@ class SchedulerState {
   // Auxiliary data structures for constructing NodeState and DeviceState.
   std::unique_ptr<GraphProperties> graph_properties_;  // Initialized in Init().
   Cluster* cluster_;                                   // Not owned.
-  const GrapplerItem* grappler_item_;  // Not owned.
+  const GrapplerItem* grappler_item_;                  // Not owned.
   bool use_static_shapes_;
   bool initialized_;
   bool track_mem_usage_snapshot_;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 6e502490796..2da15512181 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -2547,7 +2547,7 @@ TEST_F(VirtualSchedulerTest, MemoryUsageForStreamingOps) {
       node.set_device(kCPU1);
     }
     if (node.name() == "z" || node.name() == "w")
-      (*node.mutable_attr())["_streaming"].mutable_list()->add_b(true);
+      (*node.mutable_attr())[kStreaming].mutable_list()->add_b(true);
   }
 
   InitScheduler();
diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc
index 1288f9695b9..ae315e97ccb 100644
--- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc
+++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc
@@ -71,7 +71,7 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
     "ZipDataset"
 };
 
-constexpr std::array<const char*, 26> kPassThroughOps = {
+constexpr std::array<const char*, 28> kPassThroughOps = {
     "_Retval",
     "AssertNextDataset",
     "BatchDataset",
@@ -83,12 +83,14 @@ constexpr std::array<const char*, 26> kPassThroughOps = {
     "Identity",
     "MapAndBatchDataset",
     "MapDataset",
+    "MaxIntraOpParallelismDataset",
     "ModelDataset",
     "OptimizeDataset",
     "PaddedBatchDataset",
     "ParallelMapDataset",
     "ParseExampleDataset",
     "PrefetchDataset",
+    "PrivateThreadPoolDataset",
     "ReduceDataset",
     "RebatchDataset",
     "RepeatDataset",
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 36fee01c034..62488211e7d 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -83,6 +83,7 @@ package_group(
     packages = [
         "//tensorflow/...",
         "//tensorflow_text/...",
+        "//third_party/google_research/google_research/tf3d/...",
     ],
 )
 
@@ -635,9 +636,12 @@ cc_library(
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/common_runtime:device_mgr",
+        "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
         "//tensorflow/core/kernels/batching_util:batch_resource_base",
         "//tensorflow/core/kernels/batching_util:concat_split_util",
         "//tensorflow/core/kernels/batching_util:periodic_function_dynamic",
+        "//tensorflow/core/platform:numbers",
         "@com_google_absl//absl/strings",
     ],
     alwayslink = 1,
diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index 5f742c37f35..4cbd214e04f 100644
--- a/tensorflow/core/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -14,12 +14,15 @@ limitations under the License.
 ==============================================================================*/
 
 #include "absl/strings/str_cat.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/device.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_util.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
@@ -29,8 +32,15 @@ limitations under the License.
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/numbers.h"
 
 namespace tensorflow {
+namespace {
+constexpr int64 kMinInflightBatchesLimit = 16;
+constexpr double kInitialInflightBatchesLimit = 64;
+constexpr int64 kBatchesToAverageOver = 10;
+constexpr int64 kMaxInflightBatchesLimit = 128;
+}  // namespace
 
 auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New(
     "/tensorflow/serving/batching/enable_large_batch_splitting",
@@ -52,6 +62,14 @@ void RecordBatchSplitUsage(
   }
 }
 
+void RecordBatchParamNumBatchThreads(int64 num_batch_threads,
+                                     const string& model_name) {
+  static auto* cell = monitoring::Gauge<int64, 1>::New(
+      "/tensorflow/serving/batching/num_batch_threads",
+      "Tracks the number of batch threads of a model.", "model_name");
+  cell->GetCell(model_name)->Set(num_batch_threads);
+}
+
 const string& GetModelName(OpKernelContext* ctx) {
   static string* kModelNameUnset = new string("model_name_unset");
   if (!ctx->session_metadata()) return *kModelNameUnset;
@@ -86,6 +104,24 @@ class BatchResource : public serving::BatchResourceBase {
     return Status::OK();
   }
 
+  static Status Create(
+      AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,
+      int32 max_batch_size, int32 batch_timeout_micros,
+      int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
+      FunctionLibraryRuntime::Handle fhandle,
+      std::unique_ptr<BatchResource>* resource) {
+    std::shared_ptr<AdaptiveBatcherT> batcher;
+    TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create(
+        adaptive_shared_batch_scheduler_options, &batcher));
+
+    resource->reset(new BatchResource(
+        fhandle, std::move(batcher),
+        GetAdaptiveBatcherQueueOptions(max_batch_size, batch_timeout_micros,
+                                       max_enqueued_batches, true),
+        allowed_batch_sizes));
+    return Status::OK();
+  }
+
   string DebugString() const final { return "BatchResource"; }
 
  private:
@@ -99,6 +135,16 @@ class BatchResource : public serving::BatchResourceBase {
             std::move(allowed_batch_sizes)),
         fhandle_(fhandle) {}
 
+  BatchResource(FunctionLibraryRuntime::Handle fhandle,
+                std::shared_ptr<AdaptiveBatcherT> batcher,
+                const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
+                std::vector<int32> allowed_batch_sizes)
+      : BatchResourceBase(
+            /*has_process_batch_function=*/fhandle != kInvalidHandle,
+            std::move(batcher), batcher_queue_options,
+            std::move(allowed_batch_sizes)),
+        fhandle_(fhandle) {}
+
   void ProcessFuncBatchImpl(
       const BatchTask& last_task, absl::Span<const Tensor> inputs,
       std::vector<Tensor>* combined_outputs,
@@ -134,11 +180,6 @@ class BatchFunctionKernel : public AsyncOpKernel {
   explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
-    // If shared_name is not supplied, use name instead (prevent collisions by
-    // default).
-    if (shared_name_.empty()) {
-      shared_name_ = name();
-    }
     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
@@ -148,12 +189,30 @@ class BatchFunctionKernel : public AsyncOpKernel {
                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
 
-    auto lib = c->function_library();
-    OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library"));
-    NameAttrList func;
-    OP_REQUIRES_OK(c, c->GetAttr("f", &func));
-    OP_REQUIRES_OK(
-        c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
+    OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
+    if (num_batch_threads_ <= 0) {
+      adaptive_batch_scheduler_options_ =
+          absl::make_optional(AdaptiveBatchSchedulerOptions{
+              kMinInflightBatchesLimit, kInitialInflightBatchesLimit,
+              kBatchesToAverageOver});
+
+      // Use a shared shared pool across all models if adaptive shared batch
+      // scheduler is used.
+      // `shared_name_` and `container_` is used to look up an instantiated
+      // scheduler instance in `ComputeAsync`.
+      container_ = "__adapative_container";
+      shared_name_ = "__adaptive_global_shared_thread_pool";
+      // Use name to prevent collisions by default.
+      if (batcher_queue_.empty()) {
+        batcher_queue_ = name();
+      }
+    }
+
+    if (shared_name_.empty()) {
+      // If shared_name is not supplied, use name instead (prevent collisions by
+      // default).
+      shared_name_ = name();
+    }
 
     if (c->HasAttr("enable_large_batch_splitting")) {
       OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
@@ -175,16 +234,63 @@ class BatchFunctionKernel : public AsyncOpKernel {
             ? absl::make_optional(enable_large_batch_splitting_)
             : absl::nullopt,
         GetModelName(c));
+    // TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel.
+    RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c));
+
+    std::function<Status(BatchResource**)> creator;
+
+    FunctionLibraryRuntime::Handle handle;
+    OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done);
+
+    if (adaptive_batch_scheduler_options_ != absl::nullopt) {
+      creator = [this, handle](BatchResource** r) {
+        serving::AdaptiveSharedBatchScheduler<
+            serving::BatchResourceBase::BatchTask>::Options
+            adaptive_shared_batch_scheduler_options;
+        adaptive_shared_batch_scheduler_options.thread_pool_name =
+            "adaptive_batch_threads";
+        adaptive_shared_batch_scheduler_options.num_batch_threads =
+            kMaxInflightBatchesLimit;
+        // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros
+        // is 0 (default value) intentionally, so tasks are scheduled in a FIFO
+        // way.
+        // Two rationales to use default value (zero) for
+        // `full_batch_scheduling_boost_micros`
+        // 1) In this way, tasks scheduling policy is FIFO. Compared with round
+        // robin (what shared batch scheduler does), FIFO ensures that model
+        // with low QPS (i.e., models enqueue fewer tasks in the shared queue)
+        // will be processed timely.
+        // 2) If set, `full_batch_scheduling_boost_micros` should be of order
+        // the batch processing latency (which varies on a model basis).
+        // If a non-zero value is not set properly, it harms tail latency.
+        adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit =
+            adaptive_batch_scheduler_options_->min_in_flight_batches_limit;
+        adaptive_shared_batch_scheduler_options
+            .initial_in_flight_batches_limit =
+            adaptive_batch_scheduler_options_->initial_in_flight_batches_limit;
+        adaptive_shared_batch_scheduler_options.batches_to_average_over =
+            adaptive_batch_scheduler_options_->batches_to_average_over;
+        std::unique_ptr<BatchResource> new_resource;
+        TF_RETURN_IF_ERROR(BatchResource::Create(
+            adaptive_shared_batch_scheduler_options, max_batch_size_,
+            batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
+            handle, &new_resource));
+        *r = new_resource.release();
+        return Status::OK();
+      };
+    } else {
+      creator = [this, handle](BatchResource** r) {
+        std::unique_ptr<BatchResource> new_resource;
+        TF_RETURN_IF_ERROR(BatchResource::Create(
+            num_batch_threads_, max_batch_size_, batch_timeout_micros_,
+            max_enqueued_batches_, allowed_batch_sizes_, handle,
+            enable_large_batch_splitting_, &new_resource));
+        *r = new_resource.release();
+        return Status::OK();
+      };
+    }
+
     BatchResource* br;
-    std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
-      std::unique_ptr<BatchResource> new_resource;
-      TF_RETURN_IF_ERROR(BatchResource::Create(
-          num_batch_threads_, max_batch_size_, batch_timeout_micros_,
-          max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
-          enable_large_batch_splitting_, &new_resource));
-      *r = new_resource.release();
-      return Status::OK();
-    };
     OP_REQUIRES_OK_ASYNC(c,
                          c->resource_manager()->LookupOrCreate(
                              container_, shared_name_, &br, creator),
@@ -196,6 +302,75 @@ class BatchFunctionKernel : public AsyncOpKernel {
     // Assume br calls done, so nothing to do here.
   }
 
+  Status InstantiateFunction(OpKernelContext* c,
+                             FunctionLibraryRuntime::Handle* handle) const {
+    // TODO(b/173748062): Merge this instantiation logic with PartitionedCall.
+    FunctionLibraryRuntime* lib = c->function_library();
+    if (!lib) {
+      return errors::Internal("No function library");
+    }
+
+    FunctionLibraryRuntime::InstantiateOptions opts;
+    opts.target = lib->device() == nullptr ? "" : lib->device()->name();
+    opts.is_multi_device_function = true;
+
+    Device* cpu_device;
+    TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
+
+    const FunctionDef* fdef =
+        lib->GetFunctionLibraryDefinition()->Find(func_.name());
+    if (!fdef) {
+      return errors::NotFound("Failed to find definition for function \"",
+                              func_.name(), "\"");
+    }
+    OpInputList in_tensors;
+    TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors));
+    for (int i = 0; i < in_tensors.size(); i++) {
+      if (in_tensors[i].dtype() == DT_RESOURCE) {
+        return errors::InvalidArgument(
+            "BatchFunction cannot take resource inputs but input ", i,
+            " is a resource.");
+      } else {
+        // Currently, inputs are on CPU since they are concatenated on CPU
+        opts.input_devices.push_back(cpu_device->name());
+      }
+    }
+    OpInputList captured_tensors;
+    TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors));
+    for (const Tensor& t : captured_tensors) {
+      if (t.dtype() == DT_RESOURCE) {
+        const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0);
+        opts.input_devices.push_back(rhandle.device());
+      } else {
+        opts.input_devices.push_back(cpu_device->name());
+      }
+    }
+    const OpDef& signature = fdef->signature();
+    for (int i = 0; i < signature.output_arg_size(); i++) {
+      // Currently, outputs must be on CPU since they are split on CPU.
+      opts.output_devices.push_back(cpu_device->name());
+    }
+    if (opts.input_devices.size() != signature.input_arg_size()) {
+      return errors::InvalidArgument(
+          "Function takes ", signature.input_arg_size(), " argument(s) but ",
+          opts.input_devices.size(), " argument(s) were passed");
+    }
+    return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
+                            handle);
+  }
+
+  Status GetOrCreateFunctionHandle(OpKernelContext* c,
+                                   FunctionLibraryRuntime::Handle* handle) {
+    mutex_lock ml(mu_);
+    if (!fhandle_) {
+      TF_RETURN_IF_ERROR(InstantiateFunction(c, handle));
+      fhandle_ = *handle;
+    } else {
+      *handle = fhandle_.value();
+    }
+    return Status::OK();
+  }
+
   // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
   // and the last one must equal 'max_batch_size_'.
   Status ValidateAllowedBatchSizes() const {
@@ -231,13 +406,34 @@ class BatchFunctionKernel : public AsyncOpKernel {
   int32 batch_timeout_micros_;
   int32 max_enqueued_batches_;
   std::vector<int32> allowed_batch_sizes_;
-  FunctionLibraryRuntime::Handle fhandle_;
+  NameAttrList func_;
+  absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_);
   bool enable_large_batch_splitting_;
   bool has_attribute_enable_large_batch_splitting_;
+  mutex mu_;
+
+  // Parameters for adaptive batch scheduler only.
+  // Note 'num_batch_threads_' above is shared by two implementations of batch
+  // scheduler.
+  struct AdaptiveBatchSchedulerOptions {
+    int64 min_in_flight_batches_limit;
+    double initial_in_flight_batches_limit;
+    int64 batches_to_average_over;
+  };
+  absl::optional<AdaptiveBatchSchedulerOptions>
+      adaptive_batch_scheduler_options_ = absl::nullopt;
 };
 
 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
                         BatchFunctionKernel);
+// Currently all inputs and outputs are on the host.
+// TODO(b/173748277): Accept inputs/outputs on the device.
+REGISTER_KERNEL_BUILDER(Name("BatchFunction")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("in_tensors")
+                            .HostMemory("captured_tensors")
+                            .HostMemory("out_tensors"),
+                        BatchFunctionKernel);
 
 class BatchKernel : public AsyncOpKernel {
  public:
@@ -282,8 +478,8 @@ class BatchKernel : public AsyncOpKernel {
     // Assume br calls done, so nothing to do here.
   }
 
-  // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
-  // and the last one must equal 'max_batch_size_'.
+  // Validates 'allowed_batch_sizes_'. The entries must increase
+  // monotonically, and the last one must equal 'max_batch_size_'.
   Status ValidateAllowedBatchSizes() const {
     if (allowed_batch_sizes_.empty()) {
       return Status::OK();
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 5bbfbddf0d4..75c270ff405 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -241,17 +241,19 @@ cc_library(
     srcs = ["batch_resource_base.cc"],
     hdrs = ["batch_resource_base.h"],
     deps = [
+        ":adaptive_shared_batch_scheduler",
         ":batch_scheduler",
         ":concat_split_util",
         ":shared_batch_scheduler",
+        ":threadsafe_status",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
-        "//tensorflow/core/kernels/batching_util:threadsafe_status",
         "//tensorflow/core/platform:status",
         "//tensorflow/core/platform:thread_annotations",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/core/profiler/lib:traceme_encode",
         "//tensorflow/core/util:incremental_barrier",
+        "@com_google_absl//absl/strings",
     ],
 )
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
index 81a16522c55..e4af643adc4 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
@@ -19,6 +19,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_util.h"
 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/monitoring/gauge.h"
 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/profiler/lib/traceme_encode.h"
@@ -74,6 +75,39 @@ void RecordBatchDelayMs(int64 batch_delay_ms, const string& model_name) {
   cell->GetCell(model_name)->Add(static_cast<double>(batch_delay_ms));
 }
 
+void RecordBatchParamBatchTimeoutMicros(int64 batch_timeout_micros,
+                                        const string& model_name) {
+  static auto* cell = monitoring::Gauge<int64, 1>::New(
+      "/tensorflow/serving/batching/batch_timeout_micros",
+      "Tracks how long a request can wait before being processed by a batch.",
+      "model_name");
+  cell->GetCell(model_name)->Set(batch_timeout_micros);
+}
+
+void RecordBatchParamMaxBatchSize(int64 max_batch_size,
+                                  const string& model_name) {
+  static auto* cell = monitoring::Gauge<int64, 1>::New(
+      "/tensorflow/serving/batching/max_batch_size",
+      "Tracks the maximum size of a batch.", "model_name");
+  cell->GetCell(model_name)->Set(max_batch_size);
+}
+
+void RecordBatchParamMaxEnqueuedBatches(int64 max_enqueued_batches,
+                                        const string& model_name) {
+  static auto* cell = monitoring::Gauge<int64, 1>::New(
+      "/tensorflow/serving/batching/max_enqueued_batches",
+      "Tracks the maximum number of enqueued batches.", "model_name");
+  cell->GetCell(model_name)->Set(max_enqueued_batches);
+}
+
+void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes,
+                                       const string& model_name) {
+  static auto* cell = monitoring::Gauge<string, 1>::New(
+      "/tensorflow/serving/batching/allowed_batch_sizes",
+      "Tracks the sizes that are allowed to form a batch.", "model_name");
+  cell->GetCell(model_name)->Set(allowed_batch_sizes);
+}
+
 const string& GetModelName(OpKernelContext* ctx) {
   static string* kModelNameUnset = new string("model_name_unset");
   if (!ctx->session_metadata()) return *kModelNameUnset;
@@ -132,6 +166,14 @@ Status BatchResourceBase::RegisterInput(
     batch_components->inputs.push_back(tensor);
   }
   RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context));
+  RecordBatchParamBatchTimeoutMicros(
+      batcher_queue_options_.batch_timeout_micros, GetModelName(context));
+  RecordBatchParamMaxBatchSize(batcher_queue_options_.max_execution_batch_size,
+                               GetModelName(context));
+  RecordBatchParamMaxEnqueuedBatches(
+      batcher_queue_options_.max_enqueued_batches, GetModelName(context));
+  RecordBatchParamAllowedBatchSizes(allowed_batch_sizes_str_,
+                                    GetModelName(context));
   OpInputList captured_tensors;
   const auto captured_status =
       context->input_list("captured_tensors", &captured_tensors);
@@ -162,7 +204,6 @@ BatchResourceBase::GetBatcherQueueOptions(
   batcher_queue_options.input_batch_size_limit = max_batch_size;
   batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
   batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
-  // Support for splitting large batch is still in progress.
   batcher_queue_options.enable_large_batch_splitting =
       enable_large_batch_splitting;
   if (enable_large_batch_splitting) {
@@ -185,6 +226,28 @@ BatchResourceBase::GetBatcherQueueOptions(
   return batcher_queue_options;
 }
 
+/*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions
+BatchResourceBase::GetAdaptiveBatcherQueueOptions(
+    int32 max_batch_size, int32 batch_timeout_micros,
+    int32 max_enqueued_batches, bool enable_large_batch_splitting) {
+  AdaptiveBatcherT::QueueOptions batcher_queue_options;
+  batcher_queue_options.max_batch_size = max_batch_size;
+  batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
+  batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
+
+  if (enable_large_batch_splitting) {
+    batcher_queue_options.split_input_task_func =
+        [](std::unique_ptr<BatchTask>* input_task,
+           int open_batch_remaining_slot, int max_batch_size,
+           std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
+      return SplitInputTask(input_task, open_batch_remaining_slot,
+                            max_batch_size, output_tasks);
+    };
+  }
+
+  return batcher_queue_options;
+}
+
 /*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) {
   for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
     const BatchResourceBase::BatchTask& task = batch.task(task_idx);
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h
index 89391f2defe..ea8c7729c89 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.h
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h
@@ -18,10 +18,12 @@ limitations under the License.
 
 #include <map>
 
+#include "absl/strings/str_join.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/threadsafe_status.h"
@@ -48,7 +50,7 @@ class BatchResourceBase : public ResourceBase {
                        const string& batcher_queue_name,
                        AsyncOpKernel::DoneCallback done_callback);
 
- protected:
+ public:
   // One task to be batched, corresponds to a `slice` of input from one batch-op
   // invocation.
   //
@@ -106,6 +108,8 @@ class BatchResourceBase : public ResourceBase {
   // tensorflow::serving namespace, because some versions of compiler complain
   // about changing meaning of the symbols.
   using BatcherT = SharedBatchScheduler<BatchResourceBase::BatchTask>;
+  using AdaptiveBatcherT =
+      AdaptiveSharedBatchScheduler<BatchResourceBase::BatchTask>;
   using BatcherQueueT = BatchScheduler<BatchResourceBase::BatchTask>;
   using BatchT = Batch<BatchResourceBase::BatchTask>;
 
@@ -116,6 +120,17 @@ class BatchResourceBase : public ResourceBase {
       : has_process_batch_function_(has_process_batch_function),
         batcher_(std::move(batcher)),
         batcher_queue_options_(batcher_queue_options),
+        allowed_batch_sizes_(std::move(allowed_batch_sizes)) {
+    allowed_batch_sizes_str_ = absl::StrJoin(allowed_batch_sizes_, ",");
+  }
+
+  BatchResourceBase(bool has_process_batch_function,
+                    std::shared_ptr<AdaptiveBatcherT> batcher,
+                    const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
+                    std::vector<int32> allowed_batch_sizes)
+      : has_process_batch_function_(has_process_batch_function),
+        adaptive_batcher_(std::move(batcher)),
+        adaptive_batcher_queue_options_(batcher_queue_options),
         allowed_batch_sizes_(std::move(allowed_batch_sizes)) {}
 
   static BatcherT::QueueOptions GetBatcherQueueOptions(
@@ -123,6 +138,10 @@ class BatchResourceBase : public ResourceBase {
       int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
       bool enable_large_batch_splitting);
 
+  static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions(
+      int32 max_batch_size, int32 batch_timeout_micros,
+      int32 max_enqueued_batches, bool enable_large_batch_splitting);
+
  private:
   // Implementation of calling the process batch function.
   virtual void ProcessFuncBatchImpl(
@@ -196,6 +215,10 @@ class BatchResourceBase : public ResourceBase {
   std::shared_ptr<BatcherT> batcher_;
   BatcherT::QueueOptions batcher_queue_options_;
 
+  // A batch scheduler, and options for creating queues.
+  std::shared_ptr<AdaptiveBatcherT> adaptive_batcher_;
+  AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_;
+
   // A collection of batcher queues, keyed on queue name.
   // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
   // ones (with a time delay?); it's okay if they get recreated later).
@@ -204,6 +227,9 @@ class BatchResourceBase : public ResourceBase {
       TF_GUARDED_BY(batcher_queues_mu_);
 
   std::vector<int32> allowed_batch_sizes_;
+  // A concatenated string of <allowed_batch_sizes_>, separated by ",". This is
+  // used to record batching parameter.
+  string allowed_batch_sizes_str_;
 };
 
 }  // namespace serving
diff --git a/tensorflow/core/kernels/batching_util/concat_split_util.h b/tensorflow/core/kernels/batching_util/concat_split_util.h
index 77c4463f118..914c793bc89 100644
--- a/tensorflow/core/kernels/batching_util/concat_split_util.h
+++ b/tensorflow/core/kernels/batching_util/concat_split_util.h
@@ -71,8 +71,10 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
 
   TensorShape output_shape(input_shape);
   output_shape.set_dim(0, output_dim0);
-  TF_RETURN_IF_ERROR(
-      context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
+  AllocatorAttributes attr;
+  attr.set_on_host(true);
+  TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
+                                            output_shape, output, attr));
   if (output->NumElements() > 0) {
     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
@@ -167,8 +169,10 @@ Status SplitCPU(OpKernelContext* context, const Tensor& input,
     TensorShape output_shape = input.shape();
     output_shape.set_dim(0, size);
     Tensor output;
+    AllocatorAttributes attr;
+    attr.set_on_host(true);
     TF_RETURN_IF_ERROR(
-        context->allocate_temp(input.dtype(), output_shape, &output));
+        context->allocate_temp(input.dtype(), output_shape, &output, attr));
     auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
 
     Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{
diff --git a/tensorflow/core/kernels/cwise_op_imag.cc b/tensorflow/core/kernels/cwise_op_imag.cc
index bda9c19e3c2..fb76ec24df0 100644
--- a/tensorflow/core/kernels/cwise_op_imag.cc
+++ b/tensorflow/core/kernels/cwise_op_imag.cc
@@ -27,9 +27,12 @@ REGISTER_COMPLEX(CPU, float, complex64);
 REGISTER_COMPLEX(CPU, double, complex128);
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
 REGISTER_COMPLEX(GPU, float, complex64);
 REGISTER_COMPLEX(GPU, double, complex128);
 #endif
+#endif
 
 #undef REGISTER_COMPLEX
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_real.cc b/tensorflow/core/kernels/cwise_op_real.cc
index 453f2801132..cb8484802f0 100644
--- a/tensorflow/core/kernels/cwise_op_real.cc
+++ b/tensorflow/core/kernels/cwise_op_real.cc
@@ -28,9 +28,12 @@ REGISTER_COMPLEX(CPU, float, complex64);
 REGISTER_COMPLEX(CPU, double, complex128);
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
 REGISTER_COMPLEX(GPU, float, complex64);
 REGISTER_COMPLEX(GPU, double, complex128);
 #endif
+#endif
 
 #undef REGISTER_COMPLEX
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sin.cc b/tensorflow/core/kernels/cwise_op_sin.cc
index d3e8f3b605c..3689f8b7399 100644
--- a/tensorflow/core/kernels/cwise_op_sin.cc
+++ b/tensorflow/core/kernels/cwise_op_sin.cc
@@ -20,7 +20,10 @@ REGISTER6(UnaryOp, CPU, "Sin", functor::sin, float, Eigen::half, bfloat16,
           double, complex64, complex128);
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
 REGISTER3(UnaryOp, GPU, "Sin", functor::sin, float, Eigen::half, double);
 #endif
+#endif
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
index 47f0f588991..93c2f57033e 100644
--- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
@@ -38,6 +38,7 @@ limitations under the License.
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/snappy.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
 
@@ -60,13 +61,8 @@ namespace data {
 /* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
 
 namespace {
-// Once we've spent `kRetryTimeoutMicros` in `GetNextInternal`, we will wait for
-// the current attempt to complete and perform no more retries.
-const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60;  // 60 minutes.
-
 // Default interval between task list refreshes.
 const int64 kDefaultTaskRefreshIntervalMs = 1000;  // 1 second.
-
 }  // namespace
 
 // Dataset for reading data from the tf.data service non-deterministically.
@@ -224,30 +220,23 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
           &deregister_fn_));
       dispatcher_ = absl::make_unique<DataServiceDispatcherClient>(
           dataset()->address_, dataset()->protocol_);
-      int64 deadline_micros = ctx->env()->NowMicros() + kRetryTimeoutMicros;
-      if (dataset()->job_name_.empty()) {
-        TF_RETURN_IF_ERROR(grpc_util::Retry(
-            [&]() {
-              return dispatcher_->CreateJob(dataset()->dataset_id_,
-                                            dataset()->processing_mode_,
-                                            job_client_id_);
-            },
-            /*description=*/
-            strings::StrCat("create job with dispatcher at ",
-                            dataset()->address_),
-            deadline_micros));
-      } else {
-        TF_RETURN_IF_ERROR(grpc_util::Retry(
-            [&]() {
-              return dispatcher_->GetOrCreateJob(
-                  dataset()->dataset_id_, dataset()->processing_mode_,
-                  dataset()->job_name_, iterator_index_, job_client_id_);
-            },
-            /*description=*/
-            strings::StrCat("get or create job with dispatcher at ",
-                            dataset()->address_),
-            deadline_micros));
+      int64 deadline_micros = kint64max;
+      absl::optional<JobKey> key;
+      if (!dataset()->job_name_.empty()) {
+        key.emplace();
+        key.value().set_job_name(std::string(dataset()->job_name_));
+        key.value().set_job_name_index(iterator_index_);
       }
+      TF_RETURN_IF_ERROR(grpc_util::Retry(
+          [&]() {
+            return dispatcher_->GetOrCreateJob(dataset()->dataset_id_,
+                                               dataset()->processing_mode_, key,
+                                               job_client_id_);
+          },
+          /*description=*/
+          strings::StrCat("get or create job with dispatcher at ",
+                          dataset()->address_),
+          deadline_micros));
       initialized_ = true;
       VLOG(1) << "Created data service job with id " << job_client_id_;
       return Status::OK();
@@ -496,8 +485,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
           DCHECK(task_to_process != nullptr);
           VLOG(3) << "Processing task " << task_to_process->task_id;
         }
-        int64 deadline_micros =
-            Env::Default()->NowMicros() + kRetryTimeoutMicros;
+        int64 deadline_micros = kint64max;
         Status s = GetElement(task_to_process.get(), deadline_micros);
         if (!s.ok()) {
           mutex_lock l(mu_);
diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
index bfa39a71bd9..53ff074438d 100644
--- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
@@ -94,6 +94,14 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
     return kUnknownCardinality;
   }
 
+  Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
+    inputs->push_back(selector_input_);
+    for (const auto& data_input : data_inputs_) {
+      inputs->push_back(data_input);
+    }
+    return Status::OK();
+  }
+
   Status CheckExternalState() const override {
     for (const auto& input : data_inputs_) {
       TF_RETURN_IF_ERROR(input->CheckExternalState());
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 0606ca24da0..6c8622274d5 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -85,7 +85,7 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
     // clang-format off
     absl::flat_hash_map<string, uint64> live_experiments = {
         {"enable_gradient_descent", 0},
-        {"map_parallelization", 50}
+        {"map_parallelization", 100}
     };
     // clang-format on
     auto hash_func = [](const string& str) { return Hash64(str); };
diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
index a71f2902559..5607d41b178 100644
--- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
@@ -41,6 +41,8 @@ class FuzzParseTensor : public FuzzSession {
     // remainder of the fuzzer testing. Of course, this duplicates some work
     // but it's better than repeating the investigation whenever Autofuzz
     // detects another similar OOM.
+    // After adding `-fsanitize=null` to ASAN (cl/317376103), the memory
+    // footprint increased, so we lower the maximum threshold to 2^18.
     string as_string = string(reinterpret_cast<const char*>(data), size);
     TensorProto proto;
     if (!ParseProtoUnlimited(&proto, as_string)) {
@@ -53,7 +55,7 @@ class FuzzParseTensor : public FuzzSession {
     }
     TensorShape shape(proto.tensor_shape());
     const int64 num_elements = shape.num_elements();
-    const int64 max_num_elements = 1 << 20;
+    const int64 max_num_elements = 1 << 18;
     if (num_elements > max_num_elements) {
       LOG(WARNING) << "Requiring a tensor with too many elements\n";
       return;
diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD
index c9789ea2421..b26716d19f2 100644
--- a/tensorflow/core/kernels/mlir_generated/BUILD
+++ b/tensorflow/core/kernels/mlir_generated/BUILD
@@ -52,10 +52,13 @@ filegroup(
             "unranked_op_gpu_cos.cc",
             "unranked_op_gpu_exp.cc",
             "unranked_op_gpu_floor.cc",
+            "unranked_op_gpu_imag.cc",
             "unranked_op_gpu_log.cc",
             "unranked_op_gpu_logical_not.cc",
+            "unranked_op_gpu_real.cc",
             "unranked_op_gpu_rsqrt.cc",
             "unranked_op_gpu_sign.cc",
+            "unranked_op_gpu_sin.cc",
             "unranked_op_gpu_sqrt.cc",
             "unranked_op_gpu_tanh.cc",
         ],
@@ -104,10 +107,13 @@ tf_kernel_library(
             ":cos_unranked_kernels",
             ":exp_unranked_kernels",
             ":floor_unranked_kernels",
+            ":imag_unranked_kernels",
             ":log_unranked_kernels",
             ":logical_not_unranked_kernels",
+            ":real_unranked_kernels",
             ":rsqrt_unranked_kernels",
             ":sign_unranked_kernels",
+            ":sin_unranked_kernels",
             ":sqrt_unranked_kernels",
             ":tanh_unranked_kernels",
             ":unranked_op_gpu_base",
@@ -116,19 +122,18 @@ tf_kernel_library(
     ),
 )
 
-# TODO(herhut): Uncomment once unranked kernels build again.
-# tf_kernel_library(
-#     name = "cwise_binary_op",
-#     srcs = ["unranked_gpu_add.cc"],
-#     tags = [
-#         "manual",
-#     ],
-#     deps = [
-#         ":add_v2_unranked_kernels",
-#         ":unranked_op_gpu_base",
-#         "//third_party/eigen3",
-#     ],
-# )
+tf_kernel_library(
+    name = "cwise_binary_op",
+    srcs = ["unranked_gpu_add.cc"],
+    tags = [
+        "manual",
+    ],
+    deps = [
+        ":add_v2_unranked_kernels",
+        ":unranked_op_gpu_base",
+        "//third_party/eigen3",
+    ],
+)
 
 tf_kernel_library(
     name = "cwise_op",
@@ -142,8 +147,7 @@ tf_kernel_library(
         ":cwise_unary_op",
     ]) + if_mlir_unranked_kernels_enabled(
         [
-            # TODO(herhut): Uncomment once it builds again.
-            # ":cwise_binary_op",
+            ":cwise_binary_op",
         ],
     ),
 )
@@ -169,27 +173,27 @@ tf_cuda_cc_test(
     ],
 )
 
-# TODO(herhut): Uncomment once unranked kernels build again.
-# tf_cuda_cc_test(
-#     name = "gpu_add_test",
-#     size = "small",
-#     srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_add_test.cc"]),
-#     tags = tf_cuda_tests_tags() + [
-#         "no_cuda_asan",  # b/173033461
-#     ],
-#     deps = [
-#         "//tensorflow/core:framework",
-#         "//tensorflow/core:framework_internal",
-#         "//tensorflow/core:tensorflow",
-#         "//tensorflow/core:test",
-#         "//tensorflow/core:test_main",
-#         "//tensorflow/core:testlib",
-#         "//tensorflow/core/common_runtime:device",
-#         "//tensorflow/core/common_runtime:device_factory",
-#         "//tensorflow/core/kernels:cwise_op",
-#         "//tensorflow/core/kernels:ops_testutil",
-#     ],
-# )
+tf_cuda_cc_test(
+    name = "gpu_add_test",
+    size = "small",
+    srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_add_test.cc"]),
+    tags = tf_cuda_tests_tags() + [
+        "no_cuda_asan",  # b/173033461
+    ],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:tensorflow",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/common_runtime:device",
+        "//tensorflow/core/common_runtime:device_factory",
+        "//tensorflow/core/kernels:cwise_op",
+        "//tensorflow/core/kernels:ops_testutil",
+    ],
+)
+
 # TODO(b/160731748): Re-enable when it works again.
 # gen_kernel_library(
 #     name = "bias_add",
@@ -238,6 +242,7 @@ gen_kernel_library(
 
 gen_kernel_library(
     name = "imag",
+    generate_unranked = True,
     tile_size = "256",
     types = [
         "f32",
@@ -269,6 +274,7 @@ gen_kernel_library(
 
 gen_kernel_library(
     name = "real",
+    generate_unranked = True,
     tile_size = "256",
     types = [
         "f32",
@@ -292,20 +298,19 @@ gen_kernel_library(
     unroll_factors = "4",
 )
 
-# TODO(herhut): Uncomment once it builds again.
-# gen_kernel_library(
-#     name = "add_v2",
-#     generate_ranked = False,
-#     generate_unranked = True,
-#     tile_size = "256",
-#     types = [
-#         "f16",
-#         "f32",
-#         "f64",
-#         "i64",
-#     ],
-#     unroll_factors = "4",
-# )
+gen_kernel_library(
+    name = "add_v2",
+    generate_ranked = False,
+    generate_unranked = True,
+    tile_size = "256,1,1",
+    types = [
+        "f16",
+        "f32",
+        "f64",
+        "i64",
+    ],
+    unroll_factors = "4",
+)
 
 [
     gen_kernel_library(
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc
index 42806ee807a..97e55971a02 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc
@@ -64,6 +64,36 @@ class GpuAddTest : public OpsTestBase {
     test::ExpectEqual(expected_tensor, *GetOutput(0));
   }
 
+  template <typename T, typename BaselineType = T>
+  void TestBroadcastingExpandAddOp() {
+    auto input_1 = {static_cast<T>(10)};
+    auto input_2 = {static_cast<T>(1), static_cast<T>(2), static_cast<T>(3),
+                    static_cast<T>(4), static_cast<T>(5), static_cast<T>(6)};
+    std::vector<T> expected{
+        static_cast<T>(11), static_cast<T>(12), static_cast<T>(13),
+        static_cast<T>(14), static_cast<T>(15), static_cast<T>(16),
+    };
+    auto expected_shape = TensorShape({6});
+    RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape({1}), input_2,
+                                        TensorShape({6}), expected,
+                                        expected_shape);
+  }
+
+  template <typename T, typename BaselineType = T>
+  void TestBroadcastingInDimAddOp() {
+    auto input_1 = {static_cast<T>(10), static_cast<T>(20), static_cast<T>(30)};
+    auto input_2 = {static_cast<T>(1), static_cast<T>(2), static_cast<T>(3),
+                    static_cast<T>(4), static_cast<T>(5), static_cast<T>(6)};
+    std::vector<T> expected{
+        static_cast<T>(11), static_cast<T>(22), static_cast<T>(33),
+        static_cast<T>(14), static_cast<T>(25), static_cast<T>(36),
+    };
+    auto expected_shape = TensorShape({2, 3});
+    RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape({3}), input_2,
+                                        TensorShape({2, 3}), expected,
+                                        expected_shape);
+  }
+
   template <typename T, typename BaselineType = T>
   void TestBroadcastingAddOp() {
     auto input_1 = {static_cast<T>(10), static_cast<T>(20)};
@@ -104,6 +134,52 @@ class GpuAddTest : public OpsTestBase {
                                         TensorShape{2, 3});
   }
 
+  template <typename T, typename BaselineType = T>
+  void TestEqualShapesAddOp() {
+    auto input_1 = {
+        static_cast<T>(-std::numeric_limits<BaselineType>::infinity()),
+        static_cast<T>(-0.1),
+        static_cast<T>(-0.0),
+        static_cast<T>(0.0),
+        static_cast<T>(0.1),
+        static_cast<T>(std::numeric_limits<BaselineType>::infinity())};
+    auto input_2 = {
+        static_cast<T>(-std::numeric_limits<BaselineType>::infinity()),
+        static_cast<T>(-0.1),
+        static_cast<T>(-0.0),
+        static_cast<T>(0.0),
+        static_cast<T>(0.1),
+        static_cast<T>(std::numeric_limits<BaselineType>::infinity())};
+    std::vector<T> expected;
+    for (const T& inp : input_2) {
+      expected.push_back(static_cast<T>(static_cast<BaselineType>(inp) +
+                                        static_cast<BaselineType>(inp)));
+    }
+    RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape{2, 3}, input_2,
+                                        TensorShape{2, 3}, expected,
+                                        TensorShape{2, 3});
+  }
+
+  template <typename T, typename BaselineType = T>
+  void TestOneIsScalarAddOp() {
+    auto input_1 = static_cast<T>(42);
+    auto input_2 = {
+        static_cast<T>(-std::numeric_limits<BaselineType>::infinity()),
+        static_cast<T>(-0.1),
+        static_cast<T>(-0.0),
+        static_cast<T>(0.0),
+        static_cast<T>(0.1),
+        static_cast<T>(std::numeric_limits<BaselineType>::infinity())};
+    std::vector<T> expected;
+    for (const T& inp : input_2) {
+      expected.push_back(static_cast<T>(static_cast<BaselineType>(input_1) +
+                                        static_cast<BaselineType>(inp)));
+    }
+    RunAndCompareAddOp<T, BaselineType>({input_1}, TensorShape{}, input_2,
+                                        TensorShape{2, 3}, expected,
+                                        TensorShape{2, 3});
+  }
+
   template <typename T, typename RT = T>
   void TestIncompatibleShapes() {
     auto input_1 = {static_cast<T>(-0.1), static_cast<T>(-0.0),
@@ -120,12 +196,51 @@ class GpuAddTest : public OpsTestBase {
 TEST_F(GpuAddTest, AddFloat) { RunAddOp<float>(); }
 TEST_F(GpuAddTest, AddDouble) { RunAddOp<double>(); }
 TEST_F(GpuAddTest, AddHalf) { RunAddOp<Eigen::half, float>(); }
+TEST_F(GpuAddTest, AddInt64) { RunAddOp<int64, int64>(); }
+
+TEST_F(GpuAddTest, AddEqShapesFloat) { TestEqualShapesAddOp<float>(); }
+TEST_F(GpuAddTest, AddEqShapesDouble) { TestEqualShapesAddOp<double>(); }
+TEST_F(GpuAddTest, AddEqShapesHalf) {
+  TestEqualShapesAddOp<Eigen::half, float>();
+}
+TEST_F(GpuAddTest, AddEqShapesInt64) { TestEqualShapesAddOp<int64>(); }
+
+TEST_F(GpuAddTest, AddScalarFloat) { TestOneIsScalarAddOp<float>(); }
+TEST_F(GpuAddTest, AddScalarDouble) { TestOneIsScalarAddOp<double>(); }
+TEST_F(GpuAddTest, AddScalarHalf) {
+  TestOneIsScalarAddOp<Eigen::half, float>();
+}
+TEST_F(GpuAddTest, AddScalarInt64) { TestOneIsScalarAddOp<int64>(); }
+
+TEST_F(GpuAddTest, BCastExpandAddFloat) {
+  TestBroadcastingExpandAddOp<float>();
+}
+TEST_F(GpuAddTest, BCastExpandAddDouble) {
+  TestBroadcastingExpandAddOp<double>();
+}
+TEST_F(GpuAddTest, BCastExpandAddHalf) {
+  TestBroadcastingExpandAddOp<Eigen::half, float>();
+}
+TEST_F(GpuAddTest, BCastExpandAddInt64) {
+  TestBroadcastingExpandAddOp<int64>();
+}
+
+TEST_F(GpuAddTest, BCastInDimAddFloat) { TestBroadcastingInDimAddOp<float>(); }
+TEST_F(GpuAddTest, BCastInDimAddDouble) {
+  TestBroadcastingInDimAddOp<double>();
+}
+TEST_F(GpuAddTest, BCastInDimAddHalf) {
+  TestBroadcastingInDimAddOp<Eigen::half, float>();
+}
+TEST_F(GpuAddTest, BCastInDimAddInt64) { TestBroadcastingInDimAddOp<int64>(); }
+
 TEST_F(GpuAddTest, BCastAddFloat) { TestBroadcastingAddOp<float>(); }
 TEST_F(GpuAddTest, BCastAddDouble) { TestBroadcastingAddOp<double>(); }
 TEST_F(GpuAddTest, BCastAddHalf) {
   TestBroadcastingAddOp<Eigen::half, float>();
 }
 TEST_F(GpuAddTest, BCastAddInt64) { TestBroadcastingAddOp<int64>(); }
+
 TEST_F(GpuAddTest, IncompatibleShapes) { TestIncompatibleShapes<float>(); }
 
 // TEST_F(GpuAddTest, AddV2Half) { RunAddOp<Eigen::half, float>(); }
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
index a7541032ec3..a878a9d2431 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
@@ -14,10 +14,12 @@ limitations under the License.
 ==============================================================================*/
 
 #include <cmath>
+#include <complex>
 #include <functional>
 #include <initializer_list>
 #include <memory>
 #include <numeric>
+#include <type_traits>
 #include <vector>
 
 #include "tensorflow/core/common_runtime/device.h"
@@ -42,32 +44,39 @@ class GpuUnaryOpTest : public OpsTestBase {
     SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
   }
 
-  template <typename T, typename RT = T>
+  // 'T' is the input type, 'RT' is the input type for the callback function,
+  // 'OutT' is the output type, 'ROutT' is the output type for the callback
+  // function. In most cases it is enough to just provide the input type,
+  // because all the types are the same.
+  template <typename T, typename RT = T, typename OutT = T, typename ROutT = RT>
   void Run(std::vector<int64> input_shape, std::vector<T> input,
-           const std::string op_name, RT (*expected_callback)(RT),
+           const std::string op_name, ROutT (*expected_callback)(RT),
            bool expect_equal = true) {
     assert(std::accumulate(input_shape.begin(), input_shape.end(), 1,
                            std::multiplies<int64>()) == input.size() &&
            "Expected input length to equal to shape's number of elements.");
 
     TensorShape shape(input_shape);
-    TF_ASSERT_OK(NodeDefBuilder("some_name", op_name)
-                     .Input(FakeInput(DataTypeToEnum<T>::v()))
-                     .Attr("T", DataTypeToEnum<T>::v())
-                     .Finalize(node_def()));
+    NodeDefBuilder builder("some_name", op_name);
+    builder.Input(FakeInput(DataTypeToEnum<T>::v()))
+        .Attr("T", DataTypeToEnum<T>::v());
+    if (!std::is_same_v<OutT, T>) {
+      builder.Attr("Tout", DataTypeToEnum<OutT>::v());
+    }
+    TF_ASSERT_OK(builder.Finalize(node_def()));
 
     TF_ASSERT_OK(InitOp());
     AddInputFromArray<T>(shape, input);
     TF_ASSERT_OK(RunOpKernel());
 
-    Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
-    std::vector<T> expected;
+    Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, shape);
+    std::vector<OutT> expected;
     expected.reserve(input.size());
     for (const T& inp : input) {
       expected.push_back(
-          static_cast<T>(expected_callback(static_cast<RT>(inp))));
+          static_cast<OutT>(expected_callback(static_cast<RT>(inp))));
     }
-    test::FillValues<T>(&expected_tensor, expected);
+    test::FillValues<OutT>(&expected_tensor, expected);
     if (expect_equal) {
       test::ExpectEqual(expected_tensor, *GetOutput(0));
     } else {
@@ -85,6 +94,16 @@ class GpuUnaryOpTest : public OpsTestBase {
                              0.5, 0.7, 0.9, 9.0, 18.0});
   }
 
+  template <typename T>
+  std::vector<std::complex<T>> DefaultComplexInput() {
+    auto input = DefaultInput<T>();
+    std::vector<std::complex<T>> complex_input;
+    for (T value : input) {
+      complex_input.emplace_back(value, -value);
+    }
+    return complex_input;
+  }
+
   template <typename T>
   std::vector<T> DefaultInputGreaterThanZero() {
     return InputAsVector<T>({18.0, 9.0, 1e-6, 1.0, 0.1, 1e-6, 0.1, 0.2, 0.3,
@@ -109,52 +128,6 @@ class GpuUnaryOpTest : public OpsTestBase {
   }
 };
 
-/// Test `tf.Tanh`.
-
-TEST_F(GpuUnaryOpTest, TanhFloat) {
-  Run<float>(DefaultInputShape(), DefaultInput<float>(),
-             /*op_name=*/"Tanh",
-             /*expected_callback=*/std::tanh,
-             /*expect_equal=*/false);
-}
-
-TEST_F(GpuUnaryOpTest, TanhDouble) {
-  Run<double>(DefaultInputShape(), DefaultInput<double>(),
-              /*op_name=*/"Tanh",
-              /*expected_callback=*/std::tanh,
-              /*expect_equal=*/false);
-}
-
-TEST_F(GpuUnaryOpTest, TanhHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
-                          /*op_name=*/"Tanh",
-                          /*expected_callback=*/std::tanh,
-                          /*expect_equal=*/false);
-}
-
-/// Test `tf.Ceil`.
-
-TEST_F(GpuUnaryOpTest, CeilFloat) {
-  Run<float>(DefaultInputShape(), DefaultInput<float>(),
-             /*op_name=*/"Ceil",
-             /*expected_callback=*/std::ceil,
-             /*expect_equal=*/true);
-}
-
-TEST_F(GpuUnaryOpTest, CeilDouble) {
-  Run<double>(DefaultInputShape(), DefaultInput<double>(),
-              /*op_name=*/"Ceil",
-              /*expected_callback=*/std::ceil,
-              /*expect_equal=*/true);
-}
-
-TEST_F(GpuUnaryOpTest, CeilHalf) {
-  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
-                          /*op_name=*/"Ceil",
-                          /*expected_callback=*/std::ceil,
-                          /*expect_equal=*/true);
-}
-
 /// Test `tf.Abs`.
 
 TEST_F(GpuUnaryOpTest, AbsFloat) {
@@ -214,6 +187,29 @@ TEST_F(GpuUnaryOpTest, AbsInt64) {
       /*expect_equal=*/true);
 }
 
+/// Test `tf.Ceil`.
+
+TEST_F(GpuUnaryOpTest, CeilFloat) {
+  Run<float>(DefaultInputShape(), DefaultInput<float>(),
+             /*op_name=*/"Ceil",
+             /*expected_callback=*/std::ceil,
+             /*expect_equal=*/true);
+}
+
+TEST_F(GpuUnaryOpTest, CeilDouble) {
+  Run<double>(DefaultInputShape(), DefaultInput<double>(),
+              /*op_name=*/"Ceil",
+              /*expected_callback=*/std::ceil,
+              /*expect_equal=*/true);
+}
+
+TEST_F(GpuUnaryOpTest, CeilHalf) {
+  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
+                          /*op_name=*/"Ceil",
+                          /*expected_callback=*/std::ceil,
+                          /*expect_equal=*/true);
+}
+
 /// Test `tf.Cos`.
 
 TEST_F(GpuUnaryOpTest, CosFloat) {
@@ -283,6 +279,24 @@ TEST_F(GpuUnaryOpTest, FloorHalf) {
                           /*expect_equal=*/true);
 }
 
+/// Test `tf.Imag`.
+
+TEST_F(GpuUnaryOpTest, ImagFloat) {
+  Run<std::complex<float>, const std::complex<float>&, float, float>(
+      DefaultInputShape(), DefaultComplexInput<float>(),
+      /*op_name=*/"Imag",
+      /*expected_callback=*/std::imag,
+      /*expect_equal=*/false);
+}
+
+TEST_F(GpuUnaryOpTest, ImagDouble) {
+  Run<std::complex<double>, const std::complex<double>&, double, double>(
+      DefaultInputShape(), DefaultComplexInput<double>(),
+      /*op_name=*/"Imag",
+      /*expected_callback=*/std::imag,
+      /*expect_equal=*/false);
+}
+
 /// Test `tf.Log`.
 
 TEST_F(GpuUnaryOpTest, LogFloat) {
@@ -338,6 +352,24 @@ TEST_F(GpuUnaryOpTest, NegHalf) {
                           /*expect_equal=*/true);
 }
 
+/// Test `tf.Real`.
+
+TEST_F(GpuUnaryOpTest, RealFloat) {
+  Run<std::complex<float>, const std::complex<float>&, float, float>(
+      DefaultInputShape(), DefaultComplexInput<float>(),
+      /*op_name=*/"Real",
+      /*expected_callback=*/std::real,
+      /*expect_equal=*/false);
+}
+
+TEST_F(GpuUnaryOpTest, RealDouble) {
+  Run<std::complex<double>, const std::complex<double>&, double, double>(
+      DefaultInputShape(), DefaultComplexInput<double>(),
+      /*op_name=*/"Real",
+      /*expected_callback=*/std::real,
+      /*expect_equal=*/false);
+}
+
 /// Test `tf.Rsqrt`.
 
 /// Reference implementation.
@@ -369,6 +401,39 @@ TEST_F(GpuUnaryOpTest, RsqrtHalf) {
                           /*expect_equal=*/false);
 }
 
+/// Test `tf.Sign`.
+
+// Reference implementation
+template <typename T>
+T expected_sign(T x) {
+  if (x == 0) return 0;
+  if (x < 0) return -1;
+  return 1;
+}
+
+// TODO(b/162577610): Enable these tests when our generated kernels handle 0.0
+// and -0.0 correctly.
+TEST_F(GpuUnaryOpTest, DISABLED_SignFloat) {
+  Run<float>(DefaultInputShape(), DefaultInput<float>(),
+             /*op_name=*/"Sign",
+             /*expected_callback=*/expected_sign,
+             /*expect_equal=*/true);
+}
+
+TEST_F(GpuUnaryOpTest, DISABLED_SignDouble) {
+  Run<double>(DefaultInputShape(), DefaultInput<double>(),
+              /*op_name=*/"Sign",
+              /*expected_callback=*/expected_sign,
+              /*expect_equal=*/true);
+}
+
+TEST_F(GpuUnaryOpTest, DISABLED_SignHalf) {
+  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
+                          /*op_name=*/"Sign",
+                          /*expected_callback=*/expected_sign,
+                          /*expect_equal=*/true);
+}
+
 /// Test `tf.Sin`.
 
 TEST_F(GpuUnaryOpTest, SinFloat) {
@@ -416,5 +481,28 @@ TEST_F(GpuUnaryOpTest, SqrtHalf) {
                           /*expect_equal=*/false);
 }
 
+/// Test `tf.Tanh`.
+
+TEST_F(GpuUnaryOpTest, TanhFloat) {
+  Run<float>(DefaultInputShape(), DefaultInput<float>(),
+             /*op_name=*/"Tanh",
+             /*expected_callback=*/std::tanh,
+             /*expect_equal=*/false);
+}
+
+TEST_F(GpuUnaryOpTest, TanhDouble) {
+  Run<double>(DefaultInputShape(), DefaultInput<double>(),
+              /*op_name=*/"Tanh",
+              /*expected_callback=*/std::tanh,
+              /*expect_equal=*/false);
+}
+
+TEST_F(GpuUnaryOpTest, TanhHalf) {
+  Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
+                          /*op_name=*/"Tanh",
+                          /*expected_callback=*/std::tanh,
+                          /*expect_equal=*/false);
+}
+
 }  // namespace
 }  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl
index 8605741dfea..36ca1d8cf56 100644
--- a/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl
@@ -1,4 +1,5 @@
-func @Angle(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> {
-  %0 = "tf.Angle"(%arg0) : (tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type>
-  return %0 : tensor<?xelem_type>
+func @Angle_elem_type(%arg0: tensor<*xcomplex<elem_type>>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Angle"(%arg0) : (tensor<*xcomplex<elem_type>>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
 }
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl
index 741ca1b145c..0b5f35eb985 100644
--- a/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl
@@ -1,4 +1,5 @@
-func @Asin(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
-  %0 = "tf.Asin"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type>
-  return %0 : tensor<?xelem_type>
+func @Asin_elem_type(%arg0: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Asin"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
 }
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl
index 80e22f38dbe..6ee1349549a 100644
--- a/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl
@@ -1,4 +1,5 @@
-func @Atan(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
-  %0 = "tf.Atan"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type>
-  return %0 : tensor<?xelem_type>
+func @Atan_elem_type(%arg0: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Atan"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
 }
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl
index 210df66516c..78e5ffe204d 100644
--- a/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl
@@ -1,6 +1,6 @@
-func @BiasAdd(%arg0: tensor<?x?xelem_type>, %arg1: tensor<?xelem_type>)
-    -> tensor<?x?xelem_type> {
+func @BiasAdd_elem_type(%arg0: tensor<*x*xelem_type>, %arg1: tensor<*xelem_type>)
+    -> tensor<*x*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
   %0 = "tf.BiasAdd"(%arg0, %arg1)
-      : (tensor<?x?xelem_type>, tensor<?xelem_type>) -> tensor<?x?xelem_type>
-  return %0 : tensor<?x?xelem_type>
+      : (tensor<*x*xelem_type>, tensor<*xelem_type>) -> tensor<*x*xelem_type>
+  return %0 : tensor<*x*xelem_type>
 }
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl
index 963a0740c6f..5608e9b37f3 100644
--- a/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl
@@ -1,4 +1,5 @@
-func @Conj(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
-  %0 = "tf.Conj"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type>
-  return %0 : tensor<?xelem_type>
+func @Conj_elem_type(%arg0: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Conj"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
 }
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl
index c68c85a798c..d52487a8d98 100644
--- a/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl
@@ -1,4 +1,5 @@
-func @Imag(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> {
-  %0 = "tf.Imag"(%arg0) : (tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type>
-  return %0 : tensor<?xelem_type>
+func @Imag_elem_type(%arg0: tensor<*xcomplex<elem_type>>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Imag"(%arg0) : (tensor<*xcomplex<elem_type>>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
 }
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl
index 600fbe563b8..5ddbfb666f0 100644
--- a/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl
@@ -1,4 +1,5 @@
-func @Real(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> {
-  %0 = "tf.Real"(%arg0) : (tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type>
-  return %0 : tensor<?xelem_type>
+func @Real_elem_type(%arg0: tensor<*xcomplex<elem_type>>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Real"(%arg0) : (tensor<*xcomplex<elem_type>>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
 }
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl
index 0e739740f93..de3f1d8505f 100644
--- a/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl
@@ -1,4 +1,5 @@
-func @Sin(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
-  %0 = "tf.Sin"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type>
-  return %0 : tensor<?xelem_type>
+func @Sin_elem_type(%arg0: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Sin"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
 }
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h
index badb510ad74..3a939293347 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h
+++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h
@@ -57,7 +57,7 @@ template <typename ElemType>
 
 template <typename ElemType>
 Tensor ConvertDescriptorToTensor(
-    ::UnrankedMemRefType<ElemType> unranked_descriptor, DataType tf_data_type,
+    ::UnrankedMemRefType<ElemType> unranked_descriptor, DataType TfDataType,
     Allocator* allocator) {
   void* base_ptr = static_cast<void**>(unranked_descriptor.descriptor)[0];
   TensorShape result_shape;
@@ -69,25 +69,26 @@ Tensor ConvertDescriptorToTensor(
       base_ptr, sizeof(ElemType) * result_shape.num_elements(), allocator);
 
   // Tensor takes ownership of the buffer.
-  Tensor tensor{tf_data_type, result_shape, buffer};
+  Tensor tensor{TfDataType, result_shape, buffer};
   // When Tensor is constructed, its ref-counter is incremented. We need to
   // decrement it back.
   buffer->Unref();
   return tensor;
 }
 
-template <DataType tf_data_type, typename data_type, typename Derived>
+template <DataType TfDataType, typename OutputDataType, typename Kernel,
+          typename InputDataType = OutputDataType>
 class MlirUnrankedOp : public OpKernel {
  public:
   explicit MlirUnrankedOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
 
   void Compute(OpKernelContext* ctx) override {
-    llvm::SmallVector<::UnrankedMemRefType<data_type>, 2> input_descs;
+    llvm::SmallVector<::UnrankedMemRefType<InputDataType>, 2> input_descs;
     for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
       input_descs.push_back(
-          std::move(ConvertTensorToDescriptor<data_type>(ctx->input(i))));
+          std::move(ConvertTensorToDescriptor<InputDataType>(ctx->input(i))));
     }
-    auto result_desc = Derived::Invoke(ctx, input_descs);
+    auto result_desc = Kernel::Invoke(ctx, input_descs);
     for (const auto& input_desc : input_descs) {
       free(input_desc.descriptor);
     }
@@ -100,25 +101,38 @@ class MlirUnrankedOp : public OpKernel {
     for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
       const Tensor& input = ctx->input(i);
       if (input.data() == result_data_ptr) {
-        ctx->set_output(0, input);
+        // Run a bitcast in case the output type is different.
+        Tensor output;
+        OP_REQUIRES_OK(ctx,
+                       output.BitcastFrom(input, TfDataType, input.shape()));
+        ctx->set_output(0, output);
         free(result_desc.descriptor);
         return;
       }
     }
     tensorflow::AllocatorAttributes attrs;
     auto* allocator = ctx->get_allocator(attrs);
-    Tensor result_tensor = ConvertDescriptorToTensor<data_type>(
-        result_desc, tf_data_type, allocator);
+    Tensor result_tensor = ConvertDescriptorToTensor<OutputDataType>(
+        result_desc, TfDataType, allocator);
     free(result_desc.descriptor);
     ctx->set_output(0, result_tensor);
   }
 };
 
 #define MLIR_FUNCTION(tf_op, mlir_type) _mlir_ciface_##tf_op##_##mlir_type
+
 #define REGISTER_KERNEL(tf_op, mlir_type, data_type)                  \
   REGISTER_KERNEL_BUILDER(                                            \
       Name(#tf_op).Device(DEVICE_GPU).TypeConstraint<data_type>("T"), \
       MlirUnranked##tf_op##mlir_type##Op);
+
+#define REGISTER_COMPLEX_KERNEL(tf_op, mlir_type, data_type, input_data_type) \
+  REGISTER_KERNEL_BUILDER(Name(#tf_op)                                        \
+                              .Device(DEVICE_GPU)                             \
+                              .TypeConstraint<input_data_type>("T")           \
+                              .TypeConstraint<data_type>("Tout"),             \
+                          MlirUnranked##tf_op##mlir_type##Op);
+
 #define REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, mlir_type) \
   REGISTER_KERNEL_BUILDER(Name(#tf_op).Device(DEVICE_GPU),   \
                           MlirUnranked##tf_op##mlir_type##Op);
@@ -158,21 +172,26 @@ class MlirUnrankedOp : public OpKernel {
   GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type)         \
   REGISTER_KERNEL(tf_op, mlir_type, data_type)
 
-#define GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type)      \
+#define GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
+  GENERATE_UNARY_KERNEL2(tf_op, mlir_type, tf_data_type, data_type, data_type)
+
+#define GENERATE_UNARY_KERNEL2(tf_op, mlir_type, tf_data_type, data_type,     \
+                               input_data_type)                               \
   extern "C" ::UnrankedMemRefType<data_type> MLIR_FUNCTION(tf_op, mlir_type)( \
       tensorflow::OpKernelContext * ctx,                                      \
-      const ::UnrankedMemRefType<data_type>* arg);                            \
+      const ::UnrankedMemRefType<input_data_type>* arg);                      \
                                                                               \
   namespace {                                                                 \
   class MlirUnranked##tf_op##mlir_type##Op                                    \
       : public MlirUnrankedOp<tf_data_type, data_type,                        \
-                              MlirUnranked##tf_op##mlir_type##Op> {           \
+                              MlirUnranked##tf_op##mlir_type##Op,             \
+                              input_data_type> {                              \
    public:                                                                    \
     using MlirUnrankedOp::MlirUnrankedOp;                                     \
                                                                               \
     static ::UnrankedMemRefType<data_type> Invoke(                            \
         OpKernelContext* ctx,                                                 \
-        llvm::ArrayRef<::UnrankedMemRefType<data_type>> args) {               \
+        llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) {         \
       return MLIR_FUNCTION(tf_op, mlir_type)(ctx, &args[0]);                  \
     }                                                                         \
   };                                                                          \
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_imag.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_imag.cc
new file mode 100644
index 00000000000..dec53c9f5a0
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_imag.cc
@@ -0,0 +1,28 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <complex>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+
+namespace tensorflow {
+
+GENERATE_UNARY_KERNEL2(Imag, f32, DT_FLOAT, float, std::complex<float>);
+REGISTER_COMPLEX_KERNEL(Imag, f32, float, std::complex<float>);
+GENERATE_UNARY_KERNEL2(Imag, f64, DT_DOUBLE, double, std::complex<double>);
+REGISTER_COMPLEX_KERNEL(Imag, f64, double, std::complex<double>);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_real.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_real.cc
new file mode 100644
index 00000000000..8567060fd62
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_real.cc
@@ -0,0 +1,28 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <complex>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+
+namespace tensorflow {
+
+GENERATE_UNARY_KERNEL2(Real, f32, DT_FLOAT, float, std::complex<float>);
+REGISTER_COMPLEX_KERNEL(Real, f32, float, std::complex<float>);
+GENERATE_UNARY_KERNEL2(Real, f64, DT_DOUBLE, double, std::complex<double>);
+REGISTER_COMPLEX_KERNEL(Real, f64, double, std::complex<double>);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc
index a29c53a2978..e49d640f7e2 100644
--- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc
+++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc
@@ -23,5 +23,6 @@ GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f32, DT_FLOAT, float);
 GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f64, DT_DOUBLE, double);
 GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, i32, DT_INT32, int32);
 GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, i64, DT_INT64, int64);
+// TODO(b/162577610): Register the kernel for complex types and bfloat.
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
new file mode 100644
index 00000000000..2dd4a8dcda6
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc
@@ -0,0 +1,25 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
+GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
+GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index dc311ca9d3d..449b0cbf8ad 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -36,7 +36,8 @@ namespace functor {
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 typedef Eigen::GpuDevice GPUDevice;
-// Functor for SegmentSumGPUOp.
+// Functor for SegmentSumGPUOp & SegmentProdGPUOp & SegmentMaxGPUOp
+//             & SegmentMinGPUOp.
 // output_rows: the number of output segments (unique segment ids in
 //                'segment_ids').
 // segment_ids_shape: shape of 'segment_ids' tensor.
@@ -45,8 +46,9 @@ typedef Eigen::GpuDevice GPUDevice;
 // data_size: size of input data tensor.
 // data: input data tensor.
 // output: output reshaped to {output_rows, output.size/output_rows}
-template <typename T, typename Index>
-struct SegmentSumFunctor {
+template <typename T, typename Index, typename InitialValueF,
+          typename ReductionF, typename AtomicReductionF>
+struct SegmentReductionFunctor {
   void operator()(OpKernelContext* ctx, const GPUDevice& d,
                   const Index output_rows, const TensorShape& segment_ids_shape,
                   typename TTypes<Index>::ConstFlat segment_ids,
@@ -66,9 +68,10 @@ struct UnsortedSegmentFunctor {
 };
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-// reduction functors for the gpu
+
+// Atomic reduction functors for the gpu.
 template <typename T>
-struct SumOpGpu {
+struct AtomicSumOpGpu {
   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
                                                         const T& value) {
     GpuAtomicAdd(dest, value);
@@ -76,7 +79,7 @@ struct SumOpGpu {
 };
 
 template <typename T>
-struct ProdOpGpu {
+struct AtomicProdOpGpu {
   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
                                                         const T& value) {
     GpuAtomicMul(dest, value);
@@ -84,7 +87,7 @@ struct ProdOpGpu {
 };
 
 template <typename T>
-struct MaxOpGpu {
+struct AtomicMaxOpGpu {
   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
                                                         const T& value) {
     GpuAtomicMax(dest, value);
@@ -92,16 +95,49 @@ struct MaxOpGpu {
 };
 
 template <typename T>
-struct MinOpGpu {
+struct AtomicMinOpGpu {
   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
                                                         const T& value) {
     GpuAtomicMin(dest, value);
   }
 };
 
+// Non-atomic reduction functors for the gpu.
+template <typename T>
+struct NonAtomicSumOpGpu {
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
+                                                        const T& value) {
+    *dest += value;
+  }
+};
+
+template <typename T>
+struct NonAtomicProdOpGpu {
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
+                                                        const T& value) {
+    *dest *= value;
+  }
+};
+
+template <typename T>
+struct NonAtomicMaxOpGpu {
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
+                                                        const T& value) {
+    *dest = max(*dest, value);
+  }
+};
+
+template <typename T>
+struct NonAtomicMinOpGpu {
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
+                                                        const T& value) {
+    *dest = min(*dest, value);
+  }
+};
+
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
-// initial value functors
+// Initial value functors.
 template <typename T>
 struct Zero {
   EIGEN_STRONG_INLINE T operator()() const { return T(0); }
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
index 418af1d6b6d..c6f9d8b9158 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
@@ -31,14 +31,14 @@ namespace tensorflow {
 
 using GPUDevice = Eigen::GpuDevice;
 
-// SortedSegmentSumFunctor kernel reduces input data just as
-// UnsortedSegmentSumCustomKernel does except that input data
+// SortedSegmentReductionFunctor kernel reduces input data just as
+// UnsortedSegmentReductionCustomKernel does except that input data
 // is partitioned along the outer reduction dimension. This is
 // because consecutive rows (elements in a row share the same
 // outer dimension index) in the flattened 2D input data likely
 // belong to the same segment in sorted segment sum operation.
 // Therefore such partitioning strategy has two advantages over
-// the UnsortedSegmentSumFunctor kernel:
+// the UnsortedSegmentReductionFunctor kernel:
 // 1. Each thread reduces across multiple rows before writing
 // answers to the global memory, we can therefore
 // write reduction results to global memory less often.
@@ -51,18 +51,19 @@ using GPUDevice = Eigen::GpuDevice;
 // size OuterDimTileSize x 1. This strip runs across multiple
 // rows of input data and all reduction elements share one inner
 // dimension index.
-template <typename T, typename Index, int OuterDimTileSize>
-__global__ void SortedSegmentSumCustomKernel(
+template <typename T, typename Index, int OuterDimTileSize, typename ReductionF,
+          typename AtomicReductionF>
+__global__ void SortedSegmentReductionCustomKernel(
     const Index input_outer_dim_size, const Index inner_dim_size,
     const Index output_outer_dim_size, const Index* __restrict__ segment_ids,
     const T* __restrict__ input, T* __restrict__ output,
-    const Index total_stripe_count) {
+    const Index total_stripe_count, const T initial_value) {
   for (int stripe_index : GpuGridRangeX(total_stripe_count)) {
     const Index segment_offset = stripe_index % inner_dim_size;
     const Index input_outer_dim_index_base =
         stripe_index / inner_dim_size * Index(OuterDimTileSize);
 
-    T sum = T(0);
+    T reduce_res = initial_value;
     Index first_segment_id = segment_ids[input_outer_dim_index_base];
     Index last_output_segment_id = output_outer_dim_size;
 
@@ -72,24 +73,25 @@ __global__ void SortedSegmentSumCustomKernel(
     for (Index j = 0; j < actual_stripe_height; j++) {
       Index current_output_segment_id =
           segment_ids[input_outer_dim_index_base + j];
-      // Decide whether to write result to global memory.
-      // Result is only written to global memory if we move
-      // to another segment. Otherwise we can keep accumulating
-      // locally.
+      // Decide whether to write result to global memory. Result is only written
+      // to global memory if we move to another segment. Otherwise we can keep
+      // accumulating locally.
       if (current_output_segment_id > last_output_segment_id) {
         const Index output_index =
             last_output_segment_id * inner_dim_size + segment_offset;
-        // decide whether to write result to global memory using atomic
-        // operations
+        // Decide whether to write result to global memory using atomic
+        // operations.
         if (last_output_segment_id == first_segment_id) {
-          GpuAtomicAdd(output + output_index, sum);
+          AtomicReductionF()(output + output_index, reduce_res);
         } else {
-          *(output + output_index) = sum;
+          ReductionF()(output + output_index, reduce_res);
         }
-        sum = T(0);
+        reduce_res = initial_value;
       }
-      sum += ldg(input + (input_outer_dim_index_base + j) * inner_dim_size +
-                 segment_offset);
+      ReductionF()(
+          &reduce_res,
+          ldg(input + (input_outer_dim_index_base + j) * inner_dim_size +
+              segment_offset));
       last_output_segment_id = current_output_segment_id;
     }
     // For the last result in a strip, always write using atomic operations
@@ -97,7 +99,7 @@ __global__ void SortedSegmentSumCustomKernel(
     // the following strip.
     const Index output_index =
         last_output_segment_id * inner_dim_size + segment_offset;
-    GpuAtomicAdd(output + output_index, sum);
+    AtomicReductionF()(output + output_index, reduce_res);
   }
 }
 
@@ -126,25 +128,30 @@ __global__ void UnsortedSegmentCustomKernel(
 
 namespace functor {
 
-template <typename T, typename Index>
-void SegmentSumFunctor<T, Index>::operator()(
-    OpKernelContext* ctx, const GPUDevice& d, const Index output_rows,
-    const TensorShape& segment_ids_shape,
-    typename TTypes<Index>::ConstFlat segment_ids, const Index data_size,
-    const T* data, typename TTypes<T, 2>::Tensor output) {
+template <typename T, typename Index, typename InitialValueF,
+          typename ReductionF, typename AtomicReductionF>
+void SegmentReductionFunctor<
+    T, Index, InitialValueF, ReductionF,
+    AtomicReductionF>::operator()(OpKernelContext* ctx, const GPUDevice& d,
+                                  const Index output_rows,
+                                  const TensorShape& segment_ids_shape,
+                                  typename TTypes<Index>::ConstFlat segment_ids,
+                                  const Index data_size, const T* data,
+                                  typename TTypes<T, 2>::Tensor output) {
   if (output.size() == 0) {
     return;
   }
-  // Set 'output' to zeros.
+  // Set 'output' to initial value.
   GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
-  TF_CHECK_OK(GpuLaunchKernel(SetZero<T>, config.block_count,
+  const T InitialValue = InitialValueF()();
+  TF_CHECK_OK(GpuLaunchKernel(SetToValue<T>, config.block_count,
                               config.thread_per_block, 0, d.stream(),
-                              output.size(), output.data()));
+                              output.size(), output.data(), InitialValue));
   if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
     return;
   }
 
-  // Launch kernel to compute sorted segment sum.
+  // Launch kernel to compute sorted segment reduction.
   // Notes:
   // *) 'input_total_size' is the total number of elements to process.
   // *) 'segment_ids.shape' is a prefix of data's shape.
@@ -163,10 +170,12 @@ void SegmentSumFunctor<T, Index>::operator()(
 
   config = GetGpuLaunchConfig(total_stripe_count, d);
   TF_CHECK_OK(GpuLaunchKernel(
-      SortedSegmentSumCustomKernel<T, Index, OuterDimTileSize>,
+      SortedSegmentReductionCustomKernel<T, Index, OuterDimTileSize, ReductionF,
+                                         AtomicReductionF>,
       config.block_count, config.thread_per_block, 0, d.stream(),
       input_outer_dim_size, input_inner_dim_size, output_rows,
-      segment_ids.data(), data, output.data(), total_stripe_count));
+      segment_ids.data(), data, output.data(), total_stripe_count,
+      InitialValue));
 }
 
 template <typename T, typename Index, typename InitialValueF,
@@ -207,8 +216,19 @@ struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
   }
 };
 
-#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
-  template struct SegmentSumFunctor<T, Index>
+#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index)                           \
+  template struct SegmentReductionFunctor<T, Index, functor::Zero<T>,     \
+                                          functor::NonAtomicSumOpGpu<T>,  \
+                                          functor::AtomicSumOpGpu<T>>;    \
+  template struct SegmentReductionFunctor<T, Index, functor::One<T>,      \
+                                          functor::NonAtomicProdOpGpu<T>, \
+                                          functor::AtomicProdOpGpu<T>>;   \
+  template struct SegmentReductionFunctor<T, Index, functor::Highest<T>,  \
+                                          functor::NonAtomicMinOpGpu<T>,  \
+                                          functor::AtomicMinOpGpu<T>>;    \
+  template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>,   \
+                                          functor::NonAtomicMaxOpGpu<T>,  \
+                                          functor::AtomicMaxOpGpu<T>>;
 
 #define DEFINE_SORTED_GPU_SPECS(T)         \
   DEFINE_SORTED_GPU_SPECS_INDEX(T, int32); \
@@ -218,16 +238,16 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
 
 #define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index)                         \
   template struct UnsortedSegmentFunctor<                                      \
-      GPUDevice, T, Index, functor::Lowest<T>, functor::MaxOpGpu<T>>;          \
+      GPUDevice, T, Index, functor::Lowest<T>, functor::AtomicMaxOpGpu<T>>;    \
   template struct UnsortedSegmentFunctor<                                      \
-      GPUDevice, T, Index, functor::Highest<T>, functor::MinOpGpu<T>>;         \
+      GPUDevice, T, Index, functor::Highest<T>, functor::AtomicMinOpGpu<T>>;   \
   template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
-                                         functor::ProdOpGpu<T>>;
+                                         functor::AtomicProdOpGpu<T>>;
 
-// sum is the only op that supports all input types currently
+// Sum is the only op that supports all input types currently.
 #define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
   template struct UnsortedSegmentFunctor<             \
-      GPUDevice, T, Index, functor::Zero<T>, functor::SumOpGpu<T>>;
+      GPUDevice, T, Index, functor::Zero<T>, functor::AtomicSumOpGpu<T>>;
 
 #define DEFINE_REAL_GPU_SPECS(T)                  \
   DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl.h b/tensorflow/core/kernels/segment_reduction_ops_impl.h
index 7cf15ef5b72..3dd142d75e9 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl.h
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl.h
@@ -206,24 +206,26 @@ class SegmentReductionOp : public OpKernel {
 };
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-//  SegmentSumGPUOp is a segment sum operator implemented for GPU only.
-//  TODO: This implementation of SegmentSumGPUOp is sometimes slower than
+//  SegmentReductionGPUOp is a segment reduction operator implemented for GPU
+//  only.
+//  TODO: This implementation of SegmentReductionGPUOp is sometimes slower than
 //  its unsorted counterpart (mostly when problem size is small).
 //  This is due to the following two main reasons and a cost-effective way
 //  to resolve these problems is desirable.
-//  1. Sorted segment sum requires a memory transfer from device to host in
-//     order to know the size of the output dimension whereas unsorted segment
-//     sum receives the size of the output dimension as an input parameter.
-//  2. Sorted segment sum is essentially a tiled version of unsorted segment
-//     sum and therefore such optimization comes at an inherent cost. However
-//     such cost may not be justified when the problem size is small. When to
-//     use the tiled version or the untiled version depends on many factors
-//     including data alignments, ratio of calculation to memory traffic and
-//     obviously, the problem sizes.
-template <class T, class Index>
-class SegmentSumGPUOp : public AsyncOpKernel {
+//  1. Sorted segment reduction requires a memory transfer from device to host
+//     in order to know the size of the output dimension whereas unsorted
+//     segment reduction receives the size of the output dimension as an input
+//     parameter.
+//  2. Sorted segment reduction is essentially a tiled version of unsorted
+//     segment reduction and therefore such optimization comes at an inherent
+//     cost. However such cost may not be justified when the problem size is
+//     small. When to use the tiled version or the untiled version depends on
+//     many factors including data alignments, ratio of calculation to memory
+//     traffic and obviously, the problem sizes.
+template <class T, class Index, class SegmentReductionFunctor>
+class SegmentReductionGPUOp : public AsyncOpKernel {
  public:
-  explicit SegmentSumGPUOp(OpKernelConstruction* context)
+  explicit SegmentReductionGPUOp(OpKernelConstruction* context)
       : AsyncOpKernel(context) {}
 
   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
@@ -265,11 +267,11 @@ class SegmentSumGPUOp : public AsyncOpKernel {
             ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device,
                          sizeof(Index))
             .ok(),
-        errors::Internal(
-            "SegmentSumGPUOp: failed to copy output_rows from device"),
+        errors::Internal(type_string() +
+                         ": failed to copy output_rows from device"),
         done);
 
-    functor::SegmentSumFunctor<T, Index> functor_;
+    SegmentReductionFunctor functor_;
     auto create_and_check_output = [context, output_rows_host, &input,
                                     &segment_ids, &functor_, done]() {
       // Ensure that within the callback, the proper GPU settings are
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc
index f71a8dac462..97c0762c36f 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc
@@ -113,17 +113,39 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#define REGISTER_GPU_SORTED_KERNELS(type, index_type)                  \
-  REGISTER_KERNEL_BUILDER(Name("SegmentSum")                           \
-                              .Device(DEVICE_GPU)                      \
-                              .TypeConstraint<type>("T")               \
-                              .TypeConstraint<index_type>("Tindices"), \
-                          SegmentSumGPUOp<type, index_type>)
+#define REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                   \
+    name, type, index_type, initial_value_functor, reduction_kernel_functor, \
+    atomic_reduction_kernel_functor)                                         \
+  REGISTER_KERNEL_BUILDER(                                                   \
+      Name(name)                                                             \
+          .Device(DEVICE_GPU)                                                \
+          .TypeConstraint<type>("T")                                         \
+          .TypeConstraint<index_type>("Tindices"),                           \
+      SegmentReductionGPUOp<                                                 \
+          type, index_type,                                                  \
+          functor::SegmentReductionFunctor<                                  \
+              type, index_type, initial_value_functor,                       \
+              reduction_kernel_functor, atomic_reduction_kernel_functor> >)
+
+#define REGISTER_GPU_SORTED_KERNELS(type, index_type)                     \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
+      "SegmentSum", type, index_type, functor::Zero<type>,                \
+      functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>);   \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
+      "SegmentProd", type, index_type, functor::One<type>,                \
+      functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
+      "SegmentMin", type, index_type, functor::Highest<type>,             \
+      functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>);   \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
+      "SegmentMax", type, index_type, functor::Lowest<type>,              \
+      functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>);
 
 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
   REGISTER_GPU_SORTED_KERNELS(type, int32)
 
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
+#undef REGISTER_GPU_KERNEL_SORTEDSEGMENT
 #undef REGISTER_GPU_SORTED_KERNELS
 #undef REGISTER_GPU_SORTED_KERNELS_ALL
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc
index f2164260b8f..21b4ddf821d 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc
@@ -63,17 +63,39 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#define REGISTER_GPU_SORTED_KERNELS(type, index_type)                  \
-  REGISTER_KERNEL_BUILDER(Name("SegmentSum")                           \
-                              .Device(DEVICE_GPU)                      \
-                              .TypeConstraint<type>("T")               \
-                              .TypeConstraint<index_type>("Tindices"), \
-                          SegmentSumGPUOp<type, index_type>)
+#define REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                   \
+    name, type, index_type, initial_value_functor, reduction_kernel_functor, \
+    atomic_reduction_kernel_functor)                                         \
+  REGISTER_KERNEL_BUILDER(                                                   \
+      Name(name)                                                             \
+          .Device(DEVICE_GPU)                                                \
+          .TypeConstraint<type>("T")                                         \
+          .TypeConstraint<index_type>("Tindices"),                           \
+      SegmentReductionGPUOp<                                                 \
+          type, index_type,                                                  \
+          functor::SegmentReductionFunctor<                                  \
+              type, index_type, initial_value_functor,                       \
+              reduction_kernel_functor, atomic_reduction_kernel_functor> >)
+
+#define REGISTER_GPU_SORTED_KERNELS(type, index_type)                     \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
+      "SegmentSum", type, index_type, functor::Zero<type>,                \
+      functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>);   \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
+      "SegmentProd", type, index_type, functor::One<type>,                \
+      functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
+      "SegmentMin", type, index_type, functor::Highest<type>,             \
+      functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>);   \
+  REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                      \
+      "SegmentMax", type, index_type, functor::Lowest<type>,              \
+      functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>);
 
 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
   REGISTER_GPU_SORTED_KERNELS(type, int64);
 
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
+#undef REGISTER_GPU_KERNEL_SORTEDSEGMENT
 #undef REGISTER_GPU_SORTED_KERNELS
 #undef REGISTER_GPU_SORTED_KERNELS_ALL
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc
index eef5a532b29..c809a5ed1c1 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc
@@ -88,18 +88,18 @@ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type)                   \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
                                       functor::Lowest<type>,                   \
-                                      functor::MaxOpGpu<type>);                \
+                                      functor::AtomicMaxOpGpu<type>);          \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
                                       functor::Highest<type>,                  \
-                                      functor::MinOpGpu<type>);                \
+                                      functor::AtomicMinOpGpu<type>);          \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
                                       functor::One<type>,                      \
-                                      functor::ProdOpGpu<type>);
+                                      functor::AtomicProdOpGpu<type>);
 
 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type)                   \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
                                       functor::Zero<type>,                    \
-                                      functor::SumOpGpu<type>);
+                                      functor::AtomicSumOpGpu<type>);
 
 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32)
diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc
index cad6f8a5e08..c47e8d171e5 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc
@@ -88,18 +88,18 @@ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type)                   \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
                                       functor::Lowest<type>,                   \
-                                      functor::MaxOpGpu<type>);                \
+                                      functor::AtomicMaxOpGpu<type>);          \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
                                       functor::Highest<type>,                  \
-                                      functor::MinOpGpu<type>);                \
+                                      functor::AtomicMinOpGpu<type>);          \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
                                       functor::One<type>,                      \
-                                      functor::ProdOpGpu<type>);
+                                      functor::AtomicProdOpGpu<type>);
 
 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type)                   \
   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
                                       functor::Zero<type>,                    \
-                                      functor::SumOpGpu<type>);
+                                      functor::AtomicSumOpGpu<type>);
 
 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64)
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index f2f6889c900..3113de527e4 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -376,6 +376,61 @@ inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1,
   }
 }
 
+template <typename T, typename GradTy, typename GradeMaybeWithShrinkageTy,
+          typename AccumTy, typename LinearTy, typename VarTy>
+void ComputeFtrl(GradTy grad,
+                 GradeMaybeWithShrinkageTy grad_maybe_with_shrinkage,
+                 AccumTy accum, LinearTy linear, VarTy var, T l1_scalar,
+                 T l2_scalar, bool multiply_linear_by_lr, T lr_power_scalar,
+                 T lr_scalar) {
+  auto new_accum = accum + grad.square();
+  if (multiply_linear_by_lr) {
+    if (lr_power_scalar == static_cast<T>(-0.5)) {
+      linear += grad_maybe_with_shrinkage * lr_scalar -
+                (new_accum.sqrt() - accum.sqrt()) * var;
+    } else {
+      linear +=
+          grad_maybe_with_shrinkage * lr_scalar -
+          (new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) * var;
+    }
+  } else {
+    if (lr_power_scalar == static_cast<T>(-0.5)) {
+      linear += grad_maybe_with_shrinkage -
+                (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var;
+    } else {
+      linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) -
+                                             accum.pow(-lr_power_scalar)) /
+                                                lr_scalar * var;
+    }
+  }
+  auto l1_reg_adjust =
+      (multiply_linear_by_lr ? linear.cwiseMin(l1_scalar * lr_scalar)
+                                   .cwiseMax(-l1_scalar * lr_scalar)
+                             : linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar));
+  auto x = l1_reg_adjust - linear;
+  if (multiply_linear_by_lr) {
+    if (lr_power_scalar == static_cast<T>(-0.5)) {
+      auto y = new_accum.sqrt() +
+               linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar);
+      var = x / y;
+    } else {
+      auto y = new_accum.pow(-lr_power_scalar) +
+               linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar);
+      var = x / y;
+    }
+  } else {
+    if (lr_power_scalar == static_cast<T>(-0.5)) {
+      auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) +
+               linear.constant(static_cast<T>(2) * l2_scalar);
+      var = x / y;
+    } else {
+      auto y = new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) +
+               linear.constant(static_cast<T>(2) * l2_scalar);
+      var = x / y;
+    }
+  }
+  accum += grad.square();
+}
 }  // namespace
 
 template <typename T, typename Tindex, bool has_l2_shrinkage>
@@ -417,70 +472,25 @@ struct SparseApplyFtrl<CPUDevice, T, Tindex, has_l2_shrinkage> {
           auto grad = grad_flat.template chip<0>(i);
           auto var = var_flat.template chip<0>(index);
 
-// TODO(sanjoy): Remove this macro.
-// Use a macro to implement the computation here due to the templating of the
-// eigen tensor library.
-#define COMPUTE_FTRL(grad, grad_maybe_with_shrinkage)                          \
-  auto new_accum = accum + grad.square();                                      \
-  if (multiply_linear_by_lr) {                                                 \
-    if (lr_power_scalar == static_cast<T>(-0.5)) {                             \
-      linear += grad_maybe_with_shrinkage * lr_scalar -                        \
-                (new_accum.sqrt() - accum.sqrt()) * var;                       \
-    } else {                                                                   \
-      linear +=                                                                \
-          grad_maybe_with_shrinkage * lr_scalar -                              \
-          (new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) *    \
-              var;                                                             \
-    }                                                                          \
-  } else {                                                                     \
-    if (lr_power_scalar == static_cast<T>(-0.5)) {                             \
-      linear += grad_maybe_with_shrinkage -                                    \
-                (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var;           \
-    } else {                                                                   \
-      linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \
-                                             accum.pow(-lr_power_scalar)) /    \
-                                                lr_scalar * var;               \
-    }                                                                          \
-  }                                                                            \
-  auto l1_reg_adjust =                                                         \
-      (multiply_linear_by_lr                                                   \
-           ? linear.cwiseMin(l1_scalar * lr_scalar)                            \
-                 .cwiseMax(-l1_scalar * lr_scalar)                             \
-           : linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar));                 \
-  auto x = l1_reg_adjust - linear;                                             \
-  if (multiply_linear_by_lr) {                                                 \
-    if (lr_power_scalar == static_cast<T>(-0.5)) {                             \
-      auto y = new_accum.sqrt() +                                              \
-               linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar);     \
-      var = x / y;                                                             \
-    } else {                                                                   \
-      auto y = new_accum.pow(-lr_power_scalar) +                               \
-               linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar);     \
-      var = x / y;                                                             \
-    }                                                                          \
-  } else {                                                                     \
-    if (lr_power_scalar == static_cast<T>(-0.5)) {                             \
-      auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) +              \
-               linear.constant(static_cast<T>(2) * l2_scalar);                 \
-      var = x / y;                                                             \
-    } else {                                                                   \
-      auto y =                                                                 \
-          new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) +    \
-          linear.constant(static_cast<T>(2) * l2_scalar);                      \
-      var = x / y;                                                             \
-    }                                                                          \
-  }                                                                            \
-  accum += grad.square();
-
           if (has_l2_shrinkage) {
             auto grad_with_shrinkage =
                 grad + static_cast<T>(2) * l2_shrinkage_scalar * var;
-            COMPUTE_FTRL(grad, grad_with_shrinkage);
+            ComputeFtrl(/*grad=*/grad,
+                        /*grad_maybe_with_shrinkage=*/grad_with_shrinkage,
+                        /*accum=*/accum, /*linear=*/linear, /*var=*/var,
+                        /*l1_scalar=*/l1_scalar, /*l2_scalar=*/l2_scalar,
+                        /*multiply_linear_by_lr=*/multiply_linear_by_lr,
+                        /*lr_power_scalar=*/lr_power_scalar,
+                        /*lr_scalar=*/lr_scalar);
           } else {
-            COMPUTE_FTRL(grad, grad);
+            ComputeFtrl(/*grad=*/grad, /*grad_maybe_with_shrinkage=*/grad,
+                        /*accum=*/accum, /*linear=*/linear, /*var=*/var,
+                        /*l1_scalar=*/l1_scalar, /*l2_scalar=*/l2_scalar,
+                        /*multiply_linear_by_lr=*/multiply_linear_by_lr,
+                        /*lr_power_scalar=*/lr_power_scalar,
+                        /*lr_scalar=*/lr_scalar);
           }
         }
-#undef COMPUTE_FTRL
       } else {
         const Tindex first_dim_size = accum_flat.size();
 
diff --git a/tensorflow/core/lib/bmp/BUILD b/tensorflow/core/lib/bmp/BUILD
index 186c3a0753f..3e21a027265 100644
--- a/tensorflow/core/lib/bmp/BUILD
+++ b/tensorflow/core/lib/bmp/BUILD
@@ -1,24 +1,12 @@
 # Description:
-# bmp test data packages.
-
-load("//tensorflow:tensorflow.bzl", "filegroup")
+# bmp test data package alias.
 
 package(
     licenses = ["notice"],  # Apache 2.0
 )
 
-filegroup(
+alias(
     name = "bmp_testdata",
-    srcs = [
-        # BMP data
-        "testdata/lena.bmp",
-        "testdata/rgb_small.bmp",
-        "testdata/rgb_small_255.bmp",
-        "testdata/rgba_small.bmp",
-        "testdata/rgba_small_255.bmp",
-        "testdata/grayscale_small.bmp",
-        "testdata/grayscale_small_3channels.bmp",
-        "testdata/grayscale_small_4channels.bmp",
-    ],
+    actual = "//tensorflow/core/lib/bmp/testdata:bmp_testdata",
     visibility = ["//visibility:public"],
 )
diff --git a/tensorflow/core/lib/bmp/testdata/BUILD b/tensorflow/core/lib/bmp/testdata/BUILD
new file mode 100644
index 00000000000..da11c447add
--- /dev/null
+++ b/tensorflow/core/lib/bmp/testdata/BUILD
@@ -0,0 +1,24 @@
+# Description:
+# bmp test data packages.
+
+load("//tensorflow:tensorflow.bzl", "filegroup")
+
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "bmp_testdata",
+    srcs = [
+        # BMP data
+        "lena.bmp",
+        "rgb_small.bmp",
+        "rgb_small_255.bmp",
+        "rgba_small.bmp",
+        "rgba_small_255.bmp",
+        "grayscale_small.bmp",
+        "grayscale_small_3channels.bmp",
+        "grayscale_small_4channels.bmp",
+    ],
+    visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index c4bd397ee60..e00ebf7325a 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -853,7 +853,18 @@ REGISTER_OP("OptionalFromValue")
     .Input("components: Toutput_types")
     .Output("optional: variant")
     .Attr("Toutput_types: list(type) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape);
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      std::vector<DataType> dtypes;
+      TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes));
+      c->set_output(0, c->Scalar());
+      std::vector<shape_inference::ShapeAndType> shapes_and_types;
+      shapes_and_types.reserve(c->num_inputs());
+      for (int i = 0; i < c->num_inputs(); ++i) {
+        shapes_and_types.emplace_back(c->input(i), dtypes[i], ST_OPTIONAL);
+      }
+      c->set_output_handle_shapes_and_types(0, shapes_and_types);
+      return Status::OK();
+    });
 
 REGISTER_OP("OptionalNone")
     .Output("optional: variant")
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index 10df2e12038..e7a060e73cc 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -609,7 +609,7 @@ int CurlHttpRequest::ProgressCallback(void* this_object, curl_off_t dltotal,
 
     double starttransfer_time = -1;
     const auto starttransfer_time_status = that->libcurl_->curl_easy_getinfo(
-        that->curl_, CURLINFO_PRETRANSFER_TIME, &starttransfer_time);
+        that->curl_, CURLINFO_STARTTRANSFER_TIME, &starttransfer_time);
 
     LOG(ERROR) << "The transmission  of request " << this_object
                << " (URI: " << that->uri_ << ") has been stuck at "
diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD
index 74370988d07..01fd9d8bb3d 100644
--- a/tensorflow/core/platform/default/BUILD
+++ b/tensorflow/core/platform/default/BUILD
@@ -202,6 +202,7 @@ cc_library(
         "//tensorflow/core/platform",
         "//tensorflow/core/platform:env_time",
         "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
         "//tensorflow/core/platform:types",
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/strings",
diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc
index 43d8545e6fd..b19c2630f23 100644
--- a/tensorflow/core/platform/default/logging.cc
+++ b/tensorflow/core/platform/default/logging.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include "absl/base/internal/sysinfo.h"
 #include "tensorflow/core/platform/env_time.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
 
 #if defined(PLATFORM_POSIX_ANDROID)
 #include <android/log.h>
@@ -31,30 +32,135 @@ limitations under the License.
 #include <string.h>
 #include <time.h>
 
-#include <string>
+#include <algorithm>
+#include <queue>
 #include <unordered_map>
 
 namespace tensorflow {
 
-void TFAddLogSink(TFLogSink* sink) {
-  // LogSink is not implemented.
-  // If necessary, one can add the log sink support as follows.
-  // 1. Define a global vector<TFLogSink> to keep track of all registered
-  //    TFLogSink objects. Protect the global vector with mutex to make it
-  //    thread-safe.
-  // 2. Add/remove elements from the global vector<TFLogSink> in TFAddLogSink
-  //    and TFRemoveLogSink function
-  // 3. Add logic in LogMessage::GenerateLogMessage() below to dispatch log
-  //    messages to all the registered log sinks.
-}
-
-void TFRemoveLogSink(TFLogSink* sink) {
-  // LogSink is not implemented.
-}
-
 namespace internal {
 namespace {
 
+// This is an internal singleton class that manages the log sinks. It allows
+// adding and removing the log sinks, as well as handling sending log messages
+// to all the registered log sinks.
+class TFLogSinks {
+ public:
+  // Gets the TFLogSinks instance. This is the entry point for using this class.
+  static TFLogSinks& Instance();
+
+  // Adds a log sink. The sink argument must not be a nullptr. TFLogSinks
+  // takes ownership of the pointer, the user must not free the pointer.
+  // The pointer will remain valid until the application terminates or
+  // until TFLogSinks::Remove is called for the same pointer value.
+  void Add(TFLogSink* sink);
+
+  // Removes a log sink. This will also erase the sink object. The pointer
+  // to the sink becomes invalid after this call.
+  void Remove(TFLogSink* sink);
+
+  // Gets the currently registered log sinks.
+  std::vector<TFLogSink*> GetSinks() const;
+
+  // Sends a log message to all registered log sinks.
+  //
+  // If there are no log sinks are registered:
+  //
+  // NO_DEFAULT_LOGGER is defined:
+  // Up to 128 messages will be queued until a log sink is added.
+  // The queue will then be logged to the first added log sink.
+  //
+  // NO_DEFAULT_LOGGER is not defined:
+  // The messages will be logged using the default logger. The default logger
+  // will log to stdout on all platforms except for Android. On Androit the
+  // default Android logger will be used.
+  void Send(const TFLogEntry& entry);
+
+ private:
+  TFLogSinks();
+  void SendToSink(TFLogSink& sink, const TFLogEntry& entry);
+
+  std::queue<TFLogEntry> log_entry_queue_;
+  static const size_t kMaxLogEntryQueueSize = 128;
+
+  mutable tensorflow::mutex mutex_;
+  std::vector<TFLogSink*> sinks_;
+};
+
+TFLogSinks::TFLogSinks() {
+#ifndef NO_DEFAULT_LOGGER
+  static TFDefaultLogSink* default_sink = new TFDefaultLogSink();
+  sinks_.emplace_back(default_sink);
+#endif
+}
+
+TFLogSinks& TFLogSinks::Instance() {
+  static TFLogSinks* instance = new TFLogSinks();
+  return *instance;
+}
+
+void TFLogSinks::Add(TFLogSink* sink) {
+  assert(sink != nullptr && "The sink must not be a nullptr");
+
+  tensorflow::mutex_lock lock(mutex_);
+  sinks_.emplace_back(sink);
+
+  // If this is the only sink log all the queued up messages to this sink
+  if (sinks_.size() == 1) {
+    while (!log_entry_queue_.empty()) {
+      for (const auto& sink : sinks_) {
+        SendToSink(*sink, log_entry_queue_.front());
+      }
+      log_entry_queue_.pop();
+    }
+  }
+}
+
+void TFLogSinks::Remove(TFLogSink* sink) {
+  assert(sink != nullptr && "The sink must not be a nullptr");
+
+  tensorflow::mutex_lock lock(mutex_);
+  auto it = std::find(sinks_.begin(), sinks_.end(), sink);
+  if (it != sinks_.end()) sinks_.erase(it);
+}
+
+std::vector<TFLogSink*> TFLogSinks::GetSinks() const {
+  tensorflow::mutex_lock lock(mutex_);
+  return sinks_;
+}
+
+void TFLogSinks::Send(const TFLogEntry& entry) {
+  tensorflow::mutex_lock lock(mutex_);
+
+  // If we don't have any sinks registered, queue them up
+  if (sinks_.empty()) {
+    // If we've exceeded the maximum queue size, drop the oldest entries
+    while (log_entry_queue_.size() >= kMaxLogEntryQueueSize) {
+      log_entry_queue_.pop();
+    }
+    log_entry_queue_.push(entry);
+    return;
+  }
+
+  // If we have items in the queue, push them out first
+  while (!log_entry_queue_.empty()) {
+    for (const auto& sink : sinks_) {
+      SendToSink(*sink, log_entry_queue_.front());
+    }
+    log_entry_queue_.pop();
+  }
+
+  // ... and now we can log the current log entry
+  for (const auto& sink : sinks_) {
+    SendToSink(*sink, entry);
+  }
+}
+
+void TFLogSinks::SendToSink(TFLogSink& sink, const TFLogEntry& entry) {
+  sink.Send(entry);
+  sink.WaitTillSent();
+}
+
 int ParseInteger(const char* str, size_t size) {
   // Ideally we would use env_var / safe_strto64, but it is
   // hard to use here without pulling in a lot of dependencies,
@@ -197,70 +303,10 @@ LogMessage::~LogMessage() {
   }
 }
 
-#if defined(PLATFORM_POSIX_ANDROID)
 void LogMessage::GenerateLogMessage() {
-  int android_log_level;
-  switch (severity_) {
-    case INFO:
-      android_log_level = ANDROID_LOG_INFO;
-      break;
-    case WARNING:
-      android_log_level = ANDROID_LOG_WARN;
-      break;
-    case ERROR:
-      android_log_level = ANDROID_LOG_ERROR;
-      break;
-    case FATAL:
-      android_log_level = ANDROID_LOG_FATAL;
-      break;
-    default:
-      if (severity_ < INFO) {
-        android_log_level = ANDROID_LOG_VERBOSE;
-      } else {
-        android_log_level = ANDROID_LOG_ERROR;
-      }
-      break;
-  }
-
-  std::stringstream ss;
-  const char* const partial_name = strrchr(fname_, '/');
-  ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_
-     << " " << str();
-  __android_log_write(android_log_level, "native", ss.str().c_str());
-
-  // Also log to stderr (for standalone Android apps).
-  fprintf(stderr, "native : %s\n", ss.str().c_str());
-
-  // Android logging at level FATAL does not terminate execution, so abort()
-  // is still required to stop the program.
-  if (severity_ == FATAL) {
-    abort();
-  }
+  TFLogSinks::Instance().Send(TFLogEntry(severity_, fname_, line_, str()));
 }
 
-#else
-
-void LogMessage::GenerateLogMessage() {
-  static bool log_thread_id = EmitThreadIdFromEnv();
-  uint64 now_micros = EnvTime::NowMicros();
-  time_t now_seconds = static_cast<time_t>(now_micros / 1000000);
-  int32 micros_remainder = static_cast<int32>(now_micros % 1000000);
-  const size_t time_buffer_size = 30;
-  char time_buffer[time_buffer_size];
-  strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S",
-           localtime(&now_seconds));
-  const size_t tid_buffer_size = 10;
-  char tid_buffer[tid_buffer_size] = "";
-  if (log_thread_id) {
-    snprintf(tid_buffer, sizeof(tid_buffer), " %7u",
-             absl::base_internal::GetTID());
-  }
-  // TODO(jeff,sanjay): Replace this with something that logs through the env.
-  fprintf(stderr, "%s.%06d: %c%s %s:%d] %s\n", time_buffer, micros_remainder,
-          "IWEF"[severity_], tid_buffer, fname_, line_, str().c_str());
-}
-#endif
-
 int64 LogMessage::MinVLogLevel() {
   static int64 min_vlog_level = MinVLogLevelFromEnv();
   return min_vlog_level;
@@ -393,4 +439,104 @@ bool LogEveryNSecState::ShouldLog(double seconds) {
 }
 
 }  // namespace internal
+
+void TFAddLogSink(TFLogSink* sink) {
+  internal::TFLogSinks::Instance().Add(sink);
+}
+
+void TFRemoveLogSink(TFLogSink* sink) {
+  internal::TFLogSinks::Instance().Remove(sink);
+}
+
+std::vector<TFLogSink*> TFGetLogSinks() {
+  return internal::TFLogSinks::Instance().GetSinks();
+}
+
+void TFDefaultLogSink::Send(const TFLogEntry& entry) {
+#ifdef PLATFORM_POSIX_ANDROID
+  int android_log_level;
+  switch (entry.log_severity()) {
+    case absl::LogSeverity::kInfo:
+      android_log_level = ANDROID_LOG_INFO;
+      break;
+    case absl::LogSeverity::kWarning:
+      android_log_level = ANDROID_LOG_WARN;
+      break;
+    case absl::LogSeverity::kError:
+      android_log_level = ANDROID_LOG_ERROR;
+      break;
+    case absl::LogSeverity::kFatal:
+      android_log_level = ANDROID_LOG_FATAL;
+      break;
+    default:
+      if (entry.log_severity() < absl::LogSeverity::kInfo) {
+        android_log_level = ANDROID_LOG_VERBOSE;
+      } else {
+        android_log_level = ANDROID_LOG_ERROR;
+      }
+      break;
+  }
+
+  std::stringstream ss;
+  const auto& fname = entry.FName();
+  auto pos = fname.find("/");
+  ss << (pos != std::string::npos ? fname.substr(pos + 1) : fname) << ":"
+     << entry.Line() << " " << entry.ToString();
+  __android_log_write(android_log_level, "native", ss.str().c_str());
+
+  // Also log to stderr (for standalone Android apps).
+  std::cerr << "native : " << ss.str() << std::endl;
+
+  // Android logging at level FATAL does not terminate execution, so abort()
+  // is still required to stop the program.
+  if (entry.log_severity() == absl::LogSeverity::kFatal) {
+    abort();
+  }
+#else   // PLATFORM_POSIX_ANDROID
+  static bool log_thread_id = internal::EmitThreadIdFromEnv();
+  uint64 now_micros = EnvTime::NowMicros();
+  time_t now_seconds = static_cast<time_t>(now_micros / 1000000);
+  int32 micros_remainder = static_cast<int32>(now_micros % 1000000);
+  const size_t time_buffer_size = 30;
+  char time_buffer[time_buffer_size];
+  strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S",
+           localtime(&now_seconds));
+  const size_t tid_buffer_size = 10;
+  char tid_buffer[tid_buffer_size] = "";
+  if (log_thread_id) {
+    snprintf(tid_buffer, sizeof(tid_buffer), " %7u",
+             absl::base_internal::GetTID());
+  }
+
+  char sev;
+  switch (entry.log_severity()) {
+    case absl::LogSeverity::kInfo:
+      sev = 'I';
+      break;
+
+    case absl::LogSeverity::kWarning:
+      sev = 'W';
+      break;
+
+    case absl::LogSeverity::kError:
+      sev = 'E';
+      break;
+
+    case absl::LogSeverity::kFatal:
+      sev = 'F';
+      break;
+
+    default:
+      assert(false && "Unknown logging severity");
+      sev = '?';
+      break;
+  }
+
+  // TODO(jeff,sanjay): Replace this with something that logs through the env.
+  fprintf(stderr, "%s.%06d: %c%s %s:%d] %s\n", time_buffer, micros_remainder,
+          sev, tid_buffer, entry.FName().c_str(), entry.Line(),
+          entry.ToString().c_str());
+#endif  // PLATFORM_POSIX_ANDROID
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h
index f60deb43683..56dd19c8aa0 100644
--- a/tensorflow/core/platform/default/logging.h
+++ b/tensorflow/core/platform/default/logging.h
@@ -23,6 +23,7 @@ limitations under the License.
 #include <limits>
 #include <memory>
 #include <sstream>
+#include <vector>
 
 #include "absl/base/log_severity.h"
 #include "absl/strings/string_view.h"
@@ -477,15 +478,27 @@ class TFLogEntry {
   }
 
  public:
-  explicit TFLogEntry(int severity, absl::string_view log_line)
-      : severity_(AsAbslLogSeverity(severity)), log_line_(log_line) {}
+  explicit TFLogEntry(int severity, absl::string_view message)
+      : severity_(AsAbslLogSeverity(severity)), message_(message) {}
+
+  explicit TFLogEntry(int severity, absl::string_view fname, int line,
+                      absl::string_view message)
+      : severity_(AsAbslLogSeverity(severity)),
+        fname_(fname),
+        line_(line),
+        message_(message) {}
 
   absl::LogSeverity log_severity() const { return severity_; }
-  std::string ToString() const { return std::string(log_line_); }
+  std::string FName() const { return fname_; }
+  int Line() const { return line_; }
+  std::string ToString() const { return message_; }
+  absl::string_view text_message() const { return message_; }
 
  private:
   const absl::LogSeverity severity_;
-  const absl::string_view log_line_;
+  const std::string fname_;
+  int line_ = -1;
+  const std::string message_;
 };
 
 class TFLogSink {
@@ -513,10 +526,23 @@ class TFLogSink {
   virtual void WaitTillSent() {}
 };
 
+// This is the default log sink. This log sink is used if there are no other
+// log sinks registered. To disable the default log sink, set the
+// "no_default_logger" Bazel config setting to true or define a
+// NO_DEFAULT_LOGGER preprocessor symbol. This log sink will always log to
+// stderr.
+class TFDefaultLogSink : public TFLogSink {
+ public:
+  void Send(const TFLogEntry& entry) override;
+};
+
 // Add or remove a `LogSink` as a consumer of logging data.  Thread-safe.
 void TFAddLogSink(TFLogSink* sink);
 void TFRemoveLogSink(TFLogSink* sink);
 
+// Get all the log sinks.  Thread-safe.
+std::vector<TFLogSink*> TFGetLogSinks();
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_PLATFORM_DEFAULT_LOGGING_H_
diff --git a/tensorflow/core/platform/errors.h b/tensorflow/core/platform/errors.h
index 55af45a4c24..fbd3b518699 100644
--- a/tensorflow/core/platform/errors.h
+++ b/tensorflow/core/platform/errors.h
@@ -44,7 +44,7 @@ namespace internal {
 // Eventually absl::strings will have native support for this and we will be
 // able to completely remove PrepareForStrCat().
 template <typename T>
-typename std::enable_if<!std::is_constructible<strings::AlphaNum, T>::value,
+typename std::enable_if<!std::is_convertible<T, strings::AlphaNum>::value,
                         std::string>::type
 PrepareForStrCat(const T& t) {
   std::stringstream ss;
diff --git a/tensorflow/core/platform/logging_test.cc b/tensorflow/core/platform/logging_test.cc
index 75da1a4e484..54890bdd2f0 100644
--- a/tensorflow/core/platform/logging_test.cc
+++ b/tensorflow/core/platform/logging_test.cc
@@ -14,6 +14,10 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/platform/logging.h"
+
+#include <sstream>
+#include <vector>
+
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
@@ -96,4 +100,33 @@ TEST(InternalLogString, Basic) {
   internal::LogString(__FILE__, __LINE__, INFO, "Hello there");
 }
 
+class TestSink : public TFLogSink {
+ public:
+  void Send(const TFLogEntry& entry) override {
+    ss_ << entry.text_message() << std::endl;
+  }
+
+  std::string Get() const { return ss_.str(); }
+
+ private:
+  std::stringstream ss_;
+};
+
+TEST(LogSinkTest, testLogSinks) {
+  const int sinks_initial_size = TFGetLogSinks().size();
+  TestSink sink;
+
+  TFAddLogSink(&sink);
+
+  EXPECT_EQ(TFGetLogSinks().size(), sinks_initial_size + 1);
+
+  LOG(INFO) << "Foo";
+  LOG(INFO) << "Bar";
+  EXPECT_EQ(sink.Get(), "Foo\nBar\n");
+
+  TFRemoveLogSink(&sink);
+
+  EXPECT_EQ(TFGetLogSinks().size(), sinks_initial_size);
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
index e1dbaed1746..b33cd911afa 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
@@ -210,12 +210,9 @@ OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace) {
   return result;
 }
 
-OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(
-    const XPlane& device_trace, double peak_tera_flops_per_second,
-    double peak_hbm_bw_giga_bytes_per_second) {
+OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace) {
   OpMetricsDb result;
-  DeviceOpMetricsDbBuilder device_op_metrics_db_builder(
-      &result, peak_tera_flops_per_second, peak_hbm_bw_giga_bytes_per_second);
+  DeviceOpMetricsDbBuilder device_op_metrics_db_builder(&result);
 
   int64 first_op_offset_ps = kint64max;
   int64 last_op_offset_ps = 0;
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h
index f2d7fc702fc..93ab7339a37 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h
@@ -50,9 +50,7 @@ void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst);
 
 OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace);
 
-OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(
-    const XPlane& device_trace, double peak_tera_flops_per_second,
-    double peak_hbm_bw_giga_bytes_per_second);
+OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace);
 
 }  // namespace profiler
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc
index 7d6f23db041..7be6277ba58 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc
@@ -132,9 +132,7 @@ TEST(ConvertXPlaneToOpMetricsDb, DeviceOpMetricsDb) {
                        kKernel3DurationNs, /*on_device=*/true, kKernel3,
                        &device_plane, &stream2);
 
-  OpMetricsDb op_metrics = ConvertDeviceTraceXPlaneToOpMetricsDb(
-      *xplane, /*peak_tera_flops_per_second=*/0,
-      /*peak_hbm_bw_giga_bytes_per_second=*/0);
+  OpMetricsDb op_metrics = ConvertDeviceTraceXPlaneToOpMetricsDb(*xplane);
 
   // kernel1, kernel2, kernel3, Idle.
   EXPECT_EQ(4, op_metrics.metrics_db_size());
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
index 6eb67eab216..e6b84b6b873 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
@@ -170,10 +170,8 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space,
       if (!op_stats.has_perf_env()) {
         *op_stats.mutable_perf_env() = GetPerfEnvFromXPlane(*device_trace);
       }
-      const PerfEnv& perf_env = op_stats.perf_env();
-      OpMetricsDb device_op_metrics_db = ConvertDeviceTraceXPlaneToOpMetricsDb(
-          *device_trace, perf_env.peak_tera_flops_per_second(),
-          perf_env.peak_hbm_bw_giga_bytes_per_second());
+      OpMetricsDb device_op_metrics_db =
+          ConvertDeviceTraceXPlaneToOpMetricsDb(*device_trace);
       op_metrics_db_combiner.Combine(device_op_metrics_db);
     }
     if (options.generate_step_db) {
diff --git a/tensorflow/core/profiler/protobuf/xplane.proto b/tensorflow/core/profiler/protobuf/xplane.proto
index f57d7609891..3936c44227b 100644
--- a/tensorflow/core/profiler/protobuf/xplane.proto
+++ b/tensorflow/core/profiler/protobuf/xplane.proto
@@ -118,7 +118,7 @@ message XStat {
 
 // Metadata for an XEvent, corresponds to an event type and is shared by
 // all XEvents with the same metadata_id.
-// Next ID: 6
+// Next ID: 7
 message XEventMetadata {
   // XPlane.event_metadata map key.
   int64 id = 1;
@@ -135,6 +135,9 @@ message XEventMetadata {
   // XStats that are constant for all XEvents with the same metadata_id.
   // Each of these XStats should have a different metadata_id.
   repeated XStat stats = 5;
+
+  // XPlane.event_metadata map key for children events.
+  repeated int64 child_id = 6;
 }
 
 // Metadata for an XStat, corresponds to a stat type and is shared by all
diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD
index ea37a8852a0..9bf0179859d 100644
--- a/tensorflow/core/profiler/rpc/client/BUILD
+++ b/tensorflow/core/profiler/rpc/client/BUILD
@@ -117,6 +117,7 @@ cc_library(
 tf_cc_test(
     name = "profiler_client_test",
     srcs = ["profiler_client_test.cc"],
+    tags = ["notap"],  # b/173824689
     deps = [
         ":profiler_client",
         ":profiler_client_impl",  # for oss
diff --git a/tensorflow/core/profiler/utils/op_utils.h b/tensorflow/core/profiler/utils/op_utils.h
index 0aa606f0144..bb7c3103de5 100644
--- a/tensorflow/core/profiler/utils/op_utils.h
+++ b/tensorflow/core/profiler/utils/op_utils.h
@@ -49,12 +49,7 @@ class HostOpMetricsDbBuilder : public OpMetricsDbBuilder {
 
 class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder {
  public:
-  explicit DeviceOpMetricsDbBuilder(OpMetricsDb* db,
-                                    double peak_tera_flops_per_second,
-                                    double peak_hbm_bw_giga_bytes_per_second)
-      : OpMetricsDbBuilder(db),
-        peak_tera_flops_per_second_(peak_tera_flops_per_second),
-        peak_hbm_bw_giga_bytes_per_second_(peak_hbm_bw_giga_bytes_per_second) {}
+  explicit DeviceOpMetricsDbBuilder(OpMetricsDb* db) : OpMetricsDbBuilder(db) {}
 
   // A function that will be called when the end of an OP is
   // observed on a trace, where:
@@ -78,12 +73,6 @@ class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder {
                uint64 children_time_ps, int64 flops, int64 bytes_accessed,
                const protobuf::RepeatedPtrField<OpMetrics::MemoryAccessed>&
                    memory_accessed_breakdown = {});
-
- protected:
-  // Peak performance of a TensorCore or a GPU in TFLOP/s.
-  double peak_tera_flops_per_second_;
-  // Peak memory bandwidth of a TensorCore or a GPU in GiBs/s.
-  double peak_hbm_bw_giga_bytes_per_second_;
 };
 
 }  // namespace profiler
diff --git a/tensorflow/core/profiler/utils/xplane_visitor.h b/tensorflow/core/profiler/utils/xplane_visitor.h
index e7ac97f3098..73dd8db577a 100644
--- a/tensorflow/core/profiler/utils/xplane_visitor.h
+++ b/tensorflow/core/profiler/utils/xplane_visitor.h
@@ -83,15 +83,15 @@ class XStatVisitor {
 template <class T>
 class XStatsOwner {
  public:
-  // REQUIRED: metadata and stats_owner cannot be nullptr.
-  XStatsOwner(const XPlaneVisitor* metadata, const T* stats_owner)
-      : stats_owner_(stats_owner), metadata_(metadata) {}
+  // REQUIRED: plane and stats_owner cannot be nullptr.
+  XStatsOwner(const XPlaneVisitor* plane, const T* stats_owner)
+      : plane_(plane), stats_owner_(stats_owner) {}
 
-  // For each plane level stats, call the specified lambda.
+  // For each stat, call the specified lambda.
   template <typename ForEachStatFunc>
   void ForEachStat(ForEachStatFunc&& for_each_stat) const {
     for (const XStat& stat : stats_owner_->stats()) {
-      for_each_stat(XStatVisitor(metadata_, &stat));
+      for_each_stat(XStatVisitor(plane_, &stat));
     }
   }
 
@@ -100,12 +100,35 @@ class XStatsOwner {
   // Prefer ForEachStat above when multiple stat values are necessary.
   absl::optional<XStatVisitor> GetStat(int64 stat_type) const;
 
+ protected:
+  const XPlaneVisitor* plane() const { return plane_; }
+  const T* stats_owner() const { return stats_owner_; }
+
  private:
+  const XPlaneVisitor* plane_;
   const T* stats_owner_;
-  const XPlaneVisitor* metadata_;
 };
 
-using XEventMetadataVisitor = XStatsOwner<XEventMetadata>;
+class XEventMetadataVisitor : public XStatsOwner<XEventMetadata> {
+ public:
+  // REQUIRED: plane and metadata cannot be nullptr.
+  XEventMetadataVisitor(const XPlaneVisitor* plane,
+                        const XEventMetadata* metadata)
+      : XStatsOwner(plane, metadata) {}
+
+  absl::string_view Name() const { return metadata()->name(); }
+
+  bool HasDisplayName() const { return !metadata()->display_name().empty(); }
+
+  absl::string_view DisplayName() const { return metadata()->display_name(); }
+
+  // For each child event metadata, call the specified lambda.
+  template <typename ForEachChildFunc>
+  void ForEachChild(ForEachChildFunc&& for_each_child) const;
+
+ private:
+  const XEventMetadata* metadata() const { return stats_owner(); }
+};
 
 class XEventVisitor : public XStatsOwner<XEvent> {
  public:
@@ -113,10 +136,6 @@ class XEventVisitor : public XStatsOwner<XEvent> {
   XEventVisitor(const XPlaneVisitor* plane, const XLine* line,
                 const XEvent* event);
 
-  XEventMetadataVisitor MetadataStats() const {
-    return XEventMetadataVisitor(plane_, metadata_);
-  }
-
   int64 Id() const { return event_->metadata_id(); }
 
   absl::string_view Name() const { return metadata_->name(); }
@@ -127,8 +146,6 @@ class XEventVisitor : public XStatsOwner<XEvent> {
 
   absl::string_view DisplayName() const { return metadata_->display_name(); }
 
-  absl::string_view Metadata() const { return metadata_->metadata(); }
-
   double OffsetNs() const { return PicosToNanos(event_->offset_ps()); }
 
   int64 OffsetPs() const { return event_->offset_ps(); }
@@ -157,6 +174,10 @@ class XEventVisitor : public XStatsOwner<XEvent> {
 
   const XEventMetadata* metadata() const { return metadata_; }
 
+  XEventMetadataVisitor Metadata() const {
+    return XEventMetadataVisitor(plane_, metadata_);
+  }
+
   Timespan GetTimespan() const { return Timespan(TimestampPs(), DurationPs()); }
 
  private:
@@ -262,17 +283,28 @@ class XPlaneVisitor : public XStatsOwner<XPlane> {
 
 template <class T>
 absl::optional<XStatVisitor> XStatsOwner<T>::GetStat(int64 stat_type) const {
-  const auto* stat_metadata = metadata_->GetStatMetadataByType(stat_type);
+  const auto* stat_metadata = plane_->GetStatMetadataByType(stat_type);
   if (stat_metadata != nullptr) {
     for (const XStat& stat : stats_owner_->stats()) {
       if (stat.metadata_id() == stat_metadata->id()) {
-        return XStatVisitor(metadata_, &stat, stat_metadata, stat_type);
+        return XStatVisitor(plane_, &stat, stat_metadata, stat_type);
       }
     }
   }
   return absl::nullopt;  // type does not exist in this owner.
 }
 
+template <typename ForEachChildFunc>
+void XEventMetadataVisitor::ForEachChild(
+    ForEachChildFunc&& for_each_child) const {
+  for (int64 child_id : metadata()->child_id()) {
+    const auto* event_metadata = plane()->GetEventMetadata(child_id);
+    if (event_metadata != nullptr) {
+      for_each_child(XEventMetadataVisitor(plane(), event_metadata));
+    }
+  }
+}
+
 }  // namespace profiler
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/protobuf/BUILD b/tensorflow/core/protobuf/BUILD
index d1cbf872087..ef968563ba8 100644
--- a/tensorflow/core/protobuf/BUILD
+++ b/tensorflow/core/protobuf/BUILD
@@ -141,6 +141,7 @@ exports_files(
         "snapshot.proto",
         "service_config.proto",
         "debug_event.proto",
+        "extension_type_variant.proto",
         "meta_graph.proto",
         "named_tensor.proto",
         "remote_tensor_handle.proto",
@@ -170,6 +171,7 @@ tf_proto_library(
         "snapshot.proto",
         "service_config.proto",
         "debug_event.proto",
+        "extension_type_variant.proto",
         "meta_graph.proto",
         "named_tensor.proto",
         "remote_tensor_handle.proto",
diff --git a/tensorflow/core/protobuf/extension_type_variant.proto b/tensorflow/core/protobuf/extension_type_variant.proto
new file mode 100644
index 00000000000..536db3b2435
--- /dev/null
+++ b/tensorflow/core/protobuf/extension_type_variant.proto
@@ -0,0 +1,14 @@
+syntax = "proto3";
+
+package tensorflow;
+
+import "tensorflow/core/protobuf/struct.proto";
+
+// Metadata for ExtensionTypeVariant, used when serializing as Variant.
+//
+// We define a new message here (rather than directly using TypeSpecProto for
+// the metadata string) to retain flexibility to change the metadata encoding
+// to support additional features.
+message ExtensionTypeVariantMetadata {
+  TypeSpecProto type_spec_proto = 1;
+}
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 20e650de6d1..41478958118 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -108,7 +108,7 @@ limitations under the License.
 
 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 589  // Updated: 2020/11/18
+#define TF_GRAPH_DEF_VERSION 594  // Updated: 2020/11/23
 
 // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
 //
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
index bd20924ff23..ca85e3fb076 100644
--- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
@@ -653,14 +653,75 @@ struct ShardedInputInfo {
   std::vector<NodeOut> sharded_inputs;
 };
 
+// Adds pad node after split node to graph for uneven sharding tiled inputs.
+// |graph| owns the returned Node* instance.
+xla::StatusOr<Node*> CreatePadNode(const int padding, const int num_dims,
+                                   const int split_dim, DataType dtype,
+                                   Node* control_predecessor, Node* split_node,
+                                   const int split_index, Graph* graph) {
+  // Add paddings node.
+  Status s;
+  NodeDef paddings_def;
+  paddings_def.set_name(
+      graph->NewName(absl::StrCat(split_node->name(), "/paddings")));
+  paddings_def.set_op("Const");
+  AddNodeAttr("dtype", DT_INT32, &paddings_def);
+  paddings_def.set_device(split_node->assigned_device_name());
+  TensorProto sizes_tensor_proto;
+  sizes_tensor_proto.set_dtype(DT_INT32);
+  for (int i = 0; i < num_dims; ++i) {
+    sizes_tensor_proto.add_int_val(0);
+    if (i == split_dim) {
+      sizes_tensor_proto.add_int_val(padding);
+    } else {
+      sizes_tensor_proto.add_int_val(0);
+    }
+  }
+  TensorShape sizes_shape({num_dims, 2});
+  sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
+  AddNodeAttr("value", sizes_tensor_proto, &paddings_def);
+  Node* paddings_node = graph->AddNode(paddings_def, &s);
+  TF_RETURN_IF_ERROR(s);
+
+  // Add Pad node.
+  NodeDef pad_def;
+  pad_def.set_name(graph->NewName(
+      absl::StrCat(split_node->name(), "/pad_shard_", split_index)));
+  pad_def.set_op("Pad");
+  pad_def.set_device(split_node->assigned_device_name());
+  AddNodeAttr("T", dtype, &pad_def);
+  AddNodeAttr("Tpaddings", DT_INT32, &pad_def);
+  pad_def.add_input(absl::StrCat(split_node->name(), ":", split_index));
+  pad_def.add_input(absl::StrCat(paddings_node->name(), ":0"));
+  Node* pad_node = graph->AddNode(pad_def, &s);
+  pad_node->set_assigned_device_name(split_node->assigned_device_name());
+  TF_RETURN_IF_ERROR(s);
+  // Add edges for pad node.
+  graph->AddEdge(split_node, split_index, pad_node, 0);
+  graph->AddEdge(paddings_node, 0, pad_node, 1);
+  graph->AddControlEdge(control_predecessor, pad_node);
+  return pad_node;
+}
+
 // Adds split node and split dimension node to graph for sharding tiled inputs.
 // |graph| owns the returned Node* instance.
-xla::StatusOr<Node*> CreateSplitNode(int num_splits, int dim,
-                                     int orig_src_output, DataType dtype,
+xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
+                                     const int num_dims, const int64 padding,
+                                     const int orig_src_output, DataType dtype,
                                      absl::string_view name_prefix,
                                      Node* control_predecessor, Node* orig_src,
                                      Graph* graph) {
   const std::string input_assigned_device = orig_src->assigned_device_name();
+  Node* to_split_node = orig_src;
+  int to_split_index = orig_src_output;
+  if (padding > 0) {
+    TF_ASSIGN_OR_RETURN(
+        Node * pad_node,
+        CreatePadNode(padding, num_dims, dim, dtype, control_predecessor,
+                      orig_src, orig_src_output, graph));
+    to_split_node = pad_node;
+    to_split_index = 0;
+  }
 
   // Add a split dimension node.
   NodeDef split_dim_def;
@@ -686,28 +747,48 @@ xla::StatusOr<Node*> CreateSplitNode(int num_splits, int dim,
   AddNodeAttr("num_split", num_splits, &split_def);
   AddNodeAttr("T", dtype, &split_def);
   split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
-  split_def.add_input(absl::StrCat(orig_src->name(), ":", orig_src_output));
+  split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
   Node* split_node = graph->AddNode(split_def, &s);
-  split_node->set_assigned_device_name(input_assigned_device);
   TF_RETURN_IF_ERROR(s);
 
+  split_node->set_assigned_device_name(input_assigned_device);
+
+  // If colocate the newly created split op to source node of input to TPU
+  // computation.
+  split_node->AddAttr(kColocationAttrName,
+                      std::vector<string>{absl::StrCat(kColocationGroupPrefix,
+                                                       orig_src->name())});
+
   graph->AddEdge(split_dim_node, 0, split_node, 0);
-  graph->AddEdge(orig_src, orig_src_output, split_node, 1);
+  graph->AddEdge(to_split_node, to_split_index, split_node, 1);
 
   // Add a control dependency from `control_predecessor` to newly created
   // constant node. This ensures that newly added split/split dim
   // nodes are placed inside correct while loop frames when TPUExecute
   // node is inside a host training loop.
   graph->AddControlEdge(control_predecessor, split_dim_node);
-
   return split_node;
 }
 
+int64 GetPadding(const int split_dim, const int num_splits,
+                 const PartialTensorShape& partial_tensor_shape) {
+  // If dim dimension is not defined, no uneven sharding support.
+  if (partial_tensor_shape.dim_size(split_dim) <= 0) {
+    return 0;
+  }
+  int64 per_split_size = tensorflow::MathUtil::CeilOfRatio<int64>(
+      partial_tensor_shape.dim_size(split_dim), num_splits);
+  int64 total_padding =
+      per_split_size * num_splits - partial_tensor_shape.dim_size(split_dim);
+  return total_padding;
+}
+
 // Creates a set of splits nodes that shards tiled input node in graph.
 xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
     const xla::OpSharding& sharding, int orig_arg_num, DataType dtype,
-    int replica_id, int orig_src_output, Node* orig_src,
-    Node* control_predecessor, Graph* graph,
+    const PartialTensorShape& partial_tensor_shape, int replica_id,
+    int orig_src_output, Node* orig_src, Node* control_predecessor,
+    Graph* graph,
     std::map<ShardedInputIndex, ShardedInputInfo>*
         arg_index_to_sharded_input_map) {
   ShardedInputIndex input_index{replica_id, orig_arg_num};
@@ -738,6 +819,7 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
 
   auto sharding_it = split_dimension_map.begin();
   std::queue<Node*> split_nodes_for_dimension;
+  absl::flat_hash_map<Node*, int> node_to_split_dim;
   int split_dimension = sharding_it->first;
   int num_split = sharding_it->second;
 
@@ -747,13 +829,17 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
   // that split the input data at ith dimension.
   TF_ASSIGN_OR_RETURN(
       Node * root_split_node,
-      CreateSplitNode(num_split, split_dimension, orig_src_output, dtype,
-                      absl::StrCat("sharded_input/replica_", replica_id,
-                                   "_dim_", split_dimension),
-                      control_predecessor, orig_src, graph));
+      CreateSplitNode(
+          num_split, split_dimension, partial_tensor_shape.dims(),
+          GetPadding(split_dimension, num_split, partial_tensor_shape),
+          orig_src_output, dtype,
+          absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
+                       split_dimension),
+          control_predecessor, orig_src, graph));
   sharding_it++;
 
   split_nodes_for_dimension.emplace(root_split_node);
+  node_to_split_dim[root_split_node] = split_dimension;
 
   while (sharding_it != split_dimension_map.end()) {
     split_dimension = sharding_it->first;
@@ -767,11 +853,15 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
            ++src_output_index) {
         TF_ASSIGN_OR_RETURN(
             Node * split_node,
-            CreateSplitNode(num_split, split_dimension, src_output_index, dtype,
-                            absl::StrCat("sharded_input/replica_", replica_id,
-                                         "_dim_", split_dimension),
-                            control_predecessor, input_split_node, graph));
+            CreateSplitNode(
+                num_split, split_dimension, partial_tensor_shape.dims(),
+                GetPadding(split_dimension, num_split, partial_tensor_shape),
+                src_output_index, dtype,
+                absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
+                             split_dimension),
+                control_predecessor, input_split_node, graph));
         split_nodes_for_dimension.emplace(split_node);
+        node_to_split_dim[split_node] = split_dimension;
       }
     }
     sharding_it++;
@@ -856,18 +946,82 @@ xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
   return concat_node;
 }
 
+// Adds slice node after concat node to graph for uneven sharding tiled inputs.
+xla::StatusOr<Node*> CreateSliceNode(DataType dtype,
+                                     const PartialTensorShape& shape,
+                                     Node* concat_node,
+                                     const int concat_out_index, Graph* graph,
+                                     absl::string_view device) {
+  Status s;
+  // Add begin node for concat.
+  NodeDef begin_def;
+  begin_def.set_name(
+      graph->NewName(absl::StrCat(concat_node->name(), "/slice_begin")));
+  begin_def.set_op("Const");
+  AddNodeAttr("dtype", DT_INT32, &begin_def);
+  begin_def.set_device(std::string(device));
+  TensorProto begin_tensor_proto;
+  begin_tensor_proto.set_dtype(DT_INT32);
+  for (int i = 0; i < shape.dims(); ++i) {
+    begin_tensor_proto.add_int_val(0);
+  }
+  TensorShape begin_shape({shape.dims()});
+  begin_shape.AsProto(begin_tensor_proto.mutable_tensor_shape());
+  AddNodeAttr("value", begin_tensor_proto, &begin_def);
+  Node* begin_node = graph->AddNode(begin_def, &s);
+  TF_RETURN_IF_ERROR(s);
+
+  // Add size node.
+  NodeDef size_def;
+  size_def.set_name(
+      graph->NewName(absl::StrCat(concat_node->name(), "/slice_size")));
+  size_def.set_op("Const");
+  AddNodeAttr("dtype", DT_INT32, &size_def);
+  size_def.set_device(std::string(device));
+  TensorProto sizes_tensor_proto;
+  sizes_tensor_proto.set_dtype(DT_INT32);
+  for (int i = 0; i < shape.dims(); ++i) {
+    sizes_tensor_proto.add_int_val(shape.dim_size(i));
+  }
+  TensorShape sizes_shape({shape.dims()});
+  sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
+  AddNodeAttr("value", sizes_tensor_proto, &size_def);
+  Node* size_node = graph->AddNode(size_def, &s);
+  TF_RETURN_IF_ERROR(s);
+
+  // Add Slice node.
+  NodeDef slice_def;
+  slice_def.set_name(
+      graph->NewName(absl::StrCat(concat_node->name(), "/slice")));
+  slice_def.set_op("Slice");
+  slice_def.set_device(std::string(device));
+  AddNodeAttr("T", dtype, &slice_def);
+  AddNodeAttr("Index", DT_INT32, &slice_def);
+  slice_def.add_input(absl::StrCat(concat_node->name(), ":", concat_out_index));
+  slice_def.add_input(absl::StrCat(begin_node->name(), ":0"));
+  slice_def.add_input(absl::StrCat(size_node->name(), ":0"));
+  Node* slice_node = graph->AddNode(slice_def, &s);
+  TF_RETURN_IF_ERROR(s);
+  // Add edges for slice node.
+  graph->AddEdge(concat_node, concat_out_index, slice_node, 0);
+  graph->AddEdge(begin_node, 0, slice_node, 1);
+  graph->AddEdge(size_node, 0, slice_node, 2);
+  return slice_node;
+}
+
 // Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute
 // nodes into a single output. Sharded outputs are concatenated along row major
 // order. That is, tiled output along 0th dimension will be concatenated last.
 xla::StatusOr<Node*> CreateConcatNodesForRetval(
-    const xla::OpSharding& sharding, DataType dtype, int replica_id,
+    const xla::OpSharding& sharding, DataType dtype,
+    const PartialTensorShape& inferred_shape, int replica_id,
     const std::vector<NodeOut>& orig_inputs, Graph* graph,
     absl::string_view device) {
   std::map<int, int> split_dimension_map;
   TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
       sharding, &split_dimension_map));
-
   std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs;
+  bool has_paddings = false;
 
   for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend();
        it++) {
@@ -891,12 +1045,21 @@ xla::StatusOr<Node*> CreateConcatNodesForRetval(
               dim, num_splits, dtype,
               absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim),
               inputs, graph, device));
+      int64 paddings = GetPadding(dim, num_splits, inferred_shape);
+      has_paddings |= paddings > 0;
       new_concat_nodes.emplace_back(NodeOut{concat_node, 0});
     }
     inputs_to_sharded_retval = new_concat_nodes;
   }
 
   TF_RET_CHECK(inputs_to_sharded_retval.size() == 1);
+  if (has_paddings) {
+    TF_ASSIGN_OR_RETURN(Node * slice_node,
+                        CreateSliceNode(dtype, inferred_shape,
+                                        inputs_to_sharded_retval.at(0).node,
+                                        /*concat_out_index*/ 0, graph, device));
+    return slice_node;
+  }
   return inputs_to_sharded_retval.at(0).node;
 }
 
@@ -2759,7 +2922,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
               TF_ASSIGN_OR_RETURN(
                   ShardedInputInfo sharded_input_info,
                   CreateOrGetSplitNodesForInputSharding(
-                      sharding, orig_arg_num, dtype, replica,
+                      sharding, orig_arg_num, dtype,
+                      arg_shapes[orig_arg_num].handle_shape, replica,
                       edge->src_output(), edge->src(), control_predecessor,
                       graph, &input_index_to_sharded_inputs));
               NodeOut split_node_and_index =
@@ -2840,7 +3004,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
                   ShardedInputInfo sharded_input_info,
                   CreateOrGetSplitNodesForInputSharding(
                       sharding, orig_arg_num,
-                      arg_shapes[orig_arg_num].handle_type, replica,
+                      arg_shapes[orig_arg_num].handle_type,
+                      arg_shapes[orig_arg_num].handle_shape, replica,
                       var_data.index, var_data.node, control_predecessor, graph,
                       &input_index_to_sharded_inputs));
               NodeOut split_node_and_index =
@@ -2927,8 +3092,9 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
           DataType dtype = e->src()->output_type(e->src_output());
           TF_ASSIGN_OR_RETURN(
               Node * concat_node,
-              CreateConcatNodesForRetval(sharding, dtype, replica, orig_inputs,
-                                         graph, /*device=*/""));
+              CreateConcatNodesForRetval(
+                  sharding, dtype, /*inferred_shape*/ PartialTensorShape(),
+                  replica, orig_inputs, graph, /*device=*/""));
 
           const Edge* edge = replicate_output_edges[output_num];
           Node* dst = edge->dst();
@@ -3009,8 +3175,9 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
             TF_ASSIGN_OR_RETURN(
                 Node * concat_node,
                 CreateConcatNodesForRetval(
-                    sharding, arg_shapes[orig_arg_num].handle_type, replica,
-                    orig_inputs, graph, device));
+                    sharding, arg_shapes[orig_arg_num].handle_type,
+                    arg_shapes[orig_arg_num].handle_shape, replica, orig_inputs,
+                    graph, device));
             // Populate VariableWrite.
             VariableWrite& write = variable_writes->at(core_variable_writes[i]);
             write.value = concat_node;
diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
index 5ea92d007d9..f1f12a716da 100644
--- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
+++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
@@ -329,10 +329,20 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
           server_address_output);
     });
     size_t server_address_output_size;
+
+    TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params params;
+    params.struct_size =
+        TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE;
+    params.priv = nullptr;
+    params.tpu_host_config_size = tpu_host_config.size();
+    params.tpu_host_config = tpu_host_config.data();
+    params.server_address_output_size = &server_address_output_size;
+    params.server_address_output = &server_address_output;
+    params.status = status;
+
     tpu::OpsApiFn()
         ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
-            tpu_host_config.size(), tpu_host_config.data(),
-            &server_address_output_size, &server_address_output, status);
+            &params);
     OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
 
     std::string server_address(server_address_output,
diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.cc b/tensorflow/core/tpu/kernels/tpu_pod_state.cc
index 72993baf57b..bccfa085d5e 100644
--- a/tensorflow/core/tpu/kernels/tpu_pod_state.cc
+++ b/tensorflow/core/tpu/kernels/tpu_pod_state.cc
@@ -78,9 +78,16 @@ Status GetServerAddressAndPort(std::string* server_address, int* serving_port) {
   });
   size_t server_address_output_size;
   *serving_port = -1;
-  tpu::OpsApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(
-      &server_address_output_size, &server_address_output, serving_port,
-      status);
+
+  TpuConfigurationApi_GetServerAddressAndPort_Params params;
+  params.struct_size = TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE;
+  params.priv = nullptr;
+  params.server_address_output_size = &server_address_output_size;
+  params.server_address_output = &server_address_output;
+  params.port_output = serving_port;
+  params.status = status;
+
+  tpu::OpsApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(&params);
   TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
   *server_address =
       std::string(server_address_output, server_address_output_size);
diff --git a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
index fc15d71dfd8..52fa3bfa8f1 100644
--- a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
+++ b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
@@ -22,58 +22,6 @@ limitations under the License.
 
 namespace tensorflow {
 namespace {
-// TODO(b/32945756): Add a scatter op in XLA and move this to a HLO optimization
-// pass. Optimization for UnsortedSegmentSum on TPU: use k-hot matmul. This
-// optimization requires:
-//     1. data has dtype supported by TPU matmul and has rank of 1 or 2.
-//     2. indices has rank of 1.
-//     3. matmul op count is less than 800 billion.
-//
-// Example of calculating UnsortedSegmentSum by k-hot matmul:
-//     data shape        [A, B]
-//     indices shape     [A]
-//     num_segment        N
-//     output shape      [N, B]
-//     matmul op count    N * A * B
-// Step 1: create k-hot matrix
-//     k-hot matrix has shape of [A, N], where row i is responsible for
-//     collecting the sum of the i-th segment, concretely
-//            k-hot[i][j] = 1 if indices[i] = j
-// Step 2: perform matmul
-//     the final result is obtained by multiplying k-hot matrix with data
-//     matrix, namely
-//             k-hot  *  data   => result
-// shape:      [N, A] *  [A, B] => [N, B]
-xla::XlaOp KHotMatmul(XlaOpKernelContext* ctx, xla::XlaBuilder* builder,
-                      const xla::XlaOp data, const xla::XlaOp indices,
-                      int64 num_segments) {
-  DataType data_dtype = ctx->input_type(0);
-  xla::PrimitiveType indices_type = ctx->input_xla_type(1);
-  TensorShape data_shape = ctx->InputShape(0);
-  TensorShape indices_shape = ctx->InputShape(1);
-  xla::XlaOp linspace = xla::Iota(builder, indices_type, num_segments);
-  xla::XlaOp linspace_col = xla::Reshape(linspace, {num_segments, 1});
-  TensorShape indices_row_shape = indices_shape;
-  indices_row_shape.InsertDim(0, 1);
-  xla::XlaOp indices_row = xla::Reshape(indices, indices_row_shape.dim_sizes());
-  xla::XlaOp k_hot = xla::Eq(indices_row, linspace_col);
-  xla::XlaOp k_hot_with_data_dtype =
-      XlaHelpers::ConvertElementType(k_hot, data_dtype);
-  // F32 version of the KHotMatmul. It splits the F32 data into three
-  // BF16 partial data and run KHotMatmul for each of them. The final result
-  // is the summation of three BF16 results.
-  // Note that this still doesn't fully retain f32 precision.
-  // In particular, values smaller than 2^-111 may see loss of precision.
-  xla::PrecisionConfig precision_config;
-  if (data_dtype == DT_FLOAT) {
-    precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST);
-  } else {
-    CHECK_EQ(data_dtype, DT_BFLOAT16);
-    precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT);
-  }
-  precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT);
-  return xla::Dot(k_hot_with_data_dtype, data, &precision_config);
-}
 
 class UnsortedSegmentSum : public XlaOpKernel {
  public:
diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h
index 45dd2d1e88d..f49438bde85 100644
--- a/tensorflow/core/tpu/tpu_ops_c_api.h
+++ b/tensorflow/core/tpu/tpu_ops_c_api.h
@@ -240,14 +240,42 @@ TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
                                                           TF_Status* status);
 TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
     int64_t* cache_size_in_bytes);
+
+typedef struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params {
+  int32_t struct_size;
+  void* priv;
+
+  size_t tpu_host_config_size;
+  const char* tpu_host_config;
+
+  size_t* server_address_output_size;  // out
+  char** server_address_output;        // out
+  TF_Status* status;                   // out
+} TpuConfigurationApi_CompilationCacheServerAddressFromConfig_Params;
+
+#define TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE \
+  (sizeof(                                                                   \
+      struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params))
+
 TFTPU_CAPI_EXPORT
 void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
-    size_t tpu_host_config_size, const char* tpu_host_config,
-    size_t* server_address_output_size, char** server_address_output,
-    TF_Status* status);
+    TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params* params);
+
+typedef struct TpuConfigurationApi_GetServerAddressAndPort_Params {
+  int32_t struct_size;
+  void* priv;
+
+  size_t* server_address_output_size;  // out
+  char** server_address_output;        // out
+  int* port_output;                    // out
+  TF_Status* status;                   // out
+} TpuConfigurationApi_GetServerAddressAndPort_Params;
+
+#define TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE \
+  (sizeof(struct TpuConfigurationApi_GetServerAddressAndPort_Params))
+
 TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
-    size_t* server_address_output_size, char** server_address_output,
-    int* port_output, TF_Status* status);
+    TpuConfigurationApi_GetServerAddressAndPort_Params* params);
 
 // Creates a new TPU program.
 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New();
diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index b725be34844..5ac978ff0b8 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -69,6 +69,11 @@ FRAMEWORK_LIB_HDRS = [
     "stderr_reporter.h",
 ]
 
+exports_files(
+    FRAMEWORK_LIB_HDRS,
+    visibility = ["//tensorflow/lite/core/shims:__subpackages__"],
+)
+
 cc_library(
     name = "version",
     hdrs = ["version.h"],
@@ -164,9 +169,7 @@ cc_library(
 
 cc_library(
     name = "builtin_op_data",
-    hdrs = [
-        "builtin_op_data.h",
-    ],
+    hdrs = ["builtin_op_data.h"],
     deps = ["//tensorflow/lite/c:common"],
 )
 
@@ -225,16 +228,10 @@ cc_library(
     ],
 )
 
+# The library that implements the full C++ API.
+# See also 'framework' below, which is the corresponding public target.
 cc_library(
     name = "framework_lib",
-    srcs = [
-        "core/subgraph.cc",
-        "graph_info.cc",
-        "interpreter.cc",
-        "interpreter_builder.cc",
-        "model_builder.cc",
-        "optional_debug_tools.cc",
-    ],
     hdrs = FRAMEWORK_LIB_HDRS,
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts() + TFLITE_DEFAULT_COPTS,
@@ -244,9 +241,99 @@ cc_library(
     deps = [
         ":allocation",
         ":arena_planner",
+        ":cc_api",
         ":external_cpu_backend_context",
         ":graph_info",
         ":kernel_api",
+        ":macros",
+        ":memory_planner",
+        ":minimal_logging",
+        ":mutable_op_resolver",
+        ":optional_debug_tools",
+        ":shared_library",
+        ":simple_memory_arena",
+        ":stderr_reporter",
+        ":string",
+        ":type_to_tflitetype",
+        ":util",
+        ":version",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core/api",
+        "//tensorflow/lite/core/api:verifier",
+        "//tensorflow/lite/delegates:status",
+        "//tensorflow/lite/experimental/resource",
+        "//tensorflow/lite/kernels/internal:compatibility",
+        "//tensorflow/lite/profiling:platform_profiler",
+        "//tensorflow/lite/schema:schema_fbs",
+        "//tensorflow/lite/schema:schema_utils",
+    ],
+    alwayslink = 1,  # Why?? TODO(b/161243354): eliminate this.
+)
+
+# The public target for the full C++ API.
+# The deps listed here, other than ":framework_lib", are the interface dependencies
+# (dependencies required by the header files).
+# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts.
+cc_library(
+    name = "framework",
+    srcs = [],
+    hdrs = FRAMEWORK_LIB_HDRS,
+    compatible_with = get_compatible_with_portable(),
+    copts = tflite_copts() + TFLITE_DEFAULT_COPTS,
+    deps = [
+        ":allocation",
+        ":arena_planner",
+        ":cc_api",
+        ":external_cpu_backend_context",
+        ":framework_lib",
+        ":graph_info",
+        ":memory_planner",
+        ":minimal_logging",
+        ":simple_memory_arena",
+        ":string",
+        ":type_to_tflitetype",
+        ":util",
+        ":version",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core/api",
+        "//tensorflow/lite/core/api:verifier",
+        "//tensorflow/lite/experimental/resource",
+        "//tensorflow/lite/schema:schema_fbs",
+    ],
+)
+
+# The key parts of the C++ API.  This target defines the TF Lite classes for
+# loading models and interpreting them.
+cc_library(
+    name = "cc_api",
+    srcs = [
+        "core/subgraph.cc",
+        "graph_info.cc",
+        "interpreter.cc",
+        "interpreter_builder.cc",
+        "model_builder.cc",
+    ],
+    hdrs = [
+        "core/subgraph.h",
+        "graph_info.h",
+        "interpreter.h",
+        "interpreter_builder.h",
+        "model.h",
+        "model_builder.h",
+    ],
+    compatible_with = get_compatible_with_portable(),
+    copts = tflite_copts() + TFLITE_DEFAULT_COPTS,
+    visibility = [
+        "//tensorflow/lite/core/shims:__subpackages__",
+        "//tensorflow/lite/kernels:__subpackages__",
+    ],
+    deps = [
+        ":allocation",
+        ":arena_planner",
+        ":external_cpu_backend_context",
+        ":graph_info",
+        ":kernel_api",
+        ":macros",
         ":memory_planner",
         ":minimal_logging",
         ":mutable_op_resolver",
@@ -267,34 +354,24 @@ cc_library(
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/schema:schema_utils",
     ],
-    alwayslink = 1,
+    alwayslink = 1,  # Why?? TODO(b/161243354): eliminate this.
 )
 
-# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts.
 cc_library(
-    name = "framework",
+    name = "optional_debug_tools",
     srcs = [
+        "optional_debug_tools.cc",
     ],
-    hdrs = FRAMEWORK_LIB_HDRS,
+    hdrs = ["optional_debug_tools.h"],
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts() + TFLITE_DEFAULT_COPTS,
+    visibility = [
+        "//visibility:public",
+    ],
     deps = [
-        ":allocation",
-        ":arena_planner",
-        ":external_cpu_backend_context",
-        ":framework_lib",
-        ":graph_info",
-        ":memory_planner",
-        ":minimal_logging",
-        ":simple_memory_arena",
-        ":string",
-        ":type_to_tflitetype",
-        ":util",
-        ":version",
+        ":macros",
+        "//tensorflow/lite:cc_api",
         "//tensorflow/lite/c:common",
-        "//tensorflow/lite/core/api",
-        "//tensorflow/lite/core/api:verifier",
-        "//tensorflow/lite/experimental/resource",
         "//tensorflow/lite/schema:schema_fbs",
     ],
 )
@@ -448,9 +525,6 @@ cc_test(
     size = "small",
     srcs = ["string_util_test.cc"],
     features = ["-dynamic_link_test_srcs"],  # see go/dynamic_link_test_srcs
-    tags = [
-        "tflite_not_portable_ios",  # TODO(b/117786830)
-    ],
     deps = [
         ":framework",
         ":string_util",
@@ -469,7 +543,7 @@ cc_test(
     ],
     features = ["-dynamic_link_test_srcs"],  # see go/dynamic_link_test_srcs
     tags = [
-        "tflite_not_portable_ios",  # TODO(b/117786830)
+        "tflite_not_portable_ios",  # TODO(b/173711739)
         "tflite_smoke_test",
     ],
     deps = [
@@ -497,9 +571,6 @@ cc_test(
     size = "small",
     srcs = ["graph_info_test.cc"],
     features = ["-dynamic_link_test_srcs"],  # see go/dynamic_link_test_srcs
-    tags = [
-        "tflite_not_portable_ios",  # TODO(b/117786830)
-    ],
     deps = [
         ":framework",
         "//tensorflow/lite/testing:util",
@@ -513,9 +584,6 @@ cc_test(
     size = "small",
     srcs = ["simple_memory_arena_test.cc"],
     features = ["-dynamic_link_test_srcs"],  # see go/dynamic_link_test_srcs
-    tags = [
-        "tflite_not_portable_ios",  # TODO(b/117786830)
-    ],
     deps = [
         ":simple_memory_arena",
         "//tensorflow/core:tflite_portable_logging",
@@ -610,9 +678,6 @@ cc_test(
     size = "small",
     srcs = ["mutable_op_resolver_test.cc"],
     features = ["-dynamic_link_test_srcs"],  # see go/dynamic_link_test_srcs
-    tags = [
-        "tflite_not_portable_ios",  # TODO(b/117786830)
-    ],
     deps = [
         ":framework",
         "//tensorflow/lite/testing:util",
@@ -647,9 +712,6 @@ cc_test(
     size = "small",
     srcs = ["util_test.cc"],
     features = ["-dynamic_link_test_srcs"],  # see go/dynamic_link_test_srcs
-    tags = [
-        "tflite_not_portable_ios",  # TODO(b/117786830)
-    ],
     deps = [
         ":util",
         "//tensorflow/lite/c:common",
@@ -717,7 +779,7 @@ cc_test(
     size = "small",
     srcs = ["minimal_logging_test.cc"],
     tags = [
-        "tflite_not_portable_ios",  # TODO(b/117786830)
+        "tflite_not_portable_ios",  # TODO(b/173711739)
     ],
     deps = [
         ":minimal_logging",
@@ -735,6 +797,7 @@ cc_library(
 cc_library(
     name = "macros",
     hdrs = ["core/macros.h"],
+    compatible_with = get_compatible_with_portable(),
 )
 
 cc_library(
diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt
index ebf823fa57d..3f2bb836320 100644
--- a/tensorflow/lite/CMakeLists.txt
+++ b/tensorflow/lite/CMakeLists.txt
@@ -27,6 +27,12 @@
 # - Host Tools (i.e conversion / analysis tools etc.)
 
 cmake_minimum_required(VERSION 3.16)
+if(NOT CMAKE_BUILD_TYPE)
+  message(STATUS "Setting build type to Release, for debug builds use"
+    "'-DCMAKE_BUILD_TYPE=Debug'.")
+  set(CMAKE_BUILD_TYPE "Release")
+endif()
+
 # Double colon in target name means ALIAS or IMPORTED target.
 cmake_policy(SET CMP0028 NEW)
 # Enable MACOSX_RPATH (@rpath) for built dynamic libraries.
@@ -197,9 +203,14 @@ endif()
 if(TFLITE_ENABLE_GPU)
   find_package(opencl_headers REQUIRED)
   find_package(vulkan_headers REQUIRED)
+  # Android NDK already has OpenGL, EGL headers.
+  if(NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
+    find_package(opengl_headers REQUIRED)
+    find_package(egl_headers REQUIRED)
+  endif()
   populate_tflite_source_vars(
     "delegates/gpu/cl" TFLITE_DELEGATES_GPU_CL_SRCS
-    FILTER "(_test|gl_interop|egl_sync)\\.(cc|h)$"
+    FILTER "(_test|gl_interop|gpu_api_delegate|egl_sync)\\.(cc|h)$"
   )
   populate_tflite_source_vars(
     "delegates/gpu/cl/kernels" TFLITE_DELEGATES_GPU_CL_KERNELS_SRCS
diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD
index e8db0dcf440..504655f7b28 100644
--- a/tensorflow/lite/c/BUILD
+++ b/tensorflow/lite/c/BUILD
@@ -68,7 +68,7 @@ cc_library(
         "//tensorflow/lite/kernels:builtin_ops",
         "//tensorflow/lite/kernels/internal:compatibility",
     ],
-    alwayslink = 1,
+    alwayslink = 1,  # Why?? TODO(b/161243354): eliminate this.
 )
 
 cc_library(
@@ -83,7 +83,7 @@ cc_library(
         "//tensorflow/lite:framework",
         "//tensorflow/lite:kernel_api",
     ],
-    alwayslink = 1,
+    alwayslink = 1,  # Why?? TODO(b/161243354): eliminate this.
 )
 
 cc_test(
@@ -127,8 +127,8 @@ cc_library(
         "common.h",
     ],
     compatible_with = get_compatible_with_portable(),
-    deps = ["//tensorflow/lite:builtin_ops"],
-    alwayslink = 1,
+    copts = tflite_copts(),
+    alwayslink = 1,  # Why?? TODO(b/161243354): eliminate this.
 )
 
 # For use with library targets that can't use relative paths.
diff --git a/tensorflow/lite/core/shims/BUILD b/tensorflow/lite/core/shims/BUILD
new file mode 100644
index 00000000000..05b2566eef3
--- /dev/null
+++ b/tensorflow/lite/core/shims/BUILD
@@ -0,0 +1,178 @@
+# Description: this package contains shim library targets that forward
+# to the TF Lite C and C++ API targets.  See README.md.
+
+load("//tensorflow:tensorflow.bzl", "if_not_windows")
+load("//tensorflow/lite:build_def.bzl", "tflite_copts")
+load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
+load(":build_defs.bzl", "build_test")
+
+package(
+    default_visibility = ["//visibility:private"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+TFLITE_DEFAULT_COPTS = if_not_windows([
+    "-Wall",
+    "-Wno-comment",
+    "-Wno-extern-c-compat",
+])
+
+#------------------------------------------------------------------------------
+# C++ API
+
+FRAMEWORK_LIB_HDRS = [
+    "//tensorflow/lite:allocation.h",
+    "//tensorflow/lite:context.h",
+    "//tensorflow/lite:context_util.h",
+    "//tensorflow/lite:core/macros.h",
+    "//tensorflow/lite:core/subgraph.h",
+    "//tensorflow/lite:error_reporter.h",
+    "//tensorflow/lite:graph_info.h",
+    "//tensorflow/lite:mutable_op_resolver.h",
+    "//tensorflow/lite:op_resolver.h",
+    "//tensorflow/lite:optional_debug_tools.h",
+    "//tensorflow/lite:stderr_reporter.h",
+]
+
+CC_API_HDRS = [
+    "cc/interpreter.h",
+    "cc/interpreter_builder.h",
+    "cc/model.h",
+    "cc/model_builder.h",
+]
+
+cc_library(
+    name = "framework",
+    srcs = [],
+    hdrs = FRAMEWORK_LIB_HDRS + CC_API_HDRS,
+    compatible_with = get_compatible_with_portable(),
+    copts = tflite_copts() + TFLITE_DEFAULT_COPTS,
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/lite:allocation",
+        "//tensorflow/lite:arena_planner",
+        "//tensorflow/lite:external_cpu_backend_context",
+        "//tensorflow/lite:framework_lib",
+        "//tensorflow/lite:graph_info",
+        "//tensorflow/lite:kernel_api",
+        "//tensorflow/lite:memory_planner",
+        "//tensorflow/lite:minimal_logging",
+        "//tensorflow/lite:simple_memory_arena",
+        "//tensorflow/lite:string",
+        "//tensorflow/lite:type_to_tflitetype",
+        "//tensorflow/lite:util",
+        "//tensorflow/lite:version",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core/api",
+        "//tensorflow/lite/core/api:verifier",
+        "//tensorflow/lite/experimental/resource",
+        "//tensorflow/lite/schema:schema_fbs",
+    ],
+)
+
+cc_library(
+    name = "cc_api",
+    srcs = [],
+    hdrs = CC_API_HDRS,
+    compatible_with = get_compatible_with_portable(),
+    copts = tflite_copts() + TFLITE_DEFAULT_COPTS,
+    visibility = ["//tensorflow/lite:__pkg__"],
+    deps = [
+        "//tensorflow/lite:allocation",
+        "//tensorflow/lite:arena_planner",
+        "//tensorflow/lite:cc_api",
+        "//tensorflow/lite:external_cpu_backend_context",
+        "//tensorflow/lite:graph_info",
+        "//tensorflow/lite:kernel_api",
+        "//tensorflow/lite:macros",
+        "//tensorflow/lite:memory_planner",
+        "//tensorflow/lite:minimal_logging",
+        "//tensorflow/lite:simple_memory_arena",
+        "//tensorflow/lite:string",
+        "//tensorflow/lite:type_to_tflitetype",
+        "//tensorflow/lite:util",
+        "//tensorflow/lite:version",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core/api",
+        "//tensorflow/lite/core/api:verifier",
+        "//tensorflow/lite/experimental/resource",
+        "//tensorflow/lite/schema:schema_fbs",
+    ],
+)
+
+cc_library(
+    name = "builtin_ops",
+    hdrs = [
+        "builtin_op_kernels.h",
+        "fully_connected.h",
+        "register.h",
+    ],
+    compatible_with = get_compatible_with_portable(),
+    deps = [
+        ":builtin_ops_impl",
+        "//tensorflow/lite:framework_lib",
+        "//tensorflow/lite/c:common",
+    ],
+)
+
+build_test(
+    name = "cc_api_build_test",
+    targets = [
+        ":cc_api",
+        ":framework",
+    ],
+)
+
+#------------------------------------------------------------------------------
+# C API
+
+cc_library(
+    name = "c_api",
+    hdrs = ["c/c_api.h"],
+    compatible_with = get_compatible_with_portable(),
+    copts = TFLITE_DEFAULT_COPTS,
+    visibility = ["//visibility:public"],
+    deps = ["//tensorflow/lite/c:c_api"],
+)
+
+cc_library(
+    name = "c_api_experimental",
+    hdrs = ["c/c_api_experimental.h"],
+    compatible_with = get_compatible_with_portable(),
+    copts = TFLITE_DEFAULT_COPTS,
+    visibility = ["//visibility:public"],
+    deps = ["//tensorflow/lite/c:c_api_experimental"],
+)
+
+cc_library(
+    name = "common",
+    hdrs = ["c/common.h"],
+    compatible_with = get_compatible_with_portable(),
+    copts = TFLITE_DEFAULT_COPTS,
+    visibility = ["//visibility:public"],
+    deps = ["//tensorflow/lite/c:common"],
+)
+
+cc_library(
+    name = "builtin_op_data",
+    hdrs = ["c/builtin_op_data.h"],
+    compatible_with = get_compatible_with_portable(),
+    copts = TFLITE_DEFAULT_COPTS,
+    visibility = ["//visibility:public"],
+    deps = ["//tensorflow/lite/c:common"],
+)
+
+build_test(
+    name = "c_api_build_test",
+    targets = [
+        ":builtin_op_data",
+        ":c_api",
+        ":c_api_experimental",
+        ":common",
+    ],
+)
+
+#------------------------------------------------------------------------------
+
+tflite_portable_test_suite()
diff --git a/tensorflow/lite/core/shims/README.md b/tensorflow/lite/core/shims/README.md
new file mode 100644
index 00000000000..2a95d5cb9c6
--- /dev/null
+++ b/tensorflow/lite/core/shims/README.md
@@ -0,0 +1,11 @@
+This directory contains shim header files that forward to the TF Lite
+C API and to the key headers of the TF Lite C++ API.
+
+The intent is that the shims in this directory could be modified to optionally
+redirect to a different implementation of those APIs (for example,
+one built into the underlying operating system platform).
+
+These should be used as follows: #includes from .cc files that are
+_implementing_ the shimmed TF Lite APIs should include the regular TF
+Lite API headers.  #includes from files that are _using_ the shimmed
+APIs should include the shimmed headers.
diff --git a/tensorflow/lite/core/shims/build_defs.bzl b/tensorflow/lite/core/shims/build_defs.bzl
new file mode 100644
index 00000000000..330e2e9cc95
--- /dev/null
+++ b/tensorflow/lite/core/shims/build_defs.bzl
@@ -0,0 +1,22 @@
+"""A simple portable implementation of build_test."""
+
+def build_test(name, targets, visibility = None):
+    """Generates a test that just verifies that the specified targets can be built."""
+
+    # Generate an sh_test rule that lists the specified targets as data,
+    # (thus forcing those targets to be built before the test can be run)
+    # and that runs a script which always succeeds.
+    native.sh_test(
+        name = name,
+        srcs = [name + ".sh"],
+        data = targets,
+        visibility = visibility,
+    )
+
+    # Generate the script which always succeeds.  We just generate an empty script.
+    native.genrule(
+        name = name + "_gen_sh",
+        outs = [name + ".sh"],
+        cmd = "> $@",
+        visibility = ["//visibility:private"],
+    )
diff --git a/tensorflow/lite/core/shims/c/builtin_op_data.h b/tensorflow/lite/core/shims/c/builtin_op_data.h
new file mode 100644
index 00000000000..f75b0144be7
--- /dev/null
+++ b/tensorflow/lite/core/shims/c/builtin_op_data.h
@@ -0,0 +1,20 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_
+#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+
+#endif  // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/lite/core/shims/c/c_api.h b/tensorflow/lite/core/shims/c/c_api.h
new file mode 100644
index 00000000000..90e0147a479
--- /dev/null
+++ b/tensorflow/lite/core/shims/c/c_api.h
@@ -0,0 +1,20 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_
+#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_
+
+#include "tensorflow/lite/c/c_api.h"
+
+#endif  // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_
diff --git a/tensorflow/lite/micro/xtensa_hifimini_staging/micro_time.cc b/tensorflow/lite/core/shims/c/c_api_experimental.h
similarity index 68%
rename from tensorflow/lite/micro/xtensa_hifimini_staging/micro_time.cc
rename to tensorflow/lite/core/shims/c/c_api_experimental.h
index 6f3844c1fe3..ceb1cb7e5bb 100644
--- a/tensorflow/lite/micro/xtensa_hifimini_staging/micro_time.cc
+++ b/tensorflow/lite/core/shims/c/c_api_experimental.h
@@ -12,17 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_
+#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_
 
-// Xtensa implementation of micro_timer.
-// To include this with make, add TAGS=xtensa-xpg.
-#include "tensorflow/lite/micro/micro_time.h"
+#include "tensorflow/lite/c/c_api_experimental.h"
 
-#include <time.h>
-
-namespace tflite {
-
-int32_t ticks_per_second() { return CLOCKS_PER_SEC; }
-
-int32_t GetCurrentTimeTicks() { return clock(); }
-
-}  // namespace tflite
+#endif  // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_
diff --git a/tensorflow/lite/core/shims/c/common.h b/tensorflow/lite/core/shims/c/common.h
new file mode 100644
index 00000000000..a531546ec9e
--- /dev/null
+++ b/tensorflow/lite/core/shims/c/common.h
@@ -0,0 +1,20 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_
+#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_
+
+#include "tensorflow/lite/c/common.h"
+
+#endif  // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_
diff --git a/tensorflow/lite/core/shims/cc/interpreter.h b/tensorflow/lite/core/shims/cc/interpreter.h
new file mode 100644
index 00000000000..20e8f10ce94
--- /dev/null
+++ b/tensorflow/lite/core/shims/cc/interpreter.h
@@ -0,0 +1,26 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_H_
+#define TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_H_
+
+/// For documentation, see third_party/tensorflow/lite/interpreter.h.
+
+#include "tensorflow/lite/interpreter.h"
+
+namespace tflite_shims {
+using Interpreter = ::tflite::Interpreter;
+}  // namespace tflite_shims
+
+#endif  // TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_H_
diff --git a/tensorflow/lite/core/shims/cc/interpreter_builder.h b/tensorflow/lite/core/shims/cc/interpreter_builder.h
new file mode 100644
index 00000000000..891a5747f7e
--- /dev/null
+++ b/tensorflow/lite/core/shims/cc/interpreter_builder.h
@@ -0,0 +1,26 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_BUILDER_H_
+#define TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_BUILDER_H_
+
+/// For documentation, see third_party/tensorflow/lite/interpreter_builder.h.
+
+#include "tensorflow/lite/interpreter_builder.h"
+
+namespace tflite_shims {
+using InterpreterBuilder = ::tflite::InterpreterBuilder;
+}  // namespace tflite_shims
+
+#endif  // TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_BUILDER_H_
diff --git a/tensorflow/lite/core/shims/cc/kernels/register.h b/tensorflow/lite/core/shims/cc/kernels/register.h
new file mode 100644
index 00000000000..d5a52034fb7
--- /dev/null
+++ b/tensorflow/lite/core/shims/cc/kernels/register.h
@@ -0,0 +1,20 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_CORE_SHIMS_CC_KERNELS_REGISTER_H_
+#define TENSORFLOW_LITE_CORE_SHIMS_CC_KERNELS_REGISTER_H_
+
+#include "tensorflow/lite/kernels/register.h"
+
+#endif  // TENSORFLOW_LITE_CORE_SHIMS_CC_KERNELS_REGISTER_H_
diff --git a/tensorflow/lite/core/shims/cc/model.h b/tensorflow/lite/core/shims/cc/model.h
new file mode 100644
index 00000000000..8a016319510
--- /dev/null
+++ b/tensorflow/lite/core/shims/cc/model.h
@@ -0,0 +1,23 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_H_
+#define TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_H_
+
+/// For documentation, see third_party/tensorflow/lite/model.h.
+
+#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
+#include "tensorflow/lite/core/shims/cc/model_builder.h"
+
+#endif  // TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_H_
diff --git a/tensorflow/lite/core/shims/cc/model_builder.h b/tensorflow/lite/core/shims/cc/model_builder.h
new file mode 100644
index 00000000000..3ee9c5a3016
--- /dev/null
+++ b/tensorflow/lite/core/shims/cc/model_builder.h
@@ -0,0 +1,26 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_BUILDER_H_
+#define TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_BUILDER_H_
+
+/// For documentation, see third_party/tensorflow/lite/model_builder.h.
+
+#include "tensorflow/lite/model_builder.h"
+
+namespace tflite_shims {
+using FlatBufferModel = ::tflite::FlatBufferModel;
+}  // namespace tflite_shims
+
+#endif  // TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_BUILDER_H_
diff --git a/tensorflow/lite/delegates/BUILD b/tensorflow/lite/delegates/BUILD
index 240de7fef94..7c68b25cf14 100644
--- a/tensorflow/lite/delegates/BUILD
+++ b/tensorflow/lite/delegates/BUILD
@@ -71,9 +71,6 @@ cc_test(
     size = "small",
     srcs = ["delegate_test.cc"],
     features = ["-dynamic_link_test_srcs"],  # see go/dynamic_link_test_srcs
-    tags = [
-        "tflite_not_portable_ios",  # TODO(b/117786830)
-    ],
     deps = [
         ":interpreter_utils",
         ":utils",
diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD
index 58a36ca7a64..cece0d59768 100644
--- a/tensorflow/lite/delegates/gpu/cl/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/BUILD
@@ -99,13 +99,13 @@ cc_library(
     deps = [
         ":buffer",
         ":cl_context",
-        ":device_info",
         ":gpu_object",
         ":linear_storage",
         ":tensor",
         ":texture2d",
         "//tensorflow/lite/delegates/gpu/common:access_type",
         "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
@@ -126,8 +126,8 @@ cc_test(
     deps = [
         ":buffer",
         ":cl_arguments",
-        ":device_info",
         ":gpu_object",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
     ],
@@ -171,9 +171,9 @@ cc_library(
     srcs = ["cl_device.cc"],
     hdrs = ["cl_device.h"],
     deps = [
-        ":device_info",
         ":opencl_wrapper",
         ":util",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
         "@com_google_absl//absl/strings",
@@ -247,8 +247,8 @@ cc_library(
         ":cl_kernel",
         ":program_cache",
         ":tensor",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
     ],
 )
 
@@ -276,16 +276,6 @@ flatbuffer_cc_library(
     ],
 )
 
-cc_library(
-    name = "device_info",
-    srcs = ["device_info.cc"],
-    hdrs = ["device_info.h"],
-    deps = [
-        "//tensorflow/lite/delegates/gpu/common:data_type",
-        "@com_google_absl//absl/strings",
-    ],
-)
-
 cc_library(
     name = "egl_sync",
     srcs = ["egl_sync.cc"],
@@ -307,11 +297,11 @@ cc_library(
         ":cl_command_queue",
         ":cl_context",
         ":cl_device",
-        ":device_info",
         ":program_cache",
         ":tensor",
         ":util",
         "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:precision",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
@@ -406,7 +396,6 @@ cc_library(
         ":serialization_cc_fbs",
         ":storage_type_util",
         ":tensor",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector",
         "//tensorflow/lite/delegates/gpu/cl/selectors:special_selector",
         "//tensorflow/lite/delegates/gpu/common:data_type",
@@ -424,6 +413,7 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common/task:arguments",
         "//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
         "//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_linear_desc",
         "//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
@@ -509,8 +499,8 @@ cc_library(
     srcs = ["storage_type_util.cc"],
     hdrs = ["storage_type_util.h"],
     deps = [
-        ":device_info",
         "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:util",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
index 673b24f63e2..170c5538e02 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
+++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
@@ -21,8 +21,8 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
 #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
 
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments_test.cc b/tensorflow/lite/delegates/gpu/cl/cl_arguments_test.cc
index ddca3d4dc3a..69a490b9493 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_arguments_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments_test.cc
@@ -21,8 +21,8 @@ limitations under the License.
 #include <gtest/gtest.h>
 #include "absl/strings/match.h"
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
 #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_context.cc b/tensorflow/lite/delegates/gpu/cl/cl_context.cc
index 32a5e43d799..e1cc54ba574 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_context.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_context.cc
@@ -54,29 +54,29 @@ void AddSupportedImageFormats(cl_context context, GpuInfo* info) {
   auto supported_formats =
       GetSupportedImage2DFormats(context, CL_MEM_READ_WRITE);
   for (auto format : supported_formats) {
-    info->supports_r_f16_tex2d =
-        info->supports_r_f16_tex2d ||
+    info->opencl_info.supports_r_f16_tex2d =
+        info->opencl_info.supports_r_f16_tex2d ||
         IsEqualToImageFormat(format, DataType::FLOAT16, 1);
-    info->supports_rg_f16_tex2d =
-        info->supports_rg_f16_tex2d ||
+    info->opencl_info.supports_rg_f16_tex2d =
+        info->opencl_info.supports_rg_f16_tex2d ||
         IsEqualToImageFormat(format, DataType::FLOAT16, 2);
-    info->supports_rgb_f16_tex2d =
-        info->supports_rgb_f16_tex2d ||
+    info->opencl_info.supports_rgb_f16_tex2d =
+        info->opencl_info.supports_rgb_f16_tex2d ||
         IsEqualToImageFormat(format, DataType::FLOAT16, 3);
-    info->supports_rgba_f16_tex2d =
-        info->supports_rgba_f16_tex2d ||
+    info->opencl_info.supports_rgba_f16_tex2d =
+        info->opencl_info.supports_rgba_f16_tex2d ||
         IsEqualToImageFormat(format, DataType::FLOAT16, 4);
-    info->supports_r_f32_tex2d =
-        info->supports_r_f32_tex2d ||
+    info->opencl_info.supports_r_f32_tex2d =
+        info->opencl_info.supports_r_f32_tex2d ||
         IsEqualToImageFormat(format, DataType::FLOAT32, 1);
-    info->supports_rg_f32_tex2d =
-        info->supports_rg_f32_tex2d ||
+    info->opencl_info.supports_rg_f32_tex2d =
+        info->opencl_info.supports_rg_f32_tex2d ||
         IsEqualToImageFormat(format, DataType::FLOAT32, 2);
-    info->supports_rgb_f32_tex2d =
-        info->supports_rgb_f32_tex2d ||
+    info->opencl_info.supports_rgb_f32_tex2d =
+        info->opencl_info.supports_rgb_f32_tex2d ||
         IsEqualToImageFormat(format, DataType::FLOAT32, 3);
-    info->supports_rgba_f32_tex2d =
-        info->supports_rgba_f32_tex2d ||
+    info->opencl_info.supports_rgba_f32_tex2d =
+        info->opencl_info.supports_rgba_f32_tex2d ||
         IsEqualToImageFormat(format, DataType::FLOAT32, 4);
   }
 }
@@ -148,7 +148,7 @@ absl::Status CreateCLGLContext(const CLDevice& device,
                                cl_context_properties egl_context,
                                cl_context_properties egl_display,
                                CLContext* result) {
-  if (!device.SupportsExtension("cl_khr_gl_sharing")) {
+  if (!device.GetInfo().SupportsExtension("cl_khr_gl_sharing")) {
     return absl::UnavailableError("Device doesn't support CL-GL sharing.");
   }
   cl_context_properties platform =
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.cc b/tensorflow/lite/delegates/gpu/cl/cl_device.cc
index 26206529235..1bd5db7b646 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_device.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_device.cc
@@ -89,34 +89,34 @@ void GetDeviceWorkDimsSizes(cl_device_id id, int3* result) {
   result->z = limits[2];
 }
 
-OpenCLVersion ParseCLVersion(const std::string& version) {
+OpenClVersion ParseCLVersion(const std::string& version) {
   const auto first_dot_pos = version.find_first_of('.');
   if (first_dot_pos == std::string::npos) {
-    return OpenCLVersion::CL_1_0;
+    return OpenClVersion::kCl1_0;
   }
   const int major = version[first_dot_pos - 1] - '0';
   const int minor = version[first_dot_pos + 1] - '0';
 
   if (major == 1) {
     if (minor == 2) {
-      return OpenCLVersion::CL_1_2;
+      return OpenClVersion::kCl1_2;
     } else if (minor == 1) {
-      return OpenCLVersion::CL_1_1;
+      return OpenClVersion::kCl1_1;
     } else {
-      return OpenCLVersion::CL_1_0;
+      return OpenClVersion::kCl1_0;
     }
   } else if (major == 2) {
     if (minor == 2) {
-      return OpenCLVersion::CL_2_2;
+      return OpenClVersion::kCl2_2;
     } else if (minor == 1) {
-      return OpenCLVersion::CL_2_1;
+      return OpenClVersion::kCl2_1;
     } else {
-      return OpenCLVersion::CL_2_0;
+      return OpenClVersion::kCl2_0;
     }
   } else if (major == 3) {
-    return OpenCLVersion::CL_3_0;
+    return OpenClVersion::kCl3_0;
   } else {
-    return OpenCLVersion::CL_1_0;
+    return OpenClVersion::kCl1_0;
   }
 }
 
@@ -162,87 +162,90 @@ GpuInfo GpuInfoFromDeviceID(cl_device_id id) {
   const auto vendor_name = GetDeviceInfo<std::string>(id, CL_DEVICE_VENDOR);
   const auto opencl_c_version =
       GetDeviceInfo<std::string>(id, CL_DEVICE_OPENCL_C_VERSION);
-  info.gpu_vendor = ParseVendor(device_name, vendor_name);
+  info.gpu_api = GpuApi::kOpenCl;
+  info.vendor = ParseVendor(device_name, vendor_name);
   if (info.IsAdreno()) {
     info.adreno_info = AdrenoInfo(opencl_c_version);
   } else if (info.IsMali()) {
     info.mali_info = MaliInfo(device_name);
   }
-  info.cl_version = ParseCLVersion(opencl_c_version);
-  info.extensions =
+  info.opencl_info.cl_version = ParseCLVersion(opencl_c_version);
+  info.opencl_info.extensions =
       absl::StrSplit(GetDeviceInfo<std::string>(id, CL_DEVICE_EXTENSIONS), ' ');
-  info.supports_fp16 = false;
-  info.supports_image3d_writes = false;
-  for (const auto& ext : info.extensions) {
+  info.opencl_info.supports_fp16 = false;
+  info.opencl_info.supports_image3d_writes = false;
+  for (const auto& ext : info.opencl_info.extensions) {
     if (ext == "cl_khr_fp16") {
-      info.supports_fp16 = true;
+      info.opencl_info.supports_fp16 = true;
     }
     if (ext == "cl_khr_3d_image_writes") {
-      info.supports_image3d_writes = true;
+      info.opencl_info.supports_image3d_writes = true;
     }
   }
 
   cl_device_fp_config f32_config =
       GetDeviceInfo<cl_device_fp_config>(id, CL_DEVICE_SINGLE_FP_CONFIG);
-  info.supports_fp32_rtn = f32_config & CL_FP_ROUND_TO_NEAREST;
+  info.opencl_info.supports_fp32_rtn = f32_config & CL_FP_ROUND_TO_NEAREST;
 
-  if (info.supports_fp16) {
+  if (info.opencl_info.supports_fp16) {
     cl_device_fp_config f16_config;
     auto status = GetDeviceInfo<cl_device_fp_config>(
         id, CL_DEVICE_HALF_FP_CONFIG, &f16_config);
     // AMD supports cl_khr_fp16 but CL_DEVICE_HALF_FP_CONFIG is empty.
     if (status.ok() && !info.IsAMD()) {
-      info.supports_fp16_rtn = f16_config & CL_FP_ROUND_TO_NEAREST;
+      info.opencl_info.supports_fp16_rtn = f16_config & CL_FP_ROUND_TO_NEAREST;
     } else {  // happens on PowerVR
       f16_config = f32_config;
-      info.supports_fp16_rtn = info.supports_fp32_rtn;
+      info.opencl_info.supports_fp16_rtn = info.opencl_info.supports_fp32_rtn;
     }
   } else {
-    info.supports_fp16_rtn = false;
+    info.opencl_info.supports_fp16_rtn = false;
   }
 
-  if (info.IsPowerVR() && !info.supports_fp16) {
+  if (info.IsPowerVR() && !info.opencl_info.supports_fp16) {
     // PowerVR doesn't have full support of fp16 and so doesn't list this
     // extension. But it can support fp16 in MADs and as buffers/textures types,
     // so we will use it.
-    info.supports_fp16 = true;
-    info.supports_fp16_rtn = info.supports_fp32_rtn;
+    info.opencl_info.supports_fp16 = true;
+    info.opencl_info.supports_fp16_rtn = info.opencl_info.supports_fp32_rtn;
   }
 
-  if (!info.supports_image3d_writes &&
+  if (!info.opencl_info.supports_image3d_writes &&
       ((info.IsAdreno() && info.adreno_info.IsAdreno4xx()) ||
        info.IsNvidia())) {
     // in local tests Adreno 430 can write in image 3d, at least on small sizes,
     // but it doesn't have cl_khr_3d_image_writes in list of available
     // extensions
     // The same for NVidia
-    info.supports_image3d_writes = true;
+    info.opencl_info.supports_image3d_writes = true;
   }
-  info.compute_units_count =
+  info.opencl_info.compute_units_count =
       GetDeviceInfo<cl_uint>(id, CL_DEVICE_MAX_COMPUTE_UNITS);
-  info.image2d_max_width =
+  info.opencl_info.image2d_max_width =
       GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_WIDTH);
-  info.image2d_max_height =
+  info.opencl_info.image2d_max_height =
       GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_HEIGHT);
-  info.buffer_max_size =
+  info.opencl_info.buffer_max_size =
       GetDeviceInfo<cl_ulong>(id, CL_DEVICE_MAX_MEM_ALLOC_SIZE);
-  if (info.cl_version >= OpenCLVersion::CL_1_2) {
-    info.image_buffer_max_size =
+  if (info.opencl_info.cl_version >= OpenClVersion::kCl1_2) {
+    info.opencl_info.image_buffer_max_size =
         GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE);
-    info.image_array_max_layers =
+    info.opencl_info.image_array_max_layers =
         GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE_MAX_ARRAY_SIZE);
   }
-  info.image3d_max_width =
+  info.opencl_info.image3d_max_width =
       GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE3D_MAX_WIDTH);
-  info.image3d_max_height =
+  info.opencl_info.image3d_max_height =
       GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_HEIGHT);
-  info.image3d_max_depth =
+  info.opencl_info.image3d_max_depth =
       GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE3D_MAX_DEPTH);
   int3 max_work_group_sizes;
   GetDeviceWorkDimsSizes(id, &max_work_group_sizes);
-  info.max_work_group_size_x = max_work_group_sizes.x;
-  info.max_work_group_size_y = max_work_group_sizes.y;
-  info.max_work_group_size_z = max_work_group_sizes.z;
+  info.opencl_info.max_work_group_size_x = max_work_group_sizes.x;
+  info.opencl_info.max_work_group_size_y = max_work_group_sizes.y;
+  info.opencl_info.max_work_group_size_z = max_work_group_sizes.z;
+  info.opencl_info.max_work_group_total_size =
+      GetDeviceInfo<size_t>(id, CL_DEVICE_MAX_WORK_GROUP_SIZE);
 
   if (info.IsIntel()) {
     if (info.SupportsExtension("cl_intel_required_subgroup_size")) {
@@ -300,48 +303,10 @@ CLDevice& CLDevice::operator=(CLDevice&& device) {
   return *this;
 }
 
-bool CLDevice::SupportsFP16() const { return info_.supports_fp16; }
-
-bool CLDevice::SupportsExtension(const std::string& extension) const {
-  return info_.SupportsExtension(extension);
-}
-
-bool CLDevice::SupportsTextureArray() const {
-  return info_.SupportsTextureArray();
-}
-
-bool CLDevice::SupportsImageBuffer() const {
-  return info_.SupportsImageBuffer();
-}
-
-bool CLDevice::SupportsImage3D() const { return info_.SupportsImage3D(); }
-
-bool CLDevice::SupportsFP32RTN() const { return info_.supports_fp32_rtn; }
-
-bool CLDevice::SupportsFP16RTN() const { return info_.supports_fp16_rtn; }
-
 std::string CLDevice::GetPlatformVersion() const {
   return GetPlatformInfo(platform_id_, CL_PLATFORM_VERSION);
 }
 
-bool CLDevice::IsCL20OrHigher() const { return info_.IsCL20OrHigher(); }
-
-bool CLDevice::SupportsSubGroupWithSize(int sub_group_size) const {
-  return info_.SupportsSubGroupWithSize(sub_group_size);
-}
-
-bool CLDevice::IsAdreno() const { return info_.IsAdreno(); }
-
-bool CLDevice::IsPowerVR() const { return info_.IsPowerVR(); }
-
-bool CLDevice::IsNvidia() const { return info_.IsNvidia(); }
-
-bool CLDevice::IsMali() const { return info_.IsMali(); }
-
-bool CLDevice::IsAMD() const { return info_.IsAMD(); }
-
-bool CLDevice::IsIntel() const { return info_.IsIntel(); }
-
 void CLDevice::DisableOneLayerTextureArray() {
   info_.adreno_info.support_one_layer_texture_array = false;
 }
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.h b/tensorflow/lite/delegates/gpu/cl/cl_device.h
index 3614e2211f1..b94704b40d6 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_device.h
+++ b/tensorflow/lite/delegates/gpu/cl/cl_device.h
@@ -19,9 +19,9 @@ limitations under the License.
 #include <string>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
 #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
 #include "tensorflow/lite/delegates/gpu/cl/util.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
@@ -46,24 +46,6 @@ class CLDevice {
   cl_platform_id platform() const { return platform_id_; }
   std::string GetPlatformVersion() const;
 
-  GpuVendor vendor() const { return info_.gpu_vendor; }
-  OpenCLVersion cl_version() const { return info_.cl_version; }
-  bool SupportsFP16() const;
-  bool SupportsTextureArray() const;
-  bool SupportsImageBuffer() const;
-  bool SupportsImage3D() const;
-  bool SupportsExtension(const std::string& extension) const;
-  bool SupportsFP32RTN() const;
-  bool SupportsFP16RTN() const;
-  bool IsCL20OrHigher() const;
-  bool SupportsSubGroupWithSize(int sub_group_size) const;
-  bool IsAdreno() const;
-  bool IsPowerVR() const;
-  bool IsNvidia() const;
-  bool IsMali() const;
-  bool IsAMD() const;
-  bool IsIntel() const;
-
   // To track bug on some Adreno. b/131099086
   void DisableOneLayerTextureArray();
 
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_operation.h b/tensorflow/lite/delegates/gpu/cl/cl_operation.h
index 3e4d40a5fa0..b357e2f087b 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_operation.h
+++ b/tensorflow/lite/delegates/gpu/cl/cl_operation.h
@@ -24,9 +24,9 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/program_cache.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/device_info.cc b/tensorflow/lite/delegates/gpu/cl/device_info.cc
deleted file mode 100644
index 94c3cce4a74..00000000000
--- a/tensorflow/lite/delegates/gpu/cl/device_info.cc
+++ /dev/null
@@ -1,376 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
-
-#include <algorithm>
-#include <string>
-#include <vector>
-
-#include "absl/strings/numbers.h"
-#include "absl/strings/str_split.h"
-
-namespace tflite {
-namespace gpu {
-namespace cl {
-namespace {
-AdrenoGpu GetAdrenoGpuVersion(const std::string& device_name) {
-  const std::map<std::string, AdrenoGpu> kMapping = {
-      // Adreno 6xx series
-      {"685", AdrenoGpu::kAdreno685},
-      {"680", AdrenoGpu::kAdreno680},
-      {"675", AdrenoGpu::kAdreno675},
-      {"650", AdrenoGpu::kAdreno650},
-      {"640", AdrenoGpu::kAdreno640},
-      {"630", AdrenoGpu::kAdreno630},
-      {"620", AdrenoGpu::kAdreno620},
-      {"616", AdrenoGpu::kAdreno618},
-      {"616", AdrenoGpu::kAdreno616},
-      {"615", AdrenoGpu::kAdreno615},
-      {"612", AdrenoGpu::kAdreno612},
-      {"610", AdrenoGpu::kAdreno610},
-      {"605", AdrenoGpu::kAdreno605},
-      // Adreno 5xx series
-      {"540", AdrenoGpu::kAdreno540},
-      {"530", AdrenoGpu::kAdreno530},
-      {"512", AdrenoGpu::kAdreno512},
-      {"510", AdrenoGpu::kAdreno510},
-      {"509", AdrenoGpu::kAdreno509},
-      {"508", AdrenoGpu::kAdreno508},
-      {"506", AdrenoGpu::kAdreno506},
-      {"505", AdrenoGpu::kAdreno505},
-      {"504", AdrenoGpu::kAdreno504},
-      // Adreno 4xx series
-      {"430", AdrenoGpu::kAdreno430},
-      {"420", AdrenoGpu::kAdreno420},
-      {"418", AdrenoGpu::kAdreno418},
-      {"405", AdrenoGpu::kAdreno405},
-      // Adreno 3xx series
-      {"330", AdrenoGpu::kAdreno330},
-      {"320", AdrenoGpu::kAdreno320},
-      {"308", AdrenoGpu::kAdreno308},
-      {"306", AdrenoGpu::kAdreno306},
-      {"305", AdrenoGpu::kAdreno305},
-      {"304", AdrenoGpu::kAdreno304},
-      // Adreno 2xx series
-      {"225", AdrenoGpu::kAdreno225},
-      {"220", AdrenoGpu::kAdreno220},
-      {"205", AdrenoGpu::kAdreno205},
-      {"203", AdrenoGpu::kAdreno203},
-      {"200", AdrenoGpu::kAdreno200},
-      // Adreno 1xx series
-      {"130", AdrenoGpu::kAdreno130},
-      {"120", AdrenoGpu::kAdreno120},
-  };
-
-  for (const auto& v : kMapping) {
-    if (device_name.find(v.first) != std::string::npos) {
-      return v.second;
-    }
-  }
-  return AdrenoGpu::kUnknown;
-}
-
-MaliGpu GetMaliGpuVersion(const std::string& gpu_description) {
-  const std::map<std::string, MaliGpu> kMapping = {
-      {"t604", MaliGpu::kT604}, {"t622", MaliGpu::kT622},
-      {"t624", MaliGpu::kT624}, {"t628", MaliGpu::kT628},
-      {"t658", MaliGpu::kT658}, {"t678", MaliGpu::kT678},
-      {"t720", MaliGpu::kT720}, {"t760", MaliGpu::kT760},
-      {"t820", MaliGpu::kT820}, {"t830", MaliGpu::kT830},
-      {"t860", MaliGpu::kT860}, {"t880", MaliGpu::kT880},
-      {"g31", MaliGpu::kG31},   {"g51", MaliGpu::kG51},
-      {"g71", MaliGpu::kG71},   {"g52", MaliGpu::kG52},
-      {"g72", MaliGpu::kG72},   {"g76", MaliGpu::kG76},
-      {"g57", MaliGpu::kG57},   {"g77", MaliGpu::kG77},
-      {"g68", MaliGpu::kG68},   {"g78", MaliGpu::kG78},
-  };
-  for (const auto& v : kMapping) {
-    if (gpu_description.find(v.first) != std::string::npos) {
-      return v.second;
-    }
-  }
-  return MaliGpu::kUnknown;
-}
-
-}  // namespace
-
-std::string GpuVendorToString(GpuVendor v) {
-  switch (v) {
-    case GpuVendor::kApple:
-      return "Apple";
-    case GpuVendor::kQualcomm:
-      return "Qualcomm";
-    case GpuVendor::kMali:
-      return "Mali";
-    case GpuVendor::kPowerVR:
-      return "PowerVR";
-    case GpuVendor::kNvidia:
-      return "NVIDIA";
-    case GpuVendor::kAMD:
-      return "AMD";
-    case GpuVendor::kIntel:
-      return "Intel";
-    case GpuVendor::kUnknown:
-      return "unknown vendor";
-  }
-}
-
-std::string OpenCLVersionToString(OpenCLVersion version) {
-  switch (version) {
-    case OpenCLVersion::CL_1_0:
-      return "1.0";
-    case OpenCLVersion::CL_1_1:
-      return "1.1";
-    case OpenCLVersion::CL_1_2:
-      return "1.2";
-    case OpenCLVersion::CL_2_0:
-      return "2.0";
-    case OpenCLVersion::CL_2_1:
-      return "2.1";
-    case OpenCLVersion::CL_2_2:
-      return "2.2";
-    case OpenCLVersion::CL_3_0:
-      return "3.0";
-  }
-}
-
-AdrenoInfo::AdrenoInfo(const std::string& device_version)
-    : adreno_gpu(GetAdrenoGpuVersion(device_version)) {}
-
-bool AdrenoInfo::IsAdreno1xx() const {
-  return adreno_gpu == AdrenoGpu::kAdreno120 ||
-         adreno_gpu == AdrenoGpu::kAdreno130;
-}
-
-bool AdrenoInfo::IsAdreno2xx() const {
-  return adreno_gpu == AdrenoGpu::kAdreno200 ||
-         adreno_gpu == AdrenoGpu::kAdreno203 ||
-         adreno_gpu == AdrenoGpu::kAdreno205 ||
-         adreno_gpu == AdrenoGpu::kAdreno220 ||
-         adreno_gpu == AdrenoGpu::kAdreno225;
-}
-
-bool AdrenoInfo::IsAdreno3xx() const {
-  return adreno_gpu == AdrenoGpu::kAdreno304 ||
-         adreno_gpu == AdrenoGpu::kAdreno305 ||
-         adreno_gpu == AdrenoGpu::kAdreno306 ||
-         adreno_gpu == AdrenoGpu::kAdreno308 ||
-         adreno_gpu == AdrenoGpu::kAdreno320 ||
-         adreno_gpu == AdrenoGpu::kAdreno330;
-}
-
-bool AdrenoInfo::IsAdreno4xx() const {
-  return adreno_gpu == AdrenoGpu::kAdreno405 ||
-         adreno_gpu == AdrenoGpu::kAdreno418 ||
-         adreno_gpu == AdrenoGpu::kAdreno420 ||
-         adreno_gpu == AdrenoGpu::kAdreno430;
-}
-
-bool AdrenoInfo::IsAdreno5xx() const {
-  return adreno_gpu == AdrenoGpu::kAdreno504 ||
-         adreno_gpu == AdrenoGpu::kAdreno505 ||
-         adreno_gpu == AdrenoGpu::kAdreno506 ||
-         adreno_gpu == AdrenoGpu::kAdreno508 ||
-         adreno_gpu == AdrenoGpu::kAdreno509 ||
-         adreno_gpu == AdrenoGpu::kAdreno510 ||
-         adreno_gpu == AdrenoGpu::kAdreno512 ||
-         adreno_gpu == AdrenoGpu::kAdreno530 ||
-         adreno_gpu == AdrenoGpu::kAdreno540;
-}
-
-bool AdrenoInfo::IsAdreno6xx() const {
-  return adreno_gpu == AdrenoGpu::kAdreno605 ||
-         adreno_gpu == AdrenoGpu::kAdreno610 ||
-         adreno_gpu == AdrenoGpu::kAdreno612 ||
-         adreno_gpu == AdrenoGpu::kAdreno615 ||
-         adreno_gpu == AdrenoGpu::kAdreno616 ||
-         adreno_gpu == AdrenoGpu::kAdreno618 ||
-         adreno_gpu == AdrenoGpu::kAdreno620 ||
-         adreno_gpu == AdrenoGpu::kAdreno630 ||
-         adreno_gpu == AdrenoGpu::kAdreno640 ||
-         adreno_gpu == AdrenoGpu::kAdreno650 ||
-         adreno_gpu == AdrenoGpu::kAdreno675 ||
-         adreno_gpu == AdrenoGpu::kAdreno680 ||
-         adreno_gpu == AdrenoGpu::kAdreno685;
-}
-
-bool AdrenoInfo::IsAdreno6xxOrHigher() const { return IsAdreno6xx(); }
-
-int AdrenoInfo::GetMaximumWavesCount() const {
-  if (IsAdreno6xx()) {
-    if (adreno_gpu == AdrenoGpu::kAdreno640) {
-      return 30;
-    } else {
-      return 16;
-    }
-  } else {
-    // all other versions not supported
-    return 1;
-  }
-}
-
-int AdrenoInfo::GetRegisterMemorySizePerComputeUnit() const {
-  if (IsAdreno6xx()) {
-    if (adreno_gpu == AdrenoGpu::kAdreno640) {
-      return 128 * 144 * 16;
-    } else if (adreno_gpu == AdrenoGpu::kAdreno650) {
-      return 128 * 64 * 16;
-    } else {
-      return 128 * 96 * 16;
-    }
-  } else {
-    // all other versions not supported
-    return 1;
-  }
-}
-
-int AdrenoInfo::GetMaximumWavesCount(int register_footprint_per_tread,
-                                     bool full_wave) const {
-  const int register_usage_per_wave =
-      GetWaveSize(full_wave) * register_footprint_per_tread;
-  const int possible_waves_count =
-      GetRegisterMemorySizePerComputeUnit() / register_usage_per_wave;
-  return std::min(possible_waves_count, GetMaximumWavesCount());
-}
-
-int AdrenoInfo::GetWaveSize(bool full_wave) const {
-  if (IsAdreno6xx()) {
-    return full_wave ? 128 : 64;
-  } else if (IsAdreno5xx() || IsAdreno4xx()) {
-    return full_wave ? 64 : 32;
-  } else {
-    // all other versions not supported
-    return 1;
-  }
-}
-
-MaliInfo::MaliInfo(const std::string& gpu_description)
-    : gpu_version(GetMaliGpuVersion(gpu_description)) {}
-
-bool MaliInfo::IsMaliT6xx() const {
-  return gpu_version == MaliGpu::kT604 || gpu_version == MaliGpu::kT622 ||
-         gpu_version == MaliGpu::kT624 || gpu_version == MaliGpu::kT628 ||
-         gpu_version == MaliGpu::kT658 || gpu_version == MaliGpu::kT678;
-}
-
-bool MaliInfo::IsMaliT7xx() const {
-  return gpu_version == MaliGpu::kT720 || gpu_version == MaliGpu::kT760;
-}
-
-bool MaliInfo::IsMaliT8xx() const {
-  return gpu_version == MaliGpu::kT820 || gpu_version == MaliGpu::kT830 ||
-         gpu_version == MaliGpu::kT860 || gpu_version == MaliGpu::kT880;
-}
-
-bool MaliInfo::IsMidgard() const {
-  return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx();
-}
-
-bool MaliInfo::IsBifrostGen1() const {
-  return gpu_version == MaliGpu::kG31 || gpu_version == MaliGpu::kG51 ||
-         gpu_version == MaliGpu::kG71;
-}
-
-bool MaliInfo::IsBifrostGen2() const {
-  return gpu_version == MaliGpu::kG52 || gpu_version == MaliGpu::kG72;
-}
-
-bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGpu::kG76; }
-
-bool MaliInfo::IsBifrost() const {
-  return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3();
-}
-
-bool MaliInfo::IsValhall() const {
-  return gpu_version == MaliGpu::kG57 || gpu_version == MaliGpu::kG77 ||
-         gpu_version == MaliGpu::kG68 || gpu_version == MaliGpu::kG78;
-}
-
-bool GpuInfo::SupportsTextureArray() const {
-  return cl_version >= OpenCLVersion::CL_1_2;
-}
-
-bool GpuInfo::SupportsImageBuffer() const {
-  return cl_version >= OpenCLVersion::CL_1_2;
-}
-
-bool GpuInfo::SupportsImage3D() const {
-  if (IsMali() && mali_info.IsMidgard()) {
-    // On Mali T880 read_imageh doesn't compile with image3d_t
-    return false;
-  }
-  return supports_image3d_writes;
-}
-
-bool GpuInfo::SupportsFloatImage2D(DataType data_type, int channels) const {
-  if (channels == 1) {
-    return data_type == DataType::FLOAT32 ? supports_r_f32_tex2d
-                                          : supports_r_f16_tex2d;
-  } else if (channels == 2) {
-    return data_type == DataType::FLOAT32 ? supports_rg_f32_tex2d
-                                          : supports_rg_f16_tex2d;
-  } else if (channels == 3) {
-    return data_type == DataType::FLOAT32 ? supports_rgb_f32_tex2d
-                                          : supports_rgb_f16_tex2d;
-  } else if (channels == 4) {
-    return data_type == DataType::FLOAT32 ? supports_rgba_f32_tex2d
-                                          : supports_rgba_f16_tex2d;
-  } else {
-    return false;
-  }
-}
-
-bool GpuInfo::SupportsExtension(const std::string& extension) const {
-  for (const auto& ext : extensions) {
-    if (ext == extension) {
-      return true;
-    }
-  }
-  return false;
-}
-
-bool GpuInfo::IsCL20OrHigher() const {
-  return cl_version != OpenCLVersion::CL_1_0 &&
-         cl_version != OpenCLVersion::CL_1_1 &&
-         cl_version != OpenCLVersion::CL_1_2;
-}
-
-bool GpuInfo::SupportsSubGroupWithSize(int sub_group_size) const {
-  for (auto subgroup_size : supported_subgroup_sizes) {
-    if (sub_group_size == subgroup_size) {
-      return true;
-    }
-  }
-  return false;
-}
-
-bool GpuInfo::IsAdreno() const { return gpu_vendor == GpuVendor::kQualcomm; }
-
-bool GpuInfo::IsApple() const { return gpu_vendor == GpuVendor::kApple; }
-
-bool GpuInfo::IsMali() const { return gpu_vendor == GpuVendor::kMali; }
-
-bool GpuInfo::IsPowerVR() const { return gpu_vendor == GpuVendor::kPowerVR; }
-
-bool GpuInfo::IsNvidia() const { return gpu_vendor == GpuVendor::kNvidia; }
-
-bool GpuInfo::IsAMD() const { return gpu_vendor == GpuVendor::kAMD; }
-
-bool GpuInfo::IsIntel() const { return gpu_vendor == GpuVendor::kIntel; }
-
-}  // namespace cl
-}  // namespace gpu
-}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/device_info.h b/tensorflow/lite/delegates/gpu/cl/device_info.h
deleted file mode 100644
index a2ba4ca2597..00000000000
--- a/tensorflow/lite/delegates/gpu/cl/device_info.h
+++ /dev/null
@@ -1,245 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_DEVICE_INFO_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_DEVICE_INFO_H_
-
-#include <string>
-#include <vector>
-
-#include "tensorflow/lite/delegates/gpu/common/data_type.h"
-
-// for use only in device_info.cc, but keep here to make tests
-int GetAdrenoGPUVersion(const std::string& gpu_version);
-
-namespace tflite {
-namespace gpu {
-namespace cl {
-
-enum class GpuVendor {
-  kApple,
-  kQualcomm,
-  kMali,
-  kPowerVR,
-  kNvidia,
-  kAMD,
-  kIntel,
-  kUnknown
-};
-
-std::string GpuVendorToString(GpuVendor v);
-
-enum class OpenCLVersion {
-  CL_1_0,
-  CL_1_1,
-  CL_1_2,
-  CL_2_0,
-  CL_2_1,
-  CL_2_2,
-  CL_3_0
-};
-std::string OpenCLVersionToString(OpenCLVersion version);
-
-enum class AdrenoGpu {
-  // Adreno 6xx series
-  kAdreno685,
-  kAdreno680,
-  kAdreno675,
-  kAdreno650,
-  kAdreno640,
-  kAdreno630,
-  kAdreno620,
-  kAdreno618,
-  kAdreno616,
-  kAdreno615,
-  kAdreno612,
-  kAdreno610,
-  kAdreno605,
-  // Adreno 5xx series
-  kAdreno540,
-  kAdreno530,
-  kAdreno512,
-  kAdreno510,
-  kAdreno509,
-  kAdreno508,
-  kAdreno506,
-  kAdreno505,
-  kAdreno504,
-  // Adreno 4xx series
-  kAdreno430,
-  kAdreno420,
-  kAdreno418,
-  kAdreno405,
-  // Adreno 3xx series
-  kAdreno330,
-  kAdreno320,
-  kAdreno308,
-  kAdreno306,
-  kAdreno305,
-  kAdreno304,
-  // Adreno 2xx series
-  kAdreno225,
-  kAdreno220,
-  kAdreno205,
-  kAdreno203,
-  kAdreno200,
-  // Adreno 1xx series
-  kAdreno130,
-  kAdreno120,
-  kUnknown
-};
-
-struct AdrenoInfo {
-  AdrenoInfo() = default;
-  explicit AdrenoInfo(const std::string& device_version);
-
-  AdrenoGpu adreno_gpu;
-
-  bool IsAdreno1xx() const;
-  bool IsAdreno2xx() const;
-  bool IsAdreno3xx() const;
-  bool IsAdreno4xx() const;
-  bool IsAdreno5xx() const;
-  bool IsAdreno6xx() const;
-  bool IsAdreno6xxOrHigher() const;
-
-  // This function returns some not very documented physical parameter of
-  // Adreno6xx GPU.
-  // We obtained it using Snapdragon Profiler.
-  int GetMaximumWavesCount() const;
-
-  // returns amount of register memory per CU(Compute Unit) in bytes.
-  int GetRegisterMemorySizePerComputeUnit() const;
-
-  // returns maximum possible amount of waves based on register usage.
-  int GetMaximumWavesCount(int register_footprint_per_tread,
-                           bool full_wave = true) const;
-
-  int GetWaveSize(bool full_wave) const;
-
-  // Not supported on some Adreno devices with specific driver version.
-  // b/131099086
-  bool support_one_layer_texture_array = true;
-};
-
-enum class MaliGpu {
-  kUnknown,
-  kT604,
-  kT622,
-  kT624,
-  kT628,
-  kT658,
-  kT678,
-  kT720,
-  kT760,
-  kT820,
-  kT830,
-  kT860,
-  kT880,
-  kG31,
-  kG51,
-  kG71,
-  kG52,
-  kG72,
-  kG76,
-  kG57,
-  kG77,
-  kG68,
-  kG78,
-};
-
-struct MaliInfo {
-  MaliInfo() = default;
-  explicit MaliInfo(const std::string& gpu_description);
-  MaliGpu gpu_version;
-
-  bool IsMaliT6xx() const;
-  bool IsMaliT7xx() const;
-  bool IsMaliT8xx() const;
-  bool IsMidgard() const;
-  bool IsBifrostGen1() const;
-  bool IsBifrostGen2() const;
-  bool IsBifrostGen3() const;
-  bool IsBifrost() const;
-  bool IsValhall() const;
-};
-
-struct GpuInfo {
-  GpuInfo() = default;
-
-  bool IsAdreno() const;
-  bool IsApple() const;
-  bool IsMali() const;
-  bool IsPowerVR() const;
-  bool IsNvidia() const;
-  bool IsAMD() const;
-  bool IsIntel() const;
-
-  bool SupportsTextureArray() const;
-  bool SupportsImageBuffer() const;
-  bool SupportsImage3D() const;
-
-  bool SupportsFloatImage2D(DataType data_type, int channels) const;
-
-  bool SupportsExtension(const std::string& extension) const;
-  bool IsCL20OrHigher() const;
-  bool SupportsSubGroupWithSize(int sub_group_size) const;
-
-  std::vector<std::string> extensions;
-  bool supports_fp16;
-  bool supports_image3d_writes;
-  GpuVendor gpu_vendor;
-  OpenCLVersion cl_version;
-  int compute_units_count;
-  uint64_t buffer_max_size;
-  uint64_t image2d_max_width;
-  uint64_t image2d_max_height;
-  uint64_t image_buffer_max_size;
-  uint64_t image_array_max_layers;
-  uint64_t image3d_max_width;
-  uint64_t image3d_max_height;
-  uint64_t image3d_max_depth;
-  int max_work_group_size_x;
-  int max_work_group_size_y;
-  int max_work_group_size_z;
-  std::vector<int> supported_subgroup_sizes;
-
-  // rtn is ROUND_TO_NEAREST
-  // with rtn precision is much better then with rtz (ROUND_TO_ZERO)
-  // Adreno 3xx supports only rtz, Adreno 4xx and more support rtn
-  // Mali from T6xx supports rtn
-  // PowerVR supports only rtz
-  bool supports_fp32_rtn;
-  bool supports_fp16_rtn;
-
-  bool supports_r_f16_tex2d = false;
-  bool supports_rg_f16_tex2d = false;
-  bool supports_rgb_f16_tex2d = false;
-  bool supports_rgba_f16_tex2d = false;
-
-  bool supports_r_f32_tex2d = false;
-  bool supports_rg_f32_tex2d = false;
-  bool supports_rgb_f32_tex2d = false;
-  bool supports_rgba_f32_tex2d = false;
-
-  AdrenoInfo adreno_info;
-  MaliInfo mali_info;
-};
-
-}  // namespace cl
-}  // namespace gpu
-}  // namespace tflite
-
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_DEVICE_INFO_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/environment.cc b/tensorflow/lite/delegates/gpu/cl/environment.cc
index 9b2fef288fe..275ea696e09 100644
--- a/tensorflow/lite/delegates/gpu/cl/environment.cc
+++ b/tensorflow/lite/delegates/gpu/cl/environment.cc
@@ -48,6 +48,39 @@ absl::Status CreateEnvironment(Environment* result, bool shared,
   return result->Init();
 }
 
+bool IsGpuSupportsStorageType(const GpuInfo& gpu_info,
+                              TensorStorageType storage_type) {
+  switch (storage_type) {
+    case TensorStorageType::TEXTURE_2D:
+      return !gpu_info.IsAMD();
+    case TensorStorageType::BUFFER:
+      return true;
+    case TensorStorageType::TEXTURE_ARRAY:
+      return !gpu_info.IsAMD() && gpu_info.SupportsTextureArray();
+    case TensorStorageType::IMAGE_BUFFER:
+      return (gpu_info.IsAdreno() || gpu_info.IsAMD() || gpu_info.IsNvidia()) &&
+             gpu_info.SupportsImageBuffer();
+    case TensorStorageType::TEXTURE_3D:
+      return !gpu_info.IsAMD() && gpu_info.SupportsImage3D();
+    case TensorStorageType::SINGLE_TEXTURE_2D:
+      return false;
+    case TensorStorageType::UNKNOWN:
+      return false;
+  }
+  return false;
+}
+
+bool IsGpuSupportsPrecision(const GpuInfo& gpu_info,
+                            CalculationsPrecision precision) {
+  switch (precision) {
+    case CalculationsPrecision::F32_F16:
+    case CalculationsPrecision::F16:
+      return gpu_info.SupportsFP16();
+    case CalculationsPrecision::F32:
+      return true;
+  }
+}
+
 }  // namespace
 
 Environment::Environment(CLDevice&& device, CLContext&& context,
@@ -77,7 +110,8 @@ Environment& Environment::operator=(Environment&& environment) {
 }
 
 absl::Status Environment::Init() {
-  if (device().IsAdreno() && device().SupportsTextureArray()) {
+  if (device().GetInfo().IsAdreno() &&
+      device().GetInfo().SupportsTextureArray()) {
     const auto& adreno_info = device().info_.adreno_info;
     // Some Adreno < 600 have bug with one layer texture array. b/131099086
     // If we have one layer texture array and will write smt from kernel to this
@@ -117,13 +151,7 @@ std::vector<CalculationsPrecision> Environment::GetSupportedPrecisions() const {
 }
 
 bool Environment::IsSupported(CalculationsPrecision precision) const {
-  switch (precision) {
-    case CalculationsPrecision::F32_F16:
-    case CalculationsPrecision::F16:
-      return device_.SupportsFP16();
-    case CalculationsPrecision::F32:
-      return true;
-  }
+  return IsGpuSupportsPrecision(device_.GetInfo(), precision);
 }
 
 std::vector<TensorStorageType> Environment::GetSupportedStorages() const {
@@ -153,24 +181,7 @@ Environment::GetSupportedStoragesWithHWZeroClampSupport() const {
 }
 
 bool Environment::IsSupported(TensorStorageType storage_type) const {
-  switch (storage_type) {
-    case TensorStorageType::TEXTURE_2D:
-      return !device_.IsAMD();
-    case TensorStorageType::BUFFER:
-      return true;
-    case TensorStorageType::TEXTURE_ARRAY:
-      return !device_.IsAMD() && device_.SupportsTextureArray();
-    case TensorStorageType::IMAGE_BUFFER:
-      return (device_.IsAdreno() || device_.IsAMD() || device_.IsNvidia()) &&
-             device_.SupportsImageBuffer();
-    case TensorStorageType::TEXTURE_3D:
-      return !device_.IsAMD() && device_.SupportsImage3D();
-    case TensorStorageType::SINGLE_TEXTURE_2D:
-      return false;
-    case TensorStorageType::UNKNOWN:
-      return false;
-  }
-  return false;
+  return IsGpuSupportsStorageType(device_.GetInfo(), storage_type);
 }
 
 TensorStorageType GetFastestStorageType(const GpuInfo& gpu_info) {
diff --git a/tensorflow/lite/delegates/gpu/cl/environment.h b/tensorflow/lite/delegates/gpu/cl/environment.h
index 8917351841f..5138e59ee7e 100644
--- a/tensorflow/lite/delegates/gpu/cl/environment.h
+++ b/tensorflow/lite/delegates/gpu/cl/environment.h
@@ -19,9 +19,9 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
 #include "tensorflow/lite/delegates/gpu/cl/program_cache.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc
index 599e6766301..2d4e6c54b39 100644
--- a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc
+++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc
@@ -89,7 +89,7 @@ absl::Status CreateClEventFromEglSync(cl_context context,
 }
 
 bool IsClEventFromEglSyncSupported(const CLDevice& device) {
-  return device.SupportsExtension("cl_khr_egl_event");
+  return device.GetInfo().SupportsExtension("cl_khr_egl_event");
 }
 
 absl::Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id,
@@ -126,7 +126,7 @@ absl::Status CreateClMemoryFromGlTexture(GLenum texture_target,
 
 bool IsGlSharingSupported(const CLDevice& device) {
   return clCreateFromGLBuffer && clCreateFromGLTexture &&
-         device.SupportsExtension("cl_khr_gl_sharing");
+         device.GetInfo().SupportsExtension("cl_khr_gl_sharing");
 }
 
 AcquiredGlObjects::~AcquiredGlObjects() { Release({}, nullptr).IgnoreError(); }
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
index 332de066bca..3bfc455865d 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
@@ -27,7 +27,6 @@ limitations under the License.
 #include "absl/container/flat_hash_set.h"
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h"
 #include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
@@ -38,6 +37,7 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
 #include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
@@ -163,14 +163,14 @@ absl::Status InferenceContext::InitFromGraph(
   ReserveGraphTensors(create_info, creation_context.GetGpuInfo(), graph);
   precision_ = create_info.precision;
   storage_type_ = create_info.storage_type;
-  if (env->device().IsMali()) {
+  if (env->device().GetInfo().IsMali()) {
     need_flush_ = true;
     need_manual_release_ = true;
 
     flush_periodically_ = true;
     flush_period_ = 24;
   }
-  if (env->device().IsPowerVR()) {
+  if (env->device().GetInfo().IsPowerVR()) {
     need_flush_ = true;
   }
   CopyInAndOutIds(graph);
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
index 402c30c247e..afdadd86133 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
@@ -13,11 +13,10 @@ cc_library(
     srcs = ["add.cc"],
     hdrs = ["add.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -79,13 +78,12 @@ cc_library(
     srcs = ["concat_xy.cc"],
     hdrs = ["concat_xy.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -94,14 +92,13 @@ cc_library(
     srcs = ["concat_z.cc"],
     hdrs = ["concat_z.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -110,10 +107,6 @@ cc_library(
     srcs = ["conv_buffer_1x1.cc"],
     hdrs = ["conv_buffer_1x1.h"],
     deps = [
-        ":conv_common",
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
@@ -127,14 +120,14 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:winograd_util",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:util",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_conversion",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
-cc_library(
-    name = "conv_common",
-    hdrs = ["conv_common.h"],
-)
-
 cc_test(
     name = "conv_buffer_1x1_test",
     srcs = ["conv_buffer_1x1_test.cc"],
@@ -157,9 +150,6 @@ cc_library(
     srcs = ["conv_constants.cc"],
     hdrs = ["conv_constants.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
@@ -170,6 +160,9 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:util",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -195,10 +188,6 @@ cc_library(
     srcs = ["conv_powervr.cc"],
     hdrs = ["conv_powervr.h"],
     deps = [
-        ":conv_common",
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
@@ -212,6 +201,11 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:winograd_util",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:util",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_conversion",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -238,14 +232,14 @@ cc_library(
     srcs = ["conv_weights_converter.cc"],
     hdrs = ["conv_weights_converter.h"],
     deps = [
-        ":conv_common",
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:cl_command_queue",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:util",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -254,8 +248,6 @@ cc_library(
     srcs = ["converter.cc"],
     hdrs = ["converter.h"],
     deps = [
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu:spi",
         "//tensorflow/lite/delegates/gpu/cl:cl_arguments",
         "//tensorflow/lite/delegates/gpu/cl:cl_command_queue",
@@ -267,6 +259,8 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:util",
         "//tensorflow/lite/delegates/gpu/common/task:arguments",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
+        "//tensorflow/lite/delegates/gpu/common/task:util",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -275,10 +269,6 @@ cc_library(
     srcs = ["convolution_transposed.cc"],
     hdrs = ["convolution_transposed.h"],
     deps = [
-        ":conv_common",
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
@@ -290,6 +280,10 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_conversion",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -316,9 +310,6 @@ cc_library(
     srcs = ["convolution_transposed_3x3.cc"],
     hdrs = ["convolution_transposed_3x3.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
@@ -329,6 +320,10 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_conversion",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -355,9 +350,6 @@ cc_library(
     srcs = ["convolution_transposed_3x3_thin.cc"],
     hdrs = ["convolution_transposed_3x3_thin.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
@@ -368,6 +360,10 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_conversion",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -393,9 +389,6 @@ cc_library(
     srcs = ["convolution_transposed_4x4.cc"],
     hdrs = ["convolution_transposed_4x4.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
@@ -406,6 +399,10 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_conversion",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -432,9 +429,6 @@ cc_library(
     srcs = ["convolution_transposed_thin.cc"],
     hdrs = ["convolution_transposed_thin.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
         "//tensorflow/lite/delegates/gpu/cl:texture2d",
@@ -445,6 +439,8 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -470,9 +466,6 @@ cc_library(
     srcs = ["depthwise_conv.cc"],
     hdrs = ["depthwise_conv.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
@@ -485,6 +478,9 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:util",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -510,9 +506,6 @@ cc_library(
     srcs = ["depthwise_conv_3x3.cc"],
     hdrs = ["depthwise_conv_3x3.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
         "//tensorflow/lite/delegates/gpu/cl:texture2d",
@@ -523,6 +516,8 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -548,11 +543,10 @@ cc_library(
     srcs = ["elementwise.cc"],
     hdrs = ["elementwise.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
         "//tensorflow/lite/delegates/gpu/cl:storage_type_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -579,11 +573,8 @@ cc_library(
     srcs = ["fully_connected.cc"],
     hdrs = ["fully_connected.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
-        "//tensorflow/lite/delegates/gpu/cl:device_info",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
         "//tensorflow/lite/delegates/gpu/cl:texture2d",
@@ -593,6 +584,7 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -608,52 +600,26 @@ cc_test(
     deps = [
         ":cl_test",
         ":fully_connected",
-        ":gpu_operation",
         "//tensorflow/lite/delegates/gpu/cl:environment",
         "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:tensor",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_googletest//:gtest_main",
     ],
 )
 
-cc_library(
-    name = "gpu_operation",
-    srcs = ["gpu_operation.cc"],
-    hdrs = ["gpu_operation.h"],
-    deps = [
-        ":util",
-        ":work_group_picking",
-        "//tensorflow/lite/delegates/gpu/cl:device_info",
-        "//tensorflow/lite/delegates/gpu/cl:serialization_cc_fbs",
-        "//tensorflow/lite/delegates/gpu/common:access_type",
-        "//tensorflow/lite/delegates/gpu/common:data_type",
-        "//tensorflow/lite/delegates/gpu/common:kernel_info",
-        "//tensorflow/lite/delegates/gpu/common:precision",
-        "//tensorflow/lite/delegates/gpu/common:status",
-        "//tensorflow/lite/delegates/gpu/common:types",
-        "//tensorflow/lite/delegates/gpu/common/task:arguments",
-        "//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
-        "//tensorflow/lite/delegates/gpu/common/task:compiler_options",
-        "//tensorflow/lite/delegates/gpu/common/task:gpu_tensor",
-        "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
-        "//tensorflow/lite/delegates/gpu/common/task:tuning_type",
-        "@com_google_absl//absl/strings",
-    ],
-)
-
 cc_library(
     name = "lstm",
     srcs = ["lstm.cc"],
     hdrs = ["lstm.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -698,12 +664,11 @@ cc_library(
     srcs = ["max_unpooling.cc"],
     hdrs = ["max_unpooling.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -729,14 +694,12 @@ cc_library(
     srcs = ["mean_stddev_normalization.cc"],
     hdrs = ["mean_stddev_normalization.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:cl_program",
-        "//tensorflow/lite/delegates/gpu/cl:device_info",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -763,12 +726,11 @@ cc_library(
     srcs = ["padding.cc"],
     hdrs = ["padding.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -794,14 +756,14 @@ cc_library(
     srcs = ["pooling.cc"],
     hdrs = ["pooling.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:util",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -827,8 +789,6 @@ cc_library(
     srcs = ["prelu.cc"],
     hdrs = ["prelu.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
         "//tensorflow/lite/delegates/gpu/cl:cl_context",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
@@ -837,6 +797,7 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:variant",
     ],
@@ -864,8 +825,6 @@ cc_library(
     srcs = ["quantize_and_dequantize.cc"],
     hdrs = ["quantize_and_dequantize.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
         "//tensorflow/lite/delegates/gpu/cl:cl_context",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
@@ -873,6 +832,7 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:variant",
     ],
@@ -901,14 +861,14 @@ cc_library(
     srcs = ["reduce.cc"],
     hdrs = ["reduce.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/common:kernel_info",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:util",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -934,8 +894,8 @@ cc_library(
     srcs = ["relu.cc"],
     hdrs = ["relu.h"],
     deps = [
-        ":gpu_operation",
         "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -962,12 +922,11 @@ cc_library(
     srcs = ["reshape.cc"],
     hdrs = ["reshape.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -993,13 +952,12 @@ cc_library(
     srcs = ["reshapex4.cc"],
     hdrs = ["reshapex4.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:cl_command_queue",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -1025,13 +983,12 @@ cc_library(
     srcs = ["softmax.cc"],
     hdrs = ["softmax.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -1057,11 +1014,11 @@ cc_library(
     srcs = ["softmax1x1.cc"],
     hdrs = ["softmax1x1.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:util",
     ],
 )
 
@@ -1087,13 +1044,12 @@ cc_library(
     srcs = ["space_to_depth.cc"],
     hdrs = ["space_to_depth.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -1119,11 +1075,10 @@ cc_library(
     srcs = ["strided_slice.cc"],
     hdrs = ["strided_slice.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -1149,11 +1104,10 @@ cc_library(
     srcs = ["transpose.cc"],
     hdrs = ["transpose.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -1180,12 +1134,11 @@ cc_library(
     srcs = ["resize.cc"],
     hdrs = ["resize.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -1206,31 +1159,11 @@ cc_test(
     ],
 )
 
-cc_library(
-    name = "util",
-    srcs = ["util.cc"],
-    hdrs = ["util.h"],
-    deps = [
-        "//tensorflow/lite/delegates/gpu/cl:device_info",
-        "//tensorflow/lite/delegates/gpu/common:data_type",
-        "//tensorflow/lite/delegates/gpu/common:precision",
-        "//tensorflow/lite/delegates/gpu/common:shape",
-        "//tensorflow/lite/delegates/gpu/common:tensor",
-        "//tensorflow/lite/delegates/gpu/common:types",
-        "//tensorflow/lite/delegates/gpu/common:util",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-    ],
-)
-
 cc_library(
     name = "winograd",
     srcs = ["winograd.cc"],
     hdrs = ["winograd.h"],
     deps = [
-        ":gpu_operation",
-        ":util",
-        ":work_group_picking",
         "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
@@ -1240,6 +1173,8 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:winograd_util",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
         "@com_google_absl//absl/strings:str_format",
     ],
 )
@@ -1254,7 +1189,6 @@ cc_test(
     ],
     deps = [
         ":cl_test",
-        ":util",
         ":winograd",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
@@ -1263,21 +1197,6 @@ cc_test(
     ],
 )
 
-cc_library(
-    name = "work_group_picking",
-    srcs = ["work_group_picking.cc"],
-    hdrs = ["work_group_picking.h"],
-    deps = [
-        "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
-        "//tensorflow/lite/delegates/gpu/cl:device_info",
-        "//tensorflow/lite/delegates/gpu/common:kernel_info",
-        "//tensorflow/lite/delegates/gpu/common:types",
-        "//tensorflow/lite/delegates/gpu/common:util",
-        "//tensorflow/lite/delegates/gpu/common:workgroup_selection",
-        "//tensorflow/lite/delegates/gpu/common/task:tuning_type",
-    ],
-)
-
 test_suite(
     name = "all_tests",
     tests = [
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc
index 1cb41e79d88..15714f07b23 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc
@@ -18,7 +18,6 @@ limitations under the License.
 #include <string>
 
 #include "absl/strings/str_cat.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.h b/tensorflow/lite/delegates/gpu/cl/kernels/add.h
index 0e9d7e0d333..3a10f71f6b4 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/add.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.h
@@ -19,9 +19,9 @@ limitations under the License.
 #include <string>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
index fa5b933db8a..faf893fa1f0 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
@@ -19,9 +19,8 @@ limitations under the License.
 #include <string>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
@@ -69,7 +68,7 @@ std::string GetConcatKernelCode(const OperationDef& op_def,
     dst_coord += ", " + dst_coords[i];
   }
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h
index 9dd3fcee52a..2ccdc2a26ad 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h
@@ -17,9 +17,9 @@ limitations under the License.
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONCAT_XY_H_
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
index cefbf557ac1..6a264c6185c 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
@@ -17,9 +17,8 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
@@ -43,7 +42,7 @@ std::string GetConcatKernelCode(const OperationDef& op_def,
     tensor_names[i] = "src_tensor_" + std::to_string(i);
   }
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int X = get_global_id(0);\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h
index 16341af4187..227ce15d04d 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h
@@ -20,9 +20,9 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc
index ed15df0bea7..70d79813829 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc
@@ -20,9 +20,9 @@ limitations under the License.
 #include <utility>
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -139,7 +139,7 @@ ConvBuffer1x1::ConvParams GetBestParams(const GpuInfo& gpu_info,
   conv_params.element_size = 4;
   conv_params.block_size = int3(1, 1, 1);
   if (gpu_info.IsMali() && definition.precision == CalculationsPrecision::F16 &&
-      gpu_info.compute_units_count <= 4) {
+      gpu_info.GetComputeUnitsCount() <= 4) {
     conv_params.block_size.x *= 2;
   }
   return conv_params;
@@ -194,7 +194,7 @@ std::string ConvBuffer1x1::GenerateConvBuffer1x1(
   }
   AddDstTensor("dst_tensor", dst_desc);
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   switch (op_def.precision) {
     case CalculationsPrecision::F32:
       c += "#define FLT8 float8\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h
index d93e9a0460c..3687b963078 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h
@@ -18,9 +18,6 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/util.h"
@@ -28,6 +25,9 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
@@ -52,9 +52,9 @@ class ConvBuffer1x1 : public GPUOperation {
       std::vector<int3>* work_groups) const override;
   int3 GetGridSize() const override;
 
-  ConvWeightsDescription GetConvWeightsDescription() const {
-    ConvWeightsDescription desc;
-    desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4;
+  WeightsDescription GetWeightsDescription() const {
+    WeightsDescription desc;
+    desc.layout = WeightsLayout::kOHWIOGroupI4O4;
     desc.output_group_size = conv_params_.block_size.z;
     return desc;
   }
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc
index 99b1e041d1c..e272ad64f3a 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc
@@ -18,8 +18,8 @@ limitations under the License.
 #include <string>
 #include <utility>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -111,14 +111,13 @@ std::string GenerateConvolutionConstantCode(const OperationDef& op_def,
   }
   op->AddDstTensor("dst_tensor", dst_desc);
 
-  std::string c = GetCommonDefines(op_def.precision);
-
   const int out_z = DivideRoundUp(weights_shape.o, 4);
   const std::string kOutZ = std::to_string(out_z);
   const int src_depth = DivideRoundUp(weights_shape.i, 4);
 
   const std::string postfixes[] = {".x", ".xy", ".xyz", ""};
 
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int X = get_global_id(0);\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h
index 5a7cd248999..be75f06befb 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h
@@ -17,7 +17,6 @@ limitations under the License.
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_CONSTANTS_H_
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/util.h"
@@ -25,6 +24,7 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc
index 448a243a830..6d4395b2b18 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc
@@ -20,11 +20,11 @@ limitations under the License.
 #include <utility>
 
 #include "absl/strings/substitute.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -443,9 +443,9 @@ std::string ConvPowerVR::GenerateConv(const GpuInfo& gpu_info,
   const std::string weights_global_ptr =
       weights_space + " " + weights_data_type + "*";
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   if (use_simd_broadcast) {
-    if (gpu_info.cl_version == OpenCLVersion::CL_2_0) {
+    if (gpu_info.opencl_info.cl_version == OpenClVersion::kCl2_0) {
       c += "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n";
     } else if (gpu_info.SupportsExtension("cl_intel_subgroups")) {
       c += "#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n";
@@ -1045,7 +1045,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
     if (dst_shape) {
       int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth;
       float task_size_per_cu =
-          static_cast<float>(task_size) / gpu_info.compute_units_count;
+          static_cast<float>(task_size) / gpu_info.GetComputeUnitsCount();
       int block_size = conv_params.block_size.x * conv_params.block_size.y *
                        conv_params.block_size.w;
       float threads_per_cu = task_size_per_cu / block_size;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h
index 85289f63339..b1862e48a71 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h
@@ -21,9 +21,6 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
@@ -32,6 +29,9 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
@@ -50,9 +50,9 @@ class ConvPowerVR : public GPUOperation {
   absl::Status BindArguments(ArgumentsBinder* args) override;
   int3 GetGridSize() const override;
 
-  ConvWeightsDescription GetConvWeightsDescription() const {
-    ConvWeightsDescription desc;
-    desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4;
+  WeightsDescription GetWeightsDescription() const {
+    WeightsDescription desc;
+    desc.layout = WeightsLayout::kOHWIOGroupI4O4;
     desc.output_group_size = conv_params_.block_size.w;
     return desc;
   }
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc
index 521cbefd885..5d3f4e26e95 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc
@@ -17,37 +17,35 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
 namespace cl {
 
 ConverterToConvWeights::ConverterToConvWeights(
-    const OperationDef& definition,
-    const ConvWeightsDescription& conv_weights_desc)
-    : GPUOperation(definition), conv_weights_desc_(conv_weights_desc) {
-  code_ = GetConverterToConvWeightsCode(definition_, conv_weights_desc_);
+    const OperationDef& definition, const WeightsDescription& weights_desc)
+    : GPUOperation(definition), weights_desc_(weights_desc) {
+  code_ = GetConverterToConvWeightsCode(definition_, weights_desc_);
 }
 
 ConverterToConvWeights::ConverterToConvWeights(
     ConverterToConvWeights&& operation)
     : GPUOperation(std::move(operation)),
-      conv_weights_desc_(operation.conv_weights_desc_) {}
+      weights_desc_(std::move(operation.weights_desc_)) {}
 
 ConverterToConvWeights& ConverterToConvWeights::operator=(
     ConverterToConvWeights&& operation) {
   if (this != &operation) {
-    conv_weights_desc_ = operation.conv_weights_desc_;
+    weights_desc_ = std::move(operation.weights_desc_);
     GPUOperation::operator=(std::move(operation));
   }
   return *this;
 }
 
 std::string ConverterToConvWeights::GetConverterToConvWeightsCode(
-    const OperationDef& op_def,
-    const ConvWeightsDescription& conv_weights_desc) {
+    const OperationDef& op_def, const WeightsDescription& conv_weights_desc) {
   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
   args_.AddFloat("mask_x");
@@ -55,7 +53,7 @@ std::string ConverterToConvWeights::GetConverterToConvWeightsCode(
   args_.AddFloat("mask_z");
   args_.AddFloat("mask_w");
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int GROUP_SIZE = " +
@@ -120,16 +118,15 @@ absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder* args) {
 
 int3 ConverterToConvWeights::GetGridSize() const {
   const int grid_x = DivideRoundUp(
-      AlignByN(src_[0]->Batch(), 4 * conv_weights_desc_.output_group_size), 4);
+      AlignByN(src_[0]->Batch(), 4 * weights_desc_.output_group_size), 4);
   const int grid_y = src_[0]->Slices();
   const int grid_z = src_[0]->Width() * src_[0]->Height();
   return int3(grid_x, grid_y, grid_z);
 }
 
 ConverterToConvWeights CreateConverterToConvWeights(
-    const OperationDef& definition,
-    const ConvWeightsDescription& conv_weights_desc) {
-  return ConverterToConvWeights(definition, conv_weights_desc);
+    const OperationDef& definition, const WeightsDescription& weights_desc) {
+  return ConverterToConvWeights(definition, weights_desc);
 }
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h
index 3c7314ea6c9..3d5306886b3 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h
@@ -18,9 +18,9 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
@@ -30,7 +30,7 @@ namespace cl {
 class ConverterToConvWeights : public GPUOperation {
  public:
   ConverterToConvWeights(const OperationDef& definition,
-                         const ConvWeightsDescription& conv_weights_desc);
+                         const WeightsDescription& weights_desc);
   absl::Status BindArguments(ArgumentsBinder* args) override;
   int3 GetGridSize() const override;
 
@@ -42,10 +42,9 @@ class ConverterToConvWeights : public GPUOperation {
 
  private:
   std::string GetConverterToConvWeightsCode(
-      const OperationDef& op_def,
-      const ConvWeightsDescription& conv_weights_desc);
+      const OperationDef& op_def, const WeightsDescription& weights_desc);
 
-  ConvWeightsDescription conv_weights_desc_;
+  WeightsDescription weights_desc_;
 };
 
 // We expect src BHWC tensor and we assume that B is O, H = H, W = W, C is I
@@ -53,8 +52,7 @@ class ConverterToConvWeights : public GPUOperation {
 // dst.b * dst.h * dst.w * dst.c = AlignByN(src.b, 4) * src.h * src.w
 // AlignByN(src.c, 4)
 ConverterToConvWeights CreateConverterToConvWeights(
-    const OperationDef& definition,
-    const ConvWeightsDescription& conv_weights_desc);
+    const OperationDef& definition, const WeightsDescription& weights_desc);
 
 }  // namespace cl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
index ab6f986ea97..fdab1890b69 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
@@ -22,13 +22,13 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/cl/cl_arguments.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_errors.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor_type_util.h"
 #include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc
index 99df2689b95..f6c05270e8c 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc
@@ -20,10 +20,9 @@ limitations under the License.
 #include <vector>
 
 #include "absl/strings/substitute.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -132,7 +131,7 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode(
 
   const auto& src_def = op_def.src_tensors[0];
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
 
   for (int s = 0; s < block_size.w; ++s) {
     const std::string f0 =
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h
index c25c916cc56..0373c4cc964 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h
@@ -20,9 +20,6 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
@@ -31,6 +28,9 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
@@ -54,9 +54,9 @@ class ConvolutionTransposed : public GPUOperation {
   ConvolutionTransposed(const ConvolutionTransposed&) = delete;
   ConvolutionTransposed& operator=(const ConvolutionTransposed&) = delete;
 
-  ConvWeightsDescription GetConvWeightsDescription() const {
-    ConvWeightsDescription desc;
-    desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4;
+  WeightsDescription GetWeightsDescription() const {
+    WeightsDescription desc;
+    desc.layout = WeightsLayout::kOHWIOGroupI4O4;
     desc.output_group_size = block_size_.w;
     return desc;
   }
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc
index ca1be8ac5e0..1030891ab38 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc
@@ -19,8 +19,7 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -80,6 +79,19 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
   }
   AddDstTensor("dst_tensor", dst_desc);
 
+  if (op_def.src_tensors.size() == 2) {
+    // dynamic weights
+    BufferDescriptor desc;
+    desc.element_type = op_def.src_tensors[1].data_type;
+    desc.element_size = 4;
+    desc.memory_type =
+        weights_upload_type ==
+                ConvolutionTransposed3x3::WeightsUploadType::CONSTANT_MEM
+            ? MemoryType::CONSTANT
+            : MemoryType::GLOBAL;
+    AddSrcBuffer("weights", desc);
+  }
+
   args_.AddInt("filter_offset");
   args_.AddInt("padding_x");
   args_.AddInt("padding_y");
@@ -90,7 +102,7 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
       weights_upload_type ==
           ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_ASYNC;
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   switch (op_def.precision) {
     case CalculationsPrecision::F32:
     case CalculationsPrecision::F16:
@@ -350,6 +362,22 @@ int3 ConvolutionTransposed3x3::GetGridSize() const {
   return int3(grid_x, grid_y, grid_z);
 }
 
+std::vector<int> ConvolutionTransposed3x3::GetSpatialWeightsRemap() const {
+  const int padding_x_rem = abs(padding_.x) % 2;
+  const int padding_y_rem = abs(padding_.y) % 2;
+
+  std::vector<int> remap;
+  if (padding_x_rem == 1 && padding_y_rem == 1) {
+    return std::vector<int>{4, 5, 3, 7, 1, 8, 6, 2, 0};
+  } else if (padding_x_rem == 0 && padding_y_rem == 1) {
+    return std::vector<int>{5, 3, 4, 8, 6, 2, 0, 7, 1};
+  } else if (padding_x_rem == 1 && padding_y_rem == 0) {
+    return std::vector<int>{7, 1, 8, 6, 2, 0, 4, 5, 3};
+  } else {  // padding_x_rem == 0 && padding_y_rem == 0
+    return std::vector<int>{8, 6, 2, 0, 7, 1, 5, 3, 4};
+  }
+}
+
 bool IsConvolutionTransposed3x3Supported(
     const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr) {
@@ -373,6 +401,21 @@ ConvolutionTransposed3x3 CreateConvolutionTransposed3x3(
   return result;
 }
 
+ConvolutionTransposed3x3 CreateConvolutionTransposed3x3DynamicWeights(
+    const GpuInfo& gpu_info, const OperationDef& definition,
+    const ConvolutionTransposedAttributes& attr) {
+  const int2 padding = int2(attr.padding.prepended.w, attr.padding.prepended.h);
+  ConvolutionTransposed3x3 result(definition, gpu_info, padding);
+
+  TensorLinearDescriptor desc;
+  desc.storage_type = LinearStorageType::TEXTURE_2D;
+  desc.element_type = definition.GetDataType();
+  desc.UploadLinearData(attr.bias);
+  result.args_.AddObject(
+      "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+  return result;
+}
+
 }  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h
index 89abf70498a..49550a6cd55 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h
@@ -19,7 +19,6 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/util.h"
@@ -27,6 +26,9 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
@@ -50,6 +52,13 @@ class ConvolutionTransposed3x3 : public GPUOperation {
   ConvolutionTransposed3x3(const ConvolutionTransposed3x3&) = delete;
   ConvolutionTransposed3x3& operator=(const ConvolutionTransposed3x3&) = delete;
 
+  WeightsDescription GetWeightsDescription() const {
+    WeightsDescription desc;
+    desc.layout = WeightsLayout::kOICustomSSpatialI4O4;
+    desc.spatial_remap = GetSpatialWeightsRemap();
+    return desc;
+  }
+
   enum class WeightsUploadType {
     LOCAL_MEM_ASYNC,
     LOCAL_MEM_BY_THREADS,
@@ -63,12 +72,14 @@ class ConvolutionTransposed3x3 : public GPUOperation {
   friend ConvolutionTransposed3x3 CreateConvolutionTransposed3x3(
       const GpuInfo& gpu_info, const OperationDef& definition,
       const ConvolutionTransposedAttributes& attr);
+  friend ConvolutionTransposed3x3 CreateConvolutionTransposed3x3DynamicWeights(
+      const GpuInfo& gpu_info, const OperationDef& definition,
+      const ConvolutionTransposedAttributes& attr);
+
   template <DataType T>
   void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights);
 
-  template <DataType S, typename T>
-  void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
-                            absl::Span<T> dst);
+  std::vector<int> GetSpatialWeightsRemap() const;
 
   std::string GenerateConvolutionTransposedCode(
       const OperationDef& op_def,
@@ -104,71 +115,18 @@ void ConvolutionTransposed3x3::UploadWeights(
 
   if (f32_weights) {
     float4* ptr = reinterpret_cast<float4*>(desc.data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, flt4_count));
+    RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(),
+                                          absl::MakeSpan(ptr, flt4_count));
   } else {
     half4* ptr = reinterpret_cast<half4*>(desc.data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, flt4_count));
+    RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(),
+                                          absl::MakeSpan(ptr, flt4_count));
   }
 
   args_.AddObject("weights",
                   absl::make_unique<BufferDescriptor>(std::move(desc)));
 }
 
-template <DataType S, typename T>
-void ConvolutionTransposed3x3::RearrangeWeightsData(
-    const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) {
-  const int src_depth = DivideRoundUp(weights.shape.i, 4);
-  const int dst_depth = DivideRoundUp(weights.shape.o, 4);
-  const int kernel_x = 3;
-  const int kernel_y = 3;
-
-  const int padding_x_rem = abs(padding_.x) % 2;
-  const int padding_y_rem = abs(padding_.y) % 2;
-
-  // we are reorganizing weights to read them sequentially in kernel
-  std::vector<int> remap;
-  if (padding_x_rem == 1 && padding_y_rem == 1) {
-    remap = {4, 5, 3, 7, 1, 8, 6, 2, 0};
-  } else if (padding_x_rem == 0 && padding_y_rem == 1) {
-    remap = {5, 3, 4, 8, 6, 2, 0, 7, 1};
-  } else if (padding_x_rem == 1 && padding_y_rem == 0) {
-    remap = {7, 1, 8, 6, 2, 0, 4, 5, 3};
-  } else {  // padding_x_rem == 0 && padding_y_rem == 0
-    remap = {8, 6, 2, 0, 7, 1, 5, 3, 4};
-  }
-
-  int counter = 0;
-  for (int d = 0; d < dst_depth; ++d) {
-    for (int s = 0; s < src_depth; ++s) {
-      for (int y = 0; y < kernel_y; ++y) {
-        for (int x = 0; x < kernel_x; ++x) {
-          const int kernel_index = remap[y * kernel_x + x];
-          const int kernel_index_x = kernel_index % kernel_x;
-          const int kernel_index_y = kernel_index / kernel_x;
-          T filters[4];
-          for (int j = 0; j < 4; ++j) {
-            for (int i = 0; i < 4; ++i) {
-              const int s_ch = s * 4 + i;
-              const int d_ch = d * 4 + j;
-              if (s_ch < weights.shape.i && d_ch < weights.shape.o) {
-                const int f_index = weights.shape.LinearIndex(
-                    {d_ch, kernel_index_y, kernel_index_x, s_ch});
-                filters[i][j] = weights.data[f_index];
-              } else {
-                filters[i][j] = 0.0f;
-              }
-            }
-          }
-          dst[counter++] = filters[0];
-          dst[counter++] = filters[1];
-          dst[counter++] = filters[2];
-          dst[counter++] = filters[3];
-        }
-      }
-    }
-  }
-}
-
 bool IsConvolutionTransposed3x3Supported(
     const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr);
@@ -177,6 +135,10 @@ ConvolutionTransposed3x3 CreateConvolutionTransposed3x3(
     const GpuInfo& gpu_info, const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr);
 
+ConvolutionTransposed3x3 CreateConvolutionTransposed3x3DynamicWeights(
+    const GpuInfo& gpu_info, const OperationDef& definition,
+    const ConvolutionTransposedAttributes& attr);
+
 }  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc
index 4cabbbee376..7a11e617627 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc
@@ -19,8 +19,7 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -53,9 +52,18 @@ std::string ConvolutionTransposed3x3Thin::GenerateConvolutionTransposedCode(
   AddSrcTensor("src_tensor", src_desc);
   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
 
+  if (op_def.src_tensors.size() == 2) {
+    // dynamic weights
+    BufferDescriptor desc;
+    desc.element_type = op_def.src_tensors[1].data_type;
+    desc.element_size = 4;
+    desc.memory_type = MemoryType::CONSTANT;
+    AddSrcBuffer("weights", desc);
+  }
+
   const auto src_tensor_type = op_def.src_tensors[0].storage_type;
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
 
   switch (op_def.precision) {
     case CalculationsPrecision::F32:
@@ -160,8 +168,7 @@ std::string ConvolutionTransposed3x3Thin::GenerateConvolutionTransposedCode(
   for (int d = 0; d < dst_depth; ++d) {
     const std::string layer = std::to_string(d);
     c += "  {\n";
-    c += "  FLT4 bias_val = args.weights.Read(" +
-         std::to_string(36 * filters_index + d) + ");\n";
+    c += "  FLT4 bias_val = args.biases.Read(" + layer + ");\n";
     for (int y = 0; y < 2; ++y) {
       for (int x = 0; x < 2; ++x) {
         const std::string x_coord = "X + " + std::to_string(x);
@@ -188,6 +195,10 @@ int3 ConvolutionTransposed3x3Thin::GetGridSize() const {
   return int3(grid_x, grid_y, grid_z);
 }
 
+std::vector<int> ConvolutionTransposed3x3Thin::GetSpatialWeightsRemap() const {
+  return std::vector<int>{4, 5, 3, 7, 1, 8, 6, 2, 0};
+}
+
 bool IsConvolutionTransposed3x3ThinSupported(
     const ConvolutionTransposedAttributes& attr) {
   return attr.weights.shape.o <= 8 && attr.weights.shape.w == 3 &&
@@ -201,7 +212,28 @@ ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3Thin(
     const GpuInfo& gpu_info, const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr) {
   ConvolutionTransposed3x3Thin result(definition, attr);
-  result.UploadData(attr.weights, attr.bias);
+  result.UploadWeights(attr.weights);
+
+  TensorLinearDescriptor desc;
+  desc.storage_type = LinearStorageType::TEXTURE_2D;
+  desc.element_type = definition.GetDataType();
+  desc.UploadLinearData(attr.bias);
+  result.args_.AddObject(
+      "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+  return result;
+}
+
+ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3ThinDynamicWeights(
+    const GpuInfo& gpu_info, const OperationDef& definition,
+    const ConvolutionTransposedAttributes& attr) {
+  ConvolutionTransposed3x3Thin result(definition, attr);
+
+  TensorLinearDescriptor desc;
+  desc.storage_type = LinearStorageType::TEXTURE_2D;
+  desc.element_type = definition.GetDataType();
+  desc.UploadLinearData(attr.bias);
+  result.args_.AddObject(
+      "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
   return result;
 }
 
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h
index 8ff50f95f31..8080d757e9c 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h
@@ -19,7 +19,6 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/util.h"
@@ -27,6 +26,9 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
@@ -47,29 +49,38 @@ class ConvolutionTransposed3x3Thin : public GPUOperation {
   ConvolutionTransposed3x3Thin& operator=(const ConvolutionTransposed3x3Thin&) =
       delete;
 
+  WeightsDescription GetWeightsDescription() const {
+    WeightsDescription desc;
+    desc.layout = WeightsLayout::kOICustomSSpatialI4O4;
+    desc.spatial_remap = GetSpatialWeightsRemap();
+    return desc;
+  }
+
  private:
-  friend ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3Thin(
-      const GpuInfo& gpu_info, const OperationDef& definition,
-      const ConvolutionTransposedAttributes& attr);
   explicit ConvolutionTransposed3x3Thin(
       const OperationDef& definition,
       const ConvolutionTransposedAttributes& attr);
-  template <DataType T>
-  void UploadData(const tflite::gpu::Tensor<OHWI, T>& weights,
-                  const tflite::gpu::Tensor<Linear, T>& biases);
 
-  template <DataType S, typename T>
-  void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
-                            absl::Span<T> dst);
+  friend ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3Thin(
+      const GpuInfo& gpu_info, const OperationDef& definition,
+      const ConvolutionTransposedAttributes& attr);
+  friend ConvolutionTransposed3x3Thin
+  CreateConvolutionTransposed3x3ThinDynamicWeights(
+      const GpuInfo& gpu_info, const OperationDef& definition,
+      const ConvolutionTransposedAttributes& attr);
+
+  template <DataType T>
+  void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights);
+
+  std::vector<int> GetSpatialWeightsRemap() const;
 
   std::string GenerateConvolutionTransposedCode(const OperationDef& op_def,
                                                 int src_depth, int dst_depth);
 };
 
 template <DataType T>
-void ConvolutionTransposed3x3Thin::UploadData(
-    const tflite::gpu::Tensor<OHWI, T>& weights,
-    const tflite::gpu::Tensor<Linear, T>& biases) {
+void ConvolutionTransposed3x3Thin::UploadWeights(
+    const tflite::gpu::Tensor<OHWI, T>& weights) {
   const int src_depth = DivideRoundUp(weights.shape.i, 4);
   const int dst_depth = DivideRoundUp(weights.shape.o, 4);
   const int kernel_x = 3;  //  This operation support only 3x3 kernel
@@ -83,79 +94,23 @@ void ConvolutionTransposed3x3Thin::UploadData(
   desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
   desc.element_size = 4;
   desc.memory_type = MemoryType::CONSTANT;
-  desc.size = flt4_size * (flt4_count + dst_depth);
+  desc.size = flt4_size * flt4_count;
   desc.data.resize(desc.size);
 
   if (f32_weights) {
     float4* gpu_data = reinterpret_cast<float4*>(desc.data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(gpu_data, flt4_count));
-    for (int i = 0; i < dst_depth; ++i) {
-      float4 bias_value(0.0f);
-      for (int c = 0; c < 4; ++c) {
-        int ch = i * 4 + c;
-        bias_value[c] = ch < weights.shape.o ? biases.data[ch] : 0.0f;
-      }
-      gpu_data[flt4_count + i] = bias_value;
-    }
+    RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(),
+                                          absl::MakeSpan(gpu_data, flt4_count));
   } else {
     half4* gpu_data = reinterpret_cast<half4*>(desc.data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(gpu_data, flt4_count));
-    for (int i = 0; i < dst_depth; ++i) {
-      half4 bias_value(0.0f);
-      for (int c = 0; c < 4; ++c) {
-        int ch = i * 4 + c;
-        bias_value[c] = ch < weights.shape.o ? biases.data[ch] : 0.0f;
-      }
-      gpu_data[flt4_count + i] = bias_value;
-    }
+    RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(),
+                                          absl::MakeSpan(gpu_data, flt4_count));
   }
 
   args_.AddObject("weights",
                   absl::make_unique<BufferDescriptor>(std::move(desc)));
 }
 
-template <DataType S, typename T>
-void ConvolutionTransposed3x3Thin::RearrangeWeightsData(
-    const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) {
-  const int src_depth = DivideRoundUp(weights.shape.i, 4);
-  const int dst_depth = DivideRoundUp(weights.shape.o, 4);
-  const int kernel_x = 3;
-  const int kernel_y = 3;
-
-  const int remap[9] = {4, 5, 3, 7, 1, 8, 6, 2, 0};
-
-  int counter = 0;
-  for (int s = 0; s < src_depth; ++s) {
-    for (int d = 0; d < dst_depth; ++d) {
-      for (int y = 0; y < kernel_y; ++y) {
-        for (int x = 0; x < kernel_x; ++x) {
-          const int kernel_index = remap[y * kernel_x + x];
-          const int kernel_index_x = kernel_index % kernel_x;
-          const int kernel_index_y = kernel_index / kernel_x;
-          T filters[4];
-          for (int j = 0; j < 4; ++j) {
-            for (int i = 0; i < 4; ++i) {
-              const int s_ch = s * 4 + i;
-              const int d_ch = d * 4 + j;
-              if (s_ch < weights.shape.i && d_ch < weights.shape.o) {
-                const int f_index = weights.shape.LinearIndex(
-                    {d_ch, kernel_index_y, kernel_index_x, s_ch});
-                filters[i][j] = weights.data[f_index];
-              } else {
-                filters[i][j] = 0.0f;
-              }
-            }
-          }
-          dst[counter++] = filters[0];
-          dst[counter++] = filters[1];
-          dst[counter++] = filters[2];
-          dst[counter++] = filters[3];
-        }
-      }
-    }
-  }
-}
-
 bool IsConvolutionTransposed3x3ThinSupported(
     const ConvolutionTransposedAttributes& attr);
 
@@ -163,6 +118,10 @@ ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3Thin(
     const GpuInfo& gpu_info, const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr);
 
+ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3ThinDynamicWeights(
+    const GpuInfo& gpu_info, const OperationDef& definition,
+    const ConvolutionTransposedAttributes& attr);
+
 }  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc
index 5db4d03b0f4..61c5f1c32c4 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc
@@ -19,30 +19,40 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
 namespace cl {
+namespace {
+ConvolutionTransposed4x4::WeightsUploadType GetBestWeightsUploadType(
+    const GpuInfo& gpu_info) {
+  ConvolutionTransposed4x4::WeightsUploadType weights_upload_type =
+      ConvolutionTransposed4x4::WeightsUploadType::GLOBAL_MEM;
+  if (gpu_info.IsPowerVR()) {
+    weights_upload_type =
+        ConvolutionTransposed4x4::WeightsUploadType::LOCAL_MEM_ASYNC;
+  } else if (gpu_info.IsNvidia() || gpu_info.IsIntel()) {
+    weights_upload_type =
+        ConvolutionTransposed4x4::WeightsUploadType::LOCAL_MEM_BY_THREADS;
+  } else if (gpu_info.IsAMD()) {
+    weights_upload_type =
+        ConvolutionTransposed4x4::WeightsUploadType::CONSTANT_MEM;
+  } else {
+    weights_upload_type =
+        ConvolutionTransposed4x4::WeightsUploadType::GLOBAL_MEM;
+  }
+  return weights_upload_type;
+}
+}  // namespace
+
 ConvolutionTransposed4x4::ConvolutionTransposed4x4(
-    const OperationDef& definition, const GpuInfo& gpu_info,
-    const ConvolutionTransposedAttributes& attr)
+    const OperationDef& definition, const GpuInfo& gpu_info)
     : GPUOperation(definition) {
   work_group_size_ = int3(8, 4, 1);
-  WeightsUploadType weights_upload_type = WeightsUploadType::GLOBAL_MEM;
-  if (gpu_info.IsPowerVR()) {
-    weights_upload_type = WeightsUploadType::LOCAL_MEM_ASYNC;
-  } else if (gpu_info.IsNvidia() || gpu_info.IsIntel()) {
-    weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
-  } else if (gpu_info.IsAMD()) {
-    weights_upload_type = WeightsUploadType::CONSTANT_MEM;
-  } else {
-    weights_upload_type = WeightsUploadType::GLOBAL_MEM;
-  }
 
-  code_ = GenerateConvolutionTransposedCode(definition_, weights_upload_type);
-  UploadWeights(attr.weights, weights_upload_type);
+  code_ = GenerateConvolutionTransposedCode(definition_,
+                                            GetBestWeightsUploadType(gpu_info));
   if (definition_.precision == CalculationsPrecision::F16 &&
       gpu_info.IsPowerVR()) {
     compiler_options_.push_back(CompilerOptions::kClPowervrFp16);
@@ -76,6 +86,19 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode(
   }
   AddDstTensor("dst_tensor", dst_desc);
 
+  if (op_def.src_tensors.size() == 2) {
+    // dynamic weights
+    BufferDescriptor desc;
+    desc.element_type = op_def.src_tensors[1].data_type;
+    desc.element_size = 4;
+    desc.memory_type =
+        weights_upload_type ==
+                ConvolutionTransposed4x4::WeightsUploadType::CONSTANT_MEM
+            ? MemoryType::CONSTANT
+            : MemoryType::GLOBAL;
+    AddSrcBuffer("weights", desc);
+  }
+
   args_.AddInt("filter_offset");
 
   const bool need_local_mem =
@@ -84,7 +107,7 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode(
       weights_upload_type ==
           ConvolutionTransposed4x4::WeightsUploadType::LOCAL_MEM_ASYNC;
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   switch (op_def.precision) {
     case CalculationsPrecision::F32:
     case CalculationsPrecision::F16:
@@ -323,6 +346,10 @@ int3 ConvolutionTransposed4x4::GetGridSize() const {
   return int3(grid_x, grid_y, grid_z);
 }
 
+std::vector<int> ConvolutionTransposed4x4::GetSpatialWeightsRemap() const {
+  return std::vector<int>{10, 11, 14, 15, 8, 9, 12, 13, 2, 3, 6, 7, 0, 1, 4, 5};
+}
+
 bool IsConvolutionTransposed4x4Supported(
     const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr) {
@@ -334,7 +361,22 @@ bool IsConvolutionTransposed4x4Supported(
 ConvolutionTransposed4x4 CreateConvolutionTransposed4x4(
     const GpuInfo& gpu_info, const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr) {
-  ConvolutionTransposed4x4 result(definition, gpu_info, attr);
+  ConvolutionTransposed4x4 result(definition, gpu_info);
+  result.UploadWeights(attr.weights, GetBestWeightsUploadType(gpu_info));
+
+  TensorLinearDescriptor desc;
+  desc.storage_type = LinearStorageType::TEXTURE_2D;
+  desc.element_type = definition.GetDataType();
+  desc.UploadLinearData(attr.bias);
+  result.args_.AddObject(
+      "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+  return result;
+}
+
+ConvolutionTransposed4x4 CreateConvolutionTransposed4x4DynamicWeights(
+    const GpuInfo& gpu_info, const OperationDef& definition,
+    const ConvolutionTransposedAttributes& attr) {
+  ConvolutionTransposed4x4 result(definition, gpu_info);
 
   TensorLinearDescriptor desc;
   desc.storage_type = LinearStorageType::TEXTURE_2D;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h
index febbc575c33..57304adb01f 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h
@@ -19,7 +19,6 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/util.h"
@@ -27,6 +26,9 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
@@ -52,6 +54,13 @@ class ConvolutionTransposed4x4 : public GPUOperation {
   ConvolutionTransposed4x4(const ConvolutionTransposed4x4&) = delete;
   ConvolutionTransposed4x4& operator=(const ConvolutionTransposed4x4&) = delete;
 
+  WeightsDescription GetWeightsDescription() const {
+    WeightsDescription desc;
+    desc.layout = WeightsLayout::kOICustomSSpatialI4O4;
+    desc.spatial_remap = GetSpatialWeightsRemap();
+    return desc;
+  }
+
   enum class WeightsUploadType {
     LOCAL_MEM_ASYNC,
     LOCAL_MEM_BY_THREADS,
@@ -61,18 +70,20 @@ class ConvolutionTransposed4x4 : public GPUOperation {
 
  private:
   ConvolutionTransposed4x4(const OperationDef& definition,
-                           const GpuInfo& gpu_info,
-                           const ConvolutionTransposedAttributes& attr);
+                           const GpuInfo& gpu_info);
+
   friend ConvolutionTransposed4x4 CreateConvolutionTransposed4x4(
       const GpuInfo& gpu_info, const OperationDef& definition,
       const ConvolutionTransposedAttributes& attr);
+  friend ConvolutionTransposed4x4 CreateConvolutionTransposed4x4DynamicWeights(
+      const GpuInfo& gpu_info, const OperationDef& definition,
+      const ConvolutionTransposedAttributes& attr);
+
   template <DataType T>
   void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
                      WeightsUploadType weights_upload_type);
 
-  template <DataType S, typename T>
-  void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
-                            absl::Span<T> dst);
+  std::vector<int> GetSpatialWeightsRemap() const;
 
   std::string GenerateConvolutionTransposedCode(
       const OperationDef& op_def, WeightsUploadType weights_upload_type);
@@ -104,58 +115,18 @@ void ConvolutionTransposed4x4::UploadWeights(
 
   if (f32_weights) {
     float4* ptr = reinterpret_cast<float4*>(desc.data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, flt4_count));
+    RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(),
+                                          absl::MakeSpan(ptr, flt4_count));
   } else {
     half4* ptr = reinterpret_cast<half4*>(desc.data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, flt4_count));
+    RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(),
+                                          absl::MakeSpan(ptr, flt4_count));
   }
 
   args_.AddObject("weights",
                   absl::make_unique<BufferDescriptor>(std::move(desc)));
 }
 
-template <DataType S, typename T>
-void ConvolutionTransposed4x4::RearrangeWeightsData(
-    const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) {
-  const int src_depth = DivideRoundUp(weights.shape.i, 4);
-  const int dst_depth = DivideRoundUp(weights.shape.o, 4);
-  const int kernel_x = 4;
-  const int kernel_y = 4;
-
-  const int remap[16] = {10, 11, 14, 15, 8, 9, 12, 13, 2, 3, 6, 7, 0, 1, 4, 5};
-
-  int counter = 0;
-  for (int d = 0; d < dst_depth; ++d) {
-    for (int s = 0; s < src_depth; ++s) {
-      for (int y = 0; y < kernel_y; ++y) {
-        for (int x = 0; x < kernel_x; ++x) {
-          const int kernel_index = remap[y * kernel_x + x];
-          const int kernel_index_x = kernel_index % kernel_x;
-          const int kernel_index_y = kernel_index / kernel_x;
-          T filters[4];
-          for (int j = 0; j < 4; ++j) {
-            for (int i = 0; i < 4; ++i) {
-              const int s_ch = s * 4 + i;
-              const int d_ch = d * 4 + j;
-              if (s_ch < weights.shape.i && d_ch < weights.shape.o) {
-                const int f_index = weights.shape.LinearIndex(
-                    {d_ch, kernel_index_y, kernel_index_x, s_ch});
-                filters[i][j] = weights.data[f_index];
-              } else {
-                filters[i][j] = 0.0f;
-              }
-            }
-          }
-          dst[counter++] = filters[0];
-          dst[counter++] = filters[1];
-          dst[counter++] = filters[2];
-          dst[counter++] = filters[3];
-        }
-      }
-    }
-  }
-}
-
 bool IsConvolutionTransposed4x4Supported(
     const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr);
@@ -164,6 +135,10 @@ ConvolutionTransposed4x4 CreateConvolutionTransposed4x4(
     const GpuInfo& gpu_info, const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr);
 
+ConvolutionTransposed4x4 CreateConvolutionTransposed4x4DynamicWeights(
+    const GpuInfo& gpu_info, const OperationDef& definition,
+    const ConvolutionTransposedAttributes& attr);
+
 }  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc
index c21b0cc0fc1..9e75b0bca23 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc
@@ -19,8 +19,7 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -76,7 +75,7 @@ std::string ConvolutionTransposedThin::GenerateConvolutionTransposedCode(
       break;
   }
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   if (op_def.IsBatchSupported()) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h
index 8b57ac03e5f..b18ca6814b8 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h
@@ -19,7 +19,6 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
 #include "tensorflow/lite/delegates/gpu/cl/util.h"
@@ -27,6 +26,7 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc
index 538183231b0..b1787b5a7ed 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc
@@ -20,9 +20,9 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -86,7 +86,7 @@ std::string GenerateDepthwiseConvolutionCode(
   }
   op->AddDstTensor("dst_tensor", dst_desc);
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
 
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h
index 5c708622ff9..1ca2b3ff8b8 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h
@@ -19,7 +19,6 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
@@ -28,6 +27,7 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc
index 6a5bc2c4700..ceaf59fff4d 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc
@@ -18,9 +18,8 @@ limitations under the License.
 #include <string>
 #include <utility>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -66,7 +65,7 @@ std::string DepthwiseConv3x3::GenerateDepthwiseConvCode(
   const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
                             src_tensor_type == TensorStorageType::IMAGE_BUFFER;
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   if (local_mem_uploads) {
     c += "__attribute__((reqd_work_group_size(8, 4, 1)))\n";
   }
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h
index b4f706a779c..4462ea87411 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h
@@ -20,7 +20,6 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
 #include "tensorflow/lite/delegates/gpu/cl/util.h"
@@ -28,6 +27,7 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc
index c3b41c6e786..5167057408d 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc
@@ -19,7 +19,6 @@ limitations under the License.
 
 #include "absl/strings/str_cat.h"
 #include "absl/strings/substitute.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
 
 namespace tflite {
@@ -83,12 +82,7 @@ std::string GetOneInputCode(const OperationType& op_type,
       result = "$0 *= $0;\n";
       break;
     case OperationType::TANH:
-      if (precision != CalculationsPrecision::F32) {
-        result = "float4 t = native_exp(convert_float4($0 * 2.0h));\n";
-        result += "$0 = convert_half4(native_divide(t - 1.0f, t + 1.0f));\n";
-      } else {
-        result = "$0 = tanh($0);\n";
-      }
+      result = "$0 = tanh($0);\n";
       break;
     default:
       return "Unknown operation type;\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h
index 572b731d908..22a1a8d4b92 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h
@@ -18,9 +18,9 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc
index 56d8b7d11c0..00f9a5260f9 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc
@@ -388,7 +388,7 @@ TEST_F(OpenCLOperationTest, Square) {
 TEST_F(OpenCLOperationTest, Tanh) {
   TensorFloat32 src_tensor;
   src_tensor.shape = BHWC(1, 2, 1, 2);
-  src_tensor.data = {1.0f, 2.0f, 3.0f, 4.0f};
+  src_tensor.data = {-50.0f, -0.1f, 0.1f, 50.0f};
 
   for (auto storage : env_.GetSupportedStorages()) {
     for (auto precision : env_.GetSupportedPrecisions()) {
@@ -407,8 +407,8 @@ TEST_F(OpenCLOperationTest, Tanh) {
           BHWC(1, 2, 1, 2), &dst_tensor));
       EXPECT_THAT(
           dst_tensor.data,
-          Pointwise(FloatNear(eps), {std::tanh(1.0f), std::tanh(2.0f),
-                                     std::tanh(3.0f), std::tanh(4.0f)}));
+          Pointwise(FloatNear(eps), {std::tanh(-50.0f), std::tanh(-0.1f),
+                                     std::tanh(0.1f), std::tanh(50.0f)}));
     }
   }
 }
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc
index b5caef81b43..50b92381c53 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc
@@ -20,12 +20,10 @@ limitations under the License.
 #include <vector>
 
 #include "absl/memory/memory.h"
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
@@ -83,7 +81,7 @@ std::string FullyConnected::GetFullyConnectedKernelCode(
 
   const bool weights_are_buffer = UseBufferForWeights(gpu_info);
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   switch (op_def.precision) {
     case CalculationsPrecision::F32:
       c += "#define FLT16 float16\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h
index a508e741ef9..fda3dc6fa32 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h
@@ -25,12 +25,11 @@ limitations under the License.
 #include "absl/memory/memory.h"
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc
index 3db7fb70406..87020324f8a 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc
@@ -21,10 +21,10 @@ limitations under the License.
 #include <gtest/gtest.h>
 #include "tensorflow/lite/delegates/gpu/cl/environment.h"
 #include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 using ::testing::ElementsAreArray;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc
index 11d2e209428..d659d8bff02 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc
@@ -17,15 +17,14 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
 namespace cl {
 namespace {
 std::string GetLSTMCode(const OperationDef& op_def, const GpuInfo& gpu_info) {
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int B = get_global_id(0);\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h
index fa0aa270158..bfb3cd25055 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h
@@ -16,9 +16,9 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_LSTM_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_LSTM_H_
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc
index 4905c5f9284..4692ac0834e 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc
@@ -17,8 +17,7 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -44,7 +43,7 @@ std::string GetMaxUnpoolingKernelCode(const OperationDef& op_def,
   }
   op->AddDstTensor("dst_tensor", dst_desc);
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int X = get_global_id(0);\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h
index c1b6cbf334b..fb05e2dc7a3 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h
@@ -16,9 +16,9 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_MAX_UNPOOLING_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_MAX_UNPOOLING_H_
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc
index 5cdb81e5324..20edf489c2b 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc
@@ -18,9 +18,7 @@ limitations under the License.
 #include <string>
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_program.h"
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -95,7 +93,7 @@ MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition,
   // For now, fix workgroup size to the biggest supported by the device, but not
   // larger than the number of tensor slices.
   int desired_work_group_size =
-      std::min(tensor_slices, gpu_info.max_work_group_size_x);
+      std::min(tensor_slices, gpu_info.GetMaxWorkGroupSizeForX());
   if (gpu_info.IsMali()) {
     // Don't use more than 64 work items per work group on ARM Mali. They
     // implement local memory using the global memory, larger workgroups have
@@ -136,9 +134,9 @@ MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition,
   work_group_size_.y = 1;  // Required
   work_group_size_.z = 1;  // Required
   code_ = GetNormalizationCode();
-  if (gpu_info.cl_version >= OpenCLVersion::CL_3_0) {
+  if (gpu_info.IsCL30OrHigher()) {
     compiler_options_.push_back(CompilerOptions::kCl30);
-  } else if (gpu_info.cl_version >= OpenCLVersion::CL_2_0) {
+  } else if (gpu_info.IsCL20OrHigher()) {
     compiler_options_.push_back(CompilerOptions::kCl20);
   }
 }
@@ -147,7 +145,7 @@ std::string MeanStdDevNormalization::GetNormalizationCode() {
   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
   AddDstTensor("dst_tensor", definition_.dst_tensors[0]);
 
-  std::string c = GetCommonDefines(definition_.precision);
+  std::string c;
   c += GetVectorReduceCode();
   c += GetReduceCode();
   c += GetFilterCode();
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h
index 6a4a1848394..defdf9d0c3a 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h
@@ -16,10 +16,9 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_LSTM_NORMALIZATION_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_LSTM_NORMALIZATION_H_
 
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc b/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc
index 8012e601c0b..77edcbfef2d 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc
@@ -17,9 +17,8 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -36,7 +35,7 @@ std::string GetPaddingCode(const OperationDef& op_def,
 
   const std::string dst_batch =
       op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   const std::string channels[] = {".x", ".y", ".z", ".w"};
 
   if (attr.type == PaddingContentType::REFLECT) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/padding.h b/tensorflow/lite/delegates/gpu/cl/kernels/padding.h
index 81047162d20..be534017a28 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/padding.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/padding.h
@@ -16,9 +16,9 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_PADDING_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_PADDING_H_
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc
index aa2069618c7..d687376e871 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc
@@ -17,8 +17,8 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -72,7 +72,7 @@ std::string GetAveragePoolingKernelCode(const OperationDef& op_def,
       op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER ||
       op_def.src_tensors[0].storage_type == TensorStorageType::IMAGE_BUFFER;
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int X = get_global_id(0);\n";
@@ -193,7 +193,7 @@ std::string GetMaxPoolingKernelCode(const OperationDef& op_def,
     dst_coord += ", " + dst_coords[i];
   }
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int X = get_global_id(0);\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h
index 81a0dfff4de..ade7584b924 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h
@@ -17,10 +17,10 @@ limitations under the License.
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_POOLING_H_
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc
index 088dc5b027f..97bad10cac9 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc
@@ -17,7 +17,6 @@ limitations under the License.
 
 #include "absl/strings/str_cat.h"
 #include "absl/types/variant.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h
index 1e98d043eed..3ec5a7791f4 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h
@@ -20,11 +20,11 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc
index 1e08eb0ff52..267a56477c8 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc
@@ -19,7 +19,6 @@ limitations under the License.
 
 #include "absl/strings/str_cat.h"
 #include "absl/types/variant.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h
index 1e37e427af8..c84409d3f7a 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h
@@ -20,11 +20,11 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc
index 93345ce1d78..1ff3ec582e9 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc
@@ -18,9 +18,9 @@ limitations under the License.
 #include <set>
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 
 namespace tflite {
@@ -196,7 +196,7 @@ std::string Reduce::GetReduceKernelCode(const OperationDef& op_def,
     }
   };
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   const std::string wg_x = std::to_string(work_group_size.x);
   const std::string wg_y = std::to_string(work_group_size.y);
   const std::string wg_z = std::to_string(work_group_size.z);
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h
index 8747b5dc832..a46d094b5ba 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h
@@ -19,9 +19,9 @@ limitations under the License.
 #include <map>
 #include <set>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/kernel_info.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/relu.h b/tensorflow/lite/delegates/gpu/cl/kernels/relu.h
index 1b4e3a81605..ba6676b27a2 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/relu.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/relu.h
@@ -18,8 +18,8 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc
index d965b6f0611..631e140cd36 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc
@@ -17,15 +17,14 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
 namespace cl {
 namespace {
 std::string GetReshapeCode(const OperationDef& op_def) {
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h
index 59cc5c1560d..b02e7e1d10f 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h
@@ -17,8 +17,8 @@ limitations under the License.
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_RESHAPE_H_
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc
index 78440e3c843..a040f300cba 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc
@@ -17,8 +17,7 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -26,7 +25,7 @@ namespace cl {
 namespace {
 
 std::string GetReshapeCode(const OperationDef& op_def) {
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h
index 2052d45b3e1..b725cf06267 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h
@@ -18,8 +18,8 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc
index cbf706f9bd2..1c69f84e9f1 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc
@@ -15,9 +15,8 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/kernels/resize.h"
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -56,7 +55,7 @@ std::string Resize::GetResizeCode(const OperationDef& op_def,
   args_.AddFloat("scale_factor_x");
   args_.AddFloat("scale_factor_y");
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int Y = get_global_id(1);\n";
@@ -191,7 +190,7 @@ std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
   args_.AddFloat("scale_factor_y");
   args_.AddFloat("scale_factor_z");
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int Y = get_global_id(1);\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h
index 859d750b7e0..660243975ef 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h
@@ -16,9 +16,9 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_RESIZE_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_RESIZE_H_
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc
index 03a53d5716b..ac55e6831e0 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc
@@ -17,16 +17,15 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
 namespace cl {
 namespace {
 std::string GetSoftmaxKernelCode(const OperationDef& op_def) {
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int X = get_global_id(0);\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h
index 4a9b5fd7c18..bfc9ecc3f16 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h
@@ -17,8 +17,8 @@ limitations under the License.
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SOFTMAX_H_
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc
index d4d0442e61d..da04402b99f 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc
@@ -17,8 +17,8 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
 
 namespace tflite {
 namespace gpu {
@@ -48,7 +48,7 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) {
   args_.AddFloat("mask_w");
   args_.AddInt("slices_x32");
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   if (op_def.IsBatchSupported()) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h
index 2a50dee2d63..c670bdf3a6c 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h
@@ -17,8 +17,8 @@ limitations under the License.
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SOFTMAX1X1_H_
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc
index f5323b48bae..3bf90092b60 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc
@@ -19,15 +19,14 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
 namespace cl {
 namespace {
 std::string GetSpaceToDepthCode(const OperationDef& op_def) {
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   if (op_def.IsBatchSupported()) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h
index 08aca3054d6..0ef263acd38 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h
@@ -17,9 +17,9 @@ limitations under the License.
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPACE_TO_DEPTH_H_
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD
index 92231338730..c359f0887a6 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD
@@ -12,15 +12,14 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/cl:gpu_object",
         "//tensorflow/lite/delegates/gpu/cl:util",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:util",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:work_group_picking",
         "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
@@ -31,18 +30,16 @@ cc_library(
     deps = [
         "//tensorflow/lite/delegates/gpu/cl:buffer",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
-        "//tensorflow/lite/delegates/gpu/cl:device_info",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
         "//tensorflow/lite/delegates/gpu/cl:tensor",
         "//tensorflow/lite/delegates/gpu/cl:texture2d",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:util",
         "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:tensor",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/memory",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc
index 7875f95b23e..2fbaf3de080 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc
@@ -20,8 +20,7 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -128,7 +127,7 @@ std::string GenerateCode(const OperationDef& op_def,
   result->args_.AddInt("padding_y", -dw_attr.padding.prepended.h);
   result->args_.AddInt("dilation_y", dw_attr.dilations.h);
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h
index b87051104b7..4d1cc19f8b2 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h
@@ -20,12 +20,12 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
 #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/util.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc
index f1c3ddb7045..7accb9276b2 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc
@@ -20,12 +20,10 @@ limitations under the License.
 #include <vector>
 
 #include "absl/memory/memory.h"
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
@@ -82,7 +80,7 @@ std::string FCFCAdd::GetFCFCAddKernelCode(const OperationDef& op_def,
 
   const bool weights_are_buffer = UseBufferForWeights(gpu_info);
 
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   switch (op_def.precision) {
     case CalculationsPrecision::F32:
       c += "#define FLT16 float16\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h
index 09e0548c663..37248e0480c 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h
@@ -25,12 +25,11 @@ limitations under the License.
 #include "absl/memory/memory.h"
 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc
index 1f8f985f3ee..3fa2121e986 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc
@@ -17,8 +17,7 @@ limitations under the License.
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -108,7 +107,7 @@ std::string StridedSlice::GetStridedSliceCode(const OperationDef& op_def,
 
   const std::string batch_id =
       op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h
index dddff2faf35..c2af72242ec 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h
@@ -16,8 +16,8 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_STRIDED_SLICE_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_STRIDED_SLICE_H_
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc
index 0182ec7d90c..a177ade3b12 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc
@@ -18,8 +18,7 @@ limitations under the License.
 #include <string>
 
 #include "absl/strings/str_cat.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
@@ -29,7 +28,7 @@ std::string GetTransposeCode(const OperationDef& op_def,
                              const TransposeAttributes& attr) {
   const std::string batch_id =
       op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h
index 631d5dc08b3..d318e9600c5 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h
@@ -16,8 +16,8 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TRANSPOSE_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TRANSPOSE_H_
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc b/tensorflow/lite/delegates/gpu/cl/kernels/util.cc
deleted file mode 100644
index 2530c73571b..00000000000
--- a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc
+++ /dev/null
@@ -1,201 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-
-#include <cfloat>
-#include <cmath>
-#include <string>
-#include <vector>
-
-#include "absl/strings/str_cat.h"
-#include "absl/strings/substitute.h"
-#include "tensorflow/lite/delegates/gpu/common/data_type.h"
-#include "tensorflow/lite/delegates/gpu/common/precision.h"
-
-namespace tflite {
-namespace gpu {
-namespace cl {
-
-std::string GetCommonDefines(CalculationsPrecision precision) {
-  std::string result;
-
-  switch (precision) {
-    case CalculationsPrecision::F32:
-      result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
-      result += "#define ACCUM_FLT4 float4\n";
-      result += "#define FLT float\n";
-      result += "#define FLT2 float2\n";
-      result += "#define FLT3 float3\n";
-      result += "#define FLT4 float4\n";
-      result += "#define TO_FLT4 convert_float4\n";
-      result += "#define TO_ACCUM_TYPE convert_float4\n";
-      result += "#define TO_ACCUM_FLT convert_float\n";
-      break;
-    case CalculationsPrecision::F16:
-      result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
-      result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
-      result += "#define ACCUM_FLT4 half4\n";
-      result += "#define FLT half\n";
-      result += "#define FLT2 half2\n";
-      result += "#define FLT3 half3\n";
-      result += "#define FLT4 half4\n";
-      result += "#define TO_FLT4 convert_half4\n";
-      result += "#define TO_ACCUM_TYPE convert_half4\n";
-      result += "#define TO_ACCUM_FLT convert_half\n";
-      break;
-    case CalculationsPrecision::F32_F16:
-      result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
-      result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
-      result += "#define ACCUM_FLT4 float4\n";
-      result += "#define FLT half\n";
-      result += "#define FLT2 half2\n";
-      result += "#define FLT3 half3\n";
-      result += "#define FLT4 half4\n";
-      result += "#define TO_FLT4 convert_half4\n";
-      result += "#define TO_ACCUM_TYPE convert_float4\n";
-      result += "#define TO_ACCUM_FLT convert_float\n";
-      break;
-  }
-  return result;
-}
-
-std::string GetXStrideCorrected(const std::string& src_x,
-                                const std::string& batch_size,
-                                const std::string& stride_x,
-                                const std::string& padding_x) {
-  // TODO(sorokin) check perf and optimize with floor() if needed
-  // int p0 = src_x / batch_size;\n";
-  // int b0 = src_x % batch_size;\n";
-  // return p0 * stride_x * batch_size + b0 + padding_x;\n";
-  return absl::Substitute("((($0) / $1) * $2 * $1 + (($0) % $1) + $3)", src_x,
-                          batch_size, stride_x, padding_x);
-}
-
-std::string GetXStrideCorrectedV2(const std::string& src_x,
-                                  const std::string& batch_size,
-                                  const std::string& stride_x,
-                                  const std::string& padding_x) {
-  // int p0 = src_x / batch_size;\n";
-  // int b0 = src_x % batch_size;\n";
-  // return (p0 * stride_x + padding_x) * batch_size + b0;\n";
-  return absl::Substitute("(((($0) / $1) * $2 + $3) * $1 + ($0) % $1)", src_x,
-                          batch_size, stride_x, padding_x);
-}
-
-float4 GetMaskForLastPlane(int channels) {
-  float4 mask = float4(0.0f);
-  const int reminder = channels % 4 == 0 ? 4 : channels % 4;
-  for (int i = 0; i < reminder; ++i) {
-    mask[i] = 1.0f;
-  }
-  return mask;
-}
-
-int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size) {
-  for (const auto& wg : wgs) {
-    const int wg_size = wg.x * wg.y * wg.z;
-    if (wg_size <= max_wg_size) {
-      return wg;
-    }
-  }
-  return {1, 1, 1};
-}
-
-int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info,
-                                   CalculationsPrecision precision,
-                                   int task_size) {
-  const float task_size_per_cu =
-      task_size / static_cast<float>(gpu_info.compute_units_count);
-  int block_size = 1;
-  float threshold_1 = FLT_MAX;
-  float threshold_2 = FLT_MAX;
-  float threshold_4 = FLT_MAX;
-  if (!gpu_info.IsMali()) {
-    return 1;
-  }
-  MaliInfo mali_info = gpu_info.mali_info;
-  switch (precision) {
-    case CalculationsPrecision::F16:
-      if (mali_info.IsBifrostGen1()) {
-        threshold_1 = 256.0f;
-        threshold_2 = 256.0f * 4.0f;
-        threshold_4 = 256.0f * 8.0f;
-      } else if (mali_info.IsBifrostGen2()) {
-        threshold_1 = 256.0f * 2.0f;
-        threshold_2 = 256.0f * 8.0f;
-        threshold_4 = 256.0f * 16.0f;
-      } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
-        threshold_1 = 256.0f;
-        threshold_2 = 256.0f * 6.0f;
-        threshold_4 = 256.0f * 16.0f;
-      } else if (mali_info.IsMidgard()) {
-        threshold_1 = 256.0f * 4.0f;
-        threshold_2 = 256.0f * 16.0f;
-      }
-      break;
-    case CalculationsPrecision::F32_F16:
-      if (mali_info.IsBifrostGen1()) {
-        threshold_1 = 256.0f;
-        threshold_2 = 256.0f * 3.0f;
-        threshold_4 = 256.0f * 32.0f;
-      } else if (mali_info.IsBifrostGen2()) {
-        threshold_1 = 256.0f * 2.0f;
-        threshold_2 = 256.0f * 8.0f;
-      } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
-        threshold_1 = 256.0f;
-        threshold_2 = 256.0f * 8.0f;
-      } else if (mali_info.IsMidgard()) {
-        threshold_1 = 256.0f * 4.0f;
-      }
-      break;
-    case CalculationsPrecision::F32:
-      if (mali_info.IsBifrostGen1()) {
-        threshold_1 = 256.0f;
-        threshold_2 = 256.0f * 4.0f;
-      } else if (mali_info.IsBifrostGen2()) {
-        threshold_1 = 128.0f;
-        threshold_2 = 256.0f * 4.0f;
-      } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
-        threshold_1 = 256.0f;
-        threshold_2 = 256.0f * 12.0f;
-      } else if (mali_info.IsMidgard()) {
-        threshold_1 = 256.0f * 16.0f;
-      }
-      break;
-  }
-  if (task_size_per_cu <= threshold_1) {
-    block_size = 1;
-  } else if (task_size_per_cu <= threshold_2) {
-    block_size = 2;
-  } else if (task_size_per_cu <= threshold_4) {
-    block_size = 4;
-  } else {
-    block_size = 8;
-  }
-  return block_size;
-}
-
-int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size) {
-  int3 work_groups_count;
-  work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
-  work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y);
-  work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z);
-  return work_groups_count;
-}
-
-}  // namespace cl
-}  // namespace gpu
-}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc
index 47b66ededdd..dd380b49e32 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc
@@ -20,11 +20,10 @@ limitations under the License.
 
 #include "absl/strings/str_format.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
 
 namespace tflite {
@@ -59,7 +58,7 @@ Winograd4x4To36& Winograd4x4To36::operator=(Winograd4x4To36&& operation) {
 
 std::string Winograd4x4To36::GetWinograd4x4To36Code(
     const OperationDef& op_def) {
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
 
   const auto src_tensor_type = op_def.src_tensors[0].storage_type;
   const bool is_image_buffer =
@@ -327,7 +326,7 @@ Winograd36To4x4& Winograd36To4x4::operator=(Winograd36To4x4&& operation) {
 
 std::string Winograd36To4x4::GetWinograd36To4x4Code(
     const OperationDef& op_def) {
-  std::string c = GetCommonDefines(op_def.precision);
+  std::string c;
 
   switch (op_def.precision) {
     case CalculationsPrecision::F32:
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h
index 11430c99c0b..bdc3caf422d 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h
@@ -17,11 +17,11 @@ limitations under the License.
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WINOGRAD_H_
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc
index 9da73ba9783..c858356d405 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc
@@ -22,7 +22,6 @@ limitations under the License.
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
@@ -78,9 +77,13 @@ TEST_F(OpenCLOperationTest, Winograd4x4To36) {
     for (auto precision : env_.GetSupportedPrecisions()) {
       float eps;
       if (precision == CalculationsPrecision::F32) {
-        eps = 1e-5f * (env_.device().SupportsFP32RTN() ? 1.0f : 4.0f);
+        eps = 1e-5f * (env_.device().GetInfo().opencl_info.supports_fp32_rtn
+                           ? 1.0f
+                           : 4.0f);
       } else {
-        eps = 1e-2f * (env_.device().SupportsFP16RTN() ? 1.0f : 4.0f);
+        eps = 1e-2f * (env_.device().GetInfo().opencl_info.supports_fp16_rtn
+                           ? 1.0f
+                           : 4.0f);
       }
       OperationDef op_def;
       op_def.precision = precision;
@@ -151,9 +154,13 @@ TEST_F(OpenCLOperationTest, Winograd36To4x4) {
     for (auto precision : env_.GetSupportedPrecisions()) {
       float eps;
       if (precision == CalculationsPrecision::F32) {
-        eps = 1e-5f * (env_.device().SupportsFP32RTN() ? 1.0f : 4.0f);
+        eps = 1e-5f * (env_.device().GetInfo().opencl_info.supports_fp32_rtn
+                           ? 1.0f
+                           : 4.0f);
       } else {
-        eps = 1e-2f * (env_.device().SupportsFP16RTN() ? 1.0f : 4.0f);
+        eps = 1e-2f * (env_.device().GetInfo().opencl_info.supports_fp16_rtn
+                           ? 1.0f
+                           : 4.0f);
       }
       OperationDef op_def;
       op_def.precision = precision;
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD
index f4a2f8f654a..976fa82b851 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD
@@ -9,18 +9,18 @@ cc_library(
     hdrs = ["convolution_selector.h"],
     deps = [
         "//tensorflow/lite/delegates/gpu/cl/kernels:conv_buffer_1x1",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:conv_common",
         "//tensorflow/lite/delegates/gpu/cl/kernels:conv_constants",
         "//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr",
         "//tensorflow/lite/delegates/gpu/cl/kernels:conv_weights_converter",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:work_group_picking",
         "//tensorflow/lite/delegates/gpu/common:model_hints",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -30,16 +30,16 @@ cc_library(
     srcs = ["convolution_transposed_selector.cc"],
     hdrs = ["convolution_transposed_selector.h"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/cl/kernels:conv_common",
         "//tensorflow/lite/delegates/gpu/cl/kernels:convolution_transposed",
         "//tensorflow/lite/delegates/gpu/cl/kernels:convolution_transposed_3x3",
         "//tensorflow/lite/delegates/gpu/cl/kernels:convolution_transposed_3x3_thin",
         "//tensorflow/lite/delegates/gpu/cl/kernels:convolution_transposed_4x4",
         "//tensorflow/lite/delegates/gpu/cl/kernels:convolution_transposed_thin",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
+        "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -49,12 +49,11 @@ cc_library(
     hdrs = ["default_selector.h"],
     deps = [
         ":subgraph",
-        "//tensorflow/lite/delegates/gpu/cl:device_info",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/cl/selectors/default:default_selector",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_hints",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
     ],
 )
@@ -67,9 +66,9 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/cl/kernels:depthwise_conv",
         "//tensorflow/lite/delegates/gpu/cl/kernels:depthwise_conv_3x3",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -82,9 +81,9 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/cl/kernels:conv_buffer_1x1",
         "//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr",
         "//tensorflow/lite/delegates/gpu/cl/kernels:fully_connected",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -102,9 +101,7 @@ cc_library(
         ":subgraph",
         "//tensorflow/lite/delegates/gpu/cl:cl_device",
         "//tensorflow/lite/delegates/gpu/cl:storage_type_util",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:conv_common",
         "//tensorflow/lite/delegates/gpu/cl/kernels:elementwise",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/cl/kernels:mean_stddev_normalization",
         "//tensorflow/lite/delegates/gpu/cl/kernels:transpose",
         "//tensorflow/lite/delegates/gpu/cl/selectors:default_selector",
@@ -114,6 +111,8 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
+        "//tensorflow/lite/delegates/gpu/common:winograd_util",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:any",
@@ -130,7 +129,6 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/cl/kernels:concat_xy",
         "//tensorflow/lite/delegates/gpu/cl/kernels:concat_z",
         "//tensorflow/lite/delegates/gpu/cl/kernels:depthwise_conv",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/cl/kernels:lstm",
         "//tensorflow/lite/delegates/gpu/cl/kernels:max_unpooling",
         "//tensorflow/lite/delegates/gpu/cl/kernels:padding",
@@ -151,6 +149,7 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -162,7 +161,6 @@ cc_library(
     deps = [
         ":subgraph",
         "//tensorflow/lite/delegates/gpu/cl:cl_device",
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/cl/kernels/special:depthwise_conv_plus_1x1_conv",
         "//tensorflow/lite/delegates/gpu/cl/kernels/special:fc_fc_add",
         "//tensorflow/lite/delegates/gpu/common:data_type",
@@ -171,6 +169,7 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
         "@com_google_absl//absl/types:any",
     ],
@@ -181,8 +180,8 @@ cc_library(
     srcs = ["subgraph.cc"],
     hdrs = ["subgraph.h"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc
index b6b0131aeb9..5362004a903 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc
@@ -20,8 +20,8 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h"
 #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h"
 #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 
 namespace tflite {
@@ -55,10 +55,10 @@ std::unique_ptr<GPUOperation> SelectConvolutionDynamicWeightsAdreno(
     const Convolution2DAttributes& attr, const BHWC& weights_shape,
     const BHWC& dst_shape, const GpuInfo& gpu_info,
     const OperationDef& op_def, ModelHints hints,
-    ConvWeightsDescription* weights_desc) {
+    WeightsDescription* weights_desc) {
   ConvPowerVR conv = CreateConvPowerVRDynamicWeights(
       gpu_info, op_def, attr, weights_shape, &dst_shape);
-  *weights_desc = conv.GetConvWeightsDescription();
+  *weights_desc = conv.GetWeightsDescription();
   return absl::make_unique<ConvPowerVR>(std::move(conv));
 }
 
@@ -113,17 +113,17 @@ std::unique_ptr<GPUOperation> SelectConvolutionDynamicWeightsMali(
     const Convolution2DAttributes& attr, const BHWC& weights_shape,
     const BHWC& dst_shape, const GpuInfo& gpu_info,
     const OperationDef& op_def, ModelHints hints,
-    ConvWeightsDescription* weights_desc) {
+    WeightsDescription* weights_desc) {
   if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER &&
       IsConvBuffer1x1Supported(op_def, weights_shape, attr)) {
     ConvBuffer1x1 conv = CreateConvBuffer1x1DynamicWeights(
         gpu_info, op_def, attr, weights_shape, &dst_shape);
-    *weights_desc = conv.GetConvWeightsDescription();
+    *weights_desc = conv.GetWeightsDescription();
     return absl::make_unique<ConvBuffer1x1>(std::move(conv));
   } else {
     ConvPowerVR conv = CreateConvPowerVRDynamicWeights(
         gpu_info, op_def, attr, weights_shape, &dst_shape);
-    *weights_desc = conv.GetConvWeightsDescription();
+    *weights_desc = conv.GetWeightsDescription();
     return absl::make_unique<ConvPowerVR>(std::move(conv));
   }
 }
@@ -172,7 +172,7 @@ std::unique_ptr<GPUOperation> SelectConvolutionWithDynamicWeights(
     const Convolution2DAttributes& attr, const BHWC& weights_shape,
     const BHWC& dst_shape, const GpuInfo& gpu_info,
     const OperationDef& op_def, ModelHints hints,
-    ConvWeightsDescription* weights_desc) {
+    WeightsDescription* weights_desc) {
   if (gpu_info.IsAdreno()) {
     return SelectConvolutionDynamicWeightsAdreno(attr, weights_shape, dst_shape,
                                                  gpu_info, op_def, hints,
@@ -184,13 +184,13 @@ std::unique_ptr<GPUOperation> SelectConvolutionWithDynamicWeights(
   } else {
     ConvPowerVR conv = CreateConvPowerVRDynamicWeights(
         gpu_info, op_def, attr, weights_shape, &dst_shape);
-    *weights_desc = conv.GetConvWeightsDescription();
+    *weights_desc = conv.GetWeightsDescription();
     return absl::make_unique<ConvPowerVR>(std::move(conv));
   }
 }
 
 std::unique_ptr<GPUOperation> SelectConverterToConvWeights(
-    const ConvWeightsDescription& weights_desc, const OperationDef& op_def,
+    const WeightsDescription& weights_desc, const OperationDef& op_def,
     ModelHints hints) {
   ConverterToConvWeights converter =
       ConverterToConvWeights(op_def, weights_desc);
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h
index b243b5c54ae..cef1e014217 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h
@@ -18,12 +18,12 @@ limitations under the License.
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/model_hints.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
 
 namespace tflite {
 namespace gpu {
@@ -40,10 +40,10 @@ std::unique_ptr<GPUOperation> SelectConvolutionForWinograd(
 std::unique_ptr<GPUOperation> SelectConvolutionWithDynamicWeights(
     const Convolution2DAttributes& attr, const BHWC& weights_shape,
     const BHWC& dst_shape, const GpuInfo& gpu_info, const OperationDef& op_def,
-    ModelHints hints, ConvWeightsDescription* weights_desc);
+    ModelHints hints, WeightsDescription* weights_desc);
 
 std::unique_ptr<GPUOperation> SelectConverterToConvWeights(
-    const ConvWeightsDescription& weights_desc, const OperationDef& op_def,
+    const WeightsDescription& weights_desc, const OperationDef& op_def,
     ModelHints hints);
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc
index 6b3c0530c41..2aa56e2bb19 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc
@@ -98,10 +98,10 @@ std::unique_ptr<GPUOperation> SelectConvolutionTransposed(
 
 std::unique_ptr<GPUOperation> SelectConvolutionTransposedWithDynamicWeights(
     const ConvolutionTransposedAttributes& attr, const GpuInfo& gpu_info,
-    const OperationDef& op_def, ConvWeightsDescription* weights_desc) {
+    const OperationDef& op_def, WeightsDescription* weights_desc) {
   ConvolutionTransposed conv =
       CreateConvolutionTransposedDynamicWeights(gpu_info, op_def, attr);
-  *weights_desc = conv.GetConvWeightsDescription();
+  *weights_desc = conv.GetWeightsDescription();
   return absl::make_unique<ConvolutionTransposed>(std::move(conv));
 }
 
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h
index e00fe020015..4a2a6d9645f 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h
@@ -18,10 +18,10 @@ limitations under the License.
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
 
 namespace tflite {
 namespace gpu {
@@ -33,7 +33,7 @@ std::unique_ptr<GPUOperation> SelectConvolutionTransposed(
 
 std::unique_ptr<GPUOperation> SelectConvolutionTransposedWithDynamicWeights(
     const ConvolutionTransposedAttributes& attr, const GpuInfo& gpu_info,
-    const OperationDef& op_def, ConvWeightsDescription* weights_desc);
+    const OperationDef& op_def, WeightsDescription* weights_desc);
 
 }  // namespace cl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD
index 9ab0993b078..71b3f3b68e5 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD
@@ -7,12 +7,12 @@ cc_library(
     name = "default_selector",
     srcs = ["default_selector.cc"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/cl/selectors:subgraph",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_hints",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/strings",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc
index 72eec316da6..a7d94fabf43 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc
@@ -16,12 +16,12 @@ limitations under the License.
 #include <memory>
 
 #include "absl/strings/str_cat.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_hints.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h
index f6cb9a33ada..1efa215e602 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h
@@ -18,12 +18,11 @@ limitations under the License.
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_hints.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h
index 647bee97f3d..0c920984bc1 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h
@@ -18,9 +18,9 @@ limitations under the License.
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h
index 5a2639f26f3..5b1563a9351 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h
@@ -18,9 +18,9 @@ limitations under the License.
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
index fd12390cfd4..6dc28b4d454 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
@@ -34,28 +34,34 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
+#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
 
 namespace tflite {
 namespace gpu {
 namespace cl {
 namespace {
-bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
-                                   const GpuInfo& gpu_info,
-                                   const BHWC& dst_shape) {
+bool IsRecommendedForWinograd4x4To6x6(const Convolution2DAttributes& attr,
+                                      const GpuInfo& gpu_info,
+                                      const BHWC& dst_shape) {
   const int tiles_x = DivideRoundUp(dst_shape.w, 4);
   const int tiles_y = DivideRoundUp(dst_shape.h, 4);
+  const int total_tiles = tiles_x * tiles_y;
   const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
-  const bool suitable_attributes =
-      attr.weights.shape.w == 3 && attr.weights.shape.h == 3 &&
-      attr.dilations == HW(1, 1) && attr.strides == HW(1, 1);
   // Mali among other devices has smaller SIMD line size
-  const int min_depth = gpu_info.IsMali() ? 16 : 32;
-  const int min_hw = gpu_info.IsMali() ? 32 : 128;
+  int min_depth = gpu_info.IsMali() ? 16 : 32;
+  const int min_tiles = gpu_info.IsMali() ? 32 : 128;
+  if (total_tiles >= min_tiles * 8) {
+    min_depth /= 4;
+    min_depth = std::max(min_depth, 8);
+  } else if (total_tiles >= min_tiles * 4) {
+    min_depth /= 2;
+    min_depth = std::max(min_depth, 8);
+  }
   const bool recommended_channels =
       dst_depth % 4 == 0 && src_depth >= min_depth && dst_depth >= min_depth;
-  const bool recommended_hw = tiles_x * tiles_y >= min_hw;
-  return suitable_attributes && recommended_channels && recommended_hw;
+  const bool recommended_hw = total_tiles >= min_tiles;
+  return recommended_channels && recommended_hw;
 }
 
 absl::Status WinogradFromNode(const GpuInfo& gpu_info,
@@ -65,9 +71,12 @@ absl::Status WinogradFromNode(const GpuInfo& gpu_info,
                               const BHWC& input_shape, const BHWC& output_shape,
                               const Convolution2DAttributes& attr,
                               GPUOperationsSubgraph* gpu_subgraph) {
-  if (!IsSuitableForWinograd4x4To6x6(attr, gpu_info, output_shape)) {
+  if (!IsSuitableForWinograd4x4To6x6(attr)) {
     return absl::UnimplementedError("No implementation for this case.");
   }
+  if (!IsRecommendedForWinograd4x4To6x6(attr, gpu_info, output_shape)) {
+    return absl::UnimplementedError("Not recommended for this case.");
+  }
 
   const int tiles_x = DivideRoundUp(output_shape.w, 4);
   const int tiles_y = DivideRoundUp(output_shape.h, 4);
@@ -203,7 +212,7 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
       conv_op.output_ids = {static_cast<int>(outputs[0]->id)};
       OperationDef conv_def = op_def;
       conv_def.src_tensors[1] = weights_desc;
-      ConvWeightsDescription conv_weights_desc;
+      WeightsDescription conv_weights_desc;
       conv_op.operation = SelectConvolutionWithDynamicWeights(
           attr, weights_shape, dst_shape, gpu_info, conv_def, hints,
           &conv_weights_desc);
@@ -280,7 +289,7 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
         conv_op.output_ids = {static_cast<int>(outputs[0]->id)};
         OperationDef conv_def = op_def;
         conv_def.src_tensors[1] = weights_desc;
-        ConvWeightsDescription conv_weights_desc;
+        WeightsDescription conv_weights_desc;
         conv_op.operation = SelectConvolutionWithDynamicWeights(
             attr, weights_shape, output_shape, gpu_info, conv_def, hints,
             &conv_weights_desc);
@@ -329,7 +338,7 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
         conv_op.output_ids = {static_cast<int>(outputs[0]->id)};
         OperationDef conv_def = op_def;
         conv_def.src_tensors[1] = weights_desc;
-        ConvWeightsDescription conv_weights_desc;
+        WeightsDescription conv_weights_desc;
         conv_op.operation = SelectConvolutionTransposedWithDynamicWeights(
             attr, gpu_info, conv_def, &conv_weights_desc);
 
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h
index fa0c5cb1af1..b81bdaa0506 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h
@@ -18,11 +18,11 @@ limitations under the License.
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_hints.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
index 971209f8f22..52c23102dd9 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
@@ -19,10 +19,10 @@ limitations under the License.
 #include <memory>
 
 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h
index 09ba31f87b0..aecd0a0a519 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h
@@ -20,10 +20,10 @@ limitations under the License.
 #include <set>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc b/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc
index 3329bc5a83e..cd3c987ccaf 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc
@@ -17,8 +17,8 @@ limitations under the License.
 
 #include <memory>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h b/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h
index f402358b2c5..f94e0c430a3 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h
@@ -19,8 +19,8 @@ limitations under the License.
 #include <memory>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.cc b/tensorflow/lite/delegates/gpu/cl/serialization.cc
index cc608177e1d..812782c60bc 100644
--- a/tensorflow/lite/delegates/gpu/cl/serialization.cc
+++ b/tensorflow/lite/delegates/gpu/cl/serialization.cc
@@ -19,13 +19,13 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
 #include "tensorflow/lite/delegates/gpu/cl/inference_context.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.h"
 #include "tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h"
@@ -187,10 +187,6 @@ Layout ToEnum(data::Layout type) {
       return Layout::UNKNOWN;
   }
 }
-}  // namespace
-
-namespace cl {
-namespace {
 
 data::CalculationsPrecision ToFB(CalculationsPrecision type) {
   switch (type) {
@@ -279,7 +275,6 @@ CompilerOptions ToEnum(data::CompilerOptions type) {
 }
 
 }  // namespace
-}  // namespace cl
 
 flatbuffers::Offset<data::Int2> Encode(
     const int2& v, flatbuffers::FlatBufferBuilder* builder) {
@@ -732,7 +727,6 @@ flatbuffers::Offset<data::Arguments> Encode(
   return arguments_builder.Finish();
 }
 
-namespace cl {
 flatbuffers::Offset<data::OperationDef> Encode(
     const OperationDef& def, flatbuffers::FlatBufferBuilder* builder) {
   std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>
@@ -773,22 +767,6 @@ void Decode(const data::OperationDef* fb_def, OperationDef* def) {
   def->precision = ToEnum(fb_def->precision());
 }
 
-flatbuffers::Offset<data::TensorDescWithId> Encode(
-    const TensorDescriptor& desc, const ValueId& id,
-    flatbuffers::FlatBufferBuilder* builder) {
-  auto desc_fb = Encode(desc, builder);
-  data::TensorDescWithIdBuilder desc_builder(*builder);
-  desc_builder.add_desc(desc_fb);
-  desc_builder.add_id(id);
-  return desc_builder.Finish();
-}
-
-void Decode(const data::TensorDescWithId* fb_desc, TensorDescriptor* desc,
-            ValueId* id) {
-  Decode(fb_desc->desc(), desc);
-  *id = fb_desc->id();
-}
-
 absl::Status Decode(const data::GPUOperation* fb_op, GPUOperation* op) {
   RETURN_IF_ERROR(Decode(fb_op->arguments(), &op->args_));
   op->code_ = std::string(fb_op->code()->c_str(), fb_op->code()->size());
@@ -881,6 +859,24 @@ flatbuffers::Offset<data::GPUOperation> Encode(
   return op_builder.Finish();
 }
 
+namespace cl {
+
+flatbuffers::Offset<data::TensorDescWithId> Encode(
+    const TensorDescriptor& desc, const ValueId& id,
+    flatbuffers::FlatBufferBuilder* builder) {
+  auto desc_fb = Encode(desc, builder);
+  data::TensorDescWithIdBuilder desc_builder(*builder);
+  desc_builder.add_desc(desc_fb);
+  desc_builder.add_id(id);
+  return desc_builder.Finish();
+}
+
+void Decode(const data::TensorDescWithId* fb_desc, TensorDescriptor* desc,
+            ValueId* id) {
+  Decode(fb_desc->desc(), desc);
+  *id = fb_desc->id();
+}
+
 flatbuffers::Offset<data::CLNode> Encode(
     const CLNode& node, flatbuffers::FlatBufferBuilder* builder) {
   auto op_fb = Encode(node.cl_operation.GetGpuOperation(), builder);
diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.fbs b/tensorflow/lite/delegates/gpu/cl/serialization.fbs
index 9d5cf5ed783..67bd587162e 100644
--- a/tensorflow/lite/delegates/gpu/cl/serialization.fbs
+++ b/tensorflow/lite/delegates/gpu/cl/serialization.fbs
@@ -16,66 +16,13 @@ include "tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs";
 
 namespace tflite.gpu.cl.data;
 
-enum CalculationsPrecision : byte {
-  F32 = 0,
-  F32_F16 = 1,
-  F16 = 2,
-}
-
-enum TensorToGrid : byte {
-  CUSTOM = 0,
-  WB_TO_X_HD_TO_Y_S_TO_Z = 1,
-  WB_TO_X_HD_TO_Y_Z_IS_1 = 2,
-  WB_TO_X_H_TO_Y_D_TO_Z = 3,
-  B_TO_X_Y_IS_1_Z_IS_1 = 4,
-}
-
-enum CompilerOptions : byte {
-  ADRENO_FULL_SIMD_LINE = 0,
-  ADRENO_MORE_WAVES = 1,
-  POWERVR_FP16 = 2,
-  CL_OPT_DISABLE = 3,
-  CL_2_0 = 4,
-  CL_3_0 = 5,
-}
-
-table OperationDef {
-  precision:CalculationsPrecision;
-  src_tensors:[tflite.gpu.data.TensorDescriptor];
-  dst_tensors:[tflite.gpu.data.TensorDescriptor];
-}
-
-table CompilerOption {
-  option:CompilerOptions;
-}
-
-table GPUOperation {
-  arguments:tflite.gpu.data.Arguments;
-  code:string;
-  work_group_size:tflite.gpu.data.Int3;
-  compiler_options:[CompilerOption];
-  tensor_to_grid:TensorToGrid;
-  elementwise:bool;
-  linkable:bool;
-  check_src_channels_size:bool;
-  definition:OperationDef;
-  grid_dimension:int32;
-  work_group_launch_order:tflite.gpu.data.Int3;
-  grid_size:tflite.gpu.data.Int3;
-  src_tensors_names:[string];
-  dst_tensors_names:[string];
-  work_groups_count:tflite.gpu.data.Int3;
-  linkable_count:int32;
-  elementwise_code:string;
-}
-
 table TensorDescWithId {
   desc:tflite.gpu.data.TensorDescriptor;
   id:int32;
 }
 
 table CLNode {
-  gpu_op:GPUOperation;
+  gpu_op:tflite.gpu.data.GPUOperation;
   input_ids:[int32];
   output_ids:[int32];
   name:string;
@@ -91,7 +38,7 @@ table InferenceContext {
   flush_periodically:bool;
   flush_period:int32;
   need_manual_release:bool;
-  precision:CalculationsPrecision;
+  precision:tflite.gpu.data.CalculationsPrecision;
   storage_type:tflite.gpu.data.TensorStorageType;
   nodes:[CLNode];
   tensors:[TensorDescWithId];
diff --git a/tensorflow/lite/delegates/gpu/cl/serialization_generated.h b/tensorflow/lite/delegates/gpu/cl/serialization_generated.h
index a3bc04e12ca..d423b19e60a 100644
--- a/tensorflow/lite/delegates/gpu/cl/serialization_generated.h
+++ b/tensorflow/lite/delegates/gpu/cl/serialization_generated.h
@@ -27,15 +27,6 @@ namespace gpu {
 namespace cl {
 namespace data {
 
-struct OperationDef;
-struct OperationDefBuilder;
-
-struct CompilerOption;
-struct CompilerOptionBuilder;
-
-struct GPUOperation;
-struct GPUOperationBuilder;
-
 struct TensorDescWithId;
 struct TensorDescWithIdBuilder;
 
@@ -48,500 +39,6 @@ struct PairOfValueIdsBuilder;
 struct InferenceContext;
 struct InferenceContextBuilder;
 
-enum class CalculationsPrecision : int8_t {
-  F32 = 0,
-  F32_F16 = 1,
-  F16 = 2,
-  MIN = F32,
-  MAX = F16
-};
-
-inline const CalculationsPrecision (&EnumValuesCalculationsPrecision())[3] {
-  static const CalculationsPrecision values[] = {
-    CalculationsPrecision::F32,
-    CalculationsPrecision::F32_F16,
-    CalculationsPrecision::F16
-  };
-  return values;
-}
-
-inline const char * const *EnumNamesCalculationsPrecision() {
-  static const char * const names[4] = {
-    "F32",
-    "F32_F16",
-    "F16",
-    nullptr
-  };
-  return names;
-}
-
-inline const char *EnumNameCalculationsPrecision(CalculationsPrecision e) {
-  if (flatbuffers::IsOutRange(e, CalculationsPrecision::F32, CalculationsPrecision::F16)) return "";
-  const size_t index = static_cast<size_t>(e);
-  return EnumNamesCalculationsPrecision()[index];
-}
-
-enum class TensorToGrid : int8_t {
-  CUSTOM = 0,
-  WB_TO_X_HD_TO_Y_S_TO_Z = 1,
-  WB_TO_X_HD_TO_Y_Z_IS_1 = 2,
-  WB_TO_X_H_TO_Y_D_TO_Z = 3,
-  B_TO_X_Y_IS_1_Z_IS_1 = 4,
-  MIN = CUSTOM,
-  MAX = B_TO_X_Y_IS_1_Z_IS_1
-};
-
-inline const TensorToGrid (&EnumValuesTensorToGrid())[5] {
-  static const TensorToGrid values[] = {
-    TensorToGrid::CUSTOM,
-    TensorToGrid::WB_TO_X_HD_TO_Y_S_TO_Z,
-    TensorToGrid::WB_TO_X_HD_TO_Y_Z_IS_1,
-    TensorToGrid::WB_TO_X_H_TO_Y_D_TO_Z,
-    TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1
-  };
-  return values;
-}
-
-inline const char * const *EnumNamesTensorToGrid() {
-  static const char * const names[6] = {
-    "CUSTOM",
-    "WB_TO_X_HD_TO_Y_S_TO_Z",
-    "WB_TO_X_HD_TO_Y_Z_IS_1",
-    "WB_TO_X_H_TO_Y_D_TO_Z",
-    "B_TO_X_Y_IS_1_Z_IS_1",
-    nullptr
-  };
-  return names;
-}
-
-inline const char *EnumNameTensorToGrid(TensorToGrid e) {
-  if (flatbuffers::IsOutRange(e, TensorToGrid::CUSTOM, TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1)) return "";
-  const size_t index = static_cast<size_t>(e);
-  return EnumNamesTensorToGrid()[index];
-}
-
-enum class CompilerOptions : int8_t {
-  ADRENO_FULL_SIMD_LINE = 0,
-  ADRENO_MORE_WAVES = 1,
-  POWERVR_FP16 = 2,
-  CL_OPT_DISABLE = 3,
-  CL_2_0 = 4,
-  CL_3_0 = 5,
-  MIN = ADRENO_FULL_SIMD_LINE,
-  MAX = CL_3_0
-};
-
-inline const CompilerOptions (&EnumValuesCompilerOptions())[6] {
-  static const CompilerOptions values[] = {
-    CompilerOptions::ADRENO_FULL_SIMD_LINE,
-    CompilerOptions::ADRENO_MORE_WAVES,
-    CompilerOptions::POWERVR_FP16,
-    CompilerOptions::CL_OPT_DISABLE,
-    CompilerOptions::CL_2_0,
-    CompilerOptions::CL_3_0
-  };
-  return values;
-}
-
-inline const char * const *EnumNamesCompilerOptions() {
-  static const char * const names[7] = {
-    "ADRENO_FULL_SIMD_LINE",
-    "ADRENO_MORE_WAVES",
-    "POWERVR_FP16",
-    "CL_OPT_DISABLE",
-    "CL_2_0",
-    "CL_3_0",
-    nullptr
-  };
-  return names;
-}
-
-inline const char *EnumNameCompilerOptions(CompilerOptions e) {
-  if (flatbuffers::IsOutRange(e, CompilerOptions::ADRENO_FULL_SIMD_LINE, CompilerOptions::CL_3_0)) return "";
-  const size_t index = static_cast<size_t>(e);
-  return EnumNamesCompilerOptions()[index];
-}
-
-struct OperationDef FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
-  typedef OperationDefBuilder Builder;
-  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
-    VT_PRECISION = 4,
-    VT_SRC_TENSORS = 6,
-    VT_DST_TENSORS = 8
-  };
-  tflite::gpu::cl::data::CalculationsPrecision precision() const {
-    return static_cast<tflite::gpu::cl::data::CalculationsPrecision>(GetField<int8_t>(VT_PRECISION, 0));
-  }
-  const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *src_tensors() const {
-    return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *>(VT_SRC_TENSORS);
-  }
-  const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *dst_tensors() const {
-    return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *>(VT_DST_TENSORS);
-  }
-  bool Verify(flatbuffers::Verifier &verifier) const {
-    return VerifyTableStart(verifier) &&
-           VerifyField<int8_t>(verifier, VT_PRECISION) &&
-           VerifyOffset(verifier, VT_SRC_TENSORS) &&
-           verifier.VerifyVector(src_tensors()) &&
-           verifier.VerifyVectorOfTables(src_tensors()) &&
-           VerifyOffset(verifier, VT_DST_TENSORS) &&
-           verifier.VerifyVector(dst_tensors()) &&
-           verifier.VerifyVectorOfTables(dst_tensors()) &&
-           verifier.EndTable();
-  }
-};
-
-struct OperationDefBuilder {
-  typedef OperationDef Table;
-  flatbuffers::FlatBufferBuilder &fbb_;
-  flatbuffers::uoffset_t start_;
-  void add_precision(tflite::gpu::cl::data::CalculationsPrecision precision) {
-    fbb_.AddElement<int8_t>(OperationDef::VT_PRECISION, static_cast<int8_t>(precision), 0);
-  }
-  void add_src_tensors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> src_tensors) {
-    fbb_.AddOffset(OperationDef::VT_SRC_TENSORS, src_tensors);
-  }
-  void add_dst_tensors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> dst_tensors) {
-    fbb_.AddOffset(OperationDef::VT_DST_TENSORS, dst_tensors);
-  }
-  explicit OperationDefBuilder(flatbuffers::FlatBufferBuilder &_fbb)
-        : fbb_(_fbb) {
-    start_ = fbb_.StartTable();
-  }
-  flatbuffers::Offset<OperationDef> Finish() {
-    const auto end = fbb_.EndTable(start_);
-    auto o = flatbuffers::Offset<OperationDef>(end);
-    return o;
-  }
-};
-
-inline flatbuffers::Offset<OperationDef> CreateOperationDef(
-    flatbuffers::FlatBufferBuilder &_fbb,
-    tflite::gpu::cl::data::CalculationsPrecision precision = tflite::gpu::cl::data::CalculationsPrecision::F32,
-    flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> src_tensors = 0,
-    flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> dst_tensors = 0) {
-  OperationDefBuilder builder_(_fbb);
-  builder_.add_dst_tensors(dst_tensors);
-  builder_.add_src_tensors(src_tensors);
-  builder_.add_precision(precision);
-  return builder_.Finish();
-}
-
-inline flatbuffers::Offset<OperationDef> CreateOperationDefDirect(
-    flatbuffers::FlatBufferBuilder &_fbb,
-    tflite::gpu::cl::data::CalculationsPrecision precision = tflite::gpu::cl::data::CalculationsPrecision::F32,
-    const std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *src_tensors = nullptr,
-    const std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *dst_tensors = nullptr) {
-  auto src_tensors__ = src_tensors ? _fbb.CreateVector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>(*src_tensors) : 0;
-  auto dst_tensors__ = dst_tensors ? _fbb.CreateVector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>(*dst_tensors) : 0;
-  return tflite::gpu::cl::data::CreateOperationDef(
-      _fbb,
-      precision,
-      src_tensors__,
-      dst_tensors__);
-}
-
-struct CompilerOption FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
-  typedef CompilerOptionBuilder Builder;
-  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
-    VT_OPTION = 4
-  };
-  tflite::gpu::cl::data::CompilerOptions option() const {
-    return static_cast<tflite::gpu::cl::data::CompilerOptions>(GetField<int8_t>(VT_OPTION, 0));
-  }
-  bool Verify(flatbuffers::Verifier &verifier) const {
-    return VerifyTableStart(verifier) &&
-           VerifyField<int8_t>(verifier, VT_OPTION) &&
-           verifier.EndTable();
-  }
-};
-
-struct CompilerOptionBuilder {
-  typedef CompilerOption Table;
-  flatbuffers::FlatBufferBuilder &fbb_;
-  flatbuffers::uoffset_t start_;
-  void add_option(tflite::gpu::cl::data::CompilerOptions option) {
-    fbb_.AddElement<int8_t>(CompilerOption::VT_OPTION, static_cast<int8_t>(option), 0);
-  }
-  explicit CompilerOptionBuilder(flatbuffers::FlatBufferBuilder &_fbb)
-        : fbb_(_fbb) {
-    start_ = fbb_.StartTable();
-  }
-  flatbuffers::Offset<CompilerOption> Finish() {
-    const auto end = fbb_.EndTable(start_);
-    auto o = flatbuffers::Offset<CompilerOption>(end);
-    return o;
-  }
-};
-
-inline flatbuffers::Offset<CompilerOption> CreateCompilerOption(
-    flatbuffers::FlatBufferBuilder &_fbb,
-    tflite::gpu::cl::data::CompilerOptions option = tflite::gpu::cl::data::CompilerOptions::ADRENO_FULL_SIMD_LINE) {
-  CompilerOptionBuilder builder_(_fbb);
-  builder_.add_option(option);
-  return builder_.Finish();
-}
-
-struct GPUOperation FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
-  typedef GPUOperationBuilder Builder;
-  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
-    VT_ARGUMENTS = 4,
-    VT_CODE = 6,
-    VT_WORK_GROUP_SIZE = 8,
-    VT_COMPILER_OPTIONS = 10,
-    VT_TENSOR_TO_GRID = 12,
-    VT_ELEMENTWISE = 14,
-    VT_LINKABLE = 16,
-    VT_CHECK_SRC_CHANNELS_SIZE = 18,
-    VT_DEFINITION = 20,
-    VT_GRID_DIMENSION = 22,
-    VT_WORK_GROUP_LAUNCH_ORDER = 24,
-    VT_GRID_SIZE = 26,
-    VT_SRC_TENSORS_NAMES = 28,
-    VT_DST_TENSORS_NAMES = 30,
-    VT_WORK_GROUPS_COUNT = 32,
-    VT_LINKABLE_COUNT = 34,
-    VT_ELEMENTWISE_CODE = 36
-  };
-  const tflite::gpu::data::Arguments *arguments() const {
-    return GetPointer<const tflite::gpu::data::Arguments *>(VT_ARGUMENTS);
-  }
-  const flatbuffers::String *code() const {
-    return GetPointer<const flatbuffers::String *>(VT_CODE);
-  }
-  const tflite::gpu::data::Int3 *work_group_size() const {
-    return GetPointer<const tflite::gpu::data::Int3 *>(VT_WORK_GROUP_SIZE);
-  }
-  const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>> *compiler_options() const {
-    return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>> *>(VT_COMPILER_OPTIONS);
-  }
-  tflite::gpu::cl::data::TensorToGrid tensor_to_grid() const {
-    return static_cast<tflite::gpu::cl::data::TensorToGrid>(GetField<int8_t>(VT_TENSOR_TO_GRID, 0));
-  }
-  bool elementwise() const {
-    return GetField<uint8_t>(VT_ELEMENTWISE, 0) != 0;
-  }
-  bool linkable() const {
-    return GetField<uint8_t>(VT_LINKABLE, 0) != 0;
-  }
-  bool check_src_channels_size() const {
-    return GetField<uint8_t>(VT_CHECK_SRC_CHANNELS_SIZE, 0) != 0;
-  }
-  const tflite::gpu::cl::data::OperationDef *definition() const {
-    return GetPointer<const tflite::gpu::cl::data::OperationDef *>(VT_DEFINITION);
-  }
-  int32_t grid_dimension() const {
-    return GetField<int32_t>(VT_GRID_DIMENSION, 0);
-  }
-  const tflite::gpu::data::Int3 *work_group_launch_order() const {
-    return GetPointer<const tflite::gpu::data::Int3 *>(VT_WORK_GROUP_LAUNCH_ORDER);
-  }
-  const tflite::gpu::data::Int3 *grid_size() const {
-    return GetPointer<const tflite::gpu::data::Int3 *>(VT_GRID_SIZE);
-  }
-  const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *src_tensors_names() const {
-    return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_SRC_TENSORS_NAMES);
-  }
-  const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *dst_tensors_names() const {
-    return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_DST_TENSORS_NAMES);
-  }
-  const tflite::gpu::data::Int3 *work_groups_count() const {
-    return GetPointer<const tflite::gpu::data::Int3 *>(VT_WORK_GROUPS_COUNT);
-  }
-  int32_t linkable_count() const {
-    return GetField<int32_t>(VT_LINKABLE_COUNT, 0);
-  }
-  const flatbuffers::String *elementwise_code() const {
-    return GetPointer<const flatbuffers::String *>(VT_ELEMENTWISE_CODE);
-  }
-  bool Verify(flatbuffers::Verifier &verifier) const {
-    return VerifyTableStart(verifier) &&
-           VerifyOffset(verifier, VT_ARGUMENTS) &&
-           verifier.VerifyTable(arguments()) &&
-           VerifyOffset(verifier, VT_CODE) &&
-           verifier.VerifyString(code()) &&
-           VerifyOffset(verifier, VT_WORK_GROUP_SIZE) &&
-           verifier.VerifyTable(work_group_size()) &&
-           VerifyOffset(verifier, VT_COMPILER_OPTIONS) &&
-           verifier.VerifyVector(compiler_options()) &&
-           verifier.VerifyVectorOfTables(compiler_options()) &&
-           VerifyField<int8_t>(verifier, VT_TENSOR_TO_GRID) &&
-           VerifyField<uint8_t>(verifier, VT_ELEMENTWISE) &&
-           VerifyField<uint8_t>(verifier, VT_LINKABLE) &&
-           VerifyField<uint8_t>(verifier, VT_CHECK_SRC_CHANNELS_SIZE) &&
-           VerifyOffset(verifier, VT_DEFINITION) &&
-           verifier.VerifyTable(definition()) &&
-           VerifyField<int32_t>(verifier, VT_GRID_DIMENSION) &&
-           VerifyOffset(verifier, VT_WORK_GROUP_LAUNCH_ORDER) &&
-           verifier.VerifyTable(work_group_launch_order()) &&
-           VerifyOffset(verifier, VT_GRID_SIZE) &&
-           verifier.VerifyTable(grid_size()) &&
-           VerifyOffset(verifier, VT_SRC_TENSORS_NAMES) &&
-           verifier.VerifyVector(src_tensors_names()) &&
-           verifier.VerifyVectorOfStrings(src_tensors_names()) &&
-           VerifyOffset(verifier, VT_DST_TENSORS_NAMES) &&
-           verifier.VerifyVector(dst_tensors_names()) &&
-           verifier.VerifyVectorOfStrings(dst_tensors_names()) &&
-           VerifyOffset(verifier, VT_WORK_GROUPS_COUNT) &&
-           verifier.VerifyTable(work_groups_count()) &&
-           VerifyField<int32_t>(verifier, VT_LINKABLE_COUNT) &&
-           VerifyOffset(verifier, VT_ELEMENTWISE_CODE) &&
-           verifier.VerifyString(elementwise_code()) &&
-           verifier.EndTable();
-  }
-};
-
-struct GPUOperationBuilder {
-  typedef GPUOperation Table;
-  flatbuffers::FlatBufferBuilder &fbb_;
-  flatbuffers::uoffset_t start_;
-  void add_arguments(flatbuffers::Offset<tflite::gpu::data::Arguments> arguments) {
-    fbb_.AddOffset(GPUOperation::VT_ARGUMENTS, arguments);
-  }
-  void add_code(flatbuffers::Offset<flatbuffers::String> code) {
-    fbb_.AddOffset(GPUOperation::VT_CODE, code);
-  }
-  void add_work_group_size(flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size) {
-    fbb_.AddOffset(GPUOperation::VT_WORK_GROUP_SIZE, work_group_size);
-  }
-  void add_compiler_options(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>>> compiler_options) {
-    fbb_.AddOffset(GPUOperation::VT_COMPILER_OPTIONS, compiler_options);
-  }
-  void add_tensor_to_grid(tflite::gpu::cl::data::TensorToGrid tensor_to_grid) {
-    fbb_.AddElement<int8_t>(GPUOperation::VT_TENSOR_TO_GRID, static_cast<int8_t>(tensor_to_grid), 0);
-  }
-  void add_elementwise(bool elementwise) {
-    fbb_.AddElement<uint8_t>(GPUOperation::VT_ELEMENTWISE, static_cast<uint8_t>(elementwise), 0);
-  }
-  void add_linkable(bool linkable) {
-    fbb_.AddElement<uint8_t>(GPUOperation::VT_LINKABLE, static_cast<uint8_t>(linkable), 0);
-  }
-  void add_check_src_channels_size(bool check_src_channels_size) {
-    fbb_.AddElement<uint8_t>(GPUOperation::VT_CHECK_SRC_CHANNELS_SIZE, static_cast<uint8_t>(check_src_channels_size), 0);
-  }
-  void add_definition(flatbuffers::Offset<tflite::gpu::cl::data::OperationDef> definition) {
-    fbb_.AddOffset(GPUOperation::VT_DEFINITION, definition);
-  }
-  void add_grid_dimension(int32_t grid_dimension) {
-    fbb_.AddElement<int32_t>(GPUOperation::VT_GRID_DIMENSION, grid_dimension, 0);
-  }
-  void add_work_group_launch_order(flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order) {
-    fbb_.AddOffset(GPUOperation::VT_WORK_GROUP_LAUNCH_ORDER, work_group_launch_order);
-  }
-  void add_grid_size(flatbuffers::Offset<tflite::gpu::data::Int3> grid_size) {
-    fbb_.AddOffset(GPUOperation::VT_GRID_SIZE, grid_size);
-  }
-  void add_src_tensors_names(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> src_tensors_names) {
-    fbb_.AddOffset(GPUOperation::VT_SRC_TENSORS_NAMES, src_tensors_names);
-  }
-  void add_dst_tensors_names(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> dst_tensors_names) {
-    fbb_.AddOffset(GPUOperation::VT_DST_TENSORS_NAMES, dst_tensors_names);
-  }
-  void add_work_groups_count(flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count) {
-    fbb_.AddOffset(GPUOperation::VT_WORK_GROUPS_COUNT, work_groups_count);
-  }
-  void add_linkable_count(int32_t linkable_count) {
-    fbb_.AddElement<int32_t>(GPUOperation::VT_LINKABLE_COUNT, linkable_count, 0);
-  }
-  void add_elementwise_code(flatbuffers::Offset<flatbuffers::String> elementwise_code) {
-    fbb_.AddOffset(GPUOperation::VT_ELEMENTWISE_CODE, elementwise_code);
-  }
-  explicit GPUOperationBuilder(flatbuffers::FlatBufferBuilder &_fbb)
-        : fbb_(_fbb) {
-    start_ = fbb_.StartTable();
-  }
-  flatbuffers::Offset<GPUOperation> Finish() {
-    const auto end = fbb_.EndTable(start_);
-    auto o = flatbuffers::Offset<GPUOperation>(end);
-    return o;
-  }
-};
-
-inline flatbuffers::Offset<GPUOperation> CreateGPUOperation(
-    flatbuffers::FlatBufferBuilder &_fbb,
-    flatbuffers::Offset<tflite::gpu::data::Arguments> arguments = 0,
-    flatbuffers::Offset<flatbuffers::String> code = 0,
-    flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size = 0,
-    flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>>> compiler_options = 0,
-    tflite::gpu::cl::data::TensorToGrid tensor_to_grid = tflite::gpu::cl::data::TensorToGrid::CUSTOM,
-    bool elementwise = false,
-    bool linkable = false,
-    bool check_src_channels_size = false,
-    flatbuffers::Offset<tflite::gpu::cl::data::OperationDef> definition = 0,
-    int32_t grid_dimension = 0,
-    flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order = 0,
-    flatbuffers::Offset<tflite::gpu::data::Int3> grid_size = 0,
-    flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> src_tensors_names = 0,
-    flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> dst_tensors_names = 0,
-    flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count = 0,
-    int32_t linkable_count = 0,
-    flatbuffers::Offset<flatbuffers::String> elementwise_code = 0) {
-  GPUOperationBuilder builder_(_fbb);
-  builder_.add_elementwise_code(elementwise_code);
-  builder_.add_linkable_count(linkable_count);
-  builder_.add_work_groups_count(work_groups_count);
-  builder_.add_dst_tensors_names(dst_tensors_names);
-  builder_.add_src_tensors_names(src_tensors_names);
-  builder_.add_grid_size(grid_size);
-  builder_.add_work_group_launch_order(work_group_launch_order);
-  builder_.add_grid_dimension(grid_dimension);
-  builder_.add_definition(definition);
-  builder_.add_compiler_options(compiler_options);
-  builder_.add_work_group_size(work_group_size);
-  builder_.add_code(code);
-  builder_.add_arguments(arguments);
-  builder_.add_check_src_channels_size(check_src_channels_size);
-  builder_.add_linkable(linkable);
-  builder_.add_elementwise(elementwise);
-  builder_.add_tensor_to_grid(tensor_to_grid);
-  return builder_.Finish();
-}
-
-inline flatbuffers::Offset<GPUOperation> CreateGPUOperationDirect(
-    flatbuffers::FlatBufferBuilder &_fbb,
-    flatbuffers::Offset<tflite::gpu::data::Arguments> arguments = 0,
-    const char *code = nullptr,
-    flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size = 0,
-    const std::vector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>> *compiler_options = nullptr,
-    tflite::gpu::cl::data::TensorToGrid tensor_to_grid = tflite::gpu::cl::data::TensorToGrid::CUSTOM,
-    bool elementwise = false,
-    bool linkable = false,
-    bool check_src_channels_size = false,
-    flatbuffers::Offset<tflite::gpu::cl::data::OperationDef> definition = 0,
-    int32_t grid_dimension = 0,
-    flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order = 0,
-    flatbuffers::Offset<tflite::gpu::data::Int3> grid_size = 0,
-    const std::vector<flatbuffers::Offset<flatbuffers::String>> *src_tensors_names = nullptr,
-    const std::vector<flatbuffers::Offset<flatbuffers::String>> *dst_tensors_names = nullptr,
-    flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count = 0,
-    int32_t linkable_count = 0,
-    const char *elementwise_code = nullptr) {
-  auto code__ = code ? _fbb.CreateString(code) : 0;
-  auto compiler_options__ = compiler_options ? _fbb.CreateVector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>>(*compiler_options) : 0;
-  auto src_tensors_names__ = src_tensors_names ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*src_tensors_names) : 0;
-  auto dst_tensors_names__ = dst_tensors_names ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*dst_tensors_names) : 0;
-  auto elementwise_code__ = elementwise_code ? _fbb.CreateString(elementwise_code) : 0;
-  return tflite::gpu::cl::data::CreateGPUOperation(
-      _fbb,
-      arguments,
-      code__,
-      work_group_size,
-      compiler_options__,
-      tensor_to_grid,
-      elementwise,
-      linkable,
-      check_src_channels_size,
-      definition,
-      grid_dimension,
-      work_group_launch_order,
-      grid_size,
-      src_tensors_names__,
-      dst_tensors_names__,
-      work_groups_count,
-      linkable_count,
-      elementwise_code__);
-}
-
 struct TensorDescWithId FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   typedef TensorDescWithIdBuilder Builder;
   enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
@@ -602,8 +99,8 @@ struct CLNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
     VT_OUTPUT_IDS = 8,
     VT_NAME = 10
   };
-  const tflite::gpu::cl::data::GPUOperation *gpu_op() const {
-    return GetPointer<const tflite::gpu::cl::data::GPUOperation *>(VT_GPU_OP);
+  const tflite::gpu::data::GPUOperation *gpu_op() const {
+    return GetPointer<const tflite::gpu::data::GPUOperation *>(VT_GPU_OP);
   }
   const flatbuffers::Vector<int32_t> *input_ids() const {
     return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_INPUT_IDS);
@@ -632,7 +129,7 @@ struct CLNodeBuilder {
   typedef CLNode Table;
   flatbuffers::FlatBufferBuilder &fbb_;
   flatbuffers::uoffset_t start_;
-  void add_gpu_op(flatbuffers::Offset<tflite::gpu::cl::data::GPUOperation> gpu_op) {
+  void add_gpu_op(flatbuffers::Offset<tflite::gpu::data::GPUOperation> gpu_op) {
     fbb_.AddOffset(CLNode::VT_GPU_OP, gpu_op);
   }
   void add_input_ids(flatbuffers::Offset<flatbuffers::Vector<int32_t>> input_ids) {
@@ -657,7 +154,7 @@ struct CLNodeBuilder {
 
 inline flatbuffers::Offset<CLNode> CreateCLNode(
     flatbuffers::FlatBufferBuilder &_fbb,
-    flatbuffers::Offset<tflite::gpu::cl::data::GPUOperation> gpu_op = 0,
+    flatbuffers::Offset<tflite::gpu::data::GPUOperation> gpu_op = 0,
     flatbuffers::Offset<flatbuffers::Vector<int32_t>> input_ids = 0,
     flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_ids = 0,
     flatbuffers::Offset<flatbuffers::String> name = 0) {
@@ -671,7 +168,7 @@ inline flatbuffers::Offset<CLNode> CreateCLNode(
 
 inline flatbuffers::Offset<CLNode> CreateCLNodeDirect(
     flatbuffers::FlatBufferBuilder &_fbb,
-    flatbuffers::Offset<tflite::gpu::cl::data::GPUOperation> gpu_op = 0,
+    flatbuffers::Offset<tflite::gpu::data::GPUOperation> gpu_op = 0,
     const std::vector<int32_t> *input_ids = nullptr,
     const std::vector<int32_t> *output_ids = nullptr,
     const char *name = nullptr) {
@@ -767,8 +264,9 @@ struct InferenceContext FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   bool need_manual_release() const {
     return GetField<uint8_t>(VT_NEED_MANUAL_RELEASE, 0) != 0;
   }
-  tflite::gpu::cl::data::CalculationsPrecision precision() const {
-    return static_cast<tflite::gpu::cl::data::CalculationsPrecision>(GetField<int8_t>(VT_PRECISION, 0));
+  tflite::gpu::data::CalculationsPrecision precision() const {
+    return static_cast<tflite::gpu::data::CalculationsPrecision>(
+        GetField<int8_t>(VT_PRECISION, 0));
   }
   tflite::gpu::data::TensorStorageType storage_type() const {
     return static_cast<tflite::gpu::data::TensorStorageType>(GetField<int8_t>(VT_STORAGE_TYPE, 0));
@@ -847,7 +345,7 @@ struct InferenceContextBuilder {
   void add_need_manual_release(bool need_manual_release) {
     fbb_.AddElement<uint8_t>(InferenceContext::VT_NEED_MANUAL_RELEASE, static_cast<uint8_t>(need_manual_release), 0);
   }
-  void add_precision(tflite::gpu::cl::data::CalculationsPrecision precision) {
+  void add_precision(tflite::gpu::data::CalculationsPrecision precision) {
     fbb_.AddElement<int8_t>(InferenceContext::VT_PRECISION, static_cast<int8_t>(precision), 0);
   }
   void add_storage_type(tflite::gpu::data::TensorStorageType storage_type) {
@@ -895,8 +393,8 @@ inline flatbuffers::Offset<InferenceContext> CreateInferenceContext(
     flatbuffers::FlatBufferBuilder &_fbb, bool need_flush = false,
     bool flush_periodically = false, int32_t flush_period = 0,
     bool need_manual_release = false,
-    tflite::gpu::cl::data::CalculationsPrecision precision =
-        tflite::gpu::cl::data::CalculationsPrecision::F32,
+    tflite::gpu::data::CalculationsPrecision precision =
+        tflite::gpu::data::CalculationsPrecision::F32,
     tflite::gpu::data::TensorStorageType storage_type =
         tflite::gpu::data::TensorStorageType::UNKNOWN,
     flatbuffers::Offset<
@@ -937,8 +435,8 @@ inline flatbuffers::Offset<InferenceContext> CreateInferenceContextDirect(
     flatbuffers::FlatBufferBuilder &_fbb, bool need_flush = false,
     bool flush_periodically = false, int32_t flush_period = 0,
     bool need_manual_release = false,
-    tflite::gpu::cl::data::CalculationsPrecision precision =
-        tflite::gpu::cl::data::CalculationsPrecision::F32,
+    tflite::gpu::data::CalculationsPrecision precision =
+        tflite::gpu::data::CalculationsPrecision::F32,
     tflite::gpu::data::TensorStorageType storage_type =
         tflite::gpu::data::TensorStorageType::UNKNOWN,
     const std::vector<flatbuffers::Offset<tflite::gpu::cl::data::CLNode>>
diff --git a/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc
index d579368926e..7ba81d138e2 100644
--- a/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc
+++ b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc
@@ -33,37 +33,38 @@ bool CanCreateTensorWithShape(const GpuInfo& gpu_info, const BHWDC& shape,
           4 * (descriptor.data_type == DataType::FLOAT32 ? 4 : 2);
       const int buffer_size =
           shape.b * shape.w * shape.h * shape.d * slices * flt4_size;
-      return buffer_size <= gpu_info.buffer_max_size;
+      return buffer_size <= gpu_info.GetMaxBufferSize();
     }
     case TensorStorageType::IMAGE_BUFFER:
       return shape.b * shape.w * shape.h * shape.d * slices <=
-             gpu_info.image_buffer_max_size;
+             gpu_info.GetMaxImageBufferWidth();
     case TensorStorageType::TEXTURE_3D:
-      if (gpu_info.cl_version < OpenCLVersion::CL_1_2 && slices == 1) {
+      if (gpu_info.opencl_info.cl_version < OpenClVersion::kCl1_2 &&
+          slices == 1) {
         // clCreateImage3D (that used in CL 1.0/1.1) can not create image with
         // depth = 1 by specification;
         return false;
       }
-      return shape.w * shape.b <= gpu_info.image3d_max_width &&
-             shape.h <= gpu_info.image3d_max_height &&
-             slices * shape.d <= gpu_info.image3d_max_depth;
+      return shape.w * shape.b <= gpu_info.GetMaxImage3DWidth() &&
+             shape.h <= gpu_info.GetMaxImage3DHeight() &&
+             slices * shape.d <= gpu_info.GetMaxImage3DDepth();
     case TensorStorageType::TEXTURE_ARRAY:
       // Bug on some Adreno. b/131099086
       if (slices == 1 && gpu_info.IsAdreno() &&
           !gpu_info.adreno_info.support_one_layer_texture_array) {
         return false;
       }
-      return shape.w * shape.b <= gpu_info.image2d_max_width &&
-             shape.h <= gpu_info.image2d_max_height &&
-             slices * shape.d <= gpu_info.image_array_max_layers;
+      return shape.w * shape.b <= gpu_info.GetMaxImage2DWidth() &&
+             shape.h <= gpu_info.GetMaxImage2DHeight() &&
+             slices * shape.d <= gpu_info.GetMaxImage2DArrayLayers();
     case TensorStorageType::TEXTURE_2D:
-      return shape.w * shape.b * shape.d <= gpu_info.image2d_max_width &&
-             shape.h * slices <= gpu_info.image2d_max_height;
+      return shape.w * shape.b * shape.d <= gpu_info.GetMaxImage2DWidth() &&
+             shape.h * slices <= gpu_info.GetMaxImage2DHeight();
     case TensorStorageType::SINGLE_TEXTURE_2D:
       return shape.c <= 4 &&
              gpu_info.SupportsFloatImage2D(descriptor.data_type, shape.c) &&
-             shape.w * shape.b * shape.d <= gpu_info.image2d_max_width &&
-             shape.h <= gpu_info.image2d_max_height;
+             shape.w * shape.b * shape.d <= gpu_info.GetMaxImage2DWidth() &&
+             shape.h <= gpu_info.GetMaxImage2DHeight();
     default:
       return false;
   }
diff --git a/tensorflow/lite/delegates/gpu/cl/storage_type_util.h b/tensorflow/lite/delegates/gpu/cl/storage_type_util.h
index f30219156b4..b849d7a1087 100644
--- a/tensorflow/lite/delegates/gpu/cl/storage_type_util.h
+++ b/tensorflow/lite/delegates/gpu/cl/storage_type_util.h
@@ -16,8 +16,8 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_STORAGE_TYPE_UTIL_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_STORAGE_TYPE_UTIL_H_
 
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 
diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD
index fe67d396539..ef43c76310b 100644
--- a/tensorflow/lite/delegates/gpu/common/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/BUILD
@@ -40,6 +40,7 @@ cc_library(
     srcs = ["gpu_info.cc"],
     hdrs = ["gpu_info.h"],
     deps = [
+        ":data_type",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -386,11 +387,21 @@ cc_library(
     hdrs = ["winograd_util.h"],
     deps = [
         ":data_type",
+        ":operations",
         ":shape",
         ":tensor",
     ],
 )
 
+cc_test(
+    name = "winograd_util_test",
+    srcs = ["winograd_util_test.cc"],
+    deps = [
+        ":winograd_util",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
 cc_library(
     name = "workgroup_selection",
     srcs = ["workgroup_selection.cc"],
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
index 14f79df7801..2204f4a6448 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
@@ -358,6 +358,27 @@ void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
   }
 }
 
+std::string OpenClVersionToString(OpenClVersion version) {
+  switch (version) {
+    case OpenClVersion::kCl1_0:
+      return "1.0";
+    case OpenClVersion::kCl1_1:
+      return "1.1";
+    case OpenClVersion::kCl1_2:
+      return "1.2";
+    case OpenClVersion::kCl2_0:
+      return "2.0";
+    case OpenClVersion::kCl2_1:
+      return "2.1";
+    case OpenClVersion::kCl2_2:
+      return "2.2";
+    case OpenClVersion::kCl3_0:
+      return "3.0";
+    default:
+      return "Unknown OpenCL version";
+  }
+}
+
 bool GpuInfo::IsAdreno() const { return vendor == GpuVendor::kQualcomm; }
 
 bool GpuInfo::IsApple() const { return vendor == GpuVendor::kApple; }
@@ -373,11 +394,45 @@ bool GpuInfo::IsAMD() const { return vendor == GpuVendor::kAMD; }
 bool GpuInfo::IsIntel() const { return vendor == GpuVendor::kIntel; }
 
 bool GpuInfo::IsRoundToNearestSupported() const {
+  if (IsApiOpenCl()) {
+    return opencl_info.supports_fp16_rtn || opencl_info.supports_fp32_rtn;
+  }
   if (IsApple()) {
     return apple_info.IsRoundToNearestSupported();
-  } else {
-    return true;
   }
+  return true;
+}
+
+bool GpuInfo::SupportsFP16() const {
+  if (IsApiOpenCl()) {
+    return opencl_info.supports_fp16;
+  }
+  return true;
+}
+
+bool GpuInfo::SupportsTextureArray() const {
+  if (IsApiOpenCl()) {
+    return opencl_info.cl_version >= OpenClVersion::kCl1_2;
+  }
+  return true;
+}
+
+bool GpuInfo::SupportsImageBuffer() const {
+  if (IsApiOpenCl()) {
+    return opencl_info.cl_version >= OpenClVersion::kCl1_2;
+  }
+  return true;
+}
+
+bool GpuInfo::SupportsImage3D() const {
+  if (IsApiOpenCl()) {
+    if (IsMali() && mali_info.IsMidgard()) {
+      // On Mali T880 read_imageh doesn't compile with image3d_t
+      return false;
+    }
+    return opencl_info.supports_image3d_writes;
+  }
+  return true;
 }
 
 bool GpuInfo::IsWaveSizeEqualTo32() const {
@@ -385,36 +440,212 @@ bool GpuInfo::IsWaveSizeEqualTo32() const {
          supported_subgroup_sizes[0] == 32;
 }
 
+bool GpuInfo::SupportsExtension(const std::string& extension) const {
+  const std::vector<std::string>* extensions = nullptr;
+  if (IsApiOpenGl()) {
+    extensions = &opengl_info.extensions;
+  } else if (IsApiVulkan()) {
+    extensions = &vulkan_info.extensions;
+  } else if (IsApiOpenCl()) {
+    extensions = &opencl_info.extensions;
+  }
+  if (!extensions) {
+    return false;
+  }
+  for (const auto& ext : *extensions) {
+    if (ext == extension) {
+      return true;
+    }
+  }
+  return false;
+}
+
+bool GpuInfo::SupportsSubGroupWithSize(int sub_group_size) const {
+  for (auto subgroup_size : supported_subgroup_sizes) {
+    if (sub_group_size == subgroup_size) {
+      return true;
+    }
+  }
+  return false;
+}
+
+bool GpuInfo::SupportsFloatImage2D(DataType data_type, int channels) const {
+  if (IsApiOpenCl()) {
+    if (channels == 1) {
+      return data_type == DataType::FLOAT32 ? opencl_info.supports_r_f32_tex2d
+                                            : opencl_info.supports_r_f16_tex2d;
+    } else if (channels == 2) {
+      return data_type == DataType::FLOAT32 ? opencl_info.supports_rg_f32_tex2d
+                                            : opencl_info.supports_rg_f16_tex2d;
+    } else if (channels == 3) {
+      return data_type == DataType::FLOAT32
+                 ? opencl_info.supports_rgb_f32_tex2d
+                 : opencl_info.supports_rgb_f16_tex2d;
+    } else if (channels == 4) {
+      return data_type == DataType::FLOAT32
+                 ? opencl_info.supports_rgba_f32_tex2d
+                 : opencl_info.supports_rgba_f16_tex2d;
+    } else {
+      return false;
+    }
+  }
+  return false;
+}
+
 int GpuInfo::GetComputeUnitsCount() const {
+  if (IsApiOpenCl()) {
+    return opencl_info.compute_units_count;
+  }
   if (IsApple()) {
     return apple_info.GetComputeUnitsCount();
-  } else {
-    return 1;
   }
+  return 1;
+}
+
+int GpuInfo::GetMaxWorkGroupSizeForX() const {
+  if (IsApiOpenGl()) {
+    return opengl_info.max_compute_work_group_size_x;
+  }
+  if (IsApiVulkan()) {
+    return vulkan_info.max_compute_work_group_size_x;
+  }
+  if (IsApiOpenCl()) {
+    return opencl_info.max_work_group_size_x;
+  }
+  return 256;
+}
+
+int GpuInfo::GetMaxWorkGroupSizeForY() const {
+  if (IsApiOpenGl()) {
+    return opengl_info.max_compute_work_group_size_y;
+  }
+  if (IsApiVulkan()) {
+    return vulkan_info.max_compute_work_group_size_y;
+  }
+  if (IsApiOpenCl()) {
+    return opencl_info.max_work_group_size_y;
+  }
+  return 256;
+}
+
+int GpuInfo::GetMaxWorkGroupSizeForZ() const {
+  if (IsApiOpenGl()) {
+    return opengl_info.max_compute_work_group_size_z;
+  }
+  if (IsApiVulkan()) {
+    return vulkan_info.max_compute_work_group_size_z;
+  }
+  if (IsApiOpenCl()) {
+    return opencl_info.max_work_group_size_z;
+  }
+  return 64;
+}
+
+int GpuInfo::GetMaxWorkGroupTotalSize() const {
+  if (IsApiOpenGl()) {
+    return opengl_info.max_work_group_invocations;
+  }
+  if (IsApiVulkan()) {
+    return vulkan_info.max_compute_work_group_invocations;
+  }
+  if (IsApiOpenCl()) {
+    return opencl_info.max_work_group_total_size;
+  }
+  return 256;
+}
+
+uint64_t GpuInfo::GetMaxImage2DWidth() const {
+  if (IsApiOpenGl()) {
+    return opengl_info.max_texture_size;
+  }
+  if (IsApiVulkan()) {
+    return vulkan_info.max_image_dimension_2d;
+  }
+  if (IsApiOpenCl()) {
+    return opencl_info.image2d_max_width;
+  }
+  return 2048;
+}
+
+uint64_t GpuInfo::GetMaxImage2DHeight() const {
+  if (IsApiOpenGl()) {
+    return opengl_info.max_texture_size;
+  }
+  if (IsApiVulkan()) {
+    return vulkan_info.max_image_dimension_2d;
+  }
+  if (IsApiOpenCl()) {
+    return opencl_info.image2d_max_height;
+  }
+  return 2048;
+}
+
+uint64_t GpuInfo::GetMaxImage2DArrayLayers() const {
+  if (IsApiOpenGl()) {
+    return opengl_info.max_array_texture_layers;
+  }
+  if (IsApiVulkan()) {
+    return vulkan_info.max_image_array_layers;
+  }
+  if (IsApiOpenCl()) {
+    return opencl_info.image_array_max_layers;
+  }
+  return 256;
+}
+
+uint64_t GpuInfo::GetMaxImage3DWidth() const {
+  if (IsApiOpenCl()) {
+    return opencl_info.image3d_max_width;
+  }
+  return 256;
+}
+
+uint64_t GpuInfo::GetMaxImage3DHeight() const {
+  if (IsApiOpenCl()) {
+    return opencl_info.image3d_max_height;
+  }
+  return 256;
+}
+
+uint64_t GpuInfo::GetMaxImage3DDepth() const {
+  if (IsApiOpenCl()) {
+    return opencl_info.image3d_max_depth;
+  }
+  return 256;
+}
+
+uint64_t GpuInfo::GetMaxBufferSize() const {
+  if (IsApiOpenCl()) {
+    return opencl_info.buffer_max_size;
+  }
+  return 128 * 1024 * 1024;
+}
+
+uint64_t GpuInfo::GetMaxImageBufferWidth() const {
+  if (IsApiOpenCl()) {
+    return opencl_info.image_buffer_max_size;
+  }
+  return 64 * 1024;
 }
 
 int GpuInfo::GetMaxImageArguments() const {
   if (IsApiOpenGl()) {
     return opengl_info.max_image_units;
-  } else if (IsApiVulkan()) {
-    return vulkan_info.max_per_stage_descriptor_sampled_images;
-  } else if (IsApiMetal()) {
-    return 32;
-  } else if (IsApiOpenCl()) {
-    return 128;
-  } else {
-    return 1;
   }
+  if (IsApiVulkan()) {
+    return vulkan_info.max_per_stage_descriptor_sampled_images;
+  }
+  if (IsApiMetal()) {
+    return 32;
+  }
+  if (IsApiOpenCl()) {
+    return 128;
+  }
+  return 1;
 }
 
 bool GpuInfo::IsApiOpenGl() const { return gpu_api == GpuApi::kOpenGl; }
 
-bool GpuInfo::IsApiVulkan() const { return gpu_api == GpuApi::kVulkan; }
-
-bool GpuInfo::IsApiMetal() const { return gpu_api == GpuApi::kMetal; }
-
-bool GpuInfo::IsApiOpenCl() const { return gpu_api == GpuApi::kOpenCl; }
-
 bool GpuInfo::IsApiOpenGl31OrAbove() const {
   if (!IsApiOpenGl()) {
     return false;
@@ -423,5 +654,29 @@ bool GpuInfo::IsApiOpenGl31OrAbove() const {
          opengl_info.major_version > 3;
 }
 
+bool GpuInfo::IsApiVulkan() const { return gpu_api == GpuApi::kVulkan; }
+
+bool GpuInfo::IsApiMetal() const { return gpu_api == GpuApi::kMetal; }
+
+bool GpuInfo::IsApiOpenCl() const { return gpu_api == GpuApi::kOpenCl; }
+
+bool GpuInfo::IsCL20OrHigher() const {
+  if (!IsApiOpenCl()) {
+    return false;
+  }
+  return opencl_info.cl_version != OpenClVersion::kCl1_0 &&
+         opencl_info.cl_version != OpenClVersion::kCl1_1 &&
+         opencl_info.cl_version != OpenClVersion::kCl1_2;
+}
+
+bool GpuInfo::IsCL30OrHigher() const {
+  if (!IsApiOpenCl()) {
+    return false;
+  }
+  return IsCL20OrHigher() && opencl_info.cl_version != OpenClVersion::kCl2_0 &&
+         opencl_info.cl_version != OpenClVersion::kCl2_1 &&
+         opencl_info.cl_version != OpenClVersion::kCl2_2;
+}
+
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h
index ae5d3a75c1b..cd61887fa83 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.h
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h
@@ -19,6 +19,8 @@ limitations under the License.
 #include <string>
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
+
 namespace tflite {
 namespace gpu {
 
@@ -208,6 +210,14 @@ struct OpenGlInfo {
   int max_image_units = 0;
   int max_ssbo_bindings = 0;
   int max_image_bindings = 0;
+  int max_work_group_invocations = 0;
+  int max_texture_size = 0;
+  int max_array_texture_layers = 0;
+
+  std::vector<std::string> extensions;
+  int max_compute_work_group_size_x;
+  int max_compute_work_group_size_y;
+  int max_compute_work_group_size_z;
 };
 
 struct VulkanInfo {
@@ -218,6 +228,65 @@ struct VulkanInfo {
   uint32_t api_version_patch = -1;
 
   uint32_t max_per_stage_descriptor_sampled_images = 0;
+  uint32_t max_compute_work_group_invocations;
+  uint32_t max_image_dimension_2d;
+  uint32_t max_image_array_layers;
+
+  std::vector<std::string> extensions;
+  int max_compute_work_group_size_x;
+  int max_compute_work_group_size_y;
+  int max_compute_work_group_size_z;
+};
+
+enum class OpenClVersion {
+  kCl1_0,
+  kCl1_1,
+  kCl1_2,
+  kCl2_0,
+  kCl2_1,
+  kCl2_2,
+  kCl3_0,
+  kUnknown,
+};
+std::string OpenClVersionToString(OpenClVersion version);
+
+struct OpenClInfo {
+  OpenClVersion cl_version;
+
+  std::vector<std::string> extensions;
+  bool supports_fp16;
+  bool supports_image3d_writes;
+  int compute_units_count;
+  uint64_t buffer_max_size;
+  uint64_t image2d_max_width;
+  uint64_t image2d_max_height;
+  uint64_t image_buffer_max_size;
+  uint64_t image_array_max_layers;
+  uint64_t image3d_max_width;
+  uint64_t image3d_max_height;
+  uint64_t image3d_max_depth;
+  int max_work_group_size_x;
+  int max_work_group_size_y;
+  int max_work_group_size_z;
+  int max_work_group_total_size;
+
+  // rtn is ROUND_TO_NEAREST
+  // with rtn precision is much better then with rtz (ROUND_TO_ZERO)
+  // Adreno 3xx supports only rtz, Adreno 4xx and more support rtn
+  // Mali from T6xx supports rtn
+  // PowerVR supports only rtz
+  bool supports_fp32_rtn;
+  bool supports_fp16_rtn;
+
+  bool supports_r_f16_tex2d = false;
+  bool supports_rg_f16_tex2d = false;
+  bool supports_rgb_f16_tex2d = false;
+  bool supports_rgba_f16_tex2d = false;
+
+  bool supports_r_f32_tex2d = false;
+  bool supports_rg_f32_tex2d = false;
+  bool supports_rgb_f32_tex2d = false;
+  bool supports_rgba_f32_tex2d = false;
 };
 
 struct GpuInfo {
@@ -232,21 +301,42 @@ struct GpuInfo {
   // floating point rounding mode
   bool IsRoundToNearestSupported() const;
 
+  bool SupportsFP16() const;
+
+  bool SupportsTextureArray() const;
+  bool SupportsImageBuffer() const;
+  bool SupportsImage3D() const;
+
   // returns true if device have fixed wave size equal to 32
   bool IsWaveSizeEqualTo32() const;
+  bool SupportsSubGroupWithSize(int sub_group_size) const;
+
+  bool SupportsFloatImage2D(DataType data_type, int channels) const;
+  bool SupportsExtension(const std::string& extension) const;
 
   int GetComputeUnitsCount() const;
 
   int GetMaxImageArguments() const;
 
+  int GetMaxWorkGroupSizeForX() const;
+  int GetMaxWorkGroupSizeForY() const;
+  int GetMaxWorkGroupSizeForZ() const;
+  int GetMaxWorkGroupTotalSize() const;
+
+  uint64_t GetMaxImage2DWidth() const;
+  uint64_t GetMaxImage2DHeight() const;
+  uint64_t GetMaxImage2DArrayLayers() const;
+  uint64_t GetMaxImage3DWidth() const;
+  uint64_t GetMaxImage3DHeight() const;
+  uint64_t GetMaxImage3DDepth() const;
+  uint64_t GetMaxBufferSize() const;
+  uint64_t GetMaxImageBufferWidth() const;
+
   GpuVendor vendor = GpuVendor::kUnknown;
   GpuApi gpu_api = GpuApi::kUnknown;
 
-  std::vector<std::string> extensions;
+  // Temporary
   std::vector<int> max_work_group_size;
-  int max_work_group_invocations;
-  int max_texture_size = 0;
-  int max_array_texture_layers = 0;
 
   std::vector<int> supported_subgroup_sizes;
 
@@ -265,7 +355,10 @@ struct GpuInfo {
 
   bool IsApiMetal() const;
 
+  OpenClInfo opencl_info;
   bool IsApiOpenCl() const;
+  bool IsCL20OrHigher() const;
+  bool IsCL30OrHigher() const;
 };
 
 inline bool IsOpenGl31OrAbove(const GpuInfo& gpu_info) {
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index ec39dbd5c5c..97eb0751328 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -858,7 +858,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser {
   absl::Status IsSupported(const TfLiteContext* context,
                            const TfLiteNode* tflite_node,
                            const TfLiteRegistration* registration) final {
-    RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4));
+    RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 9));
     const TfLiteFullyConnectedParams* tf_options;
     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
     if (tf_options->weights_format !=
@@ -870,6 +870,15 @@ class FullyConnectedOperationParser : public TFLiteOperationParser {
       return absl::UnimplementedError(
           "FullyConnected doesn't support more than 2 runtime inputs.");
     }
+    if (tf_options->keep_num_dims == true) {
+      const auto* input = context->tensors + tflite_node->inputs->data[0];
+      const auto* output = context->tensors + tflite_node->outputs->data[0];
+      if (input->dims->size != output->dims->size) {
+        return absl::UnimplementedError(
+            "Input and output dimensions different and FullyConnected doesn't "
+            "support keep_num_dims.");
+      }
+    }
     // TODO(eignasheva): check input shape
     return absl::OkStatus();
   }
@@ -1508,10 +1517,6 @@ class ReduceOperationParser : public TFLiteOperationParser {
       return absl::UnimplementedError(
           "Reduce has unsupported tensor for axes.");
     }
-    if (tflite::NumElements(axes) != 1) {
-      return absl::UnimplementedError(
-          "Supported reduce in single dimensions only.");
-    }
     return absl::OkStatus();
   }
 
@@ -1526,13 +1531,15 @@ class ReduceOperationParser : public TFLiteOperationParser {
     const TfLiteReducerParams* tf_options;
     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
 
-    Tensor<Scalar, DataType::INT32> axes;
-    RETURN_IF_ERROR(reader->ReadTensor(1, &axes));
-    const TfLiteTensor* input = reader->GetInputTensor(0);
     ReduceAttributes attr;
-    Axis axis;
-    RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[0], &axis));
-    attr.dims = {axis};
+    Tensor<Linear, DataType::INT32> axes;
+    RETURN_IF_ERROR(reader->ReadTensor(1, &axes));
+    const TfLiteTensor* output = reader->GetOutputTensor(0);
+    for (int i = 0; i < axes.data.size(); i++) {
+      Axis axis;
+      RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis));
+      attr.dims.insert(axis);
+    }
     node->operation.attributes = attr;
     return absl::OkStatus();
   }
@@ -1652,7 +1659,6 @@ class Resize2DOperationParser : public TFLiteOperationParser {
     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
                                        /*runtime_inputs=*/1, /*outputs=*/1));
 
-    RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node));
     bool align_corners;
     RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners));
     bool half_pixel_centers;
@@ -1724,27 +1730,6 @@ class Resize2DOperationParser : public TFLiteOperationParser {
     return absl::OkStatus();
   }
 
-  absl::Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context,
-                                              const TfLiteNode* tflite_node) {
-    const auto* input = context->tensors + tflite_node->inputs->data[0];
-    const auto* output = context->tensors + tflite_node->outputs->data[0];
-
-    if (!input->dims || input->dims->size != 4) {
-      return absl::InvalidArgumentError("input.dims.size != 4");
-    }
-    if (!output->dims || output->dims->size != 4) {
-      return absl::InvalidArgumentError("output.dims.size != 4");
-    }
-    if (output->dims->data[1] < input->dims->data[1] ||
-        output->dims->data[2] < input->dims->data[2]) {
-      return absl::InvalidArgumentError(absl::StrCat(
-          "Only upsampling is supported, received output h,w = ",
-          output->dims->data[1], ",", output->dims->data[2],
-          " input h,w = ", input->dims->data[1], ",", input->dims->data[2]));
-    }
-    return absl::OkStatus();
-  }
-
   SamplingType sampling_type_ = SamplingType::UNKNOWN;
 };
 
@@ -2612,18 +2597,10 @@ class MeanOperationParser : public TFLiteOperationParser {
                                        /*runtime_inputs=*/1,
                                        /*outputs=*/1));
 
-    // Simple mechanism to check if MEAN is to be performed only on HW plane.
     auto* axes = &context->tensors[tflite_node->inputs->data[1]];
     if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) {
       return absl::UnimplementedError("Mean has unsupported tensor for axes");
     }
-    auto* axes_data = axes->data.i32;
-    const bool is_hw_mean = tflite::NumElements(axes) == 2 &&
-                            ((axes_data[0] == 1 && axes_data[1] == 2) ||
-                             (axes_data[0] == 2 && axes_data[1] == 1));
-    if (!is_hw_mean) {
-      return absl::UnimplementedError("Mean operation supports only HW plane");
-    }
     return absl::OkStatus();
   }
 
@@ -2636,27 +2613,13 @@ class MeanOperationParser : public TFLiteOperationParser {
     RETURN_IF_ERROR(reader->AddOutputs(node));
 
     MeanAttributes attr;
-    Tensor<Linear, DataType::INT32> channel;
-    RETURN_IF_ERROR(reader->ReadTensor(1, &channel));
-    for (int i = 0; i < channel.data.size(); i++) {
-      std::string unsupported;
-      switch (channel.data[i]) {
-        case 1:
-          attr.dims.insert(Axis::HEIGHT);
-          break;
-        case 2:
-          attr.dims.insert(Axis::WIDTH);
-          break;
-        case 0:
-          unsupported = unsupported.empty() ? "batch" : unsupported;
-          ABSL_FALLTHROUGH_INTENDED;
-        case 3:
-          unsupported = unsupported.empty() ? "channels" : unsupported;
-          ABSL_FALLTHROUGH_INTENDED;
-        default:
-          return absl::UnimplementedError(
-              absl::StrCat("Unsupported mean dimension: ", unsupported));
-      }
+    Tensor<Linear, DataType::INT32> axes;
+    RETURN_IF_ERROR(reader->ReadTensor(1, &axes));
+    const TfLiteTensor* output = reader->GetOutputTensor(0);
+    for (int i = 0; i < axes.data.size(); i++) {
+      Axis axis;
+      RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis));
+      attr.dims.insert(axis);
     }
     node->operation.attributes = attr;
     return absl::OkStatus();
diff --git a/tensorflow/lite/delegates/gpu/common/object_reader.h b/tensorflow/lite/delegates/gpu/common/object_reader.h
index 3c7d7f6a859..71f4a529b33 100644
--- a/tensorflow/lite/delegates/gpu/common/object_reader.h
+++ b/tensorflow/lite/delegates/gpu/common/object_reader.h
@@ -71,6 +71,9 @@ class ObjectReader {
     }
 
     const TfLiteTensor* tflite_tensor = context_->tensors + tensor_idx;
+    if (tflite_tensor->sparsity != nullptr) {
+      return absl::InvalidArgumentError("Sparsity is not supported on GPU.");
+    }
     t->data.resize(NumElements(tflite_tensor));
     RETURN_IF_ERROR(CreateVectorCopyData(*tflite_tensor, &t->data[0]));
 
diff --git a/tensorflow/lite/delegates/gpu/common/task/BUILD b/tensorflow/lite/delegates/gpu/common/task/BUILD
index b6800701989..04e5082e1f5 100644
--- a/tensorflow/lite/delegates/gpu/common/task/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/task/BUILD
@@ -50,6 +50,30 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "gpu_operation",
+    srcs = ["gpu_operation.cc"],
+    hdrs = ["gpu_operation.h"],
+    deps = [
+        ":serialization_base_cc_fbs",
+        "//tensorflow/lite/delegates/gpu/common:access_type",
+        "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
+        "//tensorflow/lite/delegates/gpu/common:kernel_info",
+        "//tensorflow/lite/delegates/gpu/common:precision",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common/task:arguments",
+        "//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
+        "//tensorflow/lite/delegates/gpu/common/task:compiler_options",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_tensor",
+        "//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
+        "//tensorflow/lite/delegates/gpu/common/task:tuning_type",
+        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
 cc_library(
     name = "gpu_tensor",
     hdrs = ["gpu_tensor.h"],
@@ -116,5 +140,43 @@ cc_library(
     hdrs = ["util.h"],
     deps = [
         ":gpu_object_desc",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
+        "//tensorflow/lite/delegates/gpu/common:precision",
+        "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common:util",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+cc_library(
+    name = "weights_conversion",
+    hdrs = ["weights_conversion.h"],
+    deps = [
+        "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
+        "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common:util",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+cc_library(
+    name = "weights_layout",
+    hdrs = ["weights_layout.h"],
+)
+
+cc_library(
+    name = "work_group_picking",
+    srcs = ["work_group_picking.cc"],
+    hdrs = ["work_group_picking.h"],
+    deps = [
+        ":tuning_type",
+        "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
+        "//tensorflow/lite/delegates/gpu/common:kernel_info",
+        "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common:workgroup_selection",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
similarity index 82%
rename from tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
rename to tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
index e9d5b626fd2..800576d0ff1 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
@@ -13,22 +13,61 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 #include "absl/strings/substitute.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
 #include "tensorflow/lite/delegates/gpu/common/access_type.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 namespace {
+std::string GetCommonDefines(CalculationsPrecision precision) {
+  std::string result;
+
+  switch (precision) {
+    case CalculationsPrecision::F32:
+      result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
+      result += "#define ACCUM_FLT4 float4\n";
+      result += "#define FLT float\n";
+      result += "#define FLT2 float2\n";
+      result += "#define FLT3 float3\n";
+      result += "#define FLT4 float4\n";
+      result += "#define TO_FLT4 convert_float4\n";
+      result += "#define TO_ACCUM_TYPE convert_float4\n";
+      result += "#define TO_ACCUM_FLT convert_float\n";
+      break;
+    case CalculationsPrecision::F16:
+      result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
+      result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
+      result += "#define ACCUM_FLT4 half4\n";
+      result += "#define FLT half\n";
+      result += "#define FLT2 half2\n";
+      result += "#define FLT3 half3\n";
+      result += "#define FLT4 half4\n";
+      result += "#define TO_FLT4 convert_half4\n";
+      result += "#define TO_ACCUM_TYPE convert_half4\n";
+      result += "#define TO_ACCUM_FLT convert_half\n";
+      break;
+    case CalculationsPrecision::F32_F16:
+      result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
+      result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
+      result += "#define ACCUM_FLT4 float4\n";
+      result += "#define FLT half\n";
+      result += "#define FLT2 half2\n";
+      result += "#define FLT3 half3\n";
+      result += "#define FLT4 half4\n";
+      result += "#define TO_FLT4 convert_half4\n";
+      result += "#define TO_ACCUM_TYPE convert_float4\n";
+      result += "#define TO_ACCUM_FLT convert_float\n";
+      break;
+  }
+  return result;
+}
 
 std::string GetElementWiseCode(const OperationDef& op_def,
                                bool check_src_slices) {
-  std::string c = GetCommonDefines(op_def.precision);
-
+  std::string c;
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int X = get_global_id(0);\n";
@@ -201,6 +240,7 @@ void GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
     elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_;
     code_ = GetElementWiseCode(definition_, check_src_channels_size_);
   }
+  code_ = GetCommonDefines(definition_.precision) + code_;
 }
 
 void GPUOperation::GetPossibleKernelWorkGroups(
@@ -247,6 +287,5 @@ void GPUOperation::AddUniquePostfix(const std::string& unique_postfix) {
   }
 }
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
similarity index 91%
rename from tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h
rename to tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
index c4d73225a27..f35682dad3a 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h
+++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
@@ -13,15 +13,14 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_GPU_OPERATION_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_GPU_OPERATION_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_GPU_OPERATION_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_GPU_OPERATION_H_
 
 #include <string>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
-#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/kernel_info.h"
 #include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
@@ -29,6 +28,7 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
 #include "tensorflow/lite/delegates/gpu/common/task/compiler_options.h"
 #include "tensorflow/lite/delegates/gpu/common/task/gpu_tensor.h"
+#include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
@@ -36,6 +36,8 @@ limitations under the License.
 namespace tflite {
 namespace gpu {
 namespace cl {
+class ClOperation;
+}
 
 // kCustom: default value
 //   GPUOperation::GetGridSize must be overloaded
@@ -142,10 +144,11 @@ class GPUOperation {
   bool check_src_channels_size_ = false;
 
  protected:
-  friend class ClOperation;
-  friend flatbuffers::Offset<data::GPUOperation> Encode(
+  friend class cl::ClOperation;
+  friend flatbuffers::Offset<tflite::gpu::data::GPUOperation> Encode(
       const GPUOperation& op, flatbuffers::FlatBufferBuilder* builder);
-  friend absl::Status Decode(const data::GPUOperation* fb_op, GPUOperation* op);
+  friend absl::Status Decode(const tflite::gpu::data::GPUOperation* fb_op,
+                             GPUOperation* op);
 
   virtual absl::Status BindArguments(ArgumentsBinder* args) {
     return absl::OkStatus();
@@ -168,8 +171,7 @@ class GPUOperation {
   std::string elementwise_code_;  // temporary, used during op construction
 };
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_GPU_OPERATION_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_GPU_OPERATION_H_
diff --git a/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs b/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs
index 469a7dd9443..8d9434786a2 100644
--- a/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs
+++ b/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs
@@ -181,3 +181,56 @@ table Arguments {
   tensor_linear_objects:[TensorLinearDescriptorMapValue];
   tensor_objects:[TensorDescriptorMapValue];
 }
+
+enum CalculationsPrecision : byte {
+  F32 = 0,
+  F32_F16 = 1,
+  F16 = 2,
+}
+
+enum TensorToGrid : byte {
+  CUSTOM = 0,
+  WB_TO_X_HD_TO_Y_S_TO_Z = 1,
+  WB_TO_X_HD_TO_Y_Z_IS_1 = 2,
+  WB_TO_X_H_TO_Y_D_TO_Z = 3,
+  B_TO_X_Y_IS_1_Z_IS_1 = 4,
+}
+
+enum CompilerOptions : byte {
+  ADRENO_FULL_SIMD_LINE = 0,
+  ADRENO_MORE_WAVES = 1,
+  POWERVR_FP16 = 2,
+  CL_OPT_DISABLE = 3,
+  CL_2_0 = 4,
+  CL_3_0 = 5,
+}
+
+table OperationDef {
+  precision:CalculationsPrecision;
+  src_tensors:[TensorDescriptor];
+  dst_tensors:[TensorDescriptor];
+}
+
+table CompilerOption {
+  option:CompilerOptions;
+}
+
+table GPUOperation {
+  arguments:Arguments;
+  code:string;
+  work_group_size:Int3;
+  compiler_options:[CompilerOption];
+  tensor_to_grid:TensorToGrid;
+  elementwise:bool;
+  linkable:bool;
+  check_src_channels_size:bool;
+  definition:OperationDef;
+  grid_dimension:int32;
+  work_group_launch_order:Int3;
+  grid_size:Int3;
+  src_tensors_names:[string];
+  dst_tensors_names:[string];
+  work_groups_count:Int3;
+  linkable_count:int32;
+  elementwise_code:string;
+}
diff --git a/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h b/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
index 7f089f95082..e3c3a5c33df 100644
--- a/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
+++ b/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
@@ -78,6 +78,15 @@ struct TensorDescriptorMapValueBuilder;
 struct Arguments;
 struct ArgumentsBuilder;
 
+struct OperationDef;
+struct OperationDefBuilder;
+
+struct CompilerOption;
+struct CompilerOptionBuilder;
+
+struct GPUOperation;
+struct GPUOperationBuilder;
+
 enum class AccessType : int8_t {
   READ = 0,
   WRITE = 1,
@@ -291,6 +300,111 @@ inline const char *EnumNameLayout(Layout e) {
   return EnumNamesLayout()[index];
 }
 
+enum class CalculationsPrecision : int8_t {
+  F32 = 0,
+  F32_F16 = 1,
+  F16 = 2,
+  MIN = F32,
+  MAX = F16
+};
+
+inline const CalculationsPrecision (&EnumValuesCalculationsPrecision())[3] {
+  static const CalculationsPrecision values[] = {CalculationsPrecision::F32,
+                                                 CalculationsPrecision::F32_F16,
+                                                 CalculationsPrecision::F16};
+  return values;
+}
+
+inline const char *const *EnumNamesCalculationsPrecision() {
+  static const char *const names[4] = {"F32", "F32_F16", "F16", nullptr};
+  return names;
+}
+
+inline const char *EnumNameCalculationsPrecision(CalculationsPrecision e) {
+  if (flatbuffers::IsOutRange(e, CalculationsPrecision::F32,
+                              CalculationsPrecision::F16))
+    return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesCalculationsPrecision()[index];
+}
+
+enum class TensorToGrid : int8_t {
+  CUSTOM = 0,
+  WB_TO_X_HD_TO_Y_S_TO_Z = 1,
+  WB_TO_X_HD_TO_Y_Z_IS_1 = 2,
+  WB_TO_X_H_TO_Y_D_TO_Z = 3,
+  B_TO_X_Y_IS_1_Z_IS_1 = 4,
+  MIN = CUSTOM,
+  MAX = B_TO_X_Y_IS_1_Z_IS_1
+};
+
+inline const TensorToGrid (&EnumValuesTensorToGrid())[5] {
+  static const TensorToGrid values[] = {
+      TensorToGrid::CUSTOM, TensorToGrid::WB_TO_X_HD_TO_Y_S_TO_Z,
+      TensorToGrid::WB_TO_X_HD_TO_Y_Z_IS_1, TensorToGrid::WB_TO_X_H_TO_Y_D_TO_Z,
+      TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1};
+  return values;
+}
+
+inline const char *const *EnumNamesTensorToGrid() {
+  static const char *const names[6] = {"CUSTOM",
+                                       "WB_TO_X_HD_TO_Y_S_TO_Z",
+                                       "WB_TO_X_HD_TO_Y_Z_IS_1",
+                                       "WB_TO_X_H_TO_Y_D_TO_Z",
+                                       "B_TO_X_Y_IS_1_Z_IS_1",
+                                       nullptr};
+  return names;
+}
+
+inline const char *EnumNameTensorToGrid(TensorToGrid e) {
+  if (flatbuffers::IsOutRange(e, TensorToGrid::CUSTOM,
+                              TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1))
+    return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesTensorToGrid()[index];
+}
+
+enum class CompilerOptions : int8_t {
+  ADRENO_FULL_SIMD_LINE = 0,
+  ADRENO_MORE_WAVES = 1,
+  POWERVR_FP16 = 2,
+  CL_OPT_DISABLE = 3,
+  CL_2_0 = 4,
+  CL_3_0 = 5,
+  MIN = ADRENO_FULL_SIMD_LINE,
+  MAX = CL_3_0
+};
+
+inline const CompilerOptions (&EnumValuesCompilerOptions())[6] {
+  static const CompilerOptions values[] = {
+      CompilerOptions::ADRENO_FULL_SIMD_LINE,
+      CompilerOptions::ADRENO_MORE_WAVES,
+      CompilerOptions::POWERVR_FP16,
+      CompilerOptions::CL_OPT_DISABLE,
+      CompilerOptions::CL_2_0,
+      CompilerOptions::CL_3_0};
+  return values;
+}
+
+inline const char *const *EnumNamesCompilerOptions() {
+  static const char *const names[7] = {"ADRENO_FULL_SIMD_LINE",
+                                       "ADRENO_MORE_WAVES",
+                                       "POWERVR_FP16",
+                                       "CL_OPT_DISABLE",
+                                       "CL_2_0",
+                                       "CL_3_0",
+                                       nullptr};
+  return names;
+}
+
+inline const char *EnumNameCompilerOptions(CompilerOptions e) {
+  if (flatbuffers::IsOutRange(e, CompilerOptions::ADRENO_FULL_SIMD_LINE,
+                              CompilerOptions::CL_3_0))
+    return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesCompilerOptions()[index];
+}
+
 struct Int4 FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   typedef Int4Builder Builder;
   enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
@@ -1832,6 +1946,454 @@ inline flatbuffers::Offset<Arguments> CreateArgumentsDirect(
       tensor_objects__);
 }
 
+struct OperationDef FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef OperationDefBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_PRECISION = 4,
+    VT_SRC_TENSORS = 6,
+    VT_DST_TENSORS = 8
+  };
+  tflite::gpu::data::CalculationsPrecision precision() const {
+    return static_cast<tflite::gpu::data::CalculationsPrecision>(
+        GetField<int8_t>(VT_PRECISION, 0));
+  }
+  const flatbuffers::Vector<
+      flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>
+      *src_tensors() const {
+    return GetPointer<const flatbuffers::Vector<
+        flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *>(
+        VT_SRC_TENSORS);
+  }
+  const flatbuffers::Vector<
+      flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>
+      *dst_tensors() const {
+    return GetPointer<const flatbuffers::Vector<
+        flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *>(
+        VT_DST_TENSORS);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<int8_t>(verifier, VT_PRECISION) &&
+           VerifyOffset(verifier, VT_SRC_TENSORS) &&
+           verifier.VerifyVector(src_tensors()) &&
+           verifier.VerifyVectorOfTables(src_tensors()) &&
+           VerifyOffset(verifier, VT_DST_TENSORS) &&
+           verifier.VerifyVector(dst_tensors()) &&
+           verifier.VerifyVectorOfTables(dst_tensors()) && verifier.EndTable();
+  }
+};
+
+struct OperationDefBuilder {
+  typedef OperationDef Table;
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_precision(tflite::gpu::data::CalculationsPrecision precision) {
+    fbb_.AddElement<int8_t>(OperationDef::VT_PRECISION,
+                            static_cast<int8_t>(precision), 0);
+  }
+  void add_src_tensors(
+      flatbuffers::Offset<flatbuffers::Vector<
+          flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>>
+          src_tensors) {
+    fbb_.AddOffset(OperationDef::VT_SRC_TENSORS, src_tensors);
+  }
+  void add_dst_tensors(
+      flatbuffers::Offset<flatbuffers::Vector<
+          flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>>
+          dst_tensors) {
+    fbb_.AddOffset(OperationDef::VT_DST_TENSORS, dst_tensors);
+  }
+  explicit OperationDefBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+      : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  flatbuffers::Offset<OperationDef> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<OperationDef>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<OperationDef> CreateOperationDef(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    tflite::gpu::data::CalculationsPrecision precision =
+        tflite::gpu::data::CalculationsPrecision::F32,
+    flatbuffers::Offset<flatbuffers::Vector<
+        flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>>
+        src_tensors = 0,
+    flatbuffers::Offset<flatbuffers::Vector<
+        flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>>
+        dst_tensors = 0) {
+  OperationDefBuilder builder_(_fbb);
+  builder_.add_dst_tensors(dst_tensors);
+  builder_.add_src_tensors(src_tensors);
+  builder_.add_precision(precision);
+  return builder_.Finish();
+}
+
+inline flatbuffers::Offset<OperationDef> CreateOperationDefDirect(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    tflite::gpu::data::CalculationsPrecision precision =
+        tflite::gpu::data::CalculationsPrecision::F32,
+    const std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>
+        *src_tensors = nullptr,
+    const std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>
+        *dst_tensors = nullptr) {
+  auto src_tensors__ =
+      src_tensors
+          ? _fbb.CreateVector<
+                flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>(
+                *src_tensors)
+          : 0;
+  auto dst_tensors__ =
+      dst_tensors
+          ? _fbb.CreateVector<
+                flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>(
+                *dst_tensors)
+          : 0;
+  return tflite::gpu::data::CreateOperationDef(_fbb, precision, src_tensors__,
+                                               dst_tensors__);
+}
+
+struct CompilerOption FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef CompilerOptionBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_OPTION = 4
+  };
+  tflite::gpu::data::CompilerOptions option() const {
+    return static_cast<tflite::gpu::data::CompilerOptions>(
+        GetField<int8_t>(VT_OPTION, 0));
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<int8_t>(verifier, VT_OPTION) && verifier.EndTable();
+  }
+};
+
+struct CompilerOptionBuilder {
+  typedef CompilerOption Table;
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_option(tflite::gpu::data::CompilerOptions option) {
+    fbb_.AddElement<int8_t>(CompilerOption::VT_OPTION,
+                            static_cast<int8_t>(option), 0);
+  }
+  explicit CompilerOptionBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+      : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  flatbuffers::Offset<CompilerOption> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<CompilerOption>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<CompilerOption> CreateCompilerOption(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    tflite::gpu::data::CompilerOptions option =
+        tflite::gpu::data::CompilerOptions::ADRENO_FULL_SIMD_LINE) {
+  CompilerOptionBuilder builder_(_fbb);
+  builder_.add_option(option);
+  return builder_.Finish();
+}
+
+struct GPUOperation FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef GPUOperationBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_ARGUMENTS = 4,
+    VT_CODE = 6,
+    VT_WORK_GROUP_SIZE = 8,
+    VT_COMPILER_OPTIONS = 10,
+    VT_TENSOR_TO_GRID = 12,
+    VT_ELEMENTWISE = 14,
+    VT_LINKABLE = 16,
+    VT_CHECK_SRC_CHANNELS_SIZE = 18,
+    VT_DEFINITION = 20,
+    VT_GRID_DIMENSION = 22,
+    VT_WORK_GROUP_LAUNCH_ORDER = 24,
+    VT_GRID_SIZE = 26,
+    VT_SRC_TENSORS_NAMES = 28,
+    VT_DST_TENSORS_NAMES = 30,
+    VT_WORK_GROUPS_COUNT = 32,
+    VT_LINKABLE_COUNT = 34,
+    VT_ELEMENTWISE_CODE = 36
+  };
+  const tflite::gpu::data::Arguments *arguments() const {
+    return GetPointer<const tflite::gpu::data::Arguments *>(VT_ARGUMENTS);
+  }
+  const flatbuffers::String *code() const {
+    return GetPointer<const flatbuffers::String *>(VT_CODE);
+  }
+  const tflite::gpu::data::Int3 *work_group_size() const {
+    return GetPointer<const tflite::gpu::data::Int3 *>(VT_WORK_GROUP_SIZE);
+  }
+  const flatbuffers::Vector<
+      flatbuffers::Offset<tflite::gpu::data::CompilerOption>>
+      *compiler_options() const {
+    return GetPointer<const flatbuffers::Vector<
+        flatbuffers::Offset<tflite::gpu::data::CompilerOption>> *>(
+        VT_COMPILER_OPTIONS);
+  }
+  tflite::gpu::data::TensorToGrid tensor_to_grid() const {
+    return static_cast<tflite::gpu::data::TensorToGrid>(
+        GetField<int8_t>(VT_TENSOR_TO_GRID, 0));
+  }
+  bool elementwise() const { return GetField<uint8_t>(VT_ELEMENTWISE, 0) != 0; }
+  bool linkable() const { return GetField<uint8_t>(VT_LINKABLE, 0) != 0; }
+  bool check_src_channels_size() const {
+    return GetField<uint8_t>(VT_CHECK_SRC_CHANNELS_SIZE, 0) != 0;
+  }
+  const tflite::gpu::data::OperationDef *definition() const {
+    return GetPointer<const tflite::gpu::data::OperationDef *>(VT_DEFINITION);
+  }
+  int32_t grid_dimension() const {
+    return GetField<int32_t>(VT_GRID_DIMENSION, 0);
+  }
+  const tflite::gpu::data::Int3 *work_group_launch_order() const {
+    return GetPointer<const tflite::gpu::data::Int3 *>(
+        VT_WORK_GROUP_LAUNCH_ORDER);
+  }
+  const tflite::gpu::data::Int3 *grid_size() const {
+    return GetPointer<const tflite::gpu::data::Int3 *>(VT_GRID_SIZE);
+  }
+  const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>
+      *src_tensors_names() const {
+    return GetPointer<
+        const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(
+        VT_SRC_TENSORS_NAMES);
+  }
+  const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>
+      *dst_tensors_names() const {
+    return GetPointer<
+        const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(
+        VT_DST_TENSORS_NAMES);
+  }
+  const tflite::gpu::data::Int3 *work_groups_count() const {
+    return GetPointer<const tflite::gpu::data::Int3 *>(VT_WORK_GROUPS_COUNT);
+  }
+  int32_t linkable_count() const {
+    return GetField<int32_t>(VT_LINKABLE_COUNT, 0);
+  }
+  const flatbuffers::String *elementwise_code() const {
+    return GetPointer<const flatbuffers::String *>(VT_ELEMENTWISE_CODE);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ARGUMENTS) &&
+           verifier.VerifyTable(arguments()) &&
+           VerifyOffset(verifier, VT_CODE) && verifier.VerifyString(code()) &&
+           VerifyOffset(verifier, VT_WORK_GROUP_SIZE) &&
+           verifier.VerifyTable(work_group_size()) &&
+           VerifyOffset(verifier, VT_COMPILER_OPTIONS) &&
+           verifier.VerifyVector(compiler_options()) &&
+           verifier.VerifyVectorOfTables(compiler_options()) &&
+           VerifyField<int8_t>(verifier, VT_TENSOR_TO_GRID) &&
+           VerifyField<uint8_t>(verifier, VT_ELEMENTWISE) &&
+           VerifyField<uint8_t>(verifier, VT_LINKABLE) &&
+           VerifyField<uint8_t>(verifier, VT_CHECK_SRC_CHANNELS_SIZE) &&
+           VerifyOffset(verifier, VT_DEFINITION) &&
+           verifier.VerifyTable(definition()) &&
+           VerifyField<int32_t>(verifier, VT_GRID_DIMENSION) &&
+           VerifyOffset(verifier, VT_WORK_GROUP_LAUNCH_ORDER) &&
+           verifier.VerifyTable(work_group_launch_order()) &&
+           VerifyOffset(verifier, VT_GRID_SIZE) &&
+           verifier.VerifyTable(grid_size()) &&
+           VerifyOffset(verifier, VT_SRC_TENSORS_NAMES) &&
+           verifier.VerifyVector(src_tensors_names()) &&
+           verifier.VerifyVectorOfStrings(src_tensors_names()) &&
+           VerifyOffset(verifier, VT_DST_TENSORS_NAMES) &&
+           verifier.VerifyVector(dst_tensors_names()) &&
+           verifier.VerifyVectorOfStrings(dst_tensors_names()) &&
+           VerifyOffset(verifier, VT_WORK_GROUPS_COUNT) &&
+           verifier.VerifyTable(work_groups_count()) &&
+           VerifyField<int32_t>(verifier, VT_LINKABLE_COUNT) &&
+           VerifyOffset(verifier, VT_ELEMENTWISE_CODE) &&
+           verifier.VerifyString(elementwise_code()) && verifier.EndTable();
+  }
+};
+
+struct GPUOperationBuilder {
+  typedef GPUOperation Table;
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_arguments(
+      flatbuffers::Offset<tflite::gpu::data::Arguments> arguments) {
+    fbb_.AddOffset(GPUOperation::VT_ARGUMENTS, arguments);
+  }
+  void add_code(flatbuffers::Offset<flatbuffers::String> code) {
+    fbb_.AddOffset(GPUOperation::VT_CODE, code);
+  }
+  void add_work_group_size(
+      flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size) {
+    fbb_.AddOffset(GPUOperation::VT_WORK_GROUP_SIZE, work_group_size);
+  }
+  void add_compiler_options(
+      flatbuffers::Offset<flatbuffers::Vector<
+          flatbuffers::Offset<tflite::gpu::data::CompilerOption>>>
+          compiler_options) {
+    fbb_.AddOffset(GPUOperation::VT_COMPILER_OPTIONS, compiler_options);
+  }
+  void add_tensor_to_grid(tflite::gpu::data::TensorToGrid tensor_to_grid) {
+    fbb_.AddElement<int8_t>(GPUOperation::VT_TENSOR_TO_GRID,
+                            static_cast<int8_t>(tensor_to_grid), 0);
+  }
+  void add_elementwise(bool elementwise) {
+    fbb_.AddElement<uint8_t>(GPUOperation::VT_ELEMENTWISE,
+                             static_cast<uint8_t>(elementwise), 0);
+  }
+  void add_linkable(bool linkable) {
+    fbb_.AddElement<uint8_t>(GPUOperation::VT_LINKABLE,
+                             static_cast<uint8_t>(linkable), 0);
+  }
+  void add_check_src_channels_size(bool check_src_channels_size) {
+    fbb_.AddElement<uint8_t>(GPUOperation::VT_CHECK_SRC_CHANNELS_SIZE,
+                             static_cast<uint8_t>(check_src_channels_size), 0);
+  }
+  void add_definition(
+      flatbuffers::Offset<tflite::gpu::data::OperationDef> definition) {
+    fbb_.AddOffset(GPUOperation::VT_DEFINITION, definition);
+  }
+  void add_grid_dimension(int32_t grid_dimension) {
+    fbb_.AddElement<int32_t>(GPUOperation::VT_GRID_DIMENSION, grid_dimension,
+                             0);
+  }
+  void add_work_group_launch_order(
+      flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order) {
+    fbb_.AddOffset(GPUOperation::VT_WORK_GROUP_LAUNCH_ORDER,
+                   work_group_launch_order);
+  }
+  void add_grid_size(flatbuffers::Offset<tflite::gpu::data::Int3> grid_size) {
+    fbb_.AddOffset(GPUOperation::VT_GRID_SIZE, grid_size);
+  }
+  void add_src_tensors_names(
+      flatbuffers::Offset<
+          flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
+          src_tensors_names) {
+    fbb_.AddOffset(GPUOperation::VT_SRC_TENSORS_NAMES, src_tensors_names);
+  }
+  void add_dst_tensors_names(
+      flatbuffers::Offset<
+          flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
+          dst_tensors_names) {
+    fbb_.AddOffset(GPUOperation::VT_DST_TENSORS_NAMES, dst_tensors_names);
+  }
+  void add_work_groups_count(
+      flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count) {
+    fbb_.AddOffset(GPUOperation::VT_WORK_GROUPS_COUNT, work_groups_count);
+  }
+  void add_linkable_count(int32_t linkable_count) {
+    fbb_.AddElement<int32_t>(GPUOperation::VT_LINKABLE_COUNT, linkable_count,
+                             0);
+  }
+  void add_elementwise_code(
+      flatbuffers::Offset<flatbuffers::String> elementwise_code) {
+    fbb_.AddOffset(GPUOperation::VT_ELEMENTWISE_CODE, elementwise_code);
+  }
+  explicit GPUOperationBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+      : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  flatbuffers::Offset<GPUOperation> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<GPUOperation>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<GPUOperation> CreateGPUOperation(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    flatbuffers::Offset<tflite::gpu::data::Arguments> arguments = 0,
+    flatbuffers::Offset<flatbuffers::String> code = 0,
+    flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size = 0,
+    flatbuffers::Offset<flatbuffers::Vector<
+        flatbuffers::Offset<tflite::gpu::data::CompilerOption>>>
+        compiler_options = 0,
+    tflite::gpu::data::TensorToGrid tensor_to_grid =
+        tflite::gpu::data::TensorToGrid::CUSTOM,
+    bool elementwise = false, bool linkable = false,
+    bool check_src_channels_size = false,
+    flatbuffers::Offset<tflite::gpu::data::OperationDef> definition = 0,
+    int32_t grid_dimension = 0,
+    flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order = 0,
+    flatbuffers::Offset<tflite::gpu::data::Int3> grid_size = 0,
+    flatbuffers::Offset<
+        flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
+        src_tensors_names = 0,
+    flatbuffers::Offset<
+        flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
+        dst_tensors_names = 0,
+    flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count = 0,
+    int32_t linkable_count = 0,
+    flatbuffers::Offset<flatbuffers::String> elementwise_code = 0) {
+  GPUOperationBuilder builder_(_fbb);
+  builder_.add_elementwise_code(elementwise_code);
+  builder_.add_linkable_count(linkable_count);
+  builder_.add_work_groups_count(work_groups_count);
+  builder_.add_dst_tensors_names(dst_tensors_names);
+  builder_.add_src_tensors_names(src_tensors_names);
+  builder_.add_grid_size(grid_size);
+  builder_.add_work_group_launch_order(work_group_launch_order);
+  builder_.add_grid_dimension(grid_dimension);
+  builder_.add_definition(definition);
+  builder_.add_compiler_options(compiler_options);
+  builder_.add_work_group_size(work_group_size);
+  builder_.add_code(code);
+  builder_.add_arguments(arguments);
+  builder_.add_check_src_channels_size(check_src_channels_size);
+  builder_.add_linkable(linkable);
+  builder_.add_elementwise(elementwise);
+  builder_.add_tensor_to_grid(tensor_to_grid);
+  return builder_.Finish();
+}
+
+inline flatbuffers::Offset<GPUOperation> CreateGPUOperationDirect(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    flatbuffers::Offset<tflite::gpu::data::Arguments> arguments = 0,
+    const char *code = nullptr,
+    flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size = 0,
+    const std::vector<flatbuffers::Offset<tflite::gpu::data::CompilerOption>>
+        *compiler_options = nullptr,
+    tflite::gpu::data::TensorToGrid tensor_to_grid =
+        tflite::gpu::data::TensorToGrid::CUSTOM,
+    bool elementwise = false, bool linkable = false,
+    bool check_src_channels_size = false,
+    flatbuffers::Offset<tflite::gpu::data::OperationDef> definition = 0,
+    int32_t grid_dimension = 0,
+    flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order = 0,
+    flatbuffers::Offset<tflite::gpu::data::Int3> grid_size = 0,
+    const std::vector<flatbuffers::Offset<flatbuffers::String>>
+        *src_tensors_names = nullptr,
+    const std::vector<flatbuffers::Offset<flatbuffers::String>>
+        *dst_tensors_names = nullptr,
+    flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count = 0,
+    int32_t linkable_count = 0, const char *elementwise_code = nullptr) {
+  auto code__ = code ? _fbb.CreateString(code) : 0;
+  auto compiler_options__ =
+      compiler_options
+          ? _fbb.CreateVector<
+                flatbuffers::Offset<tflite::gpu::data::CompilerOption>>(
+                *compiler_options)
+          : 0;
+  auto src_tensors_names__ =
+      src_tensors_names
+          ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(
+                *src_tensors_names)
+          : 0;
+  auto dst_tensors_names__ =
+      dst_tensors_names
+          ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(
+                *dst_tensors_names)
+          : 0;
+  auto elementwise_code__ =
+      elementwise_code ? _fbb.CreateString(elementwise_code) : 0;
+  return tflite::gpu::data::CreateGPUOperation(
+      _fbb, arguments, code__, work_group_size, compiler_options__,
+      tensor_to_grid, elementwise, linkable, check_src_channels_size,
+      definition, grid_dimension, work_group_launch_order, grid_size,
+      src_tensors_names__, dst_tensors_names__, work_groups_count,
+      linkable_count, elementwise_code__);
+}
+
 }  // namespace data
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/task/util.cc b/tensorflow/lite/delegates/gpu/common/task/util.cc
index 4d2892f7e54..137815b9a90 100644
--- a/tensorflow/lite/delegates/gpu/common/task/util.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/util.cc
@@ -15,6 +15,11 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
 
+#include <cfloat>
+
+#include "absl/strings/substitute.h"
+#include "tensorflow/lite/delegates/gpu/common/util.h"
+
 namespace tflite {
 namespace gpu {
 
@@ -43,5 +48,118 @@ std::string MemoryTypeToMetalType(MemoryType type) {
   return "";
 }
 
+std::string GetXStrideCorrected(const std::string& src_x,
+                                const std::string& batch_size,
+                                const std::string& stride_x,
+                                const std::string& padding_x) {
+  // int p0 = src_x / batch_size;\n";
+  // int b0 = src_x % batch_size;\n";
+  // return p0 * stride_x * batch_size + b0 + padding_x;\n";
+  return absl::Substitute("((($0) / $1) * $2 * $1 + (($0) % $1) + $3)", src_x,
+                          batch_size, stride_x, padding_x);
+}
+
+std::string GetXStrideCorrectedV2(const std::string& src_x,
+                                  const std::string& batch_size,
+                                  const std::string& stride_x,
+                                  const std::string& padding_x) {
+  // int p0 = src_x / batch_size;\n";
+  // int b0 = src_x % batch_size;\n";
+  // return (p0 * stride_x + padding_x) * batch_size + b0;\n";
+  return absl::Substitute("(((($0) / $1) * $2 + $3) * $1 + ($0) % $1)", src_x,
+                          batch_size, stride_x, padding_x);
+}
+
+float4 GetMaskForLastPlane(int channels) {
+  float4 mask = float4(0.0f);
+  const int reminder = channels % 4 == 0 ? 4 : channels % 4;
+  for (int i = 0; i < reminder; ++i) {
+    mask[i] = 1.0f;
+  }
+  return mask;
+}
+
+int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info,
+                                   CalculationsPrecision precision,
+                                   int task_size) {
+  const float task_size_per_cu =
+      task_size / static_cast<float>(gpu_info.GetComputeUnitsCount());
+  int block_size = 1;
+  float threshold_1 = FLT_MAX;
+  float threshold_2 = FLT_MAX;
+  float threshold_4 = FLT_MAX;
+  if (!gpu_info.IsMali()) {
+    return 1;
+  }
+  MaliInfo mali_info = gpu_info.mali_info;
+  switch (precision) {
+    case CalculationsPrecision::F16:
+      if (mali_info.IsBifrostGen1()) {
+        threshold_1 = 256.0f;
+        threshold_2 = 256.0f * 4.0f;
+        threshold_4 = 256.0f * 8.0f;
+      } else if (mali_info.IsBifrostGen2()) {
+        threshold_1 = 256.0f * 2.0f;
+        threshold_2 = 256.0f * 8.0f;
+        threshold_4 = 256.0f * 16.0f;
+      } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
+        threshold_1 = 256.0f;
+        threshold_2 = 256.0f * 6.0f;
+        threshold_4 = 256.0f * 16.0f;
+      } else if (mali_info.IsMidgard()) {
+        threshold_1 = 256.0f * 4.0f;
+        threshold_2 = 256.0f * 16.0f;
+      }
+      break;
+    case CalculationsPrecision::F32_F16:
+      if (mali_info.IsBifrostGen1()) {
+        threshold_1 = 256.0f;
+        threshold_2 = 256.0f * 3.0f;
+        threshold_4 = 256.0f * 32.0f;
+      } else if (mali_info.IsBifrostGen2()) {
+        threshold_1 = 256.0f * 2.0f;
+        threshold_2 = 256.0f * 8.0f;
+      } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
+        threshold_1 = 256.0f;
+        threshold_2 = 256.0f * 8.0f;
+      } else if (mali_info.IsMidgard()) {
+        threshold_1 = 256.0f * 4.0f;
+      }
+      break;
+    case CalculationsPrecision::F32:
+      if (mali_info.IsBifrostGen1()) {
+        threshold_1 = 256.0f;
+        threshold_2 = 256.0f * 4.0f;
+      } else if (mali_info.IsBifrostGen2()) {
+        threshold_1 = 128.0f;
+        threshold_2 = 256.0f * 4.0f;
+      } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
+        threshold_1 = 256.0f;
+        threshold_2 = 256.0f * 12.0f;
+      } else if (mali_info.IsMidgard()) {
+        threshold_1 = 256.0f * 16.0f;
+      }
+      break;
+  }
+  if (task_size_per_cu <= threshold_1) {
+    block_size = 1;
+  } else if (task_size_per_cu <= threshold_2) {
+    block_size = 2;
+  } else if (task_size_per_cu <= threshold_4) {
+    block_size = 4;
+  } else {
+    block_size = 8;
+  }
+  return block_size;
+}
+
+int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size) {
+  int3 work_groups_count;
+  work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
+  work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y);
+  work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z);
+  return work_groups_count;
+}
+
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/task/util.h b/tensorflow/lite/delegates/gpu/common/task/util.h
index 3fb2c5abe84..135126b7d01 100644
--- a/tensorflow/lite/delegates/gpu/common/task/util.h
+++ b/tensorflow/lite/delegates/gpu/common/task/util.h
@@ -17,8 +17,12 @@ limitations under the License.
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_UTIL_H_
 
 #include <string>
+#include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
+#include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
+#include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
 namespace gpu {
@@ -27,6 +31,36 @@ std::string MemoryTypeToCLType(MemoryType type);
 
 std::string MemoryTypeToMetalType(MemoryType type);
 
+// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts
+// with B after W (for example HWBC4) and WB stored in one axis of GPU
+// resources.
+std::string GetXStrideCorrected(const std::string& src_x,
+                                const std::string& batch_size,
+                                const std::string& stride_x,
+                                const std::string& padding_x);
+
+// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts
+// with B after W (for example HWBC4) and WB stored in one axis of GPU
+// resources.
+std::string GetXStrideCorrectedV2(const std::string& src_x,
+                                  const std::string& batch_size,
+                                  const std::string& stride_x,
+                                  const std::string& padding_x);
+
+// Returns float4 mask for last plane(batch of 4 channels)
+// assumes that plane size is 4;
+// for example we have 7 channels, in our data structures we align it to 8
+// but 8s-channel will be empty, then last plane (batch of 4 channels) will
+// have this mask (1, 1, 1, 0).
+float4 GetMaskForLastPlane(int channels);
+
+// task_size as amount of FLT4 processed elements.
+int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info,
+                                   CalculationsPrecision precision,
+                                   int task_size);
+
+int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size);
+
 }  // namespace gpu
 }  // namespace tflite
 
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.h b/tensorflow/lite/delegates/gpu/common/task/weights_conversion.h
similarity index 75%
rename from tensorflow/lite/delegates/gpu/cl/kernels/util.h
rename to tensorflow/lite/delegates/gpu/common/task/weights_conversion.h
index 519f1f117b2..adbd8a107cd 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/util.h
+++ b/tensorflow/lite/delegates/gpu/common/task/weights_conversion.h
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,16 +13,14 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_UTIL_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_UTIL_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_CONVERSION_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_CONVERSION_H_
 
 #include <string>
 #include <vector>
 
 #include "absl/types/span.h"
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
-#include "tensorflow/lite/delegates/gpu/common/precision.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
@@ -30,25 +28,6 @@ limitations under the License.
 
 namespace tflite {
 namespace gpu {
-namespace cl {
-
-std::string GetCommonDefines(CalculationsPrecision precision);
-
-// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts
-// with B after W (for example HWBC4) and WB stored in one axis of GPU
-// resources.
-std::string GetXStrideCorrected(const std::string& src_x,
-                                const std::string& batch_size,
-                                const std::string& stride_x,
-                                const std::string& padding_x);
-
-// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts
-// with B after W (for example HWBC4) and WB stored in one axis of GPU
-// resources.
-std::string GetXStrideCorrectedV2(const std::string& src_x,
-                                  const std::string& batch_size,
-                                  const std::string& stride_x,
-                                  const std::string& padding_x);
 
 template <DataType S, typename T>
 void RearrangeWeightsToOHWIOGroupI4O4(
@@ -198,25 +177,43 @@ void RearrangeWeightsToI4DHWIOOGroupO4(
   }
 }
 
-// Returns float4 mask for last plane(batch of 4 channels)
-// assumes that plane size is 4;
-// for example we have 7 channels, in our data structures we align it to 8
-// but 8s-channel will be empty, then last plane (batch of 4 channels) will
-// have this mask (1, 1, 1, 0).
-float4 GetMaskForLastPlane(int channels);
+template <DataType S, typename T>
+void RearrangeWeightsToOICustomSpatialI4O4(
+    const tflite::gpu::Tensor<OHWI, S>& weights,
+    const std::vector<int>& spatial_remap, absl::Span<T> dst) {
+  const int dst_slices = DivideRoundUp(weights.shape.o, 4);
+  const int src_slices = DivideRoundUp(weights.shape.i, 4);
 
-// returns first work group from wgs that has size not bigger than max_wg_size
-// if no suitable groups among wgs, returns {1, 1, 1}
-int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size);
+  int counter = 0;
+  for (int d = 0; d < dst_slices; ++d) {
+    for (int s = 0; s < src_slices; ++s) {
+      for (int y = 0; y < weights.shape.h; ++y) {
+        for (int x = 0; x < weights.shape.w; ++x) {
+          const int kernel_index = spatial_remap[y * weights.shape.w + x];
+          const int kernel_index_x = kernel_index % weights.shape.w;
+          const int kernel_index_y = kernel_index / weights.shape.w;
+          for (int i = 0; i < 4; ++i) {
+            T filter;
+            for (int j = 0; j < 4; ++j) {
+              const int s_ch = s * 4 + i;
+              const int d_ch = d * 4 + j;
+              if (s_ch < weights.shape.i && d_ch < weights.shape.o) {
+                const int f_index = weights.shape.LinearIndex(
+                    {d_ch, kernel_index_y, kernel_index_x, s_ch});
+                filter[j] = weights.data[f_index];
+              } else {
+                filter[j] = 0.0f;
+              }
+            }
+            dst[counter++] = filter;
+          }
+        }
+      }
+    }
+  }
+}
 
-// task_size as amount of FLT4 processed elements.
-int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info,
-                                   CalculationsPrecision precision,
-                                   int task_size);
-
-int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size);
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_UTIL_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_CONVERSION_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h b/tensorflow/lite/delegates/gpu/common/task/weights_layout.h
similarity index 64%
rename from tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h
rename to tensorflow/lite/delegates/gpu/common/task/weights_layout.h
index f630c9d1f1c..cf22630ee27 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h
+++ b/tensorflow/lite/delegates/gpu/common/task/weights_layout.h
@@ -13,25 +13,29 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_COMMON_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_COMMON_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_LAYOUT_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_LAYOUT_H_
+
+#include <vector>
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
-enum class ConvWeightsLayout {
+enum class WeightsLayout {
   kUnknown,
   kOHWIOGroupI4O4,
+  kOICustomSSpatialI4O4,
 };
 
-struct ConvWeightsDescription {
-  ConvWeightsLayout layout;
+struct WeightsDescription {
+  WeightsLayout layout;
+  // applicable with kOHWIOGroupI4O4
   int output_group_size;
+  // applicable with kOICustomSSpatialI4O4
+  std::vector<int> spatial_remap;
 };
 
-}  // namespace cl
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_COMMON_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_LAYOUT_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc b/tensorflow/lite/delegates/gpu/common/task/work_group_picking.cc
similarity index 90%
rename from tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc
rename to tensorflow/lite/delegates/gpu/common/task/work_group_picking.cc
index 0b7ec8ed683..765bf1d1a0d 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/work_group_picking.cc
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
 
 #include <algorithm>
 #include <limits>
@@ -24,7 +24,6 @@ limitations under the License.
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 namespace {
 
@@ -52,9 +51,9 @@ std::vector<int3> GenerateWorkGroupSizesXYMultipleOf(
         if (work_group_size_xy * z > kernel_info.max_work_group_size) {
           continue;
         }
-        if (x <= gpu_info.max_work_group_size_x &&
-            y <= gpu_info.max_work_group_size_y &&
-            z <= gpu_info.max_work_group_size_z) {
+        if (x <= gpu_info.GetMaxWorkGroupSizeForX() &&
+            y <= gpu_info.GetMaxWorkGroupSizeForY() &&
+            z <= gpu_info.GetMaxWorkGroupSizeForZ()) {
           work_groups.push_back({x, y, z});
         }
       }
@@ -78,9 +77,9 @@ std::vector<int3> GenerateWorkGroupSizesXMultipleOf(
        x += multiplier) {
     for (auto y : possible_y_sizes) {
       for (auto z : possible_z_sizes) {
-        if (x <= gpu_info.max_work_group_size_x &&
-            y <= gpu_info.max_work_group_size_y &&
-            z <= gpu_info.max_work_group_size_z &&
+        if (x <= gpu_info.GetMaxWorkGroupSizeForX() &&
+            y <= gpu_info.GetMaxWorkGroupSizeForY() &&
+            z <= gpu_info.GetMaxWorkGroupSizeForZ() &&
             x * y * z <= kernel_info.max_work_group_size) {
           work_groups.push_back({x, y, z});
         }
@@ -94,9 +93,9 @@ void GetWorkGroupsAlignedToGrid(const GpuInfo& gpu_info,
                                 const KernelInfo& kernel_info, const int3& grid,
                                 std::vector<int3>* work_groups) {
   int3 max_wg_size;
-  max_wg_size.x = gpu_info.max_work_group_size_x;
-  max_wg_size.y = gpu_info.max_work_group_size_y;
-  max_wg_size.z = gpu_info.max_work_group_size_z;
+  max_wg_size.x = gpu_info.GetMaxWorkGroupSizeForX();
+  max_wg_size.y = gpu_info.GetMaxWorkGroupSizeForY();
+  max_wg_size.z = gpu_info.GetMaxWorkGroupSizeForZ();
   GenerateWorkGroupSizesAlignedToGrid(
       grid, max_wg_size, kernel_info.max_work_group_size, work_groups);
 }
@@ -275,7 +274,7 @@ void GetPossibleWorkGroupsConv(TuningType tuning_type, const GpuInfo& gpu_info,
       if (gpu_info.IsAdreno()) {
         max_z_size = gpu_info.adreno_info.IsAdreno3xx() ? 16 : 64;
       }
-      max_z_size = std::min(max_z_size, gpu_info.max_work_group_size_z);
+      max_z_size = std::min(max_z_size, gpu_info.GetMaxWorkGroupSizeForZ());
       work_groups->push_back(
           GetWorkGroupConv(grid, kernel_info.max_work_group_size, max_z_size));
       return;
@@ -290,6 +289,15 @@ void GetPossibleWorkGroupsConv(TuningType tuning_type, const GpuInfo& gpu_info,
   }
 }
 
-}  // namespace cl
+int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size) {
+  for (const auto& wg : wgs) {
+    const int wg_size = wg.x * wg.y * wg.z;
+    if (wg_size <= max_wg_size) {
+      return wg;
+    }
+  }
+  return {1, 1, 1};
+}
+
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h b/tensorflow/lite/delegates/gpu/common/task/work_group_picking.h
similarity index 81%
rename from tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h
rename to tensorflow/lite/delegates/gpu/common/task/work_group_picking.h
index 8135aec8855..508bffc762d 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h
+++ b/tensorflow/lite/delegates/gpu/common/task/work_group_picking.h
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WORK_GROUP_PICKING_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WORK_GROUP_PICKING_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WORK_GROUP_PICKING_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WORK_GROUP_PICKING_H_
 
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/kernel_info.h"
 #include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
@@ -26,7 +26,6 @@ limitations under the License.
 
 namespace tflite {
 namespace gpu {
-namespace cl {
 
 // multiplier can be power of two only
 void GetPossibleWorkGroupsXYMultipleOf(int multiplier, const GpuInfo& gpu_info,
@@ -56,8 +55,11 @@ void GetPossibleWorkGroupsConv(TuningType tuning_type, const GpuInfo& gpu_info,
                                const KernelInfo& kernel_info, const int3& grid,
                                std::vector<int3>* work_groups);
 
-}  // namespace cl
+// returns first work group from wgs that has size not bigger than max_wg_size
+// if no suitable groups among wgs, returns {1, 1, 1}
+int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size);
+
 }  // namespace gpu
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WORK_GROUP_PICKING_H_
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WORK_GROUP_PICKING_H_
diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util.cc b/tensorflow/lite/delegates/gpu/common/winograd_util.cc
index 4b9581d0f39..7d7f96bbb1b 100644
--- a/tensorflow/lite/delegates/gpu/common/winograd_util.cc
+++ b/tensorflow/lite/delegates/gpu/common/winograd_util.cc
@@ -149,5 +149,10 @@ void RearrangeWeightsToWinograd4x4To6x6Weights(
   }
 }
 
+bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr) {
+  return attr.weights.shape.w == 3 && attr.weights.shape.h == 3 &&
+         attr.dilations == HW(1, 1) && attr.strides == HW(1, 1);
+}
+
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util.h b/tensorflow/lite/delegates/gpu/common/winograd_util.h
index e88ceacb490..55629696b58 100644
--- a/tensorflow/lite/delegates/gpu/common/winograd_util.h
+++ b/tensorflow/lite/delegates/gpu/common/winograd_util.h
@@ -19,6 +19,7 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
@@ -38,6 +39,8 @@ void RearrangeWeightsToWinograd4x4To6x6Weights(
     const Tensor<OHWI, DataType::FLOAT32>& src_weights,
     Tensor<OHWI, DataType::FLOAT32>* dst_weights);
 
+bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr);
+
 }  // namespace gpu
 }  // namespace tflite
 
diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc b/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc
new file mode 100644
index 00000000000..81fb643d399
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc
@@ -0,0 +1,45 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace gpu {
+
+TEST(Winograd, CorrectAttributesFor4x4To6x6) {
+  Convolution2DAttributes attr;
+  attr.padding.prepended = HW(1, 2);
+  attr.padding.appended = HW(0, 1);
+  attr.strides = HW(1, 1);
+  attr.dilations = HW(1, 1);
+  attr.weights.shape = OHWI(1, 3, 3, 1);
+  EXPECT_TRUE(IsSuitableForWinograd4x4To6x6(attr));
+}
+
+TEST(Winograd, IncorrectAttributesFor4x4To6x6) {
+  Convolution2DAttributes attr;
+  attr.padding.prepended = HW(1, 2);
+  attr.padding.appended = HW(0, 1);
+  attr.strides = HW(1, 1);
+  attr.dilations = HW(1, 1);
+  attr.weights.shape = OHWI(1, 2, 3, 1);
+  EXPECT_FALSE(IsSuitableForWinograd4x4To6x6(attr));
+}
+
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc
index 439eb0ade90..157a8992f71 100644
--- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc
+++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc
@@ -189,15 +189,15 @@ template std::vector<uint3> GenerateWorkGroupSizes(
 template <typename T>
 void GenerateWorkGroupSizesAlignedToGrid(const T& grid,
                                          const T& max_work_group_size,
-                                         const int max_work_group_invocations,
+                                         const int max_work_group_total_size,
                                          std::vector<T>* work_groups) {
   auto alignment = WorkGroupSizeAlignment::PRECISE;
   *work_groups = GenerateWorkGroupSizes<T>(
-      grid, /*min_work_group_total_size = */ 32, max_work_group_invocations,
+      grid, /*min_work_group_total_size = */ 32, max_work_group_total_size,
       max_work_group_size, alignment, alignment, alignment);
   // If the grid parameter too small, method below cannot generate workgroups.
   if (work_groups->empty()) {
-    AddCornerCases(grid, max_work_group_invocations, max_work_group_size,
+    AddCornerCases(grid, max_work_group_total_size, max_work_group_size,
                    alignment, alignment, alignment, work_groups);
   }
 }
@@ -206,11 +206,11 @@ void GenerateWorkGroupSizesAlignedToGrid(const T& grid,
 
 template void GenerateWorkGroupSizesAlignedToGrid(
     const int3& grid, const int3& max_work_group_size,
-    const int max_work_group_invocations, std::vector<int3>* work_groups);
+    const int max_work_group_total_size, std::vector<int3>* work_groups);
 
 template void GenerateWorkGroupSizesAlignedToGrid(
     const uint3& grid, const uint3& max_work_group_size,
-    const int max_work_group_invocations, std::vector<uint3>* work_groups);
+    const int max_work_group_total_size, std::vector<uint3>* work_groups);
 
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h
index 67c51b45177..30e4d8dc33c 100644
--- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h
+++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h
@@ -41,7 +41,7 @@ std::vector<T> GenerateWorkGroupSizes(
 template <typename T>
 void GenerateWorkGroupSizesAlignedToGrid(const T& grid,
                                          const T& max_work_group_size,
-                                         const int max_work_group_invocations,
+                                         const int max_work_group_total_size,
                                          std::vector<T>* work_groups);
 
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler.cc b/tensorflow/lite/delegates/gpu/gl/compiler.cc
index 4b19082e4aa..20c93d6216a 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler.cc
+++ b/tensorflow/lite/delegates/gpu/gl/compiler.cc
@@ -43,25 +43,27 @@ namespace gl {
 namespace {
 
 struct ExceedSizeChecker {
-  bool operator()(uint32_t v) const { return v > max_size; }
+  bool operator()(uint32_t v) const { return v > max_size.x; }
 
   bool operator()(const uint2& v) const {
-    return v.x > max_size || v.y > max_size;
+    return v.x > max_size.x || v.y > max_size.y;
   }
 
   bool operator()(const uint3& v) const {
-    return v.x > max_size || v.y > max_size || v.z > max_z_size;
+    return v.x > max_size.x || v.y > max_size.y || v.z > max_z_size;
   }
 
-  int max_size;
+  int2 max_size;
   int max_z_size;
 };
 
 // Returns true if any size variable exceeds the given limit
 bool ExceedsMaxSize(const Object& object, const GpuInfo& gpu_info) {
-  return absl::visit(ExceedSizeChecker{gpu_info.max_texture_size,
-                                       gpu_info.max_array_texture_layers},
-                     object.size);
+  ExceedSizeChecker size_checker;
+  size_checker.max_size =
+      int2(gpu_info.GetMaxImage2DWidth(), gpu_info.GetMaxImage2DHeight());
+  size_checker.max_z_size = gpu_info.GetMaxImage2DArrayLayers();
+  return absl::visit(size_checker, object.size);
 }
 
 ObjectType ChooseFastestObjectType(const GpuInfo& gpu_info) {
diff --git a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc
index 4f46a3586f5..598a27c0ea7 100644
--- a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc
+++ b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc
@@ -59,27 +59,33 @@ absl::Status RequestGpuInfo(GpuInfo* gpu_info) {
 
   GLint extensions_count;
   glGetIntegerv(GL_NUM_EXTENSIONS, &extensions_count);
-  info.extensions.resize(extensions_count);
+  info.opengl_info.extensions.resize(extensions_count);
   for (int i = 0; i < extensions_count; ++i) {
-    info.extensions[i] = std::string(
+    info.opengl_info.extensions[i] = std::string(
         reinterpret_cast<const char*>(glGetStringi(GL_EXTENSIONS, i)));
   }
   glGetIntegerv(GL_MAX_COMPUTE_SHADER_STORAGE_BLOCKS,
                 &info.opengl_info.max_ssbo_bindings);
   glGetIntegerv(GL_MAX_COMPUTE_IMAGE_UNIFORMS,
                 &info.opengl_info.max_image_bindings);
-  info.max_work_group_size.resize(3);
   glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 0,
-                  &info.max_work_group_size[0]);
+                  &info.opengl_info.max_compute_work_group_size_x);
   glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
-                  &info.max_work_group_size[1]);
+                  &info.opengl_info.max_compute_work_group_size_y);
   glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 2,
-                  &info.max_work_group_size[2]);
+                  &info.opengl_info.max_compute_work_group_size_z);
+  info.max_work_group_size.push_back(
+      info.opengl_info.max_compute_work_group_size_x);
+  info.max_work_group_size.push_back(
+      info.opengl_info.max_compute_work_group_size_y);
+  info.max_work_group_size.push_back(
+      info.opengl_info.max_compute_work_group_size_z);
   glGetIntegerv(GL_MAX_COMPUTE_WORK_GROUP_INVOCATIONS,
-                &info.max_work_group_invocations);
-  glGetIntegerv(GL_MAX_TEXTURE_SIZE, &info.max_texture_size);
+                &info.opengl_info.max_work_group_invocations);
+  glGetIntegerv(GL_MAX_TEXTURE_SIZE, &info.opengl_info.max_texture_size);
   glGetIntegerv(GL_MAX_IMAGE_UNITS, &info.opengl_info.max_image_units);
-  glGetIntegerv(GL_MAX_ARRAY_TEXTURE_LAYERS, &info.max_array_texture_layers);
+  glGetIntegerv(GL_MAX_ARRAY_TEXTURE_LAYERS,
+                &info.opengl_info.max_array_texture_layers);
   RETURN_IF_ERROR(GetOpenGlErrors());
   *gpu_info = info;
   return absl::OkStatus();
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc
index e21538b22a5..54252dc4fc8 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc
@@ -29,26 +29,26 @@ uint64_t CalculateProduct(const uint3& value) {
 }
 
 void MaybeShrinkWorkgroup(const GpuInfo& gpu_info, uint3* wg) {
-  while (wg->x > gpu_info.max_work_group_size[0]) {
+  while (wg->x > gpu_info.GetMaxWorkGroupSizeForX()) {
     wg->x /= 2;
   }
 
-  while (wg->y > gpu_info.max_work_group_size[1]) {
+  while (wg->y > gpu_info.GetMaxWorkGroupSizeForY()) {
     wg->y /= 2;
   }
 
-  while (wg->z > gpu_info.max_work_group_size[2]) {
+  while (wg->z > gpu_info.GetMaxWorkGroupSizeForZ()) {
     wg->z /= 2;
   }
 
   // Code below decreases amount of invocations per workgroup in a balanced way.
   // As example, workgroup size is x=16, y=8, z=8 (16x8x8 = 1024), but
-  // max_work_group_invocations = 512. We need to fit this limit and we can
+  // max_work_group_total_size = 512. We need to fit this limit and we can
   // reduce workgroup size in different ways, but we want to use the most
   // balanced way. So code below will find the maximal of three dimensions and
   // reduce it, so the whole workgroup is kept balanced by all dimensions. And
   // the final reduced workgroup will be x=8, y=8, z=8 for the given example.
-  while (CalculateProduct(*wg) > gpu_info.max_work_group_invocations) {
+  while (CalculateProduct(*wg) > gpu_info.GetMaxWorkGroupTotalSize()) {
     unsigned int* max = &wg->x;
     if (wg->y > *max) max = &wg->y;
     if (wg->z > *max) max = &wg->z;
diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc
index 8062fb16d7c..bd6e9caa342 100644
--- a/tensorflow/lite/delegates/gpu/metal/api.cc
+++ b/tensorflow/lite/delegates/gpu/metal/api.cc
@@ -294,10 +294,14 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
           node_id, inputs[0], inputs[1], outputs[0],
           absl::any_cast<MaxUnpooling2DAttributes>(node->operation.attributes));
       break;
-    case OperationType::MEAN:
-      *tasks = Mean(node_id, inputs[0], outputs[0],
-                    absl::any_cast<MeanAttributes>(node->operation.attributes));
+    case OperationType::MEAN: {
+      auto attr = absl::any_cast<MeanAttributes>(node->operation.attributes);
+      if (attr.dims != std::set<Axis>({Axis::HEIGHT, Axis::WIDTH})) {
+        return absl::UnimplementedError("Mean supports HW axis only in Metal");
+      }
+      *tasks = Mean(node_id, inputs[0], outputs[0], attr);
       break;
+    }
     case OperationType::MUL:
       if (inputs.size() == 1) {
         if (node->operation.attributes.has_value()) {
diff --git a/tensorflow/lite/experimental/examples/lstm/rnn.py b/tensorflow/lite/experimental/examples/lstm/rnn.py
index 1f55538bda5..aed003982dd 100644
--- a/tensorflow/lite/experimental/examples/lstm/rnn.py
+++ b/tensorflow/lite/experimental/examples/lstm/rnn.py
@@ -34,11 +34,14 @@ from tensorflow.python.ops.rnn import _best_effort_input_batch_size
 from tensorflow.python.ops.rnn import _dynamic_rnn_loop
 from tensorflow.python.ops.rnn import _should_cache
 from tensorflow.python.ops.rnn import _transpose_batch_time
+from tensorflow.python.util import deprecation
 from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
 
 @tf_export(v1=["lite.experimental.nn.dynamic_rnn"])
+@deprecation.deprecated(
+    None, "Use `keras.layers.LSTM` instead.")
 def dynamic_rnn(cell,
                 inputs,
                 sequence_length=None,
diff --git a/tensorflow/lite/experimental/examples/lstm/rnn_cell.py b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py
index 9736719c997..3c609af9b04 100644
--- a/tensorflow/lite/experimental/examples/lstm/rnn_cell.py
+++ b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py
@@ -32,10 +32,13 @@ from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import partitioned_variables
 from tensorflow.python.ops import rnn_cell_impl
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 
 
 @tf_export(v1=["lite.experimental.nn.TfLiteRNNCell"])
+@deprecation.deprecated(
+    None, "Use `keras.layers.RNN` instead for TF2.x.")
 class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
   """The most basic RNN cell.
 
@@ -159,6 +162,8 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
 
 
 @tf_export(v1=["lite.experimental.nn.TFLiteLSTMCell"])
+@deprecation.deprecated(
+    None, "Use `keras.layers.LSTM` instead.")
 class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
   """Long short-term memory unit (LSTM) recurrent network cell.
 
diff --git a/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD b/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD
deleted file mode 100644
index bb64be61599..00000000000
--- a/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD
+++ /dev/null
@@ -1,24 +0,0 @@
-load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
-
-package(
-    default_visibility = ["//tensorflow:internal"],
-    licenses = ["notice"],  # Apache 2.0
-)
-
-cc_library(
-    name = "tflite_api_dispatcher",
-    hdrs = ["tflite_api_dispatcher.h"],
-    compatible_with = get_compatible_with_portable(),
-    deps = [
-        "//tensorflow/lite:framework_lib",
-    ],
-)
-
-cc_library(
-    name = "tflite_api_dispatcher_with_kernels",
-    hdrs = ["tflite_api_dispatcher.h"],
-    deps = [
-        ":tflite_api_dispatcher",
-        "//tensorflow/lite:framework_lib",
-    ],
-)
diff --git a/tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h b/tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h
deleted file mode 100644
index 55771ed9673..00000000000
--- a/tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// The purpose of this file is to indirect how implementations of the TensorFlow
-// Lite API are selected by providing a single namespace tflite_api_dispatcher.
-
-#ifndef TENSORFLOW_LITE_EXPERIMENTAL_TFLITE_API_DISPATCHER_TFLITE_API_DISPATCHER_H_
-#define TENSORFLOW_LITE_EXPERIMENTAL_TFLITE_API_DISPATCHER_TFLITE_API_DISPATCHER_H_
-
-// Import the relevant interpreter and model files.
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/model.h"
-
-namespace tflite_api_dispatcher {
-
-// Use the correct interpreter.
-using tflite::Interpreter;
-using tflite::InterpreterBuilder;
-using TfLiteModel = tflite::FlatBufferModel;
-using TfLiteVerifier = tflite::TfLiteVerifier;
-
-}  // namespace tflite_api_dispatcher
-
-#endif  // TENSORFLOW_LITE_EXPERIMENTAL_TFLITE_API_DISPATCHER_TFLITE_API_DISPATCHER_H_
diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml
index bf52cb3c23d..84a09824613 100644
--- a/tensorflow/lite/g3doc/_book.yaml
+++ b/tensorflow/lite/g3doc/_book.yaml
@@ -42,7 +42,7 @@ upper_tabs:
       - heading: "Text"
       - title: "Text classification with Model Maker"
         path: /lite/tutorials/model_maker_text_classification
-      - title: "Question Answer with Model Maker"
+      - title: "BERT Question Answer with Model Maker"
         path: /lite/tutorials/model_maker_question_answer
 
       - heading: "Microcontrollers"
@@ -110,7 +110,7 @@ upper_tabs:
       - heading: "Run Inference with metadata"
       - title: "Overview"
         path: /lite/inference_with_metadata/overview
-      - title: "Generate model interfaces with codegen"
+      - title: "Generate model interfaces using metadata"
         path: /lite/inference_with_metadata/codegen
       - title: "Integrate models with Task Library"
         path: /lite/inference_with_metadata/task_library/overview
@@ -215,7 +215,7 @@ upper_tabs:
       - title: "Style transfer"
         path: /lite/models/style_transfer/overview
       - heading: "Text"
-      - title: "Question and answer"
+      - title: "BERT Question Answer"
         path: /lite/models/bert_qa/overview
       - title: "Smart reply"
         path: /lite/models/smart_reply/overview
diff --git a/tensorflow/lite/g3doc/convert/metadata.md b/tensorflow/lite/g3doc/convert/metadata.md
index a6670f10aba..41a92b4b1ca 100644
--- a/tensorflow/lite/g3doc/convert/metadata.md
+++ b/tensorflow/lite/g3doc/convert/metadata.md
@@ -472,15 +472,39 @@ public QuantizationParams getoutputTensorQuantizationParams(int inputIndex);
 public int[] getoutputTensorShape(int inputIndex);
 ```
 
-You can also read associated files through their names with the
-`getAssociatedFile` method:
-
-```java
-public InputStream getAssociatedFile(String fileName);
-```
-
 Though the
 [TensorFlow Lite model schema](https://github.com/tensorflow/tensorflow/blob/aa7ff6aa28977826e7acae379e82da22482b2bf2/tensorflow/lite/schema/schema.fbs#L1075)
 supports multiple subgraphs, the TFLite Interpreter currently only supports a
 single subgraph. Therefore, `MetadataExtractor` omits subgraph index as an input
 argument in its methods.
+
+## Read the associated files from models
+
+The TensorFlow Lite model with metadata and associated files is essentially a
+zip file that can be unpacked with common zip tools to get the associated files.
+For example, you can unzip
+[mobilenet_v1_0.75_160_quantized](https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_0.75_160_quantized/1/metadata/1)
+and extract the label file in the model as follows:
+
+```sh
+$ unzip mobilenet_v1_0.75_160_quantized_1_metadata_1.tflite
+Archive:  mobilenet_v1_0.75_160_quantized_1_metadata_1.tflite
+ extracting: labels.txt
+```
+
+You can also read associated files through the Metadata Extractor library.
+
+In Java, pass the file name into the `MetadataExtractor.getAssociatedFile`
+method:
+
+```java
+public InputStream getAssociatedFile(String fileName);
+```
+
+Similarily, in C++, this can be done with the method,
+`ModelMetadataExtractor::GetAssociatedFile`:
+
+```c++
+tflite::support::StatusOr<absl::string_view> GetAssociatedFile(
+      const std::string& filename) const;
+```
diff --git a/tensorflow/lite/g3doc/guide/android.md b/tensorflow/lite/g3doc/guide/android.md
index 420269de941..e36e8b1bcc5 100644
--- a/tensorflow/lite/g3doc/guide/android.md
+++ b/tensorflow/lite/g3doc/guide/android.md
@@ -41,6 +41,34 @@ as a starting point.
 The following sections contain some useful information for working with
 TensorFlow Lite on Android.
 
+### Use Android Studio ML Model Binding
+
+Note: Required [Android Studio 4.1](https://developer.android.com/studio) or
+above
+
+To import a TensorFlow Lite (TFLite) model:
+
+1.  Right-click on the module you would like to use the TFLite model or click on
+    `File`, then `New` > `Other` > `TensorFlow Lite Model`
+    ![Right-click menus to access the TensorFlow Lite import functionality](../images/android/right_click_menu.png)
+
+1.  Select the location of your TFLite file. Note that the tooling will
+    configure the module's dependency on your behalf with ML Model binding and
+    all dependencies automatically inserted into your Android module's
+    `build.gradle` file.
+
+    Optional: Select the second checkbox for importing TensorFlow GPU if you
+    want to use [GPU acceleration](../performance/gpu).
+    ![Import dialog for TFLite model](../images/android/import_dialog.png)
+
+1.  Click `Finish`.
+
+1.  The following screen will appear after the import is successful. To start
+    using the model, select Kotlin or Java, copy and paste the code under the
+    `Sample Code` section. You can get back to this screen by double clicking
+    the TFLite model under the `ml` directory in Android Studio.
+    ![Model details page in Android Studio](../images/android/model_details.png)
+
 ### Use the TensorFlow Lite Task Library
 
 TensorFlow Lite Task Library contains a set of powerful and easy-to-use
diff --git a/tensorflow/lite/g3doc/guide/model_maker.md b/tensorflow/lite/g3doc/guide/model_maker.md
index 956bd127bcf..7db650d97a0 100644
--- a/tensorflow/lite/g3doc/guide/model_maker.md
+++ b/tensorflow/lite/g3doc/guide/model_maker.md
@@ -15,7 +15,7 @@ Supported Tasks
 -------------------------------------------------------------------------------------------------------- | ------------
 Image Classification [guide](https://www.tensorflow.org/lite/tutorials/model_maker_image_classification) | Classify images into predefined categories.
 Text Classification [guide](https://www.tensorflow.org/lite/tutorials/model_maker_text_classification)   | Classify text into predefined categories.
-Question Answer [guide](https://www.tensorflow.org/lite/tutorials/model_maker_question_answer)           | Find the answer in a certain context for a given question.
+BERT Question Answer [guide](https://www.tensorflow.org/lite/tutorials/model_maker_question_answer)      | Find the answer in a certain context for a given question with BERT.
 
 ## End-to-End Example
 
diff --git a/tensorflow/lite/g3doc/images/android/import_dialog.png b/tensorflow/lite/g3doc/images/android/import_dialog.png
new file mode 100644
index 00000000000..adaaa596a52
Binary files /dev/null and b/tensorflow/lite/g3doc/images/android/import_dialog.png differ
diff --git a/tensorflow/lite/g3doc/images/android/model_details.png b/tensorflow/lite/g3doc/images/android/model_details.png
new file mode 100644
index 00000000000..952d8489e94
Binary files /dev/null and b/tensorflow/lite/g3doc/images/android/model_details.png differ
diff --git a/tensorflow/lite/g3doc/images/android/right_click_menu.png b/tensorflow/lite/g3doc/images/android/right_click_menu.png
new file mode 100644
index 00000000000..64ff7c46d19
Binary files /dev/null and b/tensorflow/lite/g3doc/images/android/right_click_menu.png differ
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/codegen.md b/tensorflow/lite/g3doc/inference_with_metadata/codegen.md
index b447573da41..2123737c76a 100644
--- a/tensorflow/lite/g3doc/inference_with_metadata/codegen.md
+++ b/tensorflow/lite/g3doc/inference_with_metadata/codegen.md
@@ -1,4 +1,121 @@
-# Generate model interfaces with TensorFlow Lite code generator
+# Generate model interfaces using metadata
+
+Using [TensorFlow Lite Metadata](../convert/metadata), developers can generate
+wrapper code to enable integration on Android. For most developers, the
+graphical interface of [Android Studio ML Model Binding](#mlbinding) is the
+easiest to use. If you require more customisation or are using command line
+tooling, the [TensorFlow Lite Codegen](#codegen) is also available.
+
+## Use Android Studio ML Model Binding {:#mlbinding}
+
+For TensorFlow Lite models enhanced with [metadata](../convert/metadata.md),
+developers can use Android Studio ML Model Binding to automatically configure
+settings for the project and generate wrapper classes based on the model
+metadata. The wrapper code removes the need to interact directly with
+`ByteBuffer`. Instead, developers can interact with the TensorFlow Lite model
+with typed objects such as `Bitmap` and `Rect`.
+
+Note: Required [Android Studio 4.1](https://developer.android.com/studio) or
+above
+
+### Import a TensorFlow Lite model in Android Studio
+
+1.  Right-click on the module you would like to use the TFLite model or click on
+    `File`, then `New` > `Other` > `TensorFlow Lite Model`
+    ![Right-click menus to access the TensorFlow Lite import functionality](../images/android/right_click_menu.png)
+
+1.  Select the location of your TFLite file. Note that the tooling will
+    configure the module's dependency on your behalf with ML Model binding and
+    all dependencies automatically inserted into your Android module's
+    `build.gradle` file.
+
+    Optional: Select the second checkbox for importing TensorFlow GPU if you
+    want to use GPU acceleration.
+    ![Import dialog for TFLite model](../images/android/import_dialog.png)
+
+1.  Click `Finish`.
+
+1.  The following screen will appear after the import is successful. To start
+    using the model, select Kotlin or Java, copy and paste the code under the
+    `Sample Code` section. You can get back to this screen by double clicking
+    the TFLite model under the `ml` directory in Android Studio.
+    ![Model details page in Android Studio](../images/android/model_details.png)
+
+### Accelerating model inference {:#acceleration}
+
+ML Model Binding provides a way for developers to accelerate their code through
+the use of delegates and the number of threads.
+
+Note: The TensorFlow Lite Interpreter must be created on the same thread as when
+is is run. Otherwise, TfLiteGpuDelegate Invoke: GpuDelegate must run on the same
+thread where it was initialized. may occur.
+
+Step 1. Check the module `build.gradle` file that it contains the following
+dependency:
+
+```java
+    dependencies {
+        ...
+        // TFLite GPU delegate 2.3.0 or above is required.
+        implementation 'org.tensorflow:tensorflow-lite-gpu:2.3.0'
+    }
+```
+
+Step 2. Detect if GPU running on the device is compatible with TensorFlow GPU
+delegate, if not run the model using multiple CPU threads:
+
+<div>
+    <devsite-selector>
+    <section>
+      <h3>Kotlin</h3>
+      <p><pre class="prettyprint lang-kotlin">
+    import org.tensorflow.lite.gpu.CompatibilityList
+    import org.tensorflow.lite.gpu.GpuDelegate
+
+    val compatList = CompatibilityList()
+
+    val options = if(compatList.isDelegateSupportedOnThisDevice) {
+        // if the device has a supported GPU, add the GPU delegate
+        Model.Options.Builder().setDevice(Model.Device.GPU).build()
+    } else {
+        // if the GPU is not supported, run on 4 threads
+        Model.Options.Builder().setNumThreads(4).build()
+    }
+
+    // Initialize the model as usual feeding in the options object
+    val myModel = MyModel.newInstance(context, options)
+
+    // Run inference per sample code
+      </pre></p>
+    </section>
+    <section>
+      <h3>Java</h3>
+      <p><pre class="prettyprint lang-java">
+    import org.tensorflow.lite.support.model.Model
+    import org.tensorflow.lite.gpu.CompatibilityList;
+    import org.tensorflow.lite.gpu.GpuDelegate;
+
+    // Initialize interpreter with GPU delegate
+    Model.Options options;
+    CompatibilityList compatList = CompatibilityList();
+
+    if(compatList.isDelegateSupportedOnThisDevice()){
+        // if the device has a supported GPU, add the GPU delegate
+        options = Model.Options.Builder().setDevice(Model.Device.GPU).build();
+    } else {
+        // if the GPU is not supported, run on 4 threads
+        options = Model.Options.Builder().setNumThreads(4).build();
+    }
+
+    MyModel myModel = new MyModel.newInstance(context, options);
+
+    // Run inference per sample code
+      </pre></p>
+    </section>
+    </devsite-selector>
+</div>
+
+## Generate model interfaces with TensorFlow Lite code generator {:#codegen}
 
 Note: TensorFlow Lite wrapper code generator currently only supports Android.
 
@@ -14,7 +131,7 @@ under relevant fields in
 [metadata_schema.fbs](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/metadata_schema.fbs),
 to see how the codegen tool parses each field.
 
-## Generate wrapper Code
+### Generate wrapper Code
 
 You will need to install the following tooling in your terminal:
 
@@ -45,9 +162,9 @@ from google.colab import files
 files.download('classify_wrapper.zip')
 ```
 
-## Using the generated code
+### Using the generated code
 
-### Step 1: Import the generated code
+#### Step 1: Import the generated code
 
 Unzip the generated code if necessary into a directory structure. The root of
 the generated code is assumed to be `SRC_ROOT`.
@@ -59,7 +176,7 @@ select `SRC_ROOT`
 Using the above example, the directory and the module imported would be called
 `classify_wrapper`.
 
-### Step 2: Update the app's `build.gradle` file
+#### Step 2: Update the app's `build.gradle` file
 
 In the app module that will be consuming the generated library module:
 
@@ -77,7 +194,7 @@ Under the dependencies section, add the following:
 implementation project(":classify_wrapper")
 ```
 
-### Step 3: Using the model
+#### Step 3: Using the model
 
 ```java
 // 1. Initialize the model
@@ -103,7 +220,7 @@ if(null != myImageClassifier) {
 }
 ```
 
-## Accelerating model inference
+### Accelerating model inference
 
 The generated code provides a way for developers to accelerate their code
 through the use of [delegates](../performance/delegates.md) and the number of
@@ -127,7 +244,7 @@ try {
 }
 ```
 
-## Troubleshooting
+### Troubleshooting
 
 If you get a 'java.io.FileNotFoundException: This file can not be opened as a
 file descriptor; it is probably compressed' error, insert the following lines
@@ -138,16 +255,3 @@ aaptOptions {
    noCompress "tflite"
 }
 ```
-
-## Generate code with Android Studio ML Model Binding
-
-[Android Studio ML Model Binding](https://developer.android.com/studio/preview/features#tensor-flow-lite-models)
-allows you to directly import TensorFlow Lite models and use them in your
-Android Studio projects. It generates easy-to-use classes so you can run your
-model with less code and better type safety. See the
-[introduction](https://developer.android.com/studio/preview/features#tensor-flow-lite-models)
-for more details.
-
-Note: Code generated by the TensorFlow Lite Android code generator may include
-some latest API or experimental features, which can be a super set of the one
-generated by the Android Studio ML Model Binding.
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/overview.md b/tensorflow/lite/g3doc/inference_with_metadata/overview.md
index 2bb628bddb9..736f1540c85 100644
--- a/tensorflow/lite/g3doc/inference_with_metadata/overview.md
+++ b/tensorflow/lite/g3doc/inference_with_metadata/overview.md
@@ -4,25 +4,32 @@ Inferencing [models with metadata](../convert/metadata.md) can be as easy as
 just a few lines of code. TensorFlow Lite metadata contains a rich description
 of what the model does and how to use the model. It can empower code generators
 to automatically generate the inference code for you, such as using the
-[TensorFlow Lite Android code generator](codegen.md#generate-code-with-tensorflow-lite-android-code-generator)
-and the
-[Android Studio ML Binding feature](codegen.md#generate-code-with-android-studio-ml-model-binding).
-It can also be used to configure your custom inference pipeline.
+[Android Studio ML Binding feature](codegen.md#mlbinding) or
+[TensorFlow Lite Android code generator](codegen.md#codegen). It can also be
+used to configure your custom inference pipeline.
 
 ## Tools and libraries
 
 TensorFlow Lite provides varieties of tools and libraries to serve different
 tiers of deployment requirements as follows:
 
-### Generate model interface with the TensorFlow Lite Code Generator
+### Generate model interface with Android code generators
 
-[TensorFlow Lite Code Generator](codegen.md) is an executable that generates
-model interface automatically based on the metadata. It currently supports
-Android with Java. The wrapper code removes the need to interact directly with
-`ByteBuffer`. Instead, developers can interact with the TensorFlow Lite model
-with typed objects such as `Bitmap` and `Rect`. Android Studio users can also
-get access to the codegen feature through
-[Android Studio ML Binding](codegen.md#generate-code-with-android-studio-ml-model-binding).
+There are two ways to automatically generate the necessary Android wrapper code
+for TensorFlow Lite model with metadata:
+
+1.  [Android Studio ML Model Binding](codegen.md#mlbinding) is tooling available
+    within Android Studio to import TensorFlow Lite model through a graphical
+    interface. Android Studio will automatically configure settings for the
+    project and generate wrapper classes based on the model metadata.
+
+2.  [TensorFlow Lite Code Generator](codegen.md#codegen) is an executable that
+    generates model interface automatically based on the metadata. It currently
+    supports Android with Java. The wrapper code removes the need to interact
+    directly with `ByteBuffer`. Instead, developers can interact with the
+    TensorFlow Lite model with typed objects such as `Bitmap` and `Rect`.
+    Android Studio users can also get access to the codegen feature through
+    [Android Studio ML Binding](codegen.md#mlbinding).
 
 ### Leverage out-of-box APIs with the TensorFlow Lite Task Library
 
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md
index 995c1cf7478..168574b45bd 100644
--- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md
@@ -18,7 +18,7 @@ documentation for the Question-Answer model
 The following models are compatible with the `BertNLClassifier` API.
 
 *   Models created by
-    [TensorFlow Lite Model Maker for Question Answer](https://www.tensorflow.org/lite/tutorials/model_maker_question_answer).
+    [TensorFlow Lite Model Maker for BERT Question Answer](https://www.tensorflow.org/lite/tutorials/model_maker_question_answer).
 
 *   The
     [pretrained BERT models on TensorFlow Hub](https://tfhub.dev/tensorflow/collections/lite/task-library/bert-question-answerer/1).
diff --git a/tensorflow/lite/g3doc/models/bert_qa/overview.md b/tensorflow/lite/g3doc/models/bert_qa/overview.md
index fc75cb7f33d..af042e44d07 100644
--- a/tensorflow/lite/g3doc/models/bert_qa/overview.md
+++ b/tensorflow/lite/g3doc/models/bert_qa/overview.md
@@ -1,4 +1,4 @@
-# Question and answer
+# BERT Question and Answer
 
 Use a pre-trained model to answer questions based on the content of a given
 passage.
diff --git a/tensorflow/lite/g3doc/performance/gpu.md b/tensorflow/lite/g3doc/performance/gpu.md
index e992518baf1..581f2f2d843 100644
--- a/tensorflow/lite/g3doc/performance/gpu.md
+++ b/tensorflow/lite/g3doc/performance/gpu.md
@@ -158,6 +158,12 @@ Note: The TensorFlow Lite Interpreter must be created on the same thread as
 where it is run. Otherwise, `TfLiteGpuDelegate Invoke: GpuDelegate must run on
 the same thread where it was initialized.` may occur.
 
+There are two ways to invoke model acceleration depending on if you are using
+[Android Studio ML Model Binding](../inference_with_metadata/codegen#acceleration)
+or TensorFlow Lite Interpreter.
+
+#### TensorFlow Lite Interpreter
+
 Look at the demo to see how to add the delegate. In your application, add the
 AAR as above, import `org.tensorflow.lite.gpu.GpuDelegate` module, and use
 the`addDelegate` function to register the GPU delegate to the interpreter:
diff --git a/tensorflow/lite/g3doc/performance/gpu_advanced.md b/tensorflow/lite/g3doc/performance/gpu_advanced.md
index d23c87c8288..45fbaf4c9c8 100644
--- a/tensorflow/lite/g3doc/performance/gpu_advanced.md
+++ b/tensorflow/lite/g3doc/performance/gpu_advanced.md
@@ -65,7 +65,12 @@ allows the appropriate versions; for example, ADD v2.
 
 ## Basic usage
 
-### Android (Kotlin / Java)
+There are two ways to invoke model acceleration in Android depending on if you
+are using
+[Android Studio ML Model Binding](../inference_with_metadata/codegen#acceleration)
+or TensorFlow Lite Interpreter.
+
+### Android via TensorFlow Lite Interpreter
 
 Add the `tensorflow-lite-gpu` package alongside the existing `tensorflow-lite`
 package in the existing `dependencies` block.
diff --git a/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb b/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb
index 328f9d0cb70..04e351785b0 100644
--- a/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb
+++ b/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb
@@ -37,7 +37,7 @@
         "id": "Gb7qyhNL1yWt"
       },
       "source": [
-        "# Question Answer with TensorFlow Lite Model Maker"
+        "# BERT Question Answer with TensorFlow Lite Model Maker"
       ]
     },
     {
@@ -79,7 +79,7 @@
         "id": "UxEHFTk755qw"
       },
       "source": [
-        "# Introduction to Question Answer Task"
+        "# Introduction to BERT Question Answer Task"
       ]
     },
     {
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index cea4a4ae2c2..e14c38d1b08 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -178,6 +178,11 @@ public final class Interpreter implements AutoCloseable {
     Boolean allowFp16PrecisionForFp32;
     Boolean allowBufferHandleOutput;
     Boolean allowCancellation;
+
+    // TODO(b/171856982): update the comment when applying XNNPACK delegate by default is
+    // enabled for C++ TfLite library on Android platform.
+    // Note: the initial "null" value indicates default behavior which may mean XNNPACK
+    // delegate will be applied by default.
     Boolean useXNNPACK;
     final List<Delegate> delegates = new ArrayList<>();
   }
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 3d439e4470c..1eaaafdff88 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -83,9 +83,17 @@ final class NativeInterpreterWrapper implements AutoCloseable {
       allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
     }
     applyDelegates(options);
+
+    // Simply use "-1" to represent the default mode.
+    int applyXNNPACKMode = -1;
     if (options.useXNNPACK != null) {
-      useXNNPACK(
-          interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads);
+      applyXNNPACKMode = options.useXNNPACK.booleanValue() ? 1 : 0;
+    }
+
+    // TODO(b/171856982): uncomment the following when applying XNNPACK delegate by default is
+    // enabled for C++ TfLite library on Android platform.
+    if (applyXNNPACKMode == 1 /*|| applyXNNPACKMode == -1*/) {
+      useXNNPACK(interpreterHandle, errorHandle, applyXNNPACKMode, options.numThreads);
     }
     allocateTensors(interpreterHandle, errorHandle);
     this.isMemoryAllocated = true;
@@ -459,7 +467,7 @@ final class NativeInterpreterWrapper implements AutoCloseable {
   private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
 
   private static native void useXNNPACK(
-      long interpreterHandle, long errorHandle, boolean state, int numThreads);
+      long interpreterHandle, long errorHandle, int state, int numThreads);
 
   private static native long createErrorReporter(int size);
 
diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD
index 75e32dfc2a9..8b6bad53206 100644
--- a/tensorflow/lite/java/src/main/native/BUILD
+++ b/tensorflow/lite/java/src/main/native/BUILD
@@ -32,7 +32,6 @@ cc_library(
         "//tensorflow/lite:util",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
-        "//tensorflow/lite/experimental/tflite_api_dispatcher:tflite_api_dispatcher_with_kernels",
         "//tensorflow/lite/java/jni",
     ],
     alwayslink = 1,
diff --git a/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc
index aa654046ef4..eb17fdcf2a5 100644
--- a/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc
+++ b/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc
@@ -21,12 +21,14 @@ limitations under the License.
 namespace tflite {
 
 // The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in
-// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the
-// builtin ops. For smaller binary sizes users should avoid linking this in, and
-// should provide a custom make CreateOpResolver() instead.
+// the tflite namespace. This one instantiates a
+// BuiltinOpResolverWithoutDefaultDelegates, with all the builtin ops but
+// without applying any TfLite delegates by default (like the XNNPACK delegate).
+// For smaller binary sizes users should avoid linking this in, and should
+// provide a custom make CreateOpResolver() instead.
 std::unique_ptr<OpResolver> CreateOpResolver() {  // NOLINT
   return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
-      new tflite::ops::builtin::BuiltinOpResolver());
+      new tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
 }
 
 }  // namespace tflite
diff --git a/tensorflow/lite/java/src/main/native/jni_utils.cc b/tensorflow/lite/java/src/main/native/jni_utils.cc
index 0187d489ee8..1dc7d076fad 100644
--- a/tensorflow/lite/java/src/main/native/jni_utils.cc
+++ b/tensorflow/lite/java/src/main/native/jni_utils.cc
@@ -19,16 +19,15 @@ limitations under the License.
 #include <stdio.h>
 #include <stdlib.h>
 
+namespace tflite {
+namespace jni {
+
 const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException";
 const char kIllegalStateException[] = "java/lang/IllegalStateException";
 const char kNullPointerException[] = "java/lang/NullPointerException";
-const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException";
 const char kUnsupportedOperationException[] =
     "java/lang/UnsupportedOperationException";
 
-namespace tflite {
-namespace jni {
-
 void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
   va_list args;
   va_start(args, fmt);
diff --git a/tensorflow/lite/java/src/main/native/jni_utils.h b/tensorflow/lite/java/src/main/native/jni_utils.h
index cb0cdf5b49f..62d1e43a961 100644
--- a/tensorflow/lite/java/src/main/native/jni_utils.h
+++ b/tensorflow/lite/java/src/main/native/jni_utils.h
@@ -20,23 +20,23 @@ limitations under the License.
 
 #include "tensorflow/lite/error_reporter.h"
 
+namespace tflite {
+namespace jni {
+
 extern const char kIllegalArgumentException[];
 extern const char kIllegalStateException[];
 extern const char kNullPointerException[];
-extern const char kIndexOutOfBoundsException[];
 extern const char kUnsupportedOperationException[];
 
-namespace tflite {
-namespace jni {
-
 void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...);
 
 class BufferErrorReporter : public ErrorReporter {
  public:
   BufferErrorReporter(JNIEnv* env, int limit);
-  virtual ~BufferErrorReporter();
+  ~BufferErrorReporter() override;
   int Report(const char* format, va_list args) override;
   const char* CachedErrorMessage();
+  using ErrorReporter::Report;
 
  private:
   char* buffer_;
diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 959acfb205e..3551286b966 100644
--- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -23,8 +23,10 @@ limitations under the License.
 
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
-#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/interpreter_builder.h"
 #include "tensorflow/lite/java/src/main/native/jni_utils.h"
+#include "tensorflow/lite/model_builder.h"
 #include "tensorflow/lite/util.h"
 
 namespace tflite {
@@ -37,29 +39,27 @@ using tflite::jni::ThrowException;
 
 namespace {
 
-tflite_api_dispatcher::Interpreter* convertLongToInterpreter(JNIEnv* env,
-                                                             jlong handle) {
+tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
   if (handle == 0) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to Interpreter.");
     return nullptr;
   }
-  return reinterpret_cast<tflite_api_dispatcher::Interpreter*>(handle);
+  return reinterpret_cast<tflite::Interpreter*>(handle);
 }
 
-tflite_api_dispatcher::TfLiteModel* convertLongToModel(JNIEnv* env,
-                                                       jlong handle) {
+tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
   if (handle == 0) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to model.");
     return nullptr;
   }
-  return reinterpret_cast<tflite_api_dispatcher::TfLiteModel*>(handle);
+  return reinterpret_cast<tflite::FlatBufferModel*>(handle);
 }
 
 BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
   if (handle == 0) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to ErrorReporter.");
     return nullptr;
   }
@@ -68,7 +68,7 @@ BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
 
 TfLiteDelegate* convertLongToDelegate(JNIEnv* env, jlong handle) {
   if (handle == 0) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to delegate.");
     return nullptr;
   }
@@ -80,7 +80,7 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
   std::vector<int> outputs(size, 0);
   jint* ptr = env->GetIntArrayElements(inputs, nullptr);
   if (ptr == nullptr) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Array has empty dimensions.");
     return {};
   }
@@ -130,7 +130,7 @@ bool AreDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) {
   int num_dims = static_cast<int>(env->GetArrayLength(dims));
   jint* ptr = env->GetIntArrayElements(dims, nullptr);
   if (ptr == nullptr) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Empty dimensions of input array.");
     return true;
   }
@@ -165,12 +165,11 @@ JNIEXPORT jobjectArray JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
                                                                 jclass clazz,
                                                                 jlong handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return nullptr;
   jclass string_class = env->FindClass("java/lang/String");
   if (string_class == nullptr) {
-    ThrowException(env, kUnsupportedOperationException,
+    ThrowException(env, tflite::jni::kUnsupportedOperationException,
                    "Internal error: Can not find java/lang/String class to get "
                    "input names.");
     return nullptr;
@@ -188,8 +187,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
     JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
@@ -197,7 +195,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
 
   if (interpreter->AllocateTensors() != kTfLiteOk) {
     ThrowException(
-        env, kIllegalStateException,
+        env, tflite::jni::kIllegalStateException,
         "Internal error: Unexpected failure when preparing tensor allocations:"
         " %s",
         error_reporter->CachedErrorMessage());
@@ -207,8 +205,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
 JNIEXPORT jboolean JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
     JNIEnv* env, jclass clazz, jlong handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return JNI_FALSE;
 
   // TODO(b/132995737): Remove this logic by caching whether an unresolved
@@ -231,8 +228,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
     JNIEnv* env, jclass clazz, jlong handle, jint input_index) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return interpreter->inputs()[input_index];
 }
@@ -240,8 +236,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
     JNIEnv* env, jclass clazz, jlong handle, jint output_index) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return interpreter->outputs()[output_index];
 }
@@ -249,8 +244,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength(
     JNIEnv* env, jclass clazz, jlong handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return static_cast<jint>(interpreter->execution_plan().size());
 }
@@ -259,8 +253,7 @@ JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
                                                                 jclass clazz,
                                                                 jlong handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return static_cast<jint>(interpreter->inputs().size());
 }
@@ -269,8 +262,7 @@ JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
                                                                  jclass clazz,
                                                                  jlong handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return static_cast<jint>(interpreter->outputs().size());
 }
@@ -279,12 +271,11 @@ JNIEXPORT jobjectArray JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
                                                                  jclass clazz,
                                                                  jlong handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return nullptr;
   jclass string_class = env->FindClass("java/lang/String");
   if (string_class == nullptr) {
-    ThrowException(env, kUnsupportedOperationException,
+    ThrowException(env, tflite::jni::kUnsupportedOperationException,
                    "Internal error: Can not find java/lang/String class to get "
                    "output names.");
     return nullptr;
@@ -302,8 +293,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
     JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow));
 }
@@ -311,23 +301,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
     JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   interpreter->SetAllowBufferHandleOutput(allow);
 }
 
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
-    JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state,
+    JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jint state,
     jint num_threads) {
   // If not using xnnpack, simply don't apply the delegate.
-  if (!state) {
+  if (state == 0) {
     return;
   }
 
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) {
     return;
   }
@@ -355,8 +343,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
     if (num_threads > 0) {
       options.num_threads = num_threads;
     }
-    tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate(
-        xnnpack_create(&options), xnnpack_delete);
+    tflite::Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options),
+                                                    xnnpack_delete);
     auto delegation_status =
         interpreter->ModifyGraphWithDelegate(std::move(delegate));
     // kTfLiteApplicationError occurs in cases where delegation fails but
@@ -365,12 +353,19 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
     // TODO(b/166483905): Add support for multiple delegates when model allows.
     if (delegation_status != kTfLiteOk &&
         delegation_status != kTfLiteApplicationError) {
-      ThrowException(env, kIllegalArgumentException,
+      ThrowException(env, tflite::jni::kIllegalArgumentException,
                      "Internal error: Failed to apply XNNPACK delegate: %s",
                      error_reporter->CachedErrorMessage());
     }
+  } else if (state == -1) {
+    // Instead of throwing an exception, we tolerate the missing of such
+    // dependencies because we try to apply XNNPACK delegate by default.
+    TF_LITE_REPORT_ERROR(
+        error_reporter,
+        "WARNING: Missing necessary XNNPACK delegate dependencies to apply it "
+        "by default.\n");
   } else {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Failed to load XNNPACK delegate from current runtime. "
                    "Have you added the necessary dependencies?");
   }
@@ -381,8 +376,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
                                                              jclass clazz,
                                                              jlong handle,
                                                              jint num_threads) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   interpreter->SetNumThreads(static_cast<int>(num_threads));
 }
@@ -396,7 +390,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
 }
 
 // Verifies whether the model is a flatbuffer file.
-class JNIFlatBufferVerifier : public tflite_api_dispatcher::TfLiteVerifier {
+class JNIFlatBufferVerifier : public tflite::TfLiteVerifier {
  public:
   bool Verify(const char* data, int length,
               tflite::ErrorReporter* reporter) override {
@@ -416,13 +410,13 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
   if (error_reporter == nullptr) return 0;
   const char* path = env->GetStringUTFChars(model_file, nullptr);
 
-  std::unique_ptr<tflite_api_dispatcher::TfLiteVerifier> verifier;
+  std::unique_ptr<tflite::TfLiteVerifier> verifier;
   verifier.reset(new JNIFlatBufferVerifier());
 
-  auto model = tflite_api_dispatcher::TfLiteModel::VerifyAndBuildFromFile(
+  auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(
       path, verifier.get(), error_reporter);
   if (!model) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Contents of %s does not encode a valid "
                    "TensorFlow Lite model: %s",
                    path, error_reporter->CachedErrorMessage());
@@ -443,15 +437,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
       static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
   jlong capacity = env->GetDirectBufferCapacity(model_buffer);
   if (!VerifyModel(buf, capacity)) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "ByteBuffer is not a valid flatbuffer model");
     return 0;
   }
 
-  auto model = tflite_api_dispatcher::TfLiteModel::BuildFromBuffer(
+  auto model = tflite::FlatBufferModel::BuildFromBuffer(
       buf, static_cast<size_t>(capacity), error_reporter);
   if (!model) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "ByteBuffer does not encode a valid model: %s",
                    error_reporter->CachedErrorMessage());
     return 0;
@@ -463,18 +457,17 @@ JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
     JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
     jint num_threads) {
-  tflite_api_dispatcher::TfLiteModel* model =
-      convertLongToModel(env, model_handle);
+  tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
   if (model == nullptr) return 0;
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
   if (error_reporter == nullptr) return 0;
   auto resolver = ::tflite::CreateOpResolver();
-  std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter;
-  TfLiteStatus status = tflite_api_dispatcher::InterpreterBuilder(
-      *model, *(resolver.get()))(&interpreter, static_cast<int>(num_threads));
+  std::unique_ptr<tflite::Interpreter> interpreter;
+  TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))(
+      &interpreter, static_cast<int>(num_threads));
   if (status != kTfLiteOk) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Cannot create interpreter: %s",
                    error_reporter->CachedErrorMessage());
     return 0;
@@ -487,7 +480,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
 // Sets inputs, runs inference, and returns outputs as long handles.
 JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
+  tflite::Interpreter* interpreter =
       convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return;
   BufferErrorReporter* error_reporter =
@@ -496,7 +489,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
 
   if (interpreter->Invoke() != kTfLiteOk) {
     // TODO(b/168266570): Return InterruptedException.
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Failed to run on the given Interpreter: %s",
                    error_reporter->CachedErrorMessage());
     return;
@@ -506,12 +499,11 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
     JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      convertLongToInterpreter(env, handle);
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return -1;
   const int idx = static_cast<int>(output_idx);
   if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Failed to get %d-th output out of %d outputs", output_idx,
                    interpreter->outputs().size());
     return -1;
@@ -528,11 +520,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
   if (error_reporter == nullptr) return JNI_FALSE;
-  tflite_api_dispatcher::Interpreter* interpreter =
+  tflite::Interpreter* interpreter =
       convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return JNI_FALSE;
   if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Input error: Can not resize %d-th input for a model having "
                    "%d inputs.",
                    input_idx, interpreter->inputs().size());
@@ -552,7 +544,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
           tensor_idx, convertJIntArrayToVector(env, dims));
     }
     if (status != kTfLiteOk) {
-      ThrowException(env, kIllegalArgumentException,
+      ThrowException(env, tflite::jni::kIllegalArgumentException,
                      "Internal error: Failed to resize %d-th input: %s",
                      input_idx, error_reporter->CachedErrorMessage());
       return JNI_FALSE;
@@ -565,7 +557,7 @@ JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
     jlong delegate_handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
+  tflite::Interpreter* interpreter =
       convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return;
 
@@ -578,7 +570,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
 
   TfLiteStatus status = interpreter->ModifyGraphWithDelegate(delegate);
   if (status != kTfLiteOk) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Failed to apply delegate: %s",
                    error_reporter->CachedErrorMessage());
   }
@@ -587,7 +579,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
+  tflite::Interpreter* interpreter =
       convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return;
 
@@ -597,7 +589,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
 
   TfLiteStatus status = interpreter->ResetVariableTensors();
   if (status != kTfLiteOk) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Failed to reset variable tensors: %s",
                    error_reporter->CachedErrorMessage());
   }
@@ -606,10 +598,10 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag(
     JNIEnv* env, jclass clazz, jlong interpreter_handle) {
-  tflite_api_dispatcher::Interpreter* interpreter =
+  tflite::Interpreter* interpreter =
       convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to interpreter.");
   }
   std::atomic_bool* cancellation_flag = new std::atomic_bool(false);
diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc
index 5dfca9ebe6c..24a13bbc3f6 100644
--- a/tensorflow/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc
@@ -20,7 +20,7 @@ limitations under the License.
 #include <string>
 
 #include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h"
+#include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/java/src/main/native/jni_utils.h"
 #include "tensorflow/lite/string_util.h"
 
@@ -39,21 +39,20 @@ static const char* kStringClassPath = "java/lang/String";
 // invalidate all TfLiteTensor* handles during inference or allocation.
 class TensorHandle {
  public:
-  TensorHandle(tflite_api_dispatcher::Interpreter* interpreter,
-               int tensor_index)
+  TensorHandle(tflite::Interpreter* interpreter, int tensor_index)
       : interpreter_(interpreter), tensor_index_(tensor_index) {}
 
   TfLiteTensor* tensor() const { return interpreter_->tensor(tensor_index_); }
   int index() const { return tensor_index_; }
 
  private:
-  tflite_api_dispatcher::Interpreter* const interpreter_;
+  tflite::Interpreter* const interpreter_;
   const int tensor_index_;
 };
 
 TfLiteTensor* GetTensorFromHandle(JNIEnv* env, jlong handle) {
   if (handle == 0) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to TfLiteTensor.");
     return nullptr;
   }
@@ -62,7 +61,7 @@ TfLiteTensor* GetTensorFromHandle(JNIEnv* env, jlong handle) {
 
 int GetTensorIndexFromHandle(JNIEnv* env, jlong handle) {
   if (handle == 0) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to TfLiteTensor.");
     return -1;
   }
@@ -110,7 +109,7 @@ size_t WriteOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
   const int num_elements = env->GetArrayLength(array);
   size_t to_copy = num_elements * ElementByteSize(type);
   if (to_copy > dst_size) {
-    ThrowException(env, kIllegalStateException,
+    ThrowException(env, tflite::jni::kIllegalStateException,
                    "Internal error: cannot write Java array of %d bytes to "
                    "Tensor of %d bytes",
                    to_copy, dst_size);
@@ -150,7 +149,7 @@ size_t WriteOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
     }
     default: {
       ThrowException(
-          env, kUnsupportedOperationException,
+          env, tflite::jni::kUnsupportedOperationException,
           "DataType error: TensorFlowLite currently supports float "
           "(32 bits), int (32 bits), byte (8 bits), bool (8 bits), and long "
           "(64 bits), support for other types (DataType %d in this "
@@ -167,7 +166,7 @@ size_t ReadOneDimensionalArray(JNIEnv* env, TfLiteType data_type,
   const size_t size = len * ElementByteSize(data_type);
   if (size > src_size) {
     ThrowException(
-        env, kIllegalStateException,
+        env, tflite::jni::kIllegalStateException,
         "Internal error: cannot fill a Java array of %d bytes with a Tensor of "
         "%d bytes",
         size, src_size);
@@ -205,7 +204,7 @@ size_t ReadOneDimensionalArray(JNIEnv* env, TfLiteType data_type,
       return size;
     }
     default: {
-      ThrowException(env, kIllegalStateException,
+      ThrowException(env, tflite::jni::kIllegalStateException,
                      "DataType error: invalid DataType(%d)", data_type);
     }
   }
@@ -345,7 +344,7 @@ void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst,
   size_t src_size = ElementByteSize(type);
   if (src_size != dst_size) {
     ThrowException(
-        env, kIllegalStateException,
+        env, tflite::jni::kIllegalStateException,
         "Scalar (%d bytes) not compatible with allocated tensor (%d bytes)",
         src_size, dst_size);
     return;
@@ -377,7 +376,8 @@ void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst,
       return;
     }
     default:
-      ThrowException(env, kIllegalStateException, "Invalid DataType(%d)", type);
+      ThrowException(env, tflite::jni::kIllegalStateException,
+                     "Invalid DataType(%d)", type);
       return;
   }
 }
@@ -398,8 +398,8 @@ extern "C" {
 
 JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index) {
-  tflite_api_dispatcher::Interpreter* interpreter =
-      reinterpret_cast<tflite_api_dispatcher::Interpreter*>(interpreter_handle);
+  tflite::Interpreter* interpreter =
+      reinterpret_cast<tflite::Interpreter*>(interpreter_handle);
   return reinterpret_cast<jlong>(new TensorHandle(interpreter, tensor_index));
 }
 
@@ -415,7 +415,7 @@ JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
   TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
   if (tensor == nullptr) return nullptr;
   if (tensor->data.raw == nullptr) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Tensor hasn't been allocated.");
     return nullptr;
   }
@@ -430,13 +430,13 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
 
   void* src_data_raw = env->GetDirectBufferAddress(src);
   if (!src_data_raw) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Input ByteBuffer is not a direct buffer");
     return;
   }
 
   if (!tensor->data.data) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Tensor hasn't been allocated.");
     return;
   }
@@ -459,7 +459,7 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
   if (tensor == nullptr) return;
   int num_dims = tensor->dims->size;
   if (num_dims == 0) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Cannot copy empty/scalar Tensors.");
     return;
   }
@@ -481,12 +481,12 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
   TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
   if (tensor == nullptr) return;
   if (tensor->type != kTfLiteString && tensor->data.raw == nullptr) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Target Tensor hasn't been allocated.");
     return;
   }
   if (tensor->dims->size == 0) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Cannot copy empty/scalar Tensors.");
     return;
   }
@@ -503,12 +503,12 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeScalar(
   TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
   if (tensor == nullptr) return;
   if ((tensor->type != kTfLiteString) && (tensor->data.raw == nullptr)) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Target Tensor hasn't been allocated.");
     return;
   }
   if ((tensor->dims->size != 0) && (tensor->dims->data[0] != 1)) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Cannot write Java scalar to non-scalar "
                    "Tensor.");
     return;
@@ -533,7 +533,7 @@ JNIEXPORT jstring JNICALL Java_org_tensorflow_lite_Tensor_name(JNIEnv* env,
                                                                jlong handle) {
   TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
   if (tensor == nullptr) {
-    ThrowException(env, kIllegalArgumentException,
+    ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Target Tensor doesn't exist.");
     return nullptr;
   }
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index 7009f069bd8..f5b02173033 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -484,6 +484,10 @@ public final class InterpreterTest {
       fail();
     } catch (IllegalStateException e) {
       // Expected failure.
+    } catch (IllegalArgumentException e) {
+      // As we could apply some TfLite delegate by default, the flex ops preparation could fail if
+      // the flex delegate isn't applied first, in which this type of exception is thrown.
+      // Expected failure
     }
   }
 
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 19ea16e94eb..113805e83c5 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -104,7 +104,15 @@ public final class NativeInterpreterWrapperTest {
       NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(MODEL_WITH_CUSTOM_OP_PATH);
       fail();
     } catch (IllegalStateException e) {
-      assertThat(e).hasMessageThat().contains("Encountered unresolved custom op: Assign");
+      assertThat(e)
+          .hasMessageThat()
+          .contains("preparing tensor allocations: Encountered unresolved custom op: Assign");
+    } catch (IllegalArgumentException e) {
+      // As we could apply TfLite delegate by default, during which the prepration of this
+      // unresolved custom op could fail and this type of exception is thrown.
+      assertThat(e)
+          .hasMessageThat()
+          .containsMatch("Failed to apply .* delegate: Encountered unresolved custom op: Assign");
     }
   }
 
@@ -201,8 +209,20 @@ public final class NativeInterpreterWrapperTest {
       outputs.put(0, parsedOutputs);
       wrapper.run(inputs, outputs);
       long[] outputOneD = parsedOutputs[0][0][0];
-      long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
-          -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
+      long[] expected = {
+        -892834092L,
+        923423L,
+        2123918239018L,
+        -892834092L,
+        923423L,
+        2123918239018L,
+        -892834092L,
+        923423L,
+        2123918239018L,
+        -892834092L,
+        923423L,
+        2123918239018L
+      };
       assertThat(outputOneD).isEqualTo(expected);
     }
   }
@@ -222,8 +242,20 @@ public final class NativeInterpreterWrapperTest {
       outputs.put(0, parsedOutputs);
       wrapper.run(inputs, outputs);
       byte[] outputOneD = parsedOutputs[0][0][0];
-      byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
-          (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
+      byte[] expected = {
+        (byte) 0xe0,
+        0x4f,
+        (byte) 0xd0,
+        (byte) 0xe0,
+        0x4f,
+        (byte) 0xd0,
+        (byte) 0xe0,
+        0x4f,
+        (byte) 0xd0,
+        (byte) 0xe0,
+        0x4f,
+        (byte) 0xd0
+      };
       assertThat(outputOneD).isEqualTo(expected);
     }
   }
@@ -242,7 +274,7 @@ public final class NativeInterpreterWrapperTest {
       wrapper.run(inputs, outputs);
       String[] outputOneD = parsedOutputs[0][0][0];
       String[] expected = {
-          "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333"
+        "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333"
       };
       assertThat(outputOneD).isEqualTo(expected);
     }
@@ -276,8 +308,8 @@ public final class NativeInterpreterWrapperTest {
       wrapper.run(inputs, outputs);
       String[] outputOneD = parsedOutputs[0][0][0];
       String[] expected = {
-          "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e",
-          "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e"
+        "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e",
+        "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e"
       };
       assertThat(outputOneD).isEqualTo(expected);
     }
@@ -332,8 +364,8 @@ public final class NativeInterpreterWrapperTest {
       wrapper.run(inputs, outputs);
       byte[] outputOneD = parsedOutputs[0][0][0];
       byte[] expected = {
-          (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
-          (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
+        (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
+        (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
       };
       assertThat(outputOneD).isEqualTo(expected);
     }
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index 1965914f9bc..81f474879e9 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -220,6 +220,7 @@ cc_library(
         ],
     }),
     alwayslink = 1,
+    # TODO(b/161243354): add testonly=1?
 )
 
 cc_library(
@@ -787,7 +788,8 @@ cc_library(
     compatible_with = get_compatible_with_portable(),
     deps = [
         ":builtin_op_kernels",
-        "//tensorflow/lite:framework_lib",
+        "//tensorflow/lite:cc_api",
+        "//tensorflow/lite:mutable_op_resolver",
         "//tensorflow/lite:tflite_with_xnnpack_optional",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/schema:schema_fbs",
@@ -2366,4 +2368,13 @@ cc_test(
     ],
 )
 
+exports_files(
+    [
+        "register.h",
+        "builtin_op_kernels.h",
+        "fully_connected.h",
+    ],
+    visibility = ["//tensorflow/lite/core/shims:__subpackages__"],
+)
+
 tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})
diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h
index 662a1864025..0c42f11ec05 100644
--- a/tensorflow/lite/kernels/internal/common.h
+++ b/tensorflow/lite/kernels/internal/common.h
@@ -186,6 +186,42 @@ inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
   return result;
 }
 
+#ifdef USE_NEON
+// Round uses ARM's rounding shift right.
+inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
+    int32x4x4_t input_val, int32 quantized_multiplier, int shift) {
+  const int left_shift = std::max(shift, 0);
+  const int right_shift = std::min(shift, 0);
+  int32x4x4_t result;
+
+  int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
+  int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
+  int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
+
+  result.val[0] =
+      vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
+                               multiplier_dup),
+                 right_shift_dup);
+
+  result.val[1] =
+      vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
+                               multiplier_dup),
+                 right_shift_dup);
+
+  result.val[2] =
+      vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
+                               multiplier_dup),
+                 right_shift_dup);
+
+  result.val[3] =
+      vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
+                               multiplier_dup),
+                 right_shift_dup);
+
+  return result;
+}
+#endif
+
 template <typename T>
 int CountLeadingZeros(T integer_input) {
   static_assert(std::is_unsigned<T>::value,
diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h
index b2ccef7e942..2535552444a 100644
--- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h
+++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h
@@ -89,8 +89,8 @@ inline void MeanImpl(const tflite::MeanParams& op_params,
         }
       }
 
-      temp_sum = optimized_ops::MultiplyByQuantizedMultiplier4Rows(
-          temp_sum, multiplier, shift);
+      temp_sum =
+          MultiplyByQuantizedMultiplier4Rows(temp_sum, multiplier, shift);
 
       temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup);
       temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup);
diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 07ecdd1208b..2c2edcfbb39 100644
--- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -127,69 +127,6 @@ inline int32_t AccumulateNeonLane(const int32x4_t lane) {
 #endif
 }
 
-// TODO(jaesung): Merge duplicated implementations in optimized_ops.h and
-// neon_tensor_utils.cc.
-inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
-    int32x4x4_t input_val, int32 quantized_multiplier, int shift) {
-  using gemmlowp::RoundingDivideByPOT;
-  using gemmlowp::SaturatingRoundingDoublingHighMul;
-  const int left_shift = shift > 0 ? shift : 0;
-  const int right_shift = shift > 0 ? 0 : -shift;
-  int32x4x4_t result;
-  // The vector type support for SaturatingRoundingDoublingHighMulth in gemmlowp
-  // is limited to NEON.
-#ifdef GEMMLOWP_NEON
-  const int32x4_t left_shifted_one_dup = vdupq_n_s32(1 << left_shift);
-  result.val[0] =
-      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
-                              vmulq_s32(input_val.val[0], left_shifted_one_dup),
-                              quantized_multiplier),
-                          right_shift);
-  result.val[1] =
-      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
-                              vmulq_s32(input_val.val[1], left_shifted_one_dup),
-                              quantized_multiplier),
-                          right_shift);
-  result.val[2] =
-      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
-                              vmulq_s32(input_val.val[2], left_shifted_one_dup),
-                              quantized_multiplier),
-                          right_shift);
-  result.val[3] =
-      RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
-                              vmulq_s32(input_val.val[3], left_shifted_one_dup),
-                              quantized_multiplier),
-                          right_shift);
-#else
-  for (int i = 0; i < 4; ++i) {
-    int32_t vals[4];
-    vals[0] = RoundingDivideByPOT(
-        SaturatingRoundingDoublingHighMul(
-            vgetq_lane_s32(input_val.val[i], 0) * (1 << left_shift),
-            quantized_multiplier),
-        right_shift);
-    vals[1] = RoundingDivideByPOT(
-        SaturatingRoundingDoublingHighMul(
-            vgetq_lane_s32(input_val.val[i], 1) * (1 << left_shift),
-            quantized_multiplier),
-        right_shift);
-    vals[2] = RoundingDivideByPOT(
-        SaturatingRoundingDoublingHighMul(
-            vgetq_lane_s32(input_val.val[i], 2) * (1 << left_shift),
-            quantized_multiplier),
-        right_shift);
-    vals[3] = RoundingDivideByPOT(
-        SaturatingRoundingDoublingHighMul(
-            vgetq_lane_s32(input_val.val[i], 3) * (1 << left_shift),
-            quantized_multiplier),
-        right_shift);
-
-    result.val[i] = vld1q_s32(reinterpret_cast<int32_t*>(&vals));
-  }
-#endif
-  return result;
-}
-
 inline int32x4x2_t MultiplyByQuantizedMultiplier2Rows(
     int32x4x2_t input_val, int32 quantized_multiplier, int shift) {
   using gemmlowp::RoundingDivideByPOT;
diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
index bfa5eb3075d..3638a4e5874 100644
--- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -15,9 +15,6 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
 
-// TODO(ghodrat): Remove this header file and the dependency to internal data
-// structure.
-#include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/kernels/cpu_backend_context.h"
 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index 6cabea11ac4..cbe62516a52 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -201,43 +201,6 @@ MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
   return MatrixMap<Scalar>(data, rows, cols);
 }
 
-// TODO(renjieliu): Refactor this to merge with other
-// MultiplyByQuantizedMultipler.
-#ifdef USE_NEON
-inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
-    int32x4x4_t input_val, int32 quantized_multiplier, int32 shift) {
-  const int left_shift = std::max(shift, 0);
-  const int right_shift = std::min(shift, 0);
-  int32x4x4_t result;
-
-  int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
-  int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
-  int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
-
-  result.val[0] =
-      vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
-                               multiplier_dup),
-                 right_shift_dup);
-
-  result.val[1] =
-      vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
-                               multiplier_dup),
-                 right_shift_dup);
-
-  result.val[2] =
-      vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
-                               multiplier_dup),
-                 right_shift_dup);
-
-  result.val[3] =
-      vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
-                               multiplier_dup),
-                 right_shift_dup);
-
-  return result;
-}
-#endif
-
 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
 inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
                                     const RuntimeShape& unswitched_input1_shape,
@@ -418,7 +381,7 @@ inline void FullyConnected(
   const int32 output_activation_max = params.quantized_activation_max;
   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
-  // TODO(benoitjacob): This really should be:
+  // TODO(b/62193649): This really should be:
   //     const int batches = ArraySize(output_dims, 1);
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
@@ -484,7 +447,7 @@ inline void FullyConnected(
   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
 
-  // TODO(benoitjacob): This really should be:
+  // TODO(b/62193649): This really should be:
   //     const int batches = ArraySize(output_dims, 1);
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
@@ -862,7 +825,7 @@ inline void ShuffledFullyConnected(
   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
-  // TODO(benoitjacob): This really should be:
+  // TODO(b/62193649): This really should be:
   //     const int batches = ArraySize(output_dims, 1);
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
index 77b4fab0e42..649e525cbf9 100644
--- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
@@ -24,9 +24,6 @@ limitations under the License.
 // NEON_2_SSE translator library. If a native SSE version of a function is
 // implemented, replace the appropriate one to SSE_OR_PORTABLE.
 
-// TODO(ghodrat): Remove this header file and the dependency to internal data
-// structure.
-#include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/kernels/cpu_backend_context.h"
 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
 #include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
diff --git a/tensorflow/lite/kernels/internal/reference/add.h b/tensorflow/lite/kernels/internal/reference/add.h
index 5be7ab4dc0c..3da76d88b97 100644
--- a/tensorflow/lite/kernels/internal/reference/add.h
+++ b/tensorflow/lite/kernels/internal/reference/add.h
@@ -202,14 +202,6 @@ inline void Add(const ArithmeticParams& params,
   }
 }
 
-// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
 inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
                                const RuntimeShape& input1_shape,
                                const float* input1_data,
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
index 338adf8c2ee..8f0f1e8543e 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -428,7 +428,7 @@ void PortableApplyLayerNorm(const int16_t* input,
     }
     int32_t mean =
         static_cast<int32_t>(static_cast<int64_t>(sum) * 1024 / n_input);
-    // TODO(jianlijianli): Avoids overflow but only works for POT n_input.
+    // TODO(b/173994730): Avoids overflow but only works for POT n_input.
     int32_t temp = kTwoToPower20 / n_input;
     int64_t variance =
         sum_sq * temp - static_cast<int64_t>(mean) * static_cast<int64_t>(mean);
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index bbc156cc3be..fe7fde50aef 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -179,7 +179,7 @@ inline void Elu(const RuntimeShape& input_shape, const float* input_data,
   const int flat_size = MatchingFlatSize(input_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     const float val = input_data[i];
-    output_data[i] = val < 0.0 ? std::exp(val) - 1 : val;
+    output_data[i] = val < 0.0f ? std::expm1(val) : val;
   }
 }
 
@@ -319,10 +319,6 @@ inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs,
 // dimensionality if the runtime code does a single loop over one dimension
 // that handles broadcasting as the base case. The code generator would then
 // generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastMul is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
                                  const RuntimeShape& unswitched_input1_shape,
                                  const uint8* unswitched_input1_data,
diff --git a/tensorflow/lite/kernels/internal/reference/sub.h b/tensorflow/lite/kernels/internal/reference/sub.h
index b27f251de6c..5f4fd19e37d 100644
--- a/tensorflow/lite/kernels/internal/reference/sub.h
+++ b/tensorflow/lite/kernels/internal/reference/sub.h
@@ -65,10 +65,6 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
 // dimensionality if the runtime code does a single loop over one dimension
 // that handles broadcasting as the base case. The code generator would then
 // generate max(D1, D2) nested for loops.
-// TODO(b/151345101): BroadcastSub is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
 template <int N = 5>
 inline void BroadcastSubSlow(const ArithmeticParams& params,
                              const RuntimeShape& input1_shape,
diff --git a/tensorflow/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/lite/kernels/internal/tensor_utils_test.cc
index 1d0a4d50eb2..3e10548a106 100644
--- a/tensorflow/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/lite/kernels/internal/tensor_utils_test.cc
@@ -30,6 +30,29 @@ limitations under the License.
 namespace tflite {
 namespace tensor_utils {
 
+// Normally we should require bit-for-bit exact results. Unfortunately a bug
+// in the Intel arm_neon_sse.h translation header that we use for x86 tests
+// causes 1-bit inaccuracy in the vqrdmulh_n_s32 intrinsic, which causes
+// off-by-1 errors. So we have to live with a
+// few off-by-one errors for now, yet still ensure that no more than a small
+// minority of values are wrong.
+// This util is to compare the rounding results for integer-output.
+template <typename T>
+void CompareRoundingResults(int flat_size, const T* expected_result,
+                            const T* real_result, int max_element_tolerance = 1,
+                            int max_total_tolerance = 5) {
+  int max_diff = 0;
+  int64_t total_diff = 0;
+  for (int i = 0; i < flat_size; i++) {
+    int diff = static_cast<int>(std::abs(expected_result[i] - real_result[i]));
+    total_diff += diff;
+    max_diff = std::max(max_diff, diff);
+  }
+
+  EXPECT_LE(max_diff, max_element_tolerance);
+  EXPECT_LE(total_diff, max_total_tolerance);
+}
+
 TEST(uKernels, FloorLog2Test) {
   for (int i = 1; i < 257; ++i) {
     EXPECT_EQ(::tflite::FloorLog2(i),
@@ -1758,7 +1781,7 @@ TEST(uKernels, VectorBatchVectorCwiseProductAccumulateInteger) {
 
   const std::vector<int16_t> expected_output = {
       /* batch 0 */
-      -35, 34, 32, 30, 27, 24, 20, 16, 11, -2, 10, 13, 16, 18, 19, 20, 21, 21,
+      -35, 34, 32, 30, 27, 24, 20, 16, 11, -1, 10, 13, 16, 18, 19, 20, 21, 21,
       20, 0, 4, 8, 12, 17, 23, 29, 35, 42, 50,
       /* batch 1 */
       27, 24, 20, 18, 15, 14, 12, 12, 1, 2, 2, 6, 10, 15, 20, 26, 32, 39, 26, 9,
@@ -1769,7 +1792,9 @@ TEST(uKernels, VectorBatchVectorCwiseProductAccumulateInteger) {
       /* batch 3 */
       17, 21, 14, 17, 18, 20, 20, 21, 20, 20, 18, -7, 13, 14, 13, 13, 11, 10, 7,
       5, 26, 31, 37, 56, 63, 72, 80, 90, 99};
-  EXPECT_THAT(batch_output, testing::ElementsAreArray(expected_output));
+  // Only allow 1 element difference for the rounding result.
+  CompareRoundingResults<int16_t>(4 * 29, expected_output.data(),
+                                  batch_output.data(), 1, 1);
 }
 
 TEST(uKernels, VectorBatchVectorCwiseProductAccumulateFloat) {
diff --git a/tensorflow/lite/kernels/reshape.cc b/tensorflow/lite/kernels/reshape.cc
index 2a21fa730bc..d764e1f81b2 100644
--- a/tensorflow/lite/kernels/reshape.cc
+++ b/tensorflow/lite/kernels/reshape.cc
@@ -47,7 +47,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
   // Tensorflow's Reshape allows one of the shape components to have the
   // special -1 value, meaning it will be calculated automatically based on the
   // input. Here we calculate what that dimension should be so that the number
-  // of output elements in the same as the number of input elements.
+  // of output elements is the same as the number of input elements.
   int num_input_elements = NumElements(input);
 
   int num_output_elements = 1;
diff --git a/tensorflow/lite/micro/README.md b/tensorflow/lite/micro/README.md
index ebde7dbc81f..613636cf29b 100644
--- a/tensorflow/lite/micro/README.md
+++ b/tensorflow/lite/micro/README.md
@@ -24,7 +24,7 @@ kilobytes of memory.
 To learn how to use the framework, visit the developer documentation at
 [tensorflow.org/lite/microcontrollers](https://www.tensorflow.org/lite/microcontrollers).
 
-# Continuous Buils Status
+# Continuous Build Status
 
 Build Type | Status                                                                                                                                                                       | Artifacts
 ---------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD
index 22a0d035471..ee6d55a801c 100644
--- a/tensorflow/lite/micro/kernels/BUILD
+++ b/tensorflow/lite/micro/kernels/BUILD
@@ -14,11 +14,6 @@ config_setting(
     define_values = {"tflm_build": "xtensa_hifimini"},
 )
 
-config_setting(
-    name = "xtensa_hifimini_staging",
-    define_values = {"tflm_build": "xtensa_hifimini_staging"},
-)
-
 package_group(
     name = "micro",
     packages = ["//tensorflow/lite/micro/..."],
@@ -35,10 +30,7 @@ cc_library(
         "//conditions:default": [
         ],
         ":xtensa_hifimini": [
-            "xtensa_hifimini/fixedpoint_utils.h",
-        ],
-        ":xtensa_hifimini_staging": [
-            "xtensa_hifimini/fixedpoint_utils.h",
+            "xtensa/fixedpoint_utils.h",
         ],
     }),
     copts = micro_copts(),
@@ -58,10 +50,7 @@ cc_library(
             "fully_connected.cc",
         ],
         ":xtensa_hifimini": [
-            "xtensa_hifimini/fully_connected.cc",
-        ],
-        ":xtensa_hifimini_staging": [
-            "xtensa_hifimini_staging/fully_connected.cc",
+            "xtensa/fully_connected.cc",
         ],
     }),
     hdrs = ["fully_connected.h"],
@@ -71,22 +60,14 @@ cc_library(
         ":micro",
     ],
     deps = [
-        ":activation_utils",
         ":fixedpoint_utils",
         ":kernel_util",
-        ":micro_utils",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/kernels:kernel_util",
-        "//tensorflow/lite/kernels:op_macros",
-        "//tensorflow/lite/kernels:padding",
         "//tensorflow/lite/kernels/internal:common",
-        "//tensorflow/lite/kernels/internal:compatibility",
         "//tensorflow/lite/kernels/internal:quantization_util",
         "//tensorflow/lite/kernels/internal:reference_base",
         "//tensorflow/lite/kernels/internal:tensor",
-        "//tensorflow/lite/kernels/internal:types",
-        "//tensorflow/lite/micro:memory_helpers",
-        "//tensorflow/lite/micro:micro_utils",
     ] + select({
         "//conditions:default": [],
         ":xtensa_hifimini": [
@@ -141,23 +122,11 @@ cc_library(
             "svdf.cc",
         ],
         ":xtensa_hifimini": [
-            "xtensa_hifimini/conv.cc",
-            "xtensa_hifimini/depthwise_conv.cc",
-            "xtensa_hifimini/quantize.cc",
-            "xtensa_hifimini/softmax.cc",
-            "xtensa_hifimini/svdf.cc",
-        ],
-        ":xtensa_hifimini_staging": [
-            # TODO(b/144176795): finer granularity would help reduce the
-            # duplication of srcs in the BUILD rules (in this case conv.cc and
-            # depthwise_conv.cc). We are falling back to reference kernels in
-            # case the optimized kernels are not implemented to match the
-            # behavior that we get with the Makefiles.
-            "conv.cc",
-            "depthwise_conv.cc",
-            "xtensa_hifimini_staging/quantize.cc",
-            "xtensa_hifimini_staging/softmax.cc",
-            "xtensa_hifimini_staging/svdf.cc",
+            "xtensa/conv.cc",
+            "xtensa/depthwise_conv.cc",
+            "xtensa/quantize.cc",
+            "xtensa/softmax.cc",
+            "xtensa/svdf.cc",
         ],
     }),
     hdrs = ["micro_ops.h"],
diff --git a/tensorflow/lite/micro/kernels/circular_buffer.cc b/tensorflow/lite/micro/kernels/circular_buffer.cc
index f70203062a4..5ce8dbe14c8 100644
--- a/tensorflow/lite/micro/kernels/circular_buffer.cc
+++ b/tensorflow/lite/micro/kernels/circular_buffer.cc
@@ -65,45 +65,49 @@ struct OpData {
   int cycles_max;
 };
 
-// These constants represent constants specific to the music detect model.
-// They exist until (b/132070898) is fixed.
-constexpr int kMaxOpDataSize = 7;
-int op_data_counter = 0;
-OpData op_data_array[kMaxOpDataSize];
-
 }  // namespace
 
-void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; }
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+  return context->AllocatePersistentBuffer(context, sizeof(OpData));
+}
 
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  TF_LITE_ENSURE(context, input != nullptr);
   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-  TF_LITE_ENSURE(context, output != nullptr);
+
+  TFLITE_DCHECK(node->user_data != nullptr);
+  OpData* op_data = static_cast<OpData*>(node->user_data);
 
   TF_LITE_ENSURE(context, input != nullptr);
   TF_LITE_ENSURE(context, output != nullptr);
-  TF_LITE_ENSURE_EQ(context, 1, output->dims->data[0]);
-  TF_LITE_ENSURE_EQ(context, 1, input->dims->data[0]);
+  TF_LITE_ENSURE_EQ(context, input->dims->data[0], output->dims->data[0]);
   TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]);
-  TF_LITE_ENSURE_EQ(context, 1, output->dims->data[2]);
-  TF_LITE_ENSURE_EQ(context, 1, input->dims->data[2]);
+  TF_LITE_ENSURE_EQ(context, input->dims->data[2], output->dims->data[2]);
   TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]);
 
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
 
-  // The circular buffer custom operator currently only supports int8_t.
+  // The circular buffer custom operator currently only supports int8.
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
 
-  // TODO(b/132070898): Use statically slotted OpData structures until a
-  // scratch memory API is ready.
-  TFLITE_DCHECK_LE(op_data_counter, kMaxOpDataSize);
-  OpData* op_data = &op_data_array[op_data_counter++];
-  // The last circular buffer layer (length 5) simply accumulates outputs, and
-  // does not run periodically.
+  // The last circular buffer layer simply accumulates outputs, and does not run
+  // periodically.
   // TODO(b/150001379): Move this special case logic to the tflite flatbuffer.
-  if (output->dims->data[1] == 5) {
+  static int cb_prepare_count = 0;
+  cb_prepare_count++;
+  // These checks specifically work for the only two streaming models supported
+  // on TFLM. They use the shape of the output tensor along with the layer
+  // number to determine if the circular buffer period should be 1 or 2.
+
+  // These models are outlined int the following documents:
+  // https://docs.google.com/document/d/1lc_G2ZFhjiKFo02UHjBaljye1xsL0EkfybkaVELEE3Q/edit?usp=sharing
+  // https://docs.google.com/document/d/1pGc42PuWyrk-Jy1-9qeqtggvsmHr1ifz8Lmqfpr2rKA/edit?usp=sharing
+  if (output->dims->data[1] == 5 || output->dims->data[1] == 13 ||
+      (cb_prepare_count == 5 && output->dims->data[2] == 2 &&
+       output->dims->data[3] == 96)) {
     op_data->cycles_max = 1;
+    cb_prepare_count = 0;
   } else {
     op_data->cycles_max = 2;
   }
@@ -127,10 +131,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   TfLiteEvalTensor* output =
       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
 
+  TFLITE_DCHECK(node->user_data != nullptr);
   OpData* data = reinterpret_cast<OpData*>(node->user_data);
 
   int num_slots = output->dims->data[1];
-  int depth = output->dims->data[3];
+  int depth = output->dims->data[2] * output->dims->data[3];
 
   if (input->type == kTfLiteInt8) {
     EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
@@ -148,12 +153,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
     return static_cast<TfLiteStatus>(kTfLiteAbort);
   }
 
-  // If prepare is ever called more than one time (for example, when testing the
-  // ambient model, the interpreter is created a few times), this op data
-  // counter needs to be reset so that future instances do not overrun this op
-  // data array.
-  op_data_counter = 0;
-
   data->cycles_until_run = data->cycles_max;
 
   return kTfLiteOk;
@@ -162,8 +161,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace circular_buffer
 
 TfLiteRegistration* Register_CIRCULAR_BUFFER() {
-  static TfLiteRegistration r = {/*init=*/nullptr,
-                                 /*free=*/circular_buffer::Free,
+  static TfLiteRegistration r = {/*init=*/circular_buffer::Init,
+                                 /*free=*/nullptr,
                                  /*prepare=*/circular_buffer::Prepare,
                                  /*invoke=*/circular_buffer::Eval,
                                  /*profiling_string=*/nullptr,
diff --git a/tensorflow/lite/micro/kernels/xtensa/conv.cc b/tensorflow/lite/micro/kernels/xtensa/conv.cc
index de9820b82d9..0af54c13bf6 100644
--- a/tensorflow/lite/micro/kernels/xtensa/conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/conv.cc
@@ -25,7 +25,7 @@ limitations under the License.
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/padding.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
 
 namespace tflite {
 namespace {
diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc
index 12410a94456..b0ecedcb8ea 100644
--- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc
@@ -25,7 +25,7 @@ limitations under the License.
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/padding.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
 
 namespace tflite {
 namespace {
diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc
index 30a5b6a602a..165e243a6d7 100644
--- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc
@@ -25,7 +25,7 @@ limitations under the License.
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
 
 namespace tflite {
 namespace {
diff --git a/tensorflow/lite/micro/kernels/xtensa/quantize.cc b/tensorflow/lite/micro/kernels/xtensa/quantize.cc
index b867e70d98b..05646a32ad7 100644
--- a/tensorflow/lite/micro/kernels/xtensa/quantize.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/quantize.cc
@@ -19,10 +19,12 @@ limitations under the License.
 
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/requantize.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/micro_utils.h"
 
 namespace tflite {
 namespace {
@@ -30,6 +32,12 @@ namespace {
 struct OpData {
   int32_t zero_point = 0;
   int scale_multiplier = 0;
+
+  // Use 32-bit multiplier and scale for requantize version of this operator
+  // to preserve compatibility with reference op.
+  int32_t requantize_output_multiplier;
+  int requantize_output_shift;
+  int32_t input_zero_point = 0;
 };
 
 void AffineQuantize(int scale_multiplier, const int32_t zero_point,
@@ -116,6 +124,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
       CreateQConstantForInt24(0, input->params.scale / output->params.scale);
 
   op_data->zero_point = output->params.zero_point;
+  op_data->input_zero_point = input->params.zero_point;
+
+  double effective_scale = static_cast<double>(input->params.scale) /
+                           static_cast<double>(output->params.scale);
+  QuantizeMultiplier(effective_scale, &op_data->requantize_output_multiplier,
+                     &op_data->requantize_output_shift);
 
   return kTfLiteOk;
 }
@@ -127,21 +141,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
 
-  tflite::QuantizationParams op_params;
-  op_params.zero_point = op_data->zero_point;
-
-  if (input->type != kTfLiteInt16 && output->type != kTfLiteInt8) {
+  if (output->type == kTfLiteInt8 && input->type == kTfLiteInt16) {
+    AffineQuantize(op_data->scale_multiplier, op_data->zero_point,
+                   tflite::micro::GetTensorShape(input),
+                   tflite::micro::GetTensorData<int16_t>(input),
+                   tflite::micro::GetTensorShape(output),
+                   tflite::micro::GetTensorData<int8_t>(output));
+  } else if (output->type == kTfLiteInt32 && input->type == kTfLiteInt16) {
+    int size = ElementCount(*input->dims);
+    reference_ops::Requantize(tflite::micro::GetTensorData<int16_t>(input),
+                              size, op_data->requantize_output_multiplier,
+                              op_data->requantize_output_shift,
+                              op_data->input_zero_point, op_data->zero_point,
+                              tflite::micro::GetTensorData<int32_t>(output));
+  } else {
     TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
                        TfLiteTypeGetName(input->type),
                        TfLiteTypeGetName(output->type));
     return kTfLiteError;
   }
-
-  AffineQuantize(op_data->scale_multiplier, op_data->zero_point,
-                 tflite::micro::GetTensorShape(input),
-                 tflite::micro::GetTensorData<int16_t>(input),
-                 tflite::micro::GetTensorShape(output),
-                 tflite::micro::GetTensorData<int8_t>(output));
   return kTfLiteOk;
 }
 
diff --git a/tensorflow/lite/micro/kernels/xtensa/svdf.cc b/tensorflow/lite/micro/kernels/xtensa/svdf.cc
index 28f8f1e1af0..5392e50245e 100644
--- a/tensorflow/lite/micro/kernels/xtensa/svdf.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/svdf.cc
@@ -25,7 +25,7 @@ limitations under the License.
 #include "tensorflow/lite/kernels/op_macros.h"
 #include "tensorflow/lite/micro/kernels/activation_utils.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
+#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
 
 namespace tflite {
 namespace {
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc
deleted file mode 100644
index de9820b82d9..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc
+++ /dev/null
@@ -1,456 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/conv.h"
-
-#include <xtensa/tie/xt_hifi2.h>
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/padding.h"
-#include "tensorflow/lite/micro/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
-
-namespace tflite {
-namespace {
-
-constexpr int kInputTensor = 0;
-constexpr int kFilterTensor = 1;
-constexpr int kBiasTensor = 2;
-constexpr int kOutputTensor = 0;
-
-// Conv is quantized along dimension 0:
-// https://www.tensorflow.org/lite/performance/quantization_spec
-constexpr int kConvQuantizedDimension = 0;
-
-struct OpData {
-  TfLitePaddingValues padding;
-  // The scaling factor from input to output (aka the 'real multiplier') can
-  // be represented as a fixed point multiplier plus a left shift.
-  int32_t output_multiplier;
-  int output_shift;
-
-  // Cached tensor zero point values for quantized operations.
-  int32_t input_zero_point;
-  int32_t output_zero_point;
-
-  // Per channel output multiplier and shift.
-  int32_t* per_channel_output_multiplier;
-  int32_t* per_channel_output_shift;
-
-  // The range of the fused activation layer. For example for kNone and
-  // uint8_t these would be 0 and 255.
-  int32_t output_activation_min;
-  int32_t output_activation_max;
-};
-
-void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier,
-                    const int32_t* output_shift,
-                    const RuntimeShape& input_shape, const int8_t* input_data,
-                    const RuntimeShape& filter_shape, const int8_t* filter_data,
-                    const RuntimeShape& bias_shape, const int32_t* bias_data,
-                    const RuntimeShape& output_shape, int8_t* output_data) {
-  const int stride_width = params.stride_width;
-  const int stride_height = params.stride_height;
-  const int dilation_width_factor = params.dilation_width_factor;
-  const int dilation_height_factor = params.dilation_height_factor;
-  const int pad_width = params.padding_values.width;
-  const int pad_height = params.padding_values.height;
-  const int32_t input_offset = params.input_offset;
-  const int32_t output_offset = params.output_offset;
-  const int32_t output_activation_min = params.quantized_activation_min;
-  const int32_t output_activation_max = params.quantized_activation_max;
-
-  const int batches = input_shape.Dims(0);
-
-  const int input_height = input_shape.Dims(1);
-  const int input_width = input_shape.Dims(2);
-  const int input_depth = input_shape.Dims(3);
-
-  const int filter_height = filter_shape.Dims(1);
-  const int filter_width = filter_shape.Dims(2);
-  const int filter_depth = filter_shape.Dims(3);
-
-  const int output_height = output_shape.Dims(1);
-  const int output_width = output_shape.Dims(2);
-  const int output_depth = output_shape.Dims(3);
-
-  ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset);
-  ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset);
-  ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min);
-  ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max);
-
-  for (int batch = 0; batch < batches; ++batch) {
-    for (int out_y = 0; out_y < output_height; ++out_y) {
-      const int in_y_origin = (out_y * stride_height) - pad_height;
-      for (int out_x = 0; out_x < output_width; ++out_x) {
-        const int in_x_origin = (out_x * stride_width) - pad_width;
-        for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
-          ae_q56s acc_56 = AE_ZEROQ56();
-
-          for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
-            for (int filter_x = 0; filter_x < filter_width; filter_x += 2) {
-              const int in_x = in_x_origin + dilation_width_factor * filter_x;
-              const int in_y = in_y_origin + dilation_height_factor * filter_y;
-              const bool is_point_inside_image =
-                  (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
-                  (in_y < input_height);
-              if (is_point_inside_image) {
-                // Find current input index, minus 2 for Xtensa load
-                // alignments:
-                // TODO(b/147322595): Consider doing these offset calculations
-                // with intrinsics:
-                int input_idx =
-                    ((batch * input_height + in_y) * input_width + in_x) *
-                        input_depth * 2 -
-                    2;
-                const int8_t* input_vals_offset_ptr = input_data + input_idx;
-                for (int i = 0; i < input_depth; i += 2) {
-                  // Load signed 2x 8bit values and right shift into 24bit
-                  // alignment:
-                  ae_p24x2s input_vals_24x2;
-                  AE_LP8X2F_IU(input_vals_24x2, input_vals_offset_ptr, 2);
-                  input_vals_24x2 = AE_P24X2S_SRAI(input_vals_24x2, 16);
-
-                  // Add input offset (24bit aligned):
-                  input_vals_24x2 =
-                      AE_P24S_ADDS_P24X2S(input_vals_24x2, input_offset_24x2);
-
-                  // Find current filter index, minus 2 for Xtensa load
-                  // alignments:
-                  int filter_idx =
-                      ((out_channel * filter_height + filter_y) * filter_width +
-                       filter_x) *
-                          filter_depth +
-                      i - 2;
-                  const int8_t* filter_vals_offset_ptr =
-                      filter_data + filter_idx;
-
-                  // Load signed 2x 8bit values and right shift into 24bit
-                  // alignment:
-                  ae_p24x2s filter_vals_24x2;
-                  AE_LP8X2F_IU(filter_vals_24x2, filter_vals_offset_ptr, 2);
-                  filter_vals_24x2 = AE_P24X2S_SRAI(filter_vals_24x2, 16);
-
-                  // Multiply and accumulate into 48bit bit space:
-                  AE_MULAAP24S_HH_LL(acc_56, filter_vals_24x2, input_vals_24x2);
-                }
-              }
-            }
-          }
-
-          // Left shift from 48bit alignment to 32bit:
-          acc_56 = AE_Q56S_SLAI(acc_56, 16);
-
-          if (bias_data) {
-            // Load and add bias at 32bit alignment:
-            ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[out_channel]);
-            acc_56 = AE_ADDQ56(acc_56, bias_56);
-          }
-
-          // Shift from 32bit alignment to 24bit alignment and place back on
-          // the PR register:
-          acc_56 = AE_Q56S_SLAI(acc_56, 8);
-          ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56);
-
-          // Apply quantized multiplier and accumulate result at 48bit
-          // alignment. Convert the (unsigned) 32-bit multiplier down to a
-          // 24-bit multiplier.
-          acc_56 = MultiplyByQuantizedMultiplier(
-              acc_24x2, output_multiplier[out_channel] >> 8,
-              output_shift[out_channel]);
-
-          // Add output offset, cap activation, and assign to the output:
-          acc_56 = AE_ADDQ56(acc_56, output_offset_56);
-          acc_56 = AE_MINQ56S(acc_56, output_activation_max_56);
-          acc_56 = AE_MAXQ56S(acc_56, output_activation_min_56);
-
-          int output_idx =
-              ((batch * output_height + out_y) * output_width + out_x) *
-                  output_depth +
-              out_channel;
-          output_data[output_idx] = static_cast<int8_t>(AE_TRUNCA32Q48(acc_56));
-        }
-      }
-    }
-  }
-}
-
-// TODO(b/154240772): Move shared code into common methods.
-inline void Conv1x32Input32x32Filter(
-    const int input_offset, const int output_offset,
-    const int quantized_activation_min, const int quantized_activation_max,
-    const int32_t* output_multiplier, const int32_t* output_shift,
-    const RuntimeShape& input_shape, const int8_t* input_data,
-    const RuntimeShape& filter_shape, const int8_t* filter_data,
-    const RuntimeShape& bias_shape, const int32_t* bias_data,
-    const RuntimeShape& output_shape, int8_t* output_data) {
-  ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset);
-  ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset);
-  ae_q56s output_activation_max_56 = AE_CVTQ48A32S(quantized_activation_max);
-  ae_q56s output_activation_min_56 = AE_CVTQ48A32S(quantized_activation_min);
-
-  constexpr int kChannels = 32;
-  constexpr int kFilterDepth = 32;
-  for (int ch = 0; ch < kChannels; ch++) {
-    ae_q56s acc_56 = AE_ZEROQ56();
-    const int8_t* input_vals_ptr = input_data - 2;
-    for (int i = 0; i < kFilterDepth; i += 2) {
-      // Load signed 2x 8bit values and right shift into 24bit
-      // alignment:
-      ae_p24x2s input_vals_24x2;
-      AE_LP8X2F_IU(input_vals_24x2, input_vals_ptr, 2);
-      input_vals_24x2 = AE_P24X2S_SRAI(input_vals_24x2, 16);
-
-      // Add input offset (24bit aligned):
-      input_vals_24x2 = AE_P24S_ADDS_P24X2S(input_vals_24x2, input_offset_24x2);
-      // Find current filter index, minus 2 for Xtensa load
-      // alignments:
-      const int filter_idx = ch * kFilterDepth + i - 2;
-      const int8_t* filter_vals_offset_ptr = filter_data + filter_idx;
-
-      // Load signed 2x 8bit values and right shift into 24bit
-      // alignment:
-      ae_p24x2s filter_vals_24x2;
-      AE_LP8X2F_IU(filter_vals_24x2, filter_vals_offset_ptr, 2);
-      filter_vals_24x2 = AE_P24X2S_SRAI(filter_vals_24x2, 16);
-
-      // Multiply and accumulate into 48bit bit space:
-      AE_MULAAP24S_HH_LL(acc_56, filter_vals_24x2, input_vals_24x2);
-    }
-    // Left shift from 48bit alignment to 32bit:
-    acc_56 = AE_Q56S_SLAI(acc_56, 16);
-    if (bias_data) {
-      // Load and add bias at 32bit alignment:
-      ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[ch]);
-      acc_56 = AE_ADDQ56(acc_56, bias_56);
-    }
-
-    // Shift from 32bit alignment to 24bit alignment and place back on
-    // the PR register:
-    acc_56 = AE_Q56S_SLAI(acc_56, 8);
-    ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56);
-
-    // Apply quantized multiplier and accumulate result at 48bit alignment.
-    // Convert the (unsigned) 32-bit multiplier down to a 24-bit multiplier.
-    acc_56 = MultiplyByQuantizedMultiplier(acc_24x2, output_multiplier[ch] >> 8,
-                                           output_shift[ch]);
-
-    // Add output offset, cap activation, and assign to the output:
-    acc_56 = AE_ADDQ56(acc_56, output_offset_56);
-    acc_56 = AE_MINQ56S(acc_56, output_activation_max_56);
-    acc_56 = AE_MAXQ56S(acc_56, output_activation_min_56);
-
-    output_data[ch] = static_cast<int8_t>(AE_TRUNCA32Q48(acc_56));
-  }
-}
-
-TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
-                             TfLiteConvParams* params, int width, int height,
-                             int filter_width, int filter_height, int out_width,
-                             int out_height, const TfLiteType data_type,
-                             OpData* data) {
-  bool has_bias = node->inputs->size == 3;
-  // Check number of inputs/outputs
-  TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
-  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-
-  // Matching GetWindowedOutputSize in TensorFlow.
-  auto padding = params->padding;
-  data->padding = ComputePaddingHeightWidth(
-      params->stride_height, params->stride_width,
-      params->dilation_height_factor, params->dilation_width_factor, height,
-      width, filter_height, filter_width, padding, &out_height, &out_width);
-
-  // Note that quantized inference requires that all tensors have their
-  // parameters set. This is usually done during quantized training.
-  if (data_type != kTfLiteFloat32) {
-    const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-    const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
-    const TfLiteTensor* bias =
-        GetOptionalInputTensor(context, node, kBiasTensor);
-    TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-    int output_channels = filter->dims->data[kConvQuantizedDimension];
-
-    return tflite::PopulateConvolutionQuantizationParams(
-        context, input, filter, bias, output, params->activation,
-        &data->output_multiplier, &data->output_shift,
-        &data->output_activation_min, &data->output_activation_max,
-        data->per_channel_output_multiplier,
-        reinterpret_cast<int*>(data->per_channel_output_shift),
-        output_channels);
-  }
-  return kTfLiteOk;
-}
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  TFLITE_DCHECK(node->builtin_data != nullptr);
-  auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
-
-  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
-
-  auto* op_data = reinterpret_cast<OpData*>(node->user_data);
-
-  int input_width = input->dims->data[2];
-  int input_height = input->dims->data[1];
-  int filter_width = filter->dims->data[2];
-  int filter_height = filter->dims->data[1];
-  int output_width = output->dims->data[2];
-  int output_height = output->dims->data[1];
-
-  // Per channel quantization is only needed for int8_t inference. For other
-  // quantized types, only a single scale and zero point is needed.
-  const int num_channels = filter->dims->data[kConvQuantizedDimension];
-  // Dynamically allocate per-channel quantization parameters.
-  op_data->per_channel_output_multiplier =
-      reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
-          context, num_channels * sizeof(int32_t)));
-  op_data->per_channel_output_shift =
-      reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
-          context, num_channels * sizeof(int32_t)));
-  op_data->input_zero_point = input->params.zero_point;
-  op_data->output_zero_point = output->params.zero_point;
-  // All per-channel quantized tensors need valid zero point and scale arrays.
-  if (input->type == kTfLiteInt8) {
-    TF_LITE_ENSURE_EQ(context, filter->quantization.type,
-                      kTfLiteAffineQuantization);
-
-    const auto* affine_quantization =
-        reinterpret_cast<TfLiteAffineQuantization*>(
-            filter->quantization.params);
-    TF_LITE_ENSURE(context, affine_quantization);
-    TF_LITE_ENSURE(context, affine_quantization->scale);
-    TF_LITE_ENSURE(context, affine_quantization->zero_point);
-
-    TF_LITE_ENSURE(context,
-                   affine_quantization->scale->size == 1 ||
-                       affine_quantization->scale->size ==
-                           filter->dims->data[kConvQuantizedDimension]);
-    TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
-                      affine_quantization->zero_point->size);
-  }
-
-  return CalculateOpData(context, node, params, input_width, input_height,
-                         filter_width, filter_height, output_width,
-                         output_height, input->type, op_data);
-}
-
-void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
-                             TfLiteConvParams* params, OpData* data,
-                             const TfLiteEvalTensor* input,
-                             const TfLiteEvalTensor* filter,
-                             const TfLiteEvalTensor* bias,
-                             TfLiteEvalTensor* output,
-                             TfLiteEvalTensor* im2col) {
-  // TODO(b/154032858): Investigate removing extra copies.
-  ConvParams op_params;
-  op_params.input_offset = -data->input_zero_point;
-  op_params.output_offset = data->output_zero_point;
-  op_params.stride_height = params->stride_height;
-  op_params.stride_width = params->stride_width;
-  op_params.dilation_height_factor = params->dilation_height_factor;
-  op_params.dilation_width_factor = params->dilation_width_factor;
-  op_params.padding_values.height = data->padding.height;
-  op_params.padding_values.width = data->padding.width;
-  op_params.quantized_activation_min = data->output_activation_min;
-  op_params.quantized_activation_max = data->output_activation_max;
-
-  ConvPerChannel(op_params, data->per_channel_output_multiplier,
-                 data->per_channel_output_shift,
-                 tflite::micro::GetTensorShape(input),
-                 tflite::micro::GetTensorData<int8_t>(input),
-                 tflite::micro::GetTensorShape(filter),
-                 tflite::micro::GetTensorData<int8_t>(filter),
-                 tflite::micro::GetTensorShape(bias),
-                 tflite::micro::GetTensorData<int32_t>(bias),
-                 tflite::micro::GetTensorShape(output),
-                 tflite::micro::GetTensorData<int8_t>(output));
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  TFLITE_DCHECK(node->builtin_data != nullptr);
-  auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
-  auto* op_data = reinterpret_cast<OpData*>(node->user_data);
-
-  TfLiteEvalTensor* output =
-      tflite::micro::GetEvalOutput(context, node, kOutputTensor);
-  const TfLiteEvalTensor* input =
-      tflite::micro::GetEvalInput(context, node, kInputTensor);
-  const TfLiteEvalTensor* filter =
-      tflite::micro::GetEvalInput(context, node, kFilterTensor);
-  const TfLiteEvalTensor* bias =
-      (NumInputs(node) == 3)
-          ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
-          : nullptr;
-
-  int* input_dims = input->dims->data;
-  int* filter_dims = filter->dims->data;
-  if (input_dims[0] == 1 && input_dims[1] == 1 && input_dims[2] == 1 &&
-      input_dims[3] == 32 && filter_dims[0] == 32 && filter_dims[1] == 1 &&
-      filter_dims[2] == 1 && filter_dims[3] == 32) {
-    Conv1x32Input32x32Filter(
-        -op_data->input_zero_point, op_data->output_zero_point,
-        op_data->output_activation_min, op_data->output_activation_max,
-        op_data->per_channel_output_multiplier,
-        op_data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
-        tflite::micro::GetTensorData<int8_t>(input),
-        tflite::micro::GetTensorShape(filter),
-        tflite::micro::GetTensorData<int8_t>(filter),
-        tflite::micro::GetTensorShape(bias),
-        tflite::micro::GetTensorData<int32_t>(bias),
-        tflite::micro::GetTensorShape(output),
-        tflite::micro::GetTensorData<int8_t>(output));
-    return kTfLiteOk;
-  }
-
-  switch (input->type) {
-    case kTfLiteInt8:
-      EvalQuantizedPerChannel(context, node, params, op_data, input, filter,
-                              bias, output, nullptr);
-      break;
-    default:
-      TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
-                         TfLiteTypeGetName(input->type), input->type);
-      return kTfLiteError;
-  }
-  return kTfLiteOk;
-}
-}  // namespace
-
-TfLiteRegistration Register_CONV_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
-}
-
-}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc
deleted file mode 100644
index 12410a94456..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc
+++ /dev/null
@@ -1,503 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <xtensa/tie/xt_hifi2.h>
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
-#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/padding.h"
-#include "tensorflow/lite/micro/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
-
-namespace tflite {
-namespace {
-
-constexpr int kInputTensor = 0;
-constexpr int kFilterTensor = 1;
-constexpr int kBiasTensor = 2;
-constexpr int kOutputTensor = 0;
-
-// Depthwise conv is quantized along dimension 3:
-// https://www.tensorflow.org/lite/performance/quantization_spec
-constexpr int kDepthwiseConvQuantizedDimension = 3;
-
-struct OpData {
-  TfLitePaddingValues padding;
-  // The scaling factor from input to output (aka the 'real multiplier') can
-  // be represented as a fixed point multiplier plus a left shift.
-  int32_t output_multiplier;
-  int output_shift;
-
-  // Cached tensor zero point values for quantized operations.
-  int32_t input_zero_point;
-  int32_t output_zero_point;
-
-  // Per channel output multiplier and shift.
-  // TODO(b/141139247): Allocate these dynamically when possible.
-  int32_t* per_channel_output_multiplier;
-  int32_t* per_channel_output_shift;
-
-  // The range of the fused activation layer. For example for kNone and
-  // uint8_t these would be 0 and 255.
-  int32_t output_activation_min;
-  int32_t output_activation_max;
-};
-
-inline void DepthwiseConvPerChannel(
-    const DepthwiseParams& params, const int32_t* output_multiplier,
-    const int32_t* output_shift, const RuntimeShape& input_shape,
-    const int8_t* input_data, const RuntimeShape& filter_shape,
-    const int8_t* filter_data, const RuntimeShape& bias_shape,
-    const int32_t* bias_data, const RuntimeShape& output_shape,
-    int8_t* output_data) {
-  // TODO(b/154032858): Investigate removing extra copies.
-  const int stride_width = params.stride_width;
-  const int stride_height = params.stride_height;
-  const int dilation_width_factor = params.dilation_width_factor;
-  const int dilation_height_factor = params.dilation_height_factor;
-  const int pad_width = params.padding_values.width;
-  const int pad_height = params.padding_values.height;
-  const int depth_multiplier = params.depth_multiplier;
-  const int32_t input_offset = params.input_offset;
-  const int32_t output_offset = params.output_offset;
-  const int32_t output_activation_min = params.quantized_activation_min;
-  const int32_t output_activation_max = params.quantized_activation_max;
-
-  const int batches = input_shape.Dims(0);
-
-  const int input_height = input_shape.Dims(1);
-  const int input_width = input_shape.Dims(2);
-  const int input_depth = input_shape.Dims(3);
-
-  const int filter_height = filter_shape.Dims(1);
-  const int filter_width = filter_shape.Dims(2);
-  const int filter_depth = filter_shape.Dims(3);
-
-  const int output_height = output_shape.Dims(1);
-  const int output_width = output_shape.Dims(2);
-  const int output_depth = output_shape.Dims(3);
-
-  ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset);
-  ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset);
-  ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min);
-  ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max);
-
-  for (int batch = 0; batch < batches; ++batch) {
-    for (int out_y = 0; out_y < output_height; ++out_y) {
-      const int in_y_origin = (out_y * stride_height) - pad_height;
-      for (int out_x = 0; out_x < output_width; ++out_x) {
-        const int in_x_origin = (out_x * stride_width) - pad_width;
-        for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
-          for (int m = 0; m < depth_multiplier; ++m) {
-            const int output_channel = m + in_channel * depth_multiplier;
-            ae_q56s acc_56 = AE_ZEROQ56();
-            for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
-              const int in_y = in_y_origin + dilation_height_factor * filter_y;
-              for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
-                const int in_x = in_x_origin + dilation_width_factor * filter_x;
-                // Zero padding by omitting the areas outside the image.
-                const bool is_point_inside_image =
-                    (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
-                    (in_y < input_height);
-
-                if (is_point_inside_image) {
-                  // Find current input index, minus 2 for Xtensa load
-                  // alignments:
-                  // TODO(b/147322595): Consider doing these offset calculations
-                  // with intrinsics:
-                  int input_idx =
-                      ((batch * input_height + in_y) * input_width + in_x) *
-                          input_depth +
-                      (in_channel);
-                  int32_t input_val = input_data[input_idx];
-
-                  // Find current filter index, minus 2 for Xtensa load
-                  // alignments:
-                  int filter_idx =
-                      ((filter_y)*filter_width + filter_x) * filter_depth +
-                      (output_channel);
-                  int32_t filter_val = filter_data[filter_idx];
-
-                  // Load 8bit value as int32_t into a 24x24 register and right
-                  // shift into 24bit space. Note: value is duplicated in the HH
-                  // and LL register - but all calculations are done on the HH
-                  // side.
-                  ae_p24x2s input_val_24x2 = AE_MOVPA24(input_val);
-
-                  // Add input offset (24bit aligned):
-                  input_val_24x2 =
-                      AE_P24S_ADDS_P24X2S(input_val_24x2, input_offset_24x2);
-
-                  // Load filter 8bit value into 24bit alignment:
-                  ae_p24x2s filter_val_24x2 = AE_MOVPA24(filter_val);
-
-                  // Multiply and accumulate the HH side of each 24x24 PR
-                  // register:
-                  AE_MULAS56P24S_HH(acc_56, filter_val_24x2, input_val_24x2);
-                }
-              }
-            }
-
-            // Left shift from 48bit alignment to 32bit:
-            acc_56 = AE_Q56S_SLAI(acc_56, 16);
-
-            if (bias_data) {
-              // Load and add bias at 32bit alignment:
-              ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[output_channel]);
-              acc_56 = AE_ADDQ56(acc_56, bias_56);
-            }
-
-            // Shift from 32bit alignment to 24bit alignment and place back on
-            // the PR register:
-            acc_56 = AE_Q56S_SLAI(acc_56, 8);
-            ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56);
-
-            // Apply quantized multiplier and accumulate result at 48bit
-            // alignment:
-            acc_56 = MultiplyByQuantizedMultiplier(
-                acc_24x2, output_multiplier[output_channel],
-                output_shift[output_channel]);
-
-            // Add output offset, cap activation, and assign to the output:
-            acc_56 = AE_ADDQ56(acc_56, output_offset_56);
-            acc_56 = AE_MINQ56S(acc_56, output_activation_max_56);
-            acc_56 = AE_MAXQ56S(acc_56, output_activation_min_56);
-
-            int output_idx =
-                ((batch * output_height + out_y) * output_width + out_x) *
-                    output_depth +
-                output_channel;
-            output_data[output_idx] =
-                static_cast<int8_t>(AE_TRUNCA32Q48(acc_56));
-          }
-        }
-      }
-    }
-  }
-}
-
-constexpr int kConvolutionalKernelWidth = 4;
-constexpr int kConvolutionalKernelDepth = 32;
-inline void DepthwiseConv4x32MatchingInputAndFilter(
-    const int input_offset, const int output_offset,
-    const int quantized_activation_min, const int quantized_activation_max,
-    const int32_t* output_multiplier, const int32_t* output_shift,
-    const RuntimeShape& input_shape, const int8_t* input_data,
-    const RuntimeShape& filter_shape, const int8_t* filter_data,
-    const RuntimeShape& bias_shape, const int32_t* bias_data,
-    const RuntimeShape& output_shape, int8_t* output_data) {
-  // Convert the (unsigned) 32-bit multiplier down to a 24-bit multiplier.
-  const int32_t mult = output_multiplier[0] >> 8;
-  const int32_t shift = output_shift[0];
-  ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset);
-  ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset);
-  ae_q56s output_activation_min_56 = AE_CVTQ48A32S(quantized_activation_min);
-  ae_q56s output_activation_max_56 = AE_CVTQ48A32S(quantized_activation_max);
-
-  const int num_blocks =
-      kConvolutionalKernelDepth / 2;  // Based on the 24x2 register size.
-  const int stride_elements =
-      (kConvolutionalKernelDepth / kConvolutionalKernelWidth);
-
-  const int8_t* input_0_ptr = (const int8_t*)(input_data - 2);
-  const int8_t* weight_0_ptr = (const int8_t*)(filter_data - 2);
-  // Apply the kernels in blocks of 4 for all the channels.
-  const int8_t* input_1_ptr = input_0_ptr + stride_elements * 4;
-  const int8_t* input_2_ptr = input_1_ptr + stride_elements * 4;
-  const int8_t* input_3_ptr = input_2_ptr + stride_elements * 4;
-
-  const int8_t* weight_1_ptr = weight_0_ptr + stride_elements * 4;
-  const int8_t* weight_2_ptr = weight_1_ptr + stride_elements * 4;
-  const int8_t* weight_3_ptr = weight_2_ptr + stride_elements * 4;
-
-  for (int i = 0; i < num_blocks; ++i) {
-    ae_q56s block_0_acc = AE_ZEROQ56();
-    ae_q56s block_1_acc = AE_ZEROQ56();
-
-    // Load all the weights.
-    ae_p24x2s weight_0, weight_1, weight_2, weight_3;
-    AE_LP8X2F_IU(weight_0, weight_0_ptr, 2);
-    AE_LP8X2F_IU(weight_1, weight_1_ptr, 2);
-    AE_LP8X2F_IU(weight_2, weight_2_ptr, 2);
-    AE_LP8X2F_IU(weight_3, weight_3_ptr, 2);
-
-    // Load all the inputs.
-    ae_p24x2s input_0, input_1, input_2, input_3;
-    AE_LP8X2F_IU(input_0, input_0_ptr, 2);
-    AE_LP8X2F_IU(input_1, input_1_ptr, 2);
-    AE_LP8X2F_IU(input_2, input_2_ptr, 2);
-    AE_LP8X2F_IU(input_3, input_3_ptr, 2);
-
-    // Shift inputs to 8 bit alignment and add offsets.
-    input_0 = AE_P24X2S_SRAI(input_0, 16);
-    input_1 = AE_P24X2S_SRAI(input_1, 16);
-    input_2 = AE_P24X2S_SRAI(input_2, 16);
-    input_3 = AE_P24X2S_SRAI(input_3, 16);
-
-    input_0 = AE_P24S_ADDS_P24X2S(input_0, input_offset_24x2);
-    input_1 = AE_P24S_ADDS_P24X2S(input_1, input_offset_24x2);
-    input_2 = AE_P24S_ADDS_P24X2S(input_2, input_offset_24x2);
-    input_3 = AE_P24S_ADDS_P24X2S(input_3, input_offset_24x2);
-
-    // Do the multiplies across all channels.  Resulting accumulators are 32bit
-    // aligned (24 bit aligned weights * 8 bit aligned inputs).
-    AE_MULAS56P24S_HH(block_0_acc, input_0, weight_0);
-    AE_MULAS56P24S_HH(block_0_acc, input_1, weight_1);
-    AE_MULAS56P24S_HH(block_0_acc, input_2, weight_2);
-    AE_MULAS56P24S_HH(block_0_acc, input_3, weight_3);
-
-    AE_MULAS56P24S_LL(block_1_acc, input_0, weight_0);
-    AE_MULAS56P24S_LL(block_1_acc, input_1, weight_1);
-    AE_MULAS56P24S_LL(block_1_acc, input_2, weight_2);
-    AE_MULAS56P24S_LL(block_1_acc, input_3, weight_3);
-
-    int ch_0 = i * 2;
-    int ch_1 = i * 2 + 1;
-
-    // Load and add bias at 32bit alignment:
-    ae_q56s bias_56_0 = AE_CVTQ48A32S(bias_data[ch_0]);
-    ae_q56s bias_56_1 = AE_CVTQ48A32S(bias_data[ch_1]);
-    block_0_acc = AE_ADDQ56(block_0_acc, bias_56_0);
-    block_1_acc = AE_ADDQ56(block_1_acc, bias_56_1);
-
-    // Shift from 32bit alignment to 24bit alignment and place back on
-    // the PR register:
-    block_0_acc = AE_Q56S_SLAI(block_0_acc, 8);
-    block_1_acc = AE_Q56S_SLAI(block_1_acc, 8);
-    ae_p24x2s acc_24x2_0 = AE_TRUNCP24Q48(block_0_acc);
-    ae_p24x2s acc_24x2_1 = AE_TRUNCP24Q48(block_1_acc);
-
-    // Apply quantized multiplier and accumulate result at 48bit
-    // alignment:
-    block_0_acc = MultiplyByQuantizedMultiplier(acc_24x2_0, mult, shift);
-    // Apply quantized multiplier and accumulate result at 48bit
-    // alignment:
-    block_1_acc = MultiplyByQuantizedMultiplier(acc_24x2_1, mult, shift);
-
-    // Add output offset, cap activation, and assign to the output:
-    block_0_acc = AE_ADDQ56(block_0_acc, output_offset_56);
-    block_1_acc = AE_ADDQ56(block_1_acc, output_offset_56);
-    block_0_acc = AE_MINQ56S(block_0_acc, output_activation_max_56);
-    block_1_acc = AE_MINQ56S(block_1_acc, output_activation_max_56);
-    block_0_acc = AE_MAXQ56S(block_0_acc, output_activation_min_56);
-    block_1_acc = AE_MAXQ56S(block_1_acc, output_activation_min_56);
-
-    output_data[ch_0] = static_cast<int8_t>(AE_TRUNCA32Q48(block_0_acc));
-    output_data[ch_1] = static_cast<int8_t>(AE_TRUNCA32Q48(block_1_acc));
-  }
-}
-
-TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
-                             TfLiteDepthwiseConvParams* params, int width,
-                             int height, int filter_width, int filter_height,
-                             const TfLiteType data_type, OpData* data) {
-  bool has_bias = node->inputs->size == 3;
-  // Check number of inputs/outputs
-  TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
-  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-
-  int unused_output_height, unused_output_width;
-  data->padding = ComputePaddingHeightWidth(
-      params->stride_height, params->stride_width, 1, 1, height, width,
-      filter_height, filter_width, params->padding, &unused_output_height,
-      &unused_output_width);
-
-  // Note that quantized inference requires that all tensors have their
-  // parameters set. This is usually done during quantized training.
-  if (data_type != kTfLiteFloat32) {
-    const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-    const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
-    const TfLiteTensor* bias =
-        GetOptionalInputTensor(context, node, kBiasTensor);
-    TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-    int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
-
-    // TODO(b/148610881): Consider calculating quantized params at int24
-    // calculations:
-    TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
-        context, input, filter, bias, output, params->activation,
-        &data->output_multiplier, &data->output_shift,
-        &data->output_activation_min, &data->output_activation_max,
-        data->per_channel_output_multiplier,
-        reinterpret_cast<int*>(data->per_channel_output_shift), num_channels));
-  }
-  return kTfLiteOk;
-}
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  TFLITE_DCHECK(node->builtin_data != nullptr);
-  auto* params =
-      reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
-
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
-  const TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
-  auto* op_data = reinterpret_cast<OpData*>(node->user_data);
-
-  const TfLiteType data_type = input->type;
-  int width = SizeOfDimension(input, 2);
-  int height = SizeOfDimension(input, 1);
-  int filter_width = SizeOfDimension(filter, 2);
-  int filter_height = SizeOfDimension(filter, 1);
-
-  // Per channel quantization is only needed for int8_t inference. For other
-  // quantized types, only a single scale and zero point is needed.
-  const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
-  // Dynamically allocate per-channel quantization parameters.
-  op_data->per_channel_output_multiplier =
-      reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
-          context, num_channels * sizeof(int32_t)));
-  op_data->per_channel_output_shift =
-      reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
-          context, num_channels * sizeof(int32_t)));
-
-  op_data->input_zero_point = input->params.zero_point;
-  op_data->output_zero_point = output->params.zero_point;
-
-  // All per-channel quantized tensors need valid zero point and scale arrays.
-  if (input->type == kTfLiteInt8) {
-    TF_LITE_ENSURE_EQ(context, filter->quantization.type,
-                      kTfLiteAffineQuantization);
-
-    const auto* affine_quantization =
-        reinterpret_cast<TfLiteAffineQuantization*>(
-            filter->quantization.params);
-    TF_LITE_ENSURE(context, affine_quantization);
-    TF_LITE_ENSURE(context, affine_quantization->scale);
-    TF_LITE_ENSURE(context, affine_quantization->zero_point);
-    TF_LITE_ENSURE(
-        context, affine_quantization->scale->size == 1 ||
-                     affine_quantization->scale->size ==
-                         filter->dims->data[kDepthwiseConvQuantizedDimension]);
-    TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
-                      affine_quantization->zero_point->size);
-  }
-
-  return CalculateOpData(context, node, params, width, height, filter_width,
-                         filter_height, data_type, op_data);
-}
-
-void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
-                             TfLiteDepthwiseConvParams* params, OpData* data,
-                             const TfLiteEvalTensor* input,
-                             const TfLiteEvalTensor* filter,
-                             const TfLiteEvalTensor* bias,
-                             TfLiteEvalTensor* output) {
-  DepthwiseParams op_params;
-  op_params.padding_type = PaddingType::kSame;
-  op_params.padding_values.width = data->padding.width;
-  op_params.padding_values.height = data->padding.height;
-  op_params.stride_width = params->stride_width;
-  op_params.stride_height = params->stride_height;
-  op_params.dilation_width_factor = params->dilation_width_factor;
-  op_params.dilation_height_factor = params->dilation_height_factor;
-  op_params.depth_multiplier = params->depth_multiplier;
-  op_params.input_offset = -data->input_zero_point;
-  op_params.weights_offset = 0;
-  op_params.output_offset = data->output_zero_point;
-  // TODO(b/130439627): Use calculated value for clamping.
-  op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
-  op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
-
-  DepthwiseConvPerChannel(op_params, data->per_channel_output_multiplier,
-                          data->per_channel_output_shift,
-                          tflite::micro::GetTensorShape(input),
-                          tflite::micro::GetTensorData<int8_t>(input),
-                          tflite::micro::GetTensorShape(filter),
-                          tflite::micro::GetTensorData<int8_t>(filter),
-                          tflite::micro::GetTensorShape(bias),
-                          tflite::micro::GetTensorData<int32_t>(bias),
-                          tflite::micro::GetTensorShape(output),
-                          tflite::micro::GetTensorData<int8_t>(output));
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  TFLITE_DCHECK(node->builtin_data != nullptr);
-  auto* params =
-      reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
-  auto* op_data = reinterpret_cast<OpData*>(node->user_data);
-
-  TfLiteEvalTensor* output =
-      tflite::micro::GetEvalOutput(context, node, kOutputTensor);
-  const TfLiteEvalTensor* input =
-      tflite::micro::GetEvalInput(context, node, kInputTensor);
-  const TfLiteEvalTensor* filter =
-      tflite::micro::GetEvalInput(context, node, kFilterTensor);
-  const TfLiteEvalTensor* bias =
-      (NumInputs(node) == 3)
-          ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
-          : nullptr;
-
-  // Handle special case for streaming model.
-  int* input_dims = input->dims->data;
-  int* filter_dims = filter->dims->data;
-  if (input_dims[0] == 1 && input_dims[1] == 4 && input_dims[2] == 1 &&
-      input_dims[3] == 32 && filter_dims[0] == 1 && filter_dims[1] == 4 &&
-      filter_dims[2] == 1 && filter_dims[3] == 32) {
-    DepthwiseConv4x32MatchingInputAndFilter(
-        -op_data->input_zero_point, op_data->output_zero_point,
-        std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
-        op_data->per_channel_output_multiplier,
-        op_data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
-        tflite::micro::GetTensorData<int8_t>(input),
-        tflite::micro::GetTensorShape(filter),
-        tflite::micro::GetTensorData<int8_t>(filter),
-        tflite::micro::GetTensorShape(bias),
-        tflite::micro::GetTensorData<int32_t>(bias),
-        tflite::micro::GetTensorShape(output),
-        tflite::micro::GetTensorData<int8_t>(output));
-    return kTfLiteOk;
-  }
-  switch (input->type) {  // Already know in/out types are same.
-    case kTfLiteInt8:
-      EvalQuantizedPerChannel(context, node, params, op_data, input, filter,
-                              bias, output);
-      break;
-    default:
-      TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
-                         TfLiteTypeGetName(input->type), input->type);
-      return kTfLiteError;
-  }
-  return kTfLiteOk;
-}
-
-}  // namespace
-
-TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
-}
-
-}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h
deleted file mode 100644
index a1d14df1352..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h
+++ /dev/null
@@ -1,137 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_
-#define TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_
-
-#include <xtensa/tie/xt_hifi2.h>
-
-#include <algorithm>
-#include <cmath>
-#include <cstdint>
-
-#include "tensorflow/lite/kernels/internal/compatibility.h"
-
-namespace tflite {
-
-// INT24 MIN/MAX
-#define INT24_MIN -8388608
-#define INT24_MAX 8388607
-
-// Multiply 24bit value by a quantized multiplier (w/ shift) and returns a 48bit
-// aligned value in the QR register.
-inline ae_q56s MultiplyByQuantizedMultiplier(ae_p24x2s x_24x2,
-                                             int32_t quantized_multiplier,
-                                             int shift) {
-  // A value with 1 sign bit, N integer bits and M fractional bits is
-  // represented as QN+1.M since the sign bit is included in the integer bits.
-  //
-  // The Q notation in this method explains the values represented in each
-  // variable, along with an implicit division since the quantized_multiplier
-  // represents a value between 0.5 and 1.0 (Q1.X-1 where X is the bit precision
-  // of the type).
-  //
-  // Load the quantized multiplier into the PR register.
-  // NOTE: This method assumes that this param has been calculated for 24bit
-  // space - not 32bits.
-  // Q32.0 / 2^23 -> Q24.0 / 2^23 representing a Q1.23 multiplier.
-  ae_p24x2s quantized_multiplier_24x2 = AE_MOVPA24(quantized_multiplier);
-  // Shift right by 23 - 16 bits minus the specified shift.  This is because we
-  // keep 16 fractional bits until the end to perform rounding.  Subtract shift
-  // since shift is a left shift, and the 23-16 is a right shift.
-  int shift_amount = 7 - shift;
-
-  // Find the product of x and the quantized_multiplier.
-  // Q24.0 / 2^23 * Q24.0 = Q48.0 / 2^23
-  // Q48.0 / 2^23 >> 7 = Q48.0 / 2^16
-  ae_q56s result_56 = AE_MULP24S_HH(x_24x2, quantized_multiplier_24x2);
-
-  // Shift right if shift amount is positive, left if shift amount is negative.
-  if (shift_amount >= 0) {
-    result_56 = AE_Q56S_SRA(result_56, shift_amount);
-  } else {
-    result_56 = AE_Q56S_SLA(result_56, -shift_amount);
-  }
-
-  // Round off the bottom 16 bits.
-  // Q48.0 / 2^16 -> Q32.0 aligned to 48 bits.
-  result_56 = AE_ROUNDSQ32SYM(result_56);
-  return result_56;
-}
-
-// Multiply 32bit value by a quantized multiplier (w/ shift) and returns a 48bit
-// aligned value in the QR register.
-inline ae_q56s MultiplyByQuantizedMultiplierResult48Bit(
-    int32_t x, int32_t quantized_multiplier, int shift) {
-  // Convert x into a 2x24bit PR register file. If x is outside the numerical
-  // limits of a 24bit integer, the "fractional" or lower 8bits are discarded.
-  // If x is within the range of a 24 bit integer, the "signed" or upper 8bits
-  // are discarded.
-  ae_p24x2s x_24x2;
-  if (x > INT24_MIN && x < INT24_MAX) {
-    x_24x2 = AE_MOVPA24(x);
-  } else {
-    x_24x2 = static_cast<ae_p24s>(*reinterpret_cast<ae_p24f*>(&x));
-    shift += 8;
-  }
-
-  return MultiplyByQuantizedMultiplier(x_24x2, quantized_multiplier, shift);
-}
-
-// Calculate quantization params for 24bit runtimes.
-inline void QuantizeMultiplierForInt24(float multiplier,
-                                       int32_t* quantized_multiplier,
-                                       int* shift) {
-  if (multiplier == 0.0f) {
-    *quantized_multiplier = 0;
-    *shift = 0;
-    return;
-  }
-
-  // Special cased to 24bit:
-  const float q = std::frexp(multiplier, shift);
-  auto q_fixed = static_cast<int64_t>(std::round(q * (1 << 23)));
-
-  TFLITE_CHECK(q_fixed <= (1 << 23));
-  if (q_fixed == (1 << 23)) {
-    q_fixed /= 2;
-    ++*shift;
-  }
-  TFLITE_CHECK_LE(q_fixed, INT24_MAX);
-
-  // Ensure shift does not exceed 24-bit range.
-  TFLITE_CHECK_LE(*shift, 23);
-  if (*shift < -23) {
-    *shift = 0;
-    q_fixed = 0;
-  }
-  *quantized_multiplier = static_cast<int32_t>(q_fixed);
-}
-
-// Convert a floating point number to a Q representation for 24 bit integers.
-inline int CreateQConstantForInt24(int integer_bits, float f) {
-  const float min_bounds = static_cast<float>(INT24_MIN);
-  const float max_bounds = static_cast<float>(INT24_MAX);
-
-  int fractional_bits = 23 - integer_bits;
-  float raw = std::round(f * static_cast<float>(1 << fractional_bits));
-  raw = std::max(raw, min_bounds);
-  raw = std::min(raw, max_bounds);
-  return static_cast<int>(raw);
-}
-
-}  // namespace tflite
-
-#endif  // TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc
deleted file mode 100644
index 30a5b6a602a..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc
+++ /dev/null
@@ -1,252 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
-
-#include <xtensa/tie/xt_hifi2.h>
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
-
-namespace tflite {
-namespace {
-
-struct OpData {
-  // The scaling factor from input to output (aka the 'real multiplier') can
-  // be represented as a fixed point multiplier plus a left shift.
-  int32_t output_multiplier;
-  int output_shift;
-
-  // Cached tensor zero point values for quantized operations.
-  int32_t input_zero_point;
-  int32_t filter_zero_point;
-  int32_t output_zero_point;
-
-  // The range of the fused activation layer. For example for kNone and
-  // uint8_t these would be 0 and 255.
-  int32_t output_activation_min;
-  int32_t output_activation_max;
-  // The index of the temporary tensor where the quantized inputs are cached.
-  int input_quantized_index;
-};
-
-constexpr int kInputTensor = 0;
-constexpr int kWeightsTensor = 1;
-constexpr int kBiasTensor = 2;
-constexpr int kOutputTensor = 0;
-
-void FullyConnected(const FullyConnectedParams& params,
-                    const RuntimeShape& input_shape, const int8_t* input_data,
-                    const RuntimeShape& filter_shape, const int8_t* filter_data,
-                    const RuntimeShape& bias_shape, const int32_t* bias_data,
-                    const RuntimeShape& output_shape, int8_t* output_data) {
-  // TODO(b/154032858): Investigate removing extra copies.
-  const int32_t input_offset = params.input_offset;
-  const int32_t filter_offset = params.weights_offset;
-  const int32_t output_offset = params.output_offset;
-  const int32_t output_multiplier = params.output_multiplier;
-  const int output_shift = params.output_shift;
-  const int32_t output_activation_min = params.quantized_activation_min;
-  const int32_t output_activation_max = params.quantized_activation_max;
-
-  const int filter_dim_count = filter_shape.DimensionsCount();
-  const int batches = output_shape.Dims(0);
-  const int output_depth = output_shape.Dims(1);
-  const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
-  const int accum_depth_iters = accum_depth / 2;
-
-  ae_p24x2s offsets_input_24x2 = AE_MOVPA24(input_offset);
-  ae_p24x2s offsets_filter_24x2 = AE_MOVPA24(filter_offset);
-  ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset);
-  ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max);
-  ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min);
-
-  for (int b = 0; b < batches; ++b) {
-    for (int out_c = 0; out_c < output_depth; ++out_c) {
-      // Load intrinsics advance pointer before loading so backoff data pointers
-      // by two before loading:
-      const int8_t* input_ptr = (input_data + b * accum_depth) - 2;
-      const int8_t* filter_ptr = (filter_data + out_c * accum_depth) - 2;
-
-      // Main accumulator register entry for loop:
-      ae_q56s sum_56 = AE_ZEROQ56();
-
-      for (int d = 0; d < accum_depth_iters; d++) {
-        // Load the signed 8bit values into the PR register:
-        ae_p24x2s input_24x2;
-        ae_p24x2s filter_24x2;
-        AE_LP8X2F_IU(input_24x2, input_ptr, 2);
-        AE_LP8X2F_IU(filter_24x2, filter_ptr, 2);
-
-        // Right shift the signed 8bit values to expand to signed 24bit values:
-        input_24x2 = AE_P24X2S_SRAI(input_24x2, 16);
-        filter_24x2 = AE_P24X2S_SRAI(filter_24x2, 16);
-
-        // Add offsets to data values (24 bit aligned):
-        input_24x2 = AE_P24S_ADDS_P24X2S(offsets_input_24x2, input_24x2);
-        filter_24x2 = AE_P24S_ADDS_P24X2S(offsets_filter_24x2, filter_24x2);
-
-        // 24x2 signed integer dual MAC w/ addition into 56bit accumulator (48
-        // bit aligned):
-        AE_MULAAP24S_HH_LL(sum_56, input_24x2, filter_24x2);
-      }
-
-      // Left shift to get back into 32bit space (right padded to 48bit):
-      sum_56 = AE_Q56S_SLAI(sum_56, 16);
-
-      // Add bias data if needed:
-      if (bias_data) {
-        ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[out_c]);
-        sum_56 = AE_ADDQ56(sum_56, bias_56);
-      }
-
-      // Shift left into 24bit space and place back on PR register:
-      sum_56 = AE_Q56S_SLAI(sum_56, 8);
-      ae_p24x2s sum_24x2 = AE_TRUNCP24Q48(sum_56);
-
-      // MultiplyByQuantizedMultiplier returns a 48bit aligned value
-      sum_56 = MultiplyByQuantizedMultiplier(sum_24x2, output_multiplier,
-                                             output_shift);
-
-      // Add output_offset and cap min/max values:
-      sum_56 = AE_ADDQ56(sum_56, output_offset_56);
-      sum_56 = AE_MINQ56S(sum_56, output_activation_max_56);
-      sum_56 = AE_MAXQ56S(sum_56, output_activation_min_56);
-
-      output_data[out_c + output_depth * b] =
-          static_cast<int8_t>(AE_TRUNCA32Q48(sum_56));
-    }
-  }
-}
-
-TfLiteStatus CalculateOpData(TfLiteContext* context,
-                             TfLiteFusedActivation activation,
-                             TfLiteType data_type, const TfLiteTensor* input,
-                             const TfLiteTensor* filter,
-                             const TfLiteTensor* bias, TfLiteTensor* output,
-                             OpData* data) {
-  double real_multiplier = 0.0;
-  TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
-      context, input, filter, bias, output, &real_multiplier));
-  QuantizeMultiplierForInt24(real_multiplier, &data->output_multiplier,
-                             &data->output_shift);
-  return CalculateActivationRangeQuantized(context, activation, output,
-                                           &data->output_activation_min,
-                                           &data->output_activation_max);
-}
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  TFLITE_DCHECK(node->builtin_data != nullptr);
-
-  OpData* data = static_cast<OpData*>(node->user_data);
-  const auto* params =
-      reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
-
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
-  const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
-  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
-  if (input->type != kTfLiteInt8) {
-    TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
-                       TfLiteTypeGetName(input->type), input->type);
-    return kTfLiteError;
-  }
-
-  data->input_zero_point = input->params.zero_point;
-  data->filter_zero_point = filter->params.zero_point;
-  data->output_zero_point = output->params.zero_point;
-
-  return CalculateOpData(context, params->activation, input->type, input,
-                         filter, bias, output, data);
-}
-
-TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
-                               const OpData& data,
-                               const TfLiteEvalTensor* input,
-                               const TfLiteEvalTensor* filter,
-                               const TfLiteEvalTensor* bias,
-                               TfLiteEvalTensor* output) {
-  // TODO(b/154032858): Investigate removing extra copies, and also passing by
-  // value. TODO(b/155656675): Consider passing OpData by value once it is also
-  // passed to the FullyConnected function. Until it is copied to a local
-  // op_param variable, we do not get any latency improvements from passing by
-  // value.
-  FullyConnectedParams op_params;
-  op_params.input_offset = -data.input_zero_point;
-  op_params.weights_offset = -data.filter_zero_point;
-  op_params.output_offset = data.output_zero_point;
-  op_params.output_multiplier = data.output_multiplier;
-  op_params.output_shift = data.output_shift;
-  op_params.quantized_activation_min = data.output_activation_min;
-  op_params.quantized_activation_max = data.output_activation_max;
-
-  FullyConnected(op_params, tflite::micro::GetTensorShape(input),
-                 tflite::micro::GetTensorData<int8_t>(input),
-                 tflite::micro::GetTensorShape(filter),
-                 tflite::micro::GetTensorData<int8_t>(filter),
-                 tflite::micro::GetTensorShape(bias),
-                 tflite::micro::GetTensorData<int32_t>(bias),
-                 tflite::micro::GetTensorShape(output),
-                 tflite::micro::GetTensorData<int8_t>(output));
-  return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  const OpData& data = *(static_cast<const OpData*>(node->user_data));
-
-  const TfLiteEvalTensor* input =
-      tflite::micro::GetEvalInput(context, node, kInputTensor);
-  const TfLiteEvalTensor* filter =
-      tflite::micro::GetEvalInput(context, node, kWeightsTensor);
-  const TfLiteEvalTensor* bias =
-      (NumInputs(node) == 3)
-          ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
-          : nullptr;
-
-  TfLiteEvalTensor* output =
-      tflite::micro::GetEvalOutput(context, node, kOutputTensor);
-
-  return EvalQuantizedInt8(context, node, data, input, filter, bias, output);
-}
-
-}  // namespace
-
-TfLiteRegistration Register_FULLY_CONNECTED() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
-}
-
-}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc
deleted file mode 100644
index b867e70d98b..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc
+++ /dev/null
@@ -1,161 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/quantize.h"
-
-#include <xtensa/tie/xt_hifi2.h>
-
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
-
-namespace tflite {
-namespace {
-
-struct OpData {
-  int32_t zero_point = 0;
-  int scale_multiplier = 0;
-};
-
-void AffineQuantize(int scale_multiplier, const int32_t zero_point,
-                    const RuntimeShape& input_shape, const int16_t* input_data,
-                    const RuntimeShape& output_shape, int8_t* output_data) {
-  const int flat_size = MatchingFlatSize(input_shape, output_shape);
-  ae_q56s min_val_56 = AE_CVTQ48A32S(INT16_MIN);
-  ae_q56s max_val_56 = AE_CVTQ48A32S(INT16_MAX);
-  ae_q56s zero_point_56 = AE_CVTQ48A32S(zero_point);
-
-  const ae_p16x2s* input_data_ptr = (const ae_p16x2s*)(input_data - 2);
-
-  ae_p24x2s scale_multiplier_24x2 = AE_MOVPA24(scale_multiplier);
-
-  int iters = flat_size / 2;
-  for (int i = 0; i < iters; i++) {
-    // Load two 16bit pairs into the 2x24bit register PR:
-    // Values need to be right shifted 8 bits to align from upper 16bits to a
-    // 24bit value:
-    ae_p24x2s inputs_24x2;
-    AE_LP16X2F_IU(inputs_24x2, input_data_ptr, 4);
-    inputs_24x2 = AE_P24X2S_SRAI(inputs_24x2, 8);
-
-    // Q0.23 * Q16.0 == Q16.23
-    {
-      ae_q56s sum_56 = AE_MULP24S_HH(scale_multiplier_24x2, inputs_24x2);
-
-      // Q16.23 -> Q16.0
-      // Shift right only 7 bits (23 - 16). This truncated shift aligns the
-      // 16bit value at the truncation line for 32bit in the QR register. The
-      // lower 16 bits will be used for rounding in AE_ROUNDSQ32SYM.
-      sum_56 = AE_Q56S_SRAI(sum_56, 7);
-
-      // Round and truncate 32 bits
-      sum_56 = AE_ROUNDSQ32SYM(sum_56);
-
-      // Add offset (zero_point_56 is already aligned at 32bits.
-      sum_56 = AE_ADDQ56(sum_56, zero_point_56);
-
-      // Saturate:
-      sum_56 = AE_MINQ56S(sum_56, max_val_56);
-      sum_56 = AE_MAXQ56S(sum_56, min_val_56);
-
-      output_data[i * 2] = static_cast<int16_t>(AE_TRUNCA32Q48(sum_56));
-    }
-    {
-      ae_q56s sum_56 = AE_MULP24S_LL(scale_multiplier_24x2, inputs_24x2);
-
-      // Q16.23 -> Q16.0
-      // Shift right only 7 bits (23 - 16). This truncated shift aligns the
-      // 16bit value at the truncation line for 32bit in the QR register. The
-      // lower 16 bits will be used for rounding in AE_ROUNDSQ32SYM.
-      sum_56 = AE_Q56S_SRAI(sum_56, 23 - 16);
-
-      // Round and truncate 32 bits
-      sum_56 = AE_ROUNDSQ32SYM(sum_56);
-
-      // Add offset (zero_point_56 is already aligned at 32bits.
-      sum_56 = AE_ADDQ56(sum_56, zero_point_56);
-
-      // Saturate:
-      sum_56 = AE_MINQ56S(sum_56, max_val_56);
-      sum_56 = AE_MAXQ56S(sum_56, min_val_56);
-
-      output_data[i * 2 + 1] = static_cast<int16_t>(AE_TRUNCA32Q48(sum_56));
-    }
-  }
-}
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  auto* op_data = static_cast<OpData*>(node->user_data);
-
-  TfLiteTensor* output = GetOutput(context, node, 0);
-  const TfLiteTensor* input = GetInput(context, node, 0);
-
-  // TODO(b/155682734): Fix dangerous input/output scale ratio assumptions.
-  op_data->scale_multiplier =
-      CreateQConstantForInt24(0, input->params.scale / output->params.scale);
-
-  op_data->zero_point = output->params.zero_point;
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  auto* op_data = static_cast<OpData*>(node->user_data);
-
-  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
-  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
-
-  tflite::QuantizationParams op_params;
-  op_params.zero_point = op_data->zero_point;
-
-  if (input->type != kTfLiteInt16 && output->type != kTfLiteInt8) {
-    TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
-                       TfLiteTypeGetName(input->type),
-                       TfLiteTypeGetName(output->type));
-    return kTfLiteError;
-  }
-
-  AffineQuantize(op_data->scale_multiplier, op_data->zero_point,
-                 tflite::micro::GetTensorShape(input),
-                 tflite::micro::GetTensorData<int16_t>(input),
-                 tflite::micro::GetTensorShape(output),
-                 tflite::micro::GetTensorData<int8_t>(output));
-  return kTfLiteOk;
-}
-
-}  // namespace
-
-TfLiteRegistration Register_QUANTIZE() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
-}
-
-}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc
deleted file mode 100644
index 75eb2838034..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc
+++ /dev/null
@@ -1,208 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/softmax.h"
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/op_macros.h"
-#include "tensorflow/lite/micro/kernels/kernel_util.h"
-
-namespace tflite {
-namespace {
-
-struct OpData {
-  uint16_t* exp_lut;
-};
-
-// Number of unique int8_t and int16_t values.  Used in exponent lookup table
-// computation.
-constexpr int kInt8Range =
-    std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min() + 1;
-constexpr int kInt16Range = std::numeric_limits<int16_t>::max() -
-                            std::numeric_limits<int16_t>::min() + 1;
-// Each 16-bit precalculated exponent is expressed as a Q0.16 fixedpoint
-// value. We special-case e^0 since 1.0 requires 1 integer bit to
-// express.
-constexpr int kExpFractionalBits = 16;
-// e^0 expressed as Q1.15 exceeds the int16_t range, so it must be handled
-// specially.
-constexpr int kMaxExponentValue = (1 << kExpFractionalBits);
-
-// Quantized softmax with int8_t input and int16_t output.
-// Passing OpData by value does not have much savings in this op, but following
-// that as a best practice, at least for the xtensa kernels. See b/155656675 for
-// more details.
-TfLiteStatus Softmax(OpData op_data, const RuntimeShape& input_shape,
-                     const int8_t* input_data, const RuntimeShape& output_shape,
-                     int16_t* output_data) {
-  // The last dimension is depth.  Outer size is the total input size
-  // divided by depth.
-  const int trailing_dim = input_shape.DimensionsCount() - 1;
-  const int outer_size =
-      MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
-  const int depth =
-      MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
-
-  for (int i = 0; i < outer_size; ++i) {
-    int8_t max_in_row = std::numeric_limits<int8_t>::min();
-    for (int c = 0; c < depth; ++c) {
-      max_in_row = std::max(max_in_row, input_data[i * depth + c]);
-    }
-
-    uint32_t sum_of_exps = 0;
-    for (int c = 0; c < depth; ++c) {
-      TFLITE_DCHECK(max_in_row >= input_data[i * depth + c]);
-      uint8_t input_diff = max_in_row - input_data[i * depth + c];
-
-      sum_of_exps +=
-          input_diff == 0 ? kMaxExponentValue : op_data.exp_lut[input_diff];
-    }
-
-    // Ensure we cannot overflow the full_range_output value.  We need to
-    // guarantee that kInt16Range * max(input_data) / sum_of_exps < kInt16Range.
-    TFLITE_DCHECK(sum_of_exps >= kMaxExponentValue);
-
-    for (int c = 0; c < depth; ++c) {
-      uint8_t input_diff = max_in_row - input_data[i * depth + c];
-      // Special case for diff == 0
-      uint32_t unscaled_output =
-          input_diff == 0 ? kMaxExponentValue : op_data.exp_lut[input_diff];
-      int64_t scaled_output = static_cast<int64_t>(unscaled_output) *
-                              static_cast<int64_t>(kInt16Range);
-      int32_t full_range_output =
-          scaled_output / sum_of_exps + std::numeric_limits<int16_t>::min();
-      // Round up if remainder exceeds half of the divider value.
-      uint32_t remainder = scaled_output % sum_of_exps;
-      if (remainder * 2 >= sum_of_exps) {
-        full_range_output++;
-      }
-      output_data[i * depth + c] = static_cast<int16_t>(std::max(
-          std::min(full_range_output,
-                   static_cast<int32_t>(std::numeric_limits<int16_t>::max())),
-          static_cast<int32_t>(std::numeric_limits<int16_t>::min())));
-    }
-  }
-  return kTfLiteOk;
-}
-
-TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
-                                    const TfLiteTensor* input,
-                                    TfLiteTensor* output,
-                                    const TfLiteSoftmaxParams* params,
-                                    OpData* op_data) {
-  if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
-    if (input->type == kTfLiteUInt8) {
-      TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
-    } else {
-      if (output->type == kTfLiteInt16) {
-        TF_LITE_ENSURE_EQ(context, output->params.zero_point,
-                          std::numeric_limits<int16_t>::min());
-        // NOTE: Current int16_t softmax output does not require symmetric
-        // scaling
-        // - so no need to verify scale here.
-      } else {
-        TF_LITE_ENSURE_EQ(context, output->params.zero_point,
-                          std::numeric_limits<int8_t>::min());
-        TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
-      }
-    }
-
-    // Precompute e^(-x * input_scale * beta) for every possible int8_t input.
-    // This computation is used for every iteration of Softmax.  We must compute
-    // using pre-scaled inputs to avoid introducing additional error, while
-    // restricting our input range to the int8_t range. This is valid since beta
-    // and input scale are constant for a given op in the graph. Skip index 0
-    // since that is a special case which requires 1 integer bit instead of 0.
-    for (int i = 1; i <= kInt8Range; i++) {
-      float scaled_input = i * input->params.scale;
-      float exp_value =
-          std::exp((-scaled_input) * static_cast<float>(params->beta));
-
-      float exponent_scaled =
-          std::round(exp_value * static_cast<float>(1 << kExpFractionalBits));
-      op_data->exp_lut[i] = static_cast<uint16_t>(exponent_scaled);
-    }
-  }
-  return kTfLiteOk;
-}
-
-void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
-  auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
-
-  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
-  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-  const TfLiteTensor* input = GetInput(context, node, 0);
-  TfLiteTensor* output = GetOutput(context, node, 0);
-  TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
-
-  TFLITE_DCHECK(node->user_data != nullptr);
-  OpData* op_data = static_cast<OpData*>(node->user_data);
-
-  // Allocate an array to precompute exponents over all int8_t inputs, applying
-  // the scale and beta before calculating exp. It is mandatory to apply beta
-  // and scale here, since each softmax op may have different beta and scale
-  // values. Beta and scale will remain constant for a given softmax op.
-  op_data->exp_lut = static_cast<uint16_t*>(context->AllocatePersistentBuffer(
-      context, kInt8Range * sizeof(uint16_t)));
-  TF_LITE_ENSURE(context, op_data->exp_lut != nullptr);
-
-  TF_LITE_ENSURE_STATUS(
-      CalculateSoftmaxOpData(context, input, output, params, op_data));
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
-  auto* op_data = static_cast<OpData*>(node->user_data);
-
-  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
-  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
-
-  if (input->type == kTfLiteInt8 && output->type == kTfLiteInt16) {
-    return Softmax(*op_data, tflite::micro::GetTensorShape(input),
-                   tflite::micro::GetTensorData<int8_t>(input),
-                   tflite::micro::GetTensorShape(output),
-                   tflite::micro::GetTensorData<int16_t>(output));
-  } else {
-    TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
-                       TfLiteTypeGetName(input->type), input->type);
-    return kTfLiteError;
-  }
-}
-
-}  // namespace
-
-TfLiteRegistration Register_SOFTMAX() {
-  return {/*init=*/SoftmaxInit,
-          /*free=*/nullptr,
-          /*prepare=*/SoftmaxPrepare,
-          /*invoke=*/SoftmaxEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
-}
-
-}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc
deleted file mode 100644
index 28f8f1e1af0..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc
+++ /dev/null
@@ -1,420 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <math.h>
-#include <xtensa/tie/xt_hifi2.h>
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/op_macros.h"
-#include "tensorflow/lite/micro/kernels/activation_utils.h"
-#include "tensorflow/lite/micro/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
-
-namespace tflite {
-namespace {
-
-struct OpData {
-  int32_t effective_scale_1_a;
-  int32_t effective_scale_2_a;
-  // b versions of each scale are kept at int since the numbers are just the
-  // shift value - typically between [-32, 32].
-  int effective_scale_1_b;
-  int effective_scale_2_b;
-  int scratch_tensor_index;
-  int scratch_output_tensor_index;
-
-  // Cached tensor zero point values for quantized operations.
-  int input_zero_point;
-  int output_zero_point;
-};
-
-// Input tensors.
-constexpr int kInputTensor = 0;
-constexpr int kWeightsFeatureTensor = 1;
-constexpr int kWeightsTimeTensor = 2;
-constexpr int kBiasTensor = 3;
-// This is a variable tensor, and will be modified by this op.
-constexpr int kInputActivationStateTensor = 4;
-
-// Output tensor.
-constexpr int kOutputTensor = 0;
-
-/**
- * This version of SVDF is specific to TFLite Micro. It contains only a full
- * integer receipe with optimizations for the Xtensa HiFiMini platform.
- *
- * Note: passing OpData by value might seem like an oversight but it helps
- * reduce the latency. See b/155656675 for more details.
- */
-void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
-                     const TfLiteEvalTensor* input_tensor,
-                     const TfLiteEvalTensor* weights_feature_tensor,
-                     const TfLiteEvalTensor* weights_time_tensor,
-                     const TfLiteEvalTensor* bias_tensor,
-                     const TfLiteSVDFParams* params,
-                     TfLiteEvalTensor* activation_state_tensor,
-                     TfLiteEvalTensor* output_tensor, OpData data) {
-  const int n_rank = params->rank;
-  const int n_batch = input_tensor->dims->data[0];
-  const int n_input = input_tensor->dims->data[1];
-  const int n_filter = weights_feature_tensor->dims->data[0];
-  const int n_unit = n_filter / n_rank;
-  const int n_memory = weights_time_tensor->dims->data[1];
-
-  TFLITE_DCHECK(context != nullptr);
-  TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
-
-  int32_t* scratch_tensor = static_cast<int32_t*>(
-      context->GetScratchBuffer(context, data.scratch_tensor_index));
-  TFLITE_DCHECK(scratch_tensor != nullptr);
-  int32_t* scratch_output_tensor = static_cast<int32_t*>(
-      context->GetScratchBuffer(context, data.scratch_output_tensor_index));
-  TFLITE_DCHECK(scratch_output_tensor != nullptr);
-
-  // Shift states.
-  int16_t* const state_ptr =
-      tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
-
-  // Left shift the activation_state.
-  {
-    int16_t* new_state_start = state_ptr;
-    const int16_t* old_state_start = state_ptr + 1;
-    const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory;
-    while (old_state_start != old_state_end) {
-      *new_state_start++ = *old_state_start++;
-    }
-  }
-
-  // Note: no need to clear the latest activation, matmul is not accumulative.
-
-  // Feature matmul.
-  {
-    const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
-    const int8_t* weight_feature =
-        tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
-    int16_t* result_in_batch = state_ptr + (n_memory - 1);
-
-    ae_q56s output_int16_max_56 = AE_CVTQ48A32S(INT16_MAX);
-    ae_q56s output_int16_min_56 = AE_CVTQ48A32S(INT16_MIN);
-    ae_p24x2s input_zp_24x2 = AE_MOVPA24(data.input_zero_point);
-
-    for (int b = 0; b < n_batch; b++) {
-      const int8_t* weight_feature_ptr = weight_feature - 2;
-
-      for (int r = 0; r < n_filter; r++) {
-        ae_q56s dot_prod_56 = AE_ZEROQ56();
-
-        const int8_t* input_batch_ptr = input + b * n_input;
-        const int8_t* offset_input_batch_ptr = input_batch_ptr - 2;
-
-        int num_iters = n_input / 2;
-        for (int c = 0; c < num_iters; c++) {
-          // Load 2 sets of values:
-          ae_p24x2s weight_feature_ptr_24x2;
-          ae_p24x2s input_batch_ptr_24x2;
-          AE_LP8X2F_IU(weight_feature_ptr_24x2, weight_feature_ptr, 2);
-          AE_LP8X2F_IU(input_batch_ptr_24x2, offset_input_batch_ptr, 2);
-
-          // Right shift the signed 8bit values to expand to signed 24bit
-          // values:
-          weight_feature_ptr_24x2 = AE_P24X2S_SRAI(weight_feature_ptr_24x2, 16);
-          input_batch_ptr_24x2 = AE_P24X2S_SRAI(input_batch_ptr_24x2, 16);
-
-          // First subtract input_zp from input_batch_ptr_24x2:
-          input_batch_ptr_24x2 =
-              AE_SUBSP24S(input_batch_ptr_24x2, input_zp_24x2);
-
-          // Multiply accum:
-          AE_MULAAP24S_HH_LL(dot_prod_56, weight_feature_ptr_24x2,
-                             input_batch_ptr_24x2);
-        }
-
-        // Left shift 48bit value into 24bit space and place on the PR register:
-        dot_prod_56 = AE_Q56S_SLAI(dot_prod_56, 24);
-        ae_p24x2s dot_prod_24x2 = AE_TRUNCP24Q48(dot_prod_56);
-
-        dot_prod_56 = MultiplyByQuantizedMultiplier(
-            dot_prod_24x2, data.effective_scale_1_a, data.effective_scale_1_b);
-
-        // Cap min/max and convert to int32_t:
-        dot_prod_56 = AE_MAXQ56S(dot_prod_56, output_int16_min_56);
-        dot_prod_56 = AE_MINQ56S(dot_prod_56, output_int16_max_56);
-        // Truncate immediately since the QR register is already 32 bit aligned:
-        // This assumes state is symmetrically quantized. Otherwise last bit of
-        // state should be initialized to its zero point and accumulate the
-        // dot_prod.
-        // Equivalent as the following:
-        //     result_in_batch = zero point, which happens to be zero.
-        //     result_in_batch += dot_prod_56.
-        *result_in_batch = AE_TRUNCA32Q48(dot_prod_56);
-        result_in_batch += n_memory;
-      }
-    }
-  }
-
-  // Time.
-  {
-    for (int b = 0; b < n_batch; ++b) {
-      int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
-
-      // Perform batched vector dot product:
-      const int16_t* vector1_ptr =
-          tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
-      const int16_t* vector2_ptr = state_ptr + b * n_memory * n_filter;
-
-      const ae_p16x2s* offset_vector1 =
-          reinterpret_cast<const ae_p16x2s*>(vector1_ptr - 2);
-      const ae_p16x2s* offset_vector2 =
-          reinterpret_cast<const ae_p16x2s*>(vector2_ptr - 2);
-
-      for (int i = 0; i < n_filter; i++) {
-        *scratch_ptr_batch = 0;
-
-        ae_q56s sum_56 = AE_ZEROQ56();
-        int num_iters = n_memory / 2;
-        for (int j = 0; j < num_iters; j++) {
-          ae_p24x2s vector1_24x2;
-          ae_p24x2s vector2_24x2;
-          AE_LP16X2F_IU(vector1_24x2, offset_vector1, 4);
-          AE_LP16X2F_IU(vector2_24x2, offset_vector2, 4);
-          AE_MULAAP24S_HH_LL(sum_56, vector1_24x2, vector2_24x2);
-        }
-        // Truncate directly since values are already 32bit aligned:
-        *scratch_ptr_batch = AE_TRUNCA32Q48(sum_56);
-        scratch_ptr_batch++;
-      }
-    }
-  }
-
-  // Reduce, add bias, rescale, activation.
-  {
-    // Add bias.
-    if (bias_tensor) {
-      // Vector batch assign:
-      const int32_t* bias_data =
-          tflite::micro::GetTensorData<int32_t>(bias_tensor);
-      for (int i = 0; i < n_batch; ++i) {
-        int32_t* output_ptr = scratch_output_tensor + i * n_unit;
-        const int32_t* bias_ptr = bias_data;
-        for (int j = 0; j < n_unit; ++j) {
-          *output_ptr++ = *bias_ptr++;
-        }
-      }
-    } else {
-      int32_t* output_ptr = scratch_output_tensor;
-      for (int i = 0; i < n_batch * n_unit; ++i) {
-        *output_ptr++ = 0;
-      }
-    }
-
-    // Reduce.
-    for (int b = 0; b < n_batch; ++b) {
-      int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
-      int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
-
-      // Reduction sum vector
-      for (int i = 0; i < n_unit; ++i) {
-        for (int j = 0; j < n_rank; ++j) {
-          output_temp_ptr[i] += *scratch_ptr_batch++;
-        }
-      }
-    }
-
-    // Rescale.
-    ae_q56s output_int8_max_56 = AE_CVTQ48A32S(INT8_MAX);
-    ae_q56s output_int8_min_56 = AE_CVTQ48A32S(INT8_MIN);
-    ae_q56s output_zp_56 = AE_CVTQ48A32S(data.output_zero_point);
-    for (int i = 0; i < n_batch * n_unit; ++i) {
-      ae_q56s x_56 = MultiplyByQuantizedMultiplierResult48Bit(
-          scratch_output_tensor[i], data.effective_scale_2_a,
-          data.effective_scale_2_b);
-      // Add output adjustment:
-      x_56 = AE_ADDQ56(x_56, output_zp_56);
-      // Cap min/max and convert to int32_t (already aligned to 32bit):
-      x_56 = AE_MAXQ56S(x_56, output_int8_min_56);
-      x_56 = AE_MINQ56S(x_56, output_int8_max_56);
-      tflite::micro::GetTensorData<int8_t>(output_tensor)[i] =
-          static_cast<int8_t>(AE_TRUNCA32Q48(x_56));
-    }
-  }
-}
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->builtin_data != nullptr);
-  const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
-
-  // Validate Tensor Inputs (dtype depends on quantization):
-  // [0] = Input, {2, batch_size, input_size}
-  // [1] = Weights Feature, {2, num_filters, input_size}
-  // [2] = Weights Time, {2, num_filters, memory_size}
-  // [3] = Bias (optional), {1, num_units}
-  // [4] = Activation State (variable),
-  //         {2, batch_size, memory_size * num_filters}
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  const TfLiteTensor* weights_feature =
-      GetInput(context, node, kWeightsFeatureTensor);
-  const TfLiteTensor* weights_time =
-      GetInput(context, node, kWeightsTimeTensor);
-  const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
-  const TfLiteTensor* activation_state =
-      GetInput(context, node, kInputActivationStateTensor);
-
-  // Define input constants based on input tensor definition above:
-  const int rank = params->rank;
-  const int input_size = input->dims->data[1];
-  const int batch_size = input->dims->data[0];
-  // Ensure the input size is a multiple of two.  This is necessary since
-  // optimized kernels access the memory in chunks of two, and all accesses
-  // must be aligned to 16 bits.
-  // TODO(b/153202598): Remove when padding is allowed in TFLite tensors.
-  TF_LITE_ENSURE_EQ(context, input_size % 2, 0);
-
-  const int num_filters = weights_feature->dims->data[0];
-  TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
-  const int num_units = num_filters / rank;
-  const int memory_size = weights_time->dims->data[1];
-
-  if (input->type != kTfLiteInt8) {
-    TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
-                       TfLiteTypeGetName(input->type), input->type);
-    return kTfLiteError;
-  }
-
-  // Validate Input Tensor:
-  TF_LITE_ENSURE(context, input->type == kTfLiteInt8);
-  TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
-
-  // Validate Tensor Output:
-  // [0] = float/int8_t, {2, batch_size, num_units}
-  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-  TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
-  TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
-  TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
-
-  // Validate Weights Feature Input Tensor:
-  TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
-  TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
-
-  // Validate Weights Time Input Tensor:
-  TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
-  TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
-  TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
-
-  // Validate Optional Bias Input Tensor:
-  if (bias != nullptr) {
-    TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
-    TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
-  }
-
-  // Validate Activation State Input Tensor:
-  TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
-  TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
-  TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
-                    memory_size * num_filters);
-
-  TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
-  TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
-  TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
-  TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
-
-  // Validate output tensor:
-  TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
-
-  const double effective_scale_1 =
-      static_cast<double>(input->params.scale * weights_feature->params.scale /
-                          activation_state->params.scale);
-  const double effective_scale_2 =
-      static_cast<double>(activation_state->params.scale *
-                          weights_time->params.scale / output->params.scale);
-
-  TF_LITE_ENSURE_EQ(context, static_cast<double>(bias->params.scale),
-                    static_cast<double>(activation_state->params.scale *
-                                        weights_time->params.scale));
-
-  TFLITE_DCHECK(node->user_data != nullptr);
-  OpData* data = static_cast<OpData*>(node->user_data);
-
-  QuantizeMultiplierForInt24(effective_scale_1, &data->effective_scale_1_a,
-                             &data->effective_scale_1_b);
-  QuantizeMultiplierForInt24(effective_scale_2, &data->effective_scale_2_a,
-                             &data->effective_scale_2_b);
-
-  data->input_zero_point = input->params.zero_point;
-  data->output_zero_point = output->params.zero_point;
-
-  const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
-      context, batch_size * num_filters * sizeof(int32_t),
-      &(data->scratch_tensor_index));
-  TF_LITE_ENSURE_OK(context, scratch_status);
-  const TfLiteStatus scratch_output_status =
-      context->RequestScratchBufferInArena(
-          context, batch_size * num_units * sizeof(int32_t),
-          &(data->scratch_output_tensor_index));
-  TF_LITE_ENSURE_OK(context, scratch_output_status);
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  auto* params = static_cast<TfLiteSVDFParams*>(node->builtin_data);
-
-  const TfLiteEvalTensor* input =
-      tflite::micro::GetEvalInput(context, node, kInputTensor);
-  const TfLiteEvalTensor* weights_feature =
-      tflite::micro::GetEvalInput(context, node, kWeightsFeatureTensor);
-  const TfLiteEvalTensor* weights_time =
-      tflite::micro::GetEvalInput(context, node, kWeightsTimeTensor);
-  const TfLiteEvalTensor* bias =
-      (NumInputs(node) == 5)
-          ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
-          : nullptr;
-  TfLiteEvalTensor* activation_state = tflite::micro::GetMutableEvalInput(
-      context, node, kInputActivationStateTensor);
-  TfLiteEvalTensor* output =
-      tflite::micro::GetEvalOutput(context, node, kOutputTensor);
-
-  TFLITE_DCHECK(node->user_data != nullptr);
-  const OpData& data = *(static_cast<const OpData*>(node->user_data));
-
-  EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
-                  params, activation_state, output, data);
-  return kTfLiteOk;
-}
-
-}  // namespace
-
-TfLiteRegistration Register_SVDF() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
-}
-
-}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/fully_connected.cc
deleted file mode 100644
index f9b49a2f1ae..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/fully_connected.cc
+++ /dev/null
@@ -1,197 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
-
-#include <xtensa/tie/xt_hifi2.h>
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h"
-namespace tflite {
-namespace ops {
-namespace micro {
-
-namespace fully_connected {
-namespace {
-
-struct OpData {
-  // The scaling factor from input to output (aka the 'real multiplier') can
-  // be represented as a fixed point multiplier plus a left shift.
-  int32_t output_multiplier;
-  int output_shift;
-  // The range of the fused activation layer. For example for kNone and
-  // uint8_t these would be 0 and 255.
-  int32_t output_activation_min;
-  int32_t output_activation_max;
-  // The index of the temporary tensor where the quantized inputs are cached.
-  int input_quantized_index;
-};
-
-constexpr int kInputTensor = 0;
-constexpr int kWeightsTensor = 1;
-constexpr int kBiasTensor = 2;
-constexpr int kOutputTensor = 0;
-
-TfLiteStatus CalculateOpData(TfLiteContext* context,
-                             TfLiteFusedActivation activation,
-                             TfLiteType data_type, const TfLiteTensor* input,
-                             const TfLiteTensor* filter,
-                             const TfLiteTensor* bias, TfLiteTensor* output,
-                             OpData* data) {
-  if (data_type != kTfLiteInt8) {
-    TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
-                       TfLiteTypeGetName(data_type), data_type);
-    return kTfLiteError;
-  }
-
-  double real_multiplier = 0.0;
-  TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
-      context, input, filter, bias, output, &real_multiplier));
-  xtensa::hifimini::QuantizeMultiplier(
-      real_multiplier, &data->output_multiplier, &data->output_shift);
-  return CalculateActivationRangeQuantized(context, activation, output,
-                                           &data->output_activation_min,
-                                           &data->output_activation_max);
-}
-
-}  // namespace
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  TFLITE_DCHECK(node->builtin_data != nullptr);
-
-  OpData* data = static_cast<OpData*>(node->user_data);
-  const auto* params =
-      reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
-
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
-  const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
-  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
-  return CalculateOpData(context, params->activation, input->type, input,
-                         filter, bias, output, data);
-}
-
-TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
-                               const OpData& data, const TfLiteTensor* input,
-                               const TfLiteTensor* filter,
-                               const TfLiteTensor* bias, TfLiteTensor* output) {
-  // TODO(b/154032858): Investigate removing extra copies.
-  FullyConnectedParams op_params;
-  op_params.input_offset = -input->params.zero_point;
-  op_params.weights_offset = -filter->params.zero_point;
-  op_params.output_offset = output->params.zero_point;
-  op_params.output_multiplier = data.output_multiplier;
-  op_params.output_shift = data.output_shift;
-  op_params.quantized_activation_min = data.output_activation_min;
-  op_params.quantized_activation_max = data.output_activation_max;
-
-  {
-    int ret, b, weight_depth, out_depth, batches;
-    int8_t* p_out = GetTensorData<int8_t>(output);
-    weight_depth = GetTensorShape(filter).Dims(
-        GetTensorShape(filter).DimensionsCount() - 1);
-    out_depth = GetTensorShape(output).Dims(
-        GetTensorShape(output).DimensionsCount() - 1);
-    batches = FlatSizeSkipDim(GetTensorShape(output),
-                              GetTensorShape(output).DimensionsCount() - 1);
-
-    // TODO: Use xa_nn_fully_connected_sym8xasym8s_asym8s? the kernel tests fail
-    // with it.
-    for (b = 0; b < batches; b++) {
-      ret = xa_nn_fully_connected_asym8sxasym8s_asym8s(
-          (GetTensorData<int8_t>(output) + b * out_depth),
-          GetTensorData<int8_t>(filter),
-          (GetTensorData<int8_t>(input) + b * weight_depth),
-          GetTensorData<int32_t>(bias), weight_depth, out_depth,
-          op_params.weights_offset, op_params.input_offset,
-          (op_params.output_multiplier << 8), op_params.output_shift,
-          op_params.output_offset);
-      CHECK_ERR_HIFI_NNLIB_KER(
-          ret, "xa_nn_fully_connected_sym8xasym8s_asym8s failed");
-    }
-    ret = xa_nn_vec_activation_min_max_asym8s_asym8s(
-        p_out, p_out, data.output_activation_min, data.output_activation_max,
-        batches * out_depth);
-    CHECK_ERR_HIFI_NNLIB_KER(
-        ret,
-        "fully_connected: xa_nn_vec_activation_min_max_asym8s_asym8s failed");
-  }
-  return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  const OpData& data = *(static_cast<const OpData*>(node->user_data));
-
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
-  const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
-  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
-  TFLITE_DCHECK(filter->type == kTfLiteInt8);
-  return EvalQuantizedInt8(context, node, data, input, filter, bias, output);
-}
-
-}  // namespace fully_connected
-
-TfLiteRegistration Register_FULLY_CONNECTED() {
-  return {/*init=*/fully_connected::Init,
-          /*free=*/nullptr,
-          /*prepare=*/fully_connected::Prepare,
-          /*invoke=*/fully_connected::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
-}
-
-}  // namespace micro
-}  // namespace ops
-}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/quantize.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/quantize.cc
deleted file mode 100644
index 13c19cc6f34..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/quantize.cc
+++ /dev/null
@@ -1,172 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/quantize.h"
-
-#include <xtensa/tie/xt_hifi2.h>
-
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-
-namespace xtensa {
-namespace hifimini {
-
-void AffineQuantize(int scale_multiplier,
-                    const tflite::QuantizationParams& op_params,
-                    const RuntimeShape& input_shape, const int16_t* input_data,
-                    const RuntimeShape& output_shape, int8_t* output_data) {
-  const int32_t zero_point = op_params.zero_point;
-  const int flat_size = MatchingFlatSize(input_shape, output_shape);
-  ae_q56s min_val_56 = AE_CVTQ48A32S(INT16_MIN);
-  ae_q56s max_val_56 = AE_CVTQ48A32S(INT16_MAX);
-  ae_q56s zero_point_56 = AE_CVTQ48A32S(zero_point);
-
-  const ae_p16x2s* input_data_ptr = (const ae_p16x2s*)(input_data - 2);
-
-  ae_p24x2s scale_multiplier_24x2 = AE_MOVPA24(scale_multiplier);
-
-  int iters = flat_size / 2;
-  for (int i = 0; i < iters; i++) {
-    // Load two 16bit pairs into the 2x24bit register PR:
-    // Values need to be right shifted 8 bits to align from upper 16bits to a
-    // 24bit value:
-    ae_p24x2s inputs_24x2;
-    AE_LP16X2F_IU(inputs_24x2, input_data_ptr, 4);
-    inputs_24x2 = AE_P24X2S_SRAI(inputs_24x2, 8);
-
-    // Q0.23 * Q16.0 == Q16.23
-    {
-      ae_q56s sum_56 = AE_MULP24S_HH(scale_multiplier_24x2, inputs_24x2);
-
-      // Q16.23 -> Q16.0
-      // Shift right only 7 bits (23 - 16). This truncated shift aligns the
-      // 16bit value at the truncation line for 32bit in the QR register. The
-      // lower 16 bits will be used for rounding in AE_ROUNDSQ32SYM.
-      sum_56 = AE_Q56S_SRAI(sum_56, 7);
-
-      // Round and truncate 32 bits
-      sum_56 = AE_ROUNDSQ32SYM(sum_56);
-
-      // Add offset (zero_point_56 is already aligned at 32bits.
-      sum_56 = AE_ADDQ56(sum_56, zero_point_56);
-
-      // Saturate:
-      sum_56 = AE_MINQ56S(sum_56, max_val_56);
-      sum_56 = AE_MAXQ56S(sum_56, min_val_56);
-
-      output_data[i * 2] = static_cast<int16_t>(AE_TRUNCA32Q48(sum_56));
-    }
-    {
-      ae_q56s sum_56 = AE_MULP24S_LL(scale_multiplier_24x2, inputs_24x2);
-
-      // Q16.23 -> Q16.0
-      // Shift right only 7 bits (23 - 16). This truncated shift aligns the
-      // 16bit value at the truncation line for 32bit in the QR register. The
-      // lower 16 bits will be used for rounding in AE_ROUNDSQ32SYM.
-      sum_56 = AE_Q56S_SRAI(sum_56, 23 - 16);
-
-      // Round and truncate 32 bits
-      sum_56 = AE_ROUNDSQ32SYM(sum_56);
-
-      // Add offset (zero_point_56 is already aligned at 32bits.
-      sum_56 = AE_ADDQ56(sum_56, zero_point_56);
-
-      // Saturate:
-      sum_56 = AE_MINQ56S(sum_56, max_val_56);
-      sum_56 = AE_MAXQ56S(sum_56, min_val_56);
-
-      output_data[i * 2 + 1] = static_cast<int16_t>(AE_TRUNCA32Q48(sum_56));
-    }
-  }
-}
-
-}  // namespace hifimini
-}  // namespace xtensa
-
-namespace quantize {
-
-struct OpData {
-  int scale_multiplier = 0;
-};
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  auto* op_data = static_cast<OpData*>(node->user_data);
-
-  TfLiteTensor* output = GetOutput(context, node, 0);
-  const TfLiteTensor* input = GetInput(context, node, 0);
-
-  // TODO(b/155682734): Fix dangerous input/output scale ratio assumptions.
-  op_data->scale_multiplier = xtensa::hifimini::CreateQConstantForInt24(
-      0, input->params.scale / output->params.scale);
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->user_data != nullptr);
-  auto* op_data = static_cast<OpData*>(node->user_data);
-
-  const TfLiteTensor* input = GetInput(context, node, 0);
-  TfLiteTensor* output = GetOutput(context, node, 0);
-
-  tflite::QuantizationParams op_params;
-  op_params.zero_point = output->params.zero_point;
-
-  if (input->type != kTfLiteInt16 && output->type != kTfLiteInt8) {
-    TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
-                       TfLiteTypeGetName(input->type),
-                       TfLiteTypeGetName(output->type));
-    return kTfLiteError;
-  }
-
-  xtensa::hifimini::AffineQuantize(
-      op_data->scale_multiplier, op_params, GetTensorShape(input),
-      GetTensorData<int16_t>(input), GetTensorShape(output),
-      GetTensorData<int8_t>(output));
-  return kTfLiteOk;
-}
-
-}  // namespace quantize
-
-// This Op (QUANTIZE) quantizes the input and produces quantized output.
-// AffineQuantize takes scale and zero point and quantizes the float value to
-// quantized output, in int8_t or uint8_t format.
-TfLiteRegistration Register_QUANTIZE() {
-  return {/*init=*/quantize::Init,
-          /*free=*/nullptr,
-          /*prepare=*/quantize::Prepare,
-          /*invoke=*/quantize::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
-}
-
-}  // namespace micro
-}  // namespace ops
-}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/softmax.cc
deleted file mode 100644
index 3e5ef198928..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/softmax.cc
+++ /dev/null
@@ -1,189 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/softmax.h"
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/op_macros.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace activations {
-namespace {
-
-struct OpData {
-  int32_t input_multiplier;
-  int32_t input_left_shift;
-  int32_t diff_min;
-  int scratch_tensor_index;
-};
-
-}  // namespace
-
-TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
-                                    const TfLiteTensor* input,
-                                    TfLiteTensor* output,
-                                    const TfLiteSoftmaxParams* params,
-                                    OpData* op_data) {
-  if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
-    if (input->type == kTfLiteUInt8) {
-      TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
-    } else {
-      if (output->type == kTfLiteInt16) {
-        TF_LITE_ENSURE_EQ(context, output->params.zero_point,
-                          std::numeric_limits<int16_t>::min());
-        // NOTE: Current int16_t softmax output does not require symmetric
-        // scaling
-        // - so no need to verify scale here.
-      } else {
-        TF_LITE_ENSURE_EQ(context, output->params.zero_point,
-                          std::numeric_limits<int8_t>::min());
-        TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
-      }
-    }
-
-    static const int kScaledDiffIntegerBits = 5;
-
-    int input_left_shift;
-    tflite::PreprocessSoftmaxScaling(
-        static_cast<double>(params->beta),
-        static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
-        &op_data->input_multiplier, &input_left_shift);
-    op_data->input_left_shift = input_left_shift;
-    op_data->diff_min =
-        -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
-                                            op_data->input_left_shift);
-  }
-  return kTfLiteOk;
-}
-
-void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
-  auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
-
-  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
-  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-  const TfLiteTensor* input = GetInput(context, node, 0);
-  TfLiteTensor* output = GetOutput(context, node, 0);
-  TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
-
-  TFLITE_DCHECK(node->user_data != nullptr);
-  OpData* op_data = static_cast<OpData*>(node->user_data);
-
-  const RuntimeShape& input_shape = GetTensorShape(input);
-  const RuntimeShape& output_shape = GetTensorShape(output);
-  const int trailing_dim = input_shape.DimensionsCount() - 1;
-  const int depth =
-      MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
-  int scratch_size =
-      xa_nn_get_softmax_scratch_size(PREC_SYM8S, PREC_SYM8S, depth);
-
-  const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
-      context, scratch_size, &(op_data->scratch_tensor_index));
-  TF_LITE_ENSURE_OK(context, scratch_status);
-  // Allocate an array to precompute exponents over all int8_t inputs, applying
-  // the scale and beta before calculating exp. It is mandatory to apply beta
-  // and scale here, since each softmax op may have different beta and scale
-  // values. Beta and scale will remain constant for a given softmax op.
-
-  TF_LITE_ENSURE_STATUS(
-      CalculateSoftmaxOpData(context, input, output, params, op_data));
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
-  auto* op_data = static_cast<OpData*>(node->user_data);
-
-  const TfLiteTensor* input = GetInput(context, node, 0);
-  TfLiteTensor* output = GetOutput(context, node, 0);
-
-  if (input->type == kTfLiteInt8 && output->type == kTfLiteInt16) {
-    const RuntimeShape& input_shape = GetTensorShape(input);
-    const int8_t* input_data = GetTensorData<int8_t>(input);
-    const RuntimeShape& output_shape = GetTensorShape(output);
-    int16_t* output_data = GetTensorData<int16_t>(output);
-    const int trailing_dim = input_shape.DimensionsCount() - 1;
-    const int outer_size =
-        MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
-    const int depth =
-        MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
-
-    void* p_scratch = static_cast<void*>(
-        context->GetScratchBuffer(context, op_data->scratch_tensor_index));
-    TFLITE_DCHECK(p_scratch != nullptr);
-
-    for (int i = 0; i < outer_size; ++i) {
-      int err = xa_nn_vec_softmax_asym8s_16(
-          &output_data[i * depth], &input_data[i * depth], op_data->diff_min,
-          op_data->input_left_shift, op_data->input_multiplier, depth,
-          p_scratch);
-      CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_softmax_asym8s_16 failed");
-    }
-    return kTfLiteOk;
-  } else {
-    TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
-                       TfLiteTypeGetName(input->type), input->type);
-    return kTfLiteError;
-  }
-}
-}  // namespace activations
-
-TfLiteRegistration Register_SOFTMAX() {
-  return {/*init=*/activations::SoftmaxInit,
-          /*free=*/nullptr,
-          /*prepare=*/activations::SoftmaxPrepare,
-          /*invoke=*/activations::SoftmaxEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
-}
-
-}  // namespace micro
-}  // namespace ops
-}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/svdf.cc
deleted file mode 100644
index 05256f33306..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/svdf.cc
+++ /dev/null
@@ -1,356 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <math.h>
-#include <xtensa/tie/xt_hifi2.h>
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/op_macros.h"
-#include "tensorflow/lite/micro/kernels/activation_utils.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace svdf {
-namespace {
-
-struct OpData {
-  int32_t effective_scale_1_a;
-  int32_t effective_scale_2_a;
-  // b versions of each scale are kept at int since the numbers are just the
-  // shift value - typically between [-32, 32].
-  int effective_scale_1_b;
-  int effective_scale_2_b;
-  int scratch_tensor_index;
-  int scratch_output_tensor_index;
-};
-
-// Input tensors.
-constexpr int kInputTensor = 0;
-constexpr int kWeightsFeatureTensor = 1;
-constexpr int kWeightsTimeTensor = 2;
-constexpr int kBiasTensor = 3;
-// This is a variable tensor, and will be modified by this op.
-constexpr int kInputActivationStateTensor = 4;
-
-// Output tensor.
-constexpr int kOutputTensor = 0;
-
-/**
- * This version of SVDF is specific to TFLite Micro. It contains only a full
- * integer receipe with optimizations for the Xtensa HiFiMini platform.
- *
- * Note: passing OpData by value might seem like an oversight but it helps
- * reduce the latency. See b/155656675 for more details.
- */
-TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
-                             const TfLiteTensor* input_tensor,
-                             const TfLiteTensor* weights_feature_tensor,
-                             const TfLiteTensor* weights_time_tensor,
-                             const TfLiteTensor* bias_tensor,
-                             const TfLiteSVDFParams* params,
-                             TfLiteTensor* activation_state_tensor,
-                             TfLiteTensor* output_tensor, OpData data,
-                             int32_t input_zp, int32_t output_zp) {
-  const int n_rank = params->rank;
-  const int n_batch = input_tensor->dims->data[0];
-  const int n_input = input_tensor->dims->data[1];
-  const int n_filter = weights_feature_tensor->dims->data[0];
-  const int n_unit = n_filter / n_rank;
-  const int n_memory = weights_time_tensor->dims->data[1];
-
-  TFLITE_DCHECK(context != nullptr);
-  TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
-
-  int32_t* scratch_tensor = static_cast<int32_t*>(
-      context->GetScratchBuffer(context, data.scratch_tensor_index));
-  TFLITE_DCHECK(scratch_tensor != nullptr);
-  int32_t* scratch_output_tensor = static_cast<int32_t*>(
-      context->GetScratchBuffer(context, data.scratch_output_tensor_index));
-  TFLITE_DCHECK(scratch_output_tensor != nullptr);
-
-  // Shift states.
-  int16_t* const state_ptr = GetTensorData<int16_t>(activation_state_tensor);
-
-  // Left shift the activation_state.
-
-  // 4-byte alignment check for state_ptr
-  if (((reinterpret_cast<int>(state_ptr)) & 0x3) == 0) {
-    // 4-bytes aligned processing
-    ae_p16x2s* new_state_start = (ae_p16x2s*)(state_ptr - 2);
-    const ae_p16x2s* old_state_start = (ae_p16x2s*)(state_ptr - 2);
-    int loopcnt = (n_batch * n_filter * n_memory) - 1;
-    ae_p24x2s dstate, dtmp, dout;
-
-    AE_LP16X2F_IU(dtmp, old_state_start, 4);
-    AE_LP16X2F_IU(dstate, old_state_start, 4);
-    for (int i = 0; i < (loopcnt >> 1); i++) {
-      dout = AE_SELP24_LH(dtmp, dstate);
-      dtmp = dstate;
-      AE_LP16X2F_IU(dstate, old_state_start, 4);
-      AE_SP16X2F_IU(dout, new_state_start, 4);
-    }
-    if (loopcnt & 0x1) {
-      AE_SP16F_L_I(dtmp, (ae_p16s*)new_state_start, 4);
-    }
-  } else {
-    // 2-bytes aligned processing
-    ae_p16s* new_state_start = (ae_p16s*)(state_ptr - 1);
-    const ae_p16s* old_state_start = (ae_p16s*)(state_ptr);
-    int loopcnt = (n_batch * n_filter * n_memory) - 1;
-    ae_p24x2s dstate;
-    for (int i = 0; i < loopcnt; i++) {
-      AE_LP16F_IU(dstate, old_state_start, 2);
-      AE_SP16F_L_IU(dstate, new_state_start, 2);
-    }
-  }
-  // Note: no need to clear the latest activation, matmul is not accumulative.
-
-  // Feature matmul.
-  {
-    int16_t* state = GetTensorData<int16_t>(activation_state_tensor);
-    const int8_t* input = GetTensorData<int8_t>(input_tensor);
-    const int8_t* weight_feature =
-        GetTensorData<int8_t>(weights_feature_tensor);
-    int16_t* result_in_batch = state + (n_memory - 1);
-    int err = 0;
-
-    for (int b = 0; b < n_batch; b++) {
-      err = xa_nn_matXvec_out_stride_sym8sxasym8s_16(
-          &result_in_batch[b * n_filter * n_memory], weight_feature,
-          &input[b * n_input], NULL, n_filter, n_input, n_input, n_memory,
-          -input_zp, (data.effective_scale_1_a << 8), data.effective_scale_1_b);
-      CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_matXvec_sym8sxasym8s_16 failed");
-    }
-  }
-
-  // Time.
-  {
-    for (int b = 0; b < n_batch; ++b) {
-      int8_t* output_ptr = GetTensorData<int8_t>(output_tensor) + b * n_unit;
-
-      const int16_t* vector1_ptr = GetTensorData<int16_t>(weights_time_tensor);
-      const int16_t* vector2_ptr =
-          GetTensorData<int16_t>(activation_state_tensor) +
-          b * n_memory * n_filter;
-      int err = 0;
-      const int32_t* bias_ptr = GetTensorData<int32_t>(bias_tensor);
-      err = xa_nn_dot_prod_16x16_asym8s(
-          output_ptr, vector1_ptr, vector2_ptr, bias_ptr, n_memory * n_rank,
-          (data.effective_scale_2_a << 8), data.effective_scale_2_b, output_zp,
-          n_unit);
-      CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_dot_prod_16x16_asym8s failed");
-    }
-  }
-  return kTfLiteOk;
-}
-
-}  // namespace
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  TFLITE_DCHECK(context != nullptr);
-  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
-  return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->builtin_data != nullptr);
-  const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
-
-  // Validate Tensor Inputs (dtype depends on quantization):
-  // [0] = Input, {2, batch_size, input_size}
-  // [1] = Weights Feature, {2, num_filters, input_size}
-  // [2] = Weights Time, {2, num_filters, memory_size}
-  // [3] = Bias (optional), {1, num_units}
-  // [4] = Activation State (variable),
-  //         {2, batch_size, memory_size * num_filters}
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  const TfLiteTensor* weights_feature =
-      GetInput(context, node, kWeightsFeatureTensor);
-  const TfLiteTensor* weights_time =
-      GetInput(context, node, kWeightsTimeTensor);
-  const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
-  const TfLiteTensor* activation_state =
-      GetInput(context, node, kInputActivationStateTensor);
-
-  // Define input constants based on input tensor definition above:
-  const int rank = params->rank;
-  const int input_size = input->dims->data[1];
-  const int batch_size = input->dims->data[0];
-  // Ensure the input size is a multiple of two.  This is necessary since
-  // optimized kernels access the memory in chunks of two, and all accesses
-  // must be aligned to 16 bits.
-  // TODO(b/153202598): Remove when padding is allowed in TFLite tensors.
-  TF_LITE_ENSURE_EQ(context, input_size % 2, 0);
-
-  const int num_filters = weights_feature->dims->data[0];
-  TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
-  const int num_units = num_filters / rank;
-  const int memory_size = weights_time->dims->data[1];
-
-  if (input->type != kTfLiteInt8) {
-    TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
-                       TfLiteTypeGetName(input->type), input->type);
-    return kTfLiteError;
-  }
-
-  // Validate Input Tensor:
-  TF_LITE_ENSURE(context, input->type == kTfLiteInt8);
-  TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
-
-  // Validate Tensor Output:
-  // [0] = float/int8_t, {2, batch_size, num_units}
-  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-  TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
-  TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
-  TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
-
-  // Validate Weights Feature Input Tensor:
-  TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
-  TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
-
-  // Validate Weights Time Input Tensor:
-  TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
-  TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
-  TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
-
-  // Validate Optional Bias Input Tensor:
-  if (bias != nullptr) {
-    TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
-    TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
-  }
-
-  // Validate Activation State Input Tensor:
-  TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
-  TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
-  TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
-                    memory_size * num_filters);
-
-  TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
-  TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
-  TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
-  TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
-
-  // Validate output tensor:
-  TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8);
-
-  // Calculate effective scales.
-  auto* input_params =
-      static_cast<TfLiteAffineQuantization*>(input->quantization.params);
-  auto* weights_feature_params = static_cast<TfLiteAffineQuantization*>(
-      weights_feature->quantization.params);
-  auto* state_params = static_cast<TfLiteAffineQuantization*>(
-      activation_state->quantization.params);
-  auto* weight_time_params =
-      static_cast<TfLiteAffineQuantization*>(weights_time->quantization.params);
-  auto* output_params =
-      static_cast<TfLiteAffineQuantization*>(output->quantization.params);
-  const float effective_scale_1 = input_params->scale->data[0] *
-                                  weights_feature_params->scale->data[0] /
-                                  state_params->scale->data[0];
-  const float effective_scale_2 = state_params->scale->data[0] *
-                                  weight_time_params->scale->data[0] /
-                                  output_params->scale->data[0];
-
-  TFLITE_DCHECK(node->user_data != nullptr);
-  OpData* data = static_cast<OpData*>(node->user_data);
-
-  xtensa::hifimini::QuantizeMultiplier(effective_scale_1,
-                                       &data->effective_scale_1_a,
-                                       &data->effective_scale_1_b);
-  xtensa::hifimini::QuantizeMultiplier(effective_scale_2,
-                                       &data->effective_scale_2_a,
-                                       &data->effective_scale_2_b);
-
-  const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
-      context, batch_size * num_filters * sizeof(int32_t),
-      &(data->scratch_tensor_index));
-  TF_LITE_ENSURE_OK(context, scratch_status);
-  const TfLiteStatus scratch_output_status =
-      context->RequestScratchBufferInArena(
-          context, batch_size * num_units * sizeof(int32_t),
-          &(data->scratch_output_tensor_index));
-  TF_LITE_ENSURE_OK(context, scratch_output_status);
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  auto* params = static_cast<TfLiteSVDFParams*>(node->builtin_data);
-
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  const TfLiteTensor* weights_feature =
-      GetInput(context, node, kWeightsFeatureTensor);
-  const TfLiteTensor* weights_time =
-      GetInput(context, node, kWeightsTimeTensor);
-  const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
-  TfLiteTensor* activation_state =
-      GetVariableInput(context, node, kInputActivationStateTensor);
-  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-  TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
-
-  TFLITE_DCHECK(node->user_data != nullptr);
-  const OpData& data = *(static_cast<const OpData*>(node->user_data));
-
-  return EvalIntegerSVDF(context, node, input, weights_feature, weights_time,
-                         bias, params, activation_state, output, data,
-                         input->params.zero_point, output->params.zero_point);
-}
-
-}  // namespace svdf
-
-TfLiteRegistration Register_SVDF() {
-  return {/*init=*/svdf::Init,
-          /*free=*/nullptr,
-          /*prepare=*/svdf::Prepare,
-          /*invoke=*/svdf::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
-}
-
-}  // namespace micro
-}  // namespace ops
-}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_api_defs.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_api_defs.h
deleted file mode 100644
index a3eac676bbe..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_api_defs.h
+++ /dev/null
@@ -1,65 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __XA_API_DEFS_H__
-#define __XA_API_DEFS_H__
-
-/*****************************************************************************/
-/* Constant hash defines                                                     */
-/*****************************************************************************/
-/* A constant to let API copy small strings to buffers outside */
-#define XA_API_STR_LEN 30
-#define XA_APIVERSION_MAJOR 1
-#define XA_APIVERSION_MINOR 0
-
-/* last compatible version */
-/* sometimes a new API version is just for a bugfix, or a added feature  in */
-/* this case it is better to use a newer version even though a library  was */
-/* made for an older version, library API can then be upgraded to newer API */
-/* version after checking for compatibility or by adding features           */
-#define XA_LASTCOMP_APIVERSION_MAJOR 1
-#define XA_LASTCOMP_APIVERSION_MINOR 0
-
-#define XA_STR(str) #str
-#define XA_MAKE_VERSION_STR(maj, min) XA_STR(maj) "." XA_STR(min)
-#define XA_APIVERSION \
-  XA_MAKE_VERSION_STR(XA_APIVERSION_MAJOR, XA_APIVERSION_MINOR)
-
-#define XA_LAST_COMP_APIVERSION                     \
-  XA_MAKE_VERSION_STR(XA_LASTCOMP_APIVERSION_MAJOR, \
-                      XA_LASTCOMP_APIVERSION_MINOR)
-
-#endif /* __XA_API_DEFS_H__ */
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common.h
deleted file mode 100644
index 71e668299e6..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common.h
+++ /dev/null
@@ -1,55 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __XA_NNLIB_COMMON_H__
-#define __XA_NNLIB_COMMON_H__
-
-#include <inttypes.h>
-#include <stddef.h>
-#include <xtensa/config/core-isa.h>
-#include <xtensa/tie/xt_core.h>
-#include <xtensa/tie/xt_hifi2.h>
-#include <xtensa/tie/xt_misc.h>
-#if XCHAL_HAVE_HIFI4_VFPU
-#include <xtensa/tie/xt_FP.h>
-#endif
-
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_standards.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h"
-
-#endif /* __XA_NNLIB_COMMON_H__ */
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common_macros.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common_macros.h
deleted file mode 100644
index d04752b3a12..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common_macros.h
+++ /dev/null
@@ -1,921 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __XA_NNLIB_COMMON_MACROS_H__
-#define __XA_NNLIB_COMMON_MACROS_H__
-
-#ifndef NULL
-#define NULL (void *)0
-#endif /* NULL */
-
-#define ALIGNMENT 8
-
-/* Macro for zero value */
-#define ZERO64 AE_MOVINT64_FROMINT32X2(AE_MOVDA32(0))
-#define ZERO16X4 AE_MOVDA16(0)
-#define ZERO16 (0)
-#define ZERO32 (0)
-
-/* Macro for 1 */
-#define ONE16X4 AE_MOVDA16(1)
-
-/* Value of ROW_UNROLL currently supported are 1,2,4,8 only */
-#ifndef ROW_UNROLL
-#define ROW_UNROLL 8
-#endif
-#define VEC_UNROLL 2
-
-#define ACC_LSH_AFTER_FIRST_MATXVEC 0
-
-/* Increment in bytes required for particular load
- * instructions. */
-#define INCREMENT_IN_BYTES_FOR_WORD8 1
-#define INCREMENT_IN_BYTES_FOR_INT16 2
-#define INCREMENT_IN_BYTES_FOR_INT32 (INCREMENT_IN_BYTES_FOR_INT16 * 2)
-#define INCREMENT_IN_BYTES_FOR_WORD8X4 (INCREMENT_IN_BYTES_FOR_WORD8 * 4)
-#define INCREMENT_IN_BYTES_FOR_INT16X4 (INCREMENT_IN_BYTES_FOR_INT16 * 4)
-#define INCREMENT_IN_BYTES_FOR_INT64 INCREMENT_IN_BYTES_FOR_INT16X4
-#define INCREMENT_IN_BYTES_FOR_FLOAT32 4
-#define INCREMENT_IN_BYTES_FOR_FLOAT32x2 (INCREMENT_IN_BYTES_FOR_FLOAT32 * 2)
-
-#define HF2_AE_ADDCIRC16X4_XC(ptr, offset) \
-  ptr = ptr + offset;                      \
-  if (ptr >= p_end) ptr = ptr - size;
-
-#define MULTIPLY_BY_QUANTIZED_MULTIPLIER(q_out, inp, out_multiplier, \
-                                         left_shift, right_shift)    \
-  {                                                                  \
-    ae_q56s d1;                                                      \
-    ae_p24x2s d_mul;                                                 \
-    d_mul = AE_CVTP24A16X2_HL(out_multiplier, out_multiplier);       \
-    d1 = AE_CVTQ48A32S(inp);                                         \
-    d1 = AE_SLLAQ56(d1, left_shift);                                 \
-    q_out = AE_MULFQ32SP16U_L(d1, d_mul);                            \
-    q_out = AE_SRAIQ56(q_out, 16);                                   \
-    AE_MULAFQ32SP16S_H(q_out, d1, d_mul);                            \
-    q_out = AE_SRAAQ56(q_out, right_shift);                          \
-    q_out = AE_ROUNDSQ32SYM(q_out);                                  \
-  }
-
-/* Limit effective bias_shift and acc_shift to [-63 ... 63] */
-#define LIMIT_VARIABLE(_var, _left_limit, _right_limit) \
-  _var = _var > _right_limit ? _right_limit             \
-                             : _var < _left_limit ? _left_limit : _var;
-
-#define LIMIT_ACC_LSH LIMIT_VARIABLE(acc_shift, -63, 63);
-
-#define LIMIT_BIAS_LSH LIMIT_VARIABLE(bias_shift, -63, 63);
-
-#define BW(_datatype) sizeof(_datatype)
-
-#define ADJUST_VAR_AxB(A, B) (((8 * (4 - (BW(A) + BW(B))))))
-
-#define ADJUST_VAR_C(C) (((64 - (8 * BW(C)))))
-
-#define ADJUST_ACC_LSH_AxB_C(A, B, C) \
-  acc_shift = acc_shift + 32;         \
-  LIMIT_ACC_LSH;
-
-#define ADJUST_BIAS_LSH_AxB(A, B) LIMIT_BIAS_LSH;
-
-#define ADJUST_ACC_LSH_AND_BIAS_LSH_AxB_C(A, B, C) \
-  ADJUST_ACC_LSH_AxB_C(A, B, C);                   \
-  ADJUST_BIAS_LSH_AxB(A, B);
-
-/* ====================================================================================================
- */
-#define SETUP_BIAS_f32                   \
-  xtfloat _xtfloat_bias = (xtfloat)0.0f; \
-  xtfloat *_xtfloat_p_bias = (xtfloat *)p_bias;
-
-#define SETUP_BIAS_ASYM8b               \
-  WORD32 _WORD32_bias;                  \
-  ae_int64 _ae_int64_sat_bias = ZERO64; \
-  WORD32 *_WORD32_p_bias = (WORD32 *)p_bias;
-
-#define SETUP_BIAS_8b                   \
-  WORD8 _WORD8_bias;                    \
-  UWORD32 _UWORD32_bias;                \
-  ae_int64 _ae_int64_bias = ZERO64;     \
-  ae_int64 _ae_int64_sat_bias = ZERO64; \
-  WORD8 *_WORD8_p_bias = (WORD8 *)p_bias;
-
-#define SETUP_BIAS_8b_BATCH                     \
-  WORD8 _WORD8_bias;                            \
-  WORD16 _WORD16_bias;                          \
-  ae_int16 _ae_int16_bias = ZERO16;             \
-  ae_int16 *_ae_int16_p_bias = &_ae_int16_bias; \
-  ae_int64 _ae_int64_sat_bias = ZERO64;         \
-  WORD8 *_WORD8_p_bias = (WORD8 *)p_bias;
-
-#define SETUP_BIAS_32b                  \
-  ae_int32 _ae_int32_bias = ZERO32;     \
-  ae_int64 _ae_int64_sat_bias = ZERO64; \
-  ae_int32 *_ae_int32_p_bias = (ae_int32 *)p_bias;
-
-#define SETUP_BIAS_16b                  \
-  ae_int16 _ae_int16_bias = ZERO16;     \
-  ae_int64 _ae_int64_sat_bias = ZERO64; \
-  ae_int16 *_ae_int16_p_bias = (ae_int16 *)p_bias;
-
-#define SETUP_BIAS_64b                  \
-  ae_int64 _ae_int64_bias = ZERO64;     \
-  ae_int64 _ae_int64_sat_bias = ZERO64; \
-  ae_int64 *_ae_int64_p_bias = (ae_int64 *)p_bias;
-
-#define SETUP_ACC_FOR_8bx8b(idx) SETUP_ACC_64b(idx)
-#define SETUP_ACC_FOR_8bx16b(idx) SETUP_ACC_64b(idx)
-#define SETUP_ACC_FOR_16bx8b(idx) SETUP_ACC_64b(idx)
-#define SETUP_ACC_FOR_16bx16b(idx) SETUP_ACC_64b(idx)
-#define SETUP_ACC_FOR_ASYM8bxASYM8b(idx) SETUP_ACC_64b(idx)
-
-/*------------------ time batching macros ----------------- */
-
-#define SETUP_ACC_BATCH_ROW_FOR_16bx8b SETUP_ACC_BATCH_ROW_FOR_16bx16b
-#define SETUP_ACC_BATCH_ROW_FOR_8bx16b SETUP_ACC_BATCH_ROW_FOR_16bx16b
-#define SETUP_ACC_BATCH_ROW_FOR_8bx8b SETUP_ACC_BATCH_ROW_FOR_16bx16b
-#define SETUP_ACC_BATCH_ROW_FOR_ASYM8bxASYM8b SETUP_ACC_BATCH_ROW_FOR_16bx16b
-
-#define SETUP_ACC_BATCH_FOR_16bx8b SETUP_ACC_BATCH_FOR_16bx16b
-#define SETUP_ACC_BATCH_FOR_8bx16b SETUP_ACC_BATCH_FOR_16bx16b
-#define SETUP_ACC_BATCH_FOR_8bx8b SETUP_ACC_BATCH_FOR_16bx16b
-#define SETUP_ACC_BATCH_FOR_ASYM8bxASYM8b SETUP_ACC_BATCH_FOR_16bx16b
-
-#define SETUP_ACC_BATCH_ROW_FOR_16bx16b(idx_row) \
-  SETUP_ACC_BATCH_VEC_UNROLL(idx_row);
-
-#define SETUP_ACC_BATCH_FOR_16bx16b(idx_row, idx_vec) \
-  ae_int64 _ae_int64_acc_##idx_row##_##idx_vec = ZERO64;
-
-#define SETUP_ACC_BATCH_ROW_FOR_f32(idx_row) \
-  SETUP_ACC_BATCH_VEC_UNROLL(idx_row);
-
-#define SETUP_ACC_BATCH_FOR_f32(idx_row, idx_vec)                   \
-  xtfloatx2 _xtfloatx2_acc_##idx_row##_##idx_vec = (xtfloatx2)0.0f; \
-  xtfloat _xtfloat_acc_##idx_row##_##idx_vec = (xtfloat)0.0f;       \
-  /*---------------------------------------------------------*/
-
-#define SETUP_ACC_64b(idx) ae_int64 _ae_int64_acc_##idx = ZERO64;
-
-#define SETUP_VEC1_8b                     \
-  ae_int16x4 _ae_int16x4_vec1 = ZERO16X4; \
-  WORD8 *_WORD8_p_vec1 = (WORD8 *)p_vec1;
-
-#define SETUP_VEC2_8b                     \
-  ae_int16x4 _ae_int16x4_vec2 = ZERO16X4; \
-  WORD8 *_WORD8_p_vec2 = (WORD8 *)p_vec2;
-
-#define SETUP_VEC1_16b                    \
-  ae_int16x4 _ae_int16x4_vec1 = ZERO16X4; \
-  ae_int16x4 *_ae_int16x4_p_vec1 = (ae_int16x4 *)p_vec1;
-
-#define SETUP_VEC2_16b                    \
-  ae_int16x4 _ae_int16x4_vec2 = ZERO16X4; \
-  ae_int16x4 *_ae_int16x4_p_vec2 = (ae_int16x4 *)p_vec2;
-
-#define SETUP_VEC1_ASYM8b SETUP_VEC1_8b
-#define SETUP_VEC2_ASYM8b SETUP_VEC2_8b
-/*------------------ time batching macros ----------------- */
-
-#define SETUP_VEC_BATCH_8b(idx_vec)                      \
-  ae_int16x4 _ae_int16x4_vec_batch_##idx_vec = ZERO16X4; \
-  WORD8 *_WORD8_p_vec_batch_##idx_vec = (WORD8 *)(p_vec1[vec_itr + idx_vec]);
-
-#define SETUP_VEC_BATCH_16b(idx_vec)                     \
-  ae_int16x4 _ae_int16x4_vec_batch_##idx_vec = ZERO16X4; \
-  ae_int16x4 *_ae_int16x4_p_vec_batch_##idx_vec =        \
-      (ae_int16x4 *)(p_vec1[vec_itr + idx_vec]);
-
-#define SETUP_VEC_OFFSET_BATCH_16b(idx_vec)              \
-  ae_int16x4 _ae_int16x4_vec_batch_##idx_vec = ZERO16X4; \
-  ae_int16x4 *_ae_int16x4_p_vec_batch_##idx_vec =        \
-      (ae_int16x4 *)(p_vec1 + (vec_itr + idx_vec) * vec_offset);
-
-#define SETUP_VEC_BATCH_f32(idx_vec)                          \
-  xtfloatx2 _xtfloatx2_vec_batch_##idx_vec = (xtfloatx2)0.0f; \
-  xtfloatx2 *_xtfloatx2_p_vec_batch_##idx_vec =               \
-      (xtfloatx2 *)(p_vec1[vec_itr + idx_vec]);
-
-#define SETUP_VEC_BATCH_ASYM8b SETUP_VEC_BATCH_8b
-/*---------------------------------------------------------*/
-
-#define SETUP_MAT1_8b(idx)                      \
-  ae_int16x4 _ae_int16x4_mat1_##idx = ZERO16X4; \
-  WORD8 *_WORD8_p_mat1_##idx = (WORD8 *)&p_mat1[(m_itr + idx) * row_stride1];
-
-#define SETUP_MAT2_8b(idx)                      \
-  ae_int16x4 _ae_int16x4_mat2_##idx = ZERO16X4; \
-  WORD8 *_WORD8_p_mat2_##idx = (WORD8 *)&p_mat2[(m_itr + idx) * row_stride2];
-
-#define SETUP_MAT1_16b(idx)                     \
-  ae_int16x4 _ae_int16x4_mat1_##idx = ZERO16X4; \
-  ae_int16x4 *_ae_int16x4_p_mat1_##idx =        \
-      (ae_int16x4 *)&p_mat1[(m_itr + idx) * row_stride1];
-
-#define SETUP_MAT2_16b(idx)                     \
-  ae_int16x4 _ae_int16x4_mat2_##idx = ZERO16X4; \
-  ae_int16x4 *_ae_int16x4_p_mat2_##idx =        \
-      (ae_int16x4 *)&p_mat2[(m_itr + idx) * row_stride2];
-
-#define SETUP_MAT1_f32(idx)                          \
-  xtfloatx2 _xtfloatx2_mat1_##idx = (xtfloatx2)0.0f; \
-  xtfloatx2 *_xtfloatx2_p_mat1_##idx =               \
-      (xtfloatx2 *)&p_mat1[(m_itr + idx) * row_stride1];
-
-#define SETUP_MAT1_ASYM8b SETUP_MAT1_8b
-#define SETUP_MAT2_ASYM8b SETUP_MAT2_8b
-/* ====================================================================== */
-
-#define LOAD_VEC1_8b \
-  AE_L8X4F_IP(_ae_int16x4_vec1, _WORD8_p_vec1, INCREMENT_IN_BYTES_FOR_WORD8X4);
-
-#define LOAD_VEC2_8b \
-  AE_L8X4F_IP(_ae_int16x4_vec2, _WORD8_p_vec2, INCREMENT_IN_BYTES_FOR_WORD8X4);
-
-#define LOAD_VEC1_16b                               \
-  AE_L16X4_IP(_ae_int16x4_vec1, _ae_int16x4_p_vec1, \
-              INCREMENT_IN_BYTES_FOR_INT16X4);
-
-#define LOAD_VEC2_16b                               \
-  AE_L16X4_IP(_ae_int16x4_vec2, _ae_int16x4_p_vec2, \
-              INCREMENT_IN_BYTES_FOR_INT16X4);
-
-#define LOAD_VEC1_ASYM8b                                    \
-  AE_L8X4F_IP(_ae_int16x4_vec1, _WORD8_p_vec1,              \
-              INCREMENT_IN_BYTES_FOR_WORD8X4);              \
-  _ae_int16x4_vec1 = AE_MOVF16X4_FROMF64(                   \
-      AE_SRLI64(AE_MOVF64_FROMF16X4(_ae_int16x4_vec1), 8)); \
-  _ae_int16x4_vec1 = AE_ADD16(_ae_int16x4_vec1, AE_MOVDA16(vec1_zero_bias));
-
-#define LOAD_VEC2_ASYM8b                                                     \
-  AE_L8X4F_IP(_ae_int16x4_vec2, _WORD8_p_vec2,                               \
-              INCREMENT_IN_BYTES_FOR_WORD8X4);                               \
-  _ae_int16x4_vec2 = AE_MOVF16X4_FROMF64(                                    \
-      AE_SRLI64(AE_MOVF64_FROMF16X4(_ae_int16x4_vec2), 8));                  \
-  _ae_int16x4_vec2 = AE_ADD16(_ae_int16x4_vec2, AE_MOVDA16(vec2_zero_bias)); \
-/*------------------ time batching macros ----------------- */
-#define LOAD_VEC_BATCH_f32(idx_vec)                                           \
-  XT_LSX2IP(_xtfloatx2_vec_batch_##idx_vec, _xtfloatx2_p_vec_batch_##idx_vec, \
-            INCREMENT_IN_BYTES_FOR_FLOAT32x2);
-
-#define LOAD_VEC_BATCH_8b(idx_vec)                                           \
-  AE_L8X4F_IP(_ae_int16x4_vec_batch_##idx_vec, _WORD8_p_vec_batch_##idx_vec, \
-              INCREMENT_IN_BYTES_FOR_WORD8X4);
-
-#define LOAD_VEC_BATCH_16b(idx_vec)              \
-  AE_L16X4_IP(_ae_int16x4_vec_batch_##idx_vec,   \
-              _ae_int16x4_p_vec_batch_##idx_vec, \
-              INCREMENT_IN_BYTES_FOR_INT16X4);
-
-#define LOAD_VEC_BATCH_ASYM8b(idx_vec)                                       \
-  AE_L8X4F_IP(_ae_int16x4_vec_batch_##idx_vec, _WORD8_p_vec_batch_##idx_vec, \
-              INCREMENT_IN_BYTES_FOR_WORD8X4);                               \
-  _ae_int16x4_vec_batch_##idx_vec = AE_MOVF16X4_FROMF64(                     \
-      AE_SRLI64(AE_MOVF64_FROMF16X4(_ae_int16x4_vec_batch_##idx_vec), 8));   \
-  _ae_int16x4_vec_batch_##idx_vec =                                          \
-      AE_ADD16(_ae_int16x4_vec_batch_##idx_vec, AE_MOVDA16(vec1_zero_bias));
-
-#define LOAD_BIAS_8b_FOR_8bx8b                  \
-  _WORD8_bias = *_WORD8_p_bias++;               \
-  _WORD16_bias = _WORD8_bias;                   \
-  *((WORD16 *)_ae_int16_p_bias) = _WORD16_bias; \
-  _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int16_bias), bias_shift);
-
-#define LOAD_BIAS_16b_FOR_8bx16b                    \
-  ae_int16_loadip(_ae_int16_bias, _ae_int16_p_bias, \
-                  INCREMENT_IN_BYTES_FOR_INT16);    \
-  _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int16_bias), bias_shift);
-
-#define LOAD_BIAS_16b_FOR_16bx8b LOAD_BIAS_16b_FOR_8bx16b
-
-#define LOAD_BIAS_16b_FOR_16bx16b                   \
-  ae_int16_loadip(_ae_int16_bias, _ae_int16_p_bias, \
-                  INCREMENT_IN_BYTES_FOR_INT16);    \
-  _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int16_bias), bias_shift);
-
-#define LOAD_BIAS_f32 \
-  XT_LSIP(_xtfloat_bias, _xtfloat_p_bias, INCREMENT_IN_BYTES_FOR_FLOAT32);
-
-#define LOAD_BIAS_ASYM8b                                                \
-  _WORD32_bias = *_WORD32_p_bias++;                                     \
-  _ae_int64_sat_bias =                                                  \
-      AE_SRAI64(AE_MOVINT64_FROMINT32X2(AE_MOVDA32(_WORD32_bias)), 32); \
-/*---------------------------------------------------------*/
-#define LOAD_ROW_MAT1_8b(idx)                              \
-  AE_L8X4F_IP(_ae_int16x4_mat1_##idx, _WORD8_p_mat1_##idx, \
-              INCREMENT_IN_BYTES_FOR_WORD8X4);
-
-#define LOAD_ROW_MAT2_8b(idx)                              \
-  AE_L8X4F_IP(_ae_int16x4_mat2_##idx, _WORD8_p_mat2_##idx, \
-              INCREMENT_IN_BYTES_FOR_WORD8X4);
-
-#define LOAD_ROW_MAT1_16b(idx)                                  \
-  AE_L16X4_IP(_ae_int16x4_mat1_##idx, _ae_int16x4_p_mat1_##idx, \
-              INCREMENT_IN_BYTES_FOR_INT16X4);
-
-#define LOAD_ROW_MAT2_16b(idx)                                  \
-  AE_L16X4_IP(_ae_int16x4_mat2_##idx, _ae_int16x4_p_mat2_##idx, \
-              INCREMENT_IN_BYTES_FOR_INT16X4);
-
-#define LOAD_ROW_MAT1_f32(idx)                              \
-  XT_LSX2IP(_xtfloatx2_mat1_##idx, _xtfloatx2_p_mat1_##idx, \
-            INCREMENT_IN_BYTES_FOR_FLOAT32x2);
-
-#define LOAD_ROW_MAT1_ASYM8b(idx)                                 \
-  AE_L8X4F_IP(_ae_int16x4_mat1_##idx, _WORD8_p_mat1_##idx,        \
-              INCREMENT_IN_BYTES_FOR_WORD8X4);                    \
-  _ae_int16x4_mat1_##idx = AE_MOVF16X4_FROMF64(                   \
-      AE_SRLI64(AE_MOVF64_FROMF16X4(_ae_int16x4_mat1_##idx), 8)); \
-  _ae_int16x4_mat1_##idx =                                        \
-      AE_ADD16(_ae_int16x4_mat1_##idx, AE_MOVDA16(mat1_zero_bias));
-
-#define LOAD_ROW_MAT2_ASYM8b(idx)                                 \
-  AE_L8X4F_IP(_ae_int16x4_mat2_##idx, _WORD8_p_mat2_##idx,        \
-              INCREMENT_IN_BYTES_FOR_WORD8X4);                    \
-  _ae_int16x4_mat2_##idx = AE_MOVF16X4_FROMF64(                   \
-      AE_SRLI64(AE_MOVF64_FROMF16X4(_ae_int16x4_mat2_##idx), 8)); \
-  _ae_int16x4_mat2_##idx =                                        \
-      AE_ADD16(_ae_int16x4_mat2_##idx, AE_MOVDA16(mat2_zero_bias));
-
-#define KERNEL_MAT1_VEC1_8b_8b(idx) \
-  LOAD_ROW_MAT1_8b(idx);            \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec1, _ae_int16x4_mat1_##idx);
-
-#define KERNEL_MAT2_VEC2_8b_8b(idx) \
-  LOAD_ROW_MAT2_8b(idx);            \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec2, _ae_int16x4_mat2_##idx);
-
-#define KERNEL_MAT1_VEC1_16b_8b(idx) \
-  LOAD_ROW_MAT1_16b(idx);            \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec1, _ae_int16x4_mat1_##idx);
-
-#define KERNEL_MAT2_VEC2_16b_8b(idx) \
-  LOAD_ROW_MAT2_16b(idx);            \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec2, _ae_int16x4_mat2_##idx);
-
-#define KERNEL_MAT1_VEC1_8b_16b(idx) \
-  LOAD_ROW_MAT1_8b(idx);             \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec1, _ae_int16x4_mat1_##idx);
-
-#define KERNEL_MAT2_VEC2_8b_16b(idx) \
-  LOAD_ROW_MAT2_8b(idx);             \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec2, _ae_int16x4_mat2_##idx);
-
-#define KERNEL_MAT1_VEC1_16b_16b(idx) \
-  LOAD_ROW_MAT1_16b(idx);             \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec1, _ae_int16x4_mat1_##idx);
-
-#define KERNEL_MAT2_VEC2_16b_16b(idx) \
-  LOAD_ROW_MAT2_16b(idx);             \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec2, _ae_int16x4_mat2_##idx);
-
-#define KERNEL_MAT1_VEC1_ASYM8b_ASYM8b(idx) \
-  LOAD_ROW_MAT1_ASYM8b(idx);                \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec1, _ae_int16x4_mat1_##idx);
-
-#define KERNEL_MAT2_VEC2_ASYM8b_ASYM8b(idx) \
-  LOAD_ROW_MAT2_ASYM8b(idx);                \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec2, _ae_int16x4_mat2_##idx);
-
-/*------------------ time batching macros ----------------- */
-
-#define KERNEL_MAT1_VEC_BATCH_ROW_8b_8b KERNEL_MAT1_VEC_BATCH_ROW_16b_16b
-#define KERNEL_MAT1_VEC_BATCH_ROW_16b_8b KERNEL_MAT1_VEC_BATCH_ROW_16b_16b
-#define KERNEL_MAT1_VEC_BATCH_ROW_8b_16b KERNEL_MAT1_VEC_BATCH_ROW_16b_16b
-#define KERNEL_MAT1_VEC_BATCH_ROW_ASYM8b_ASYM8b \
-  KERNEL_MAT1_VEC_BATCH_ROW_16b_16b
-#define KERNEL_MAT1_VEC_BATCH_8b_8b KERNEL_MAT1_VEC_BATCH_16b_16b
-#define KERNEL_MAT1_VEC_BATCH_16b_8b KERNEL_MAT1_VEC_BATCH_16b_16b
-#define KERNEL_MAT1_VEC_BATCH_8b_16b KERNEL_MAT1_VEC_BATCH_16b_16b
-#define KERNEL_MAT1_VEC_BATCH_ASYM8b_ASYM8b KERNEL_MAT1_VEC_BATCH_16b_16b
-
-#define KERNEL_MAT1_VEC_BATCH_ROW_16b_16b(idx_row) \
-  KERNEL_MAT1_VEC_BATCH_VEC_UNROLL(idx_row);
-
-#define KERNEL_MAT1_VEC_BATCH_16b_16b(idx_row, idx_vec) \
-  AE_MULAAAAQ16(_ae_int64_acc_##idx_row##_##idx_vec,    \
-                _ae_int16x4_vec_batch_##idx_vec, _ae_int16x4_mat1_##idx_row);
-
-#define KERNEL_MAT1_VEC_BATCH_ROW_f32(idx_row) \
-  KERNEL_MAT1_VEC_BATCH_VEC_UNROLL(idx_row);
-
-#define KERNEL_MAT1_VEC_BATCH_f32(idx_row, idx_vec) \
-  XT_MADD_SX2(_xtfloatx2_acc_##idx_row##_##idx_vec, \
-              _xtfloatx2_vec_batch_##idx_vec, _xtfloatx2_mat1_##idx_row);
-
-/*---------------------------------------------------------*/
-#define ADD_BIAS_8b_ACC_FOR_8bx8b(idx)                                        \
-  /* Load 8b bias */                                                          \
-  _WORD8_bias = *_WORD8_p_bias++;                                             \
-  /* Copy 8-bits to unsigned 32-bits */                                       \
-  _UWORD32_bias = _WORD8_bias;                                                \
-  /*Move unsigned 32 bit value to DR register*/                               \
-  _ae_int64_bias = AE_MOVINT64_FROMINT32X2((AE_MOVDA32X2(_UWORD32_bias, 0))); \
-  _ae_int64_bias = AE_SRAA64(_ae_int64_bias, 32);                             \
-  _ae_int64_sat_bias = AE_SLAA64S(_ae_int64_bias, bias_shift);                \
-  _ae_int64_acc_##idx = AE_SRAA64(_ae_int64_acc_##idx, 16);                   \
-  _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias);
-
-#define ADD_BIAS_32b_ACC_FOR_8bx8b(idx)                                    \
-  ae_int32_loadip(_ae_int32_bias, _ae_int32_p_bias,                        \
-                  INCREMENT_IN_BYTES_FOR_INT32);                           \
-  _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int32_bias), bias_shift); \
-  _ae_int64_acc_##idx = AE_SRAA64(_ae_int64_acc_##idx, 16);                \
-  _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias);
-
-#define ADD_BIAS_16b_ACC_FOR_8bx16b(idx)                                   \
-  ae_int16_loadip(_ae_int16_bias, _ae_int16_p_bias,                        \
-                  INCREMENT_IN_BYTES_FOR_INT16);                           \
-  /* Saturate 16b bias after shift to 64b */                               \
-  _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int16_bias), bias_shift); \
-  _ae_int64_acc_##idx = AE_SRAA64(_ae_int64_acc_##idx, 8);                 \
-  _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias);
-
-#define ADD_BIAS_16b_ACC_FOR_16bx8b ADD_BIAS_16b_ACC_FOR_8bx16b
-
-#define ADD_BIAS_64b_ACC_FOR_8bx16b(idx)                                   \
-  ae_int64_loadip(_ae_int64_bias, _ae_int64_p_bias,                        \
-                  INCREMENT_IN_BYTES_FOR_INT64);                           \
-  /* Saturate 64b bias after shift to 64b */                               \
-  _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int64_bias), bias_shift); \
-  _ae_int64_acc_##idx = AE_SRAA64(_ae_int64_acc_##idx, 8);                 \
-  _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias);
-
-#define ADD_BIAS_16b_ACC_FOR_16bx16b(idx)                                  \
-  ae_int16_loadip(_ae_int16_bias, _ae_int16_p_bias,                        \
-                  INCREMENT_IN_BYTES_FOR_INT16);                           \
-  /* Saturate 16b bias after shift to 64b */                               \
-  _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int16_bias), bias_shift); \
-  _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias);
-
-#define ADD_BIAS_64b_ACC_FOR_16bx16b(idx)                                  \
-  ae_int64_loadip(_ae_int64_bias, _ae_int64_p_bias,                        \
-                  INCREMENT_IN_BYTES_FOR_INT64);                           \
-  /* Saturate 64b bias after shift to 64b */                               \
-  _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int64_bias), bias_shift); \
-  _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias);
-
-#define ADD_BIAS_ASYM8b_ACC_FOR_ASYM8bxASYM8b(idx)                      \
-  /* Load 32b bias */                                                   \
-  _WORD32_bias = *_WORD32_p_bias++;                                     \
-  _ae_int64_sat_bias =                                                  \
-      AE_SRAI64(AE_MOVINT64_FROMINT32X2(AE_MOVDA32(_WORD32_bias)), 32); \
-  _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias);
-
-/*------------------ time batching macros ----------------- */
-#define ADD_BIAS_BATCH_ROW_8b_ACC_FOR_8bx8b(idx_row) \
-  LOAD_BIAS_8b_FOR_8bx8b;                            \
-  ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row);
-
-#define ADD_BIAS_BATCH_ROW_16b_ACC_FOR_8bx16b(idx_row) \
-  LOAD_BIAS_16b_FOR_8bx16b;                            \
-  ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row);
-
-#define ADD_BIAS_BATCH_ROW_16b_ACC_FOR_16bx8b(idx_row) \
-  LOAD_BIAS_16b_FOR_16bx8b;                            \
-  ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row);
-
-#define ADD_BIAS_BATCH_ROW_16b_ACC_FOR_16bx16b(idx_row) \
-  LOAD_BIAS_16b_FOR_16bx16b;                            \
-  ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row);
-
-#define ADD_BIAS_BATCH_ROW_ASYM8b_ACC_FOR_ASYM8bxASYM8b(idx_row) \
-  LOAD_BIAS_ASYM8b ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row);
-
-#define ADD_BIAS_BATCH_8b_ACC_FOR_8bx8b(idx_row, idx_vec) \
-  _ae_int64_acc_##idx_row##_##idx_vec =                   \
-      AE_SRAA64(_ae_int64_acc_##idx_row##_##idx_vec, 16); \
-  _ae_int64_acc_##idx_row##_##idx_vec =                   \
-      AE_ADD64S(_ae_int64_acc_##idx_row##_##idx_vec, _ae_int64_sat_bias);
-
-#define ADD_BIAS_BATCH_16b_ACC_FOR_8bx16b(idx_row, idx_vec) \
-  _ae_int64_acc_##idx_row##_##idx_vec =                     \
-      AE_SRAA64(_ae_int64_acc_##idx_row##_##idx_vec, 8);    \
-  _ae_int64_acc_##idx_row##_##idx_vec =                     \
-      AE_ADD64S(_ae_int64_acc_##idx_row##_##idx_vec, _ae_int64_sat_bias);
-
-#define ADD_BIAS_BATCH_16b_ACC_FOR_16bx16b(idx_row, idx_vec) \
-  _ae_int64_acc_##idx_row##_##idx_vec =                      \
-      AE_ADD64S(_ae_int64_acc_##idx_row##_##idx_vec, _ae_int64_sat_bias);
-
-#define ADD_BIAS_BATCH_16b_ACC_FOR_16bx8b ADD_BIAS_BATCH_16b_ACC_FOR_8bx16b
-#define ADD_BIAS_BATCH_ASYM8b_ACC_FOR_ASYM8bxASYM8b \
-  ADD_BIAS_BATCH_16b_ACC_FOR_16bx16b
-
-#define ADD_BIAS_BATCH_ROW_ACC_FOR_f32(idx_row) \
-  LOAD_BIAS_f32;                                \
-  ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row);
-
-#define ADD_BIAS_BATCH_ACC_FOR_f32(idx_row, idx_vec)     \
-  _xtfloat_acc_##idx_row##_##idx_vec =                   \
-      XT_RADD_SX2(_xtfloatx2_acc_##idx_row##_##idx_vec); \
-  _xtfloat_acc_##idx_row##_##idx_vec =                   \
-      XT_ADD_S(_xtfloat_acc_##idx_row##_##idx_vec, _xtfloat_bias);
-
-#define STORE_ACC_8bx8b_AT_SCRATCH_32b(idx)  \
-  (*((ae_int32 *)p_scratch + m_itr + idx)) = \
-      AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift));
-
-#define STORE_ACC_8bx8b_AT_OUT_8b(idx)                                    \
-  ae_int32 _ae_int32_tmp_var_##idx;                                       \
-  ae_f32x2 _ae_f32x2_tmp_var_##idx = AE_SLAA32S(                          \
-      AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)), 24); \
-  _ae_int32_tmp_var_##idx = AE_SLAA32S(_ae_f32x2_tmp_var_##idx, -24);     \
-  (*((WORD8 *)p_out + m_itr + idx)) = (*((UWORD32 *)&_ae_int32_tmp_var_##idx));
-
-#define STORE_ACC_8bx8b_AT_OUT_16b(idx)                                   \
-  ae_int32 _ae_int32_tmp_var_##idx;                                       \
-  ae_f32x2 _ae_f32x2_tmp_var_##idx = AE_SLAA32S(                          \
-      AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)), 16); \
-  _ae_int32_tmp_var_##idx = AE_SLAA32S(_ae_f32x2_tmp_var_##idx, -16);     \
-  (*((WORD16 *)p_out + m_itr + idx)) = (*((UWORD32 *)&_ae_int32_tmp_var_##idx));
-
-#define STORE_ACC_8bx8b_AT_OUT_32b(idx)  \
-  (*((ae_int32 *)p_out + m_itr + idx)) = \
-      AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift));
-
-#define STORE_ACC_ASYM8bxASYM8b_AT_OUT_ASYM8b(idx)                      \
-  _ae_int32x2_acc_##idx = AE_MIN32(                                     \
-      AE_MAX32(_ae_int32x2_acc_##idx, AE_MOVDA32(0)), AE_MOVDA32(255)); \
-  (*((UWORD8 *)p_out + m_itr + idx)) =                                  \
-      (UWORD8)AE_MOVAD32_L(_ae_int32x2_acc_##idx);
-
-/* ====================================================================================================
- */
-#define STORE_ACC_8bx16b_AT_SCRATCH_32b(idx) \
-  (*((ae_int32 *)p_scratch + m_itr + idx)) = \
-      AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift));
-
-#define STORE_ACC_8bx16b_AT_OUT_16b(idx)                                  \
-  ae_int32 _ae_int32_tmp_var_##idx;                                       \
-  ae_f32x2 _ae_f32x2_tmp_var_##idx = AE_SLAA32S(                          \
-      AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)), 16); \
-  _ae_int32_tmp_var_##idx = AE_SLAA32S(_ae_f32x2_tmp_var_##idx, -16);     \
-  (*((WORD16 *)p_out + m_itr + idx)) = (*((UWORD32 *)&_ae_int32_tmp_var_##idx));
-
-#define STORE_ACC_16bx8b_AT_OUT_16b STORE_ACC_8bx16b_AT_OUT_16b
-
-#define STORE_ACC_8bx16b_AT_OUT_32b(idx) \
-  (*((ae_int32 *)p_out + m_itr + idx)) = \
-      AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift));
-
-#define STORE_ACC_8bx16b_AT_OUT_64b(idx) \
-  (*((ae_int64 *)p_out + m_itr + idx)) = \
-      AE_SLAA64S(_ae_int64_acc_##idx, acc_shift);
-
-/* ====================================================================================================
- */
-#define STORE_ACC_16bx16b_AT_SCRATCH_32b(idx) \
-  (*((ae_int32 *)p_scratch + m_itr + idx)) =  \
-      AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift));
-
-#define STORE_ACC_16bx16b_AT_OUT_16b(idx)                                 \
-  ae_int32 _ae_int32_tmp_var_##idx;                                       \
-  ae_f32x2 _ae_f32x2_tmp_var_##idx = AE_SLAA32S(                          \
-      AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)), 16); \
-  _ae_int32_tmp_var_##idx = AE_SLAA32S(_ae_f32x2_tmp_var_##idx, -16);     \
-  (*((WORD16 *)p_out + m_itr + idx)) = (*((UWORD32 *)&_ae_int32_tmp_var_##idx));
-
-#define STORE_ACC_16bx16b_AT_OUT_32b(idx) \
-  (*((ae_int32 *)p_out + m_itr + idx)) =  \
-      AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift));
-
-#define STORE_ACC_16bx16b_AT_OUT_64b(idx) \
-  (*((ae_int64 *)p_out + m_itr + idx)) =  \
-      AE_SLAA64S(_ae_int64_acc_##idx, acc_shift);
-
-/*------------------ time batching macros ----------------- */
-#define STORE_ACC_BATCH_ROW_8bx8b_AT_OUT_32b(idx_row) \
-  STORE_ACC_BATCH_VEC_UNROLL(idx_row);
-
-#define STORE_ACC_BATCH_ROW_8bx8b_AT_OUT_8b(idx_row) \
-  STORE_ACC_BATCH_VEC_UNROLL(idx_row);
-
-#define STORE_ACC_BATCH_8bx8b_AT_OUT_32b(idx_row, idx_vec)      \
-  (*((ae_int32 *)p_out[vec_itr + idx_vec] + m_itr + idx_row)) = \
-      AE_ROUND32F64SSYM(                                        \
-          AE_SLAA64S(_ae_int64_acc_##idx_row##_##idx_vec, acc_shift));
-
-#define STORE_ACC_BATCH_8bx8b_AT_OUT_8b(idx_row, idx_vec)              \
-  ae_int32 _ae_int32_tmp_var_##idx_row##_##idx_vec;                    \
-  ae_f32x2 _ae_f32x2_tmp_var_##idx_row##_##idx_vec =                   \
-      AE_SLAA32S(AE_ROUND32F64SSYM(AE_SLAA64S(                         \
-                     _ae_int64_acc_##idx_row##_##idx_vec, acc_shift)), \
-                 24);                                                  \
-  _ae_int32_tmp_var_##idx_row##_##idx_vec =                            \
-      AE_SLAA32S(_ae_f32x2_tmp_var_##idx_row##_##idx_vec, -24);        \
-  (*((WORD8 *)p_out[vec_itr + idx_vec] + m_itr + idx_row)) =           \
-      (*((UWORD32 *)&_ae_int32_tmp_var_##idx_row##_##idx_vec));
-
-#define STORE_ACC_BATCH_ROW_8bx16b_AT_OUT_64b(idx_row) \
-  STORE_ACC_BATCH_VEC_UNROLL(idx_row);
-
-#define STORE_ACC_BATCH_ROW_16bx8b_AT_OUT_16b \
-  STORE_ACC_BATCH_ROW_8bx16b_AT_OUT_64b
-
-#define STORE_ACC_BATCH_ROW_8bx16b_AT_OUT_16b \
-  STORE_ACC_BATCH_ROW_8bx16b_AT_OUT_64b
-
-#define STORE_ACC_BATCH_8bx16b_AT_OUT_64b(idx_row, idx_vec)     \
-  (*((ae_int64 *)p_out[vec_itr + idx_vec] + m_itr + idx_row)) = \
-      AE_SLAA64S(_ae_int64_acc_##idx_row##_##idx_vec, acc_shift);
-
-#define STORE_ACC_BATCH_8bx16b_AT_OUT_16b(idx_row, idx_vec) \
-  STORE_ACC_BATCH_16bx16b_AT_OUT_16b(idx_row, idx_vec);
-
-#define STORE_ACC_BATCH_ROW_16bx16b_AT_OUT_64b(idx_row) \
-  STORE_ACC_BATCH_VEC_UNROLL(idx_row);
-
-#define STORE_ACC_BATCH_ROW_16bx16b_AT_OUT_16b \
-  STORE_ACC_BATCH_ROW_16bx16b_AT_OUT_64b
-
-#define STORE_ACC_BATCH_16bx16b_AT_OUT_64b(idx_row, idx_vec)    \
-  (*((ae_int64 *)p_out[vec_itr + idx_vec] + m_itr + idx_row)) = \
-      AE_SLAA64S(_ae_int64_acc_##idx_row##_##idx_vec, acc_shift);
-
-#define STORE_STRIDE_ACC_BATCH_16bx16b_AT_OUT_16b(idx_row, idx_vec)    \
-  ae_int32 _ae_int32_tmp_var_##idx_row##_##idx_vec;                    \
-  ae_f32x2 _ae_f32x2_tmp_var_##idx_row##_##idx_vec =                   \
-      AE_SLAA32S(AE_ROUND32F64SSYM(AE_SLAA64S(                         \
-                     _ae_int64_acc_##idx_row##_##idx_vec, acc_shift)), \
-                 16);                                                  \
-  _ae_int32_tmp_var_##idx_row##_##idx_vec =                            \
-      AE_SLAA32S(_ae_f32x2_tmp_var_##idx_row##_##idx_vec, -16);        \
-  (*((WORD16 *)p_out + (vec_itr + idx_vec) * out_offset +              \
-     (m_itr + idx_row) * out_stride)) =                                \
-      (*((UWORD32 *)&_ae_int32_tmp_var_##idx_row##_##idx_vec));
-
-#define STORE_ACC_BATCH_ROW_AT_OUT_f32(idx_row) \
-  STORE_ACC_BATCH_VEC_UNROLL(idx_row);
-
-#define STORE_ACC_BATCH_AT_OUT_f32(idx_row, idx_vec)                \
-  /*p_out value stored in a tmp pointer to make it inout for ISA */ \
-  p_out_tmp = (p_out[vec_itr + idx_vec] + m_itr + idx_row);         \
-  XT_SSIP(_xtfloat_acc_##idx_row##_##idx_vec, p_out_tmp, 0);
-
-#define STORE_ACC_BATCH_ROW_ASYM8bxASYM8b_AT_OUT_ASYM8b(idx_row) \
-  STORE_ACC_BATCH_VEC_UNROLL(idx_row);
-
-#define STORE_ACC_BATCH_ASYM8bxASYM8b_AT_OUT_ASYM8b(idx_row, idx_vec)          \
-  _ae_int32x2_acc_##idx_row##_##idx_vec =                                      \
-      AE_MIN32(AE_MAX32(_ae_int32x2_acc_##idx_row##_##idx_vec, AE_MOVDA32(0)), \
-               AE_MOVDA32(255));                                               \
-  (*((UWORD8 *)(p_out[vec_itr + idx_vec] + m_itr + idx_row))) =                \
-      (UWORD8)AE_MOVAD32_L(_ae_int32x2_acc_##idx_row##_##idx_vec);
-
-/*---------------------------------------------------------*/
-/* Specific macros needed for extra calculations involved
-  for ASYM8b */
-
-/* This is written to match with Tensorflow */
-#define ADJUST_ACC_ASYM8b(idx)                                             \
-  /* Multiply accumulator with 'out_multiplier', same as Tensorflow */     \
-  ae_int32x2 _ae_int32x2_acc_##idx =                                       \
-      AE_SLAA32(AE_MOVINT32X2_FROMINT64(_ae_int64_acc_##idx), left_shift); \
-  _ae_int32x2_acc_##idx =                                                  \
-      AE_MULFP32X2RAS(_ae_int32x2_acc_##idx, AE_MOVDA32(out_multiplier));  \
-  /* Shift by out_shift, same as Tensorflow */                             \
-  _ae_int64_acc_##idx =                                                    \
-      AE_SLAI64(AE_MOVINT64_FROMINT32X2(_ae_int32x2_acc_##idx), 32);       \
-  _ae_int64_acc_##idx = AE_SRAA64(_ae_int64_acc_##idx, right_shift);       \
-  _ae_int32x2_acc_##idx = AE_ROUND32F64SSYM(_ae_int64_acc_##idx);          \
-  /* Add output zero point */                                              \
-  (_ae_int32x2_acc_##idx) =                                                \
-      AE_ADD32S(_ae_int32x2_acc_##idx, AE_MOVDA32(out_zero_bias));
-
-/* For time batching */
-#define ADJUST_ACC_BATCH_ROW_ASYM8b(idx_row) \
-  ADJUST_ACC_BATCH_VEC_UNROLL(idx_row);
-
-/* For time batching */
-#define ADJUST_ACC_BATCH_ASYM8b(idx_row, idx_vec)                             \
-  /* Multiply accumulator with 'out_multiplier', same as Tensorflow */        \
-  ae_int32x2 _ae_int32x2_acc_##idx_row##_##idx_vec =                          \
-      AE_SLAA32(AE_MOVINT32X2_FROMINT64(_ae_int64_acc_##idx_row##_##idx_vec), \
-                left_shift);                                                  \
-  _ae_int32x2_acc_##idx_row##_##idx_vec = AE_MULFP32X2RAS(                    \
-      _ae_int32x2_acc_##idx_row##_##idx_vec, AE_MOVDA32(out_multiplier));     \
-  /* Shift by out_shift, same as Tensorflow */                                \
-  _ae_int64_acc_##idx_row##_##idx_vec = AE_SLAI64(                            \
-      AE_MOVINT64_FROMINT32X2(_ae_int32x2_acc_##idx_row##_##idx_vec), 32);    \
-  _ae_int64_acc_##idx_row##_##idx_vec =                                       \
-      AE_SRAA64(_ae_int64_acc_##idx_row##_##idx_vec, right_shift);            \
-  _ae_int32x2_acc_##idx_row##_##idx_vec =                                     \
-      AE_ROUND32F64SSYM(_ae_int64_acc_##idx_row##_##idx_vec);                 \
-  /* Add output zero point */                                                 \
-  (_ae_int32x2_acc_##idx_row##_##idx_vec) = AE_ADD32S(                        \
-      _ae_int32x2_acc_##idx_row##_##idx_vec, AE_MOVDA32(out_zero_bias));
-
-/*---------------------------------------------------------*/
-/* ====================================================================================================
- */
-#if (ROW_UNROLL == 1)
-#define SETUP_ACC UNROLL_SETUP_ACC(0)
-#define SETUP_MAT1 UNROLL_SETUP_MAT1(0)
-#define SETUP_MAT2 UNROLL_SETUP_MAT2(0)
-#define KERNEL_MAT1_VEC1 UNROLL_KERNEL_MAT1_VEC1(0)
-#define KERNEL_MAT2_VEC2 UNROLL_KERNEL_MAT2_VEC2(0)
-#define ADD_BIAS_ACC UNROLL_ADD_BIAS_ACC(0)
-#define ADJUST_ACC UNROLL_ADJUST_ACC(0)
-#define STORE_ACC UNROLL_STORE_ACC(0)
-
-#elif (ROW_UNROLL == 2)
-#define SETUP_ACC UNROLL_SETUP_ACC(0) UNROLL_SETUP_ACC(1)
-#define SETUP_MAT1 UNROLL_SETUP_MAT1(0) UNROLL_SETUP_MAT1(1)
-#define SETUP_MAT2 UNROLL_SETUP_MAT2(0) UNROLL_SETUP_MAT2(1)
-#define KERNEL_MAT1_VEC1 UNROLL_KERNEL_MAT1_VEC1(0) UNROLL_KERNEL_MAT1_VEC1(1)
-#define KERNEL_MAT2_VEC2 UNROLL_KERNEL_MAT2_VEC2(0) UNROLL_KERNEL_MAT2_VEC2(1)
-#define ADD_BIAS_ACC UNROLL_ADD_BIAS_ACC(0) UNROLL_ADD_BIAS_ACC(1)
-#define ADJUST_ACC UNROLL_ADJUST_ACC(0) UNROLL_ADJUST_ACC(1)
-#define STORE_ACC UNROLL_STORE_ACC(0) UNROLL_STORE_ACC(1)
-
-#elif (ROW_UNROLL == 4)
-#define SETUP_ACC     \
-  UNROLL_SETUP_ACC(0) \
-  UNROLL_SETUP_ACC(1) UNROLL_SETUP_ACC(2) UNROLL_SETUP_ACC(3)
-#define SETUP_MAT1     \
-  UNROLL_SETUP_MAT1(0) \
-  UNROLL_SETUP_MAT1(1) UNROLL_SETUP_MAT1(2) UNROLL_SETUP_MAT1(3)
-#define SETUP_MAT2     \
-  UNROLL_SETUP_MAT2(0) \
-  UNROLL_SETUP_MAT2(1) UNROLL_SETUP_MAT2(2) UNROLL_SETUP_MAT2(3)
-#define KERNEL_MAT1_VEC1     \
-  UNROLL_KERNEL_MAT1_VEC1(0) \
-  UNROLL_KERNEL_MAT1_VEC1(1) \
-  UNROLL_KERNEL_MAT1_VEC1(2) UNROLL_KERNEL_MAT1_VEC1(3)
-#define KERNEL_MAT2_VEC2     \
-  UNROLL_KERNEL_MAT2_VEC2(0) \
-  UNROLL_KERNEL_MAT2_VEC2(1) \
-  UNROLL_KERNEL_MAT2_VEC2(2) UNROLL_KERNEL_MAT2_VEC2(3)
-#define ADD_BIAS_ACC     \
-  UNROLL_ADD_BIAS_ACC(0) \
-  UNROLL_ADD_BIAS_ACC(1) UNROLL_ADD_BIAS_ACC(2) UNROLL_ADD_BIAS_ACC(3)
-#define ADJUST_ACC     \
-  UNROLL_ADJUST_ACC(0) \
-  UNROLL_ADJUST_ACC(1) UNROLL_ADJUST_ACC(2) UNROLL_ADJUST_ACC(3)
-#define STORE_ACC     \
-  UNROLL_STORE_ACC(0) \
-  UNROLL_STORE_ACC(1) UNROLL_STORE_ACC(2) UNROLL_STORE_ACC(3)
-
-#elif (ROW_UNROLL == 8)
-#define SETUP_ACC     \
-  UNROLL_SETUP_ACC(0) \
-  UNROLL_SETUP_ACC(1) \
-  UNROLL_SETUP_ACC(2) \
-  UNROLL_SETUP_ACC(3) \
-  UNROLL_SETUP_ACC(4) \
-  UNROLL_SETUP_ACC(5) UNROLL_SETUP_ACC(6) UNROLL_SETUP_ACC(7)
-#define SETUP_MAT1     \
-  UNROLL_SETUP_MAT1(0) \
-  UNROLL_SETUP_MAT1(1) \
-  UNROLL_SETUP_MAT1(2) \
-  UNROLL_SETUP_MAT1(3) \
-  UNROLL_SETUP_MAT1(4) \
-  UNROLL_SETUP_MAT1(5) UNROLL_SETUP_MAT1(6) UNROLL_SETUP_MAT1(7)
-#define SETUP_MAT2     \
-  UNROLL_SETUP_MAT2(0) \
-  UNROLL_SETUP_MAT2(1) \
-  UNROLL_SETUP_MAT2(2) \
-  UNROLL_SETUP_MAT2(3) \
-  UNROLL_SETUP_MAT2(4) \
-  UNROLL_SETUP_MAT2(5) UNROLL_SETUP_MAT2(6) UNROLL_SETUP_MAT2(7)
-#define KERNEL_MAT1_VEC1     \
-  UNROLL_KERNEL_MAT1_VEC1(0) \
-  UNROLL_KERNEL_MAT1_VEC1(1) \
-  UNROLL_KERNEL_MAT1_VEC1(2) \
-  UNROLL_KERNEL_MAT1_VEC1(3) \
-  UNROLL_KERNEL_MAT1_VEC1(4) \
-  UNROLL_KERNEL_MAT1_VEC1(5) \
-  UNROLL_KERNEL_MAT1_VEC1(6) UNROLL_KERNEL_MAT1_VEC1(7)
-#define KERNEL_MAT2_VEC2     \
-  UNROLL_KERNEL_MAT2_VEC2(0) \
-  UNROLL_KERNEL_MAT2_VEC2(1) \
-  UNROLL_KERNEL_MAT2_VEC2(2) \
-  UNROLL_KERNEL_MAT2_VEC2(3) \
-  UNROLL_KERNEL_MAT2_VEC2(4) \
-  UNROLL_KERNEL_MAT2_VEC2(5) \
-  UNROLL_KERNEL_MAT2_VEC2(6) UNROLL_KERNEL_MAT2_VEC2(7)
-#define ADD_BIAS_ACC     \
-  UNROLL_ADD_BIAS_ACC(0) \
-  UNROLL_ADD_BIAS_ACC(1) \
-  UNROLL_ADD_BIAS_ACC(2) \
-  UNROLL_ADD_BIAS_ACC(3) \
-  UNROLL_ADD_BIAS_ACC(4) \
-  UNROLL_ADD_BIAS_ACC(5) UNROLL_ADD_BIAS_ACC(6) UNROLL_ADD_BIAS_ACC(7)
-#define ADJUST_ACC     \
-  UNROLL_ADJUST_ACC(0) \
-  UNROLL_ADJUST_ACC(1) \
-  UNROLL_ADJUST_ACC(2) \
-  UNROLL_ADJUST_ACC(3) \
-  UNROLL_ADJUST_ACC(4) \
-  UNROLL_ADJUST_ACC(5) UNROLL_ADJUST_ACC(6) UNROLL_ADJUST_ACC(7)
-#define STORE_ACC     \
-  UNROLL_STORE_ACC(0) \
-  UNROLL_STORE_ACC(1) \
-  UNROLL_STORE_ACC(2) \
-  UNROLL_STORE_ACC(3) \
-  UNROLL_STORE_ACC(4) \
-  UNROLL_STORE_ACC(5) UNROLL_STORE_ACC(6) UNROLL_STORE_ACC(7)
-
-#endif /* (ROW_UNROLL == 1) */
-
-#if (ROW_UNROLL == 4 && VEC_UNROLL == 2)
-
-#define SETUP_VEC_BATCH UNROLL_SETUP_VEC_BATCH(0) UNROLL_SETUP_VEC_BATCH(1)
-
-#define SETUP_ACC_BATCH         \
-  UNROLL_ROW_SETUP_ACC_BATCH(0) \
-  UNROLL_ROW_SETUP_ACC_BATCH(1) \
-  UNROLL_ROW_SETUP_ACC_BATCH(2) UNROLL_ROW_SETUP_ACC_BATCH(3)
-#define SETUP_ACC_BATCH_VEC_UNROLL(idx_row) \
-  UNROLL_SETUP_ACC_BATCH(idx_row, 0) UNROLL_SETUP_ACC_BATCH(idx_row, 1)
-#define SETUP_ACC_BATCH_TAIL   \
-  UNROLL_SETUP_ACC_BATCH(0, 0) \
-  UNROLL_SETUP_ACC_BATCH(1, 0) \
-  UNROLL_SETUP_ACC_BATCH(2, 0) UNROLL_SETUP_ACC_BATCH(3, 0)
-
-#define LOAD_VEC_BATCH UNROLL_LOAD_VEC_BATCH(0) UNROLL_LOAD_VEC_BATCH(1)
-#define LOAD_MAT1         \
-  UNROLL_LOAD_ROW_MAT1(0) \
-  UNROLL_LOAD_ROW_MAT1(1) UNROLL_LOAD_ROW_MAT1(2) UNROLL_LOAD_ROW_MAT1(3)
-
-#define KERNEL_MAT1_VEC_BATCH         \
-  UNROLL_ROW_KERNEL_MAT1_VEC_BATCH(0) \
-  UNROLL_ROW_KERNEL_MAT1_VEC_BATCH(1) \
-  UNROLL_ROW_KERNEL_MAT1_VEC_BATCH(2) UNROLL_ROW_KERNEL_MAT1_VEC_BATCH(3)
-#define KERNEL_MAT1_VEC_BATCH_VEC_UNROLL(idx_row) \
-  UNROLL_KERNEL_MAT1_VEC_BATCH(idx_row, 0)        \
-  UNROLL_KERNEL_MAT1_VEC_BATCH(idx_row, 1)
-#define KERNEL_MAT1_VEC_BATCH_TAIL   \
-  UNROLL_KERNEL_MAT1_VEC_BATCH(0, 0) \
-  UNROLL_KERNEL_MAT1_VEC_BATCH(1, 0) \
-  UNROLL_KERNEL_MAT1_VEC_BATCH(2, 0) UNROLL_KERNEL_MAT1_VEC_BATCH(3, 0)
-
-#define ADD_BIAS_ACC_BATCH   \
-  UNROLL_ROW_ADD_BIAS_ACC(0) \
-  UNROLL_ROW_ADD_BIAS_ACC(1) \
-  UNROLL_ROW_ADD_BIAS_ACC(2) UNROLL_ROW_ADD_BIAS_ACC(3)
-#define ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row) \
-  UNROLL_ADD_BIAS_ACC_BATCH(idx_row, 0) UNROLL_ADD_BIAS_ACC_BATCH(idx_row, 1)
-#define ADD_BIAS_ACC_BATCH_TAIL                     \
-  LOAD_BIAS UNROLL_ADD_BIAS_ACC_BATCH(0, 0)         \
-      LOAD_BIAS UNROLL_ADD_BIAS_ACC_BATCH(1, 0)     \
-          LOAD_BIAS UNROLL_ADD_BIAS_ACC_BATCH(2, 0) \
-              LOAD_BIAS UNROLL_ADD_BIAS_ACC_BATCH(3, 0)
-
-#define STORE_ACC_BATCH   \
-  UNROLL_ROW_STORE_ACC(0) \
-  UNROLL_ROW_STORE_ACC(1) UNROLL_ROW_STORE_ACC(2) UNROLL_ROW_STORE_ACC(3)
-#define STORE_ACC_BATCH_VEC_UNROLL(idx_row) \
-  UNROLL_STORE_ACC_BATCH(idx_row, 0) UNROLL_STORE_ACC_BATCH(idx_row, 1)
-#define STORE_ACC_BATCH_TAIL   \
-  UNROLL_STORE_ACC_BATCH(0, 0) \
-  UNROLL_STORE_ACC_BATCH(1, 0) \
-  UNROLL_STORE_ACC_BATCH(2, 0) UNROLL_STORE_ACC_BATCH(3, 0)
-
-#define ADJUST_ACC_BATCH_TAIL   \
-  UNROLL_ADJUST_ACC_BATCH(0, 0) \
-  UNROLL_ADJUST_ACC_BATCH(1, 0) \
-  UNROLL_ADJUST_ACC_BATCH(2, 0) UNROLL_ADJUST_ACC_BATCH(3, 0)
-#define ADJUST_ACC_BATCH   \
-  UNROLL_ROW_ADJUST_ACC(0) \
-  UNROLL_ROW_ADJUST_ACC(1) UNROLL_ROW_ADJUST_ACC(2) UNROLL_ROW_ADJUST_ACC(3)
-#define ADJUST_ACC_BATCH_VEC_UNROLL(idx_row) \
-  UNROLL_ADJUST_ACC_BATCH(idx_row, 0) UNROLL_ADJUST_ACC_BATCH(idx_row, 1)
-
-#endif /* (ROW_UNROLL == 4 && VEC_UNROLL == 2)*/
-
-#endif /* __XA_NNLIB_COMMON_MACROS_H__ */
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_definitions.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_definitions.h
deleted file mode 100644
index 7199887f501..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_definitions.h
+++ /dev/null
@@ -1,57 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __XA_OPUS_CODEC_DEFINITIONS_H__
-#define __XA_OPUS_CODEC_DEFINITIONS_H__
-
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_api_defs.h"
-
-/* Identification Strings */
-#define LIBNAME "HiFi Mini Neural Network Library"
-#define LIBVERSION "0.6.0"
-
-#define LIB_APIVERSION_MAJOR 1
-#define LIB_APIVERSION_MINOR 0
-
-#if LIB_APIVERSION_MAJOR != XA_APIVERSION_MAJOR || \
-    LIB_APIVERSION_MINOR != XA_APIVERSION_MINOR
-// #error "Version Mismatch"
-#endif
-
-#define LIB_APIVERSION \
-  XA_MAKE_VERSION_STR(LIB_APIVERSION_MAJOR, LIB_APIVERSION_MINOR)
-
-#endif /* __XA_OPUS_CODEC_DEFINITIONS_H__ */
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h
deleted file mode 100644
index 8508e54e515..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h
+++ /dev/null
@@ -1,84 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __XA_NNLIB_ERR_CHK_H__
-#define __XA_NNLIB_ERR_CHK_H__
-
-#ifndef NULL
-#define NULL (void *)0
-#endif /* NULL */
-
-#ifndef DISABLE_ARG_CHK
-
-#define XA_NNLIB_ARG_CHK_PTR(_ptr, _err) \
-  do {                                   \
-    if ((_ptr) == NULL) return (_err);   \
-  } while (0)
-
-#define XA_NNLIB_ARG_CHK_ALIGN(_ptr, _align, _err)                 \
-  do {                                                             \
-    if (((unsigned int)(_ptr) & ((_align)-1)) != 0) return (_err); \
-  } while (0)
-
-#define XA_NNLIB_ARG_CHK_COND(_cond, _err) \
-  do {                                     \
-    if ((_cond)) return (_err);            \
-  } while (0)
-
-#else /* DISABLE_ARG_CHK */
-
-#define XA_NNLIB_ARG_CHK_PTR(_ptr, _err)
-#define XA_NNLIB_ARG_CHK_ALIGN(_ptr, _align, _err)
-#define XA_NNLIB_ARG_CHK_COND(_cond, _err)
-
-#endif /* DISABLE_ARG_CHK */
-
-#define XA_NNLIB_CHK_PTR(_ptr, _err)   \
-  do {                                 \
-    if ((_ptr) == NULL) return (_err); \
-  } while (0)
-
-#define XA_NNLIB_CHK_ALIGN(_ptr, _align, _err)                     \
-  do {                                                             \
-    if (((unsigned int)(_ptr) & ((_align)-1)) != 0) return (_err); \
-  } while (0)
-
-#define XA_NNLIB_CHK_COND(_cond, _err) \
-  do {                                 \
-    if ((_cond)) return (_err);        \
-  } while (0)
-
-#endif /* __XA_NNLIB_ERR_CHK_H__ */
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_activations_asym8s_asym8s.c b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_activations_asym8s_asym8s.c
deleted file mode 100644
index 060b70696e0..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_activations_asym8s_asym8s.c
+++ /dev/null
@@ -1,176 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xa_nnlib_common.h"
-
-#define ALIGNMENT 8 /* 8 bytes alignment */
-
-#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1)))
-
-#define LIMIT(out, inp, min, max) \
-  {                               \
-    out = min;                    \
-    out = AE_MAXP24S(inp, min);   \
-    out = AE_MINP24S(out, max);   \
-  }
-
-#define STORE_8X2_FROM_24X2(out_ptr, val) \
-  {                                       \
-    int o1, o2;                           \
-    o1 = AE_MOVAP24S_H(val);              \
-    o2 = AE_MOVAP24S_L(val);              \
-    *out_ptr++ = (WORD8)o1;               \
-    *out_ptr++ = (WORD8)o2;               \
-  }
-
-/*
- * inp: p_vec: 4 byte aligned input pointer
- * out: p_out: no alignment needed for output pointer*/
-WORD32 xa_nn_vec_activation_min_max_asym8s_asym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_vec,
-    int activation_min, int activation_max, WORD32 vec_length) {
-  int i;
-  ae_p24x2s x, y, min, max;
-
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(p_out, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_vec, -1);
-
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((vec_length <= 0), -1);
-
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((activation_max < activation_min), -1);
-
-  WORD8 *p_o = p_out;
-  WORD8 *p_v = (WORD8 *)p_vec;
-
-  min = AE_SRAIP24(AE_CVTP24A16(activation_min), 8);
-  max = AE_SRAIP24(AE_CVTP24A16(activation_max), 8);
-
-  int pre_loop_count = 0;
-  // pre loop, active when input ptr is not 4 byte aligned
-  pre_loop_count = (int)((unsigned)ALIGN_PTR(p_v, 4) - (unsigned)p_v);
-  pre_loop_count = (pre_loop_count < vec_length) ? pre_loop_count : vec_length;
-
-  vec_length = vec_length - pre_loop_count;
-  vec_length = (vec_length < 0) ? 0 : vec_length;
-
-  for (i = 0; i < pre_loop_count; i++) {
-    int i1;
-    i1 = ((WORD8)*p_v++);
-    x = AE_MOVPA24(i1);
-    LIMIT(y, x, min, max)
-    i1 = AE_MOVAP24S_H(y);
-    *p_o++ = (WORD8)i1;
-  }
-
-  if ((activation_max >= (int)127) && (activation_min <= (int)-128)) {
-    p_v = p_v - 2;
-    for (i = 0; i < (vec_length >> 1); i++) {
-      AE_LP8X2F_IU(x, (WORD8 *)p_v, 2 * sizeof(WORD8));
-      y = AE_SRAIP24(x, 16);
-
-      STORE_8X2_FROM_24X2(p_o, y)
-    }
-    if (vec_length & 1) {
-      p_v = p_v + 2;
-      int i1;
-      i1 = (WORD8)p_v[0];
-      *p_o++ = (WORD8)i1;
-    }
-  } else if ((activation_max < (int)127) && (activation_min <= (int)-128)) {
-    p_v = p_v - 2;
-    for (i = 0; i < (vec_length >> 1); i++) {
-      AE_LP8X2F_IU(x, (WORD8 *)p_v, 2 * sizeof(WORD8));
-      y = AE_SRAIP24(x, 16);
-
-      y = AE_MINP24S(y, max);
-
-      STORE_8X2_FROM_24X2(p_o, y)
-    }
-    if (vec_length & 1) {
-      p_v = p_v + 2;
-      int i1;
-      i1 = (WORD8)p_v[0];
-      y = AE_MOVPA24(i1);
-
-      y = AE_MINP24S(y, max);
-
-      i1 = AE_MOVAP24S_H(y);
-      *p_o++ = (WORD8)i1;
-    }
-  } else if ((activation_max >= (int)127) && (activation_min > (int)-128)) {
-    p_v = p_v - 2;
-    for (i = 0; i < (vec_length >> 1); i++) {
-      AE_LP8X2F_IU(x, (WORD8 *)p_v, 2 * sizeof(WORD8));
-      y = AE_SRAIP24(x, 16);
-
-      y = AE_MAXP24S(y, min);
-
-      STORE_8X2_FROM_24X2(p_o, y)
-    }
-    if (vec_length & 1) {
-      p_v = p_v + 2;
-      int i1;
-      i1 = (WORD8)p_v[0];
-      y = AE_MOVPA24(i1);
-
-      y = AE_MAXP24S(y, min);
-
-      i1 = AE_MOVAP24S_H(y);
-      *p_o++ = (WORD8)i1;
-    }
-  } else {
-    p_v = p_v - 2;
-    for (i = 0; i < (vec_length >> 1); i++) {
-      AE_LP8X2F_IU(x, (WORD8 *)p_v, 2 * sizeof(WORD8));
-      x = AE_SRAIP24(x, 16);
-      LIMIT(y, x, min, max)
-      STORE_8X2_FROM_24X2(p_o, y)
-    }
-    if (vec_length & 1) {
-      p_v = p_v + 2;
-      int i1;
-      i1 = (WORD8)p_v[0];
-      x = AE_MOVPA24(i1);
-      LIMIT(y, x, min, max)
-      i1 = AE_MOVAP24S_H(y);
-      *p_o++ = (WORD8)i1;
-    }
-  }
-  return 0;
-}
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_softmax_asym8_asym8.c b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_softmax_asym8_asym8.c
deleted file mode 100644
index 4f7dce839d3..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_softmax_asym8_asym8.c
+++ /dev/null
@@ -1,1005 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xa_nnlib_common.h"
-
-#define ALIGNMENT 8 /* 8 bytes alignment */
-#define ALIGNED_SIZE(x, bytes) (((x) + (bytes - 1)) & (~(bytes - 1)))
-#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1)))
-
-#ifndef AE_LP8X2F_IU
-#define AE_LP8X2F_IU(p_x, p_in, x)                           \
-  AE_LP16F_IU(p_x, (ae_p16s *)p_in, x);                      \
-  ae_p24x2s p_tmp1 = AE_SLLIP24(p_x, 8);                     \
-  ae_p24x2s p_tmp2 = AE_ANDP48(p_x, AE_MOVPA24(0xFFFF0000)); \
-  p_x = AE_SELP24_LL(p_tmp2, p_tmp1);
-
-#endif
-
-#define NSA64_T(y, x)               \
-  {                                 \
-    ae_q56s q_tmp = *(ae_q56s *)&x; \
-    y = AE_NSAQ56S(q_tmp) + 8;      \
-  }
-
-#define MULFP32X2RAS_T(result, a, b)             \
-  {                                              \
-    ae_q56s q_a = AE_CVTQ48A32S(a);              \
-    ae_p24x2s p_b = AE_CVTP24A16X2_HL(b, b);     \
-    ae_q56s q_out = AE_MULFQ32SP16U_L(q_a, p_b); \
-    q_out = AE_SRAIQ56(q_out, 16);               \
-    AE_MULAFQ32SP16S_H(q_out, q_a, p_b);         \
-    q_out = AE_ROUNDSQ32ASYM(q_out);             \
-    *(ae_q32s *)&result = q_out;                 \
-  }
-
-#define MULFP32X2RS_T(result, a, b)              \
-  {                                              \
-    ae_q56s q_a = AE_CVTQ48A32S(a);              \
-    ae_p24x2s p_b = AE_CVTP24A16X2_HL(b, b);     \
-    ae_q56s q_out = AE_MULFQ32SP16U_L(q_a, p_b); \
-    q_out = AE_SRAIQ56(q_out, 16);               \
-    AE_MULAFQ32SP16S_H(q_out, q_a, p_b);         \
-    q_out = AE_ROUNDSQ32SYM(q_out);              \
-    *(ae_q32s *)&result = q_out;                 \
-  }
-#define ADD32S_T(result, a, b)             \
-  {                                        \
-    ae_q56s q_a = AE_CVTQ48A32S(a);        \
-    ae_q56s q_b = AE_CVTQ48A32S(b);        \
-    ae_q56s q_out = AE_ADDSQ56S(q_a, q_b); \
-    q_out = AE_SATQ48S(q_out);             \
-    *(ae_q32s *)&result = q_out;           \
-  }
-
-#define SUB32S_T(result, a, b)             \
-  {                                        \
-    ae_q56s q_a = AE_CVTQ48A32S(a);        \
-    ae_q56s q_b = AE_CVTQ48A32S(b);        \
-    ae_q56s q_out = AE_SUBSQ56S(q_a, q_b); \
-    q_out = AE_SATQ48S(q_out);             \
-    *(ae_q32s *)&result = q_out;           \
-  }
-
-#define SLAI32S_T(result, a, b)         \
-  {                                     \
-    ae_q56s q_a = AE_CVTQ48A32S(a);     \
-    ae_q56s q_out = AE_SLLIQ56(q_a, b); \
-    q_out = AE_SATQ48S(q_out);          \
-    *(ae_q32s *)&result = q_out;        \
-  }
-
-#define SRAA32RS_T(result, a, b)             \
-  {                                          \
-    ae_q56s q_a = AE_CVTQ48A32S(a);          \
-    ae_q56s q_out = AE_SLAASQ56S(q_a, (-b)); \
-    q_out = AE_ROUNDSQ32ASYM(q_out);         \
-    *(ae_q32s *)&result = q_out;             \
-  }
-
-#define SRAI32R_T(result, a, b)         \
-  {                                     \
-    ae_q56s q_a = AE_CVTQ48A32S(a);     \
-    ae_q56s q_out = AE_SRAIQ56(q_a, b); \
-    q_out = AE_ROUNDSQ32ASYM(q_out);    \
-    *(ae_q32s *)&result = q_out;        \
-  }
-
-static const int CONSTANT_TERM = (0x70f5a894);
-static const int CONSTANT_1_OVER_3 = (0x2aaaaaab);
-static const int CONSTANT_1_OVER_8 = (0x10000000);
-static const int ONE_QUATER_Q26 = (0x1000000);  // Q6.26
-static const int MASK = (0xffffff);
-static const int Q31 = 0x7fffffff;
-static const int constant_48_over_17 = 1515870810;
-static const int constant_neg_32_over_17 = -1010580540;  // Q29
-static const int F2_ONE = 0x20000000;
-
-static const int constant_neg_32_over_17_Q21 = -3947580;  // Q21
-static const int constant_48_over_17_Q21 = 5921370;       // Q21
-
-static ae_p24x2s GetReciprocal(ae_q56s q_x, int x_integerbits, int *lsh) {
-  int headroom_plus_one;
-  ae_p24x2s p_x;
-  ae_q56s q_tmp;
-  ae_p24x2s p_half_den;
-  int i;
-
-  headroom_plus_one = AE_NSAQ56S(q_x) + 8;
-  headroom_plus_one = headroom_plus_one - 31;
-  *lsh = x_integerbits - headroom_plus_one;
-
-  q_x = (q_x << (headroom_plus_one + 15));
-  p_half_den = AE_ROUNDSP24Q48SYM(q_x);
-
-  q_tmp = AE_CVTQ48A32S(constant_48_over_17);
-  AE_MULAFP24S_LL(q_tmp, p_half_den, AE_MOVPA24(constant_neg_32_over_17_Q21));
-  p_x = AE_ROUNDSP24Q48SYM(q_tmp);
-
-  for (i = 0; i < 3; i++) {
-    q_tmp = AE_CVTQ48A32S(F2_ONE);
-    AE_MULSFP24S_LL(q_tmp, p_x, p_half_den);
-    ae_p24x2s p_one_minus_half_denominator_times_x = AE_ROUNDSP24Q48SYM(q_tmp);
-
-    q_tmp = AE_MULFP24S_LL(p_x, p_one_minus_half_denominator_times_x);
-    ae_p24x2s p_m = AE_ROUNDSP24Q48SYM(q_tmp);
-    p_m = AE_SLLISP24S(p_m, 2);
-    p_x = AE_ADDSP24S(p_x, p_m);
-  }
-
-  p_x = AE_SLLISP24S(p_x, 1);
-
-  return p_x;
-}
-
-static const int MASK_16BITS = (0xffff);
-static const int ONE_QUATER_Q18 = (0x10000);          // Q18
-static const int CONSTANT_1_OVER_8_Q23 = (0x100000);  // Q23
-static const int CONSTANT_1_OVER_3_Q23 = (0x2aaaaa);  // Q23
-static const int CONSTANT_TERM_Q23 = (0x70f5a8);      // Q23
-static const int Q23 = 0x7fffff;
-
-#define GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_in_out, exponent,                \
-                                           FixedPointMultiplier, p_remainder) \
-  {                                                                           \
-    ae_p24x2s p_out;                                                          \
-                                                                              \
-    ae_p24x2s p_zero = AE_ZEROP48();                                          \
-                                                                              \
-    ae_p24x2s p_scale = AE_MOVPA24(1 << (18 + exponent));                     \
-    ae_p24x2s p_mask = p_remainder & p_scale;                                 \
-                                                                              \
-    ae_p24x2s p_FixedPointMultiplier = AE_MOVPA24(FixedPointMultiplier >> 8); \
-                                                                              \
-    ae_q56s q_tmp1 = AE_MULFP24S_HH(p_in_out, p_FixedPointMultiplier);        \
-    ae_q56s q_tmp2 = AE_MULFP24S_LL(p_in_out, p_FixedPointMultiplier);        \
-    ae_p24x2s p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1);                              \
-    ae_p24x2s p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2);                              \
-    p_out = AE_SELP24_LL(p_t1, p_t2);                                         \
-                                                                              \
-    xtbool2 flag_le = AE_LTP24S(p_zero, p_mask);                              \
-    AE_MOVTP24X2(p_in_out, p_out, flag_le);                                   \
-  }
-
-#define EXP_Q26_II(p_exp_y, p_inp_t)                                        \
-  {                                                                         \
-    ae_p24x2s p_x1_in, p_x2, p_x3, p_x4, p_x4_by_4, p_y1, p_y2, p_y3, p_y4, \
-        p_y5, p_y6, p_y;                                                    \
-                                                                            \
-    p_x2 = p_inp_t & AE_MOVPA24(MASK_16BITS);                               \
-    ae_p24x2s p_a_mod_quater_minus_q_1_by_4 =                               \
-        p_x2 - AE_MOVPA24(ONE_QUATER_Q18);                                  \
-    ae_p24x2s p_x_in = p_a_mod_quater_minus_q_1_by_4 << 5;                  \
-    ae_p24x2s p_remainder = p_a_mod_quater_minus_q_1_by_4 - p_inp_t;        \
-                                                                            \
-    p_x1_in = AE_ADDSP24S(p_x_in, AE_MOVPA24(CONSTANT_1_OVER_8_Q23));       \
-                                                                            \
-    ae_q56s q_tmp1 = AE_MULFP24S_HH(p_x1_in, p_x1_in);                      \
-    ae_q56s q_tmp2 = AE_MULFP24S_LL(p_x1_in, p_x1_in);                      \
-    ae_p24x2s p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1);                            \
-    ae_p24x2s p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2);                            \
-    p_x2 = AE_SELP24_LL(p_t1, p_t2);                                        \
-                                                                            \
-    q_tmp1 = AE_MULFP24S_HH(p_t1, p_x1_in);                                 \
-    q_tmp2 = AE_MULFP24S_LL(p_t2, p_x1_in);                                 \
-    p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1);                                      \
-    p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2);                                      \
-    p_x3 = AE_SELP24_LL(p_t1, p_t2);                                        \
-                                                                            \
-    q_tmp1 = AE_MULFP24S_HH(p_x2, p_x2);                                    \
-    q_tmp2 = AE_MULFP24S_LL(p_x2, p_x2);                                    \
-    p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1);                                      \
-    p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2);                                      \
-    p_x4 = AE_SELP24_LL(p_t1, p_t2);                                        \
-    p_x4_by_4 = p_x4 >> 2;                                                  \
-                                                                            \
-    p_y1 = AE_ADDSP24S(p_x4_by_4, p_x3);                                    \
-                                                                            \
-    ae_p24x2s p_const = AE_MOVPA24(CONSTANT_1_OVER_3_Q23);                  \
-    q_tmp1 = AE_MULFP24S_HH(p_y1, p_const);                                 \
-    q_tmp2 = AE_MULFP24S_LL(p_y1, p_const);                                 \
-    p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1);                                      \
-    p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2);                                      \
-    p_y2 = AE_SELP24_LL(p_t1, p_t2);                                        \
-                                                                            \
-    p_y3 = AE_ADDSP24S(p_y2, p_x2);                                         \
-    p_y4 = p_y3 >> 1;                                                       \
-                                                                            \
-    p_y5 = AE_ADDSP24S(p_x1_in, p_y4); /* ADD32S_T(y5, x1_in, y4);  */      \
-                                                                            \
-    p_const = AE_MOVPA24(CONSTANT_TERM_Q23);                                \
-    q_tmp1 = AE_MULFP24S_HH(p_y5, p_const);                                 \
-    q_tmp2 = AE_MULFP24S_LL(p_y5, p_const);                                 \
-    p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1);                                      \
-    p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2);                                      \
-    p_y6 = AE_SELP24_LL(p_t1, p_t2);                                        \
-    p_y = AE_ADDSP24S(p_y6, p_const);                                       \
-                                                                            \
-    {                                                                       \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, -2, 1672461947, p_remainder); \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, -1, 1302514674, p_remainder); \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, 0, 790015084, p_remainder);   \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, 1, 290630308, p_remainder);   \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, 2, 39332535, p_remainder);    \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, 3, 720401, p_remainder);      \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, 4, 242, p_remainder);         \
-    }                                                                       \
-    p_exp_y = p_y;                                                          \
-    p_const = AE_MOVPA24(Q23);                                              \
-    xtbool2 flag_eq = AE_EQP24(p_inp_t, AE_ZEROP48());                      \
-    AE_MOVTP24X2(p_exp_y, p_const, flag_eq);                                \
-  }
-
-#define GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_in_out, exponent,                 \
-                                          FixedPointMultiplier, p_remainder)  \
-  {                                                                           \
-    ae_p24x2s p_out;                                                          \
-                                                                              \
-    ae_p24x2s p_zero = AE_ZEROP48();                                          \
-                                                                              \
-    ae_p24x2s p_scale = AE_MOVPA24(1 << (18 + exponent));                     \
-    ae_p24x2s p_mask = p_remainder & p_scale;                                 \
-                                                                              \
-    ae_p24x2s p_FixedPointMultiplier = AE_MOVPA24(FixedPointMultiplier >> 8); \
-                                                                              \
-    ae_q56s q_tmp1 = AE_MULFP24S_HH(p_in_out, p_FixedPointMultiplier);        \
-    p_out = AE_ROUNDSP24Q48SYM(q_tmp1);                                       \
-                                                                              \
-    xtbool2 flag_le = AE_LTP24S(p_zero, p_mask);                              \
-    AE_MOVTP24X2(p_in_out, p_out, flag_le);                                   \
-  }
-
-#define EXP_Q26_I(p_exp_y, p_inp_t)                                         \
-  {                                                                         \
-    ae_p24x2s p_x1_in, p_x2, p_x3, p_x4, p_x4_by_4, p_y1, p_y2, p_y3, p_y4, \
-        p_y5, p_y6, p_y;                                                    \
-                                                                            \
-    p_x2 = p_inp_t & AE_MOVPA24(MASK_16BITS);                               \
-    ae_p24x2s p_a_mod_quater_minus_q_1_by_4 =                               \
-        p_x2 - AE_MOVPA24(ONE_QUATER_Q18);                                  \
-    ae_p24x2s p_x_in = p_a_mod_quater_minus_q_1_by_4 << 5;                  \
-    ae_p24x2s p_remainder = p_a_mod_quater_minus_q_1_by_4 - p_inp_t;        \
-                                                                            \
-    p_x1_in = AE_ADDSP24S(p_x_in, AE_MOVPA24(CONSTANT_1_OVER_8_Q23));       \
-                                                                            \
-    ae_q56s q_tmp1 = AE_MULFP24S_HH(p_x1_in, p_x1_in);                      \
-    p_x2 = AE_ROUNDSP24Q48SYM(q_tmp1);                                      \
-                                                                            \
-    q_tmp1 = AE_MULFP24S_HH(p_x2, p_x1_in);                                 \
-    p_x3 = AE_ROUNDSP24Q48SYM(q_tmp1);                                      \
-                                                                            \
-    q_tmp1 = AE_MULFP24S_HH(p_x2, p_x2);                                    \
-    p_x4 = AE_ROUNDSP24Q48SYM(q_tmp1);                                      \
-    p_x4_by_4 = p_x4 >> 2;                                                  \
-                                                                            \
-    p_y1 = AE_ADDSP24S(p_x4_by_4, p_x3);                                    \
-                                                                            \
-    ae_p24x2s p_const = AE_MOVPA24(CONSTANT_1_OVER_3_Q23);                  \
-    q_tmp1 = AE_MULFP24S_HH(p_y1, p_const);                                 \
-    p_y2 = AE_ROUNDSP24Q48SYM(q_tmp1);                                      \
-                                                                            \
-    p_y3 = AE_ADDSP24S(p_y2, p_x2);                                         \
-    p_y4 = p_y3 >> 1;                                                       \
-                                                                            \
-    p_y5 = AE_ADDSP24S(p_x1_in, p_y4); /* ADD32S_T(y5, x1_in, y4);  */      \
-                                                                            \
-    p_const = AE_MOVPA24(CONSTANT_TERM_Q23);                                \
-    q_tmp1 = AE_MULFP24S_HH(p_y5, p_const);                                 \
-    p_y6 = AE_ROUNDSP24Q48SYM(q_tmp1);                                      \
-    p_y = AE_ADDSP24S(p_y6, p_const);                                       \
-                                                                            \
-    {                                                                       \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, -2, 1672461947, p_remainder);  \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, -1, 1302514674, p_remainder);  \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, 0, 790015084, p_remainder);    \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, 1, 290630308, p_remainder);    \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, 2, 39332535, p_remainder);     \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, 3, 720401, p_remainder);       \
-      GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, 4, 242, p_remainder);          \
-    }                                                                       \
-    p_exp_y = p_y;                                                          \
-    p_const = AE_MOVPA24(Q23);                                              \
-    xtbool2 flag_eq = AE_EQP24(p_inp_t, AE_ZEROP48());                      \
-    AE_MOVTP24X2(p_exp_y, p_const, flag_eq);                                \
-  }
-
-WORD32 xa_nn_vec_softmax_asym8u_8(UWORD8 *__restrict__ pOut,
-                                  const UWORD8 *__restrict__ pVec,
-                                  WORD32 diffmin, WORD32 input_beta_left_shift,
-                                  WORD32 input_beta_multiplier,
-                                  WORD32 vec_length, pVOID pScratch) {
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(pOut, -1);
-  XA_NNLIB_ARG_CHK_PTR(pVec, -1);
-  XA_NNLIB_ARG_CHK_PTR(pScratch, -1);
-  /* Pointer alignment checks */
-  /* No alignment (1-byte) needed for any pointer */
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((vec_length <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND(
-      ((input_beta_left_shift < -31) || (input_beta_left_shift > 31)), -1);
-  XA_NNLIB_ARG_CHK_COND((input_beta_multiplier < 0), -1);
-
-  int i;
-  int shift_bits_reciprocal;
-  UWORD8 *p_in;
-  WORD32 *__restrict pExp = (WORD32 *)ALIGN_PTR(pScratch, ALIGNMENT);
-  ae_p24f *__restrict pTmpScratch = (ae_p24f *)pExp;
-  int max;
-  ae_p24x2s p_x;
-  ae_p24x2s p_max = AE_MOVPA24(0xFF800000);
-  ae_p24x2s p_recip_sum_exp;
-  int pre_loop_count;
-  int main_loop_count;
-  int post_loop_count;
-
-  if (vec_length > 1) {
-    pre_loop_count = (int)pVec & 0x1;
-    main_loop_count = vec_length - pre_loop_count;
-    post_loop_count = (main_loop_count & 1);
-    main_loop_count = main_loop_count >> 1;
-  } else {
-    pre_loop_count = 0;
-    main_loop_count = 0;
-    post_loop_count = vec_length;
-  }
-
-  /* Calculating Max */
-  {
-    p_in = (UWORD8 *)pVec;
-
-    if (pre_loop_count) {
-      p_x = AE_MOVPA24(*p_in++);
-      p_max = AE_MAXP24S(p_max, p_x);
-    }
-
-    p_in -= 2;
-    for (i = 0; i < main_loop_count; i++) {
-      AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8));
-      p_x = AE_SRLIP24(p_x, 16);
-      p_max = AE_MAXP24S(p_max, p_x);
-    }
-
-    if (post_loop_count) {
-      p_in += 2;
-      p_x = AE_MOVPA24(*p_in);
-      p_max = AE_MAXP24S(p_max, p_x);
-    }
-    p_max = AE_MAXP24S(p_max, AE_SELP24_LH(p_max, p_max));
-    max = AE_MOVAP24S_L(p_max);
-  }
-
-  /* Calculate exponents */
-  {
-    ae_q56s q_sum_exp = AE_ZEROQ56();
-    ae_p24x2s p_rem_x, p_y, p_exp_y;
-    ae_p24x2s p_zero = AE_ZEROP48();
-    ae_p24x2s p_input_beta_multiplier =
-        AE_MOVPA24((input_beta_multiplier >> 8));
-    ae_p24x2s p_diffmin = AE_MOVPA24(diffmin);
-    int input_beta_left_shift_for_24bit = input_beta_left_shift - 8;
-
-    p_in = (UWORD8 *)pVec;
-    WUR_AE_SAR(input_beta_left_shift_for_24bit);
-
-    if (pre_loop_count) {
-      p_x = AE_MOVPA24(*p_in++);
-      p_rem_x = p_x - p_max;
-      p_y = AE_SLLSSP24S(p_rem_x);
-
-      ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier);
-
-      ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1);
-
-      EXP_Q26_I(p_exp_y, p_dequantized_y1)
-
-      xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin);
-      AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp);
-
-      *pTmpScratch++ = p_exp_y;
-
-      p_exp_y = p_exp_y >> 4;
-
-      AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1));
-    }
-
-    p_in -= 2;
-    for (i = 0; i < main_loop_count; i++) {
-      AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8));
-      p_x = AE_SRLIP24(p_x, 16);
-      p_rem_x = p_x - p_max;
-      p_y = AE_SLLSSP24S(p_rem_x);
-
-      ae_q56s q_dequantized_y1 = AE_MULFP24S_HH(p_y, p_input_beta_multiplier);
-      ae_q56s q_dequantized_y2 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier);
-
-      ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1);
-      ae_p24x2s p_dequantized_y2 = AE_ROUNDSP24Q48ASYM(q_dequantized_y2);
-
-      ae_p24x2s p_dequantized =
-          AE_SELP24_LL(p_dequantized_y1, p_dequantized_y2);
-
-      EXP_Q26_II(p_exp_y, p_dequantized)
-
-      xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin);
-      AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp);
-
-      *pTmpScratch++ = AE_SELP24_HH(p_exp_y, p_exp_y);
-      *pTmpScratch++ = p_exp_y; /* store lower element */
-
-      p_exp_y = p_exp_y >> 4;
-
-      AE_MULAAP24S_HH_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1));
-    }
-    if (post_loop_count) {
-      p_in += 2;
-
-      p_x = AE_MOVPA24(*p_in);
-      p_rem_x = p_x - p_max;
-      p_y = AE_SLLSSP24S(p_rem_x);
-
-      ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier);
-
-      ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1);
-
-      EXP_Q26_I(p_exp_y, p_dequantized_y1)
-
-      xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin);
-      AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp);
-
-      *pTmpScratch = p_exp_y;
-
-      p_exp_y = p_exp_y >> 4;
-
-      AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1));
-    }
-    p_recip_sum_exp = GetReciprocal(q_sum_exp, 12, &shift_bits_reciprocal);
-  }
-
-  /* Calculate output */
-  {
-    ae_p24x2s p_exp;
-
-    int shift_val = -(shift_bits_reciprocal + 31 - 8 - 8);
-
-    ae_p24x2s p_min = AE_ZEROP48();
-    ae_p24x2s p_max = AE_MOVPA24(255);
-
-    for (i = 0; i<vec_length >> 1; i++) {
-      int out;
-
-      p_exp = *(ae_p24x2f *)&pExp[2 * i];
-
-      ae_q56s q_tmp1 = AE_MULFP24S_HH(p_exp, p_recip_sum_exp);
-      ae_q56s q_tmp2 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp);
-
-      q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val);
-      q_tmp2 = AE_SLAASQ56S(q_tmp2, shift_val);
-
-      ae_p24x2s p_out1 = AE_ROUNDSP24Q48ASYM(q_tmp1);
-      ae_p24x2s p_out2 = AE_ROUNDSP24Q48ASYM(q_tmp2);
-
-      ae_p24x2s p_out = AE_SELP24_LL(p_out1, p_out2);
-
-      p_out = AE_MAXP24S(p_out, p_min);
-      p_out = AE_MINP24S(p_out, p_max);
-
-      out = AE_MOVAP24S_H(p_out);
-      *pOut++ = (UWORD8)out;
-
-      out = AE_MOVAP24S_L(p_out);
-      *pOut++ = (UWORD8)out;
-    }
-
-    if (vec_length & 0x1) {
-      int out;
-
-      p_exp = *(ae_p24f *)&pExp[vec_length - 1];
-
-      ae_q56s q_tmp1 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp);
-
-      q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val);
-
-      ae_p24x2s p_out = AE_ROUNDSP24Q48ASYM(q_tmp1);
-
-      p_out = AE_MAXP24S(p_out, p_min);
-      p_out = AE_MINP24S(p_out, p_max);
-
-      out = AE_MOVAP24S_L(p_out);
-      *pOut++ = (UWORD8)out;
-    }
-  }
-
-  return 0;
-}
-
-WORD32 xa_nn_vec_softmax_asym8s_8(WORD8 *__restrict__ pOut,
-                                  const WORD8 *__restrict__ pVec,
-                                  WORD32 diffmin, WORD32 input_beta_left_shift,
-                                  WORD32 input_beta_multiplier,
-                                  WORD32 vec_length, pVOID pScratch) {
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(pOut, -1);
-  XA_NNLIB_ARG_CHK_PTR(pVec, -1);
-  XA_NNLIB_ARG_CHK_PTR(pScratch, -1);
-  /* Pointer alignment checks */
-  /* No alignment (1-byte) needed for any pointer */
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((vec_length <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND(
-      ((input_beta_left_shift < -31) || (input_beta_left_shift > 31)), -1);
-  XA_NNLIB_ARG_CHK_COND((input_beta_multiplier < 0), -1);
-
-  int i;
-  int shift_bits_reciprocal;
-  WORD8 *p_in;
-  WORD32 *__restrict pExp = (WORD32 *)ALIGN_PTR(pScratch, ALIGNMENT);
-  ae_p24x2s p_recip_sum_exp;
-  ae_p24x2s p_x;
-  ae_p24x2s p_max = AE_MOVPA24(0xFF800000);
-
-  int pre_loop_count;
-  int main_loop_count;
-  int post_loop_count;
-
-  if (vec_length > 1) {
-    pre_loop_count = (int)pVec & 0x1;
-    main_loop_count = vec_length - pre_loop_count;
-    post_loop_count = (main_loop_count & 1);
-    main_loop_count = main_loop_count >> 1;
-  } else {
-    pre_loop_count = 0;
-    main_loop_count = 0;
-    post_loop_count = vec_length;
-  }
-
-  /* Calculating Max */
-  {
-    p_in = (WORD8 *)pVec;
-
-    if (pre_loop_count) {
-      p_x = AE_MOVPA24(*p_in++);
-      p_max = AE_MAXP24S(p_max, p_x);
-    }
-
-    p_in -= 2;
-    for (i = 0; i < main_loop_count; i++) {
-      AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8));
-      p_max = AE_MAXP24S(p_max, p_x);
-    }
-    p_max = AE_SRAIP24(p_max, 16);
-
-    if (post_loop_count) {
-      p_in += 2;
-      p_x = AE_MOVPA24(*p_in);
-      p_max = AE_MAXP24S(p_max, p_x);
-    }
-    p_max = AE_MAXP24S(p_max, AE_SELP24_LH(p_max, p_max));
-  }
-
-  /* Calculate exponents */
-  {
-    ae_q56s q_sum_exp = AE_ZEROQ56();
-    ae_p24x2s p_rem_x, p_y, p_exp_y;
-    ae_p24x2s p_zero = AE_ZEROP48();
-    ae_p24x2s p_input_beta_multiplier =
-        AE_MOVPA24((input_beta_multiplier >> 8));
-    ae_p24x2s p_diffmin = AE_MOVPA24(diffmin);
-    int input_beta_left_shift_for_24bit = input_beta_left_shift - 8;
-
-    p_in = (WORD8 *)pVec;
-    WUR_AE_SAR(input_beta_left_shift_for_24bit);
-
-    if (pre_loop_count) {
-      p_x = AE_MOVPA24(*p_in++);
-      p_rem_x = p_x - p_max;
-      p_y = AE_SLLSSP24S(p_rem_x);
-
-      ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier);
-
-      ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1);
-
-      EXP_Q26_I(p_exp_y, p_dequantized_y1)
-
-      xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin);
-      AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp);
-
-      *(ae_p24f *)&pExp[0] = p_exp_y;
-
-      p_exp_y = p_exp_y >> 4;
-
-      AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1));
-    }
-
-    p_in -= 2;
-    for (i = 0; i < main_loop_count; i++) {
-      AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8));
-      p_x = AE_SRAIP24(p_x, 16);
-      p_rem_x = p_x - p_max;
-      p_y = AE_SLLSSP24S(p_rem_x);
-
-      ae_q56s q_dequantized_y1 = AE_MULFP24S_HH(p_y, p_input_beta_multiplier);
-      ae_q56s q_dequantized_y2 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier);
-
-      ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1);
-      ae_p24x2s p_dequantized_y2 = AE_ROUNDSP24Q48ASYM(q_dequantized_y2);
-
-      ae_p24x2s p_dequantized =
-          AE_SELP24_LL(p_dequantized_y1, p_dequantized_y2);
-
-      EXP_Q26_II(p_exp_y, p_dequantized)
-
-      xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin);
-      AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp);
-
-      //*(ae_p24x2f *)&pExp[pre_loop_count + 2*i] = p_exp_y;
-      *(ae_p24f *)&pExp[pre_loop_count + 2 * i] =
-          AE_SELP24_HH(p_exp_y, p_exp_y);
-      *(ae_p24f *)&pExp[pre_loop_count + 2 * i + 1] =
-          AE_SELP24_LL(p_exp_y, p_exp_y);
-      //*(ae_p24f *)&pExp[0] = p_exp_y;
-
-      p_exp_y = p_exp_y >> 4;
-
-      AE_MULAAP24S_HH_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1));
-    }
-
-    if (post_loop_count) {
-      p_in += 2;
-
-      p_x = AE_MOVPA24(*p_in);
-      p_rem_x = p_x - p_max;
-      p_y = AE_SLLSSP24S(p_rem_x);
-
-      ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier);
-
-      ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1);
-
-      EXP_Q26_I(p_exp_y, p_dequantized_y1)
-
-      xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin);
-      AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp);
-
-      *(ae_p24f *)&pExp[vec_length - 1] = p_exp_y;
-
-      p_exp_y = p_exp_y >> 4;
-
-      AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1));
-    }
-
-    p_recip_sum_exp = GetReciprocal(q_sum_exp, 12, &shift_bits_reciprocal);
-  }
-
-  /* Calculate output */
-  pExp = (WORD32 *)ALIGN_PTR(pScratch, ALIGNMENT);
-  {
-    ae_p24x2s p_exp;
-
-    int shift_val = -(shift_bits_reciprocal + 31 - 8 - 8);
-
-    ae_p24x2s p_min = AE_MOVPA24(-128);
-    ae_p24x2s p_max = AE_MOVPA24(127);
-
-    for (i = 0; i<vec_length >> 1; i++) {
-      int out;
-
-      p_exp = *(ae_p24x2f *)&pExp[2 * i];
-
-      ae_q56s q_tmp1 = AE_MULFP24S_HH(p_exp, p_recip_sum_exp);
-      ae_q56s q_tmp2 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp);
-
-      q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val);
-      q_tmp2 = AE_SLAASQ56S(q_tmp2, shift_val);
-
-      ae_p24x2s p_out1 = AE_ROUNDSP24Q48ASYM(q_tmp1);
-      ae_p24x2s p_out2 = AE_ROUNDSP24Q48ASYM(q_tmp2);
-
-      ae_p24x2s p_out = AE_SELP24_LL(p_out1, p_out2);
-
-      p_out = AE_SUBSP24S(p_out, AE_MOVPA24(128));
-      p_out = AE_MAXP24S(p_out, p_min);
-      p_out = AE_MINP24S(p_out, p_max);
-
-      out = AE_MOVAP24S_H(p_out);
-      *pOut++ = (WORD8)out;
-
-      out = AE_MOVAP24S_L(p_out);
-      *pOut++ = (WORD8)out;
-    }
-
-    if (vec_length & 0x1) {
-      int out;
-
-      p_exp = *(ae_p24f *)&pExp[vec_length - 1];
-
-      ae_q56s q_tmp1 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp);
-
-      q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val);
-
-      ae_p24x2s p_out = AE_ROUNDSP24Q48ASYM(q_tmp1);
-
-      p_out = AE_SUBSP24S(p_out, AE_MOVPA24(128));
-      p_out = AE_MAXP24S(p_out, p_min);
-      p_out = AE_MINP24S(p_out, p_max);
-
-      out = AE_MOVAP24S_L(p_out);
-      *pOut++ = (WORD8)out;
-    }
-  }
-
-  return 0;
-}
-
-WORD32 xa_nn_vec_softmax_asym8s_16(WORD16 *__restrict__ pOut,
-                                   const WORD8 *__restrict__ pVec,
-                                   WORD32 diffmin, WORD32 input_beta_left_shift,
-                                   WORD32 input_beta_multiplier,
-                                   WORD32 vec_length, pVOID pScratch) {
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(pOut, -1);
-  XA_NNLIB_ARG_CHK_PTR(pVec, -1);
-  XA_NNLIB_ARG_CHK_PTR(pScratch, -1);
-  /* Pointer alignment checks */
-  /* No alignment (1-byte) needed for any pointer */
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((vec_length <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND(
-      ((input_beta_left_shift < -31) || (input_beta_left_shift > 31)), -1);
-  XA_NNLIB_ARG_CHK_COND((input_beta_multiplier < 0), -1);
-
-  int i;
-  int shift_bits_reciprocal;
-  WORD8 *p_in;
-  WORD32 *__restrict pExp = (WORD32 *)ALIGN_PTR(pScratch, ALIGNMENT);
-  ae_p24x2s p_recip_sum_exp;
-  ae_p24x2s p_x;
-  ae_p24x2s p_max = AE_MOVPA24(0xFF800000);
-
-  int pre_loop_count;
-  int main_loop_count;
-  int post_loop_count;
-
-  if (vec_length > 1) {
-    pre_loop_count = (int)pVec & 0x1;
-    main_loop_count = vec_length - pre_loop_count;
-    post_loop_count = (main_loop_count & 1);
-    main_loop_count = main_loop_count >> 1;
-  } else {
-    pre_loop_count = 0;
-    main_loop_count = 0;
-    post_loop_count = vec_length;
-  }
-
-  /* Calculating Max */
-  {
-    p_in = (WORD8 *)pVec;
-
-    if (pre_loop_count) {
-      p_x = AE_MOVPA24(*p_in++);
-      p_max = AE_MAXP24S(p_max, p_x);
-    }
-
-    p_in -= 2;
-    for (i = 0; i < main_loop_count; i++) {
-      AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8));
-      p_max = AE_MAXP24S(p_max, p_x);
-    }
-    p_max = AE_SRAIP24(p_max, 16);
-
-    if (post_loop_count) {
-      p_in += 2;
-      p_x = AE_MOVPA24(*p_in);
-      p_max = AE_MAXP24S(p_max, p_x);
-    }
-    p_max = AE_MAXP24S(p_max, AE_SELP24_LH(p_max, p_max));
-  }
-
-  /* Calculate exponents */
-  {
-    ae_q56s q_sum_exp = AE_ZEROQ56();
-    ae_p24x2s p_rem_x, p_y, p_exp_y;
-    ae_p24x2s p_zero = AE_ZEROP48();
-    ae_p24x2s p_input_beta_multiplier =
-        AE_MOVPA24((input_beta_multiplier >> 8));
-    ae_p24x2s p_diffmin = AE_MOVPA24(diffmin);
-    int input_beta_left_shift_for_24bit = input_beta_left_shift - 8;
-
-    p_in = (WORD8 *)pVec;
-    WUR_AE_SAR(input_beta_left_shift_for_24bit);
-
-    if (pre_loop_count) {
-      p_x = AE_MOVPA24(*p_in++);
-      p_rem_x = p_x - p_max;
-      p_y = AE_SLLSSP24S(p_rem_x);
-
-      ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier);
-
-      ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1);
-
-      EXP_Q26_I(p_exp_y, p_dequantized_y1)
-
-      xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin);
-      AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp);
-
-      *(ae_p24f *)&pExp[0] = p_exp_y;
-
-      p_exp_y = p_exp_y >> 4;
-
-      AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1));
-    }
-
-    p_in -= 2;
-    for (i = 0; i < main_loop_count; i++) {
-      AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8));
-      p_x = AE_SRAIP24(p_x, 16);
-      p_rem_x = p_x - p_max;
-      p_y = AE_SLLSSP24S(p_rem_x);
-
-      ae_q56s q_dequantized_y1 = AE_MULFP24S_HH(p_y, p_input_beta_multiplier);
-      ae_q56s q_dequantized_y2 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier);
-
-      ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1);
-      ae_p24x2s p_dequantized_y2 = AE_ROUNDSP24Q48ASYM(q_dequantized_y2);
-
-      ae_p24x2s p_dequantized =
-          AE_SELP24_LL(p_dequantized_y1, p_dequantized_y2);
-
-      EXP_Q26_II(p_exp_y, p_dequantized)
-
-      xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin);
-      AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp);
-
-      *(ae_p24f *)&pExp[pre_loop_count + 2 * i] =
-          AE_SELP24_HH(p_exp_y, p_exp_y);
-      *(ae_p24f *)&pExp[pre_loop_count + 2 * i + 1] =
-          AE_SELP24_LL(p_exp_y, p_exp_y);
-
-      p_exp_y = p_exp_y >> 4;
-
-      AE_MULAAP24S_HH_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1));
-    }
-
-    if (post_loop_count) {
-      p_in += 2;
-
-      p_x = AE_MOVPA24(*p_in);
-      p_rem_x = p_x - p_max;
-      p_y = AE_SLLSSP24S(p_rem_x);
-
-      ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier);
-
-      ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1);
-
-      EXP_Q26_I(p_exp_y, p_dequantized_y1)
-
-      xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin);
-      AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp);
-
-      *(ae_p24f *)&pExp[vec_length - 1] = p_exp_y;
-
-      p_exp_y = p_exp_y >> 4;
-
-      AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1));
-    }
-
-    p_recip_sum_exp = GetReciprocal(q_sum_exp, 12, &shift_bits_reciprocal);
-  }
-
-  /* Calculate output */
-  pExp = (WORD32 *)ALIGN_PTR(pScratch, ALIGNMENT);
-  {
-    ae_p24x2s p_exp;
-
-    int shift_val = -(shift_bits_reciprocal + 31 - 8 - 16);
-
-    ae_p24x2s p_min = AE_MOVPA24(-32768);
-    ae_p24x2s p_max = AE_MOVPA24(32767);
-
-    for (i = 0; i<vec_length >> 1; i++) {
-      int out;
-
-      p_exp = *(ae_p24x2f *)&pExp[2 * i];
-
-      ae_q56s q_tmp1 = AE_MULFP24S_HH(p_exp, p_recip_sum_exp);
-      ae_q56s q_tmp2 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp);
-
-      q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val);
-      q_tmp2 = AE_SLAASQ56S(q_tmp2, shift_val);
-
-      ae_p24x2s p_out1 = AE_ROUNDSP24Q48ASYM(q_tmp1);
-      ae_p24x2s p_out2 = AE_ROUNDSP24Q48ASYM(q_tmp2);
-
-      ae_p24x2s p_out = AE_SELP24_LL(p_out1, p_out2);
-
-      p_out = AE_SUBSP24S(p_out, AE_MOVPA24(32768));
-      p_out = AE_MAXP24S(p_out, p_min);
-      p_out = AE_MINP24S(p_out, p_max);
-
-      out = AE_MOVAP24S_H(p_out);
-      *pOut++ = (WORD16)out;
-
-      out = AE_MOVAP24S_L(p_out);
-      *pOut++ = (WORD16)out;
-    }
-
-    if (vec_length & 0x1) {
-      int out;
-
-      p_exp = *(ae_p24f *)&pExp[vec_length - 1];
-
-      ae_q56s q_tmp1 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp);
-
-      q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val);
-
-      ae_p24x2s p_out = AE_ROUNDSP24Q48ASYM(q_tmp1);
-
-      p_out = AE_SUBSP24S(p_out, AE_MOVPA24(32768));
-      p_out = AE_MAXP24S(p_out, p_min);
-      p_out = AE_MINP24S(p_out, p_max);
-
-      out = AE_MOVAP24S_L(p_out);
-      *pOut++ = (WORD16)out;
-    }
-  }
-
-  return 0;
-}
-
-int xa_nn_get_softmax_scratch_size(int inp_precision, int out_precision,
-                                   int length) {
-  int size_of_one_elm_in_bytes, total_bytes;
-  (void)out_precision;
-
-  /* This function returns scratch size required by softmax implementation in
-     bytes scratch memory is needed to save exponents of inputs computed in the
-     function, every exponent is computed as 32 bit (4 bytes) number currently*/
-  switch (inp_precision) {
-    case PREC_ASYM8U:
-      size_of_one_elm_in_bytes = 4;
-      break;
-    case PREC_SYM8S:
-      size_of_one_elm_in_bytes = 4;
-      break;
-    default:
-      size_of_one_elm_in_bytes = 4;
-      break;
-  }
-
-  total_bytes = size_of_one_elm_in_bytes * length;
-  total_bytes = ALIGNED_SIZE(total_bytes, ALIGNMENT);
-
-  return total_bytes;
-}
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/basic/hifi_mini/xa_nn_dot_prod_16x16.c b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/basic/hifi_mini/xa_nn_dot_prod_16x16.c
deleted file mode 100644
index 80697ca7068..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/basic/hifi_mini/xa_nn_dot_prod_16x16.c
+++ /dev/null
@@ -1,175 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xa_nnlib_common.h"
-#include "xa_nnlib_common_macros.h"
-
-/*----------------------------Main function---------------------------------*/
-WORD32 xa_nn_dot_prod_16x16_asym8s(
-    WORD8 *__restrict__ p_out,               /* pointer to output */
-    const WORD16 *__restrict__ p_inp1_start, /* pointer to input1 */
-    const WORD16 *__restrict__ p_inp2_start, /* pointer to input2 */
-    const WORD32 *bias_ptr, WORD32 vec_length, WORD32 out_multiplier,
-    WORD32 out_shift, WORD32 out_zero_bias, WORD32 vec_count) {
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(p_out, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_inp1_start, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_inp2_start, -1);
-  /* Pointer alignment checks */
-  XA_NNLIB_ARG_CHK_ALIGN(p_inp1_start, sizeof(WORD16), -1);
-  XA_NNLIB_ARG_CHK_ALIGN(p_inp2_start, sizeof(WORD16), -1);
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((vec_length <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1);
-  XA_NNLIB_ARG_CHK_COND((out_zero_bias < -128 || out_zero_bias > 127), -1);
-  int left_shift, right_shift;
-  int loopcnt;
-  const WORD32 bias_buffer[2] = {0, 0};
-  const WORD32 *p_bias_load;
-  WORD32 bias_address_increment = sizeof(WORD32);
-
-  if (bias_ptr == NULL) {
-    p_bias_load = bias_buffer - 1;
-    bias_address_increment = 0;
-  } else {
-    p_bias_load = bias_ptr - 1;
-  }
-
-  left_shift = out_shift < 0 ? 0 : out_shift;
-  right_shift = out_shift > 0 ? 0 : -out_shift;
-  /* inp1 4-bytes aligned, inp2 4-bytes aligned and vec_length is multple of 2
-   */
-  if (((((unsigned)p_inp1_start) & 0x3) == 0) &&
-      ((((unsigned)p_inp2_start) & 0x3) == 0) && ((vec_length & 0x1) == 0)) {
-    const ae_p16x2s *pt_inp1, *pt_inp2;
-    pt_inp1 = (const ae_p16x2s *)&p_inp1_start[-2];
-    pt_inp2 = (const ae_p16x2s *)&p_inp2_start[-2];
-
-    ae_q56s output_int8_max_56 = AE_CVTQ48A32S(127);
-    ae_q56s output_int8_min_56 = AE_CVTQ48A32S(-128);
-    for (loopcnt = 0; loopcnt < vec_count; loopcnt++) {
-      ae_p24x2s dp_inp1, dp_inp2;
-      ae_q32s dq_out32;
-      ae_q56s dq_out;
-      int i;
-
-      AE_LQ32F_XU(dq_out, (ae_q32s *)p_bias_load, bias_address_increment);
-
-      for (i = 0; i < (vec_length >> 1); i++) {
-        AE_LP16X2F_IU(dp_inp1, pt_inp1, 4);
-        AE_LP16X2F_IU(dp_inp2, pt_inp2, 4);
-        AE_MULAAP24S_HH_LL(dq_out, dp_inp1, dp_inp2);
-      }
-
-      dq_out32 = AE_SATQ48S(dq_out);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      dq_out = AE_ADDSQ56S(dq_out, AE_CVTQ48A32S(out_zero_bias));
-
-      dq_out = AE_MAXQ56S(dq_out, output_int8_min_56);
-      dq_out = AE_MINQ56S(dq_out, output_int8_max_56);
-      *p_out++ = (WORD8)AE_TRUNCA32Q48(dq_out);
-    }
-  } else {
-#ifndef DISABLE_NNLIB_UNALIGNED_SUPPORT
-    for (loopcnt = 0; loopcnt < vec_count; loopcnt++) {
-      ae_p24x2s dp_inp1, dp_inp2;
-      ae_q32s dq_out32;
-      ae_q56s dq_out;
-      int i;
-      const WORD16 *p_inp1 = (WORD16 *)&p_inp1_start[loopcnt * vec_length];
-      const WORD16 *p_inp2 = (WORD16 *)&p_inp2_start[loopcnt * vec_length];
-
-      AE_LQ32F_XU(dq_out, (ae_q32s *)p_bias_load, bias_address_increment);
-
-      if (((((unsigned)p_inp1) & 3) != 0 && (((unsigned)p_inp2) & 3) != 0) ||
-          ((((unsigned)p_inp1) & 3) == 0 && (((unsigned)p_inp2) & 3) == 0)) {
-        int pre_loop_count = ((int)(((unsigned)p_inp1) & 3)) >> 1;
-        if (pre_loop_count != 0) {
-          dp_inp1 = AE_CVTP24A16X2_LL(*p_inp1++, *p_inp2++);
-          AE_MULAP24S_HL(dq_out, dp_inp1, dp_inp1);
-        }
-        const ae_p16x2s *pt_inp1, *pt_inp2;
-        pt_inp1 = (const ae_p16x2s *)(p_inp1 - 2);
-        pt_inp2 = (const ae_p16x2s *)(p_inp2 - 2);
-        for (i = 0; i < (vec_length - pre_loop_count - 1); i += 2) {
-          AE_LP16X2F_IU(dp_inp1, pt_inp1, 4);
-          AE_LP16X2F_IU(dp_inp2, pt_inp2, 4);
-          AE_MULAAP24S_HH_LL(dq_out, dp_inp1, dp_inp2);
-        }
-        if ((vec_length - pre_loop_count) & 1) {
-          dp_inp1 = AE_CVTP24A16X2_LL(p_inp1[i], p_inp2[i]);
-          AE_MULAP24S_HL(dq_out, dp_inp1, dp_inp1);
-        }
-      } else {
-        /* One of the pointers in not aligned to 4 bytes, if it is p_inp1, swap
-         * them */
-        if ((((unsigned)p_inp1) & 3) != 0) {
-          const WORD16 *p_tmp;
-          p_tmp = p_inp1;
-          p_inp1 = p_inp2;
-          p_inp2 = p_tmp;
-        }
-        const ae_p16x2s *pt_inp1 = (const ae_p16x2s *)(p_inp1 - 2);
-        const ae_p16s *pt_inp2 = (const ae_p16s *)(p_inp2 - 1);
-        for (i = 0; i < (vec_length - 1); i += 2) {
-          ae_p24x2s dp_t0, dp_t1;
-          AE_LP16X2F_IU(dp_inp1, pt_inp1, 4);
-          AE_LP16F_IU(dp_t0, pt_inp2, 2);
-          AE_LP16F_IU(dp_t1, pt_inp2, 2);
-          dp_inp2 = AE_SELP24_LL(dp_t0, dp_t1);
-          AE_MULAAP24S_HH_LL(dq_out, dp_inp1, dp_inp2);
-        }
-        if (vec_length & 1) {
-          dp_inp1 = AE_CVTP24A16X2_LL(p_inp1[i], p_inp2[i]);
-          AE_MULAP24S_HL(dq_out, dp_inp1, dp_inp1);
-        }
-      }
-      dq_out32 = AE_SATQ48S(dq_out);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      dq_out = AE_ADDSQ56S(dq_out, AE_CVTQ48A32S(out_zero_bias));
-      WORD32 out_i32 = AE_TRUNCA32Q48(AE_SATQ48S(dq_out));
-      out_i32 = out_i32 < -128 ? -128 : out_i32;
-      out_i32 = out_i32 > 127 ? 127 : out_i32;
-      *p_out++ = (WORD8)out_i32;
-    }
-#else
-    return 1;
-#endif
-  }
-  return 0;
-}
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/fc/hifi_mini/xa_nn_fully_connected.c b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/fc/hifi_mini/xa_nn_fully_connected.c
deleted file mode 100644
index 0a9325e81bf..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/fc/hifi_mini/xa_nn_fully_connected.c
+++ /dev/null
@@ -1,142 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xa_nnlib_err_chk.h"
-#include "xa_nnlib_kernels_api.h"
-#include "xa_type_def.h"
-
-WORD32 xa_nn_fully_connected_asym8uxasym8u_asym8u(
-    UWORD8 *__restrict__ p_out, const UWORD8 *__restrict__ p_weight,
-    const UWORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias,
-    WORD32 weight_depth, WORD32 out_depth, WORD32 input_zero_bias,
-    WORD32 weight_zero_bias, WORD32 out_multiplier, WORD32 out_shift,
-    WORD32 out_zero_bias) {
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(p_out, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_weight, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_inp, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_bias, -1);
-  /* Pointer alignment checks */
-  XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((out_depth <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND((input_zero_bias < -255 || input_zero_bias > 0), -1);
-  XA_NNLIB_ARG_CHK_COND((weight_zero_bias < -255 || weight_zero_bias > 0), -1);
-  XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1);
-  XA_NNLIB_ARG_CHK_COND((out_zero_bias < 0 || out_zero_bias > 255), -1);
-
-  WORD32 ret = 0;
-  ret = xa_nn_matXvec_out_stride_asym8uxasym8u_asym8u(
-      p_out, p_weight, p_inp, p_bias, out_depth /* rows */
-      ,
-      weight_depth /* cols */
-      ,
-      weight_depth /* row_stride */
-      ,
-      1 /* out_stride */
-      ,
-      weight_zero_bias, input_zero_bias, out_multiplier, out_shift,
-      out_zero_bias);
-  return ret;
-}
-
-WORD32 xa_nn_fully_connected_sym8sxasym8s_asym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_weight,
-    const WORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias,
-    WORD32 weight_depth, WORD32 out_depth, WORD32 input_zero_bias,
-    WORD32 out_multiplier, WORD32 out_shift, WORD32 out_zero_bias) {
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(p_out, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_weight, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_inp, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_bias, -1);
-  /* Pointer alignment checks */
-  XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((out_depth <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND((input_zero_bias < -127 || input_zero_bias > 128), -1);
-  XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1);
-  XA_NNLIB_ARG_CHK_COND((out_zero_bias < -128 || out_zero_bias > 127), -1);
-
-  WORD32 ret = 0;
-  ret = xa_nn_matXvec_out_stride_sym8sxasym8s_asym8s(
-      p_out, p_weight, p_inp, p_bias, out_depth /* rows */
-      ,
-      weight_depth /* cols */
-      ,
-      weight_depth /* row_stride */
-      ,
-      1 /* out_stride */
-      ,
-      input_zero_bias, out_multiplier, out_shift, out_zero_bias);
-  return ret;
-}
-
-WORD32 xa_nn_fully_connected_asym8sxasym8s_asym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_weight,
-    const WORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias,
-    WORD32 weight_depth, WORD32 out_depth, WORD32 weight_zero_bias,
-    WORD32 input_zero_bias, WORD32 out_multiplier, WORD32 out_shift,
-    WORD32 out_zero_bias) {
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(p_out, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_weight, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_inp, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_bias, -1);
-  /* Pointer alignment checks */
-  XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((out_depth <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND((weight_zero_bias < -127 || weight_zero_bias > 128),
-                        -1);
-  XA_NNLIB_ARG_CHK_COND((input_zero_bias < -127 || input_zero_bias > 128), -1);
-  XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1);
-  XA_NNLIB_ARG_CHK_COND((out_zero_bias < -128 || out_zero_bias > 127), -1);
-
-  WORD32 ret = 0;
-  ret = xa_nn_matXvec_out_stride_asym8sxasym8s_asym8s(
-      p_out, p_weight, p_inp, p_bias, out_depth /* rows */
-      ,
-      weight_depth /* cols */
-      ,
-      weight_depth /* row_stride */
-      ,
-      1 /* out_stride */
-      ,
-      weight_zero_bias, input_zero_bias, out_multiplier, out_shift,
-      out_zero_bias);
-  return ret;
-}
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/matXvec/hifi_mini/xa_nn_matXvec_sym8sxasym8s.c b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/matXvec/hifi_mini/xa_nn_matXvec_sym8sxasym8s.c
deleted file mode 100644
index 71af822e68b..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/matXvec/hifi_mini/xa_nn_matXvec_sym8sxasym8s.c
+++ /dev/null
@@ -1,1053 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xa_nnlib_common.h"
-#include "xa_nnlib_common_macros.h"
-
-#define ADD_OUT_OFFSET_STORE_INT8(ptr, data, out_offset) \
-  {                                                      \
-    data = AE_ADDSQ56S(data, AE_CVTQ48A32S(out_offset)); \
-    int out_i32 = AE_TRUNCA32Q48(AE_SATQ48S(data));      \
-    out_i32 = out_i32 < -128 ? -128 : out_i32;           \
-    out_i32 = out_i32 > 127 ? 127 : out_i32;             \
-    *(ptr) = (WORD8)out_i32;                             \
-  }
-
-WORD32 xa_nn_matXvec_out_stride_sym8sxasym8s_asym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1,
-    const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias,
-    WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride,
-    WORD32 vec1_zero_bias, WORD32 out_multiplier, WORD32 out_shift,
-    WORD32 out_zero_bias) {
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(p_out, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_mat1, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_vec1, -1);
-  /* Pointer alignment checks */
-  XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((rows <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND((cols1 <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND((row_stride1 < cols1), -1);
-  XA_NNLIB_ARG_CHK_COND((vec1_zero_bias < -127 || vec1_zero_bias > 128), -1);
-  XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1);
-  XA_NNLIB_ARG_CHK_COND((out_zero_bias < -128 || out_zero_bias > 127), -1);
-
-  /* Iterators used in for loops */
-  int m_itr, c_itr, i;
-  /* Assign initial value so this value will be used in trailing loop */
-  m_itr = 0;
-  /* Shifts to match with Tensorflow */
-  int left_shift, right_shift;
-
-  left_shift = out_shift < 0 ? 0 : out_shift;
-  right_shift = out_shift > 0 ? 0 : -out_shift;
-
-  const WORD8 *p_mat1_0, *p_mat1_1, *p_mat1_2, *p_mat1_3;
-  const WORD8 *p_vec1_0;
-  ae_p24x2s dp_mat1_0, dp_mat1_1, dp_mat1_2, dp_mat1_3, dp_vec1_0;
-  ae_p24x2s dp_vec1_zb;
-  ae_q56s dq_acc[4];
-  ae_q56s dq_out32, dq_out;
-
-  dp_vec1_zb = AE_MOVPA24(vec1_zero_bias);
-  if (((((unsigned)p_mat1) & 1) == 0) && ((((unsigned)p_vec1) & 1) == 0) &&
-      ((row_stride1 & 1) == 0)) {
-    for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) {
-      p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2];
-      p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1 - 2];
-      p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2];
-      p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1 - 2];
-      p_vec1_0 = p_vec1 - 2;
-
-      dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56();
-
-      /* AE_LP8X2F* instruction loads in upper 8 bits of P register, so shifting
-      vector right by 16 to get multiplication result in middle 32 bits of Q
-      register (lower 16 bits 0) */
-      for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2);
-        AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2);
-        AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-        AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-        AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0);
-        AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0);
-        AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0);
-      }
-      /* Pointers are aligned so can do 8X2 loads and ignore L parts of
-       * registers */
-      if (cols1 & 1) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2);
-        AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2);
-        AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-        AE_MULAP24S_HH(dq_acc[0], dp_mat1_0, dp_vec1_0);
-        AE_MULAP24S_HH(dq_acc[1], dp_mat1_1, dp_vec1_0);
-        AE_MULAP24S_HH(dq_acc[2], dp_mat1_2, dp_vec1_0);
-        AE_MULAP24S_HH(dq_acc[3], dp_mat1_3, dp_vec1_0);
-      }
-
-      if (p_bias != NULL) {
-        for (i = 0; i < 4; i++)
-          dq_acc[i] = AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i]));
-      }
-
-      for (i = 0; i < 4; i++) {
-        dq_out32 = AE_SATQ48S(dq_acc[i]);
-        MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                         out_multiplier, left_shift,
-                                         right_shift);
-        ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out,
-                                  out_zero_bias);
-      }
-    }
-    for (; m_itr < rows; m_itr++) {
-      p_mat1_0 = &p_mat1[m_itr * row_stride1 - 2];
-      p_vec1_0 = p_vec1 - 2;
-
-      dq_acc[0] = AE_ZEROQ56();
-
-      for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-        AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-      }
-      /* Pointers are aligned so can do 8X2 loads and ignore L parts of
-       * registers */
-      if (cols1 & 1) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-        AE_MULAP24S_HH(dq_acc[0], dp_mat1_0, dp_vec1_0);
-      }
-
-      if (p_bias != NULL)
-        dq_acc[0] = AE_ADDSQ56S(dq_acc[0], *(ae_q32s *)(&p_bias[m_itr]));
-
-      dq_out32 = AE_SATQ48S(dq_acc[0]);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      ADD_OUT_OFFSET_STORE_INT8(&p_out[m_itr * out_stride], dq_out,
-                                out_zero_bias);
-    }
-  } else {
-    if ((((unsigned)p_mat1) & 1) == 0) {
-      for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) {
-        p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2];
-        p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1];
-        p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2];
-        p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1];
-        p_vec1_0 = p_vec1;
-
-        dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56();
-
-        /* Matrix elements are kept in upper 8 bits of P registers, vector
-        elements are kept in lower 8 bits of P registers, typecasting to UWORD8
-        is to avoid extra extui instructions since signed 8-bit load in not
-        there in HiFiMini */
-        for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-          AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-          dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr],
-                                        (UWORD8)p_mat1_1[c_itr + 1]);
-          AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2);
-          dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr],
-                                        (UWORD8)p_mat1_3[c_itr + 1]);
-          dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr],
-                                        (UWORD8)p_vec1_0[c_itr + 1]);
-          dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8);
-          dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8);
-          dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8);
-          dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-          AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0);
-        }
-        if (cols1 & 1) {
-          ae_p24x2s dp_mat1_01, dp_mat1_23;
-          dp_mat1_01 =
-              AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[2], (UWORD8)p_mat1_1[c_itr]);
-          dp_mat1_23 =
-              AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[2], (UWORD8)p_mat1_3[c_itr]);
-          dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]);
-          dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8);
-          dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-          AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0);
-        }
-
-        if (p_bias != NULL) {
-          for (i = 0; i < 4; i++)
-            dq_acc[i] =
-                AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i]));
-        }
-
-        for (i = 0; i < 4; i++) {
-          dq_out32 = AE_SATQ48S(dq_acc[i]);
-          MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                           out_multiplier, left_shift,
-                                           right_shift);
-          ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out,
-                                    out_zero_bias);
-        }
-      }
-    } else {
-      for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) {
-        p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1];
-        p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1];
-        p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1];
-        p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1];
-        p_vec1_0 = p_vec1;
-
-        dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56();
-
-        /* Matrix elements are kept in upper 8 bits of P registers, vector
-        elements are kept in lower 8 bits of P registers, typecasting to UWORD8
-        is to avoid extra extui instructions since signed 8-bit load in not
-        there in HiFiMini */
-        for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-          dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr],
-                                        (UWORD8)p_mat1_0[c_itr + 1]);
-          dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr],
-                                        (UWORD8)p_mat1_1[c_itr + 1]);
-          dp_mat1_2 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr],
-                                        (UWORD8)p_mat1_2[c_itr + 1]);
-          dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr],
-                                        (UWORD8)p_mat1_3[c_itr + 1]);
-          dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr],
-                                        (UWORD8)p_vec1_0[c_itr + 1]);
-          dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8);
-          dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8);
-          dp_mat1_2 = AE_SLLIP24(dp_mat1_2, 8);
-          dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8);
-          dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8);
-          dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-          AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0);
-        }
-        if (cols1 & 1) {
-          ae_p24x2s dp_mat1_01, dp_mat1_23;
-          dp_mat1_01 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr],
-                                         (UWORD8)p_mat1_1[c_itr]);
-          dp_mat1_23 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr],
-                                         (UWORD8)p_mat1_3[c_itr]);
-          dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]);
-          dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8);
-          dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-          AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0);
-        }
-
-        if (p_bias != NULL) {
-          for (i = 0; i < 4; i++)
-            dq_acc[i] =
-                AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i]));
-        }
-
-        for (i = 0; i < 4; i++) {
-          dq_out32 = AE_SATQ48S(dq_acc[i]);
-          MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                           out_multiplier, left_shift,
-                                           right_shift);
-          ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out,
-                                    out_zero_bias);
-        }
-      }
-    }
-    for (; m_itr < rows; m_itr++) {
-      p_mat1_0 = &p_mat1[m_itr * row_stride1];
-      p_vec1_0 = p_vec1;
-
-      dq_acc[0] = AE_ZEROQ56();
-
-      for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-        dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr],
-                                      (UWORD8)p_mat1_0[c_itr + 1]);
-        dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr],
-                                      (UWORD8)p_vec1_0[c_itr + 1]);
-        dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8);
-        dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-        AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-      }
-      if (cols1 & 1) {
-        dp_mat1_0 = AE_CVTP24A16(p_mat1_0[c_itr]);
-        dp_vec1_0 = AE_CVTP24A16(p_vec1_0[c_itr]);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, AE_CVTP24A16(vec1_zero_bias));
-        AE_MULAP24S_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-      }
-
-      if (p_bias != NULL)
-        dq_acc[0] = AE_ADDSQ56S(dq_acc[0], *(ae_q32s *)(&p_bias[m_itr]));
-
-      dq_out32 = AE_SATQ48S(dq_acc[0]);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      ADD_OUT_OFFSET_STORE_INT8(&p_out[m_itr * out_stride], dq_out,
-                                out_zero_bias);
-    }
-  }
-
-  return 0;
-}
-
-WORD32 xa_nn_matXvec_out_stride_asym8sxasym8s_asym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1,
-    const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias,
-    WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride,
-    WORD32 mat1_zero_bias, WORD32 vec1_zero_bias, WORD32 out_multiplier,
-    WORD32 out_shift, WORD32 out_zero_bias) {
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(p_out, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_mat1, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_vec1, -1);
-  /* Pointer alignment checks */
-  XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((rows <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND((cols1 <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND((row_stride1 < cols1), -1);
-  XA_NNLIB_ARG_CHK_COND((mat1_zero_bias < -127 || mat1_zero_bias > 128), -1);
-  XA_NNLIB_ARG_CHK_COND((vec1_zero_bias < -127 || vec1_zero_bias > 128), -1);
-  XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1);
-  XA_NNLIB_ARG_CHK_COND((out_zero_bias < -128 || out_zero_bias > 127), -1);
-
-  /* Iterators used in for loops */
-  int m_itr, c_itr, i;
-  /* Assign initial value so this value will be used in trailing loop */
-  m_itr = 0;
-  /* Shifts to match with Tensorflow */
-  int left_shift, right_shift;
-
-  left_shift = out_shift < 0 ? 0 : out_shift;
-  right_shift = out_shift > 0 ? 0 : -out_shift;
-
-  const WORD8 *p_mat1_0, *p_mat1_1, *p_mat1_2, *p_mat1_3;
-  const WORD8 *p_vec1_0;
-  ae_p24x2s dp_mat1_0, dp_mat1_1, dp_mat1_2, dp_mat1_3, dp_vec1_0;
-  ae_p24x2s dp_vec1_zb, dp_mat1_zb;
-  ae_q56s dq_acc_0, dq_acc_1, dq_acc_2, dq_acc_3;
-  ae_q56s dq_out32, dq_out;
-
-  const WORD32 bias_buffer[1] = {0};
-  const WORD32 *p_bias_load;
-  WORD32 bias_address_increment = sizeof(WORD32);
-
-  dp_mat1_zb = AE_MOVPA24(mat1_zero_bias);
-  dp_vec1_zb = AE_MOVPA24(vec1_zero_bias);
-
-  /* Check for alignment conditions */
-  if (((((unsigned)p_mat1) & 1) == 0) && ((((unsigned)p_vec1) & 1) == 0) &&
-      ((row_stride1 & 1) == 0) && ((cols1 & 1) == 0)) {
-    /* Calculate partial zero offset adjustment outside the loop */
-    WORD32 zero_offset_adjustment;
-
-    // Constant part of total zero bias
-    ae_q56s dq_zero_bias_sum =
-        AE_CVTQ48A32S(vec1_zero_bias * cols1 * mat1_zero_bias);
-
-    WORD8 *p_inp = (WORD8 *)p_vec1 - 2;
-    for (i = 0; i < (cols1 >> 1); i++) {
-      /* Input vector is in MSB 8 bits, matrix zero bias in LSB 8 bits */
-      AE_LP8X2F_IU(dp_vec1_0, p_inp, 2);
-      AE_MULAAP24S_HH_LL(dq_zero_bias_sum, dp_vec1_0, dp_mat1_zb);
-    }
-    /* Product is already aligned to bits 16 to 47 in QR register. */
-    zero_offset_adjustment = AE_TRUNCA32Q48(dq_zero_bias_sum);
-
-    /* If bias is not provided, use a dummy zero value from bias_buffer. */
-    if (p_bias == NULL) {
-      p_bias_load = bias_buffer - 1;
-      bias_address_increment = 0;
-    } else {
-      p_bias_load = p_bias - 1;
-    }
-
-    for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) {
-      p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2];
-      p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1 - 2];
-      p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2];
-      p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1 - 2];
-      p_vec1_0 = p_vec1 - 2;
-
-      AE_LQ32F_XU(dq_acc_0, (ae_q32s *)p_bias_load, bias_address_increment);
-      AE_LQ32F_XU(dq_acc_1, (ae_q32s *)p_bias_load, bias_address_increment);
-      AE_LQ32F_XU(dq_acc_2, (ae_q32s *)p_bias_load, bias_address_increment);
-      AE_LQ32F_XU(dq_acc_3, (ae_q32s *)p_bias_load, bias_address_increment);
-
-      dq_acc_0 = AE_ADDQ56(dq_acc_0, AE_CVTQ48A32S(zero_offset_adjustment));
-      dq_acc_1 = AE_ADDQ56(dq_acc_1, AE_CVTQ48A32S(zero_offset_adjustment));
-      dq_acc_2 = AE_ADDQ56(dq_acc_2, AE_CVTQ48A32S(zero_offset_adjustment));
-      dq_acc_3 = AE_ADDQ56(dq_acc_3, AE_CVTQ48A32S(zero_offset_adjustment));
-
-      /* AE_LP8X2F* instruction loads in upper 8 bits of P register, so shifting
-      vector right by 16 to get multiplication result in middle 32 bits of Q
-      register (lower 16 bits 0) */
-      for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2);
-        AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2);
-        AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-
-        AE_MULAAP24S_HH_LL(dq_acc_0, dp_mat1_0, dp_vec1_0);
-        AE_MULAAP24S_HH_LL(dq_acc_1, dp_mat1_1, dp_vec1_0);
-        AE_MULAAP24S_HH_LL(dq_acc_2, dp_mat1_2, dp_vec1_0);
-        AE_MULAAP24S_HH_LL(dq_acc_3, dp_mat1_3, dp_vec1_0);
-      }
-
-      /* Pointers are aligned so can do 8X2 loads and ignore L parts of
-       * registers */
-      if (cols1 & 1) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2);
-        AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2);
-        AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-
-        AE_MULAP24S_HH(dq_acc_0, dp_mat1_0, dp_vec1_0);
-        AE_MULAP24S_HH(dq_acc_1, dp_mat1_1, dp_vec1_0);
-        AE_MULAP24S_HH(dq_acc_2, dp_mat1_2, dp_vec1_0);
-        AE_MULAP24S_HH(dq_acc_3, dp_mat1_3, dp_vec1_0);
-      }
-
-      dq_out32 = AE_SATQ48S(dq_acc_0);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out,
-                                out_zero_bias);
-
-      dq_out32 = AE_SATQ48S(dq_acc_1);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out,
-                                out_zero_bias);
-
-      dq_out32 = AE_SATQ48S(dq_acc_2);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out,
-                                out_zero_bias);
-
-      dq_out32 = AE_SATQ48S(dq_acc_3);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out,
-                                out_zero_bias);
-    }
-    for (; m_itr < rows; m_itr++) {
-      p_mat1_0 = &p_mat1[m_itr * row_stride1 - 2];
-      p_vec1_0 = p_vec1 - 2;
-
-      AE_LQ32F_XU(dq_acc_0, (ae_q32s *)p_bias_load, bias_address_increment);
-      dq_acc_0 = AE_ADDQ56(dq_acc_0, AE_CVTQ48A32S(zero_offset_adjustment));
-
-      for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-
-        AE_MULAAP24S_HH_LL(dq_acc_0, dp_mat1_0, dp_vec1_0);
-      }
-
-      dq_out32 = AE_SATQ48S(dq_acc_0);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      ADD_OUT_OFFSET_STORE_INT8(&p_out[m_itr * out_stride], dq_out,
-                                out_zero_bias);
-    }
-  } else {
-#ifndef DISABLE_NNLIB_UNALIGNED_SUPPORT
-    ae_q56s dq_acc[4];
-
-    if ((((unsigned)p_mat1) & 1) == 0) {
-      for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) {
-        p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2];
-        p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1];
-        p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2];
-        p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1];
-        p_vec1_0 = p_vec1;
-
-        dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56();
-
-        /* Matrix elements are kept in upper 8 bits of P registers, vector
-        elements are kept in lower 8 bits of P registers, typecasting to UWORD8
-        is to avoid extra extui instructions since signed 8-bit load in not
-        there in HiFiMini */
-        for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-          AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-          dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr],
-                                        (UWORD8)p_mat1_1[c_itr + 1]);
-          AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2);
-          dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr],
-                                        (UWORD8)p_mat1_3[c_itr + 1]);
-          dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr],
-                                        (UWORD8)p_vec1_0[c_itr + 1]);
-          dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8);
-          dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8);
-          dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8);
-          dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-
-          dp_mat1_0 = AE_SRAIP24(dp_mat1_0, 16);
-          dp_mat1_0 = AE_ADDSP24S(dp_mat1_0, dp_mat1_zb);
-          dp_mat1_1 = AE_SRAIP24(dp_mat1_1, 16);
-          dp_mat1_1 = AE_ADDSP24S(dp_mat1_1, dp_mat1_zb);
-          dp_mat1_2 = AE_SRAIP24(dp_mat1_2, 16);
-          dp_mat1_2 = AE_ADDSP24S(dp_mat1_2, dp_mat1_zb);
-          dp_mat1_3 = AE_SRAIP24(dp_mat1_3, 16);
-          dp_mat1_3 = AE_ADDSP24S(dp_mat1_3, dp_mat1_zb);
-
-          AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0);
-        }
-        if (cols1 & 1) {
-          ae_p24x2s dp_mat1_01, dp_mat1_23;
-          dp_mat1_01 =
-              AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[2], (UWORD8)p_mat1_1[c_itr]);
-          dp_mat1_23 =
-              AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[2], (UWORD8)p_mat1_3[c_itr]);
-          dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]);
-          dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8);
-          dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-
-          dp_mat1_01 = AE_SRAIP24(dp_mat1_01, 16);
-          dp_mat1_01 = AE_ADDSP24S(dp_mat1_01, dp_mat1_zb);
-          dp_mat1_23 = AE_SRAIP24(dp_mat1_23, 16);
-          dp_mat1_23 = AE_ADDSP24S(dp_mat1_23, dp_mat1_zb);
-
-          AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0);
-        }
-
-        dq_acc[0] = AE_SLLISQ56S(dq_acc[0], 16);
-        dq_acc[1] = AE_SLLISQ56S(dq_acc[1], 16);
-        dq_acc[2] = AE_SLLISQ56S(dq_acc[2], 16);
-        dq_acc[3] = AE_SLLISQ56S(dq_acc[3], 16);
-
-        if (p_bias != NULL) {
-          for (i = 0; i < 4; i++)
-            dq_acc[i] =
-                AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i]));
-        }
-
-        for (i = 0; i < 4; i++) {
-          dq_out32 = AE_SATQ48S(dq_acc[i]);
-          MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                           out_multiplier, left_shift,
-                                           right_shift);
-          ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out,
-                                    out_zero_bias);
-        }
-      }
-    } else {
-      for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) {
-        p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1];
-        p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1];
-        p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1];
-        p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1];
-        p_vec1_0 = p_vec1;
-
-        dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56();
-
-        /* Matrix elements are kept in upper 8 bits of P registers, vector
-        elements are kept in lower 8 bits of P registers, typecasting to UWORD8
-        is to avoid extra extui instructions since signed 8-bit load in not
-        there in HiFiMini */
-        for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-          dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr],
-                                        (UWORD8)p_mat1_0[c_itr + 1]);
-          dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr],
-                                        (UWORD8)p_mat1_1[c_itr + 1]);
-          dp_mat1_2 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr],
-                                        (UWORD8)p_mat1_2[c_itr + 1]);
-          dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr],
-                                        (UWORD8)p_mat1_3[c_itr + 1]);
-          dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr],
-                                        (UWORD8)p_vec1_0[c_itr + 1]);
-          dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8);
-          dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8);
-          dp_mat1_2 = AE_SLLIP24(dp_mat1_2, 8);
-          dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8);
-          dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8);
-          dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-
-          dp_mat1_0 = AE_SRAIP24(dp_mat1_0, 16);
-          dp_mat1_0 = AE_ADDSP24S(dp_mat1_0, dp_mat1_zb);
-          dp_mat1_1 = AE_SRAIP24(dp_mat1_1, 16);
-          dp_mat1_1 = AE_ADDSP24S(dp_mat1_1, dp_mat1_zb);
-          dp_mat1_2 = AE_SRAIP24(dp_mat1_2, 16);
-          dp_mat1_2 = AE_ADDSP24S(dp_mat1_2, dp_mat1_zb);
-          dp_mat1_3 = AE_SRAIP24(dp_mat1_3, 16);
-          dp_mat1_3 = AE_ADDSP24S(dp_mat1_3, dp_mat1_zb);
-
-          AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0);
-        }
-        if (cols1 & 1) {
-          ae_p24x2s dp_mat1_01, dp_mat1_23;
-          dp_mat1_01 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr],
-                                         (UWORD8)p_mat1_1[c_itr]);
-          dp_mat1_23 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr],
-                                         (UWORD8)p_mat1_3[c_itr]);
-          dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]);
-          dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8);
-          dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-
-          dp_mat1_01 = AE_SRAIP24(dp_mat1_01, 16);
-          dp_mat1_01 = AE_ADDSP24S(dp_mat1_01, dp_mat1_zb);
-          dp_mat1_23 = AE_SRAIP24(dp_mat1_23, 16);
-          dp_mat1_23 = AE_ADDSP24S(dp_mat1_23, dp_mat1_zb);
-
-          AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0);
-        }
-
-        dq_acc[0] = AE_SLLISQ56S(dq_acc[0], 16);
-        dq_acc[1] = AE_SLLISQ56S(dq_acc[1], 16);
-        dq_acc[2] = AE_SLLISQ56S(dq_acc[2], 16);
-        dq_acc[3] = AE_SLLISQ56S(dq_acc[3], 16);
-
-        if (p_bias != NULL) {
-          for (i = 0; i < 4; i++)
-            dq_acc[i] =
-                AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i]));
-        }
-
-        for (i = 0; i < 4; i++) {
-          dq_out32 = AE_SATQ48S(dq_acc[i]);
-          MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                           out_multiplier, left_shift,
-                                           right_shift);
-          ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out,
-                                    out_zero_bias);
-        }
-      }
-    }
-    for (; m_itr < rows; m_itr++) {
-      p_mat1_0 = &p_mat1[m_itr * row_stride1];
-      p_vec1_0 = p_vec1;
-
-      dq_acc[0] = AE_ZEROQ56();
-
-      for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-        dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr],
-                                      (UWORD8)p_mat1_0[c_itr + 1]);
-        dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr],
-                                      (UWORD8)p_vec1_0[c_itr + 1]);
-        dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8);
-        dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-
-        dp_mat1_0 = AE_SRAIP24(dp_mat1_0, 16);
-        dp_mat1_0 = AE_ADDSP24S(dp_mat1_0, dp_mat1_zb);
-
-        AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-      }
-      if (cols1 & 1) {
-        dp_mat1_0 = AE_CVTP24A16(p_mat1_0[c_itr]);
-        dp_vec1_0 = AE_CVTP24A16(p_vec1_0[c_itr]);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, AE_CVTP24A16(vec1_zero_bias));
-
-        dp_mat1_0 = AE_SRAIP24(dp_mat1_0, 16);
-        dp_mat1_0 = AE_ADDSP24S(dp_mat1_0, dp_mat1_zb);
-
-        AE_MULAP24S_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-      }
-
-      dq_acc[0] = AE_SLLISQ56S(dq_acc[0], 16);
-
-      if (p_bias != NULL)
-        dq_acc[0] = AE_ADDSQ56S(dq_acc[0], *(ae_q32s *)(&p_bias[m_itr]));
-
-      dq_out32 = AE_SATQ48S(dq_acc[0]);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      ADD_OUT_OFFSET_STORE_INT8(&p_out[m_itr * out_stride], dq_out,
-                                out_zero_bias);
-    }
-#else
-    return 1;
-#endif
-  }
-
-  return 0;
-}
-
-#define STORE_INT16(ptr, data)                                         \
-  {                                                                    \
-    int out_i32 = AE_TRUNCA32Q48(AE_SATQ48S(data));                    \
-    out_i32 = out_i32 < (int)0xffff8000L ? (int)0xffff8000L : out_i32; \
-    out_i32 = out_i32 > (int)0x7fff ? (int)0x7fff : out_i32;           \
-    *(ptr) = (WORD16)out_i32;                                          \
-  }
-
-WORD32 xa_nn_matXvec_out_stride_sym8sxasym8s_16(
-    WORD16 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1,
-    const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias,
-    WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride,
-    WORD32 vec1_zero_bias, WORD32 out_multiplier, WORD32 out_shift) {
-  /* NULL pointer checks */
-  XA_NNLIB_ARG_CHK_PTR(p_out, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_mat1, -1);
-  XA_NNLIB_ARG_CHK_PTR(p_vec1, -1);
-  /* Pointer alignment checks */
-  XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(WORD16), -1);
-  XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-  /* Basic Parameter checks */
-  XA_NNLIB_ARG_CHK_COND((rows <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND((cols1 <= 0), -1);
-  XA_NNLIB_ARG_CHK_COND((row_stride1 < cols1), -1);
-  XA_NNLIB_ARG_CHK_COND((vec1_zero_bias < -127 || vec1_zero_bias > 128), -1);
-  XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1);
-
-  /* Iterators used in for loops */
-  int m_itr, c_itr, i;
-  /* Assign initial value so this value will be used in trailing loop */
-  m_itr = 0;
-  /* Shifts to match with Tensorflow */
-  int left_shift, right_shift;
-
-  left_shift = out_shift < 0 ? 0 : out_shift;
-  right_shift = out_shift > 0 ? 0 : -out_shift;
-
-  const WORD8 *p_mat1_0, *p_mat1_1, *p_mat1_2, *p_mat1_3;
-  const WORD8 *p_vec1_0;
-  ae_p24x2s dp_mat1_0, dp_mat1_1, dp_mat1_2, dp_mat1_3, dp_vec1_0;
-  ae_p24x2s dp_vec1_zb;
-  ae_q56s dq_acc[4];
-  ae_q56s dq_out32, dq_out;
-
-  dp_vec1_zb = AE_MOVPA24(vec1_zero_bias);
-  if (((((unsigned)p_mat1) & 1) == 0) && ((((unsigned)p_vec1) & 1) == 0) &&
-      ((row_stride1 & 1) == 0)) {
-    for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) {
-      p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2];
-      p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1 - 2];
-      p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2];
-      p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1 - 2];
-      p_vec1_0 = p_vec1 - 2;
-
-      dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56();
-
-      /* AE_LP8X2F* instruction loads in upper 8 bits of P register, so shifting
-      vector right by 16 to get multiplication result in middle 32 bits of Q
-      register (lower 16 bits 0) */
-      for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2);
-        AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2);
-        AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-        AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-        AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0);
-        AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0);
-        AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0);
-      }
-      /* Pointers are aligned so can do 8X2 loads and ignore L parts of
-       * registers */
-      if (cols1 & 1) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2);
-        AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2);
-        AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-        AE_MULAP24S_HH(dq_acc[0], dp_mat1_0, dp_vec1_0);
-        AE_MULAP24S_HH(dq_acc[1], dp_mat1_1, dp_vec1_0);
-        AE_MULAP24S_HH(dq_acc[2], dp_mat1_2, dp_vec1_0);
-        AE_MULAP24S_HH(dq_acc[3], dp_mat1_3, dp_vec1_0);
-      }
-
-      if (p_bias != NULL) {
-        for (i = 0; i < 4; i++)
-          dq_acc[i] = AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i]));
-      }
-
-      for (i = 0; i < 4; i++) {
-        dq_out32 = AE_SATQ48S(dq_acc[i]);
-        MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                         out_multiplier, left_shift,
-                                         right_shift);
-        STORE_INT16(&p_out[(m_itr + i) * out_stride], dq_out);
-      }
-    }
-    for (; m_itr < rows; m_itr++) {
-      p_mat1_0 = &p_mat1[m_itr * row_stride1 - 2];
-      p_vec1_0 = p_vec1 - 2;
-
-      dq_acc[0] = AE_ZEROQ56();
-
-      for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-        AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-      }
-      /* Pointers are aligned so can do 8X2 loads and ignore L parts of
-       * registers */
-      if (cols1 & 1) {
-        AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-        AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-        AE_MULAP24S_HH(dq_acc[0], dp_mat1_0, dp_vec1_0);
-      }
-
-      if (p_bias != NULL)
-        dq_acc[0] = AE_ADDSQ56S(dq_acc[0], *(ae_q32s *)(&p_bias[m_itr]));
-
-      dq_out32 = AE_SATQ48S(dq_acc[0]);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      STORE_INT16(&p_out[m_itr * out_stride], dq_out);
-    }
-  } else {
-#ifndef DISABLE_NNLIB_UNALIGNED_SUPPORT
-    if ((((unsigned)p_mat1) & 1) == 0) {
-      for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) {
-        p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2];
-        p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1];
-        p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2];
-        p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1];
-        p_vec1_0 = p_vec1;
-
-        dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56();
-
-        /* Matrix elements are kept in upper 8 bits of P registers, vector
-        elements are kept in lower 8 bits of P registers, typecasting to UWORD8
-        is to avoid extra extui instructions since signed 8-bit load in not
-        there in HiFiMini */
-        for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-          AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2);
-          dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr],
-                                        (UWORD8)p_mat1_1[c_itr + 1]);
-          AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2);
-          dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr],
-                                        (UWORD8)p_mat1_3[c_itr + 1]);
-          dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr],
-                                        (UWORD8)p_vec1_0[c_itr + 1]);
-          dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8);
-          dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8);
-          dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8);
-          dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-          AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0);
-        }
-        if (cols1 & 1) {
-          ae_p24x2s dp_mat1_01, dp_mat1_23;
-          dp_mat1_01 =
-              AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[2], (UWORD8)p_mat1_1[c_itr]);
-          dp_mat1_23 =
-              AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[2], (UWORD8)p_mat1_3[c_itr]);
-          dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]);
-          dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8);
-          dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-          AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0);
-        }
-
-        if (p_bias != NULL) {
-          for (i = 0; i < 4; i++)
-            dq_acc[i] =
-                AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i]));
-        }
-
-        for (i = 0; i < 4; i++) {
-          dq_out32 = AE_SATQ48S(dq_acc[i]);
-          MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                           out_multiplier, left_shift,
-                                           right_shift);
-          STORE_INT16(&p_out[(m_itr + i) * out_stride], dq_out);
-        }
-      }
-    } else {
-      for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) {
-        p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1];
-        p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1];
-        p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1];
-        p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1];
-        p_vec1_0 = p_vec1;
-
-        dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56();
-
-        /* Matrix elements are kept in upper 8 bits of P registers, vector
-        elements are kept in lower 8 bits of P registers, typecasting to UWORD8
-        is to avoid extra extui instructions since signed 8-bit load in not
-        there in HiFiMini */
-        for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-          dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr],
-                                        (UWORD8)p_mat1_0[c_itr + 1]);
-          dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr],
-                                        (UWORD8)p_mat1_1[c_itr + 1]);
-          dp_mat1_2 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr],
-                                        (UWORD8)p_mat1_2[c_itr + 1]);
-          dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr],
-                                        (UWORD8)p_mat1_3[c_itr + 1]);
-          dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr],
-                                        (UWORD8)p_vec1_0[c_itr + 1]);
-          dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8);
-          dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8);
-          dp_mat1_2 = AE_SLLIP24(dp_mat1_2, 8);
-          dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8);
-          dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8);
-          dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-          AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0);
-          AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0);
-        }
-        if (cols1 & 1) {
-          ae_p24x2s dp_mat1_01, dp_mat1_23;
-          dp_mat1_01 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr],
-                                         (UWORD8)p_mat1_1[c_itr]);
-          dp_mat1_23 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr],
-                                         (UWORD8)p_mat1_3[c_itr]);
-          dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]);
-          dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8);
-          dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8);
-          dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-          AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0);
-          AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0);
-          AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0);
-        }
-
-        if (p_bias != NULL) {
-          for (i = 0; i < 4; i++)
-            dq_acc[i] =
-                AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i]));
-        }
-
-        for (i = 0; i < 4; i++) {
-          dq_out32 = AE_SATQ48S(dq_acc[i]);
-          MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                           out_multiplier, left_shift,
-                                           right_shift);
-          STORE_INT16(&p_out[(m_itr + i) * out_stride], dq_out);
-        }
-      }
-    }
-    for (; m_itr < rows; m_itr++) {
-      p_mat1_0 = &p_mat1[m_itr * row_stride1];
-      p_vec1_0 = p_vec1;
-
-      dq_acc[0] = AE_ZEROQ56();
-
-      for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) {
-        dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr],
-                                      (UWORD8)p_mat1_0[c_itr + 1]);
-        dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr],
-                                      (UWORD8)p_vec1_0[c_itr + 1]);
-        dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8);
-        dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8);
-        dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb);
-        AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-      }
-      if (cols1 & 1) {
-        dp_mat1_0 = AE_CVTP24A16(p_mat1_0[c_itr]);
-        dp_vec1_0 = AE_CVTP24A16(p_vec1_0[c_itr]);
-        dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, AE_CVTP24A16(vec1_zero_bias));
-        AE_MULAP24S_LL(dq_acc[0], dp_mat1_0, dp_vec1_0);
-      }
-
-      if (p_bias != NULL)
-        dq_acc[0] = AE_ADDSQ56S(dq_acc[0], *(ae_q32s *)(&p_bias[m_itr]));
-
-      dq_out32 = AE_SATQ48S(dq_acc[0]);
-      MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32),
-                                       out_multiplier, left_shift, right_shift);
-      STORE_INT16(&p_out[m_itr * out_stride], dq_out);
-    }
-#else
-    return 1;
-#endif
-  }
-
-  return 0;
-}
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_api.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_api.h
deleted file mode 100644
index e499e1eb980..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_api.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __XA_NNLIB_API_H__
-#define __XA_NNLIB_API_H__
-
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h"
-
-#endif /* __XA_NNLIB_API_H__ */
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h
deleted file mode 100644
index d3a5e2990c0..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h
+++ /dev/null
@@ -1,300 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __XA_NNLIB_KERNELS_API_H__
-#define __XA_NNLIB_KERNELS_API_H__
-
-/**
- * @file xa_nnlib_kernels_api.h
- * @brief This file gives the API definition for the HiFi NNLIB
- *
- * matXvec KERNELS API NAMING CONVENTION <br>
- * <br>
- * xa_nn_matXvec_<batch>_[m]x[n]_[p]_<activation>, where
- * - <batch>: Optional 'batch' tag to indicate time batching routine
- * - [m]: Matrix precision in bits
- * - [n]: Vector (and bias for non-activation routines) precision in bits
- * - [p]: Output precision in bits
- * - <activation>: optional activation tag 'sigmoid' / 'tanh'
- *
- * These set of kernels perform dual matXvec followed by optional
- * activation function. There are several variants based on the input,
- * output precision and use of activation functions.
- *
- * Restriction,
- * - All pointers (p_out, p_mat1, p_mat2, p_vec1, p_vec2, p_bias, p_scratch)
- * must be SIMD (64-bit) aligned and should not overlap.
- * - p_mat2, p_vec2 can be 'NULL', but other pointers cannot be 'NULL'
- * - Variables cols1, cols2, row_stride1, row_stride2 must be multiple of 4
- *
- * Usage of few critical variables,
- * - acc_shift:
- *   -# In case of valid activation tag i.e. <activation>: shift to be
- *   applied on accumulator to match accumulator's Q format with activation
- *   function's input's Q format
- *   -# In case of bypass i.e. no activation tag: shift to be applied on
- *   accumulator.
- *   -# Positive value denotes left shift, and negative value denotes right
- * shift.
- * - bias_shift: shift which is to be applied on bias to match bias's
- *   Q format with accumulator's Q format. Positive value denotes left shift,
- *   and negative value denotes right shift.
- * - bias_precision: This represents bias precision
- *   -# For 16x16, and 8x16 apis, valid values are '16' and '64'
- *   -# For 8x8 apis, valid values are '8' and '32'
- *
- * Output 8b, 16b, 32b of fixed point apis (only for bypass variants) is
- * extracted from 64b accumulator with symmetric rounding. Output 64b of fixed
- * point apis (only for bypass variants) is extracted from 64b accumulator.
- * Output 8b, 16b of fixed point apis (only for activation variants) is
- * symmetrically rounded.
- *
- * matXvec 16x16 Kernels,
- * - Bypass kernels with 16, 32, 64 bit output: 3
- * - Fused kernel with 2 activation variants:   2
- * - Time batching kernel:                      1 (Not implemented)
- * - Total:                                     6
- *
- * matXvec 8x16 Kernels,
- * - Bypass kernels with 16, 32, 64 bit output: 3
- * - Fused kernel with 2 activation variants:   2
- * - Time batching kernel:                      1 (Not implemented)
- * - Total:                                     6
- *
- * matXvec 8x8 Kernels,
- * - Bypass kernels with 8, 16, 32 bit output: 3
- * - Fused kernel with 2 activation variants:  2
- * - Time batching kernel:                     1 (Not implemented)
- * - Total:                                    6
- *
- * matXvec float32 x float32 Kernels,
- * - Bypass kernels 32 bit output:            1
- * - Fused kernel with 2 activation variants: 2
- * - Time batching kernel:                    1 (Not implemented)
- * - Total:                                   4
- *
- * ACTIVATION KERNELS API NAMING CONVENTION <br>
- * <br>
- * xa_nn_vec_[activation]_[n]_[p] for fixed point <br>
- * xa_nn_vec_[activation]_f32_f32 for floating point, where
- * - [activation]: One of activations - sigmoid/tanh/relu/relu1/relu6/softmax
- * - [n]:          Input precision in bits
- * - [p]:          Output precision in bits
- *
- * Possible values,
- * - 'n' takes value '32', and expects input in Q6.25 format.
- * - 'p' takes values '32' and '16', gives output in Q16.15 and Q0.15 formats
- * respectively.
- *
- * There is WORD32 datatype variable 'threshold' for 'relu' related apis, which
- * expects value in Q16.15 format.
- *
- * Restriction,
- * - All pointers (p_out, p_vec) must be 32-bit aligned and should not overlap.
- *
- * activation 32_32 kernels,
- * - Vector activation kernels: 6
- * - Total:                     6
- *
- * activation f32_f32 kernels,
- * - Vector activation kernels: 6
- * - Total:                     6
- *
- * activation 32_16 kernels,
- * - Vector activation kernels: 2
- * - Total:                     2
- */
-
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h"
-
-#if defined(__cplusplus)
-extern "C" {
-#endif
-
-WORD32 xa_nn_conv2d_depthwise_getsize(
-    WORD32 input_height, WORD32 input_width, WORD32 input_channels,
-    WORD32 kernel_height, WORD32 kernel_width, WORD32 channels_multiplier,
-    WORD32 x_stride, WORD32 y_stride, WORD32 x_padding, WORD32 y_padding,
-    WORD32 output_height, WORD32 output_width, WORD32 circ_buf_precision,
-    WORD32 inp_data_format);
-
-WORD32 xa_nn_vec_activation_min_max_asym8u_asym8u(
-    UWORD8 *__restrict__ p_out, const UWORD8 *__restrict__ p_vec,
-    int activation_min, int activation_max, WORD32 vec_length);
-
-WORD32 xa_nn_vec_activation_min_max_asym8s_asym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_vec,
-    int activation_min, int activation_max, WORD32 vec_length);
-
-WORD32 xa_nn_conv2d_std_getsize(WORD32 input_height, WORD32 input_channels,
-                                WORD32 kernel_height, WORD32 kernel_width,
-                                WORD32 y_stride, WORD32 y_padding,
-                                WORD32 out_height, WORD32 input_precision);
-
-WORD32 xa_nn_conv2d_std_asym8uxasym8u(
-    UWORD8 *__restrict__ p_out, const UWORD8 *__restrict__ p_inp,
-    const UWORD8 *__restrict__ p_kernel, const WORD32 *__restrict__ p_bias,
-    WORD32 input_height, WORD32 input_width, WORD32 input_channels,
-    WORD32 kernel_height, WORD32 kernel_width, WORD32 out_channels,
-    WORD32 x_stride, WORD32 y_stride, WORD32 x_padding, WORD32 y_padding,
-    WORD32 out_height, WORD32 out_width, WORD32 input_zero_bias,
-    WORD32 kernel_zero_bias, WORD32 out_multiplier, WORD32 out_shift,
-    WORD32 out_zero_bias, WORD32 out_data_format, VOID *p_scratch);
-
-WORD32 xa_nn_conv2d_std_per_chan_sym8sxasym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_inp,
-    const WORD8 *__restrict__ p_kernel, const WORD32 *__restrict__ p_bias,
-    WORD32 input_height, WORD32 input_width, WORD32 input_channels,
-    WORD32 kernel_height, WORD32 kernel_width, WORD32 out_channels,
-    WORD32 x_stride, WORD32 y_stride, WORD32 x_padding, WORD32 y_padding,
-    WORD32 out_height, WORD32 out_width, WORD32 input_zero_bias,
-    WORD32 *p_out_multiplier, WORD32 *p_out_shift, WORD32 out_zero_bias,
-    WORD32 out_data_format, VOID *p_scratch);
-
-WORD32 xa_nn_conv2d_depthwise_asym8uxasym8u(
-    pUWORD8 __restrict__ p_out, const UWORD8 *__restrict__ p_kernel,
-    const UWORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias,
-    WORD32 input_height, WORD32 input_width, WORD32 input_channels,
-    WORD32 kernel_height, WORD32 kernel_width, WORD32 channels_multiplier,
-    WORD32 x_stride, WORD32 y_stride, WORD32 x_padding, WORD32 y_padding,
-    WORD32 out_height, WORD32 out_width, WORD32 input_zero_bias,
-    WORD32 kernel_zero_bias, WORD32 out_multiplier, WORD32 out_shift,
-    WORD32 out_zero_bias, WORD32 inp_data_format, WORD32 out_data_format,
-    pVOID p_scratch);
-
-WORD32 xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_kernel,
-    const WORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias,
-    WORD32 input_height, WORD32 input_width, WORD32 input_channels,
-    WORD32 kernel_height, WORD32 kernel_width, WORD32 channels_multiplier,
-    WORD32 x_stride, WORD32 y_stride, WORD32 x_padding, WORD32 y_padding,
-    WORD32 out_height, WORD32 out_width, WORD32 input_zero_bias,
-    const WORD32 *p_out_multiplier, const WORD32 *p_out_shift,
-    WORD32 out_zero_bias, WORD32 inp_data_format, WORD32 out_data_format,
-    pVOID p_scratch);
-
-WORD32 xa_nn_fully_connected_asym8uxasym8u_asym8u(
-    pUWORD8 __restrict__ p_out, const UWORD8 *__restrict__ p_weight,
-    const UWORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias,
-    WORD32 weight_depth, WORD32 out_depth, WORD32 input_zero_bias,
-    WORD32 weight_zero_bias, WORD32 out_multiplier, WORD32 out_shift,
-    WORD32 out_zero_bias);
-
-WORD32 xa_nn_fully_connected_sym8sxasym8s_asym8s(
-    pWORD8 __restrict__ p_out, const WORD8 *__restrict__ p_weight,
-    const WORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias,
-    WORD32 weight_depth, WORD32 out_depth, WORD32 input_zero_bias,
-    WORD32 out_multiplier, WORD32 out_shift, WORD32 out_zero_bias);
-
-WORD32 xa_nn_fully_connected_asym8sxasym8s_asym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_weight,
-    const WORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias,
-    WORD32 weight_depth, WORD32 out_depth, WORD32 weight_zero_bias,
-    WORD32 input_zero_bias, WORD32 out_multiplier, WORD32 out_shift,
-    WORD32 out_zero_bias);
-
-WORD32 xa_nn_vec_softmax_asym8u_8(UWORD8 *__restrict__ p_out,
-                                  const UWORD8 *__restrict__ p_vec,
-                                  WORD32 diffmin, WORD32 input_left_shift,
-                                  WORD32 input_multiplier, WORD32 vec_length,
-                                  pVOID p_scratch);
-
-WORD32 xa_nn_vec_softmax_asym8s_16(WORD16 *__restrict__ p_out,
-                                   const WORD8 *__restrict__ p_vec,
-                                   WORD32 diffmin, WORD32 input_left_shift,
-                                   WORD32 input_multiplier, WORD32 vec_length,
-                                   pVOID p_scratch);
-
-WORD32 xa_nn_vec_softmax_asym8s_8(WORD8 *__restrict__ p_out,
-                                  const WORD8 *__restrict__ p_vec,
-                                  WORD32 diffmin, WORD32 input_left_shift,
-                                  WORD32 input_multiplier, WORD32 vec_length,
-                                  pVOID p_scratch);
-
-int xa_nn_get_softmax_scratch_size(int inp_precision, int out_precision,
-                                   int length);
-
-WORD32 xa_nn_matXvec_out_stride_asym8uxasym8u_asym8u(
-    UWORD8 *__restrict__ p_out, const UWORD8 *__restrict__ p_mat1,
-    const UWORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias,
-    WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride,
-    WORD32 mat1_zero_bias, WORD32 vec1_zero_bias, WORD32 out_multiplier,
-    WORD32 out_shift, WORD32 out_zero_bias);
-
-WORD32 xa_nn_matXvec_out_stride_sym8sxasym8s_asym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1,
-    const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias,
-    WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride,
-    WORD32 vec1_zero_bias, WORD32 out_multiplier, WORD32 out_shift,
-    WORD32 out_zero_bias);
-
-WORD32 xa_nn_matXvec_out_stride_asym8sxasym8s_asym8s(
-    WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1,
-    const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias,
-    WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride,
-    WORD32 mat1_zero_bias, WORD32 vec1_zero_bias, WORD32 out_multiplier,
-    WORD32 out_shift, WORD32 out_zero_bias);
-
-WORD32 xa_nn_matXvec_out_stride_sym8sxasym8s_16(
-    WORD16 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1,
-    const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias,
-    WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride,
-    WORD32 vec1_zero_bias, WORD32 out_multiplier, WORD32 out_shift);
-
-WORD32 xa_nn_dot_prod_16x16_asym8s(
-    WORD8 *__restrict__ p_out,               /* pointer to output */
-    const WORD16 *__restrict__ p_inp1_start, /* pointer to input1 */
-    const WORD16 *__restrict__ p_inp2_start, /* pointer to input2 */
-    const WORD32 *bias_ptr, WORD32 vec_length, WORD32 out_multiplier,
-    WORD32 out_shift, WORD32 out_zero_bias, WORD32 vec_count);
-
-/* Mapping the functions names from previous naming convension for backward
- * compatibility */
-#define xa_nn_vec_activation_min_max_asym8_asym8 \
-  xa_nn_vec_activation_min_max_asym8u_asym8u
-#define xa_nn_conv2d_std_asym8xasym8 xa_nn_conv2d_std_asym8uxasym8u
-#define xa_nn_conv2d_depthwise_asym8xasym8 xa_nn_conv2d_depthwise_asym8uxasym8u
-#define xa_nn_fully_connected_asym8xasym8_asym8 \
-  xa_nn_fully_connected_asym8uxasym8u_asym8u
-#define xa_nn_vec_softmax_asym8_asym8 xa_nn_vec_softmax_asym8u_asym8u
-#define xa_nn_dot_prod_asym8xasym8_asym8 xa_nn_dot_prod_asym8uxasym8u_asym8u
-#define xa_nn_matXvec_out_stride_asym8xasym8_asym8 \
-  xa_nn_matXvec_out_stride_asym8uxasym8u_asym8u
-
-#if defined(__cplusplus)
-}
-#endif
-#endif /* __XA_NNLIB_KERNELS_API_H__ */
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_standards.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_standards.h
deleted file mode 100644
index 36ea75d1e25..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_standards.h
+++ /dev/null
@@ -1,170 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __STANDARDS_H__
-#define __STANDARDS_H__
-
-#if defined(__cplusplus)
-extern "C" {
-#endif
-
-typedef double flt64;
-typedef char Int4;
-typedef char Int8;
-typedef int16_t Int16;
-typedef int Int32;
-typedef int Int24;
-typedef int64_t Int64;
-typedef int Bool;
-typedef float Flt32;
-
-#ifdef MODEL_FLT64
-typedef double vect_t;
-typedef double coeff_t;
-typedef double accu_t;
-
-#elif MODEL_INT16
-typedef int16_t vect_t;
-typedef int16_t coeff_t;
-typedef signed char coeff8_t;
-typedef int64_t accu_t;
-typedef float coefff32_t;
-#endif
-
-typedef struct xa_nnlib_opaque {
-  Int32 _;
-} * xa_nnlib_handle_t;
-
-typedef enum _xa_nnlib_prec_t {
-  PREC_8 = 8,
-  PREC_16 = 16,
-  PREC_32 = 32,
-  PREC_F32 = -1,
-  PREC_F16 = -2,
-  PREC_ASYM8U = -3,
-  PREC_ASYM8S = -4,
-  PREC_SYM8S = -5
-} xa_nnlib_prec_t;
-
-typedef enum _xa_nnlib_shape_type_t {
-  SHAPE_UNKNOWN_T = 0,
-  SHAPE_VECTOR_T = 1,
-  SHAPE_MATRIX_T = 2,
-  SHAPE_CUBE_DWH_T = 3,
-  SHAPE_CUBE_WHD_T = 4
-} xa_nnlib_shape_type_t;
-
-typedef struct _xa_nnlib_shape_t {
-  xa_nnlib_shape_type_t shape_type;
-  Int32 n_shapes;
-  Int32 shape_offset;  // Offest between current shape and next shape
-  union {
-    struct {
-      Int32 height;
-      Int32 height_offset;
-      Int32 width;
-      Int32 width_offset;
-      Int32 depth;
-      Int32 depth_offset;
-    } cube;
-
-    struct {
-      Int32 length;
-    } vector;
-    struct {
-      Int32 rows;
-      Int32 row_offset;  // Offset between current row and next row
-      Int32 cols;
-    } matrix;
-  } dim;
-} xa_nnlib_shape_t;
-
-/*****************************************************************************/
-/* Constant hash defines                                                     */
-/*****************************************************************************/
-#define XA_NNLIB_NO_ERROR 0
-/* error handling 'AND' definition */
-#define XA_FATAL_ERROR 0x80000000
-
-enum xa_error_severity {
-  xa_severity_nonfatal = 0,
-  xa_severity_fatal = (int)0xffffffff
-};
-
-enum xa_error_class {
-  xa_class_nnlib = 0,
-  xa_class_config = 1,
-  xa_class_execute = 2
-};
-
-#define XA_NNLIB_GENERIC 0
-
-#define XA_ERROR_CODE(severity, class, codec, index) \
-  ((severity << 31) | (class << 12) | (codec << 7) | index)
-#define XA_ERROR_SEVERITY(code) (((code)&XA_FATAL_ERROR) != 0)
-#define XA_ERROR_CLASS(code) (((code) >> 12) & 0x0f)
-#define XA_ERROR_CODEC(code) (((code) >> 7) & 0x1f)
-#define XA_ERROR_SUBCODE(code) (((code) >> 0) & 0x3f)
-
-/* Our convention is that only nnlib-class errors can be generic ones. */
-
-/*****************************************************************************/
-/* Class 0: NNLib Errors                                                     */
-/*****************************************************************************/
-/* Non Fatal Errors */
-/* (none) */
-/* Fatal Errors */
-enum xa_error_fatal_nnlib_generic {
-  XA_NNLIB_FATAL_MEM_ALLOC =
-      XA_ERROR_CODE(xa_severity_fatal, xa_class_nnlib, XA_NNLIB_GENERIC, 0),
-  XA_NNLIB_FATAL_MEM_ALIGN =
-      XA_ERROR_CODE(xa_severity_fatal, xa_class_nnlib, XA_NNLIB_GENERIC, 1),
-  XA_NNLIB_FATAL_INVALID_SHAPE =
-      XA_ERROR_CODE(xa_severity_fatal, xa_class_nnlib, XA_NNLIB_GENERIC, 3)
-};
-
-/*****************************************************************************/
-/* NNLib Startup Functions                                                   */
-/*****************************************************************************/
-const Int8* xa_nnlib_get_lib_name_string(void);
-const Int8* xa_nnlib_get_lib_version_string(void);
-const Int8* xa_nnlib_get_lib_api_version_string(void);
-
-#if defined(__cplusplus)
-}
-#endif
-
-#endif /* __STANDARDS_H__ */
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h
deleted file mode 100644
index 13a7469bbf7..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h
+++ /dev/null
@@ -1,108 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2019-2020 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __XA_TYPE_DEF_H__
-#define __XA_TYPE_DEF_H__
-
-#include <stdint.h>
-
-/****************************************************************************/
-/*     types               type define    prefix        examples      bytes */
-/************************  ***********    ******    ****************  ***** */
-typedef signed char WORD8;      /* b       WORD8    b_name     1   */
-typedef signed char* pWORD8;    /* pb      pWORD8   pb_nmae    1   */
-typedef unsigned char UWORD8;   /* ub      UWORD8   ub_count   1   */
-typedef unsigned char* pUWORD8; /* pub     pUWORD8  pub_count  1   */
-
-typedef int16_t WORD16;     /* s       WORD16   s_count    2   */
-typedef int16_t* pWORD16;   /* ps      pWORD16  ps_count   2   */
-typedef uint16_t UWORD16;   /* us      UWORD16  us_count   2   */
-typedef uint16_t* pUWORD16; /* pus     pUWORD16 pus_count  2   */
-
-typedef signed int WORD24;      /* k       WORD24   k_count    3   */
-typedef signed int* pWORD24;    /* pk      pWORD24  pk_count   3   */
-typedef unsigned int UWORD24;   /* uk      UWORD24  uk_count   3   */
-typedef unsigned int* pUWORD24; /* puk     pUWORD24 puk_count  3   */
-
-typedef signed int WORD32;      /* i       WORD32   i_count    4   */
-typedef signed int* pWORD32;    /* pi      pWORD32  pi_count   4   */
-typedef unsigned int UWORD32;   /* ui      UWORD32  ui_count   4   */
-typedef unsigned int* pUWORD32; /* pui     pUWORD32 pui_count  4   */
-
-typedef int64_t WORD40;     /* m       WORD40   m_count    5   */
-typedef int64_t* pWORD40;   /* pm      pWORD40  pm_count   5   */
-typedef uint64_t UWORD40;   /* um      UWORD40  um_count   5   */
-typedef uint64_t* pUWORD40; /* pum     pUWORD40 pum_count  5   */
-
-typedef int64_t WORD64;     /* h       WORD64   h_count    8   */
-typedef int64_t* pWORD64;   /* ph      pWORD64  ph_count   8   */
-typedef uint64_t UWORD64;   /* uh      UWORD64  uh_count   8   */
-typedef uint64_t* pUWORD64; /* puh     pUWORD64 puh_count  8   */
-
-typedef float FLOAT32;    /* f       FLOAT32  f_count    4   */
-typedef float* pFLOAT32;  /* pf      pFLOAT32 pf_count   4   */
-typedef double FLOAT64;   /* d       UFLOAT64 d_count    8   */
-typedef double* pFlOAT64; /* pd      pFLOAT64 pd_count   8   */
-
-typedef void VOID;   /* v       VOID     v_flag     4   */
-typedef void* pVOID; /* pv      pVOID    pv_flag    4   */
-
-/* variable size types: platform optimized implementation */
-typedef signed int BOOL;       /* bool    BOOL     bool_true      */
-typedef unsigned int UBOOL;    /* ubool   BOOL     ubool_true     */
-typedef signed int FLAG;       /* flag    FLAG     flag_false     */
-typedef unsigned int UFLAG;    /* uflag   FLAG     uflag_false    */
-typedef signed int LOOPIDX;    /* lp      LOOPIDX  lp_index       */
-typedef unsigned int ULOOPIDX; /* ulp     SLOOPIDX ulp_index      */
-typedef signed int WORD;       /* lp      LOOPIDX  lp_index       */
-typedef unsigned int UWORD;    /* ulp     SLOOPIDX ulp_index      */
-
-typedef LOOPIDX LOOPINDEX;   /* lp    LOOPIDX  lp_index       */
-typedef ULOOPIDX ULOOPINDEX; /* ulp   SLOOPIDX ulp_index      */
-
-#define PLATFORM_INLINE __inline
-
-typedef struct xa_codec_opaque {
-  WORD32 _;
-} * xa_codec_handle_t;
-
-typedef int XA_ERRORCODE;
-
-typedef XA_ERRORCODE xa_codec_func_t(xa_codec_handle_t p_xa_module_obj,
-                                     WORD32 i_cmd, WORD32 i_idx,
-                                     pVOID pv_value);
-
-#endif /* __XA_TYPE_DEF_H__ */
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h
deleted file mode 100644
index 81847b60444..00000000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h
+++ /dev/null
@@ -1,88 +0,0 @@
-/******************************************************************************
- * Copyright (C) 2019 Cadence Design Systems, Inc.
- *
- * Permission is hereby granted, free of charge, to any person obtaining
- * a copy of this software and associated documentation files (the
- * "Software"), to use this Software with Cadence processor cores only and
- * not with any other processors and platforms, subject to
- * the following conditions:
- *
- * The above copyright notice and this permission notice shall be included
- * in all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ******************************************************************************/
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __XTENSA_TF_MICRO_COMMON__
-#define __XTENSA_TF_MICRO_COMMON__
-
-#if defined HIFI_NNLIB_OPT || defined HIFI_MINI_NNLIB_OPT
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_api.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_standards.h"
-
-#define CHECK_ERR_HIFI_NNLIB_KER(ret, err_msg) \
-  if (ret != 0) {                              \
-    TF_LITE_KERNEL_LOG(context, err_msg);      \
-    return kTfLiteError;                       \
-  }
-
-#ifndef XTENSA_NNLIB_MAX_SCRATCH_SIZE
-#define XTENSA_NNLIB_MAX_SCRATCH_SIZE (70 * 1024)
-#endif
-
-#define ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM \
-  uint8_t xtensa_nnlib_scratch_buf[XTENSA_NNLIB_MAX_SCRATCH_SIZE];
-
-#define MIN(a, b) (a) < (b) ? (a) : (b);
-#define MAX(a, b) (a) > (b) ? (a) : (b);
-
-#define ACTIVATION_MIN_MAX(data_type, out, inp, min, max) \
-  {                                                       \
-    data_type temp = MAX(inp, min);                       \
-    out = MIN(temp, max);                                 \
-  }
-
-#define ACTIVATION_MIN_MAX_F32(out, inp, min, max) \
-  {                                                \
-    float temp = MAX(inp, min);                    \
-    out = MIN(temp, max);                          \
-  }
-
-#define ACTIVATION_MIN_MAX_ASYM8(out, inp, min, max) \
-  {                                                  \
-    int32_t temp = MAX((int32_t)inp, min);           \
-    out = (uint8_t)MIN(temp, max);                   \
-  }
-
-#define ALIGNED_SIZE(x, bytes) (((x) + (bytes - 1)) & (~(bytes - 1)))
-#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1)))
-
-#define PRINT_VAR(var)            \
-  printf("%s = %d\n", #var, var); \
-  fflush(stdout);                 \
-  fflush(stderr);
-
-#endif /* HIFI_NNLIB_OPT */
-
-#endif /* __XTENSA_TF_MICRO_COMMON__ */
diff --git a/tensorflow/lite/micro/micro_interpreter_test.cc b/tensorflow/lite/micro/micro_interpreter_test.cc
index 2dbfa8b56ac..3f4bb813d2d 100644
--- a/tensorflow/lite/micro/micro_interpreter_test.cc
+++ b/tensorflow/lite/micro/micro_interpreter_test.cc
@@ -494,4 +494,68 @@ TF_LITE_MICRO_TEST(TestInterpreterDoesNotAllocateUntilInvoke) {
       static_cast<size_t>(0));
 }
 
+TF_LITE_MICRO_TEST(TestInterpreterMultipleInputs) {
+  const tflite::Model* model = tflite::testing::GetSimpleMultipleInputsModel();
+  TF_LITE_MICRO_EXPECT_NE(nullptr, model);
+
+  tflite::AllOpsResolver op_resolver = tflite::testing::GetOpResolver();
+
+  constexpr size_t allocator_buffer_size = 2000;
+  uint8_t allocator_buffer[allocator_buffer_size];
+
+  // Create a new scope so that we can test the destructor.
+  {
+    tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer,
+                                         allocator_buffer_size,
+                                         micro_test::reporter);
+
+    TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+    TF_LITE_MICRO_EXPECT_LE(interpreter.arena_used_bytes(), 928 + 100);
+
+    TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(3), interpreter.inputs_size());
+    TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), interpreter.outputs_size());
+    TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(4), interpreter.tensors_size());
+
+    TfLiteTensor* input = interpreter.input(0);
+    TF_LITE_MICRO_EXPECT_NE(nullptr, input);
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, input->type);
+    TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size);
+    TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
+    TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(4), input->bytes);
+    TF_LITE_MICRO_EXPECT_NE(nullptr, input->data.i32);
+    input->data.i32[0] = 21;
+
+    TfLiteTensor* input1 = interpreter.input(1);
+    TF_LITE_MICRO_EXPECT_NE(nullptr, input1);
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, input1->type);
+    TF_LITE_MICRO_EXPECT_EQ(1, input1->dims->size);
+    TF_LITE_MICRO_EXPECT_EQ(1, input1->dims->data[0]);
+    TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), input1->bytes);
+    TF_LITE_MICRO_EXPECT_NE(nullptr, input1->data.i32);
+    input1->data.i32[0] = 21;
+
+    TfLiteTensor* input2 = interpreter.input(2);
+    TF_LITE_MICRO_EXPECT_NE(nullptr, input2);
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, input2->type);
+    TF_LITE_MICRO_EXPECT_EQ(1, input2->dims->size);
+    TF_LITE_MICRO_EXPECT_EQ(1, input2->dims->data[0]);
+    TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(4), input2->bytes);
+    TF_LITE_MICRO_EXPECT_NE(nullptr, input2->data.i32);
+    input2->data.i32[0] = 24;
+
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
+
+    TfLiteTensor* output = interpreter.output(0);
+    TF_LITE_MICRO_EXPECT_NE(nullptr, output);
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, output->type);
+    TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size);
+    TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
+    TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(4), output->bytes);
+    TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32);
+    TF_LITE_MICRO_EXPECT_EQ(66, output->data.i32[0]);
+  }
+
+  TF_LITE_MICRO_EXPECT_EQ(tflite::testing::MultipleInputs::freed_, true);
+}
+
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc
index 897f3110036..f73073f63da 100644
--- a/tensorflow/lite/micro/test_helpers.cc
+++ b/tensorflow/lite/micro/test_helpers.cc
@@ -570,6 +570,74 @@ const Model* BuildComplexMockModel() {
   return model;
 }
 
+const Model* BuildSimpleMultipleInputsModel() {
+  using flatbuffers::Offset;
+  flatbuffers::FlatBufferBuilder* builder = BuilderInstance();
+
+  constexpr size_t buffers_size = 1;
+  const Offset<Buffer> buffers[buffers_size] = {
+      CreateBuffer(*builder),
+  };
+  constexpr size_t tensor_shape_size = 1;
+  const int32_t tensor_shape[tensor_shape_size] = {1};
+  constexpr size_t tensors_size = 4;
+  const Offset<Tensor> tensors[tensors_size] = {
+      CreateTensor(*builder,
+                   builder->CreateVector(tensor_shape, tensor_shape_size),
+                   TensorType_INT32, 0,
+                   builder->CreateString("test_input_tensor1"), 0, false),
+      CreateTensor(*builder,
+                   builder->CreateVector(tensor_shape, tensor_shape_size),
+                   TensorType_INT8, 0,
+                   builder->CreateString("test_input_tensor2"), 0, false),
+      CreateTensor(*builder,
+                   builder->CreateVector(tensor_shape, tensor_shape_size),
+                   TensorType_INT32, 0,
+                   builder->CreateString("test_input_tensor3"), 0, false),
+      CreateTensor(*builder,
+                   builder->CreateVector(tensor_shape, tensor_shape_size),
+                   TensorType_INT32, 0,
+                   builder->CreateString("test_output_tensor"), 0, false),
+  };
+  constexpr size_t inputs_size = 3;
+  const int32_t inputs[inputs_size] = {0, 1, 2};
+  constexpr size_t outputs_size = 1;
+  const int32_t outputs[outputs_size] = {3};
+  constexpr size_t operator_inputs_size = 3;
+  const int32_t operator_inputs[operator_inputs_size] = {0, 1, 2};
+  constexpr size_t operator_outputs_size = 1;
+  const int32_t operator_outputs[operator_outputs_size] = {3};
+  constexpr size_t operators_size = 1;
+  const Offset<Operator> operators[operators_size] = {
+      CreateOperator(
+          *builder, 0,
+          builder->CreateVector(operator_inputs, operator_inputs_size),
+          builder->CreateVector(operator_outputs, operator_outputs_size),
+          BuiltinOptions_NONE),
+  };
+  constexpr size_t subgraphs_size = 1;
+  const Offset<SubGraph> subgraphs[subgraphs_size] = {
+      CreateSubGraph(*builder, builder->CreateVector(tensors, tensors_size),
+                     builder->CreateVector(inputs, inputs_size),
+                     builder->CreateVector(outputs, outputs_size),
+                     builder->CreateVector(operators, operators_size),
+                     builder->CreateString("test_subgraph"))};
+  constexpr size_t operator_codes_size = 1;
+  const Offset<OperatorCode> operator_codes[operator_codes_size] = {
+      CreateOperatorCodeDirect(*builder, /*deprecated_builtin_code=*/0,
+                               "multiple_inputs_op",
+                               /*version=*/0, BuiltinOperator_CUSTOM)};
+  const Offset<Model> model_offset = CreateModel(
+      *builder, 0, builder->CreateVector(operator_codes, operator_codes_size),
+      builder->CreateVector(subgraphs, subgraphs_size),
+      builder->CreateString("test_model"),
+      builder->CreateVector(buffers, buffers_size));
+  FinishModelBuffer(*builder, model_offset);
+  void* model_pointer = builder->GetBufferPointer();
+  const Model* model = flatbuffers::GetRoot<Model>(model_pointer);
+  return model;
+}
+
 }  // namespace
 
 const TfLiteRegistration* SimpleStatefulOp::getRegistration() {
@@ -704,12 +772,66 @@ TfLiteStatus MockCustom::Invoke(TfLiteContext* context, TfLiteNode* node) {
 
 bool MockCustom::freed_ = false;
 
+const TfLiteRegistration* MultipleInputs::getRegistration() {
+  return GetMutableRegistration();
+}
+
+TfLiteRegistration* MultipleInputs::GetMutableRegistration() {
+  static TfLiteRegistration r;
+  r.init = Init;
+  r.prepare = Prepare;
+  r.invoke = Invoke;
+  r.free = Free;
+  return &r;
+}
+
+void* MultipleInputs::Init(TfLiteContext* context, const char* buffer,
+                           size_t length) {
+  // We don't support delegate in TFL micro. This is a weak check to test if
+  // context struct being zero-initialized.
+  TFLITE_DCHECK(context->ReplaceNodeSubsetsWithDelegateKernels == nullptr);
+  freed_ = false;
+  // Do nothing.
+  return nullptr;
+}
+
+void MultipleInputs::Free(TfLiteContext* context, void* buffer) {
+  freed_ = true;
+}
+
+TfLiteStatus MultipleInputs::Prepare(TfLiteContext* context, TfLiteNode* node) {
+  return kTfLiteOk;
+}
+
+TfLiteStatus MultipleInputs::Invoke(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input;
+  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
+  const int32_t* input_data = input->data.i32;
+  const TfLiteTensor* input1;
+  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input1));
+  const int32_t* input_data1 = input1->data.i32;
+  const TfLiteTensor* input2;
+  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &input2));
+  const int32_t* input_data2 = input2->data.i32;
+
+  TfLiteTensor* output;
+  TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
+  int32_t* output_data = output->data.i32;
+  output_data[0] =
+      0;  // Catch output tensor sharing memory with an input tensor
+  output_data[0] = input_data[0] + input_data1[0] + input_data2[0];
+  return kTfLiteOk;
+}
+
+bool MultipleInputs::freed_ = false;
+
 AllOpsResolver GetOpResolver() {
   AllOpsResolver op_resolver;
   op_resolver.AddCustom("mock_custom", MockCustom::GetMutableRegistration());
   op_resolver.AddCustom("simple_stateful_op",
                         SimpleStatefulOp::GetMutableRegistration());
-
+  op_resolver.AddCustom("multiple_inputs_op",
+                        MultipleInputs::GetMutableRegistration());
   return op_resolver;
 }
 
@@ -721,6 +843,14 @@ const Model* GetSimpleMockModel() {
   return model;
 }
 
+const Model* GetSimpleMultipleInputsModel() {
+  static Model* model = nullptr;
+  if (!model) {
+    model = const_cast<Model*>(BuildSimpleMultipleInputsModel());
+  }
+  return model;
+}
+
 const Model* GetComplexMockModel() {
   static Model* model = nullptr;
   if (!model) {
diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h
index 1db0d81facc..4c8b7c20aa0 100644
--- a/tensorflow/lite/micro/test_helpers.h
+++ b/tensorflow/lite/micro/test_helpers.h
@@ -76,6 +76,20 @@ class MockCustom {
   static bool freed_;
 };
 
+// A simple operator with the purpose of testing multiple inputs. It returns
+// the sum of the inputs.
+class MultipleInputs {
+ public:
+  static const TfLiteRegistration* getRegistration();
+  static TfLiteRegistration* GetMutableRegistration();
+  static void* Init(TfLiteContext* context, const char* buffer, size_t length);
+  static void Free(TfLiteContext* context, void* buffer);
+  static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node);
+  static TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node);
+
+  static bool freed_;
+};
+
 // Returns an Op Resolver that can be used in the testing code.
 AllOpsResolver GetOpResolver();
 
@@ -90,6 +104,10 @@ const Model* GetComplexMockModel();
 // Returns a simple flatbuffer model with two branches.
 const Model* GetSimpleModelWithBranch();
 
+// Returns a simple example flatbuffer TensorFlow Lite model. Contains 3 inputs,
+// 1 output Tensor, and 1 operator.
+const Model* GetSimpleMultipleInputsModel();
+
 // Returns a simple flatbuffer model with offline planned tensors
 // @param[in]       num_tensors           Number of tensors in the model.
 // @param[in]       metadata_buffer       Metadata for offline planner.
diff --git a/tensorflow/lite/micro/tools/ci_build/test_esp32.sh b/tensorflow/lite/micro/tools/ci_build/test_esp32.sh
new file mode 100755
index 00000000000..8341e90924e
--- /dev/null
+++ b/tensorflow/lite/micro/tools/ci_build/test_esp32.sh
@@ -0,0 +1,58 @@
+#!/usr/bin/env bash
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Tests the microcontroller code for esp32 platform
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+ROOT_DIR=${SCRIPT_DIR}/../../../../..
+cd "${ROOT_DIR}"
+pwd
+
+source tensorflow/lite/micro/tools/ci_build/helper_functions.sh
+
+TARGET=esp
+
+# setup esp-idf and toolchains
+readable_run git clone --recursive --single-branch --branch release/v4.2 https://github.com/espressif/esp-idf.git
+readable_run export IDF_PATH="${ROOT_DIR}"/esp-idf
+cd $IDF_PATH
+readable_run ./install.sh
+readable_run . ./export.sh
+cd "${ROOT_DIR}"
+
+# clean all
+readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
+
+# generate examples
+readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} generate_hello_world_esp_project
+readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} generate_person_detection_esp_project
+readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} generate_micro_speech_esp_project
+
+# build examples
+cd "${ROOT_DIR}"/tensorflow/lite/micro/tools/make/gen/esp_xtensa-esp32/prj/hello_world/esp-idf
+readable_run idf.py build
+
+cd "${ROOT_DIR}"/tensorflow/lite/micro/tools/make/gen/esp_xtensa-esp32/prj/person_detection/esp-idf
+readable_run git clone https://github.com/espressif/esp32-camera.git components/esp32-camera
+cd components/esp32-camera/
+readable_run git checkout eacd640b8d379883bff1251a1005ebf3cf1ed95c
+cd ../../
+readable_run idf.py build
+
+cd "${ROOT_DIR}"/tensorflow/lite/micro/tools/make/gen/esp_xtensa-esp32/prj/micro_speech/esp-idf
+readable_run idf.py build
diff --git a/tensorflow/lite/micro/tools/ci_build/test_x86.sh b/tensorflow/lite/micro/tools/ci_build/test_x86.sh
index 844dccbafb7..05d79802dcd 100755
--- a/tensorflow/lite/micro/tools/ci_build/test_x86.sh
+++ b/tensorflow/lite/micro/tools/ci_build/test_x86.sh
@@ -29,9 +29,12 @@ readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
 # TODO(b/143715361): downloading first to allow for parallel builds.
 readable_run make -f tensorflow/lite/micro/tools/make/Makefile third_party_downloads
 
-# Next, build w/o TF_LITE_STATIC_MEMORY to catch additional build errors.
+# Next, build w/o TF_LITE_STATIC_MEMORY to catch additional errors.
+# TODO(b/160955687): We run the tests w/o TF_LITE_STATIC_MEMORY to make the
+# internal and open source CI consistent. See b/160955687#comment7 for more
+# details.
 readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
-readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile BUILD_TYPE=no_tf_lite_static_memory build
+readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile BUILD_TYPE=no_tf_lite_static_memory test
 
 # Next, make sure that the release build succeeds.
 readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini.inc
deleted file mode 100644
index c972870cf02..00000000000
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini.inc
+++ /dev/null
@@ -1,4 +0,0 @@
-# Every optimized kernel implementation directory (i.e.
-# micro/kernels/<optimized_kernel_dir>/ must have a corresponding
-# micro/tools/make/ext_libs/<optimized_kernel_dir>.inc
-
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini_staging_nn_library.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini_staging_nn_library.inc
deleted file mode 100644
index df7d3089c30..00000000000
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini_staging_nn_library.inc
+++ /dev/null
@@ -1,30 +0,0 @@
-ifneq ($(filter xtensa_hifimini_staging, $(ALL_TAGS)),)
-
-    XTENSA_PATH = $(MAKEFILE_DIR)/../../kernels/xtensa_hifimini_staging
-
-    ifneq (,$(filter xtensa_hifimini%, $(ALL_TAGS)))
-
-        CCFLAGS += -DHIFI_MINI_NNLIB_OPT \
-                   -DDISABLE_NNLIB_UNALIGNED_SUPPORT \
-                   -DXTENSA_NNLIB_MAX_SCRATCH_SIZE=1024
-
-        CXXFLAGS += -DHIFI_MINI_NNLIB_OPT \
-                   -DDISABLE_NNLIB_UNALIGNED_SUPPORT \
-                   -DXTENSA_NNLIB_MAX_SCRATCH_SIZE=1024
-
-        MICROLITE_CC_SRCS += \
-                    $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_activations_asym8s_asym8s.c \
-                    $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_softmax_asym8_asym8.c \
-                    $(XTENSA_PATH)/xa_nnlib/algo/kernels/basic/hifi_mini/xa_nn_dot_prod_16x16.c \
-                    $(XTENSA_PATH)/xa_nnlib/algo/kernels/fc/hifi_mini/xa_nn_fully_connected.c \
-                    $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi_mini/xa_nn_matXvec_sym8sxasym8s.c \
-
-
-        INCLUDES += -I$(XTENSA_PATH)/xa_nnlib/algo/kernels/ \
-                    -I$(XTENSA_PATH)/xa_nnlib/include/nnlib/ \
-                    -I$(XTENSA_PATH)/xa_nnlib/include/ \
-                    -I$(XTENSA_PATH)/xa_nnlib/algo/common/include/ \
-
-    endif
-
-endif
diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_makefile.inc
deleted file mode 100644
index 1587ebcd034..00000000000
--- a/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_makefile.inc
+++ /dev/null
@@ -1,58 +0,0 @@
-# Settings for Xtensa toolchain for the hifimini kernels.
-# REQUIRED:
-#  Environment variables:
-#   - XTENSA_BASE  must be set to location of
-#     the Xtensa developer tools installation directory.
-#  Command line arguments:
-#   - XTENSA_TOOLS_VERSION: For example: RI-2019.2-linux
-#   - XTENSA_CORE: The name of the Xtensa core to use
-#      For example: hifimini
-
-TARGET_ARCH := xtensa_hifimini
-
-ifndef XTENSA_BASE
-  $(error XTENSA_BASE is undefined)
-endif
-
-ifndef XTENSA_TOOLS_VERSION
-  $(error XTENSA_TOOLS_VERSION is undefined)
-endif
-
-ifndef XTENSA_CORE
-  $(error XTENSA_CORE is undefined)
-endif
-
-PLATFORM_FLAGS = \
-  -DTF_LITE_MCU_DEBUG_LOG \
-  -DTF_LITE_USE_CTIME \
-  --xtensa-core=$(XTENSA_CORE) \
-  -mcoproc \
-  -DXTENSA \
-  -DMAX_RFFT_PWR=9 \
-  -DMIN_RFFT_PWR=MAX_RFFT_PWR
-
-
-export PATH := $(XTENSA_BASE)/tools/$(XTENSA_TOOLS_VERSION)/XtensaTools/bin:$(PATH)
-TARGET_TOOLCHAIN_PREFIX := xt-
-CXX_TOOL := clang++
-CC_TOOL := clang
-
-CXXFLAGS += $(PLATFORM_FLAGS)
-CCFLAGS += $(PLATFORM_FLAGS)
-
-# TODO(b/150240249): Do not remove -fno-rtti once that works for the Xtensa toolchain.
-CXXFLAGS := $(filter-out -fno-rtti, $(CXXFLAGS))
-
-TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_hifimini_binary.sh
-
-# TODO(b/156962140): This manually maintained list of excluded examples is
-# quite error prone.
-EXCLUDED_EXAMPLE_TESTS := \
-  tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc \
-  tensorflow/lite/micro/examples/magic_wand/Makefile.inc \
-  tensorflow/lite/micro/examples/micro_speech/Makefile.inc \
-  tensorflow/lite/micro/examples/network_tester/Makefile.inc \
-  tensorflow/lite/micro/examples/person_detection/Makefile.inc \
-  tensorflow/lite/micro/examples/person_detection_experimental/Makefile.inc
-MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS))
-
diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_staging_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_staging_makefile.inc
deleted file mode 100644
index 557b8f6e9e6..00000000000
--- a/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_staging_makefile.inc
+++ /dev/null
@@ -1,62 +0,0 @@
-# Settings for Xtensa toolchain for the hifimini kernels.
-# REQUIRED:
-#  Environment variables:
-#   - XTENSA_BASE  must be set to location of
-#     the Xtensa developer tools installation directory.
-#  Command line arguments:
-#   - XTENSA_TOOLS_VERSION: For example: RI-2019.2-linux
-#   - XTENSA_CORE: The name of the Xtensa core to use
-#      For example: hifimini
-
-ifeq ($(TARGET), xtensa_hifimini_staging)
-  TARGET_ARCH := xtensa_hifimini_staging
-
-  ifndef XTENSA_BASE
-    $(error XTENSA_BASE is undefined)
-  endif
-
-  ifndef XTENSA_TOOLS_VERSION
-    $(error XTENSA_TOOLS_VERSION is undefined)
-  endif
-
-  ifndef XTENSA_CORE
-    $(error XTENSA_CORE is undefined)
-  endif
-
-  PLATFORM_ARGS = \
-    -DTF_LITE_MCU_DEBUG_LOG \
-    --xtensa-core=$(XTENSA_CORE) \
-    -mcoproc \
-    -DXTENSA -DMAX_RFFT_PWR=9 -DMIN_RFFT_PWR=MAX_RFFT_PWR \
-    -fdata-sections \
-    -ffunction-sections \
-    -fno-exceptions \
-    -fno-unwind-tables \
-    -fno-use-cxa-atexit \
-    -fmessage-length=0 \
-    -fno-threadsafe-statics
-
-  export PATH := $(XTENSA_BASE)/tools/$(XTENSA_TOOLS_VERSION)/XtensaTools/bin:$(PATH)
-  TARGET_TOOLCHAIN_PREFIX := xt-
-  CXX_TOOL := clang++
-  CC_TOOL := clang
-
-  CXXFLAGS += $(PLATFORM_ARGS)
-  CCFLAGS += $(PLATFORM_ARGS)
-
-  LDFLAGS += -Wl,-gc-sections
-
-  TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_hifimini_staging_binary.sh
-
-  # TODO(b/156962140): This manually maintained list of excluded examples is
-  # quite error prone.
-  EXCLUDED_EXAMPLE_TESTS := \
-    tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc \
-    tensorflow/lite/micro/examples/magic_wand/Makefile.inc \
-    tensorflow/lite/micro/examples/micro_speech/Makefile.inc \
-    tensorflow/lite/micro/examples/network_tester/Makefile.inc \
-    tensorflow/lite/micro/examples/person_detection/Makefile.inc \
-    tensorflow/lite/micro/examples/person_detection_experimental/Makefile.inc
-  MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS))
-
-endif
diff --git a/tensorflow/lite/micro/xtensa_hifimini_staging/debug_log.cc b/tensorflow/lite/micro/xtensa_hifimini_staging/debug_log.cc
deleted file mode 100644
index 45d9317478a..00000000000
--- a/tensorflow/lite/micro/xtensa_hifimini_staging/debug_log.cc
+++ /dev/null
@@ -1,50 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Reference implementation of the DebugLog() function that's required for a
-// platform to support the TensorFlow Lite for Microcontrollers library. This is
-// the only function that's absolutely required to be available on a target
-// device, since it's used for communicating test results back to the host so
-// that we can verify the implementation is working correctly.
-// It's designed to be as easy as possible to supply an implementation though.
-// On platforms that have a POSIX stack or C library, it can be written as a
-// single call to `fprintf(stderr, "%s", s)` to output a string to the error
-// stream of the console, but if there's no OS or C library available, there's
-// almost always an equivalent way to write out a string to some serial
-// interface that can be used instead. For example on Arm M-series MCUs, calling
-// the `bkpt #0xAB` assembler instruction will output the string in r1 to
-// whatever debug serial connection is available. If you're running mbed, you
-// can do the same by creating `Serial pc(USBTX, USBRX)` and then calling
-// `pc.printf("%s", s)`.
-// To add an equivalent function for your own platform, create your own
-// implementation file, and place it in a subfolder with named after the OS
-// you're targeting. For example, see the Cortex M bare metal version in
-// tensorflow/lite/micro/bluepill/debug_log.cc or the mbed one on
-// tensorflow/lite/micro/mbed/debug_log.cc.
-
-#include "tensorflow/lite/micro/debug_log.h"
-
-#ifndef TF_LITE_STRIP_ERROR_STRINGS
-#include <cstdio>
-#endif
-
-extern "C" void DebugLog(const char* s) {
-#ifndef TF_LITE_STRIP_ERROR_STRINGS
-  // Reusing TF_LITE_STRIP_ERROR_STRINGS to disable DebugLog completely to get
-  // maximum reduction in binary size. This is because we have DebugLog calls
-  // via TF_LITE_CHECK that are not stubbed out by TF_LITE_REPORT_ERROR.
-  fprintf(stderr, "%s", s);
-#endif
-}
diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py
index cd5a237b0ef..18ba58b27a5 100644
--- a/tensorflow/lite/python/interpreter.py
+++ b/tensorflow/lite/python/interpreter.py
@@ -460,7 +460,7 @@ class Interpreter(object):
     ]
 
   def get_tensor(self, tensor_index):
-    """Gets the value of the input tensor (get a copy).
+    """Gets the value of the output tensor (get a copy).
 
     If you wish to avoid the copy, use `tensor()`. This function cannot be used
     to read intermediate results.
diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD
index 3055d37fa07..5e7ecd7f5fd 100644
--- a/tensorflow/lite/python/interpreter_wrapper/BUILD
+++ b/tensorflow/lite/python/interpreter_wrapper/BUILD
@@ -38,7 +38,6 @@ cc_library(
         "//tensorflow/lite:util",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/core/api",
-        "//tensorflow/lite/experimental/tflite_api_dispatcher",
         "//tensorflow/lite/kernels:builtin_ops",
         "//third_party/python_runtime:headers",  # buildcleaner: keep
         "@com_google_absl//absl/memory",
@@ -90,7 +89,6 @@ pybind_extension(
         "@pybind11",
         "//third_party/python_runtime:headers",
         "//tensorflow/lite:framework_lib",
-        "//tensorflow/lite/experimental/tflite_api_dispatcher",
         "//tensorflow/python:pybind11_lib",
     ] + select({
         ":tflite_pip_with_flex": ["//tensorflow/lite/delegates/flex:delegate"],
diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index adfa760f147..9f97c79dd43 100644
--- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -66,8 +66,8 @@ namespace {
 
 using python_utils::PyDecrefDeleter;
 
-std::unique_ptr<tflite_api_dispatcher::Interpreter> CreateInterpreter(
-    const tflite_api_dispatcher::TfLiteModel* model,
+std::unique_ptr<Interpreter> CreateInterpreter(
+    const InterpreterWrapper::Model* model,
     const tflite::ops::builtin::BuiltinOpResolver& resolver) {
   if (!model) {
     return nullptr;
@@ -75,9 +75,8 @@ std::unique_ptr<tflite_api_dispatcher::Interpreter> CreateInterpreter(
 
   ::tflite::python::ImportNumpy();
 
-  std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter;
-  if (tflite_api_dispatcher::InterpreterBuilder(
-          *model, resolver)(&interpreter) != kTfLiteOk) {
+  std::unique_ptr<Interpreter> interpreter;
+  if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
     return nullptr;
   }
   return interpreter;
@@ -167,7 +166,7 @@ bool RegisterCustomOpByName(const char* registerer_name,
 }  // namespace
 
 InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
-    std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
+    std::unique_ptr<InterpreterWrapper::Model> model,
     std::unique_ptr<PythonErrorReporter> error_reporter,
     const std::vector<std::string>& registerers_by_name,
     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
@@ -198,10 +197,10 @@ InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
 }
 
 InterpreterWrapper::InterpreterWrapper(
-    std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
+    std::unique_ptr<InterpreterWrapper::Model> model,
     std::unique_ptr<PythonErrorReporter> error_reporter,
     std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
-    std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter)
+    std::unique_ptr<Interpreter> interpreter)
     : model_(std::move(model)),
       error_reporter_(std::move(error_reporter)),
       resolver_(std::move(resolver)),
@@ -537,9 +536,8 @@ namespace {
 
 // Checks to see if a tensor access can succeed (returns nullptr on error).
 // Otherwise returns Py_None.
-PyObject* CheckGetTensorArgs(tflite_api_dispatcher::Interpreter* interpreter_,
-                             int tensor_index, TfLiteTensor** tensor,
-                             int* type_num) {
+PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
+                             TfLiteTensor** tensor, int* type_num) {
   TFLITE_PY_ENSURE_VALID_INTERPRETER();
   TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index);
 
@@ -665,9 +663,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
     std::string* error_msg) {
   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
-  std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model =
-      tflite_api_dispatcher::TfLiteModel::BuildFromFile(model_path,
-                                                        error_reporter.get());
+  std::unique_ptr<InterpreterWrapper::Model> model =
+      Model::BuildFromFile(model_path, error_reporter.get());
   return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
                                   registerers_by_name, registerers_by_func,
                                   error_msg);
@@ -690,9 +687,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
     return nullptr;
   }
-  std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model =
-      tflite_api_dispatcher::TfLiteModel::BuildFromBuffer(buf, length,
-                                                          error_reporter.get());
+  std::unique_ptr<InterpreterWrapper::Model> model =
+      Model::BuildFromBuffer(buf, length, error_reporter.get());
   return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
                                   registerers_by_name, registerers_by_func,
                                   error_msg);
diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
index 6b83d2d06db..8f03d0915b1 100644
--- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -27,7 +27,6 @@ limitations under the License.
 // automatically move <Python.h> before <locale>.
 #include <Python.h>
 
-#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h"
 #include "tensorflow/lite/interpreter.h"
 
 struct TfLiteDelegate;
@@ -48,6 +47,8 @@ class PythonErrorReporter;
 
 class InterpreterWrapper {
  public:
+  using Model = FlatBufferModel;
+
   // SWIG caller takes ownership of pointer.
   static InterpreterWrapper* CreateWrapperCPPFromFile(
       const char* model_path, const std::vector<std::string>& registerers,
@@ -105,26 +106,24 @@ class InterpreterWrapper {
   // Experimental and subject to change.
   //
   // Returns a pointer to the underlying interpreter.
-  tflite_api_dispatcher::Interpreter* interpreter() {
-    return interpreter_.get();
-  }
+  Interpreter* interpreter() { return interpreter_.get(); }
 
  private:
   // Helper function to construct an `InterpreterWrapper` object.
   // It only returns InterpreterWrapper if it can construct an `Interpreter`.
   // Otherwise it returns `nullptr`.
   static InterpreterWrapper* CreateInterpreterWrapper(
-      std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
+      std::unique_ptr<Model> model,
       std::unique_ptr<PythonErrorReporter> error_reporter,
       const std::vector<std::string>& registerers_by_name,
       const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
       std::string* error_msg);
 
   InterpreterWrapper(
-      std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
+      std::unique_ptr<Model> model,
       std::unique_ptr<PythonErrorReporter> error_reporter,
       std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
-      std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter);
+      std::unique_ptr<Interpreter> interpreter);
 
   // InterpreterWrapper is not copyable or assignable. We avoid the use of
   // InterpreterWrapper() = delete here for SWIG compatibility.
@@ -137,10 +136,10 @@ class InterpreterWrapper {
   // The public functions which creates `InterpreterWrapper` should ensure all
   // these member variables are initialized successfully. Otherwise it should
   // report the error and return `nullptr`.
-  const std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model_;
+  const std::unique_ptr<Model> model_;
   const std::unique_ptr<PythonErrorReporter> error_reporter_;
   const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_;
-  const std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter_;
+  const std::unique_ptr<Interpreter> interpreter_;
 };
 
 }  // namespace interpreter_wrapper
diff --git a/tensorflow/lite/python/op_hint.py b/tensorflow/lite/python/op_hint.py
index 9d62c1b8a97..e8aff7b82da 100644
--- a/tensorflow/lite/python/op_hint.py
+++ b/tensorflow/lite/python/op_hint.py
@@ -89,11 +89,18 @@ from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
 from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
 from tensorflow.python.ops import array_ops as _array_ops
 from tensorflow.python.util import compat as _compat
+from tensorflow.python.util import deprecation as _deprecation
 from tensorflow.python.util.all_util import remove_undocumented
 from tensorflow.python.util.tf_export import tf_export as _tf_export
 
 
 @_tf_export(v1=["lite.OpHint"])
+@_deprecation.deprecated(
+    None,
+    "Please follow instructions under "
+    "https://www.tensorflow.org/lite/convert/operation_fusion for operation"
+    "fusion in tflite."
+)
 class OpHint(object):
   """A class that helps build tflite function invocations.
 
@@ -1302,6 +1309,12 @@ def is_ophint_converted(graph_def):
 
 
 @_tf_export(v1=["lite.experimental.convert_op_hints_to_stubs"])
+@_deprecation.deprecated(
+    None,
+    "Please follow instructions under "
+    "https://www.tensorflow.org/lite/convert/operation_fusion for operation"
+    "fusion in tflite."
+)
 def convert_op_hints_to_stubs(session=None,
                               graph_def=None,
                               write_callback=lambda graph_def, comments: None):
diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD
index 74380bd9d4c..dbbb69b49f4 100644
--- a/tensorflow/lite/testing/BUILD
+++ b/tensorflow/lite/testing/BUILD
@@ -312,7 +312,7 @@ cc_library(
     hdrs = ["util.h"],
     deps = [
         "//tensorflow/core/platform:logging",
-        "//tensorflow/lite:framework",
+        "//tensorflow/lite:error_reporter",
         "//tensorflow/lite:string",
         "//tensorflow/lite/core/api",
     ],
diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD
index e2146910d52..aa852093794 100644
--- a/tensorflow/lite/tools/BUILD
+++ b/tensorflow/lite/tools/BUILD
@@ -258,7 +258,7 @@ cc_library(
     srcs = ["command_line_flags.cc"],
     hdrs = ["command_line_flags.h"],
     copts = tflite_copts(),
-    deps = [":logging"],
+    deps = ["//tensorflow/lite/tools:logging"],
 )
 
 cc_test(
@@ -296,8 +296,8 @@ tf_cc_binary(
     srcs = ["list_flex_ops_main.cc"],
     visibility = ["//visibility:public"],
     deps = [
-        ":command_line_flags",
         ":list_flex_ops",
+        "//tensorflow/lite/tools:command_line_flags",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -307,8 +307,8 @@ cc_library(
     srcs = ["list_flex_ops_main.cc"],
     visibility = ["//visibility:public"],
     deps = [
-        ":command_line_flags",
         ":list_flex_ops",
+        "//tensorflow/lite/tools:command_line_flags",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -329,8 +329,8 @@ tf_cc_binary(
     srcs = ["list_flex_ops_main.cc"],
     visibility = ["//visibility:public"],
     deps = [
-        ":command_line_flags",
         ":list_flex_ops_no_kernel",
+        "//tensorflow/lite/tools:command_line_flags",
         "@com_google_absl//absl/strings",
     ],
 )
diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD
index 430bcf32766..eb3f37aef58 100644
--- a/tensorflow/lite/tools/benchmark/BUILD
+++ b/tensorflow/lite/tools/benchmark/BUILD
@@ -213,6 +213,7 @@ cc_library(
         ":benchmark_params",
         ":benchmark_utils",
         "//tensorflow/core/util:stats_calculator_portable",
+        "//tensorflow/lite:framework",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/profiling:memory_info",
         "//tensorflow/lite/profiling:time",
diff --git a/tensorflow/lite/tools/benchmark/android/BUILD b/tensorflow/lite/tools/benchmark/android/BUILD
index 09cf57e5baf..be9c379af50 100644
--- a/tensorflow/lite/tools/benchmark/android/BUILD
+++ b/tensorflow/lite/tools/benchmark/android/BUILD
@@ -22,6 +22,10 @@ android_binary(
     # can't be built. We need to prevent the build system from trying to
     # use the target in that case.
     tags = ["manual"],
+    deps = [
+        ":hexagon_libs",
+        ":tensorflowlite_benchmark_native",
+    ],
 )
 
 tflite_jni_binary(
diff --git a/tensorflow/lite/tools/benchmark/experimental/firebase/android/BUILD b/tensorflow/lite/tools/benchmark/experimental/firebase/android/BUILD
index 5b25cba0487..98591063962 100644
--- a/tensorflow/lite/tools/benchmark/experimental/firebase/android/BUILD
+++ b/tensorflow/lite/tools/benchmark/experimental/firebase/android/BUILD
@@ -23,6 +23,10 @@ android_binary(
     # can't be built. We need to prevent the build system from trying to
     # use the target in that case.
     tags = ["manual"],
+    deps = [
+        ":hexagon_libs",
+        ":tensorflowlite_benchmark_firebase_native",
+    ],
 )
 
 tflite_jni_binary(
diff --git a/tensorflow/lite/tools/cmake/README.md b/tensorflow/lite/tools/cmake/README.md
index c48685a8c1e..b49162f0eba 100644
--- a/tensorflow/lite/tools/cmake/README.md
+++ b/tensorflow/lite/tools/cmake/README.md
@@ -41,6 +41,13 @@ cd tflite_build
 cmake ../tensorflow_src/tensorflow/lite
 ```
 
+It generates release binary by default. If you need to produce debug builds, you
+need to provide '-DCMAKE_BUILD_TYPE=Debug' option.
+
+```sh
+cmake ../tensorflow_src/tensorflow/lite -DCMAKE_BUILD_TYPE=Debug
+```
+
 If you want to configure Android build with GPU delegate support,
 
 ```sh
diff --git a/tensorflow/lite/tools/cmake/modules/Findegl_headers.cmake b/tensorflow/lite/tools/cmake/modules/Findegl_headers.cmake
new file mode 100644
index 00000000000..02cf736faea
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/Findegl_headers.cmake
@@ -0,0 +1,16 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+include(egl_headers)
diff --git a/tensorflow/lite/tools/cmake/modules/Findopengl_headers.cmake b/tensorflow/lite/tools/cmake/modules/Findopengl_headers.cmake
new file mode 100644
index 00000000000..7651549eb1c
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/Findopengl_headers.cmake
@@ -0,0 +1,16 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+include(opengl_headers)
diff --git a/tensorflow/lite/tools/cmake/modules/egl_headers.cmake b/tensorflow/lite/tools/cmake/modules/egl_headers.cmake
new file mode 100644
index 00000000000..f6a23dbee06
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/egl_headers.cmake
@@ -0,0 +1,39 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(TARGET egl_headers OR egl_headers_POPULATED)
+  return()
+endif()
+
+include(FetchContent)
+
+OverridableFetchContent_Declare(
+  egl_headers
+  GIT_REPOSITORY https://github.com/KhronosGroup/EGL-Registry.git
+  GIT_TAG 649981109e263b737e7735933c90626c29a306f2
+  GIT_PROGRESS TRUE
+  PREFIX "${CMAKE_BINARY_DIR}"
+  SOURCE_DIR "${CMAKE_BINARY_DIR}/egl_headers"
+)
+
+OverridableFetchContent_GetProperties(egl_headers)
+if(NOT egl_headers)
+  OverridableFetchContent_Populate(egl_headers)
+endif()
+
+include_directories(
+  AFTER
+   "${egl_headers_SOURCE_DIR}/api"
+)
diff --git a/tensorflow/lite/tools/cmake/modules/opengl_headers.cmake b/tensorflow/lite/tools/cmake/modules/opengl_headers.cmake
new file mode 100644
index 00000000000..c9db6b48306
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/opengl_headers.cmake
@@ -0,0 +1,39 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(TARGET opengl_headers OR opengl_headers_POPULATED)
+  return()
+endif()
+
+include(FetchContent)
+
+OverridableFetchContent_Declare(
+  opengl_headers
+  GIT_REPOSITORY https://github.com/KhronosGroup/OpenGL-Registry.git
+  GIT_TAG 0cb0880d91581d34f96899c86fc1bf35627b4b81
+  GIT_PROGRESS TRUE
+  PREFIX "${CMAKE_BINARY_DIR}"
+  SOURCE_DIR "${CMAKE_BINARY_DIR}/opengl_headers"
+)
+
+OverridableFetchContent_GetProperties(opengl_headers)
+if(NOT opengl_headers)
+  OverridableFetchContent_Populate(opengl_headers)
+endif()
+
+include_directories(
+  AFTER
+   "${opengl_headers_SOURCE_DIR}/api"
+)
diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/tools/evaluation/stages/BUILD
index 1e621956ea8..4363aff01a9 100644
--- a/tensorflow/lite/tools/evaluation/stages/BUILD
+++ b/tensorflow/lite/tools/evaluation/stages/BUILD
@@ -180,6 +180,7 @@ cc_test(
     deps = [
         ":inference_profiler_stage",
         "//tensorflow/lite/c:common",
+        "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
         "@com_google_googletest//:gtest_main",
diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD
index 34a2f9ce68c..cfb49a20a78 100644
--- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD
+++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD
@@ -29,6 +29,7 @@ cc_library(
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/tools:command_line_flags",
         "//tensorflow/lite/tools:logging",
+        "//tensorflow/lite/tools/evaluation:evaluation_stage",
         "//tensorflow/lite/tools/evaluation:utils",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD
index c315e5ca4a6..941bbc0ff69 100644
--- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD
+++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD
@@ -17,6 +17,7 @@ cc_library(
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/tools:command_line_flags",
         "//tensorflow/lite/tools:logging",
+        "//tensorflow/lite/tools/evaluation:evaluation_stage",
         "//tensorflow/lite/tools/evaluation:utils",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD
index 8befc356d41..36606722caf 100644
--- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD
+++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD
@@ -17,6 +17,7 @@ cc_library(
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/tools:command_line_flags",
         "//tensorflow/lite/tools:logging",
+        "//tensorflow/lite/tools/evaluation:evaluation_stage",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
         "//tensorflow/lite/tools/evaluation/stages:inference_profiler_stage",
diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD
index 715afa642a3..c77c08eeb6b 100644
--- a/tensorflow/lite/tools/optimize/BUILD
+++ b/tensorflow/lite/tools/optimize/BUILD
@@ -43,13 +43,17 @@ tf_cc_test(
         "//tensorflow/lite/schema:schema_utils",
         "@com_google_absl//absl/memory",
         "@com_google_googletest//:gtest",
+        "@flatbuffers",
     ],
 )
 
 cc_binary(
     name = "modify_model_interface_main",
     srcs = ["modify_model_interface_main.cc"],
-    deps = [":modify_model_interface"],
+    deps = [
+        ":modify_model_interface",
+        ":quantize_model",
+    ],
 )
 
 cc_library(
@@ -61,6 +65,7 @@ cc_library(
         "//tensorflow/lite:framework",
         "//tensorflow/lite/core/api",
         "//tensorflow/lite/schema:schema_fbs",
+        "@flatbuffers",
     ],
 )
 
@@ -77,7 +82,9 @@ tf_cc_test(
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/schema:schema_utils",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest",
+        "@flatbuffers",
     ],
 )
 
@@ -102,7 +109,9 @@ tf_cc_test(
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/schema:schema_utils",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest",
+        "@flatbuffers",
     ],
 )
 
@@ -112,7 +121,11 @@ cc_library(
     hdrs = ["quantization_wrapper.h"],
     deps = [
         ":quantization_wrapper_utils",
-        ":quantize_model",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite/core/api",
+        "//tensorflow/lite/schema:schema_fbs",
+        "//tensorflow/lite/tools/optimize:quantize_model",
+        "@flatbuffers",
     ],
 )
 
@@ -134,6 +147,7 @@ cc_library(
         "//tensorflow/lite/schema:schema_fbs",
         "//third_party/eigen3",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -151,6 +165,7 @@ cc_library(
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/schema:schema_utils",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -169,6 +184,7 @@ tf_cc_test(
         "//tensorflow/lite:framework",
         "//tensorflow/lite/schema:schema_fbs",
         "@com_google_googletest//:gtest",
+        "@flatbuffers",
     ],
 )
 
@@ -179,6 +195,7 @@ cc_library(
     compatible_with = get_compatible_with_cloud(),
     deps = [
         "//tensorflow/lite:framework",
+        "//tensorflow/lite/kernels/internal:types",
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/schema:schema_utils",
     ],
@@ -206,8 +223,11 @@ tf_cc_test(
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/schema:schema_utils",
         "//tensorflow/lite/testing:util",
+        "//third_party/eigen3",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest",
+        "@flatbuffers",
     ],
 )
 
@@ -220,6 +240,7 @@ cc_library(
         ":quantization_utils",
         ":model_utils",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/container:flat_hash_map",
         "@flatbuffers",
         "//tensorflow/lite:framework",
@@ -238,10 +259,10 @@ tf_cc_test(
         "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
     ],
     data = [
-        ":testdata/custom_op.bin",
-        ":testdata/quantized_with_gather.bin",
-        ":testdata/single_conv_weights_min_0_max_plus_10.bin",
-        ":testdata/weight_shared_between_convs.bin",
+        "//tensorflow/lite/tools/optimize:testdata/custom_op.bin",
+        "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin",
+        "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
+        "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin",
     ],
     tags = [
         "tflite_not_portable_android",
@@ -266,8 +287,10 @@ cc_library(
     srcs = ["test_util.cc"],
     hdrs = ["test_util.h"],
     deps = [
+        "//tensorflow/lite:framework",
         "//tensorflow/lite/core/api",
         "@com_google_googletest//:gtest",
+        "@flatbuffers",
     ],
 )
 
@@ -295,32 +318,32 @@ tf_cc_test(
         "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
     ],
     data = [
-        ":testdata/add_with_const_input.bin",
-        ":testdata/argmax.bin",
-        ":testdata/concat.bin",
-        ":testdata/fc.bin",
-        ":testdata/lstm_calibrated.bin",
-        ":testdata/lstm_calibrated2.bin",
-        ":testdata/lstm_quantized.bin",
-        ":testdata/lstm_quantized2.bin",
-        ":testdata/maximum.bin",
-        ":testdata/minimum.bin",
-        ":testdata/mixed.bin",
-        ":testdata/mixed16x8.bin",
-        ":testdata/multi_input_add_reshape.bin",
-        ":testdata/pack.bin",
-        ":testdata/single_avg_pool_min_minus_5_max_plus_5.bin",
-        ":testdata/single_conv_no_bias.bin",
-        ":testdata/single_conv_weights_min_0_max_plus_10.bin",
-        ":testdata/single_conv_weights_min_minus_127_max_plus_127.bin",
-        ":testdata/single_softmax_min_minus_5_max_plus_5.bin",
-        ":testdata/split.bin",
-        ":testdata/svdf_calibrated.bin",
-        ":testdata/svdf_quantized.bin",
-        ":testdata/transpose.bin",
-        ":testdata/unidirectional_sequence_lstm_calibrated.bin",
-        ":testdata/unidirectional_sequence_lstm_quantized.bin",
-        ":testdata/unpack.bin",
+        "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin",
+        "//tensorflow/lite/tools/optimize:testdata/argmax.bin",
+        "//tensorflow/lite/tools/optimize:testdata/concat.bin",
+        "//tensorflow/lite/tools/optimize:testdata/fc.bin",
+        "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin",
+        "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin",
+        "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin",
+        "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin",
+        "//tensorflow/lite/tools/optimize:testdata/maximum.bin",
+        "//tensorflow/lite/tools/optimize:testdata/minimum.bin",
+        "//tensorflow/lite/tools/optimize:testdata/mixed.bin",
+        "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin",
+        "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin",
+        "//tensorflow/lite/tools/optimize:testdata/pack.bin",
+        "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin",
+        "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin",
+        "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
+        "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin",
+        "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin",
+        "//tensorflow/lite/tools/optimize:testdata/split.bin",
+        "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin",
+        "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin",
+        "//tensorflow/lite/tools/optimize:testdata/transpose.bin",
+        "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin",
+        "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin",
+        "//tensorflow/lite/tools/optimize:testdata/unpack.bin",
     ],
     tags = [
         "tflite_not_portable_android",
diff --git a/tensorflow/lite/tools/optimize/calibration/BUILD b/tensorflow/lite/tools/optimize/calibration/BUILD
index 64f15857bdb..bf4a9b86233 100644
--- a/tensorflow/lite/tools/optimize/calibration/BUILD
+++ b/tensorflow/lite/tools/optimize/calibration/BUILD
@@ -70,7 +70,9 @@ cc_library(
         "//tensorflow/lite/kernels:kernel_util",
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/schema:schema_utils",
+        "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
         "@flatbuffers",
     ],
 )
@@ -99,6 +101,7 @@ tf_cc_test(
         "//tensorflow/lite/kernels:builtin_ops",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest",
     ],
 )
@@ -125,6 +128,7 @@ cc_test(
     deps = [
         ":logging_op_resolver",
         "//tensorflow/lite:framework",
+        "//tensorflow/lite/kernels:builtin_ops",
         "@com_google_googletest//:gtest",
     ],
 )
@@ -139,6 +143,7 @@ cc_library(
         "//tensorflow/lite:framework",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -148,6 +153,7 @@ cc_library(
     hdrs = ["calibration_logger.h"],
     copts = tflite_copts(),
     deps = [
+        "//tensorflow/lite:framework",
         "//tensorflow/lite:minimal_logging",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/core/api",
diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc
index b46393b964f..2c993945ea1 100644
--- a/tensorflow/lite/tools/optimize/operator_property.cc
+++ b/tensorflow/lite/tools/optimize/operator_property.cc
@@ -67,6 +67,9 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index,
 }
 }  // namespace
 
+// Update operation defintions in TensorFlow Lite dialect accordingly when there
+// are any needs on updating the kernel support level.
+// LINT.IfChange
 OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
                                      int op_index) {
   OpVariant op_variant = GetOperatorVariant(model, subgraph_index, op_index);
@@ -933,7 +936,6 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
       property.inputs = {{0, {}}};
       property.outputs = {{0, {}}};
       property.version = 2;
-      property.quantizable_int16 = false;
       break;
     case BuiltinOperator_TANH: {
       property.inputs = {{0, {}}};
@@ -1002,7 +1004,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
       property.quantizable_int16 = false;
   }
   return property;
-}
+}  // NOLINT(readability/fn_size)
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_ops.td)
 
 }  // namespace operator_property
 }  // namespace optimize
diff --git a/tensorflow/lite/tools/optimize/python/BUILD b/tensorflow/lite/tools/optimize/python/BUILD
index b25f3c710fe..34f57cbecaf 100644
--- a/tensorflow/lite/tools/optimize/python/BUILD
+++ b/tensorflow/lite/tools/optimize/python/BUILD
@@ -26,6 +26,7 @@ py_library(
     deps = [
         ":_pywrap_modify_model_interface",
         ":modify_model_interface_constants",
+        "//tensorflow:tensorflow_py",
         "//tensorflow/lite/python:schema_py",
     ],
 )
diff --git a/tensorflow/lite/experimental/writer/BUILD b/tensorflow/lite/tools/serialization/BUILD
similarity index 100%
rename from tensorflow/lite/experimental/writer/BUILD
rename to tensorflow/lite/tools/serialization/BUILD
diff --git a/tensorflow/lite/experimental/writer/enum_mapping.h b/tensorflow/lite/tools/serialization/enum_mapping.h
similarity index 96%
rename from tensorflow/lite/experimental/writer/enum_mapping.h
rename to tensorflow/lite/tools/serialization/enum_mapping.h
index 688ee406125..a79e25d844e 100644
--- a/tensorflow/lite/experimental/writer/enum_mapping.h
+++ b/tensorflow/lite/tools/serialization/enum_mapping.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
-#define TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+#ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_ENUM_MAPPING_H_
+#define TENSORFLOW_LITE_TOOLS_SERIALIZATION_ENUM_MAPPING_H_
 
 #include "tensorflow/lite/builtin_op_data.h"
 #include "tensorflow/lite/schema/reflection/schema_generated.h"
@@ -147,4 +147,4 @@ inline CombinerType CombinerTypeToSchema(TfLiteCombinerType type) {
 // int
 
 }  // namespace tflite
-#endif  // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+#endif  // TENSORFLOW_LITE_TOOLS_SERIALIZATION_ENUM_MAPPING_H_
diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/tools/serialization/option_writer_generator.cc
similarity index 100%
rename from tensorflow/lite/experimental/writer/option_writer_generator.cc
rename to tensorflow/lite/tools/serialization/option_writer_generator.cc
diff --git a/tensorflow/lite/experimental/writer/writer.cc b/tensorflow/lite/tools/serialization/writer.cc
similarity index 96%
rename from tensorflow/lite/experimental/writer/writer.cc
rename to tensorflow/lite/tools/serialization/writer.cc
index 3977c8e1003..fb816792b6a 100644
--- a/tensorflow/lite/experimental/writer/writer.cc
+++ b/tensorflow/lite/tools/serialization/writer.cc
@@ -20,9 +20,9 @@ limitations under the License.
 
 #include <iostream>
 
-#include "tensorflow/lite/experimental/writer/writer_lib.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
+#include "tensorflow/lite/tools/serialization/writer_lib.h"
 
 int main(int argc, char* argv[]) {
   if (argc != 3) {
diff --git a/tensorflow/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/tools/serialization/writer_lib.cc
similarity index 98%
rename from tensorflow/lite/experimental/writer/writer_lib.cc
rename to tensorflow/lite/tools/serialization/writer_lib.cc
index 9f18fff76d5..0d831f5f9a0 100644
--- a/tensorflow/lite/experimental/writer/writer_lib.cc
+++ b/tensorflow/lite/tools/serialization/writer_lib.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/lite/experimental/writer/writer_lib.h"
+#include "tensorflow/lite/tools/serialization/writer_lib.h"
 
 #include <cstdlib>
 #include <cstring>
@@ -23,9 +23,9 @@ limitations under the License.
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/context_util.h"
 #include "tensorflow/lite/core/subgraph.h"
-#include "tensorflow/lite/experimental/writer/enum_mapping.h"
 #include "tensorflow/lite/schema/reflection/schema_generated.h"
 #include "tensorflow/lite/schema/schema_utils.h"
+#include "tensorflow/lite/tools/serialization/enum_mapping.h"
 #include "tensorflow/lite/version.h"
 
 namespace tflite {
@@ -34,7 +34,7 @@ std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
     flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
     void* builtin_op_data, const TfLiteNode& node) {
   switch (op) {
-#include "tensorflow/lite/experimental/writer/option_writer_generated.h"
+#include "tensorflow/lite/tools/serialization/option_writer_generated.h"
   }
   return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset<void>());
 }
diff --git a/tensorflow/lite/experimental/writer/writer_lib.h b/tensorflow/lite/tools/serialization/writer_lib.h
similarity index 96%
rename from tensorflow/lite/experimental/writer/writer_lib.h
rename to tensorflow/lite/tools/serialization/writer_lib.h
index f7816dcc33e..a18a3dd0958 100644
--- a/tensorflow/lite/experimental/writer/writer_lib.h
+++ b/tensorflow/lite/tools/serialization/writer_lib.h
@@ -24,8 +24,8 @@ limitations under the License.
 //   // Build Interpreter however
 //   // ... <omitted>
 //   SubgraphWriter(&interpreter->primary_subgraph()).Write("output.tflite");
-#ifndef TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
-#define TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
+#define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
 #include <iostream>
 #include <unordered_map>
 
@@ -33,8 +33,8 @@ limitations under the License.
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/context_util.h"
 #include "tensorflow/lite/core/subgraph.h"
-#include "tensorflow/lite/experimental/writer/enum_mapping.h"
 #include "tensorflow/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/lite/tools/serialization/enum_mapping.h"
 #include "tensorflow/lite/version.h"
 
 namespace tflite {
@@ -149,4 +149,4 @@ class SubgraphWriter {
 
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#endif  // TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
diff --git a/tensorflow/lite/experimental/writer/writer_lib_test.cc b/tensorflow/lite/tools/serialization/writer_lib_test.cc
similarity index 99%
rename from tensorflow/lite/experimental/writer/writer_lib_test.cc
rename to tensorflow/lite/tools/serialization/writer_lib_test.cc
index bf50d4944f1..189b4bc106f 100644
--- a/tensorflow/lite/experimental/writer/writer_lib_test.cc
+++ b/tensorflow/lite/tools/serialization/writer_lib_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/experimental/writer/writer_lib.h"
+#include "tensorflow/lite/tools/serialization/writer_lib.h"
 
 #include <numeric>
 #include <sstream>
diff --git a/tensorflow/lite/experimental/writer/writer_test.cc b/tensorflow/lite/tools/serialization/writer_test.cc
similarity index 97%
rename from tensorflow/lite/experimental/writer/writer_test.cc
rename to tensorflow/lite/tools/serialization/writer_test.cc
index ac89b74291f..ccaab76776b 100644
--- a/tensorflow/lite/experimental/writer/writer_test.cc
+++ b/tensorflow/lite/tools/serialization/writer_test.cc
@@ -21,9 +21,9 @@ limitations under the License.
 
 #include <iostream>
 
-#include "tensorflow/lite/experimental/writer/writer_lib.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
+#include "tensorflow/lite/tools/serialization/writer_lib.h"
 
 int main(int argc, char* argv[]) {
   if (argc != 2) {
diff --git a/tensorflow/lite/tools/signature/BUILD b/tensorflow/lite/tools/signature/BUILD
index 7fd83562258..05fc106d759 100644
--- a/tensorflow/lite/tools/signature/BUILD
+++ b/tensorflow/lite/tools/signature/BUILD
@@ -81,7 +81,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":_pywrap_signature_def_util_wrapper",
-        "//tensorflow/core/protobuf:meta_graph_pyclif_pb2",
+        "//tensorflow/core:protos_all_py",
     ],
 )
 
@@ -98,7 +98,7 @@ py_test(
     deps = [
         ":signature_def_utils",
         "//tensorflow:tensorflow_py",
-        "//tensorflow/core/protobuf:meta_graph_pyclif_pb2",
+        "//tensorflow/core:protos_all_py",
     ],
 )
 
diff --git a/tensorflow/lite/tools/versioning/BUILD b/tensorflow/lite/tools/versioning/BUILD
index 31fb903ce8d..06ac1968f52 100644
--- a/tensorflow/lite/tools/versioning/BUILD
+++ b/tensorflow/lite/tools/versioning/BUILD
@@ -43,6 +43,7 @@ tf_cc_test(
         ":versioning",
         "//tensorflow/lite/kernels:builtin_ops",
         "//tensorflow/lite/schema:schema_fbs",
+        "//tensorflow/lite/schema:schema_fbs_with_mutable",
         "@com_google_googletest//:gtest_main",
     ],
 )
diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files
index 97ab4635cec..bf8bd7ab6bb 100644
--- a/tensorflow/opensource_only.files
+++ b/tensorflow/opensource_only.files
@@ -5,6 +5,16 @@ tensorflow/compat_template.__init__.py
 tensorflow/compat_template_v1.__init__.py
 tensorflow/compiler/mlir/glob_lit_test.bzl
 tensorflow/go/op/wrappers.go
+tensorflow/lite/core/shims/BUILD
+tensorflow/lite/core/shims/c/builtin_op_data.h
+tensorflow/lite/core/shims/c/c_api.h
+tensorflow/lite/core/shims/c/c_api_experimental.h
+tensorflow/lite/core/shims/c/common.h
+tensorflow/lite/core/shims/cc/interpreter.h
+tensorflow/lite/core/shims/cc/interpreter_builder.h
+tensorflow/lite/core/shims/cc/kernels/register.h
+tensorflow/lite/core/shims/cc/model.h
+tensorflow/lite/core/shims/cc/model_builder.h
 tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h
 tensorflow/lite/delegates/gpu/cl/serialization_generated.h
 tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index e3b016e0201..31c205c638e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3246,7 +3246,7 @@ py_library(
     ],
 )
 
-py_test(
+cuda_py_test(
     name = "batch_ops_test",
     size = "small",
     srcs = ["ops/batch_ops_test.py"],
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index fa894acf9bd..86252637df1 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 11, 18)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 11, 23)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 
diff --git a/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py b/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py
index 7eadb001708..8ef97107e33 100644
--- a/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py
@@ -59,19 +59,11 @@ class MaxBatchSizesTestBase(trt_test.TfTrtIntegrationTestBase):
     # engines.
     return (not run_params.dynamic_engine, 'test static engine only.')
 
-  def GetConversionParams(self, run_params):
-    """Returns a ConversionParams for test."""
-    conversion_params = super(MaxBatchSizesTestBase,
-                              self).GetConversionParams(run_params)
-    conversion_params._replace(
-        max_batch_size=min(self.max_batch_sizes), maximum_cached_engines=1)
-    rewrite_config_with_trt = self.GetTrtRewriterConfig(
-        run_params=run_params,
-        conversion_params=conversion_params,
-        use_implicit_batch=True,
-        disable_non_trt_optimizers=True)
-    return conversion_params._replace(
-        rewriter_config_template=rewrite_config_with_trt)
+  def GetMaxBatchSize(self, run_params):
+    """Returns the max_batch_size that the converter should use for tests."""
+    if run_params.dynamic_engine:
+      return None
+    return min(self.max_batch_sizes)
 
   def ExpectedEnginesToBuild(self, run_params):
     """Checks that the expected engine is built.
diff --git a/tensorflow/python/compiler/tensorrt/test/base_test.py b/tensorflow/python/compiler/tensorrt/test/base_test.py
index 9d2d3abd4fb..b43749fd305 100644
--- a/tensorflow/python/compiler/tensorrt/test/base_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/base_test.py
@@ -117,18 +117,11 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
         "TRTEngineOp_1": ["c2", "conv", "div", "weights"]
     }
 
-  def GetConversionParams(self, run_params):
-    """Return a ConversionParams for test."""
-    conversion_params = super(SimpleMultiEnginesTest,
-                              self).GetConversionParams(run_params)
-    rewrite_config_with_trt = self.GetTrtRewriterConfig(
-        run_params=run_params,
-        conversion_params=conversion_params,
-        # Disable layout optimizer, since it will convert BiasAdd with NHWC
-        # format to NCHW format under four dimentional input.
-        disable_non_trt_optimizers=True)
-    return conversion_params._replace(
-        rewriter_config_template=rewrite_config_with_trt)
+  def setUp(self):
+    super(trt_test.TfTrtIntegrationTestBase, self).setUp()
+    # Disable layout optimizer, since it will convert BiasAdd with NHWC
+    # format to NCHW format under four dimentional input.
+    self.DisableNonTrtOptimizers()
 
 
 class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase):
diff --git a/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py
index 04445cc99aa..b229cff47dc 100644
--- a/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py
@@ -106,19 +106,18 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
     return self.BuildParams(self.GraphFn, dtypes.float32, [[4, 144]],
                             [[4, 6680]])
 
-  def GetConversionParams(self, run_params):
-    """Return a ConversionParams for test."""
-    conversion_params = super(BiasaddMatMulTest,
-                              self).GetConversionParams(run_params)
-    conversion_params._replace(max_batch_size=4, maximum_cached_engines=1)
-    rewrite_config_with_trt = self.GetTrtRewriterConfig(
-        run_params=run_params,
-        conversion_params=conversion_params,
-        # Disable layout optimizer, since it will convert BiasAdd with NHWC
-        # format to NCHW format under four dimensional input.
-        disable_non_trt_optimizers=True)
-    return conversion_params._replace(
-        rewriter_config_template=rewrite_config_with_trt)
+  def setUp(self):
+    super(trt_test.TfTrtIntegrationTestBase, self).setUp()
+    # Disable layout optimizer, since it will convert BiasAdd with NHWC
+    # format to NCHW format under four dimentional input.
+    self.DisableNonTrtOptimizers()
+
+  def GetMaxBatchSize(self, run_params):
+    """Returns the max_batch_size that the converter should use for tests."""
+    if run_params.dynamic_engine:
+      return None
+
+    return 4
 
   def ExpectedEnginesToBuild(self, run_params):
     """Return the expected engines to build."""
diff --git a/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py b/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py
index 23397d76cc3..6d45d358b82 100644
--- a/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py
@@ -91,13 +91,10 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase):
     }
 
   def ShouldRunTest(self, run_params):
-    # There is no CombinedNonMaxSuppression op for GPU at the moment, so
-    # calibration will fail.
-    # TODO(laigd): fix this.
+    should_run, reason = super().ShouldRunTest(run_params)
     # Only run for TRT 5.1 and above.
-    return trt_test.IsTensorRTVersionGreaterEqual(
-        5, 1) and not trt_test.IsQuantizationMode(
-            run_params.precision_mode), 'test >=TRT5.1 and non-INT8'
+    return should_run and trt_test.IsTensorRTVersionGreaterEqual(
+        5, 1), reason + ' and >=TRT5.1'
 
 
 class CombinedNmsExecuteNativeSegmentTest(CombinedNmsTest):
@@ -110,15 +107,17 @@ class CombinedNmsExecuteNativeSegmentTest(CombinedNmsTest):
     super().tearDown()
     os.environ['TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION'] = 'False'
 
-  def GetConversionParams(self, run_params):
-    conversion_param = super().GetConversionParams(run_params)
+  def GetMaxBatchSize(self, run_params):
+    """Returns the max_batch_size that the converter should use for tests."""
+    if run_params.dynamic_engine:
+      return None
+
     # Build the engine with the allowed max_batch_size less than the actual
     # max_batch_size, to fore the runtime to execute the native segment. This
     # is to test that combined_non_max_suppression, which doesn't have a TF GPU
     # implementation, can be executed natively even though the it is in the
     # the graph for the TRTEngineOp with a GPU as a default device.
-    return conversion_param._replace(
-        max_batch_size=conversion_param.max_batch_size - 1)
+    return super().GetMaxBatchSize(run_params) - 1
 
   def ShouldRunTest(self, run_params):
     should_run, reason = super().ShouldRunTest(run_params)
diff --git a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py
index 95dbe727ac3..2cf9abe8455 100644
--- a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py
@@ -80,19 +80,11 @@ class DynamicInputShapesTest(trt_test.TfTrtIntegrationTestBase):
         input_dims=input_dims,
         expected_output_dims=expected_output_dims)
 
-  def GetConversionParams(self, run_params):
-    """Return a ConversionParams for test."""
-    conversion_params = super(DynamicInputShapesTest,
-                              self).GetConversionParams(run_params)
-    conversion_params._replace(maximum_cached_engines=10)
-    rewrite_config_with_trt = self.GetTrtRewriterConfig(
-        run_params=run_params,
-        conversion_params=conversion_params,
-        # Disable layout optimizer, since it will convert BiasAdd with NHWC
-        # format to NCHW format under four dimensional input.
-        disable_non_trt_optimizers=True)
-    return conversion_params._replace(
-        rewriter_config_template=rewrite_config_with_trt)
+  def setUp(self):
+    super(trt_test.TfTrtIntegrationTestBase, self).setUp()
+    # Disable layout optimizer, since it will convert BiasAdd with NHWC
+    # format to NCHW format under four dimentional input.
+    self.DisableNonTrtOptimizers()
 
   def ExpectedEnginesToBuild(self, run_params):
     return ["TRTEngineOp_0"]
diff --git a/tensorflow/python/compiler/tensorrt/test/int32_test.py b/tensorflow/python/compiler/tensorrt/test/int32_test.py
index ecc68656a60..638dbb5727a 100644
--- a/tensorflow/python/compiler/tensorrt/test/int32_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/int32_test.py
@@ -46,19 +46,17 @@ class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase):
   def GetParams(self):
     return self.BuildParams(self.GraphFn, dtypes.int32, [[100, 4]], [[100, 10]])
 
-  def GetConversionParams(self, run_params):
-    """Return a ConversionParams for test."""
-    conversion_params = super(ExcludeUnsupportedInt32Test,
-                              self).GetConversionParams(run_params)
-    conversion_params._replace(max_batch_size=100, maximum_cached_engines=1)
-    rewrite_config_with_trt = self.GetTrtRewriterConfig(
-        run_params=run_params,
-        conversion_params=conversion_params,
-        # Disable layout optimizer, since it will convert BiasAdd with NHWC
-        # format to NCHW format under four dimensional input.
-        disable_non_trt_optimizers=True)
-    return conversion_params._replace(
-        rewriter_config_template=rewrite_config_with_trt)
+  def setUp(self):
+    super(trt_test.TfTrtIntegrationTestBase, self).setUp()
+    # Disable layout optimizer, since it will convert BiasAdd with NHWC
+    # format to NCHW format under four dimentional input.
+    self.DisableNonTrtOptimizers()
+
+  def GetMaxBatchSize(self, run_params):
+    """Returns the max_batch_size that the converter should use for tests."""
+    if run_params.dynamic_engine:
+      return None
+    return 100
 
   def ExpectedEnginesToBuild(self, run_params):
     """Return the expected engines to build."""
diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
index 2265a19cf62..fc32e7e3b3d 100644
--- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
@@ -159,6 +159,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
   def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
     super(TfTrtIntegrationTestBase, self).__init__(methodName)
     self._trt_test_params = None
+    self._disable_non_trt_optimizers = False
+    self._use_implicit_batch = True
 
   def setUp(self):
     """Setup method."""
@@ -257,12 +259,33 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
         input_dims=[input_shapes] + extra_inputs,
         expected_output_dims=[output_shapes] + extra_outputs)
 
+  def DisableNonTrtOptimizers(self):
+    self._disable_non_trt_optimizers = True
+
+  def DisableImplicitBatchMode(self):
+    self._use_implicit_batch = False
+
   def GetParams(self):
-    """Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
+    """Returns a TfTrtIntegrationTestParams for the test."""
     raise NotImplementedError()
 
   def GetConversionParams(self, run_params):
-    """Return a TrtConversionParams for test."""
+    """Returns a TrtConversionParams for test."""
+    conversion_params = trt_convert.TrtConversionParams(
+        # We use the minimum of all the batch sizes, so when multiple different
+        # input shapes are provided it'll always create new engines in the
+        # cache, and we can therefore test the cache behavior.
+        max_workspace_size_bytes=1 << 25,
+        precision_mode=run_params.precision_mode,
+        minimum_segment_size=2,
+        maximum_cached_engines=1,
+        use_calibration=run_params.use_calibration)
+    return conversion_params
+
+  def GetMaxBatchSize(self, run_params):
+    """Returns the max_batch_size that the converter should use for tests."""
+    if run_params.dynamic_engine:
+      return None
     batch_list = []
     for dims_list in self._GetParamsCached().input_dims:
       assert dims_list
@@ -270,33 +293,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
       input_batches = [dims[0] for dims in dims_list]
       assert max(input_batches) == min(input_batches)
       batch_list.append(input_batches[0])
-    conversion_params = trt_convert.TrtConversionParams(
-        # We use the minimum of all the batch sizes, so when multiple different
-        # input shapes are provided it'll always create new engines in the
-        # cache, and we can therefore test the cache behavior.
-        rewriter_config_template=None,
-        max_workspace_size_bytes=1 << 25,
-        precision_mode=run_params.precision_mode,
-        minimum_segment_size=2,
-        is_dynamic_op=run_params.dynamic_engine,
-        maximum_cached_engines=1,
-        use_calibration=run_params.use_calibration,
-        max_batch_size=max(batch_list))
-    return conversion_params
-
-  def GetTrtRewriterConfig(self,
-                           run_params,
-                           conversion_params,
-                           disable_non_trt_optimizers=False,
-                           use_implicit_batch=True):
-    rewriter_config = trt_convert.get_tensorrt_rewriter_config(
-        conversion_params=conversion_params,
-        is_v2=run_params.is_v2,
-        disable_non_trt_optimizers=disable_non_trt_optimizers)
-    for optimizer in rewriter_config.custom_optimizers:
-      if optimizer.name == "TensorRTOptimizer":
-        optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
-    return rewriter_config
+    return max(batch_list)
 
   def ShouldRunTest(self, run_params):
     """Whether to run the test."""
@@ -333,14 +330,15 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
   def _GetConfigProto(self, run_params, graph_state):
     """Get config proto based on specific settings."""
     conversion_params = self.GetConversionParams(run_params)
+    max_batch_size = self.GetMaxBatchSize(run_params)
     if graph_state == GraphState.INFERENCE and run_params.convert_online:
-      rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(conversion_params)
+      rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
+          conversion_params,
+          is_dynamic_op=run_params.dynamic_engine,
+          max_batch_size=max_batch_size)
       graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
     else:
       graph_options = config_pb2.GraphOptions()
-      if conversion_params.rewriter_config_template is not None:
-        graph_options.rewrite_options.CopyFrom(
-            conversion_params.rewriter_config_template)
 
     config = config_pb2.ConfigProto(
         gpu_options=self._GetGPUOptions(), graph_options=graph_options)
@@ -444,30 +442,37 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
       config = self._GetConfigProto(run_params, GraphState.INFERENCE)
     return self._RunGraphV1(saved_model_dir, inputs_data, config, num_runs)
 
-  def _CreateConverter(self, run_params, saved_model_dir, session_config,
-                       conversion_params):
-    """Return a TrtGraphConverter."""
+  def _CreateConverter(self, run_params, saved_model_dir, conversion_params):
+    """Returns a TrtGraphConverter."""
     if run_params.is_v2:
-      return trt_convert.TrtGraphConverterV2(
+      converter_v2 = trt_convert.TrtGraphConverterV2(
           input_saved_model_dir=saved_model_dir,
           conversion_params=conversion_params)
-    return trt_convert.TrtGraphConverter(
+      if self._disable_non_trt_optimizers:
+        converter_v2._test_only_disable_non_trt_optimizers = True  # pylint: disable=protected-access
+      if not self._use_implicit_batch:
+        converter_v2._test_only_use_implicit_batch = False  # pylint: disable=protected-access
+      return converter_v2
+
+    converter_v1 = trt_convert.TrtGraphConverter(
         input_saved_model_dir=saved_model_dir,
-        session_config=session_config,
-        max_batch_size=conversion_params.max_batch_size,
+        max_batch_size=self.GetMaxBatchSize(run_params),
         max_workspace_size_bytes=conversion_params.max_workspace_size_bytes,
         precision_mode=conversion_params.precision_mode,
         minimum_segment_size=conversion_params.minimum_segment_size,
-        is_dynamic_op=conversion_params.is_dynamic_op,
+        is_dynamic_op=run_params.dynamic_engine,
         maximum_cached_engines=conversion_params.maximum_cached_engines,
         use_calibration=conversion_params.use_calibration)
+    if self._disable_non_trt_optimizers:
+      converter_v1._test_only_disable_non_trt_optimizers = True  # pylint: disable=protected-access
+    return converter_v1
 
   def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data):
     """Return trt converted graphdef in INT8 mode."""
     conversion_params = self.GetConversionParams(run_params)
     logging.info(conversion_params)
     assert conversion_params.precision_mode == "INT8"
-    assert conversion_params.is_dynamic_op
+    assert run_params.dynamic_engine
     assert conversion_params.maximum_cached_engines == 1
     assert conversion_params.use_calibration
 
@@ -475,11 +480,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
     # TODO(aaroey): fix this.
     assert len(inputs_data) == 1
 
-    session_config = self._GetConfigProto(run_params, GraphState.CALIBRATE)
-    logging.info("Running calibration graph, config:\n%s", str(session_config))
-
     converter = self._CreateConverter(run_params, saved_model_dir,
-                                      session_config, conversion_params)
+                                      conversion_params)
     int8_gdef = converter.convert()
     self._VerifyGraphDef(run_params, saved_model_dir, int8_gdef,
                          GraphState.CALIBRATE)
@@ -498,15 +500,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
     conversion_params = self.GetConversionParams(run_params)
     logging.info(conversion_params)
 
-    session_config = self._GetConfigProto(run_params, GraphState.INFERENCE)
-    logging.info("Creating TRT graph for inference, config\n%s",
-                 str(session_config))
     converter = self._CreateConverter(run_params, saved_model_dir,
-                                      session_config, conversion_params)
+                                      conversion_params)
     converter.convert()
 
-    if trt_convert.is_explicit_batch_mode_enabled(
-        conversion_params.rewriter_config_template):
+    if not self._use_implicit_batch:
       logging.info("Using build mode")
 
       def _BuildInputFn():
@@ -684,7 +682,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
       original_gdef: GraphDef. The graph def before TensorRT conversion.
       converted_gdef: GraphDef. The graph def after TensorRT conversion.
       default_max_batch_size: The default maximum batch size to use if no node
-        inside a segment is annoted with a customized max batch size.
+        inside a segment is annoted with a customized max batch size. This value
+        is None when the graph is converted to TF-TRT with dynamic engines.
       expected_max_batch_sizes: Optional. A sequence of max batch sizes for all
         the engines. `None` if does not check enforce max batch sizes.
     """
@@ -780,7 +779,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
                           node_max_batch_size is None)
         logging.info("'{%s}'s max batch size is %d.", engine_name,
                      engine_max_batch_size)
-        self.assertTrue(engine_max_batch_size == default_max_batch_size or
+        self.assertTrue(default_max_batch_size is None or
+                        engine_max_batch_size == default_max_batch_size or
                         not node_max_batch_size_all_none)
 
     self.assertCountEqual(expected_engines, tuple(name_to_engines_map.keys()))
@@ -851,9 +851,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
           original_gdef=original_gdef,
           converted_gdef=gdef_to_verify,
           expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params),
-          default_max_batch_size=self.GetConversionParams(
-              run_params).max_batch_size,
-      )
+          default_max_batch_size=self.GetMaxBatchSize(run_params))
 
   def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify,
                         graph_state):
@@ -886,14 +884,6 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
       expected_engines = set(expected_engines.keys())
 
     self.assertEqual(set(expected_engines), trt_op_names)
-    self._VerifyMaxBatchSizeAnnotations(
-        expected_engines=expected_engines,
-        original_gdef=original_gdef,
-        converted_gdef=gdef_to_verify,
-        expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params),
-        default_max_batch_size=self.GetConversionParams(
-            run_params).max_batch_size,
-    )
 
   def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir,
                       gdef_or_saved_model_dir_to_verify, graph_state):
diff --git a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py
index 5c7f261fa98..f5eb3c75653 100644
--- a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py
@@ -59,25 +59,6 @@ class TrtModeTestBase(trt_test.TfTrtIntegrationTestBase):
     return self.BuildParams(self.GraphFn, dtypes.float32, [[1, 12, 5]],
                             [[12, 5]])
 
-  def GetConversionParams(self,
-                          run_params,
-                          max_batch_size=0,
-                          implicit_batch=False):
-    """Return a TrtConversionParams for test."""
-
-    conversion_params = super(TrtModeTestBase,
-                              self).GetConversionParams(run_params)
-    # If max_batch_size!=0, use the value for conversion_params.
-    if max_batch_size and implicit_batch:
-      conversion_params = conversion_params._replace(
-          max_batch_size=max_batch_size)
-
-    rewriter_config = self.GetTrtRewriterConfig(
-        run_params=run_params,
-        conversion_params=conversion_params,
-        use_implicit_batch=implicit_batch)
-    return conversion_params._replace(rewriter_config_template=rewriter_config)
-
   @classmethod
   def setUpClass(cls):
     if cls is TrtModeTestBase:
@@ -87,12 +68,13 @@ class TrtModeTestBase(trt_test.TfTrtIntegrationTestBase):
 
 class ImplicitBatchTest(TrtModeTestBase):
 
-  def GetConversionParams(self, run_params):
-    """Return a TrtConversionParams for test using implicit batch mdoe."""
+  def GetMaxBatchSize(self, run_params):
+    if run_params.dynamic_engine:
+      return None
+
     # The first dimension of the input is squeezed and the batch size for the
     # rest OPs is 12.
-    return super(ImplicitBatchTest,
-                 self).GetConversionParams(run_params, 12, True)
+    return 12
 
   def ExpectedEnginesToBuild(self, run_params):
     """Check that the expected engine is built.
@@ -123,11 +105,6 @@ class ExplicitBatchTest(TrtModeTestBase):
         extra_inputs=[],
         extra_outputs=[])
 
-  def GetConversionParams(self, run_params):
-    """Return a TrtConversionParams for test that enables explicit batch."""
-    return super(ExplicitBatchTest, self).GetConversionParams(
-        run_params, implicit_batch=False)
-
   def ExpectedEnginesToBuild(self, run_params):
     """Check that the expected engine is built.
 
@@ -146,6 +123,11 @@ class ExplicitBatchTest(TrtModeTestBase):
     return run_params.is_v2 and trt_test.IsTensorRTVersionGreaterEqual(6) and (
         not run_params.use_calibration), "test v2, >=TRT6 and non-calibration"
 
+  def setUp(self):
+    super().setUp()
+    # Diable implicit batch mode for testing explicit batch mode.
+    self.DisableImplicitBatchMode()
+
 
 class DynamicShapesTest(TrtModeTestBase):
   """Test with dynamic input shapes.
@@ -169,10 +151,6 @@ class DynamicShapesTest(TrtModeTestBase):
         input_mask=[[False, False, False]],
         output_mask=[[False, False]])
 
-  def GetConversionParams(self, run_params):
-    """Return a TrtConversionParams for test that enables explicit batch."""
-    return super(DynamicShapesTest, self).GetConversionParams(run_params, False)
-
   def ExpectedEnginesToBuild(self, run_params):
     """Return the expected engines to build."""
     return ["TRTEngineOp_0"]
@@ -182,6 +160,9 @@ class DynamicShapesTest(TrtModeTestBase):
     return run_params.is_v2 and trt_test.IsTensorRTVersionGreaterEqual(6) and (
         not run_params.use_calibration), "test v2 >=TRT6 and non-calibration"
 
+  def setUp(self):
+    super().setUp()
+    self.DisableImplicitBatchMode()
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py
index 315da90598d..734b13aad1e 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert.py
@@ -117,16 +117,12 @@ DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
 @tf_export("experimental.tensorrt.ConversionParams", v1=[])
 class TrtConversionParams(
     collections.namedtuple("TrtConversionParams", [
-        "rewriter_config_template", "max_workspace_size_bytes",
-        "precision_mode", "minimum_segment_size", "is_dynamic_op",
-        "maximum_cached_engines", "use_calibration", "max_batch_size",
-        "allow_build_at_runtime"
+        "max_workspace_size_bytes", "precision_mode", "minimum_segment_size",
+        "maximum_cached_engines", "use_calibration", "allow_build_at_runtime"
     ])):
   """Parameters that are used for TF-TRT conversion.
 
   Fields:
-    rewriter_config_template: a template RewriterConfig proto used to create a
-      TRT-enabled RewriterConfig. If None, it will use a default one.
     max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
       engine can use at execution time. This corresponds to the
       'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
@@ -134,11 +130,6 @@ class TrtConversionParams(
       TrtPrecisionMode.supported_precision_modes().
     minimum_segment_size: the minimum number of nodes required for a subgraph
       to be replaced by TRTEngineOp.
-    is_dynamic_op: whether to generate dynamic TRT ops which will build the
-      TRT network and engine at run time. i.e. Since TensorRT version < 6.0
-      does not support dynamic dimensions other than the batch dimension, when
-      the TensorFlow graph has a non-batch dimension of dynamic size, we would
-      need to enable this option. This option should be set to True in TF 2.0.
     maximum_cached_engines: max number of cached TRT engines for dynamic TRT
       ops. Created TRT engines for a dynamic dimension are cached. This is the
       maximum number of engines that can be cached. If the number of cached
@@ -154,31 +145,23 @@ class TrtConversionParams(
       will occur. Please note that accuracy may be negatively affected if
       there is a mismatch between which tensors TRT quantizes and which
       tensors were trained with fake quantization.
-    max_batch_size: max size for the input batch. This parameter is only
-      effective when use_implicit_batch is true.
     allow_build_at_runtime: whether to build TensorRT engines during runtime.
       If no TensorRT engine can be found in cache that can handle the given
       inputs during runtime, then a new TensorRT engine is built at runtime if
-      allow_build_at_runtime=True, and otherwise native TF is used. This
-      argument is only effective if is_dynamic_op=True.
+      allow_build_at_runtime=True, and otherwise native TF is used.
   """
 
   def __new__(cls,
-              rewriter_config_template=None,
               max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
               precision_mode=TrtPrecisionMode.FP32,
               minimum_segment_size=3,
-              is_dynamic_op=True,
               maximum_cached_engines=1,
               use_calibration=True,
-              max_batch_size=1,
               allow_build_at_runtime=True):
     return super(TrtConversionParams,
-                 cls).__new__(cls, rewriter_config_template,
-                              max_workspace_size_bytes, precision_mode,
-                              minimum_segment_size, is_dynamic_op,
-                              maximum_cached_engines, use_calibration,
-                              max_batch_size, allow_build_at_runtime)
+                 cls).__new__(cls, max_workspace_size_bytes, precision_mode,
+                              minimum_segment_size, maximum_cached_engines,
+                              use_calibration, allow_build_at_runtime)
 
 
 DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams()
@@ -203,51 +186,6 @@ def _check_conversion_params(conversion_params, is_v2=False):
         ("precision mode '{}' is not supported."
          "It should be one of {}").format(conversion_params.precision_mode,
                                           supported_precision_modes))
-  if is_v2:
-    # Static mode (building TRT engine without executing the op) is deprecated
-    # in TF 2.0. See TrtGraphConverterV2 for more details.
-    if not conversion_params.is_dynamic_op:
-      raise ValueError("Option is_dynamic_op=False is not supported in TF 2.0, "
-                       "please set it to True instead.")
-
-  if conversion_params.rewriter_config_template:
-    rewriter_cfg = conversion_params.rewriter_config_template
-    trt_optimizer = None
-    for optimizer in rewriter_cfg.custom_optimizers:
-      if optimizer.name == "TensorRTOptimizer":
-        if trt_optimizer:
-          raise ValueError(
-              "Found more than one TensorRTOptimizer in "
-              "rewriter_config_template while only one is allowed.")
-        trt_optimizer = optimizer
-    # If rewriter_config_template is set, it should include TensorRTOptimizer.
-    # It is possible to remove this requirement if needed.
-    if not trt_optimizer:
-      raise ValueError(
-          "Found no TensorRTOptimizer in rewriter_config_template.")
-    if not trt_optimizer.parameter_map:
-      raise ValueError("Found no parameter_map in TensorRTOptimizer.")
-    if ("precision_mode" in trt_optimizer.parameter_map.keys() and
-        trt_optimizer.parameter_map["precision_mode"].s not in map(
-            _to_bytes, supported_precision_modes)):
-      raise ValueError(("precision_mode '{}' is not supported. "
-                        "It should be one of {}").format(
-                            trt_optimizer.parameter_map["precision_mode"],
-                            supported_precision_modes))
-    if is_v2:
-      # Static mode (building TRT engine without executing the op) is not
-      # supported in TF 2.0. See TrtGraphConverterV2 for more details.
-      if ("is_dynamic_op" in trt_optimizer.parameter_map.keys() and
-          not trt_optimizer.parameter_map["is_dynamic_op"]):
-        raise ValueError("Option is_dynamic_op=False is not supported "
-                         "in TF 2.0, please set it to True instead.")
-  if (conversion_params.allow_build_at_runtime and
-      not conversion_params.is_dynamic_op):
-    tf_logging.warn(
-        ("Building TensorRT engines at runtime is not supported "
-         "if is_dynamic_op=False, therefore assuming "
-         "allow_build_at_runtime=False. If building TensorRT engines "
-         "at runtime is desired, set is_dynamic_op=True."))
 
 
 def _check_trt_version_compatibility():
@@ -290,15 +228,21 @@ def _check_trt_version_compatibility():
         " minor/patch upgrades are backward compatible")
 
 
-def get_tensorrt_rewriter_config(conversion_params,
-                                 is_v2=False,
-                                 disable_non_trt_optimizers=False):
+def _get_tensorrt_rewriter_config(conversion_params,
+                                  is_dynamic_op=None,
+                                  max_batch_size=None,
+                                  is_v2=False,
+                                  disable_non_trt_optimizers=False,
+                                  use_implicit_batch=True):
   """Returns a RewriterConfig proto for TRT transformation.
 
   Args:
     conversion_params: a TrtConversionParams instance.
+    is_dynamic_op: whether to use dynamic engines.
+    max_batch_size: maximum batch size for static engines.
     is_v2: whether we're getting a RewriterConfig for TF 2.0.
     disable_non_trt_optimizers: Turn off all default Grappler optimizers.
+    use_implicit_batch: Whether to use implicit batch or explicit batch.
 
   Returns:
     A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
@@ -307,46 +251,49 @@ def get_tensorrt_rewriter_config(conversion_params,
     TypeError: if any of the parameters are of unexpected type.
     ValueError: if any of the parameters are of unexpected value.
   """
-  if conversion_params.rewriter_config_template is not None and not isinstance(
-      conversion_params.rewriter_config_template,
-      rewriter_config_pb2.RewriterConfig):
-    raise TypeError(
-        "rewriter_config_template should be a RewriterConfig proto.")
   _check_conversion_params(conversion_params, is_v2=is_v2)
+  if is_v2 and is_dynamic_op is not None and not is_dynamic_op:
+    raise ValueError("is_dynamic_op is either None or True for TF2")
+  if not is_v2 and is_dynamic_op is None:
+    raise ValueError("is_dynamic_op can't be None for TF1")
 
+  if (is_dynamic_op is None or is_dynamic_op) and max_batch_size is not None:
+    raise ValueError("max_batch_size has to be None for TF2"
+                     " or when is_dynamic_op == True in TF1")
+  if is_dynamic_op is not None and not is_dynamic_op and not isinstance(
+      max_batch_size, int):
+    raise ValueError(
+        "max_batch_size has to be an integer for is_dynamic_op==False in TF1")
   rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
 
-  if conversion_params.rewriter_config_template is None:
-    if not disable_non_trt_optimizers:
-      # Layout optimizer may add Const nodes followed by Reshape nodes, thus we
-      # need to run constant folding again.
-      rewriter_config_with_trt.optimizers.extend(
-          ["constfold", "layout", "constfold"])
-    rewriter_config_with_trt.meta_optimizer_iterations = (
-        rewriter_config_pb2.RewriterConfig.ONE)
-    optimizer = rewriter_config_with_trt.custom_optimizers.add()
-    # Add a constfold optimizer to cleanup the unused Const nodes.
-    rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
+  if not disable_non_trt_optimizers:
+    # Layout optimizer may add Const nodes followed by Reshape nodes, thus we
+    # need to run constant folding again.
+    rewriter_config_with_trt.optimizers.extend(
+        ["constfold", "layout", "constfold"])
+  rewriter_config_with_trt.meta_optimizer_iterations = (
+      rewriter_config_pb2.RewriterConfig.ONE)
+  optimizer = rewriter_config_with_trt.custom_optimizers.add()
+  # Add a constfold optimizer to cleanup the unused Const nodes.
+  rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
 
-    optimizer.name = "TensorRTOptimizer"
-    optimizer.parameter_map[
-        "minimum_segment_size"].i = conversion_params.minimum_segment_size
-    optimizer.parameter_map["max_workspace_size_bytes"].i = (
-        conversion_params.max_workspace_size_bytes)
-    optimizer.parameter_map["precision_mode"].s = _to_bytes(
-        conversion_params.precision_mode)
-    optimizer.parameter_map[
-        "maximum_cached_engines"].i = conversion_params.maximum_cached_engines
-    optimizer.parameter_map[
-        "use_calibration"].b = conversion_params.use_calibration
-    optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op
-    optimizer.parameter_map[
-        "allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime
-    optimizer.parameter_map[
-        "max_batch_size"].i = conversion_params.max_batch_size
-  else:
-    rewriter_config_with_trt.CopyFrom(
-        conversion_params.rewriter_config_template)
+  optimizer.name = "TensorRTOptimizer"
+  optimizer.parameter_map[
+      "minimum_segment_size"].i = conversion_params.minimum_segment_size
+  optimizer.parameter_map["max_workspace_size_bytes"].i = (
+      conversion_params.max_workspace_size_bytes)
+  optimizer.parameter_map["precision_mode"].s = _to_bytes(
+      conversion_params.precision_mode)
+  optimizer.parameter_map[
+      "maximum_cached_engines"].i = conversion_params.maximum_cached_engines
+  optimizer.parameter_map[
+      "use_calibration"].b = conversion_params.use_calibration
+  optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
+  optimizer.parameter_map[
+      "allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime
+  if max_batch_size is not None:
+    optimizer.parameter_map["max_batch_size"].i = max_batch_size
+  optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
 
   # Disabling optimizers should happen after CopyFrom the template
   # otherwise the template can overwrite the disablement.
@@ -371,6 +318,18 @@ def get_tensorrt_rewriter_config(conversion_params,
   return rewriter_config_with_trt
 
 
+@deprecation.deprecated(
+    None, "You shouldn't need a rewriter_config with the current TF-TRT APIs.")
+def get_tensorrt_rewriter_config(conversion_params,
+                                 is_dynamic_op=None,
+                                 max_batch_size=None,
+                                 is_v2=False,
+                                 disable_non_trt_optimizers=False):
+  return _get_tensorrt_rewriter_config(conversion_params, is_dynamic_op,
+                                       max_batch_size, is_v2,
+                                       disable_non_trt_optimizers)
+
+
 # Remove all scope prefixes in the node name. In TF 2.0, the same concrete
 # function can be initialized multiple times with different prefixes, and
 # this will result in the same TRTEngineOp being initialized multiple times
@@ -384,17 +343,6 @@ def _get_canonical_engine_name(name):
   return name.split("/")[-1]
 
 
-def is_explicit_batch_mode_enabled(rewriter_config):
-  """Checks whether explicit batch is enabled by the rewriter config."""
-  if rewriter_config is None:
-    return False
-  for optimizer in rewriter_config.custom_optimizers:
-    if optimizer.name == "TensorRTOptimizer":
-      if "use_implicit_batch" in optimizer.parameter_map:
-        return not optimizer.parameter_map["use_implicit_batch"].b
-  return False
-
-
 class TrtGraphConverter(object):
   """A converter for TF-TRT transformation for TF 1.x GraphDef/SavedModels.
 
@@ -427,15 +375,12 @@ class TrtGraphConverter(object):
   ```
   """
 
-  @deprecation.deprecated_args(None, "Remove the use of this argument",
-                               "session_config")
   def __init__(self,
                input_saved_model_dir=None,
                input_saved_model_tags=None,
                input_saved_model_signature_key=None,
                input_graph_def=None,
                nodes_denylist=None,
-               session_config=None,
                max_batch_size=1,
                max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
                precision_mode=TrtPrecisionMode.FP32,
@@ -443,7 +388,7 @@ class TrtGraphConverter(object):
                is_dynamic_op=False,
                maximum_cached_engines=1,
                use_calibration=True):
-    """Initialize the converter.
+    """Initializes the converter.
 
     Args:
       input_saved_model_dir: the directory to load the SavedModel which contains
@@ -454,11 +399,7 @@ class TrtGraphConverter(object):
       input_graph_def: a GraphDef object containing a model to be transformed.
         If set to None, the graph will be read from the SavedModel loaded from
         input_saved_model_dir.
-      nodes_denylist: list of node names to prevent the converter from
-        touching.
-      session_config: the ConfigProto used to create a Session. It's also used
-        as a template to create a TRT-enabled ConfigProto for conversion. If not
-        specified, a default ConfigProto will be used.
+      nodes_denylist: list of node names to prevent the converter from touching.
       max_batch_size: max size for the input batch.
       max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
         engine can use at execution time. This corresponds to the
@@ -510,7 +451,6 @@ class TrtGraphConverter(object):
     self._input_saved_model_signature_key = (
         input_saved_model_signature_key or
         signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
-    self._session_config = session_config or config_pb2.ConfigProto()
 
     # For calibration usage.
     self._calibration_graph = None
@@ -523,33 +463,39 @@ class TrtGraphConverter(object):
           "dynamic TRT ops only. Disregarding is_dynamic_op parameter.")
       is_dynamic_op = True
 
-    # TODO(laigd):
-    # - Verify in int8 mode that maximum_cached_engines is set properly.
-    # - If it fails to build the int8 engine it should return error.
-    rewriter_config_template = None
-    if (session_config and session_config.HasField("graph_options") and
-        session_config.graph_options.HasField("rewrite_options")):
-      rewriter_config_template = session_config.graph_options.rewrite_options
+    self._is_dynamic_op = is_dynamic_op
+    if is_dynamic_op:
+      self._max_batch_size = None
+      if max_batch_size is not None:
+        tf_logging.warn("When is_dynamic_op==True max_batch_size should be "
+                        "None")
+    else:
+      if not isinstance(max_batch_size, int):
+        raise ValueError("When is_dynamic_op==False max_batch_size should be "
+                         "an integer")
+      self._max_batch_size = max_batch_size
 
     self._conversion_params = TrtConversionParams(
-        rewriter_config_template=rewriter_config_template,
         max_workspace_size_bytes=max_workspace_size_bytes,
         precision_mode=precision_mode,
         minimum_segment_size=minimum_segment_size,
-        is_dynamic_op=is_dynamic_op,
         maximum_cached_engines=maximum_cached_engines,
         use_calibration=use_calibration,
-        max_batch_size=max_batch_size,
         allow_build_at_runtime=True)
     _check_conversion_params(self._conversion_params)
 
+    self._test_only_disable_non_trt_optimizers = False
+
   def _run_conversion(self):
     """Run Grappler's OptimizeGraph() tool to convert the graph."""
     # Create custom ConfigProto for Grappler.
     grappler_session_config = config_pb2.ConfigProto()
-    grappler_session_config.CopyFrom(self._session_config)
-    custom_rewriter_config = get_tensorrt_rewriter_config(
-        conversion_params=self._conversion_params)
+    custom_rewriter_config = _get_tensorrt_rewriter_config(
+        conversion_params=self._conversion_params,
+        is_dynamic_op=self._is_dynamic_op,
+        max_batch_size=self._max_batch_size,
+        disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers,
+        use_implicit_batch=True)
     grappler_session_config.graph_options.rewrite_options.CopyFrom(
         custom_rewriter_config)
 
@@ -596,7 +542,7 @@ class TrtGraphConverter(object):
   def _convert_saved_model(self):
     """Convert the input SavedModel."""
     graph = ops.Graph()
-    with session.Session(graph=graph, config=self._session_config) as sess:
+    with session.Session(graph=graph) as sess:
       input_meta_graph_def = loader.load(sess, self._input_saved_model_tags,
                                          self._input_saved_model_dir)
       input_signature_def = input_meta_graph_def.signature_def[
@@ -706,9 +652,13 @@ class TrtGraphConverter(object):
           return_elements=fetch_names,
           name="")
 
+    # Set allow_soft_placement=True to run the graph for calibration so that
+    # OPs supported by TensorRT but don't have a GPU implementation are allowed
+    # to execute on CPU.
+    calibrate_config = config_pb2.ConfigProto(allow_soft_placement=True)
     with session.Session(
         graph=self._calibration_graph,
-        config=self._session_config) as calibration_sess:
+        config=calibrate_config) as calibration_sess:
       for _ in range(num_runs):
         calibration_sess.run(
             fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None)
@@ -823,7 +773,7 @@ class TrtGraphConverter(object):
           self._collections_to_keep(
               self._grappler_meta_graph_def.collection_def))
       # We don't use any specific converter here.
-      with session.Session(config=self._session_config) as sess:
+      with session.Session() as sess:
         saved_model_builder.add_meta_graph_and_variables(
             sess,
             self._input_saved_model_tags,
@@ -884,11 +834,6 @@ class TrtGraphConverterV2(object):
 
   Currently this is not available on Windows platform.
 
-  Note that in V2, is_dynamic_op=False is not supported, meaning TRT engines
-  will be built only when the corresponding TRTEngineOp is executed. But we
-  still provide a way to avoid the cost of building TRT engines during inference
-  (see more below).
-
   There are several ways to run the conversion:
 
   1. FP32/FP16 precision
@@ -997,8 +942,6 @@ class TrtGraphConverterV2(object):
     assert context.executing_eagerly()
     if conversion_params is None:
       conversion_params = TrtConversionParams()
-    elif conversion_params.rewriter_config_template is not None:
-      tf_logging.warn("the rewrite_config_template field will be deprecated.")
 
     _check_trt_version_compatibility()
     _check_conversion_params(conversion_params, is_v2=True)
@@ -1010,22 +953,21 @@ class TrtGraphConverterV2(object):
     self._input_saved_model_signature_key = (
         input_saved_model_signature_key or
         signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
-    self._rewriter_config = get_tensorrt_rewriter_config(
-        conversion_params=self._conversion_params, is_v2=True)
 
     self._need_calibration = (
         conversion_params.precision_mode == TrtPrecisionMode.INT8 and
         conversion_params.use_calibration)
-    if (self._need_calibration and not conversion_params.is_dynamic_op):
-      raise ValueError("INT8 precision mode with calibration is not supported "
-                       "with static TensorRT ops. Set is_dynamic_op to True.")
 
-    # rewriter_config is already validated
-    self._need_trt_profiles = is_explicit_batch_mode_enabled(
-        self._rewriter_config)
     self._converted = False
     self._build_called_once = False
 
+    # Fields to support TF-TRT testing and shouldn't be used for other purpose.
+    self._test_only_disable_non_trt_optimizers = False
+    self._test_only_use_implicit_batch = True
+
+  def _need_trt_profiles(self):
+    return not self._test_only_use_implicit_batch
+
   def _run_conversion(self, meta_graph_def):
     """Run Grappler's OptimizeGraph() tool to convert the graph.
 
@@ -1036,8 +978,14 @@ class TrtGraphConverterV2(object):
       The optimized GraphDef.
     """
     grappler_session_config = config_pb2.ConfigProto()
+    custom_rewriter_config = _get_tensorrt_rewriter_config(
+        conversion_params=self._conversion_params,
+        is_dynamic_op=True,
+        max_batch_size=None,
+        disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers,
+        use_implicit_batch=self._test_only_use_implicit_batch)
     grappler_session_config.graph_options.rewrite_options.CopyFrom(
-        self._rewriter_config)
+        custom_rewriter_config)
     return tf_optimizer.OptimizeGraph(
         grappler_session_config, meta_graph_def, graph_id=b"tf_graph")
 
@@ -1167,7 +1115,7 @@ class TrtGraphConverterV2(object):
     def _set_profile_generation_mode(value, node):
       node.attr["_profile_generation_mode"].b = value
 
-    if self._need_trt_profiles:
+    if self._need_trt_profiles():
       # Enable profile generation.
       self._for_each_trt_node(self._converted_graph_def,
                               partial(_set_profile_generation_mode, True))
@@ -1187,7 +1135,7 @@ class TrtGraphConverterV2(object):
         first_input = inp
       func(*map(ops.convert_to_tensor, inp))
 
-    if self._need_trt_profiles:
+    if self._need_trt_profiles():
       # Disable profile generation.
       self._for_each_trt_node(self._converted_graph_def,
                               partial(_set_profile_generation_mode, False))
@@ -1208,7 +1156,7 @@ class TrtGraphConverterV2(object):
     """
     assert self._converted
 
-    if self._need_trt_profiles and not self._build_called_once:
+    if self._need_trt_profiles() and not self._build_called_once:
       raise NotImplementedError(
           "build() is not called . Explicit batch mode "
           "(use_implicit_batch=False) requires generating TensorRT optimization"
@@ -1295,8 +1243,7 @@ def create_inference_graph(
     input_saved_model_dir=None,
     input_saved_model_tags=None,
     input_saved_model_signature_key=None,
-    output_saved_model_dir=None,
-    session_config=None):
+    output_saved_model_dir=None):
   """Python wrapper for the TRT transformation.
 
   Args:
@@ -1327,9 +1274,6 @@ def create_inference_graph(
       returned GraphDef and save it to the specified directory. This option only
       works when the input graph is loaded from a SavedModel, i.e. when
       input_saved_model_dir is specified and input_graph_def is None.
-    session_config: the ConfigProto used to create a Session. It's also used as
-      a template to create a TRT-enabled ConfigProto for conversion. If not
-      specified, a default ConfigProto will be used.
 
   Returns:
     A GraphDef transformed from input_graph_def (or the SavedModel graph def
@@ -1359,7 +1303,6 @@ def create_inference_graph(
       input_saved_model_signature_key=input_saved_model_signature_key,
       input_graph_def=input_graph_def,
       nodes_denylist=outputs,
-      session_config=session_config,
       max_batch_size=max_batch_size,
       max_workspace_size_bytes=max_workspace_size_bytes,
       precision_mode=precision_mode,
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
index 0baae0bd3bf..d19f2d03e30 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
@@ -30,7 +30,6 @@ from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
 from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance  # pylint: disable=g-importing-member
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python.compiler.tensorrt import trt_convert
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import dtypes
@@ -79,98 +78,6 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     # test if we can access the TRTEngineInstance protobuf
     assert hasattr(TRTEngineInstance(), "serialized_engine")
 
-  def testGetTensorrtRewriterConfig(self):
-    """Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
-    if not is_tensorrt_enabled():
-      return
-    conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
-        max_batch_size=128,
-        max_workspace_size_bytes=1234,
-        precision_mode="INT8",
-        minimum_segment_size=10,
-        is_dynamic_op=True,
-        maximum_cached_engines=2)
-    rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
-        conversion_params=conversion_params, is_v2=True)
-    self.assertEqual(["constfold", "layout", "constfold"],
-                     rewriter_cfg.optimizers)
-    self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE,
-                     rewriter_cfg.meta_optimizer_iterations)
-    trt_optimizer = None
-    for optimizer in rewriter_cfg.custom_optimizers:
-      if optimizer.name == "TensorRTOptimizer":
-        self.assertTrue(trt_optimizer is None)
-        trt_optimizer = optimizer
-    self.assertTrue(trt_optimizer is not None)
-    for key in [
-        "minimum_segment_size", "max_batch_size", "is_dynamic_op",
-        "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines"
-    ]:
-      self.assertTrue(key in trt_optimizer.parameter_map)
-    self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i)
-    self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i)
-    self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
-    self.assertEqual(1234,
-                     trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
-    self.assertEqual(
-        trt_convert._to_bytes("INT8"),
-        trt_optimizer.parameter_map["precision_mode"].s)
-    self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
-
-  def testGetTensorrtRewriterConfigTemplate(self):
-    """Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
-    if not is_tensorrt_enabled():
-      return
-
-    rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
-    rewriter_config_with_trt.optimizers.extend(
-        ["constfold", "layout", "constfold"])
-    rewriter_config_with_trt.meta_optimizer_iterations = (
-        rewriter_config_pb2.RewriterConfig.ONE)
-    optimizer = rewriter_config_with_trt.custom_optimizers.add()
-    rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
-    optimizer.name = "TensorRTOptimizer"
-    optimizer.parameter_map["minimum_segment_size"].i = 10
-    optimizer.parameter_map["max_batch_size"].i = 128
-    optimizer.parameter_map["is_dynamic_op"].b = True
-    optimizer.parameter_map["max_workspace_size_bytes"].i = 1234
-    optimizer.parameter_map["precision_mode"].s = trt_convert._to_bytes(
-        trt_convert.TrtPrecisionMode.INT8)
-    optimizer.parameter_map["maximum_cached_engines"].i = 2
-    optimizer.parameter_map["use_calibration"].b = False
-    optimizer.parameter_map["use_implicit_batch"].b = True
-
-    conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
-        rewriter_config_template=rewriter_config_with_trt)
-    rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
-        conversion_params=conversion_params)
-    self.assertEqual(["constfold", "layout", "constfold"],
-                     rewriter_cfg.optimizers)
-    self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE,
-                     rewriter_cfg.meta_optimizer_iterations)
-    trt_optimizer = None
-    for optimizer in rewriter_cfg.custom_optimizers:
-      if optimizer.name == "TensorRTOptimizer":
-        self.assertIsNone(trt_optimizer)
-        trt_optimizer = optimizer
-    self.assertIsNotNone(trt_optimizer)
-    for key in [
-        "minimum_segment_size", "max_batch_size", "is_dynamic_op",
-        "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines"
-    ]:
-      self.assertIn(key, trt_optimizer.parameter_map)
-    self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i)
-    self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i)
-    self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
-    self.assertEqual(1234,
-                     trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
-    self.assertEqual(
-        trt_convert._to_bytes("INT8"),
-        trt_optimizer.parameter_map["precision_mode"].s)
-    self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
-    self.assertEqual(False, trt_optimizer.parameter_map["use_calibration"].b)
-    self.assertEqual(True, trt_optimizer.parameter_map["use_implicit_batch"].b)
-
   def _GetConfigProto(self, rewriter_config=None):
     """Get ConfigProto for session creation."""
     config = config_pb2.ConfigProto(
@@ -280,13 +187,20 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       input_saved_model_dir = self.mkdtemp()
       self._WriteInputSavedModelForV1(input_saved_model_dir, device)
 
+    # Calibration requires dynamic_op.
+    if need_calibration:
+      is_dynamic_op = True
+
+    # For dynamic_op, the converter requires the unused max_batch_size=None.
+    if is_dynamic_op:
+      max_batch_size = None
+
     converter = trt_convert.TrtGraphConverter(
         input_saved_model_dir=input_saved_model_dir,
         input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
         input_graph_def=None
         if input_saved_model_dir else self._GetGraphDefForV1(device),
         nodes_denylist=None if input_saved_model_dir else ["output"],
-        session_config=self._GetConfigProto(),
         max_batch_size=max_batch_size,
         max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
         precision_mode=(trt_convert.TrtPrecisionMode.INT8 if need_calibration
@@ -437,10 +351,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       return
 
     conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
-        precision_mode=trt_convert.TrtPrecisionMode.FP32, is_dynamic_op=True)
+        precision_mode=trt_convert.TrtPrecisionMode.FP32)
     config = self._GetConfigProto(
         rewriter_config=trt_convert.get_tensorrt_rewriter_config(
-            conversion_params, is_v2=False))
+            conversion_params,
+            is_dynamic_op=False,
+            max_batch_size=1,
+            is_v2=False))
 
     with ops.Graph().as_default():
       # Online conversion requires a frozen graph, so we reuse inp1 as the var
@@ -463,7 +380,6 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
       max_workspace_size_bytes=10 << 20,  # Use a smaller workspace.
       precision_mode=trt_convert.TrtPrecisionMode.FP32,
-      is_dynamic_op=True,
       maximum_cached_engines=2):
     return trt_convert.TrtGraphConverterV2(
         input_saved_model_dir=input_saved_model_dir,
@@ -471,7 +387,6 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
         conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
             max_workspace_size_bytes=max_workspace_size_bytes,
             precision_mode=precision_mode,
-            is_dynamic_op=is_dynamic_op,
             maximum_cached_engines=maximum_cached_engines))
 
   def _CheckTrtOps(self, concrete_func, check_fn=None):
@@ -572,24 +487,6 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     del root_with_trt
     gc.collect()  # Force GC to destroy the TRT engine cache.
 
-  @test_util.run_v2_only
-  def testTrtGraphConverter_StaticConversionNotSupportedInV2(self):
-    """Test case for trt_convert.TrtGraphConverter() using static mode."""
-    if not is_tensorrt_enabled():
-      return
-
-    # Create a model and save it.
-    input_saved_model_dir = self.mkdtemp()
-    root = self._GetModelForV2()
-    save.save(root, input_saved_model_dir,
-              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
-
-    # Run TRT conversion.
-    with self.assertRaisesRegex(
-        ValueError, r"Option is_dynamic_op=False is not supported in TF 2.0, "
-        "please set it to True instead."):
-      self._CreateConverterV2(input_saved_model_dir, is_dynamic_op=False)
-
   @test_util.run_v2_only
   def testTrtGraphConverter_Int8Conversion_v2(self):
     if not is_tensorrt_enabled():
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_windows.py b/tensorflow/python/compiler/tensorrt/trt_convert_windows.py
index 782d22d3721..0180e389f3b 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert_windows.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert_windows.py
@@ -40,10 +40,15 @@ DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
 
 
 @tf_export("experimental.tensorrt.ConversionParams", v1=[])
-class TrtConversionParams(collections.namedtuple("TrtConversionParams", [
-    "rewriter_config_template", "max_workspace_size_bytes", "precision_mode",
-    "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
-    "use_calibration", "max_batch_size"])):
+class TrtConversionParams(
+    collections.namedtuple("TrtConversionParams", [
+        "rewriter_config_template",
+        "max_workspace_size_bytes",
+        "precision_mode",
+        "minimum_segment_size",
+        "maximum_cached_engines",
+        "use_calibration",
+    ])):
   """Parameters that are used for TF-TRT conversion.
 
   Fields:
@@ -56,11 +61,6 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [
       TrtPrecisionMode.supported_precision_modes().
     minimum_segment_size: the minimum number of nodes required for a subgraph
       to be replaced by TRTEngineOp.
-    is_dynamic_op: whether to generate dynamic TRT ops which will build the
-      TRT network and engine at run time. i.e. Since TensorRT version < 6.0
-      does not support dynamic dimensions other than the batch dimension, when
-      the TensorFlow graph has a non-batch dimension of dynamic size, we would
-      need to enable this option. This option should be set to True in TF 2.0.
     maximum_cached_engines: max number of cached TRT engines for dynamic TRT
       ops. Created TRT engines for a dynamic dimension are cached. This is the
       maximum number of engines that can be cached. If the number of cached
@@ -76,8 +76,6 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [
       will occur. Please note that accuracy may be negatively affected if
       there is a mismatch between which tensors TRT quantizes and which
       tensors were trained with fake quantization.
-    max_batch_size: max size for the input batch. This parameter is only
-      effective when is_dynamic_op=False which is not supported in TF 2.0.
   """
 
   def __new__(cls,
@@ -87,8 +85,7 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [
               minimum_segment_size=3,
               is_dynamic_op=True,
               maximum_cached_engines=1,
-              use_calibration=True,
-              max_batch_size=1):
+              use_calibration=True):
     raise NotImplementedError(
         "TensorRT integration is not available on Windows.")
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
index d428baca9c0..86088488e05 100644
--- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
@@ -493,6 +493,36 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
     ]
     self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testMaxIntraOpParallelism(self):
+    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
+    dataset = dataset.flat_map(core_readers.TFRecordDataset)
+    dataset = dataset.batch(5)
+    dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1)
+    dataset = distribute._AutoShardDataset(dataset, 5, 0)
+
+    expected = [
+        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
+        for f in (0, 5)
+        for r in range(0, 10)
+    ]
+    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testPrivateThreadpool(self):
+    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
+    dataset = dataset.flat_map(core_readers.TFRecordDataset)
+    dataset = dataset.batch(5)
+    dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1)
+    dataset = distribute._AutoShardDataset(dataset, 5, 0)
+
+    expected = [
+        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
+        for f in (0, 5)
+        for r in range(0, 10)
+    ]
+    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
+
   @combinations.generate(test_base.default_test_combinations())
   def testMakeBatchedFeaturesDataset(self):
     files = 2
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 20d77bd40ed..b2918076879 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -246,9 +246,6 @@ py_test(
     srcs = ["distribute_coordinator_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
-    tags = [
-        "no_oss_py2",  # b/138443278
-    ],
     deps = [
         ":distribute_coordinator",
         "//tensorflow/core:protos_all_py",
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
index ca330b88ed0..e26b829fa85 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
@@ -176,14 +176,15 @@ class RemoteValueImpl(RemoteValue):
     """
     self._closure = closure
     self._type_spec = type_spec
-    self._value = None
+    self._tensors = None
+    self._fetched_numpys = None
     self._error = None
     self._status_available_event = threading.Event()
     self._status = _RemoteValueStatus.NOT_READY
 
   def _set_aborted(self):
     self._status = _RemoteValueStatus.ABORTED
-    self._value = None
+    self._tensors = None
     self._error = None
 
     # Wake up any waiting thread and clear the event.
@@ -194,21 +195,21 @@ class RemoteValueImpl(RemoteValue):
     # TODO(yuefengz): we may need to rebuild its inputs as well.
     self._closure.execute_on(worker)
 
-  def _set_value(self, value):
+  def _set_tensors(self, tensors):
     self._status = _RemoteValueStatus.READY
-    self._value = value
+    self._tensors = tensors
     self._error = None
     self._status_available_event.set()
 
   def _set_error(self, exception):
     self._status = _RemoteValueStatus.READY
-    self._value = None
+    self._tensors = None
     self._error = exception
     self._status_available_event.set()
 
-  def _get_value(self):
+  def _get_tensors(self):
     self._status_available_event.wait()
-    return self._value
+    return self._tensors
 
   def _get_error(self):
     self._status_available_event.wait()
@@ -222,10 +223,11 @@ class RemoteValueImpl(RemoteValue):
           "The corresponding function is aborted. Please reschedule the "
           "function.")
     if self._error is not None:
-      raise self._error  # pylint: disable=raising-bad-type
-    else:
-      return nest.map_structure(
-          lambda x: x.numpy() if hasattr(x, "numpy") else x, self._value)
+      raise self._error
+    if self._fetched_numpys is None:
+      self._fetched_numpys = nest.map_structure(
+          lambda x: x.numpy() if hasattr(x, "numpy") else x, self._tensors)
+    return self._fetched_numpys
 
 
 class InputError(Exception):
@@ -271,7 +273,7 @@ def _maybe_get_remote_value(val):
       raise AssertionError(
           "RemoteValue doesn't have a value because it has errors.")
     else:
-      return val._get_value()  # pylint: disable=protected-access
+      return val._get_tensors()  # pylint: disable=protected-access
   else:
     return val
 
@@ -406,10 +408,10 @@ class Closure(object):
     with ops.device(worker.device_name):
       with context.executor_scope(worker.executor):
         with metric_utils.monitored_timer("closure_execution"):
-          output_value = self._function(
+          output_tensors = self._function(
               *nest.map_structure(_maybe_get_remote_value, replica_args),
               **nest.map_structure(_maybe_get_remote_value, replica_kwargs))
-    self.output_remote_value._set_value(output_value)  # pylint: disable=protected-access
+    self.output_remote_value._set_tensors(output_tensors)  # pylint: disable=protected-access
 
 
 class _CoordinatedClosureQueue(object):
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py b/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py
index 8b3e95f1fea..4c5a8e20ca1 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py
@@ -209,6 +209,50 @@ class ClusterCoordinatorMprTest(test.TestCase):
         any("_test_translate_ps_failure_error ends properly" in msg
             for msg in mpr.join().stdout))
 
+  def test_numpy_fetched_after_worker_failure(self):
+
+    def fn(first_fetch_occurred_event, worker_terminated_event):
+      os.environ["GRPC_FAIL_FAST"] = "use_caller"
+
+      cluster_resolver = TFConfigClusterResolver()
+      if cluster_resolver.task_type != "chief":
+        utils.start_server(cluster_resolver, "grpc")
+      strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+          cluster_resolver)
+      ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)
+
+      with strategy.scope():
+        v = variables.Variable(initial_value=0, dtype=dtypes.int32)
+
+      @def_function.function
+      def worker_fn():
+        return v + 1, v - 1
+
+      remote_value = ps_coordinator.schedule(worker_fn)
+      logging.info("result (1st fetch): %r", remote_value.fetch())
+      first_fetch_occurred_event.set()
+      worker_terminated_event.wait()
+      logging.info("result (2nd fetch): %r", remote_value.fetch())
+
+    manager = multi_process_runner.manager()
+    first_fetch_occurred_event = manager.Event()
+    worker_terminated_event = manager.Event()
+    mpr = multi_process_runner.MultiProcessRunner(
+        fn,
+        multi_worker_test_base.create_cluster_spec(
+            has_chief=True, num_workers=1, num_ps=1, has_eval=False),
+        args=(first_fetch_occurred_event, worker_terminated_event),
+        rpc_layer="grpc",
+        return_output=True,
+        use_dill_for_args=False)
+
+    mpr.start()
+    first_fetch_occurred_event.wait()
+    mpr.terminate("worker", 0)
+    worker_terminated_event.set()
+    self.assertTrue(
+        any("result (2nd fetch)" in msg for msg in mpr.join().stdout))
+
 
 if __name__ == "__main__":
   v2_compat.enable_v2_behavior()
diff --git a/tensorflow/python/distribute/coordinator/metric_utils_test.py b/tensorflow/python/distribute/coordinator/metric_utils_test.py
index abd4221df4d..db223e3aeb8 100644
--- a/tensorflow/python/distribute/coordinator/metric_utils_test.py
+++ b/tensorflow/python/distribute/coordinator/metric_utils_test.py
@@ -58,7 +58,7 @@ class MetricUtilsTest(test.TestCase):
     result = cluster.schedule(func, args=None, kwargs=None)
     result = cluster.schedule(func, args=None, kwargs=None)
     cluster.join()
-    self.assertEqual(result._get_value().numpy(), 3)
+    self.assertEqual(result.fetch(), 3)
 
     # Tracing, closure execution, and remote_value fetching should be executed
     # exactly once for running this function.
diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py
index 47af4e426df..ca17550626a 100644
--- a/tensorflow/python/distribute/cross_device_ops_test.py
+++ b/tensorflow/python/distribute/cross_device_ops_test.py
@@ -109,7 +109,7 @@ def enable_collective_ops():
   # Recover default flag values.
   cross_device_ops_lib.CollectiveAllReduce._limited_nccl = True
   cross_device_utils.CollectiveReplicaLauncher._prefer_scoped_allocator = True
-  cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = False
+  cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True
   cross_device_utils.CollectiveReplicaLauncher._prefer_ordering_token = False
 
 
diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py
index 0993b054963..15bc9d0da7d 100644
--- a/tensorflow/python/distribute/cross_device_utils.py
+++ b/tensorflow/python/distribute/cross_device_utils.py
@@ -258,7 +258,7 @@ class CollectiveReplicaLauncher(object):
   """Launch collectives on one replica."""
 
   _prefer_scoped_allocator = True
-  _prefer_collective_v2 = False
+  _prefer_collective_v2 = True
   _prefer_ordering_token = False
 
   def __init__(self,
diff --git a/tensorflow/python/distribute/multi_process_lib.py b/tensorflow/python/distribute/multi_process_lib.py
index 0cb88214cf2..14fe8a43d69 100644
--- a/tensorflow/python/distribute/multi_process_lib.py
+++ b/tensorflow/python/distribute/multi_process_lib.py
@@ -108,9 +108,10 @@ def _set_spawn_exe_path():
       # /.../tensorflow/python/distribute/input_lib_test.py
       # and the binary is
       # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
-      org_tensorflow_path = sys.argv[0][:sys.argv[0].rfind('/tensorflow')]
+      org_tensorflow_base = sys.argv[0][:sys.argv[0].rfind('/org_tensorflow')]
       binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
-      possible_path = os.path.join(org_tensorflow_path, binary)
+      possible_path = os.path.join(org_tensorflow_base, 'org_tensorflow',
+                                   binary)
       logging.info('Guessed test binary path: %s', possible_path)
       if os.access(possible_path, os.X_OK):
         path = possible_path
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index cb732d37089..89fdc0c82ad 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -375,6 +375,7 @@ cuda_py_test(
     srcs = ["backprop_test.py"],
     python_version = "PY3",
     tags = [
+        "no_cuda_asan",  # b/173825938
         "no_windows",  #TODO(b/139745667)
         "notsan",  #TODO(b/139745667)
     ],
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index edb02b072b1..0063b7f155e 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1844,6 +1844,24 @@ class JacobianTest(test.TestCase):
     self.assertAllClose(compute_jacobian(use_pfor=True),
                         compute_jacobian(use_pfor=False))
 
+  def test_cond_func_grad_jacobian(self):
+
+    @def_function.function
+    def f(x):
+      y = control_flow_ops.cond(x > 0., lambda: x**3., lambda: x**2.)
+      return y
+
+    with backprop.GradientTape(persistent=True) as tape:
+      x = constant_op.constant(1.)
+      tape.watch(x)
+      y = f(x)
+      grad = tape.gradient(y, x)
+    self.assertAllClose(3., grad)
+    jacobian = tape.jacobian(grad, x, experimental_use_pfor=False)
+    self.assertAllClose(6., jacobian)
+    jacobian_pfor = tape.jacobian(grad, x, experimental_use_pfor=True)
+    self.assertAllClose(6., jacobian_pfor)
+
 
 @test_util.run_all_in_graph_and_eager_modes
 class BatchJacobianTest(test.TestCase, parameterized.TestCase):
diff --git a/tensorflow/python/eager/backprop_util.py b/tensorflow/python/eager/backprop_util.py
index e1c719d4a9d..96d709fe9c5 100644
--- a/tensorflow/python/eager/backprop_util.py
+++ b/tensorflow/python/eager/backprop_util.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.core.framework import types_pb2
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
@@ -38,8 +39,12 @@ def _DTypeFromTensor(tensor):
         and handle_data.is_set
         and handle_data.shape_and_type):
       first_type = handle_data.shape_and_type[0].dtype
-      if all(shape_and_type.dtype == first_type
-             for shape_and_type in handle_data.shape_and_type):
+      # Some variants have statically unknown dtypes; we can't make inferences
+      # about trainability, so we conservatively assume they're trainable
+      # (which may waste memory passing zeros around, but will be correct).
+      if (first_type != types_pb2.DT_INVALID
+          and all(shape_and_type.dtype == first_type
+                  for shape_and_type in handle_data.shape_and_type)):
         return first_type
   return dtype
 
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 8c226f6c681..d41f5d736c8 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -243,7 +243,7 @@ struct FastPathOpExecInfo {
 
 #if PY_MAJOR_VERSION >= 3
 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
-PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
+PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong)
 #else
 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
 #endif
diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py
index 324a314a540..af0f27cef09 100644
--- a/tensorflow/python/eager/pywrap_tfe_test.py
+++ b/tensorflow/python/eager/pywrap_tfe_test.py
@@ -368,5 +368,19 @@ class Tests(test.TestCase):
 
     self.assertNotRegex(full_exception_text, "_FallbackException")
 
+  def testIntAttrThatDoesNotFitIn32Bits(self):
+    # Tests bug where int attributes >= 2**31 raised an exception on platforms
+    # where sizeof(long) = 32 bits.
+    ctx = context.context()
+    ctx.ensure_initialized()
+    shape = constant_op.constant([10])
+    minval = constant_op.constant(0)
+    maxval = constant_op.constant(10)
+    seed = 2**50
+    pywrap_tfe.TFE_Py_FastPathExecute(ctx, "RandomUniformInt", None,
+                                      shape, minval, maxval,
+                                      "seed", seed)
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py
index b5275372447..19947121cd5 100644
--- a/tensorflow/python/framework/tensor_spec.py
+++ b/tensorflow/python/framework/tensor_spec.py
@@ -56,10 +56,6 @@ class DenseSpec(type_spec.TypeSpec):
     self._dtype = dtypes.as_dtype(dtype)
     self._name = name
 
-  @classmethod
-  def from_spec(cls, spec, name=None):
-    return cls(spec.shape, spec.dtype, name or spec.name)
-
   @property
   def shape(self):
     """Returns the `TensorShape` that represents the shape of the tensor."""
@@ -141,8 +137,34 @@ class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec):
     """
     return super(TensorSpec, self).is_compatible_with(spec_or_tensor)
 
+  @classmethod
+  def from_spec(cls, spec, name=None):
+    """Returns a `TensorSpec` with the same shape and dtype as `spec`.
+
+    >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="OriginalName")
+    >>> tf.TensorSpec.from_spec(spec, "NewName")
+    TensorSpec(shape=(8, 3), dtype=tf.int32, name='NewName')
+
+    Args:
+      spec: The `TypeSpec` used to create the new `TensorSpec`.
+      name: The name for the new `TensorSpec`.  Defaults to `spec.name`.
+    """
+    return cls(spec.shape, spec.dtype, name or spec.name)
+
   @classmethod
   def from_tensor(cls, tensor, name=None):
+    """Returns a `TensorSpec` that describes `tensor`.
+
+    >>> tf.TensorSpec.from_tensor(tf.constant([1, 2, 3]))
+    TensorSpec(shape=(3,), dtype=tf.int32, name=None)
+
+    Args:
+      tensor: The `tf.Tensor` that should be described.
+      name: A name for the `TensorSpec`.  Defaults to `tensor.op.name`.
+
+    Returns:
+      A `TensorSpec` that describes `tensor`.
+    """
     if isinstance(tensor, ops.EagerTensor):
       return TensorSpec(tensor.shape, tensor.dtype, name)
     elif isinstance(tensor, ops.Tensor):
@@ -150,7 +172,10 @@ class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec):
     else:
       raise ValueError("`tensor` should be a tf.Tensor")
 
-  value_type = property(lambda self: ops.Tensor)
+  @property
+  def value_type(self):
+    """The Python type for values that are compatible with this TypeSpec."""
+    return ops.Tensor
 
   def _to_components(self, value):
     try:
@@ -263,6 +288,21 @@ class BoundedTensorSpec(TensorSpec):
 
   @classmethod
   def from_spec(cls, spec):
+    """Returns a `TensorSpec` with the same shape and dtype as `spec`.
+
+    If `spec` is a `BoundedTensorSpec`, then the new spec's bounds are set to
+    `spec.minimum` and `spec.maximum`; otherwise, the bounds are set to
+    `spec.dtype.min` and `spec.dtype.max`.
+
+    >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="x")
+    >>> BoundedTensorSpec.from_spec(spec)
+    BoundedTensorSpec(shape=(8, 3), dtype=tf.int32, name='x',
+        minimum=array(-2147483648, dtype=int32),
+        maximum=array(2147483647, dtype=int32))
+
+    Args:
+      spec: The `TypeSpec` used to create the new `BoundedTensorSpec`.
+    """
     dtype = dtypes.as_dtype(spec.dtype)
     minimum = getattr(spec, "minimum", dtype.min)
     maximum = getattr(spec, "maximum", dtype.max)
diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py
index ebfce25d6db..fa48ca8f952 100644
--- a/tensorflow/python/framework/type_spec.py
+++ b/tensorflow/python/framework/type_spec.py
@@ -21,6 +21,7 @@ from __future__ import print_function
 import abc
 import collections
 
+import re
 import numpy as np
 import six
 
@@ -56,9 +57,18 @@ class TypeSpec(object):
   For example, `tf.function`'s `input_signature` argument accepts a list
   (or nested structure) of `TypeSpec`s.
 
-  Creating new subclasses of TypeSpec (outside of TensorFlow core) is not
+  Creating new subclasses of `TypeSpec` (outside of TensorFlow core) is not
   currently supported.  In particular, we may make breaking changes to the
   private methods and properties defined by this base class.
+
+  Example:
+
+  >>> spec = tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32)
+  >>> @tf.function(input_signature=[spec])
+  ... def double(x):
+  ...   return x * 2
+  >>> print(double(tf.ragged.constant([[1, 2], [3]])))
+  <tf.RaggedTensor [[2, 4], [6]]>
   """
   # === Subclassing ===
   #
@@ -611,3 +621,73 @@ def register_type_spec_from_value_converter(type_object, converter_fn,
 
 
 _pywrap_utils.RegisterType("TypeSpec", TypeSpec)
+
+
+_TYPE_SPEC_TO_NAME = {}
+_NAME_TO_TYPE_SPEC = {}
+
+
+# Regular expression for valid TypeSpec names.
+_REGISTERED_NAME_RE = re.compile(r"^(\w+\.)+\w+$")
+
+
+# TODO(b/173744905) tf_export this as "tf.register_type_spec".  (And add a
+# usage example to the docstring, once the API is public.)
+#
+# TODO(b/173744905) Update this decorator to apply to ExtensionType rather than
+# TypeSpec (once we do refactoring to move to_components/from_components from
+# TypeSpec to ExtensionType).
+def register(name):
+  """Decorator used to register a globally unique name for a TypeSpec subclass.
+
+  Args:
+    name: The name of the type spec.  Must be globally unique.  Must have
+      the form `"{project_name}.{type_name}"`.  E.g. `"my_project.MyTypeSpec"`.
+
+  Returns:
+    A class decorator that registers the decorated class with the given name.
+  """
+  if not isinstance(name, str):
+    raise TypeError("Expected `name` to be a string; got %r" % (name,))
+  if not _REGISTERED_NAME_RE.match(name):
+    raise ValueError(
+        "Registered name must have the form '{project_name}.{type_name}' "
+        "(e.g. 'my_project.MyTypeSpec'); got %r." % name)
+
+  def decorator_fn(cls):
+    if not (isinstance(cls, type) and issubclass(cls, TypeSpec)):
+      raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,))
+    if cls in _TYPE_SPEC_TO_NAME:
+      raise ValueError("Class %s.%s has already been registered with name %s."
+                       % (cls.__module__, cls.__name__,
+                          _TYPE_SPEC_TO_NAME[cls]))
+    if name in _NAME_TO_TYPE_SPEC:
+      raise ValueError("Name %s has already been registered for class %s.%s."
+                       % (name, _NAME_TO_TYPE_SPEC[name].__module__,
+                          _NAME_TO_TYPE_SPEC[name].__name__))
+    _TYPE_SPEC_TO_NAME[cls] = name
+    _NAME_TO_TYPE_SPEC[name] = cls
+    return cls
+
+  return decorator_fn
+
+
+# TODO(edloper) tf_export this as "tf.get_type_spec_name" (or some similar name)
+def get_name(cls):
+  """Returns the registered name for TypeSpec `cls`."""
+  if not (isinstance(cls, type) and issubclass(cls, TypeSpec)):
+    raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,))
+  if cls not in _TYPE_SPEC_TO_NAME:
+    raise ValueError("TypeSpec %s.%s has not been registered." %
+                     (cls.__module__, cls.__name__))
+  return _TYPE_SPEC_TO_NAME[cls]
+
+
+# TODO(edloper) tf_export this as "tf.lookup_type_spec" (or some similar name)
+def lookup(name):
+  """Returns the TypeSpec that has been registered with name `name`."""
+  if not isinstance(name, str):
+    raise TypeError("Expected `name` to be a string; got %r" % (name,))
+  if name not in _NAME_TO_TYPE_SPEC:
+    raise ValueError("No TypeSpec has been registered with name %r" % (name,))
+  return _NAME_TO_TYPE_SPEC[name]
diff --git a/tensorflow/python/framework/type_spec_test.py b/tensorflow/python/framework/type_spec_test.py
index bcffd43ee6a..8007bf62dd2 100644
--- a/tensorflow/python/framework/type_spec_test.py
+++ b/tensorflow/python/framework/type_spec_test.py
@@ -47,6 +47,7 @@ class TwoTensors(object):
     self.color = color
 
 
+@type_spec.register("tf.TwoTensorsSpec")
 class TwoTensorsSpec(type_spec.TypeSpec):
   """A TypeSpec for the TwoTensors value type."""
 
@@ -97,6 +98,7 @@ class TwoComposites(object):
     self.color = color
 
 
+@type_spec.register("tf.TwoCompositesSpec")
 class TwoCompositesSpec(type_spec.TypeSpec):
   """A TypeSpec for the TwoTensors value type."""
 
@@ -349,5 +351,64 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     self.assertTrue(spec1.is_compatible_with(spec2))
     self.assertFalse(spec1.is_compatible_with(spec3))
 
+  def testRegistry(self):
+    self.assertEqual("tf.TwoCompositesSpec",
+                     type_spec.get_name(TwoCompositesSpec))
+    self.assertEqual("tf.TwoTensorsSpec", type_spec.get_name(TwoTensorsSpec))
+    self.assertEqual(TwoCompositesSpec,
+                     type_spec.lookup("tf.TwoCompositesSpec"))
+    self.assertEqual(TwoTensorsSpec, type_spec.lookup("tf.TwoTensorsSpec"))
+
+  def testRegistryTypeErrors(self):
+    with self.assertRaisesRegex(TypeError, "Expected `name` to be a string"):
+      type_spec.register(None)
+
+    with self.assertRaisesRegex(TypeError, "Expected `name` to be a string"):
+      type_spec.register(TwoTensorsSpec)
+
+    with self.assertRaisesRegex(TypeError, "Expected `cls` to be a TypeSpec"):
+      type_spec.register("tf.foo")(None)
+
+    with self.assertRaisesRegex(TypeError, "Expected `cls` to be a TypeSpec"):
+      type_spec.register("tf.foo")(ragged_tensor.RaggedTensor)
+
+  def testRegistryDuplicateErrors(self):
+    with self.assertRaisesRegex(
+        ValueError, "Name tf.TwoCompositesSpec has already been registered "
+        "for class __main__.TwoCompositesSpec."):
+
+      @type_spec.register("tf.TwoCompositesSpec")  # pylint: disable=unused-variable
+      class NewTypeSpec(TwoCompositesSpec):
+        pass
+
+    with self.assertRaisesRegex(
+        ValueError, "Class __main__.TwoCompositesSpec has already been "
+        "registered with name tf.TwoCompositesSpec"):
+      type_spec.register("tf.NewName")(TwoCompositesSpec)
+
+  def testRegistryNameErrors(self):
+    for bad_name in ["foo", "", "hello world"]:
+      with self.assertRaises(ValueError):
+        type_spec.register(bad_name)
+
+  def testRegistryLookupErrors(self):
+    with self.assertRaises(TypeError):
+      type_spec.lookup(None)
+    with self.assertRaisesRegex(
+        ValueError, "No TypeSpec has been registered with name 'foo.bar'"):
+      type_spec.lookup("foo.bar")
+
+  def testRegistryGetNameErrors(self):
+    with self.assertRaises(TypeError):
+      type_spec.get_name(None)
+
+    class Foo(TwoCompositesSpec):
+      pass
+
+    with self.assertRaisesRegex(
+        ValueError, "TypeSpec __main__.Foo has not been registered."):
+      type_spec.get_name(Foo)
+
+
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index c1da720c8ff..62dff46c44a 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -1678,14 +1678,13 @@ def zeros_like(x, dtype=None, name=None):
 
   Example:
 
-
+  ```python
   from tensorflow.keras import backend as K
   kvar = K.variable(np.random.random((2,3)))
   kvar_zeros = K.zeros_like(kvar)
   K.eval(kvar_zeros)
   # array([[ 0.,  0.,  0.], [ 0.,  0.,  0.]], dtype=float32)
-
-
+  ```
   """
   return array_ops.zeros_like(x, dtype=dtype, name=name)
 
@@ -4944,6 +4943,11 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
   # activations cache logits on the `output` Tensor.
   if hasattr(output, '_keras_logits'):
     output = output._keras_logits  # pylint: disable=protected-access
+    if from_logits:
+      warnings.warn(
+          '"`categorical_crossentropy` received `from_logits=True`, but '
+          'the `output` argument was produced by a sigmoid or softmax '
+          'activation and thus does not represent logits. Was this intended?"')
     from_logits = True
 
   if from_logits:
@@ -4999,6 +5003,11 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
   # activations cache logits on the `output` Tensor.
   if hasattr(output, '_keras_logits'):
     output = output._keras_logits  # pylint: disable=protected-access
+    if from_logits:
+      warnings.warn(
+          '"`sparse_categorical_crossentropy` received `from_logits=True`, but '
+          'the `output` argument was produced by a sigmoid or softmax '
+          'activation and thus does not represent logits. Was this intended?"')
     from_logits = True
   elif (not from_logits and
         not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
@@ -5081,6 +5090,11 @@ def binary_crossentropy(target, output, from_logits=False):
   # activations cache logits on the `output` Tensor.
   if hasattr(output, '_keras_logits'):
     output = output._keras_logits  # pylint: disable=protected-access
+    if from_logits:
+      warnings.warn(
+          '"`binary_crossentropy` received `from_logits=True`, but the `output`'
+          ' argument was produced by a sigmoid or softmax activation and thus '
+          'does not represent logits. Was this intended?"')
     from_logits = True
 
   if from_logits:
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 8943afaf993..85131056de5 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -18,6 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 import gc
+import warnings
 
 from absl.testing import parameterized
 import numpy as np
@@ -32,6 +33,7 @@ from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras import activations
 from tensorflow.python.keras import backend
 from tensorflow.python.keras import combinations
 from tensorflow.python.keras.engine import input_layer
@@ -1735,6 +1737,45 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase):
     result = self.evaluate(backend.sparse_categorical_crossentropy(t, p))
     self.assertArrayNear(result, [0.002, 0.0005, 0.17], 1e-3)
 
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def test_binary_crossentropy_from_logits_no_warnings(self):
+    t = backend.constant([[0, 1, 0]])
+    logits = backend.constant([[8., 1., 1.]])
+    with warnings.catch_warnings(record=True) as w:
+      self.evaluate(backend.binary_crossentropy(t, logits, from_logits=True))
+      self.assertEmpty(w)
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def test_binary_crossentropy_from_logits_with_sigmoid(self):
+    t = backend.constant([[0, 1, 0]])
+    logits = backend.constant([[8., 1., 1.]])
+    p = activations.sigmoid(logits)
+    with warnings.catch_warnings(record=True) as w:
+      self.evaluate(backend.binary_crossentropy(t, p, from_logits=True))
+      self.assertLen(w, 1)
+      self.assertIn('received `from_logits=True`', str(w[0].message))
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def test_categorical_crossentropy_from_logits_with_softmax(self):
+    t = backend.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+    logits = backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
+    p = activations.softmax(logits)
+    with warnings.catch_warnings(record=True) as w:
+      self.evaluate(backend.categorical_crossentropy(t, p, from_logits=True))
+      self.assertLen(w, 1)
+      self.assertIn('received `from_logits=True`', str(w[0].message))
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def test_sparse_categorical_crossentropy_from_logits_with_softmax(self):
+    t = backend.constant([0, 1, 2])
+    logits = backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
+    p = activations.softmax(logits)
+    with warnings.catch_warnings(record=True) as w:
+      self.evaluate(
+          backend.sparse_categorical_crossentropy(t, p, from_logits=True))
+      self.assertLen(w, 1)
+      self.assertIn('received `from_logits=True`', str(w[0].message))
+
 
 @test_util.with_control_flow_v2
 @combinations.generate(combinations.combine(mode=['graph', 'eager']))
diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
index ebc343158c0..0fc90150127 100644
--- a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
+++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
@@ -30,6 +30,23 @@ def _get_benchmark_name(name):
   return name.split("__")[-1].split("_")
 
 
+def _get_metadata(name):
+  return {
+      "model_name": "ideal_layers",
+      "parameters": name[1] + "_shape",
+  }
+
+
+def _generate_benchmark_params(*params_list):
+  benchmark_params = []
+  for params in params_list:
+    benchmark_params.extend(
+        [((param[0] + "_CPU",) + param[1:]) for param in params])
+    benchmark_params.extend(
+        [((param[0] + "_GPU",) + param[1:]) for param in params])
+  return benchmark_params
+
+
 def _layer_call_backward(layer, x):
   with tf.GradientTape() as tape:
     y = layer(x)
@@ -46,7 +63,7 @@ class KerasLayerBenchmarks(six.with_metaclass(
   # the benchmark name. It must follow the convention of
   # "{layer_name}_{small|normal|large}_shape" to make it compatible with
   # `self.report_benchmark()` method.
-  _benchmark_parameters = [
+  _benchmark_parameters = _generate_benchmark_params([
       ("Conv2D_small_shape", tf.keras.layers.Conv2D,
        {"filters": 1, "kernel_size": 1, "activation": "relu"},
        (1, 1, 1, 1), 10000),
@@ -57,7 +74,7 @@ class KerasLayerBenchmarks(six.with_metaclass(
        {"units": 1}, (1, 1, 1), 10000),
       ("LSTM_normal_shape", tf.keras.layers.LSTM,
        {"units": 4}, (32, 10, 8), 10000),
-  ]
+  ])
 
   def benchmark_layer_call(self, layer_cls, layer_args, input_shape, num_iters):
     layer = layer_cls(**layer_args)
@@ -65,11 +82,8 @@ class KerasLayerBenchmarks(six.with_metaclass(
 
     fn = functools.partial(layer, x)
     name = _get_benchmark_name(self._get_name())
-    metadata = {
-        "model_name": "ideal_layers",
-        "implementation": name[0] + ".layer.call",
-        "parameters": name[1] + "_shape"
-    }
+    metadata = {"implementation": name[0] + ".layer.call"}
+    metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_with_function(
@@ -80,11 +94,8 @@ class KerasLayerBenchmarks(six.with_metaclass(
 
     fn = functools.partial(layer, x)
     name = _get_benchmark_name(self._get_name())
-    metadata = {
-        "model_name": "ideal_layers",
-        "implementation": name[0] + ".layer.call.function",
-        "parameters": name[1] + "_shape"
-    }
+    metadata = {"implementation": name[0] + ".layer.call.function"}
+    metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_with_xla(
@@ -96,11 +107,8 @@ class KerasLayerBenchmarks(six.with_metaclass(
 
     fn = functools.partial(layer, x)
     name = _get_benchmark_name(self._get_name())
-    metadata = {
-        "model_name": "ideal_layers",
-        "implementation": name[0] + ".layer.call.xla",
-        "parameters": name[1] + "_shape"
-    }
+    metadata = {"implementation": name[0] + ".layer.call.xla"}
+    metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_backward(
@@ -110,11 +118,8 @@ class KerasLayerBenchmarks(six.with_metaclass(
 
     fn = functools.partial(_layer_call_backward, layer, x)
     name = _get_benchmark_name(self._get_name())
-    metadata = {
-        "model_name": "ideal_layers",
-        "implementation": name[0] + ".layer.call.backward",
-        "parameters": name[1] + "_shape"
-    }
+    metadata = {"implementation": name[0] + ".layer.call.backward"}
+    metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_backward_with_function(
@@ -125,11 +130,8 @@ class KerasLayerBenchmarks(six.with_metaclass(
 
     fn = functools.partial(_layer_call_backward, layer, x)
     name = _get_benchmark_name(self._get_name())
-    metadata = {
-        "model_name": "ideal_layers",
-        "implementation": name[0] + ".layer.call.backward.function",
-        "parameters": name[1] + "_shape"
-    }
+    metadata = {"implementation": name[0] + ".layer.call.backward.function"}
+    metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
 
@@ -137,7 +139,7 @@ class KerasLayerBenchmarksBackwardXLA(six.with_metaclass(
     benchmark.ParameterizedBenchmark,
     layer_benchmarks_test_base.LayerBenchmarksBase)):
 
-  _benchmark_parameters = [
+  _benchmark_parameters = _generate_benchmark_params([
       ("Conv2D_small_shape", tf.keras.layers.Conv2D,
        {"filters": 1, "kernel_size": 1, "activation": "relu"},
        (1, 1, 1, 1), 10000),
@@ -149,7 +151,7 @@ class KerasLayerBenchmarksBackwardXLA(six.with_metaclass(
       #  {"units": 1}, (1, 1, 1), 10000),
       # ("LSTM_normal_shape", tf.keras.layers.LSTM,
       #  {"units": 4}, (32, 10, 8), 10000),
-  ]
+  ])
 
   def benchmark_layer_call_backward_with_xla(
       self, layer_cls, layer_args, input_shape, num_iters):
@@ -160,11 +162,8 @@ class KerasLayerBenchmarksBackwardXLA(six.with_metaclass(
 
     fn = functools.partial(_layer_call_backward, layer, x)
     name = _get_benchmark_name(self._get_name())
-    metadata = {
-        "model_name": "ideal_layers",
-        "implementation": name[0] + ".layer.call.backward.xla",
-        "parameters": name[1] + "_shape"
-    }
+    metadata = {"implementation": name[0] + ".layer.call.backward.xla"}
+    metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
 
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index 43b5cf63677..34c0fc202db 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -328,6 +328,7 @@ py_library(
         "//tensorflow/python/distribute:parameter_server_strategy_v2",
         "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/distribute:tpu_strategy",
+        "//tensorflow/python/distribute:values",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:test",
         "//tensorflow/python/estimator:estimator_py",
diff --git a/tensorflow/python/keras/distribute/ctl_correctness_test.py b/tensorflow/python/keras/distribute/ctl_correctness_test.py
index 629bc2ac632..b51022f860e 100644
--- a/tensorflow/python/keras/distribute/ctl_correctness_test.py
+++ b/tensorflow/python/keras/distribute/ctl_correctness_test.py
@@ -33,7 +33,6 @@ from tensorflow.python.eager import def_function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import test_combinations as combinations
-from tensorflow.python.framework import test_util
 from tensorflow.python.keras.distribute import optimizer_combinations
 from tensorflow.python.keras.distribute import strategy_combinations
 from tensorflow.python.ops import math_ops
@@ -251,9 +250,6 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
     # TODO(anjs): Identify why this particular V1 optimizer needs a higher tol.
     if 'FtrlV1' in optimizer_fn._name and 'TPU' in type(distribution).__name__:
       self.skipTest('Reduced tolerance of the order of 1e-1 required.')
-    if ('CollectiveAllReduce' in type(distribution).__name__ and
-        test_util.is_xla_enabled()):
-      self.skipTest('XLA tests fail with MWMS.')
     self.dnn_correctness(distribution, optimizer_fn, iteration_type,
                          inside_func, sync_batchnorm)
 
diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py
index 69aa1c7e9eb..972eab4b999 100644
--- a/tensorflow/python/keras/distribute/distribute_strategy_test.py
+++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py
@@ -40,6 +40,7 @@ from tensorflow.python.distribute import parameter_server_strategy_v2
 from tensorflow.python.distribute import reduce_util
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import tpu_strategy
+from tensorflow.python.distribute import values as ds_values_lib
 from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
@@ -2677,6 +2678,43 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
             loss=keras.losses.MeanSquaredError(),
             metrics=[keras.metrics.BinaryAccuracy()])
 
+  @ds_combinations.generate(
+      combinations.combine(
+          distribution=strategy_combinations.mirrored_strategy_with_one_cpu,
+          mode=['eager']))
+  def test_optimizer(self, distribution):
+    temp_dir = os.path.join(self.get_temp_dir(), 'ckpt')
+
+    def create_model():
+      model = keras.models.Sequential([
+          keras.layers.Dense(1),
+      ])
+      model.compile(optimizer='adam', loss='mse')
+      model.build([None, 1])  # create weights.
+      self.assertEmpty(model.optimizer.weights)
+      return model
+
+    model = create_model()
+    x = y = array_ops.ones(shape=(1, 1))
+    model.fit(x=x, y=y, batch_size=1)
+    model.save_weights(temp_dir)
+
+    with distribution.scope():
+      model = create_model()
+      model.load_weights(temp_dir)
+      self.assertNotEmpty(model.optimizer.weights)
+      self.assertIsInstance(model.optimizer.weights[0],
+                            ds_values_lib.DistributedVariable)
+
+    with distribution.scope():
+      model = create_model()
+    # create/restore slot variables outside of scope is fine.
+    model.load_weights(temp_dir)
+    self.assertNotEmpty(model.optimizer.weights)
+    self.assertIsInstance(model.optimizer.weights[0],
+                          ds_values_lib.DistributedVariable)
+
+
 if __name__ == '__main__':
   base_layer_utils.enable_v2_dtype_behavior()
   multi_process_runner.test_main()
diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py
index 8427517f235..fea3ee16da7 100644
--- a/tensorflow/python/keras/engine/functional_test.py
+++ b/tensorflow/python/keras/engine/functional_test.py
@@ -2473,5 +2473,75 @@ class InputsOutputsErrorTest(keras_parameterized.TestCase):
       model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))})
 
 
+class FunctionalSubclassModel(training_lib.Model):
+
+  def __init__(self, *args, **kwargs):
+    my_input = input_layer_lib.Input(shape=(16,))
+    dense = layers.Dense(32, activation='relu')
+    output = dense(my_input)
+    outputs = {'output': output}
+    super().__init__(inputs=[my_input], outputs=outputs, *args, **kwargs)
+
+
+class MixinClass(object):
+
+  def __init__(self, foo, **kwargs):
+    self._foo = foo
+    super().__init__(**kwargs)
+
+  def get_foo(self):
+    return self._foo
+
+
+class SubclassedModel(training_lib.Model):
+
+  def __init__(self, bar, **kwargs):
+    self._bar = bar
+    super().__init__(**kwargs)
+
+  def get_bar(self):
+    return self._bar
+
+
+class MultipleInheritanceModelTest(keras_parameterized.TestCase):
+
+  def testFunctionalSubclass(self):
+    m = FunctionalSubclassModel()
+    # Some smoke test for the weights and output shape of the model
+    self.assertLen(m.weights, 2)
+    self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
+
+  def testFunctionalSubclassPreMixin(self):
+    class MixedFunctionalSubclassModel(MixinClass, FunctionalSubclassModel):
+      pass
+
+    m = MixedFunctionalSubclassModel(foo='123')
+    self.assertTrue(m._is_graph_network)
+    self.assertLen(m.weights, 2)
+    self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
+    self.assertEqual(m.get_foo(), '123')
+
+  def testFunctionalSubclassPostMixin(self):
+    # Make sure the the mixin class is also init correct when the order changed.
+
+    class MixedFunctionalSubclassModel(FunctionalSubclassModel, MixinClass):
+      pass
+
+    m = MixedFunctionalSubclassModel(foo='123')
+    self.assertTrue(m._is_graph_network)
+    self.assertLen(m.weights, 2)
+    self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
+    self.assertEqual(m.get_foo(), '123')
+
+  def testSubclassModelPreMixin(self):
+    class MixedSubclassModel(MixinClass, SubclassedModel):
+      pass
+
+    m = MixedSubclassModel(foo='123', bar='456')
+    self.assertFalse(m._is_graph_network)
+    self.assertEqual(m.get_foo(), '123')
+    self.assertEqual(m.get_bar(), '456')
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 55f71e3a94c..3feb39172f4 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -115,6 +115,10 @@ def inject_functional_model_class(cls):
   from tensorflow.python.keras.engine import training_v1  # pylint: disable=g-import-not-at-top
   if cls == Model or cls == training_v1.Model:
     return functional.Functional
+  # In case there is any multiple inheritance, we stop injecting the
+  # class if keras model is not in its class hierarchy.
+  if cls == object:
+    return object
 
   cls.__bases__ = tuple(inject_functional_model_class(base)
                         for base in cls.__bases__)
@@ -230,8 +234,33 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
     from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
     if (is_functional_model_init_params(args, kwargs) and
         not isinstance(self, functional.Functional)):
+      # Filter the kwargs for multiple inheritance.
+      supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init']
+      model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs}
+      other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs}
       inject_functional_model_class(self.__class__)
-      functional.Functional.__init__(self, *args, **kwargs)
+      functional.Functional.__init__(self, *args, **model_kwargs)
+
+      # In case there is any multiple inheritance here, we need to call the
+      # __init__ for any class that appears after the Functional class.
+      clz_to_init = []
+      found_functional_class = False
+      for clz in self.__class__.__bases__:
+        if issubclass(clz, functional.Functional):
+          found_functional_class = True
+          continue
+        if found_functional_class:
+          clz_to_init.append(clz)
+
+      if clz_to_init:
+        for clz in clz_to_init:
+          clz.__init__(self, *args, **other_kwargs)
+      elif other_kwargs:
+        # In case there are unused kwargs, we should raise an error to user, in
+        # case they have a typo in the param name.
+        raise TypeError(
+            'The following keyword arguments aren\'t supported: {}'.format(
+                other_kwargs))
       return
 
     base_layer.keras_api_gauge.get_cell('Model subclass').set(True)
diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py
index 456b6758dc6..06e369dab21 100644
--- a/tensorflow/python/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/layers/advanced_activations.py
@@ -408,7 +408,7 @@ class ReLU(Layer):
       raise ValueError('threshold of Relu layer '
                        'cannot be None. Required a float')
 
-    self.support_masking = True
+    self.supports_masking = True
     if max_value is not None:
       max_value = K.cast_to_floatx(max_value)
     self.max_value = max_value
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index 80eae8e72de..96a0c217cac 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -34,37 +34,44 @@ class AdvancedActivationsTest(keras_parameterized.TestCase):
     for alpha in [0., .5, -1.]:
       testing_utils.layer_test(keras.layers.LeakyReLU,
                                kwargs={'alpha': alpha},
-                               input_shape=(2, 3, 4))
+                               input_shape=(2, 3, 4),
+                               supports_masking=True)
 
   def test_prelu(self):
     testing_utils.layer_test(keras.layers.PReLU, kwargs={},
-                             input_shape=(2, 3, 4))
+                             input_shape=(2, 3, 4),
+                             supports_masking=True)
 
   def test_prelu_share(self):
     testing_utils.layer_test(keras.layers.PReLU,
                              kwargs={'shared_axes': 1},
-                             input_shape=(2, 3, 4))
+                             input_shape=(2, 3, 4),
+                             supports_masking=True)
 
   def test_elu(self):
     for alpha in [0., .5, -1.]:
       testing_utils.layer_test(keras.layers.ELU,
                                kwargs={'alpha': alpha},
-                               input_shape=(2, 3, 4))
+                               input_shape=(2, 3, 4),
+                               supports_masking=True)
 
   def test_thresholded_relu(self):
     testing_utils.layer_test(keras.layers.ThresholdedReLU,
                              kwargs={'theta': 0.5},
-                             input_shape=(2, 3, 4))
+                             input_shape=(2, 3, 4),
+                             supports_masking=True)
 
   def test_softmax(self):
     testing_utils.layer_test(keras.layers.Softmax,
                              kwargs={'axis': 1},
-                             input_shape=(2, 3, 4))
+                             input_shape=(2, 3, 4),
+                             supports_masking=True)
 
   def test_relu(self):
     testing_utils.layer_test(keras.layers.ReLU,
                              kwargs={'max_value': 10},
-                             input_shape=(2, 3, 4))
+                             input_shape=(2, 3, 4),
+                             supports_masking=True)
     x = keras.backend.ones((3, 4))
     if not context.executing_eagerly():
       # Test that we use `leaky_relu` when appropriate in graph mode.
@@ -80,7 +87,8 @@ class AdvancedActivationsTest(keras_parameterized.TestCase):
         ValueError, 'max_value of Relu layer cannot be negative value: -10'):
       testing_utils.layer_test(keras.layers.ReLU,
                                kwargs={'max_value': -10},
-                               input_shape=(2, 3, 4))
+                               input_shape=(2, 3, 4),
+                               supports_masking=True)
     with self.assertRaisesRegex(
         ValueError,
         'negative_slope of Relu layer cannot be negative value: -2'):
diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py
index 51dc5131a8a..6eb3e8e97ac 100644
--- a/tensorflow/python/keras/layers/pooling.py
+++ b/tensorflow/python/keras/layers/pooling.py
@@ -338,10 +338,11 @@ class MaxPooling2D(Pooling2D):
   window defined by `pool_size` for each dimension along the features axis.
   The window is shifted by `strides` in each dimension.  The resulting output
   when using "valid" padding option has a shape(number of rows or columns) of:
-  `output_shape = (input_shape - pool_size + 1) / strides)`
+  `output_shape = math.floor((input_shape - pool_size) / strides) + 1`
+  (when input_shape >= pool_size)
 
   The resulting output shape when using the "same" padding option is:
-  `output_shape = input_shape / strides`
+  `output_shape = math.floor((input_shape - 1) / strides) + 1`
 
   For example, for stride=(1,1) and padding="valid":
 
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup.py b/tensorflow/python/keras/layers/preprocessing/index_lookup.py
index f3f689fda3a..8a1eb3dfbad 100644
--- a/tensorflow/python/keras/layers/preprocessing/index_lookup.py
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup.py
@@ -223,7 +223,7 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
     if not reset_state:
       raise ValueError("IndexLookup does not support streaming adapts.")
     super(IndexLookup, self).adapt(data, reset_state)
-    self.max_tokens = self._table_handler.vocab_size()
+    self.max_tokens = int(self._table_handler.vocab_size())
 
   def get_vocabulary(self):
     if self._table_handler.vocab_size() == 0:
@@ -239,7 +239,7 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
       return [x for _, x in sorted(zip(values, keys))]
 
   def vocab_size(self):
-    return self._table_handler.vocab_size()
+    return int(self._table_handler.vocab_size())
 
   def get_config(self):
     config = {
@@ -402,7 +402,7 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
       self._set_inverse_vocabulary(vocab)
     else:
       self._set_forward_vocabulary(vocab)
-    self.max_tokens = self._table_handler.vocab_size()
+    self.max_tokens = int(self._table_handler.vocab_size())
 
   def _set_state_variables(self, updates):
     if not self.built:
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index c2accf24e58..4aaf159e11b 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -86,8 +86,8 @@ class Wrapper(Layer):
 class TimeDistributed(Wrapper):
   """This wrapper allows to apply a layer to every temporal slice of an input.
 
-  The input should be at least 3D, and the dimension of index one
-  will be considered to be the temporal dimension.
+  Every input should be at least 3D, and the dimension of index one of the
+  first input will be considered to be the temporal dimension.
 
   Consider a batch of 32 video samples, where each sample is a 128x128 RGB image
   with `channels_last` data format, across 10 timesteps.
@@ -109,7 +109,8 @@ class TimeDistributed(Wrapper):
     layer: a `tf.keras.layers.Layer` instance.
 
   Call arguments:
-    inputs: Input tensor.
+    inputs: Input tensor of shape (batch, time, ...) or nested tensors,
+      and each of which has shape (batch, time, ...).
     training: Python boolean indicating whether the layer should behave in
       training mode or in inference mode. This argument is passed to the
       wrapped layer (only if the layer supports this argument).
@@ -141,7 +142,6 @@ class TimeDistributed(Wrapper):
 
     The static shapes are replaced with the corresponding dynamic shapes of the
     tensor.
-
     Arguments:
       init_tuple: a tuple, the first part of the output shape
       tensor: the tensor from which to get the (static and dynamic) shapes
@@ -150,7 +150,6 @@ class TimeDistributed(Wrapper):
         the static shape of the tensor
       int_shape: an alternative static shape to take as the last part
         of the output shape
-
     Returns:
       The new int_shape with the first part from init_tuple
       and the last part from either `int_shape` (if provided)
@@ -160,6 +159,8 @@ class TimeDistributed(Wrapper):
     # replace all None in int_shape by K.shape
     if int_shape is None:
       int_shape = K.int_shape(tensor)[start_idx:]
+    if isinstance(int_shape, tensor_shape.TensorShape):
+      int_shape = int_shape.as_list()
     if not any(not s for s in int_shape):
       return init_tuple + tuple(int_shape)
     shape = K.shape(tensor)
@@ -169,39 +170,56 @@ class TimeDistributed(Wrapper):
         int_shape[i] = shape[start_idx + i]
     return init_tuple + tuple(int_shape)
 
+  def _remove_timesteps(self, dims):
+    dims = dims.as_list()
+    return tensor_shape.TensorShape([dims[0]] + dims[2:])
+
   def build(self, input_shape):
-    input_shape = tensor_shape.TensorShape(input_shape).as_list()
-    if len(input_shape) < 3:
+    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
+    input_dims = nest.flatten(
+        nest.map_structure(lambda x: x.ndims, input_shape))
+    if any(dim < 3 for dim in input_dims):
       raise ValueError(
           '`TimeDistributed` Layer should be passed an `input_shape ` '
           'with at least 3 dimensions, received: ' + str(input_shape))
     # Don't enforce the batch or time dimension.
-    self.input_spec = InputSpec(shape=[None, None] + input_shape[2:])
-    child_input_shape = [input_shape[0]] + input_shape[2:]
+    self.input_spec = nest.map_structure(
+        lambda x: InputSpec(shape=[None, None] + x.as_list()[2:]), input_shape)
+    child_input_shape = nest.map_structure(self._remove_timesteps, input_shape)
+    child_input_shape = tf_utils.convert_shapes(child_input_shape)
     super(TimeDistributed, self).build(tuple(child_input_shape))
     self.built = True
 
   def compute_output_shape(self, input_shape):
-    input_shape = tensor_shape.TensorShape(input_shape).as_list()
-    child_input_shape = tensor_shape.TensorShape([input_shape[0]] +
-                                                 input_shape[2:])
+    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
+
+    child_input_shape = nest.map_structure(self._remove_timesteps, input_shape)
     child_output_shape = self.layer.compute_output_shape(child_input_shape)
-    if not isinstance(child_output_shape, tensor_shape.TensorShape):
-      child_output_shape = tensor_shape.TensorShape(child_output_shape)
-    child_output_shape = child_output_shape.as_list()
-    timesteps = input_shape[1]
-    return tensor_shape.TensorShape([child_output_shape[0], timesteps] +
-                                    child_output_shape[1:])
+    child_output_shape = tf_utils.convert_shapes(
+        child_output_shape, to_tuples=False)
+    timesteps = tf_utils.convert_shapes(input_shape)
+    timesteps = nest.flatten(timesteps)[1]
+
+    def insert_timesteps(dims):
+      dims = dims.as_list()
+      return tensor_shape.TensorShape([dims[0], timesteps] + dims[1:])
+
+    return nest.map_structure(insert_timesteps, child_output_shape)
 
   def call(self, inputs, training=None, mask=None):
     kwargs = {}
     if generic_utils.has_arg(self.layer.call, 'training'):
       kwargs['training'] = training
 
-    input_shape = K.int_shape(inputs)
-    if input_shape[0] and not self._always_use_reshape:
+    input_shape = nest.map_structure(
+        lambda x: tensor_shape.TensorShape(K.int_shape(x)), inputs)
+    batch_size = tf_utils.convert_shapes(input_shape)
+    batch_size = nest.flatten(batch_size)[0]
+    if batch_size and not self._always_use_reshape:
       inputs, row_lengths = K.convert_inputs_if_ragged(inputs)
       is_ragged_input = row_lengths is not None
+      input_length = tf_utils.convert_shapes(input_shape)
+      input_length = nest.flatten(input_length)[1]
 
       # batch size matters, use rnn-based implementation
       def step(x, _):
@@ -212,27 +230,44 @@ class TimeDistributed(Wrapper):
           step,
           inputs,
           initial_states=[],
-          input_length=row_lengths[0] if is_ragged_input else input_shape[1],
+          input_length=row_lengths[0] if is_ragged_input else input_length,
           mask=mask,
           unroll=False)
-      y = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths)
+      # pylint: disable=g-long-lambda
+      y = nest.map_structure(
+          lambda output: K.maybe_convert_to_ragged(is_ragged_input, output,
+                                                   row_lengths), outputs)
     else:
       # No batch size specified, therefore the layer will be able
       # to process batches of any size.
       # We can go with reshape-based implementation for performance.
-      if isinstance(inputs, ragged_tensor.RaggedTensor):
-        y = self.layer(inputs.values, **kwargs)
-        y = ragged_tensor.RaggedTensor.from_row_lengths(
-            y,
-            inputs.nested_row_lengths()[0])
+      is_ragged_input = nest.map_structure(
+          lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs)
+      is_ragged_input = nest.flatten(is_ragged_input)
+      if all(is_ragged_input):
+        input_values = nest.map_structure(lambda x: x.values, inputs)
+        input_row_lenghts = nest.map_structure(
+            lambda x: x.nested_row_lengths()[0], inputs)
+        y = self.layer(input_values, **kwargs)
+        y = nest.map_structure(ragged_tensor.RaggedTensor.from_row_lengths, y,
+                               input_row_lenghts)
+      elif any(is_ragged_input):
+        raise ValueError('All inputs has to be either ragged or not, '
+                         'but not mixed. You passed: {}'.format(inputs))
       else:
-        input_length = input_shape[1]
+        input_length = tf_utils.convert_shapes(input_shape)
+        input_length = nest.flatten(input_length)[1]
         if not input_length:
-          input_length = array_ops.shape(inputs)[1]
-        inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
+          input_length = nest.map_structure(lambda x: array_ops.shape(x)[1],
+                                            inputs)
+          input_length = generic_utils.to_list(nest.flatten(input_length))[0]
+
+        inner_input_shape = nest.map_structure(
+            lambda x: self._get_shape_tuple((-1,), x, 2), inputs)
         # Shape: (num_samples * timesteps, ...). And track the
         # transformation in self._input_map.
-        inputs = array_ops.reshape(inputs, inner_input_shape)
+        inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs,
+                                          inner_input_shape)
         # (num_samples * timesteps, ...)
         if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
           inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
@@ -241,15 +276,19 @@ class TimeDistributed(Wrapper):
         y = self.layer(inputs, **kwargs)
 
         # Shape: (num_samples, timesteps, ...)
-        output_shape = self.compute_output_shape(input_shape).as_list()
-        output_shape = self._get_shape_tuple((-1, input_length), y, 1,
-                                             output_shape[2:])
-        y = array_ops.reshape(y, output_shape)
+        output_shape = self.compute_output_shape(input_shape)
+        # pylint: disable=g-long-lambda
+        output_shape = nest.map_structure(
+            lambda tensor, int_shape: self._get_shape_tuple(
+                (-1, input_length), tensor, 1, int_shape[2:]), y, output_shape)
+        y = nest.map_structure_up_to(y, array_ops.reshape, y, output_shape)
         if not context.executing_eagerly():
           # Set the static shape for the result since it might be lost during
           # array_ops reshape, eg, some `None` dim in the result could be
           # inferred.
-          y.set_shape(self.compute_output_shape(input_shape))
+          nest.map_structure_up_to(
+              y, lambda tensor, shape: tensor.set_shape(shape), y,
+              self.compute_output_shape(input_shape))
 
     return y
 
@@ -290,9 +329,15 @@ class TimeDistributed(Wrapper):
     """
     # cases need to call the layer.compute_mask when input_mask is None:
     # Masking layer and Embedding layer with mask_zero
-    input_shape = K.int_shape(inputs)
-    if input_shape[0] and not self._always_use_reshape or isinstance(
-        inputs, ragged_tensor.RaggedTensor):
+    input_shape = nest.map_structure(
+        lambda x: tensor_shape.TensorShape(K.int_shape(x)), inputs)
+    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
+    batch_size = tf_utils.convert_shapes(input_shape)
+    batch_size = nest.flatten(batch_size)[0]
+    is_ragged_input = nest.map_structure(
+        lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs)
+    is_ragged_input = generic_utils.to_list(nest.flatten(is_ragged_input))
+    if batch_size and not self._always_use_reshape or any(is_ragged_input):
       # batch size matters, we currently do not handle mask explicitly, or if
       # the layer always uses reshape approach, or the input is a ragged tensor.
       return mask
@@ -300,8 +345,10 @@ class TimeDistributed(Wrapper):
     if inner_mask is not None:
       inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
       inner_mask = K.reshape(inner_mask, inner_mask_shape)
-    inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
-    inner_inputs = array_ops.reshape(inputs, inner_input_shape)
+    inner_input_shape = nest.map_structure(
+        lambda tensor: self._get_shape_tuple((-1,), tensor, 2), inputs)
+    inner_inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs,
+                                            inner_input_shape)
     output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
     if output_mask is None:
       if mask is None:
@@ -313,9 +360,11 @@ class TimeDistributed(Wrapper):
         output_mask = K.any(output_mask, axis=-1)
     else:
       # output_mask is not None. We need to reshape it
-      input_length = input_shape[1]
+      input_length = tf_utils.convert_shapes(input_shape)
+      input_length = nest.flatten(input_length)[1]
       if not input_length:
-        input_length = K.shape(inputs)[1]
+        input_length = nest.map_structure(lambda x: K.shape(x)[1], inputs)
+        input_length = nest.flatten(input_length)[0]
       output_mask_int_shape = K.int_shape(output_mask)
       if output_mask_int_shape is None:
         # if the output_mask does not have a static shape,
@@ -323,6 +372,7 @@ class TimeDistributed(Wrapper):
         if mask is not None:
           output_mask_int_shape = K.int_shape(mask)
         else:
+          input_shape = generic_utils.to_list(nest.flatten(input_shape))[0]
           output_mask_int_shape = K.compute_output_shape(input_shape)[:-1]
       output_mask_shape = self._get_shape_tuple(
           (-1, input_length), output_mask, 1, output_mask_int_shape[1:])
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index c60e950794f..d8c6d4ff220 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -471,6 +471,62 @@ class TimeDistributedTest(keras_parameterized.TestCase):
     # Make sure the batch dim is not lost after array_ops.reshape.
     self.assertListEqual(outputs.shape.as_list(), [1, None, 30, 30, 16])
 
+  @keras_parameterized.run_all_keras_modes
+  def test_TimeDistributed_with_mimo(self):
+    dense_1 = keras.layers.Dense(8)
+    dense_2 = keras.layers.Dense(16)
+
+    class TestLayer(keras.layers.Layer):
+
+      def __init__(self):
+        super(TestLayer, self).__init__()
+        self.dense_1 = dense_1
+        self.dense_2 = dense_2
+
+      def call(self, inputs):
+        return self.dense_1(inputs[0]), self.dense_2(inputs[1])
+
+      def compute_output_shape(self, input_shape):
+        output_shape_1 = self.dense_1.compute_output_shape(input_shape[0])
+        output_shape_2 = self.dense_2.compute_output_shape(input_shape[1])
+        return output_shape_1, output_shape_2
+
+    np.random.seed(100)
+    layer = TestLayer()
+
+    data_1 = array_ops.constant([[[[1.0], [1.0]], [[2.0], [2.0]]],
+                                 [[[4.0], [4.0]], [[5.0], [5.0]]],
+                                 [[[7.0], [7.0]], [[8.0], [8.0]]]])
+
+    data_2 = array_ops.constant([[[[1.0], [1.0]], [[2.0], [2.0]]],
+                                 [[[4.0], [4.0]], [[5.0], [5.0]]],
+                                 [[[7.0], [7.0]], [[8.0], [8.0]]]])
+
+    x1 = keras.Input(shape=(None, 2, 1), dtype='float32')
+    x2 = keras.Input(shape=(None, 2, 1), dtype='float32')
+    y1, y2 = keras.layers.TimeDistributed(layer)([x1, x2])
+    model_1 = keras.models.Model([x1, x2], [y1, y2])
+    model_1.compile(
+        optimizer='rmsprop',
+        loss='mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    output_1 = model_1.predict((data_1, data_2), steps=1)
+
+    y1 = dense_1(x1)
+    y2 = dense_2(x2)
+    model_2 = keras.models.Model([x1, x2], [y1, y2])
+    output_2 = model_2.predict((data_1, data_2), steps=1)
+
+    self.assertAllClose(output_1, output_2)
+
+    model_1.fit(
+        x=[np.random.random((10, 2, 2, 1)),
+           np.random.random((10, 2, 2, 1))],
+        y=[np.random.random((10, 2, 2, 8)),
+           np.random.random((10, 2, 2, 16))],
+        epochs=1,
+        batch_size=3)
+
 
 @combinations.generate(combinations.combine(mode=['graph', 'eager']))
 class BidirectionalTest(test.TestCase, parameterized.TestCase):
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
index 1bd13d82834..a7d937344f4 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -855,8 +855,22 @@ class OptimizerV2(trackable.Trackable):
     """A list of names for this optimizer's slots."""
     return self._slot_names
 
-  def add_slot(self, var, slot_name, initializer="zeros"):
-    """Add a new slot variable for `var`."""
+  def add_slot(self, var, slot_name, initializer="zeros", shape=None):
+    """Add a new slot variable for `var`.
+
+    A slot variable is an additional variable associated with `var` to train.
+    It is allocated and managed by optimizers, e.g. `Adam`.
+
+    Args:
+      var: a `Variable` object.
+      slot_name: name of the slot variable.
+      initializer: initializer of the slot variable
+      shape: (Optional) shape of the slot variable. If not set, it will default
+      to the shape of `var`.
+
+    Returns:
+      A slot variable.
+    """
     if slot_name not in self._slot_names:
       self._slot_names.append(slot_name)
     var_key = _var_key(var)
@@ -865,26 +879,29 @@ class OptimizerV2(trackable.Trackable):
     if weight is None:
       if isinstance(initializer, six.string_types) or callable(initializer):
         initializer = initializers.get(initializer)
+        slot_shape = var.shape if shape is None else shape
         initial_value = functools.partial(
-            initializer, shape=var.shape, dtype=var.dtype)
+            initializer, shape=slot_shape, dtype=var.dtype)
       else:
         initial_value = initializer
-      strategy = distribute_ctx.get_strategy()
-      if not strategy.extended.variable_created_in_scope(var):
-        raise ValueError(
-            "Trying to create optimizer slot variable under the scope for "
-            "tf.distribute.Strategy ({}), which is different from the scope "
-            "used for the original variable ({}). Make sure the slot "
-            "variables are created under the same strategy scope. This may "
-            "happen if you're restoring from a checkpoint outside the scope"
-            .format(strategy, var))
 
-      with strategy.extended.colocate_vars_with(var):
-        weight = tf_variables.Variable(
-            name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
-            dtype=var.dtype,
-            trainable=False,
-            initial_value=initial_value)
+      with self._distribution_strategy_scope():
+        strategy = distribute_ctx.get_strategy()
+        if not strategy.extended.variable_created_in_scope(var):
+          raise ValueError(
+              "Trying to create optimizer slot variable under the scope for "
+              "tf.distribute.Strategy ({}), which is different from the scope "
+              "used for the original variable ({}). Make sure the slot "
+              "variables are created under the same strategy scope. This may "
+              "happen if you're restoring from a checkpoint outside the scope"
+              .format(strategy, var))
+
+        with strategy.extended.colocate_vars_with(var):
+          weight = tf_variables.Variable(
+              name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
+              dtype=var.dtype,
+              trainable=False,
+              initial_value=initial_value)
       backend.track_variable(weight)
       slot_dict[slot_name] = weight
       self._restore_slot_variable(
@@ -1344,7 +1361,13 @@ class OptimizerV2(trackable.Trackable):
         # a slot variable if not for this case). Deferring is mostly harmless
         # (aside from double initialization), and makes variable creator scopes
         # behave the same way they do when graph building.
-        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
+        #
+        # One notable case is with distribution strategy, which uses variable
+        # creator scope but always desires the `variable` and the slot to use
+        # the same scope, thus we can safely eagerly create/restore slot
+        # variables.
+        and (not ops.get_default_graph()._variable_creator_stack or  # pylint: disable=protected-access
+             self._distribution_strategy)):
       initializer = trackable.CheckpointInitialValueCallable(
           checkpoint_position=slot_variable_position)
       slot_variable = self.add_slot(
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index fecf52e71b3..9045f0ec3de 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -100,7 +100,8 @@ def layer_test(layer_cls,
                validate_training=True,
                adapt_data=None,
                custom_objects=None,
-               test_harness=None):
+               test_harness=None,
+               supports_masking=None):
   """Test routine for a layer with a single input and single output.
 
   Arguments:
@@ -122,6 +123,8 @@ def layer_test(layer_cls,
       in the layer class. This is helpful for testing custom layers.
     test_harness: The Tensorflow test, if any, that this function is being
       called in.
+    supports_masking: Optional boolean to check the `supports_masking` property
+      of the layer. If None, the check will not be performed.
 
   Returns:
     The output data (Numpy array) returned by the layer, for additional
@@ -165,6 +168,13 @@ def layer_test(layer_cls,
   kwargs = kwargs or {}
   layer = layer_cls(**kwargs)
 
+  if (supports_masking is not None
+      and layer.supports_masking != supports_masking):
+    raise AssertionError(
+        'When testing layer %s, the `supports_masking` property is %r'
+        'but expected to be %r.\nFull kwargs: %s' %
+        (layer_cls.__name__, layer.supports_masking, supports_masking, kwargs))
+
   # Test adapt, if data was passed.
   if adapt_data is not None:
     layer.adapt(adapt_data)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index b6f3b6b8c82..8fa5e2b33f0 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -987,11 +987,13 @@ tf_py_test(
     ],
 )
 
-tf_py_test(
+cuda_py_test(
     name = "segment_reduction_ops_test",
     size = "medium",
     srcs = ["segment_reduction_ops_test.py"],
     shard_count = 10,
+    # TODO (b/173835746): the test fails with XLA.
+    xla_enable_strict_auto_jit = False,
     deps = [
         "//tensorflow/python:client",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD
index c10c44e1ef4..21e11d3ad62 100644
--- a/tensorflow/python/kernel_tests/array_ops/BUILD
+++ b/tensorflow/python/kernel_tests/array_ops/BUILD
@@ -23,6 +23,9 @@ cuda_py_test(
     name = "unstack_op_test",
     size = "small",
     srcs = ["unstack_op_test.py"],
+    tags = [
+        "no_cuda_asan",  # b/173806679
+    ],
     xla_tags = [
         "no_cuda_asan",  # times out
     ],
@@ -54,6 +57,7 @@ cuda_py_test(
     name = "gather_op_test",
     size = "medium",
     srcs = ["gather_op_test.py"],
+    tags = ["no_cuda_asan"],  # b/173806733
     deps = [
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py
index fdaf3213759..bb6c04019e9 100644
--- a/tensorflow/python/kernel_tests/collective_ops_test.py
+++ b/tensorflow/python/kernel_tests/collective_ops_test.py
@@ -595,6 +595,7 @@ class OpCancellationTest(test.TestCase, parameterized.TestCase):
               mode='eager'), device_combination))
   def testOpErrorNotAbortWithCollective(self, collective_op, device,
                                         communication):
+    self.skipTest('b/173733368: currently it may timeout on guitar.')
     # Do not abort v2 collective ops even if there're active collective ops at
     # the time of an op error. We rely cancellation to terminate active
     # collective ops.
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 2011b3b4b45..7ccf4b89e67 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import cond_v2
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import math_ops
@@ -1296,6 +1297,24 @@ class CondV2Test(test.TestCase):
     i = f(constant_op.constant(False))
     self.assertEqual(self.evaluate(i), 2.0)
 
+  def testGradientOfMixedOptionals(self):
+
+    @def_function.function
+    def f(c):
+      x = constant_op.constant(1., name="x")
+
+      def then_branch():
+        return x ** 2., gen_dataset_ops.optional_from_value(
+            [constant_op.constant(1)])
+
+      def else_branch():
+        return x ** 3., gen_dataset_ops.optional_from_value(
+            [constant_op.constant(1.)])
+
+      y, _ = cond_v2.cond_v2(c, then_branch, else_branch)
+      return gradients_impl.gradients(y, x)
+    self.assertAllClose([2.], f(constant_op.constant(True)))
+
 
 class CondV2CollectionTest(test.TestCase):
 
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 95483881629..81f105899f3 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -19,9 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import numpy as np
-from six.moves import xrange  # pylint: disable=redefined-builtin
 
-from tensorflow.python import tf2
 from tensorflow.python.eager import backprop
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -29,7 +27,6 @@ from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import gradient_checker_v2
-from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variables
@@ -117,45 +114,19 @@ class ReluTest(test.TestCase):
           order="F")
       err = gradient_checker_v2.max_error(*gradient_checker_v2.compute_gradient(
           nn_ops.relu, [x], delta=1.0 / 1024))
-    self.assertLess(err, 1e-4)
+    self.assertLess(err, 1e-6)
 
-  # The gradient for fp16 is inaccurate due to the low-precision.
-  # We compare the fp16 analytical gradient against their fp32 counterpart.
+  # The gradient test for ReLU is a bit tricky as the derivative is not well
+  # defined at around zero and we want to avoid that in terms of input values.
   def testGradientFloat16(self):
-
-    def grad(x):
-      with backprop.GradientTape() as tape:
-        tape.watch(x)
-        y = nn_ops.l2_loss(nn_ops.relu(x))
-      return tape.gradient(y, x)
-
-    def f():
-      with test_util.use_gpu():
-        # Randomly construct a 1D shape from [1, 40)
-        shape = random_ops.random_uniform([1],
-                                          minval=1,
-                                          maxval=40,
-                                          dtype=dtypes.int32)
-        x32 = random_ops.random_uniform(shape, minval=-1, maxval=1)
-        x16 = math_ops.cast(x32, dtype=dtypes.float16)
-        return grad(x32), grad(x16)
-
-    # We're going to ensure that the fp16 and fp32 gradients
-    # are "close" to each other for ~100 random values.
-    #
-    # In TensorFlow 1.x, invoking f() (without eager execution enabled)
-    # would construct a graph. Instead of construct a graph with O(100) nodes,
-    # we construct a single graph to be executed ~100 times in a Session.
-    if not tf2.enabled():
-      d32_tensor, d16_tensor = f()
-      with self.cached_session() as sess:
-        f = lambda: sess.run([d32_tensor, d16_tensor])
-
-    # Repeat the experiment for 100 times. All tensor shapes and its tensor
-    # values are randomly generated for each run.
-    for _ in xrange(100):
-      d32, d16 = f()
-      self.assertAllClose(d32, d16, atol=3e-4)
+    with self.cached_session():
+      x = np.asarray(
+          [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+          dtype=np.float16,
+          order="F")
+      err = gradient_checker_v2.max_error(
+          *gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
+    self.assertLess(err, 1e-6)
 
   def testGradientFloat64(self):
     with self.cached_session():
@@ -165,7 +136,7 @@ class ReluTest(test.TestCase):
           order="F")
       err = gradient_checker_v2.max_error(*gradient_checker_v2.compute_gradient(
           nn_ops.relu, [x], delta=1.0 / 1024))
-    self.assertLess(err, 1e-10)
+    self.assertLess(err, 1e-15)
 
   def testGradGradFloat32(self):
     with self.cached_session():
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
index e64e75bde0d..d5cd9d7bf43 100644
--- a/tensorflow/python/kernel_tests/summary_ops_test.py
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -1216,6 +1216,26 @@ class SummaryOpsTest(test_util.TensorFlowTestCase):
       # Reset to default state for other tests.
       summary_ops.set_step(None)
 
+  @test_util.run_v2_only
+  def testTrace_withProfiler(self):
+
+    @def_function.function
+    def f():
+      x = constant_op.constant(2)
+      y = constant_op.constant(3)
+      return x**y
+
+    assert context.executing_eagerly()
+    logdir = self.get_temp_dir()
+    writer = summary_ops.create_file_writer(logdir)
+    summary_ops.trace_on(graph=True, profiler=True)
+    profiler_outdir = self.get_temp_dir()
+    with writer.as_default():
+      f()
+      summary_ops.trace_export(
+          name='foo', step=1, profiler_outdir=profiler_outdir)
+    writer.close()
+
   @test_util.run_v2_only
   def testGraph_graph(self):
 
diff --git a/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py b/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py
index df397d449c3..d18f66f74e3 100644
--- a/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py
+++ b/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 from tensorflow.python.distribute import distribution_strategy_context as ds_context
 from tensorflow.python.distribute import mirrored_strategy
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import template
@@ -29,28 +30,29 @@ from tensorflow.python.platform import test
 
 class TemplateMirroredStrategyTest(test.TestCase):
 
-  @test_util.run_deprecated_v1
   @test_util.disable_tfrt("Strategy not supported yet.")
   def test_merge_call(self):
-    if not test.is_gpu_available():
-      self.skipTest("No GPU available")
+    with ops.Graph().as_default():
+      # The test is testing a v1 only function.
+      if not test.is_gpu_available():
+        self.skipTest("No GPU available")
 
-    def fn():
-      var1 = variable_scope.get_variable(
-          "var1", shape=[], initializer=init_ops.constant_initializer(21.))
-      ds_context.get_replica_context().merge_call(lambda _: ())
-      var2 = variable_scope.get_variable(
-          "var2", shape=[], initializer=init_ops.constant_initializer(2.))
-      return var1 * var2
+      def fn():
+        var1 = variable_scope.get_variable(
+            "var1", shape=[], initializer=init_ops.constant_initializer(21.))
+        ds_context.get_replica_context().merge_call(lambda _: ())
+        var2 = variable_scope.get_variable(
+            "var2", shape=[], initializer=init_ops.constant_initializer(2.))
+        return var1 * var2
 
-    temp = template.make_template("my_template", fn)
+      temp = template.make_template("my_template", fn)
 
-    strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
-    out = strategy.experimental_local_results(
-        strategy.run(temp))
+      strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
+      out = strategy.experimental_local_results(
+          strategy.run(temp))
 
-    self.evaluate(variables.global_variables_initializer())
-    self.assertAllEqual([42., 42.], self.evaluate(out))
+      self.evaluate(variables.global_variables_initializer())
+      self.assertAllEqual([42., 42.], self.evaluate(out))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
index f3902fb28f3..a161b4b1720 100644
--- a/tensorflow/python/kernel_tests/while_v2_test.py
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -288,9 +288,8 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
       dx = _TapeFromGraphMode(x)
       theoretical, numerical = gradient_checker_v2.compute_gradient(
           target_function, [x])
-      self.assertAllClose(numerical, theoretical, rtol=1e-3)
-      self.assertAllClose(array_ops.reshape(numerical, []),
-                          dx, rtol=1e-3)
+      self.assertAllClose(numerical, theoretical, rtol=3e-3)
+      self.assertAllClose(array_ops.reshape(numerical, []), dx, rtol=3e-3)
 
   def testDeviceLabelsInherited(self):
     def _LoopBody(i, y):
diff --git a/tensorflow/python/ops/batch_ops_test.py b/tensorflow/python/ops/batch_ops_test.py
index 5749be96033..15c670e439d 100644
--- a/tensorflow/python/ops/batch_ops_test.py
+++ b/tensorflow/python/ops/batch_ops_test.py
@@ -25,12 +25,17 @@ import numpy as np
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.framework.errors import InvalidArgumentError
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import batch_ops
 from tensorflow.python.ops import gen_batch_ops
+from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import script_ops
+from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
 
@@ -50,7 +55,7 @@ class BatchOpsTest(test.TestCase):
     """Tests that a single batched tensor executes together and only once."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, index, _ = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=2,
@@ -92,7 +97,7 @@ class BatchOpsTest(test.TestCase):
     """Test that batching with padding up to an allowed batch size works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
       batched, index, _ = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=10,
@@ -124,7 +129,7 @@ class BatchOpsTest(test.TestCase):
     """Tests that multiple batched tensors execute together."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, _, _ = batch_ops.batch(
@@ -165,7 +170,7 @@ class BatchOpsTest(test.TestCase):
     """Tests illegally feeding tensors with different dim0 sizes."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
       batched, index, _ = batch_ops.batch(
@@ -181,7 +186,7 @@ class BatchOpsTest(test.TestCase):
     """Tests that batch and unbatch work together."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, index, id_t = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=10,
@@ -207,7 +212,7 @@ class BatchOpsTest(test.TestCase):
     """Tests that the batch_function decorator works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       # TODO(apassos): Removing this line causes test flakiness! Ideally should
       # be investigated.
       default_inp = array_ops.placeholder_with_default(2, shape=[])  # pylint: disable=unused-variable
@@ -235,33 +240,62 @@ class BatchOpsTest(test.TestCase):
     """Tests that the batch_function decorator works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
-      captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
-      captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
+    with self.cached_session(use_gpu=True) as sess:
+      captured_inp0 = array_ops.placeholder_with_default(2., shape=[])
+      captured_inp1 = resource_variable_ops.ResourceVariable(3.)
+      with ops.device("/cpu:0"):
+        captured_inp2 = resource_variable_ops.ResourceVariable(4.)
 
       @batch_ops.batch_function(1, 10, 100000)
       def computation(in_t):
-        return in_t + captured_inp0 - captured_inp1
+        return in_t + captured_inp0 + captured_inp1 + captured_inp2
 
-      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+      inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
       result = computation(inp)
       thread_results = []
 
       def worker():
         thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
 
+      sess.run(variables.global_variables_initializer())
       worker_thread = threading.Thread(target=worker)
       worker_thread.start()
       main_results = sess.run([result], feed_dict={inp: [2]})
       worker_thread.join()
-      self.assertEqual(thread_results[0], [2])
-      self.assertEqual(main_results[0], [3])
+      self.assertEqual(thread_results[0], [10])
+      self.assertEqual(main_results[0], [11])
+
+  @test_util.disable_xla("DeviceIndex returns sentinel value with XLA")
+  def testBatchDecoratedGpu(self):
+    if context.executing_eagerly():
+      return
+    with self.cached_session(use_gpu=True) as sess:
+
+      @batch_ops.batch_function(1, 10, 100000)
+      def computation(in_t):
+        # index is 0 on CPU and 1 on GPU
+        index = gen_functional_ops.DeviceIndex(device_names=["CPU", "GPU"])
+        return in_t + math_ops.cast(index, dtypes.float32)
+
+      inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
+      result = computation(inp)
+      thread_results = []
+
+      def worker():
+        thread_results.extend(sess.run([result], feed_dict={inp: [10.]}))
+
+      worker_thread = threading.Thread(target=worker)
+      worker_thread.start()
+      main_results = sess.run([result], feed_dict={inp: [20.]})
+      worker_thread.join()
+      self.assertEqual(thread_results[0], [10 + test_util.is_gpu_available()])
+      self.assertEqual(main_results[0], [20 + test_util.is_gpu_available()])
 
   def testBatchFunctionOp(self):
     """Tests that the batch_function op works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
 
       @function.Defun(dtypes.int32)
       def computation(in_t):
@@ -292,7 +326,7 @@ class BatchOpsTest(test.TestCase):
     """Tests that batch_function op works with captured input."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
       captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@@ -328,7 +362,7 @@ class BatchOpsTest(test.TestCase):
     """Tests that batch_function op works with error in the inputs."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
 
       @function.Defun(dtypes.int32, dtypes.int32)
@@ -345,8 +379,9 @@ class BatchOpsTest(test.TestCase):
           captured_tensors=computation.captured_inputs,
           Tout=[o.type for o in computation.definition.signature.output_arg])
 
-      with self.assertRaisesRegex(InvalidArgumentError,
-                                  ".*2 arguments.*but 1.*"):
+      with self.assertRaisesRegex(
+          InvalidArgumentError,
+          r"Function takes 2 argument\(s\) but 1 argument\(s\) were passed"):
         sess.run([result], feed_dict={inp: [2]})
 
   def testBatchFunctionOpWithLargeBatchSplitted(self):
@@ -354,7 +389,7 @@ class BatchOpsTest(test.TestCase):
     if context.executing_eagerly():
       return
 
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
 
       @function.Defun(dtypes.int32)
       def computation(in_t):
@@ -408,7 +443,7 @@ class BatchOpsTest(test.TestCase):
     """Tests that the batch_function decorator works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
 
       @batch_ops.batch_function(1, 10, 100000)
       def computation(in_t):
@@ -432,7 +467,7 @@ class BatchOpsTest(test.TestCase):
     """Tests that the unbatch timeout works."""
     if context.executing_eagerly():
       return
-    with self.cached_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, index, id_t = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=2,
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index 282eda948db..92c09f7dcb7 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -25,6 +25,7 @@ from __future__ import print_function
 
 import collections
 
+from tensorflow.core.framework import types_pb2
 from tensorflow.python.eager import backprop_util
 from tensorflow.python.framework import auto_control_deps
 from tensorflow.python.framework import auto_control_deps_utils as acd
@@ -829,14 +830,22 @@ def _copy_handle_data(external_tensors, *branch_graph_outputs):
       internal_handle_data.append(handle_data)
     else:  # There is handle data, so we need to combine it.
       combined_shape = tensor_shape.TensorShape(None)
+      combined_dtype = None
       for handle_data in internal_handle_data:
         handle_shape = tensor_shape.TensorShape(
             handle_data.shape_and_type[0].shape)
         combined_shape = combined_shape.most_specific_compatible_shape(
             handle_shape)
+        if combined_dtype is None:
+          combined_dtype = handle_data.shape_and_type[0].dtype
+        elif handle_data.shape_and_type[0].dtype != combined_dtype:
+          # Variants from different branches have different dtypes. The
+          # combined variant has no static dtype.
+          combined_dtype = types_pb2.DT_INVALID
       combined_handle_data = internal_handle_data[0]
       combined_handle_data.shape_and_type[0].shape.CopyFrom(
           combined_shape.as_proto())
+      combined_handle_data.shape_and_type[0].dtype = combined_dtype
       handle_data_util.set_handle_data(external, combined_handle_data)
 
 
diff --git a/tensorflow/python/ops/gradient_checker_v2.py b/tensorflow/python/ops/gradient_checker_v2.py
index 3ca0903c80c..ce5a4f76678 100644
--- a/tensorflow/python/ops/gradient_checker_v2.py
+++ b/tensorflow/python/ops/gradient_checker_v2.py
@@ -292,7 +292,7 @@ def _compute_gradient_list(f, xs, delta):
 
 
 @tf_export("test.compute_gradient", v1=[])
-def compute_gradient(f, x, delta=1e-3):
+def compute_gradient(f, x, delta=None):
   """Computes the theoretical and numeric Jacobian of `f`.
 
   With y = f(x), computes the theoretical and numeric Jacobian dy/dx.
@@ -329,6 +329,12 @@ def compute_gradient(f, x, delta=1e-3):
     raise ValueError(
         "`x` must be a list or tuple of values convertible to a Tensor "
         "(arguments to `f`), not a %s" % type(x))
+  if delta is None:
+    # By default, we use a step size for the central finite difference
+    # approximation that is exactly representable as a binary floating
+    # point number, since this reduces the amount of noise due to rounding
+    # in the approximation of some functions.
+    delta = 1.0 / 1024
   return _compute_gradient_list(f, x, delta)
 
 
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 28a5a9f8509..ac4ba701713 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -5137,6 +5137,7 @@ def non_max_suppression_padded(boxes,
     selected_indices = tf.slice(
         selected_indices_padded, tf.constant([0]), num_valid)
     selected_boxes = tf.gather(boxes, selected_indices)
+    ```
 
   Args:
     boxes: a tensor of rank 2 or higher with a shape of [..., num_boxes, 4].
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 145c2b0195c..63773ee0f95 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -458,7 +458,7 @@ class DatasetInitializer(TableInitializerBase):
   """
 
   def __init__(self, dataset):
-    """Creates a table initializser from a `tf.data.Dataset`.
+    """Creates a table initializer from a `tf.data.Dataset`.
 
     Args:
       dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
@@ -633,8 +633,7 @@ class TextFileInitializer(TableInitializerBase):
   >>> init = tf.lookup.TextFileInitializer(
   ...   filename=f.name,
   ...   key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
-  ...   value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER,
-  ...   delimiter=" ")
+  ...   value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
   >>> table = tf.lookup.StaticHashTable(init, -1)
   >>> table.lookup(tf.constant('palmer 30')).numpy()
   2
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index f641687e990..b7650846d33 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -45,6 +45,7 @@ from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gen_list_ops
 from tensorflow.python.ops import gen_nn_ops
 from tensorflow.python.ops import gradient_checker_v2
@@ -1159,6 +1160,21 @@ class TensorListTest(PForTestCase):
     self._test_loop_fn(loop_fn, 2)
 
 
+class OptionalTest(PForTestCase):
+
+  def test_optional_from_value(self):
+
+    def loop_fn(i):
+      o = gen_dataset_ops.optional_from_value(
+          [i, i + 1, constant_op.constant(3)])
+      gen_dataset_ops.optional_none()
+      return gen_dataset_ops.optional_get_value(
+          o, [dtypes.int32, dtypes.int32, dtypes.int32],
+          [[], [], []])
+
+    self._test_loop_fn(loop_fn, 2)
+
+
 class StackTest(PForTestCase):
 
   @test_util.run_v1_only("b/122612051")
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index 2b02d5e30d3..c9431fa8fa7 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -46,6 +46,7 @@ from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import custom_gradient
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gen_image_ops
 from tensorflow.python.ops import gen_linalg_ops
 from tensorflow.python.ops import gen_list_ops
@@ -83,22 +84,45 @@ def _variant_handle_data(t):
   handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
   if not handle_data.is_set:
     return None
-  if len(handle_data.shape_and_type) != 1:
-    raise ValueError("Expected handle data of length 1, got {!r} of length {}"
-                     .format(handle_data, len(handle_data.shape_and_type)))
-  return handle_data.shape_and_type[0]
+  return handle_data.shape_and_type
 
 
-def _is_tensor_list(t):
-  """True if `t` is a TensorList, False if it isn't, None if unknown."""
+def _is_variant_with_internal_stacking(t):
+  """Identifies variant tensors which pfor always maintains as scalars.
+
+  For these, the pfor tensor is recorded as "stacked" if the content of the
+  variant tensor (e.g. the elements of a TensorList) are all stacked.
+
+  Args:
+    t: A tensor to identify.
+  Returns:
+    True if `t` is a TensorList/Optional, False not, None if unknown.
+  """
   if t.dtype != dtypes.variant:
     return False
-  shape_and_type = _variant_handle_data(t)
-  if shape_and_type is None:
-    # TODO(b/169968286): Identify all variant tensors (e.g. optionals) and we
-    # can make this an error instead of assuming TensorLists have handle data.
-    return None  # Presumed not a TensorList
-  return shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST
+  shapes_and_types = _variant_handle_data(t)
+  if shapes_and_types is None or not shapes_and_types:
+    # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
+    # make this an error instead of assuming TensorLists have handle data.
+    return None  # Presumed not a TensorList/Optional
+  return (shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST or
+          shapes_and_types[0].specialized_type == types_pb2.ST_OPTIONAL)
+
+
+def _parse_variant_shapes_and_types(t):
+  """Extracts shape and dtype information from a variant tensor `t`."""
+  shapes_and_types = _variant_handle_data(t)
+  if shapes_and_types is None or not shapes_and_types:
+    raise ValueError("Required handle data not set for {!r}".format(t))
+  if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
+    return shapes_and_types
+  else:
+    if shapes_and_types[0].specialized_type != types_pb2.ST_INVALID:
+      return shapes_and_types
+    else:
+      raise ValueError(
+          "Attempted to stack a variant-dtype tensor with no type set ({!r})"
+          .format(t))
 
 
 def _stack(t, length):
@@ -109,23 +133,19 @@ def _stack(t, length):
   # suitable since operations on stacked handles may expect a vectorized version
   # of the variant.
   if t.dtype == dtypes.variant:
-    shape_and_type = _variant_handle_data(t)
-    if shape_and_type is None:
-      raise ValueError("Required handle data not set for {!r}".format(t))
-    if shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST:
+    shapes_and_types = _parse_variant_shapes_and_types(t)
+    if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
+      if len(shapes_and_types) != 1:
+        raise ValueError(
+            "Expected handle data of length 1, got {!r} of length {}"
+            .format(shapes_and_types, len(shapes_and_types)))
       return wrap(
-          _stack_tensor_list(t, shape_and_type.dtype, length),
+          _stack_tensor_list(t, shapes_and_types[0].dtype, length),
           True)
     else:
-      if shape_and_type.specialized_type != types_pb2.ST_INVALID:
-        raise ValueError(
-            ("Attempted to stack an unhandled variant-dtype tensor of "
-             "type {!r} ({!r})").format(
-                 shape_and_type.specialized_type, t))
-      else:
-        raise ValueError(
-            "Attempted to stack a variant-dtype tensor with no type set ({!r})"
-            .format(t))
+      raise ValueError(
+          ("Attempted to stack an unhandled variant-dtype tensor of "
+           "type {!r} ({!r})").format(shapes_and_types[0].specialized_type, t))
   ones = array_ops.ones_like(array_ops.shape(t))
   ones = array_ops.reshape(ones, [-1])
   length = array_ops.reshape(length, [-1])
@@ -1629,7 +1649,7 @@ class PFor(object):
                 else:
                   batch_dim = tensor_shape.TensorShape(loop_len)
                 output_shape = batch_dim.concatenate(output_shape)
-              if _is_tensor_list(new_output.t):
+              if _is_variant_with_internal_stacking(new_output.t):
                 new_output.t.set_shape([])
               else:
                 new_output.t.set_shape(output_shape)
@@ -3602,7 +3622,7 @@ def _stack_tensor_list_shape(shape, first_dim):
 
 def _tile_variant_with_length(t, length):
   """stacks `t` `length` times."""
-  if _is_tensor_list(t):
+  if _is_variant_with_internal_stacking(t):
     # The content of TensorLists is vectorized, not the variant itself.
     return t
   original_tensor = t
@@ -3622,16 +3642,41 @@ def _tile_variant(t, pfor_input):
 
 
 def _untile_variant(t):
-  if _is_tensor_list(t):
+  if _is_variant_with_internal_stacking(t):
     # The content of TensorLists is vectorized, not the variant itself.
     if not t.shape.is_compatible_with([]):
       raise AssertionError(
-          "Unexpectedly saw a TensorList with non-scalar shape: {!r}"
-          .format(t))
+          ("Unexpectedly saw a vectorized variant (e.g. TensorList) with "
+           "non-scalar shape: {!r}").format(t))
     return t
   return array_ops.gather(t, 0)
 
 
+@RegisterPFor("OptionalFromValue")
+def _convert_optional_from_value(pfor_input):
+  pfor_input.stack_inputs()
+  return wrap(
+      gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]),
+      True)
+
+
+@RegisterPFor("OptionalGetValue")
+def _convert_optional_get_value(pfor_input):
+  handle = pfor_input.stacked_input(0)
+  output_types = pfor_input.get_attr("output_types")
+  original_output_shapes = pfor_input.get_attr("output_shapes")
+  output_shapes = []
+  for shape in original_output_shapes:
+    shape = tensor_shape.TensorShape(shape)
+    loop_len_shape = tensor_shape.TensorShape(
+        [tensor_util.constant_value(pfor_input.pfor.loop_len_vector)])
+    shape = loop_len_shape.concatenate(shape)
+    output_shapes.append(shape.as_proto())
+  results = gen_dataset_ops.optional_get_value(handle, output_types,
+                                               output_shapes)
+  return [wrap(t, True) for t in results]
+
+
 @RegisterPFor("TensorListReserve")
 def _convert_tensor_list_reserve(pfor_input):
   element_shape = pfor_input.unstacked_input(0)
@@ -4275,7 +4320,7 @@ class WhileV2(object):
       shape = shape.merge_with(output_shapes[i])
       pfor_input = self._pfor_input.input(i)
       if pfor_input.is_stacked:
-        if _is_tensor_list(pfor_input.t):
+        if _is_variant_with_internal_stacking(pfor_input.t):
           shape = tensor_shape.TensorShape([]).concatenate(shape)
         else:
           shape = tensor_shape.TensorShape([None]).concatenate(shape)
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 53f6a2b0492..8575cdf3da5 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -545,7 +545,31 @@ def py_func_common(func, inp, Tout, stateful=True, name=None):
     `tf.compat.v1.py_func()` and you must pin the created operation to a device
     in that
     server (e.g. using `with tf.device():`).
+    
+  Note: It produces tensors of unknown shape and rank as shape inference 
+    does not work on arbitrary Python code.
+    If you need the shape, you need to set it based on statically 
+    available information.
+    
+    E.g.
+    ```python
+    import tensorflow as tf
+    import numpy as np
 
+    def make_synthetic_data(i):
+        return np.cast[np.uint8](i) * np.ones([20,256,256,3],
+                dtype=np.float32) / 10.
+
+    def preprocess_fn(i):
+        ones = tf.py_function(make_synthetic_data,[i],tf.float32)
+        ones.set_shape(tf.TensorShape([None, None, None, None]))
+        ones = tf.image.resize(ones, [224,224])
+        return ones
+
+    ds = tf.data.Dataset.range(10)
+    ds = ds.map(preprocess_fn)
+    ```
+    
   Args:
     func: A Python function, which accepts `ndarray` objects as arguments and
       returns a list of `ndarray` objects (or a single `ndarray`). This function
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 3e3751e3ca6..5f6b6453e15 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -111,6 +111,15 @@ def from_dense(tensor, name=None):
   Only elements not equal to zero will be present in the result. The resulting
   `SparseTensor` has the same dtype and shape as the input.
 
+  >>> sp = tf.sparse.from_dense([0, 0, 3, 0, 1])
+  >>> sp.shape.as_list()
+  [5]
+  >>> sp.values.numpy()
+  array([3, 1], dtype=int32)
+  >>> sp.indices.numpy()
+  array([[2],
+         [4]])
+
   Args:
     tensor: A dense `Tensor` to be converted to a `SparseTensor`.
     name: Optional name for the op.
@@ -1602,23 +1611,24 @@ def sparse_tensor_to_dense(sp_input,
                            name=None):
   """Converts a `SparseTensor` into a dense tensor.
 
-  This op is a convenience wrapper around `sparse_to_dense` for `SparseTensor`s.
+  For this sparse tensor with three non-empty values:
 
-  For example, if `sp_input` has shape `[3, 5]` and non-empty string values:
+  >>> sp_input = tf.SparseTensor(
+  ...   dense_shape=[3, 5],
+  ...   values=[7, 8, 9],
+  ...   indices =[[0, 1],
+  ...             [0, 3],
+  ...             [2, 0]])
 
-      [0, 1]: a
-      [0, 3]: b
-      [2, 0]: c
+  The output will be a dense `[3, 5]` tensor with values:
 
-  and `default_value` is `x`, then the output will be a dense `[3, 5]`
-  string tensor with values:
+  >>> tf.sparse.to_dense(sp_input).numpy()
+  array([[0, 7, 0, 8, 0],
+         [0, 0, 0, 0, 0],
+         [9, 0, 0, 0, 0]], dtype=int32)
 
-      [[x a x b x]
-       [x x x x x]
-       [c x x x x]]
-
-  Indices must be without repeats.  This is only
-  tested if `validate_indices` is `True`.
+  Note: Indices must be without repeats.  This is only tested if
+  `validate_indices` is `True`.
 
   Args:
     sp_input: The input `SparseTensor`.
diff --git a/tensorflow/python/ops/structured/BUILD b/tensorflow/python/ops/structured/BUILD
index 33834f0e914..81e8b37dc7d 100644
--- a/tensorflow/python/ops/structured/BUILD
+++ b/tensorflow/python/ops/structured/BUILD
@@ -18,17 +18,60 @@ py_library(
     srcs_version = "PY2AND3",
     tags = ["nofixdeps"],
     deps = [
+        ":structured_array_ops",
         ":structured_tensor",
     ],
 )
 
 py_library(
     name = "structured_tensor",
-    srcs = ["structured_tensor.py"],
+    srcs = [
+        "structured_array_ops.py",
+        "structured_tensor.py",
+    ],
     deps = [
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:check_ops",
+        "//tensorflow/python:composite_tensor",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python:math_ops",
         "//tensorflow/python:tensor_shape",
+        "//tensorflow/python:tensor_spec",
+        "//tensorflow/python:type_spec",
+        "//tensorflow/python:util",
+        "//tensorflow/python/ops/ragged:ragged_factory_ops",
         "//tensorflow/python/ops/ragged:ragged_tensor",
+        "//tensorflow/python/ops/ragged:row_partition",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_library(
+    name = "structured_array_ops",
+    srcs = [
+        "structured_array_ops.py",
+    ],
+    deps = [
+        ":structured_tensor",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:check_ops",
+        "//tensorflow/python:composite_tensor",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:tensor_shape",
+        "//tensorflow/python:tensor_spec",
+        "//tensorflow/python:type_spec",
+        "//tensorflow/python:util",
+        "//tensorflow/python/ops/ragged:ragged_factory_ops",
+        "//tensorflow/python/ops/ragged:ragged_tensor",
+        "//tensorflow/python/ops/ragged:row_partition",
+        "//third_party/py/numpy",
     ],
 )
 
@@ -37,13 +80,23 @@ py_test(
     srcs = ["structured_tensor_test.py"],
     python_version = "PY3",
     deps = [
+        ":structured_array_ops",
         ":structured_tensor",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:errors",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:tensor_shape",
+        "//tensorflow/python:tensor_spec",
+        "//tensorflow/python/eager:context",
         "//tensorflow/python/ops/ragged:ragged_factory_ops",
         "//tensorflow/python/ops/ragged:ragged_tensor",
+        "//tensorflow/python/ops/ragged:row_partition",
+        "//third_party/py/numpy",
         "@absl_py//absl/testing:parameterized",
     ],
 )
diff --git a/tensorflow/python/ops/structured/structured_array_ops.py b/tensorflow/python/ops/structured/structured_array_ops.py
new file mode 100644
index 00000000000..dca8084575e
--- /dev/null
+++ b/tensorflow/python/ops/structured/structured_array_ops.py
@@ -0,0 +1,157 @@
+# Lint as python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""StructuredTensor array ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.ragged.row_partition import RowPartition
+from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
+from tensorflow.python.util import deprecation
+from tensorflow.python.util import dispatch
+
+
+@dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor)
+@deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim')
+def expand_dims(input, axis=None, name=None, dim=None):  # pylint: disable=redefined-builtin
+  """Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
+
+  This is an implementation of tf.expand_dims for StructuredTensor. Note
+  that the `axis` must be less than or equal to rank.
+
+  >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
+  >>> tf.expand_dims(st, 0).to_pyval()
+  [[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
+  >>> tf.expand_dims(st, 1).to_pyval()
+  [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
+  >>> tf.expand_dims(st, 2).to_pyval()
+  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
+  >>> tf.expand_dims(st, -1).to_pyval()  # -1 is the same as 2
+  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
+
+  Args:
+    input: the original StructuredTensor.
+    axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
+    name: the name of the op.
+    dim: deprecated: use axis.
+
+  Returns:
+    a new structured tensor with larger rank.
+
+  Raises:
+    an error if `axis < -(rank + 1)` or `rank < axis`.
+  """
+  axis = deprecation.deprecated_argument_lookup('axis', axis, 'dim', dim)
+  return _expand_dims_impl(input, axis, name=name)
+
+
+@dispatch.dispatch_for_types(array_ops.expand_dims_v2, StructuredTensor)
+def expand_dims_v2(input, axis, name=None):  # pylint: disable=redefined-builtin
+  """Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
+
+  This is an implementation of tf.expand_dims for StructuredTensor. Note
+  that the `axis` must be less than or equal to rank.
+
+  >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
+  >>> tf.expand_dims(st, 0).to_pyval()
+  [[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
+  >>> tf.expand_dims(st, 1).to_pyval()
+  [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
+  >>> tf.expand_dims(st, 2).to_pyval()
+  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
+  >>> tf.expand_dims(st, -1).to_pyval()  # -1 is the same as 2
+  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
+
+  Args:
+    input: the original StructuredTensor.
+    axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
+    name: the name of the op.
+
+  Returns:
+    a new structured tensor with larger rank.
+
+  Raises:
+    an error if `axis < -(rank + 1)` or `rank < axis`.
+  """
+  return _expand_dims_impl(input, axis, name=name)
+
+
+def _expand_dims_impl(st, axis, name=None):  # pylint: disable=redefined-builtin
+  """Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
+
+  This is an implementation of tf.expand_dims for StructuredTensor. Note
+  that the `axis` must be less than or equal to rank.
+
+  >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
+  >>> tf.expand_dims(st, 0).to_pyval()
+  [[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
+  >>> tf.expand_dims(st, 1).to_pyval()
+  [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
+  >>> tf.expand_dims(st, 2).to_pyval()
+  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
+  >>> tf.expand_dims(st, -1).to_pyval()  # -1 is the same as 2
+  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
+
+  Args:
+    st: the original StructuredTensor.
+    axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
+    name: the name of the op.
+
+  Returns:
+    a new structured tensor with larger rank.
+
+  Raises:
+    an error if `axis < -(rank + 1)` or `rank < axis`.
+  """
+  axis = array_ops.get_positive_axis(
+      axis, st.rank + 1, axis_name='axis', ndims_name='rank(st)')
+  with ops.name_scope(name, 'ExpandDims', [st, axis]):
+    new_fields = {
+        k: array_ops.expand_dims(v, axis)
+        for (k, v) in st._fields.items()
+    }
+    new_shape = st.shape[:axis] + (1,) + st.shape[axis:]
+    new_row_partitions = _expand_st_row_partitions(st, axis)
+    new_nrows = st.nrows() if (axis > 0) else 1
+    return StructuredTensor.from_fields(
+        new_fields,
+        shape=new_shape,
+        row_partitions=new_row_partitions,
+        nrows=new_nrows)
+
+
+def _expand_st_row_partitions(st, axis):
+  """Create the row_partitions for expand_dims."""
+  if axis == 0:
+    if st.shape.rank == 0:
+      return ()
+    nvals = st.nrows()
+    new_partition = RowPartition.from_uniform_row_length(
+        nvals, nvals, nrows=1, validate=False)
+    return (new_partition,) + st.row_partitions
+  elif axis == st.rank:
+    nvals = (
+        st.row_partitions[axis - 2].nvals() if (axis - 2 >= 0) else st.nrows())
+    return st.row_partitions + (RowPartition.from_uniform_row_length(
+        1, nvals, nrows=nvals, validate=False),)
+  else:
+    nvals = (
+        st.row_partitions[axis - 1].nrows() if (axis - 1 >= 0) else st.nrows())
+    return st.row_partitions[:axis - 1] + (RowPartition.from_uniform_row_length(
+        1, nvals, nrows=nvals, validate=False),) + st.row_partitions[axis - 1:]
diff --git a/tensorflow/python/ops/structured/structured_tensor_test.py b/tensorflow/python/ops/structured/structured_tensor_test.py
index 28acfbb3304..9e064bb9dcd 100644
--- a/tensorflow/python/ops/structured/structured_tensor_test.py
+++ b/tensorflow/python/ops/structured/structured_tensor_test.py
@@ -36,6 +36,10 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.ops.ragged import row_partition
+
+# TODO(b/173144447): remove when structured_array_ops is included in init.
+from tensorflow.python.ops.structured import structured_array_ops  # pylint: disable=unused-import
+
 from tensorflow.python.ops.structured import structured_tensor
 from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
 from tensorflow.python.platform import googletest
@@ -500,8 +504,6 @@ class StructuredTensorTest(test_util.TensorFlowTestCase,
     self.assertAllEqual(struct4.field_value("s"), struct2)
 
   def testPartitionOuterDims(self):
-    if not context.executing_eagerly():
-      return  # TESTING
     a = dict(x=1, y=[1, 2])
     b = dict(x=2, y=[3, 4])
     c = dict(x=3, y=[5, 6])
@@ -977,6 +979,114 @@ class StructuredTensorTest(test_util.TensorFlowTestCase,
         r"or equal to inner_axis \(1\)"):
       st.merge_dims(2, 1)
 
+  @parameterized.named_parameters([
+      dict(
+          testcase_name="0D_0",
+          st={"x": 1},
+          axis=0,
+          expected=[{"x": 1}]),
+      dict(
+          testcase_name="0D_minus_1",
+          st={"x": 1},
+          axis=-1,
+          expected=[{"x": 1}]),
+      dict(
+          testcase_name="1D_0",
+          st=[{"x": [1, 3]}, {"x": [2, 7, 9]}],
+          axis=0,
+          expected=[[{"x": [1, 3]}, {"x": [2, 7, 9]}]]),
+      dict(
+          testcase_name="1D_1",
+          st=[{"x": [1]}, {"x": [2, 10]}],
+          axis=1,
+          expected=[[{"x": [1]}], [{"x": [2, 10]}]]),
+      dict(
+          testcase_name="2D_0",
+          st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]],
+          axis=0,
+          expected=[[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]]]),
+      dict(
+          testcase_name="2D_1",
+          st=[[{"x": 1}, {"x": 2}], [{"x": 3}]],
+          axis=1,
+          expected=[[[{"x": 1}, {"x": 2}]], [[{"x": 3}]]]),
+      dict(
+          testcase_name="2D_2",
+          st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]],
+          axis=2,
+          expected=[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3, 4]}]]]),
+      dict(
+          testcase_name="3D_0",
+          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
+          axis=0,
+          expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]],
+                     [[{"x": [4, 5]}]]]]),
+      dict(
+          testcase_name="3D_minus_4",
+          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
+          axis=-4,  # same as zero
+          expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]],
+                     [[{"x": [4, 5]}]]]]),
+      dict(
+          testcase_name="3D_1",
+          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
+          axis=1,
+          expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]],
+                    [[[{"x": [4, 5]}]]]]),
+      dict(
+          testcase_name="3D_minus_3",
+          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
+          axis=-3,  # same as 1
+          expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]],
+                    [[[{"x": [4, 5]}]]]]),
+      dict(
+          testcase_name="3D_2",
+          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
+          axis=2,
+          expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]],
+                    [[[{"x": [4, 5]}]]]]),
+      dict(
+          testcase_name="3D_minus_2",
+          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
+          axis=-2,  # same as 2
+          expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]],
+                    [[[{"x": [4, 5]}]]]]),
+      dict(
+          testcase_name="3D_3",
+          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
+          axis=3,
+          expected=[[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3]}]]],
+                    [[[{"x": [4, 5]}]]]]),
+  ])  # pyformat: disable
+  def testExpandDims(self, st, axis, expected):
+    st = StructuredTensor.from_pyval(st)
+    result = array_ops.expand_dims(st, axis)
+    self.assertAllEqual(result, expected)
+
+  def testExpandDimsAxisTooBig(self):
+    st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]]
+    st = StructuredTensor.from_pyval(st)
+    with self.assertRaisesRegex(ValueError,
+                                "axis=4 out of bounds: expected -4<=axis<4"):
+      array_ops.expand_dims(st, 4)
+
+  def testExpandDimsAxisTooSmall(self):
+    st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]]
+    st = StructuredTensor.from_pyval(st)
+    with self.assertRaisesRegex(ValueError,
+                                "axis=-5 out of bounds: expected -4<=axis<4"):
+      array_ops.expand_dims(st, -5)
+
+  def testExpandDimsScalar(self):
+    # Note that if we expand_dims for the final dimension and there are scalar
+    # fields, then the shape is (2, None, None, 1), whereas if it is constructed
+    # from pyval it is (2, None, None, None).
+    st = [[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]]
+    st = StructuredTensor.from_pyval(st)
+    result = array_ops.expand_dims(st, 3)
+    expected_shape = tensor_shape.TensorShape([2, None, None, 1])
+    self.assertEqual(repr(expected_shape), repr(result.shape))
+
   def testTupleFieldValue(self):
     st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
     self.assertAllEqual(st.field_value(("a",)), 5)
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 6be1ea04968..bd344fccca9 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -1385,4 +1385,7 @@ def trace_off():
     context.context().disable_run_metadata()
 
   if profiler:
-    _profiler.stop()
+    try:
+      _profiler.stop()
+    except _profiler.ProfilerNotRunningError:
+      pass
diff --git a/tensorflow/python/saved_model/README.md b/tensorflow/python/saved_model/README.md
index fe69f3beb01..a5c0aa894f6 100644
--- a/tensorflow/python/saved_model/README.md
+++ b/tensorflow/python/saved_model/README.md
@@ -3,49 +3,37 @@
 [TOC]
 
 ## Overview
-This document describes SavedModel, the universal serialization format for
+
+SavedModel is the universal serialization format for
 [TensorFlow](https://www.tensorflow.org/) models.
 
-SavedModel provides a language-neutral format to save machine-learned models
+SavedModel provides a language-neutral format to save machine-learning models
 that is recoverable and hermetic. It enables higher-level systems and tools to
 produce, consume and transform TensorFlow models.
 
-## Features
+## Guides
+* [Using the SavedModel Format](https://www.tensorflow.org/guide/saved_model)
+* [Save and load Keras models](https://www.tensorflow.org/guide/keras/save_and_serialize)
+* [Save and load with checkpointing in Keras](https://www.tensorflow.org/tutorials/keras/save_and_load)
+* [Training checkpoints](https://www.tensorflow.org/guide/checkpoint)
+* [Save and load a model using a distribution strategy](https://www.tensorflow.org/tutorials/distribute/save_and_load)
 
-The following is a summary of the features in SavedModel:
 
-* Multiple graphs sharing a single set of variables and assets can be added to a
-  single SavedModel. Each graph is associated with a specific set of tags to
-  allow identification during a load or restore operation.
-* Support for `SignatureDefs`
-    * Graphs that are used for inference tasks typically have a set of inputs
-      and outputs. This is called a `Signature`.
-    * SavedModel uses [SignatureDefs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/meta_graph.proto)
-      to allow generic support for signatures that may need to be saved with the graphs.
-    * For commonly used SignatureDefs in the context of TensorFlow Serving,
-      please see documentation [here](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md).
-* Support for `Assets`.
-    * For cases where ops depend on external files for initialization, such as
-      vocabularies, SavedModel supports this via `assets`.
-    * Assets are copied to the SavedModel location and can be read when loading
-      a specific meta graph def.
-* Support to clear devices before generating the SavedModel.
+## [Public API](https://www.tensorflow.org/api_docs/python/tf/saved_model)
+* [`tf.saved_model.save`](https://www.tensorflow.org/api_docs/python/tf/saved_model/save)
+* [`tf.saved_model.load`](https://www.tensorflow.org/api_docs/python/tf/saved_model/load)
+* [`tf.saved_model.SaveOptions`](https://www.tensorflow.org/api_docs/python/tf/saved_model/SaveOptions)
+* [`tf.saved_model.LoadOptions`](https://www.tensorflow.org/api_docs/python/tf/saved_model/LoadOptions)
+* [`tf.saved_model.Asset`](https://www.tensorflow.org/api_docs/python/tf/saved_model/Asset)
+* [`tf.saved_model.contains_saved_model`](https://www.tensorflow.org/api_docs/python/tf/saved_model/contains_saved_model)
 
-The following is a summary of features that are NOT supported in SavedModel.
-Higher-level frameworks and tools that use SavedModel may provide these.
+### Related Modules and Functions
+* [`tf.keras.models.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model)
+* [`tf.keras.models.load_model`](https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model)
+* [`tf.train.Checkpoint`](https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint)
 
-* Implicit versioning.
-* Garbage collection.
-* Atomic writes to the SavedModel location.
 
-## Background
-SavedModel manages and builds upon existing TensorFlow primitives such as
-`TensorFlow Saver` and `MetaGraphDef`. Specifically, SavedModel wraps a [TensorFlow Saver](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/training/saver.py).
-The Saver is primarily used to generate the variable checkpoints. SavedModel
-will replace the existing [TensorFlow Inference Model Format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/README.md)
-as the canonical way to export TensorFlow graphs for serving.
-
-## Components
+## The SavedModel Format
 A SavedModel directory has the following structure:
 
 ```
@@ -57,72 +45,23 @@ variables/
 saved_model.pb
 ```
 
-* SavedModel protocol buffer
-    * `saved_model.pb` or `saved_model.pbtxt`
-    * Includes the graph definitions as `MetaGraphDef` protocol buffers.
-* Assets
-    * Subfolder called `assets`.
-    * Contains auxiliary files such as vocabularies, etc.
-* Extra assets
-    * Subfolder where higher-level libraries and users can add their own assets
-      that co-exist with the model, but are not loaded by the graph.
-    * This subfolder is not managed by the SavedModel libraries.
-* Variables
-    * Subfolder called `variables`.
-    * Includes output from the [TensorFlow Saver](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/training/saver.py).
-        * `variables.data-?????-of-?????`
-        * `variables.index`
+*   SavedModel protocol buffer
+    *   [`saved_model.pb`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_model.proto)
+        or `saved_model.pbtxt`
+    *   Includes the graph definitions as `MetaGraphDef` protocol buffers.
+*   Assets
+    *   Subfolder called `assets`.
+    *   Contains auxiliary files such as vocabularies, etc.
+*   Extra assets
+    *   Subfolder where higher-level libraries and users can add their own
+        assets that co-exist with the model, but are not loaded by the graph.
+    *   This subfolder is not managed by the SavedModel libraries.
+*   Variables
+    *   Subfolder called `variables`.
+        *   `variables.data-?????-of-?????`
+        *   `variables.index`
 
-## APIs
-The APIs for building and loading a SavedModel are described in this section.
-
-### Builder
-The SavedModel [builder](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/builder.py)
-is implemented in Python.
-
-The `SavedModelBuilder` class provides functionality to save multiple meta graph
-defs, associated variables and assets.
-
-To build a SavedModel, the first meta graph must be saved with variables.
-Subsequent meta graphs will simply be saved with their graph definitions. If
-assets need to be saved and written or copied to disk, they can be provided
-when the meta graph def is added. If multiple meta graph defs are associated
-with an asset of the same name, only the first version is retained.
-
-#### Tags
-Each meta graph added to the SavedModel must be annotated with user specified
-tags, which reflect the meta graph capabilities or use-cases.
-More specifically, these tags typically annotate a meta graph with its
-functionality (e.g. serving or training), and possibly hardware specific aspects
-such as GPU.
-In the SavedModel, the meta graph def whose tag-set exactly matches those
-specified in the loader API, will be the one loaded by the loader.
-If no meta graph def is found matching the specified tags, an error is returned.
-For example, a loader with a requirement to serve on GPU hardware would be able
-to load only meta graph annotated with tags='serve,gpu' by specifying this set
-of tags in tensorflow::LoadSavedModel(...).
-
-
-#### Usage
-The typical usage of `builder` is as follows:
-
-~~~python
-export_dir = ...
-...
-builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
-with tf.Session(graph=tf.Graph()) as sess:
-  ...
-  builder.add_meta_graph_and_variables(sess,
-                                       [tf.saved_model.tag_constants.TRAINING],
-                                       signature_def_map=foo_signatures,
-                                       assets_collection=foo_assets)
-...
-with tf.Session(graph=tf.Graph()) as sess:
-  ...
-  builder.add_meta_graph(["bar-tag", "baz-tag"])
-...
-builder.save()
-~~~
+---
 
 #### Stripping Default valued attributes
 The SavedModelBuilder class allows users to control whether default-valued
@@ -152,60 +91,3 @@ models regenerated with newer training binaries.
 TIP: If you care about forward compatibility, then set `strip_default_attrs`
 to `True` while using `SavedModelBuilder.add_meta_graph_and_variables` and
 `SavedModelBuilder.add_meta_graph`.
-
-### Loader
-The SavedModel loader is implemented in C++ and Python.
-
-#### Python
-The Python version of the SavedModel [loader](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/loader.py)
-provides load and restore capability for a SavedModel. The `load` operation
-requires the session in which to restore the graph definition and variables, the
-tags used to identify the meta graph def to load and the location of the
-SavedModel. Upon a load, the subset of variables and assets supplied as part of
-the specific meta graph def, will be restored into the supplied session.
-
-~~~python
-export_dir = ...
-...
-with tf.Session(graph=tf.Graph()) as sess:
-  tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)
-  ...
-~~~
-
-#### C++
-The C++ version of the SavedModel [loader](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/loader.h)
-provides an API to load a SavedModel from a path, while allowing
-`SessionOptions` and `RunOptions`. Similar to the Python version, the C++
-version requires the tags associated with the graph to be loaded, to be
-specified. The loaded version of SavedModel is referred to as `SavedModelBundle`
-and contains the meta graph def and the session within which it is loaded.
-
-~~~c++
-const string export_dir = ...
-SavedModelBundle bundle;
-...
-LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain},
-               &bundle);
-~~~
-
-### Constants
-SavedModel offers the flexibility to build and load TensorFlow graphs for a
-variety of use-cases. For the set of most common expected use-cases,
-SavedModel's APIs provide a set of constants in Python and C++ that are easy to
-reuse and share across tools consistently.
-
-#### Tag constants
-Sets of tags can be used to uniquely identify a `MetaGraphDef` saved in a
-SavedModel. A subset of commonly used tags is specified in:
-
-* [Python](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/tag_constants.py)
-* [C++](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/tag_constants.h).
-
-#### Signature constants
-SignatureDefs are used to define the signature of a computation supported in a
-TensorFlow graph. Commonly used input keys, output keys and method names are
-defined in:
-
-* [Python](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/signature_constants.py)
-* [C++](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/signature_constants.h).
-
diff --git a/tensorflow/python/tpu/session_support.py b/tensorflow/python/tpu/session_support.py
index 6e85c417f54..24da6efb4a4 100644
--- a/tensorflow/python/tpu/session_support.py
+++ b/tensorflow/python/tpu/session_support.py
@@ -106,7 +106,7 @@ class WorkerHeartbeatManager(object):
     self._session.run(self._ops,
                       {self._request_placeholder: message.SerializeToString()})
 
-  def ping(self, request=None, timeout_in_ms=5000):
+  def ping(self, request=None, timeout_in_ms=60000):
     """Ping all workers, returning the parsed status results."""
     if request is None:
       request = event_pb2.WorkerHeartbeatRequest()
diff --git a/tensorflow/security/fuzzing/BUILD b/tensorflow/security/fuzzing/BUILD
index a05b287e7ba..994191eb5d4 100644
--- a/tensorflow/security/fuzzing/BUILD
+++ b/tensorflow/security/fuzzing/BUILD
@@ -27,6 +27,25 @@ tf_fuzz_target(
     ],
 )
 
+tf_fuzz_target(
+    name = "parseURI_fuzz",
+    srcs = ["parseURI_fuzz.cc"],
+    deps = [
+        "//tensorflow/core/platform:path",
+        "//tensorflow/core/platform:stringpiece",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+tf_fuzz_target(
+    name = "cleanpath_fuzz",
+    srcs = ["cleanpath_fuzz.cc"],
+    deps = [
+        "//tensorflow/core/platform:path",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
 tf_fuzz_target(
     name = "consume_leading_digits_fuzz",
     srcs = ["consume_leading_digits_fuzz.cc"],
diff --git a/tensorflow/security/fuzzing/cleanpath_fuzz.cc b/tensorflow/security/fuzzing/cleanpath_fuzz.cc
new file mode 100644
index 00000000000..b535bb31fbf
--- /dev/null
+++ b/tensorflow/security/fuzzing/cleanpath_fuzz.cc
@@ -0,0 +1,43 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <cstdlib>
+#include <iostream>
+#include <regex>  // NOLINT
+
+#include "absl/strings/match.h"
+#include "tensorflow/core/platform/path.h"
+
+// This is a fuzzer for tensorflow::io::CleanPath.
+
+namespace {
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+  std::string input_path(reinterpret_cast<const char *>(data), size);
+  std::string clean_path = tensorflow::io::CleanPath(input_path);
+
+  // Assert there are no '/./' no directory changes.
+  assert(!absl::StrContains(clean_path, "/./"));
+  // Assert there are no duplicate '/'.
+  assert(!absl::StrContains(clean_path, "//"));
+  // Assert there are no higher up directories after entering a directory.
+  std::regex higher_up_directory("[^.]{1}/[.]{2}");
+  assert(!std::regex_match(clean_path, higher_up_directory));
+
+  return 0;
+}
+
+}  // namespace
diff --git a/tensorflow/security/fuzzing/parseURI_fuzz.cc b/tensorflow/security/fuzzing/parseURI_fuzz.cc
new file mode 100644
index 00000000000..f2230c0c444
--- /dev/null
+++ b/tensorflow/security/fuzzing/parseURI_fuzz.cc
@@ -0,0 +1,46 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <cstdlib>
+
+#include "absl/strings/match.h"
+#include "tensorflow/core/platform/path.h"
+#include "tensorflow/core/platform/stringpiece.h"
+
+// This is a fuzzer for tensorflow::io::CleanPath.
+
+namespace {
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+  std::string uri(reinterpret_cast<const char *>(data), size);
+  tensorflow::StringPiece scheme, host, path;
+  tensorflow::io::ParseURI(uri, &scheme, &host, &path);
+
+  // If a path is invalid.
+  if (path == uri) {
+    assert(host == "");
+    assert(scheme == "");
+  } else {
+    assert(absl::StrContains(uri, host));
+    assert(absl::StrContains(uri, scheme));
+    assert(absl::StrContains(uri, path));
+    assert(absl::StrContains(uri, "://"));
+  }
+
+  return 0;
+}
+
+}  // namespace
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index c9419753437..fd7ce8cd515 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -119,6 +119,12 @@ def tf_portable_full_lite_protos(full, lite):
         "//conditions:default": full,
     })
 
+def if_no_default_logger(a):
+    return select({
+        clean_dep("//tensorflow:no_default_logger"): a,
+        "//conditions:default": [],
+    })
+
 def if_android_x86(a):
     return select({
         clean_dep("//tensorflow:android_x86"): a,
@@ -332,6 +338,7 @@ def tf_copts(
         if_android_arm(["-mfpu=neon"]) +
         if_linux_x86_64(["-msse3"]) +
         if_ios_x86_64(["-msse4.1"]) +
+        if_no_default_logger(["-DNO_DEFAULT_LOGGER"]) +
         select({
             clean_dep("//tensorflow:framework_shared_object"): [],
             "//conditions:default": ["-DTENSORFLOW_MONOLITHIC_BUILD"],
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt
index 88e4ecfbb62..0a297df71a9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt
index 89e0718d5b6..1ec4e30e3f2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt
index 29b1fba5aae..a6deef2da30 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt
index c481aa07ace..ccb9bebb195 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt
index a2b9d310eb9..eab98ac6d06 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt
index 650ac77d6df..81cce2fbbec 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt
index 50e3da3eda5..badb430ca1f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
index ab8391e0465..bcc8c6019db 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt
index 2bad07d9998..db673eedc11 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-conversion-params.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-conversion-params.pbtxt
index 5eed1aa7d0a..e3f7e3639da 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-conversion-params.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-conversion-params.pbtxt
@@ -7,14 +7,6 @@ tf_class {
     name: "allow_build_at_runtime"
     mtype: "<type \'property\'>"
   }
-  member {
-    name: "is_dynamic_op"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "max_batch_size"
-    mtype: "<type \'property\'>"
-  }
   member {
     name: "max_workspace_size_bytes"
     mtype: "<type \'property\'>"
@@ -31,10 +23,6 @@ tf_class {
     name: "precision_mode"
     mtype: "<type \'property\'>"
   }
-  member {
-    name: "rewriter_config_template"
-    mtype: "<type \'property\'>"
-  }
   member {
     name: "use_calibration"
     mtype: "<type \'property\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt
index 88e4ecfbb62..0a297df71a9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt
index 89e0718d5b6..1ec4e30e3f2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt
index 29b1fba5aae..a6deef2da30 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt
index c481aa07ace..ccb9bebb195 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt
index a2b9d310eb9..eab98ac6d06 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt
index 650ac77d6df..81cce2fbbec 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt
index 50e3da3eda5..badb430ca1f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
index ab8391e0465..bcc8c6019db 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt
index 2bad07d9998..db673eedc11 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt
index 605cb27c36b..e4552623c02 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt
index a436583fdd6..8b97923bcbf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt
index f874658cc25..d92c792df1b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt
index 6798187be77..e015f67cdcb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt
index 15efc6ada39..43252585d22 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt
index 00cf3e0e24e..a1866674ec5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt
index 881d15c5306..77ae5d5ccff 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt
index 661e9cb5a58..068c4b7cf40 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt
index a14c9a4ce57..e2defb8a290 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt
@@ -30,7 +30,7 @@ tf_class {
   }
   member_method {
     name: "add_slot"
-    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
+    argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], "
   }
   member_method {
     name: "add_weight"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt
index b23d3b9f01b..2d4729f1867 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt
@@ -18,7 +18,7 @@ tf_module {
   }
   member_method {
     name: "compute_gradient"
-    argspec: "args=[\'f\', \'x\', \'delta\'], varargs=None, keywords=None, defaults=[\'0.001\'], "
+    argspec: "args=[\'f\', \'x\', \'delta\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "create_local_cluster"
diff --git a/tensorflow/tools/ci_build/a100/nightly.sh b/tensorflow/tools/ci_build/a100/nightly.sh
index e7756da8089..f96429cbbba 100644
--- a/tensorflow/tools/ci_build/a100/nightly.sh
+++ b/tensorflow/tools/ci_build/a100/nightly.sh
@@ -20,8 +20,10 @@ cd tensorflow/tools/ci_build
 docker build -t gpu_test_container:latest -f \
    Dockerfile.rbe.cuda11.0-cudnn8-ubuntu18.04-manylinux2010-multipython .
 
+DEFAULT_BAZEL_TARGETS="//tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/mlir/tosa/... -//tensorflow/compiler/xrt/... //tensorflow/compiler/mlir/lite/... -//tensorflow/lite/micro/examples/... -//tensorflow/core/tpu/..."
+
 docker run --rm \
   --gpus all \
   --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 \
-  gpu_test_container:latest bash -c "python3.7 -m pytest"
+  gpu_test_container:latest bash -c "cd tensorflow_src; bazel test --config=rbe_linux_cuda11.0_nvcc_py3.6 --config=tensorflow_testing_rbe_linux --test_tag_filters=gpu,-no_gpu,-nogpu,-benchmark-test,-no_oss,-oss_serial,-v1only,-no_gpu_presubmit,-no_cuda11 -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/..."
 
diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
index 1049939c94b..2d1c5138866 100644
--- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
+++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
@@ -168,9 +168,17 @@ def get_pybind_export_symbols(symbols_file, lib_paths_file):
     lib_paths_file: String that is the path to txt file that lists
                     cc_library target execpaths for exporting symbols.
   """
-  # cc_library target name is always in [target_name] format in
-  # `symbols_pybind.txt`.
-  section_header_filter = r"\[(\S+)\]"  # e.g. `[cpp_python_util]`
+  # A cc_library target name must begin its own line, and it must begin with
+  # `//tensorflow`. It can then optionally have some number of directories, and
+  # it must end with a target name directly preceded by either a slash or a
+  # colon. A directory or target name is any combination of letters, numbers,
+  # underscores, and dashes.
+  # Examples of possible headers:
+  # `[//tensorflow/core/util/tensor_bundle]`
+  # `[//tensorflow/python:safe_ptr]`
+  # `[//tensorflow:target_name_v2_25]`
+  # `[//tensorflow/-/24/util_:port-5]`
+  section_header_filter = r"^\[\/\/(tensorflow(\/[\w-]+)*(:|\/)[\w-]+)\]"
 
   # Create a dict of target libs and their symbols to be exported and populate
   # it. (key = cc_library target, value = list of symbols) that we need to
@@ -199,20 +207,31 @@ def get_pybind_export_symbols(symbols_file, lib_paths_file):
   symbols_all = []
   for lib in lib_paths:
     if lib:
-      for cc_lib in symbols:  # keys in symbols = cc_library target name
-        path_to_lib = cc_lib.split("/")
-        cc_target = path_to_lib[-1]
-        # if `len(path_to_lib)` is larger than 1, that means, we are given one
-        # or more parent directory of the target. e.g. `[foo/bar]` instead of
-        # just the target name `[bar]`.
-        if len(path_to_lib) > 1:
-          parent_dir = path_to_lib[0]
+      for cc_lib in symbols:   # keys in symbols = cc_library target name
+        if cc_lib.count(":") == 1:
+          formatted_cc_lib = cc_lib.replace(":", "/")
+        elif cc_lib.count(":") == 0:
+          formatted_cc_lib = cc_lib
         else:
-          parent_dir = ""
-        if cc_target in lib and parent_dir in lib:
-          symbols_all.extend(
-            get_symbols(lib, "|".join(symbols[cc_lib])))
-
+          raise ValueError(f"Detected wrong format for symbols header in"
+                           "`symbols_pybind.txt`. Header must have 0 or 1 "
+                           "colon (e.g. `[//third_party/tensorflow/python:safe_ptr]`"
+                           "or `[tensorflow/core/util/tensor_bundle]`) but "
+                           "detected: {cc_lib}")
+        path_to_lib = formatted_cc_lib.split("/")
+        # `path_to_lib` is a bazel out path, which means the actual path string
+        # we get here differs from the package path listed in
+        # `win_lib_files_for_exported_symbols` and `symbols_pybind.txt`.
+        # For example, the target `tensorflow/core:op_gen_lib` in
+        # `win_lib_files_for_exported_symbols` generates the bazel library path
+        # `bazel-out/x64_windows-opt/bin/tensorflow/core/framework/op_gen_lib.lib`
+        lib_and_cc_lib_match = True
+        for p in path_to_lib:
+          if p not in lib:
+            lib_and_cc_lib_match = False
+            break
+        if lib_and_cc_lib_match:
+          symbols_all.extend(get_symbols(lib, "|".join(symbols[cc_lib])))
   return symbols_all
 
 def main():
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index a6082788413..361283c5c90 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -1,4 +1,4 @@
-[cpp_python_util] # util tfe
+[//tensorflow/python:cpp_python_util] # util tfe
 tensorflow::swig::IsSequence
 tensorflow::swig::IsSequenceOrComposite
 tensorflow::swig::IsCompositeTensor
@@ -22,18 +22,18 @@ tensorflow::swig::RegisterType
 tensorflow::swig::IsEagerTensorSlow
 tensorflow::swig::GetRegisteredPyObject
 
-[util_port] # util_port
+[//tensorflow/core/util:port] # util_port
 tensorflow::IsGoogleCudaEnabled
 tensorflow::IsBuiltWithROCm
 tensorflow::IsBuiltWithNvcc
 tensorflow::GpuSupportsHalfMatMulAndConv
 tensorflow::IsMklEnabled
 
-[stream_executor_pimpl] # stat_summarizer
+[//tensorflow/stream_executor:stream_executor_pimpl] # stat_summarizer
 stream_executor::StreamExecutor::EnablePeerAccessTo
 stream_executor::StreamExecutor::CanEnablePeerAccessTo
 
-[print_model_analysis] # tfprof
+[//tensorflow/core/profiler/internal:print_model_analysis] # tfprof
 tensorflow::tfprof::NewProfiler
 tensorflow::tfprof::DeleteProfiler
 tensorflow::tfprof::AddStep
@@ -43,24 +43,17 @@ tensorflow::tfprof::Profile
 tensorflow::tfprof::PrintModelAnalysis
 tensorflow::tfprof::SerializeToString
 
-[graph_analyzer_tool] # graph_analyzer
+[//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool] # graph_analyze
 tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
 
-[bfloat16_lib] # bfloat16
+[//tensorflow/python:bfloat16_lib] # bfloat16
 tensorflow::RegisterNumpyBfloat16
 tensorflow::Bfloat16PyType
 
-[events_writer] # events_writer
-tensorflow::EventsWriter::Init
-tensorflow::EventsWriter::InitWithSuffix
-tensorflow::EventsWriter::WriteSerializedEvent
-tensorflow::EventsWriter::Flush
-tensorflow::EventsWriter::Close
-
-[py_func_lib] # py_func
+[//tensorflow/python:py_func_lib] # py_func
 tensorflow::InitializePyTrampoline
 
-[framework_internal_impl] # op_def_registry, dtypes
+[//tensorflow/core:framework_internal_impl] # op_def_registry, dtypes
 tensorflow::BaseType
 tensorflow::DataTypeString
 tensorflow::DataTypeIsComplex
@@ -73,29 +66,26 @@ tensorflow::OpRegistry::Global
 tensorflow::OpRegistry::LookUpOpDef
 tensorflow::RemoveNonDeprecationDescriptionsFromOpDef
 
-[lib_internal_impl]  # device_lib
+[//tensorflow/core:lib_internal_impl]  # device_lib
 tensorflow::Status::code
 tensorflow::Status::error_message
 tensorflow::Status::ok()
 
-[device]  # device_lib, tfe, tf_session
-tensorflow::Device::attributes
-
-[device_factory]  # device_lib, tfe, tf_session
+[//tensorflow/core/common_runtime:device_factory]  # device_lib, tfe, tf_session
 tensorflow::DeviceFactory::AddDevices
 tensorflow::DeviceFactory::ListAllPhysicalDevices
 tensorflow::DeviceFactory::GetAnyDeviceDetails
 
-[session_options]  # device_lib, tfe, tf_session
+[//tensorflow/core/common_runtime:session_options]  # device_lib, tfe, tf_session
 tensorflow::SessionOptions::SessionOptions
 
-[quantize_training]  # quantize_training
+[//tensorflow/core/common_runtime:quantize_training]  # quantize_training
 tensorflow::DoQuantizeTrainingOnSerializedGraphDef
 
-[session_state]  # tf_session
+[//tensorflow/core/common_runtime:session_state]  # tf_session
 tensorflow::SessionState::kTensorHandleResourceTypeName
 
-[server_lib] # server_lib
+[//tensorflow/core/data/service:server_lib] # server_lib
 tensorflow::data::GrpcDataServerBase::Join
 tensorflow::data::GrpcDataServerBase::Start
 tensorflow::data::GrpcDataServerBase::Stop
@@ -105,31 +95,25 @@ tensorflow::data::WorkerGrpcDataServer::NumTasks
 tensorflow::data::NewDispatchServer
 tensorflow::data::NewWorkerServer
 
-[protos_all]  # device_lib, dtypes
-tensorflow::DataType_IsValid
-tensorflow::ConfigProto::ConfigProto
-tensorflow::ConfigProto::ParseFromString
-tensorflow::DeviceAttributes::SerializeToString
-
-[py_exception_registry] # py_exception_registry
+[//tensorflow/python:py_exception_registry] # py_exception_registry
 tensorflow::PyExceptionRegistry::Init
 tensorflow::PyExceptionRegistry::Lookup
 
-[kernel_registry] # kernel_registry
+[//tensorflow/python:kernel_registry] # kernel_registry
 tensorflow::swig::TryFindKernelClass
 
-[toco_python_api] # toco_python_api
+[//tensorflow/lite/toco/python:toco_python_api] # toco_python_api
 toco::TocoConvert
 toco::TocoGetPotentiallySupportedOps
 toco::MlirQuantizeModel
 toco::MlirSparsifyModel
 toco::RegisterCustomOpdefs
 
-[transform_graph_lib] # transform_graph
+[//tensorflow/tools/graph_transforms:transform_graph_lib] # transform_graph
 tensorflow::graph_transforms::TransformGraph
 tensorflow::graph_transforms::ParseTransformParameters
 
-[checkpoint_reader] # py_checkpoint_reader
+[//tensorflow/c:checkpoint_reader] # py_checkpoint_reader
 tensorflow::checkpoint::CheckpointReader
 tensorflow::checkpoint::CheckpointReader::Init
 tensorflow::checkpoint::CheckpointReader::DebugString
@@ -138,21 +122,21 @@ tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap
 tensorflow::checkpoint::CheckpointReader::GetTensor
 tensorflow::checkpoint::CheckpointReader::HasTensor
 
-[tensor_bundle] # py_checkpoint_reader
+[//tensorflow/core/util/tensor_bundle] # py_checkpoint_reader
 tensorflow::BundleReader::BundleReader
 tensorflow::BundleReader::~BundleReader
 
-[ndarray_tensor] # py_checkpoint_reader
+[//tensorflow/python:ndarray_tensor] # py_checkpoint_reader
 tensorflow::TensorToNdarray
 
-[safe_ptr] # py_checkpoint_reader
+[//tensorflow/python:safe_ptr] # py_checkpoint_reader
 tensorflow::detail::PyDecrefDeleter
 tensorflow::make_safe
 
-[python_op_gen] # python_op_gen
+[//tensorflow/python:python_op_gen] # python_op_gen
 tensorflow::GetPythonWrappers
 
-[pywrap_tfe_lib] # tfe
+[//tensorflow/python/eager:pywrap_tfe_lib] # tfe
 tensorflow::TFE_TensorHandleCache
 tensorflow::TFE_TensorHandleCache::Clear
 EagerTensor_CheckExact
@@ -210,17 +194,17 @@ tensorflow::MakeEagerContextThreadLocalData
 tensorflow::GetEagerContextThreadLocalData
 tensorflow::DestroyEagerContextThreadLocalData
 
-[eager_executor] # tfe
+[//tensorflow/core/common_runtime/eager:eager_executor] # tfe
 tensorflow::EagerExecutor::~EagerExecutor
 tensorflow::EagerContext::WaitForAndCloseRemoteContexts
 
-[tf_status_helper] # tfe
+[//tensorflow/c:tf_status_helper] # tfe
 tensorflow::Set_TF_Status_from_Status
 
-[context] # tfe
+[//tensorflow/core/common_runtime/eager:context] # tfe
 tensorflow::EagerContext::WaitForAndCloseRemoteContexts
 
-[mlir] # mlir
+[//tensorflow/compiler/mlir/python:mlir] # mlir
 tensorflow::ExperimentalRunPassPipeline
 tensorflow::ExperimentalConvertSavedModelV1ToMlirLite
 tensorflow::ExperimentalConvertSavedModelV1ToMlir
@@ -228,13 +212,13 @@ tensorflow::ExperimentalConvertSavedModelToMlir
 tensorflow::ImportGraphDef
 tensorflow::ImportFunction
 
-[op_gen_lib] # tf_session
+[//tensorflow/core:op_gen_lib] # tf_session
 tensorflow::ApiDefMap::~ApiDefMap
 
-[graph_constructor] # tf_session
+[//tensorflow/core/common_runtime:graph_constructor] # tf_session
 tensorflow::ShapeRefiner::~ShapeRefiner
 
-[python_api] # tf_session
+[//tensorflow/c:python_api] # tf_session
 tensorflow::AddControlInput
 tensorflow::SetAttr
 tensorflow::ClearAttr
@@ -247,11 +231,11 @@ tensorflow::GetHandleShapeAndType
 tensorflow::SetHandleShapeAndType
 tensorflow::AddWhileInputHack
 
-[numpy_lib] # tf_session
+[//tensorflow/python:numpy_lib] # tf_session
 tensorflow::ImportNumpy
 _tensorflow_numpy_api
 
-[tf_session_helper] # tf_session
+[//tensorflow/python:tf_session_helper] # tf_session
 tensorflow::TF_NewSessionRef
 tensorflow::TF_SessionMakeCallable
 tensorflow::TF_SessionRunCallable
@@ -274,76 +258,76 @@ tensorflow::TF_GraphSetTensorShape_wrapper
 tensorflow::TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper
 tensorflow::TF_TryEvaluateConstant_wrapper
 
-[grappler_item] # tf_item
+[//tensorflow/core/grappler:grappler_item] # tf_item
 tensorflow::grappler::GrapplerItem::MainOpsFanin
 tensorflow::grappler::GrapplerItem::EnqueueOpsFanin
 
-[graph_properties] # tf_item
+[//tensorflow/core/grappler/costs:graph_properties] # tf_item
 tensorflow::grappler::GraphProperties::InferStatically
 tensorflow::grappler::GraphProperties::GetOutputProperties
 
-[grappler_item_builder] # tf_item
+[//tensorflow/core/grappler:grappler_item_builder] # tf_item
 tensorflow::grappler::GrapplerItemFromMetaGraphDef
 
-[topological_sort] # tf_item
+[//tensorflow/core/grappler/utils:topological_sort] # tf_item
 tensorflow::grappler::TopologicalSort
 
-[clusters/utils] # tf_cluster tf_optimizer
+[//tensorflow/core/grappler/clusters:utils] # tf_cluster tf_optimizer
 tensorflow::grappler::GetDeviceInfo
 
-[costs/utils] # tf_optimizer tf_cluster
+[//tensorflow/core/grappler/costs:utils] # tf_optimizer tf_cluster
 tensorflow::grappler::CostGraphToOpPerformanceData
 tensorflow::grappler::GetDeviceInfo
 
-[meta_optimizer] # tf_optimizer
+[//tensorflow/core/grappler/optimizers:meta_optimizer] # tf_optimizer
 tensorflow::grappler::MetaOptimizer::MetaOptimizer
 tensorflow::grappler::MetaOptimizer::Optimize
 tensorflow::grappler::MetaOptimizer::PrintResult
 
-[clusters/cluster] # tf_cluster
+[//tensorflow/core/grappler/clusters:cluster] # tf_cluster
 tensorflow::grappler::Cluster::AllowSoftPlacement
 tensorflow::grappler::Cluster::SetNumWarmupSteps
 tensorflow::grappler::Cluster::DisableDetailedStats
 tensorflow::grappler::Cluster::DetailedStatsEnabled
 
-[single_machine] # tf_cluster
+[//tensorflow/core/grappler/clusters:single_machine] # tf_cluster
 tensorflow::grappler::SingleMachine::SingleMachine
 
-[op_level_cost_estimator] # tf_cluster
+[//tensorflow/core/grappler/costs:op_level_cost_estimator] # tf_cluster
 tensorflow::grappler::OpLevelCostEstimator::OpLevelCostEstimator
 tensorflow::grappler::OpLevelCostEstimator::PredictCosts
 tensorflow::grappler::OpLevelCostEstimator::GetDeviceInfo
 
-[virtual_cluster] # tf_cluster
+[//tensorflow/core/grappler/clusters:virtual_cluster] # tf_cluster
 tensorflow::grappler::VirtualCluster::VirtualCluster
 
-[graph_memory] # tf_cluster
+[//tensorflow/core/grappler/costs:graph_memory] # tf_cluster
 tensorflow::grappler::GraphMemory::InferStatically
 tensorflow::grappler::GraphMemory::InferDynamically
 
-[measuring_cost_estimator] # tf_cluster
+[//tensorflow/core/grappler/costs:measuring_cost_estimator] # tf_cluster
 tensorflow::grappler::MeasuringCostEstimator::MeasuringCostEstimator
 tensorflow::grappler::MeasuringCostEstimator::Initialize
 tensorflow::grappler::MeasuringCostEstimator::PredictCosts
 
-[devices] # tf_cluster
+[//tensorflow/core/grappler:devices] # tf_cluster
 tensorflow::grappler::GetNumAvailableGPUs
 tensorflow::grappler::GetNumAvailableLogicalCPUCores
 
-[traceme_recorder_impl] # profiler
+[//tensorflow/core/profiler/internal:traceme_recorder_impl] # profiler
 tensorflow::profiler::TraceMeRecorder::Record
 
-[profiler_session_impl] # profiler
+[//tensorflow/core/profiler/lib:profiler_session_impl] # profiler
 tensorflow::ProfilerSession::Create
 tensorflow::ProfilerSession::CollectData
 tensorflow::ProfilerSession::Status
 tensorflow::ProfilerSession::~ProfilerSession
 
-[profiler_server_impl] # profiler
+[//tensorflow/core/profiler/rpc:profiler_server_impl] # profiler
 tensorflow::profiler::ProfilerServer::StartProfilerServer
 tensorflow::profiler::ProfilerServer::~ProfilerServer
 
-[profiler_client_impl] # profiler
+[//tensorflow/core/profiler/rpc/client:profiler_client_impl] # profiler
 tensorflow::profiler::ProfileGrpc
 tensorflow::profiler::NewSessionGrpc
 tensorflow::profiler::MonitorGrpc
@@ -352,13 +336,13 @@ tensorflow::profiler::RemoteProfilerSession::GetServiceAddress
 tensorflow::profiler::RemoteProfilerSession::WaitForCompletion
 tensorflow::profiler::RemoteProfilerSession::~RemoteProfilerSession
 
-[status_macros] # tfcompile
+[//tensorflow/compiler/xla:status_macros] # tfcompile
 xla::status_macros::MakeErrorStream::Impl::Impl
 xla::status_macros::MakeErrorStream::Impl::~Impl
 xla::status_macros::MakeErrorStream::Impl::GetStatus
 xla::status_macros::MakeErrorStream::CheckNotDone
 
-[hlo] # tfcompile
+[//tensorflow/compiler/xla/service:hlo] # tfcompile
 xla::DfsHloVisitorBase::SetVisited
 xla::DfsHloVisitorBase<class xla::HloInstruction.*>::SetVisited
 xla::HloComputation::Accept
@@ -368,42 +352,42 @@ xla::HloInstruction::ToString
 xla::HloInstruction::Accept
 xla::HloInstruction::Visit
 
-[tfcompile_lib] # tfcompile
+[//tensorflow/compiler/aot:tfcompile_lib] # tfcompile
 tensorflow::tfcompile::Main
 
-[model_analyzer_lib] # model_analyzer
+[//tensorflow/python:model_analyzer_lib] # model_analyzer
 tensorflow::grappler::ModelAnalyzer::GenerateReport
 tensorflow::grappler::ModelAnalyzer::ModelAnalyzer
 
-[analytical_cost_estimator] # cost_analyzer
+[//tensorflow/core/grappler/costs:analytical_cost_estimator] # cost_analyzer
 tensorflow::grappler::AnalyticalCostEstimator::Initialize
 tensorflow::grappler::AnalyticalCostEstimator::PredictCosts
 
-[cost_analyzer_lib] # cost_analyzer
+[//tensorflow/python:cost_analyzer_lib] # cost_analyzer
 tensorflow::grappler::CostAnalyzer::CostAnalyzer
 tensorflow::grappler::CostAnalyzer::GenerateReport
 
-[flags] # tfe
+[//tensorflow/compiler/jit:flags] # tfe
 tensorflow::IsXlaEnabled
 tensorflow::GetMlirCommonFlags
 tensorflow::GetXlaDeviceFlags
 
-[tensor_float_32_utils] # tensor_float_32
+[//tensorflow/core/platform:tensor_float_32_utils] # tensor_float_32
 tensorflow::enable_tensor_float_32_execution
 tensorflow::tensor_float_32_execution_enabled
 
-[get_compiler_ir] # tfe
+[//tensorflow/compiler/jit:get_compiler_ir] # tfe
 tensorflow::GetCompilerIr
 stream_executor::port::internal_statusor::Helper::Crash
 
-[tensor_handle] # tfe
+[//tensorflow/core/common_runtime/eager:tensor_handle] # tfe
 tensorflow::TensorHandle::Tensor
 
-[python_api_dispatcher] # python_api_dispatcher
+[//tensorflow/python:python_api_dispatcher] # python_api_dispatcher
 tensorflow::PythonAPIDispatcher
 
-[python_tensor_converter] # python_tensor_converter
+[//tensorflow/python:python_tensor_converter] # python_tensor_converter
 tensorflow::PythonTensorConverter
 
-[python_api_info] # python_api_info
+[//tensorflow/python:python_api_info] # python_api_info
 tensorflow::PythonAPIInfo
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 0d8ae1cc67b..f2029dc1c5c 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -203,11 +203,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
         name = "eigen_archive",
         build_file = clean_dep("//third_party:eigen.BUILD"),
         patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
-        sha256 = "cce3143f6ed22dadff4ef0b43ce31c9632ace8e75bf41e401b6ec4668968d4c9",  # SHARED_EIGEN_SHA
-        strip_prefix = "eigen-41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed",
+        sha256 = "306f15c04fbd514b4adc3a327a2c6f63521ea6805cab75691fa30c30fea55193",  # SHARED_EIGEN_SHA
+        strip_prefix = "eigen-fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed/eigen-41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed.tar.gz",
-            "https://gitlab.com/libeigen/eigen/-/archive/41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed/eigen-41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed.tar.gz",
+            "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed/eigen-fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed.tar.gz",
+            "https://gitlab.com/libeigen/eigen/-/archive/fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed/eigen-fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed.tar.gz",
         ],
     )
 
@@ -686,8 +686,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
     )
 
     # Check out LLVM and MLIR from llvm-project.
-    LLVM_COMMIT = "8c1e3cbebfe9ff14829a279b9d229d4fc3f190b1"
-    LLVM_SHA256 = "ac1e5de135b12b8008c7998e9e07a657524c2bfee79406273da5f18c356e3892"
+    LLVM_COMMIT = "76bd4444e36197465f1c72f4b6f1d59721012a59"
+    LLVM_SHA256 = "805b737a8ff996ea7216818b0969e7f60c3c9d7a9be82198902f53e2a36c44eb"
     LLVM_URLS = [
         "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
         "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index a7eff42d00f..fa4ccb71ece 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -675,6 +675,7 @@ cc_library(
     hdrs = ["include/mlir/Dialect/Async/Passes.h"],
     includes = ["include"],
     deps = [
+        ":Analysis",
         ":Async",
         ":AsyncPassIncGen",
         ":IR",
@@ -3451,6 +3452,7 @@ cc_library(
         "include/mlir/ExecutionEngine/AsyncRuntime.h",
     ],
     includes = ["include"],
+    deps = ["@llvm-project//llvm:Support"],
 )
 
 cc_library(
diff --git a/third_party/nccl/archive.patch b/third_party/nccl/archive.patch
index 9dfe432d60b..695fd718cc5 100644
--- a/third_party/nccl/archive.patch
+++ b/third_party/nccl/archive.patch
@@ -36,20 +36,20 @@ index 985274e..7ebb1e1 100644
 @@ -10,12 +10,12 @@
  #include <cuda_runtime.h>
  #include <cuda_fp16.h>
- 
+
 -#define NCCL_MAJOR ${nccl:Major}
 -#define NCCL_MINOR ${nccl:Minor}
 -#define NCCL_PATCH ${nccl:Patch}
 -#define NCCL_SUFFIX "${nccl:Suffix}"
 +#define NCCL_MAJOR 2
 +#define NCCL_MINOR 7
-+#define NCCL_PATCH 3
++#define NCCL_PATCH 6
 +#define NCCL_SUFFIX ""
- 
+
 -#define NCCL_VERSION_CODE ${nccl:Version}
 +#define NCCL_VERSION_CODE 2703
  #define NCCL_VERSION(X,Y,Z) ((X) * 1000 + (Y) * 100 + (Z))
- 
+
  #ifdef __cplusplus
 See https://github.com/NVIDIA/nccl/pull/322.patch
 From 410d341bd4569f60282576daa5c991717dbd560e Mon Sep 17 00:00:00 2001
@@ -127,7 +127,7 @@ index 550cfcd0c..8fea91950 100644
    if (parent == NULL) {
 -    if (path == NULL) NCCLCHECK(getPciPath(busId, &path));
 +    NCCLCHECK(getPciPath(busId, path));
- 
+
      // Save that for later in case next step is a CPU
      char numaIdStr[MAX_STR_LEN];
 @@ -544,7 +546,6 @@ ncclResult_t ncclTopoGetXmlFromSys(struct ncclXmlNode* pciNode, struct ncclXml*
@@ -137,7 +137,7 @@ index 550cfcd0c..8fea91950 100644
 -  free(path);
    return ncclSuccess;
  }
- 
+
 @@ -644,8 +644,8 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm
          // Remote NVLink device is not visible inside this VM. Assume NVSwitch.
          NCCLCHECK(xmlSetAttr(sub, "tclass", "0x068000"));
@@ -169,7 +169,7 @@ index 8fea91950..42eb68a4b 100644
 @@ -460,20 +460,21 @@ int checkBDFFormat(char* bdf) {
    return 1;
  }
- 
+
 -ncclResult_t ncclTopoGetXmlFromSys(struct ncclXmlNode* pciNode, struct ncclXml* xml) {
 +ncclResult_t ncclTopoGetXmlNodeFromSys(struct ncclXmlNode* pciNode,
 +                                       struct ncclXml* xml,