From 4856f23a49dfe44fb697d5affcf10fb2a2585523 Mon Sep 17 00:00:00 2001
From: David Majnemer <majnemer@google.com>
Date: Mon, 22 Feb 2021 11:15:54 -0800
Subject: [PATCH] Introduce Xla{Dot,Conv}V2

These support mixed operand precision (bf16 @ f32 -> f32), (int8 @ int8 -> int32) and thus diverge significantly from the original Xla{Dot,Conv}.

PiperOrigin-RevId: 358860920
Change-Id: I316f1b43268b268ce50528f7aa307182bce304f6
---
 .../compiler/jit/compilability_check_util.cc  |   2 +
 .../compiler/jit/mark_for_compilation_pass.cc |   2 +
 .../jit/xla_ops_on_regular_devices.cc         |  10 +
 .../mlir/tensorflow/ir/tf_generated_ops.td    |  56 ++++
 .../xla/transforms/legalize_tf_with_tf2xla.cc |   2 +
 tensorflow/compiler/tests/xla_ops_test.py     |  29 ++
 .../compiler/tf2xla/kernels/xla_conv_op.cc    |  31 +-
 .../compiler/tf2xla/kernels/xla_dot_op.cc     |  30 +-
 tensorflow/compiler/tf2xla/ops/xla_ops.cc     | 278 +++++++++++-------
 tensorflow/compiler/tf2xla/python/xla.py      |  20 +-
 .../core/profiler/utils/kernel_stats_utils.cc |   3 +-
 11 files changed, 341 insertions(+), 122 deletions(-)

diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc
index fb4c187f5bd..e155becab11 100644
--- a/tensorflow/compiler/jit/compilability_check_util.cc
+++ b/tensorflow/compiler/jit/compilability_check_util.cc
@@ -694,8 +694,10 @@ tensorflow::MemoryTypeVector GetOutputMemoryTypes(
 static auto const ops_triggering_xla_compilation =
     new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
                                          "XlaConv",
+                                         "XlaConvV2",
                                          "XlaDequantize",
                                          "XlaDot",
+                                         "XlaDotV2",
                                          "XlaDynamicSlice",
                                          "XlaDynamicUpdateSlice",
                                          "XlaEinsum",
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 4d2c9f45318..8b74ce8042d 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -2057,8 +2057,10 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
                                      "While",
                                      "XlaBroadcastHelper",
                                      "XlaConv",
+                                     "XlaConvV2",
                                      "XlaDequantize",
                                      "XlaDot",
+                                     "XlaDotV2",
                                      "XlaDynamicSlice",
                                      "XlaDynamicUpdateSlice",
                                      "XlaEinsum",
diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
index 7ddb1a60b73..7c4378415a9 100644
--- a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
+++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
@@ -29,6 +29,14 @@ namespace tensorflow {
                               .HostMemory("feature_group_count")               \
                               .Device(DEVICE),                                 \
                           XlaCompileOnDemandOp);                               \
+  REGISTER_KERNEL_BUILDER(Name("XlaConvV2")                                    \
+                              .HostMemory("window_strides")                    \
+                              .HostMemory("padding")                           \
+                              .HostMemory("lhs_dilation")                      \
+                              .HostMemory("rhs_dilation")                      \
+                              .HostMemory("feature_group_count")               \
+                              .Device(DEVICE),                                 \
+                          XlaCompileOnDemandOp);                               \
   REGISTER_KERNEL_BUILDER(                                                     \
       Name("XlaBroadcastHelper").HostMemory("broadcast_dims").Device(DEVICE),  \
       XlaCompileOnDemandOp);                                                   \
@@ -38,6 +46,8 @@ namespace tensorflow {
                           XlaCompileOnDemandOp);                               \
   REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE),                       \
                           XlaCompileOnDemandOp);                               \
+  REGISTER_KERNEL_BUILDER(Name("XlaDotV2").Device(DEVICE),                     \
+                          XlaCompileOnDemandOp);                               \
   REGISTER_KERNEL_BUILDER(                                                     \
       Name("XlaDynamicSlice").HostMemory("size_indices").Device(DEVICE),       \
       XlaCompileOnDemandOp);                                                   \
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 60f727dd97c..ab6eb0c7cf7 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -17682,6 +17682,37 @@ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_XlaConvV2Op : TF_Op<"XlaConvV2", [NoSideEffect]> {
+  let summary = "Wraps the XLA ConvGeneralDilated operator, documented at";
+
+  let description = [{
+https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
+.
+  }];
+
+  let arguments = (ins
+    Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the input tensor}]>:$lhs,
+    Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the kernel tensor}]>:$rhs,
+    Arg<TF_I32OrI64Tensor, [{the inter-window strides}]>:$window_strides,
+    Arg<TF_I32OrI64Tensor, [{the padding to apply at the start and end of each input dimensions}]>:$padding,
+    Arg<TF_I32OrI64Tensor, [{dilation to apply between input elements}]>:$lhs_dilation,
+    Arg<TF_I32OrI64Tensor, [{dilation to apply between kernel elements}]>:$rhs_dilation,
+    Arg<TF_I32OrI64Tensor, [{number of feature groups for grouped convolution.}]>:$feature_group_count,
+
+    StrAttr:$dimension_numbers,
+    StrAttr:$precision_config
+  );
+
+  let results = (outs
+    TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
+  );
+
+  TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>;
+  TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr RhsT = TF_DerivedOperandTypeAttr<1>;
+  TF_DerivedResultTypeAttr preferred_element_type = TF_DerivedResultTypeAttr<0>;
+}
+
 def TF_XlaDotOp : TF_Op<"XlaDot", [NoSideEffect]> {
   let summary = "Wraps the XLA DotGeneral operator, documented at";
 
@@ -17705,6 +17736,31 @@ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_XlaDotV2Op : TF_Op<"XlaDotV2", [NoSideEffect]> {
+  let summary = "Wraps the XLA DotGeneral operator, documented at";
+
+  let description = [{
+https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
+.
+  }];
+
+  let arguments = (ins
+    Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the LHS tensor}]>:$lhs,
+    Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the RHS tensor}]>:$rhs,
+
+    StrAttr:$dimension_numbers,
+    StrAttr:$precision_config
+  );
+
+  let results = (outs
+    TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
+  );
+
+  TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr RhsT = TF_DerivedOperandTypeAttr<1>;
+  TF_DerivedResultTypeAttr preferred_element_type = TF_DerivedResultTypeAttr<0>;
+}
+
 def TF_XlaDynamicSliceOp : TF_Op<"XlaDynamicSlice", [NoSideEffect]> {
   let summary = "Wraps the XLA DynamicSlice operator, documented at";
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index d34f40c6b05..bae0196fc63 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -264,7 +264,9 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
     TypeID::get<TF::UpperBoundOp>(),
     TypeID::get<TF::XlaBroadcastHelperOp>(),
     TypeID::get<TF::XlaConvOp>(),
+    TypeID::get<TF::XlaConvV2Op>(),
     TypeID::get<TF::XlaDotOp>(),
+    TypeID::get<TF::XlaDotV2Op>(),
     TypeID::get<TF::XlaDynamicSliceOp>(),
     TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
     TypeID::get<TF::XlaEinsumOp>(),
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 7a99c27a075..cf7d342ddb8 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -205,6 +205,35 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
               ],
               dtype=dtype))
 
+  def testDotGeneralInt8xInt8ToInt32(self):
+
+    def dot_fn(lhs, rhs):
+      dnums = xla_data_pb2.DotDimensionNumbers()
+      dnums.lhs_contracting_dimensions.append(2)
+      dnums.rhs_contracting_dimensions.append(1)
+      dnums.lhs_batch_dimensions.append(0)
+      dnums.rhs_batch_dimensions.append(0)
+      return xla.dot_general(
+          lhs, rhs, dimension_numbers=dnums, preferred_element_type=np.int32)
+
+    lhs = np.array([
+        [[1, 2], [3, 4]],
+        [[5, 6], [7, 8]],
+    ], dtype=np.int8)
+    rhs = np.array([
+        [[1, 2, 3], [4, 5, 6]],
+        [[7, 8, 9], [10, 11, 12]],
+    ],
+                   dtype=np.int8)
+    self._assertOpOutputMatchesExpected(
+        dot_fn,
+        args=(lhs, rhs),
+        expected=np.array([
+            [[9, 12, 15], [19, 26, 33]],
+            [[95, 106, 117], [129, 144, 159]],
+        ],
+                          dtype=np.int32))
+
   def testNeg(self):
     for dtype in self.numeric_types - {np.uint8, np.int8}:
       self._assertOpOutputMatchesExpected(
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
index 7a8aec295a6..4cc49c34363 100644
--- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -38,6 +39,7 @@ class XlaConvOp : public XlaOpKernel {
     OP_REQUIRES(context,
                 precision_config_.ParsePartialFromString(precision_config_attr),
                 errors::InvalidArgument("Error parsing precision config."));
+    preferred_element_type_ = absl::nullopt;
   }
 
   void Compile(XlaOpKernelContext* context) override {
@@ -77,10 +79,13 @@ class XlaConvOp : public XlaOpKernel {
     xla::XlaOp output = xla::ConvGeneralDilated(
         context->Input(0), context->Input(1), window_strides, padding,
         lhs_dilation, rhs_dilation, dnums_, feature_group_count,
-        /*batch_group_count=*/1, &precision_config_);
+        /*batch_group_count=*/1, &precision_config_, preferred_element_type_);
     context->SetOutput(0, output);
   }
 
+ protected:
+  absl::optional<xla::PrimitiveType> preferred_element_type_;
+
  private:
   xla::ConvolutionDimensionNumbers dnums_;
   xla::PrecisionConfig precision_config_;
@@ -96,5 +101,29 @@ REGISTER_XLA_OP(Name("XlaConv")
                     .CompileTimeConstantInput("padding"),
                 XlaConvOp);
 
+class XlaConvV2Op : public XlaConvOp {
+ public:
+  explicit XlaConvV2Op(OpKernelConstruction* context) : XlaConvOp(context) {
+    DataType preferred_element_dtype;
+    OP_REQUIRES_OK(context, context->GetAttr("preferred_element_type",
+                                             &preferred_element_dtype));
+    xla::PrimitiveType preferred_element_type;
+    OP_REQUIRES_OK(context, DataTypeToPrimitiveType(preferred_element_dtype,
+                                                    &preferred_element_type));
+    preferred_element_type_ = preferred_element_type;
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(XlaConvV2Op);
+};
+
+REGISTER_XLA_OP(Name("XlaConvV2")
+                    .CompileTimeConstantInput("window_strides")
+                    .CompileTimeConstantInput("lhs_dilation")
+                    .CompileTimeConstantInput("rhs_dilation")
+                    .CompileTimeConstantInput("feature_group_count")
+                    .CompileTimeConstantInput("padding"),
+                XlaConvOp);
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
index 40b15b5579a..f7e938aa2e3 100644
--- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
@@ -14,12 +14,14 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.pb.h"
 
 namespace tensorflow {
 namespace {
@@ -39,6 +41,7 @@ class XlaDotOp : public XlaOpKernel {
         context,
         precision_config_.ParsePartialFromString(precision_config_attr),
         errors::InvalidArgument("Error parsing convolution dimension numbers"));
+    preferred_element_type_ = absl::nullopt;
   }
 
   void Compile(XlaOpKernelContext* context) override {
@@ -47,19 +50,40 @@ class XlaDotOp : public XlaOpKernel {
 
     // We do only minimal checking, relying on XLA to check the shape
     // invariants.
-    xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1),
-                                        dnums_, &precision_config_);
+    xla::XlaOp output =
+        xla::DotGeneral(context->Input(0), context->Input(1), dnums_,
+                        &precision_config_, preferred_element_type_);
     context->SetOutput(0, output);
   }
 
+ protected:
+  absl::optional<xla::PrimitiveType> preferred_element_type_;
+
  private:
   xla::DotDimensionNumbers dnums_;
   xla::PrecisionConfig precision_config_;
-
   TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp);
 };
 
 REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp);
 
+class XlaDotV2Op : public XlaDotOp {
+ public:
+  explicit XlaDotV2Op(OpKernelConstruction* context) : XlaDotOp(context) {
+    DataType preferred_element_dtype;
+    OP_REQUIRES_OK(context, context->GetAttr("preferred_element_type",
+                                             &preferred_element_dtype));
+    xla::PrimitiveType preferred_element_type;
+    OP_REQUIRES_OK(context, DataTypeToPrimitiveType(preferred_element_dtype,
+                                                    &preferred_element_type));
+    preferred_element_type_ = preferred_element_type;
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(XlaDotV2Op);
+};
+
+REGISTER_XLA_OP(Name("XlaDotV2"), XlaDotV2Op);
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index e311424ae5f..217bb19f952 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -161,6 +161,147 @@ dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
 precision_config: a serialized xla::PrecisionConfig proto.
 )doc");
 
+REGISTER_OP("XlaConvV2")
+    .Input("lhs: LhsT")
+    .Input("rhs: RhsT")
+    .Input("window_strides: Tindices")
+    .Input("padding: Tindices")
+    .Input("lhs_dilation: Tindices")
+    .Input("rhs_dilation: Tindices")
+    .Input("feature_group_count: Tindices")
+    .Attr("LhsT: numbertype")
+    .Attr("RhsT: numbertype")
+    .Attr("Tindices: {int32, int64}")
+    .Attr("dimension_numbers: string")
+    .Attr("precision_config: string")
+    .Attr("preferred_element_type: numbertype")
+    .Output("output: preferred_element_type")
+    .SetShapeFn(UnchangedRank)
+    .Doc(R"doc(
+Wraps the XLA ConvGeneralDilated operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
+.
+
+lhs: the input tensor
+rhs: the kernel tensor
+window_strides: the inter-window strides
+padding: the padding to apply at the start and end of each input dimensions
+lhs_dilation: dilation to apply between input elements
+rhs_dilation: dilation to apply between kernel elements
+feature_group_count: number of feature groups for grouped convolution.
+dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
+precision_config: a serialized xla::PrecisionConfig proto.
+preferred_element_type: The type of the tensor.
+)doc");
+
+static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) {
+  shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
+  shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
+  if (!c->FullyDefined(lhs_shape_handle) ||
+      !c->FullyDefined(rhs_shape_handle)) {
+    return shape_inference::UnknownShape(c);
+  }
+
+  string dimension_numbers_string;
+  TF_RETURN_IF_ERROR(
+      c->GetAttr("dimension_numbers", &dimension_numbers_string));
+
+  xla::DotDimensionNumbers dimension_numbers;
+  dimension_numbers.ParseFromString(dimension_numbers_string);
+
+  // Check that number of contracting dimensions match.
+  if (dimension_numbers.lhs_contracting_dimensions_size() !=
+      dimension_numbers.rhs_contracting_dimensions_size())
+    return errors::InvalidArgument(
+        "Must specify the same number of contracting dimensions for lhs "
+        "and rhs. Got: ",
+        dimension_numbers.lhs_contracting_dimensions_size(), " and ",
+        dimension_numbers.rhs_contracting_dimensions_size());
+
+  // Check that contracting dimension sizes match.
+  for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
+       ++i) {
+    const int64 lhs_contracting_dimension =
+        dimension_numbers.lhs_contracting_dimensions(i);
+    const int64 rhs_contracting_dimension =
+        dimension_numbers.rhs_contracting_dimensions(i);
+    shape_inference::DimensionOrConstant lhs_contracting_dimension_or_constant(
+        c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension));
+    shape_inference::DimensionOrConstant rhs_contracting_dimension_or_constant(
+        c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension));
+    const int64 lhs_contracting_dimension_size =
+        c->Value(lhs_contracting_dimension_or_constant);
+    const int64 rhs_contracting_dimension_size =
+        c->Value(rhs_contracting_dimension_or_constant);
+    if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) {
+      return errors::InvalidArgument(
+          "Contracting dimension sizes do not match. Got: ",
+          lhs_contracting_dimension_size, " and ",
+          rhs_contracting_dimension_size);
+    }
+  }
+
+  // Check that number of batch dimensions match.
+  if (dimension_numbers.lhs_batch_dimensions_size() !=
+      dimension_numbers.rhs_batch_dimensions_size())
+    return errors::InvalidArgument(
+        "Must specify the same number of batch dimensions for lhs "
+        "and rhs. Got: ",
+        dimension_numbers.lhs_batch_dimensions_size(), " and ",
+        dimension_numbers.rhs_batch_dimensions_size());
+
+  // Check that batch dimension sizes match.
+  for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
+    const int64 lhs_batch_dimension = dimension_numbers.lhs_batch_dimensions(i);
+    const int64 rhs_batch_dimension = dimension_numbers.rhs_batch_dimensions(i);
+    shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant(
+        c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension));
+    shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant(
+        c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension));
+    const int64 lhs_batch_dimension_size =
+        c->Value(lhs_batch_dimension_or_constant);
+    const int64 rhs_batch_dimension_size =
+        c->Value(rhs_batch_dimension_or_constant);
+    if (lhs_batch_dimension_size != rhs_batch_dimension_size) {
+      return errors::InvalidArgument(
+          "Batch dimension sizes do not match. Got: ", lhs_batch_dimension_size,
+          " and ", rhs_batch_dimension_size);
+    }
+  }
+
+  // The ranks of lhs and rhs are decremented by 1 respectively due to the
+  // contraction, and added for the rank of the result. When an input tensor
+  // is a scalar, its contribution to the rank of the result is 0. Generate
+  // the result dimensions in order, rhs dimensions followed by lhs
+  // dimensions except the contracted and batch dimensions.
+  std::vector<shape_inference::DimensionHandle> output_dims;
+  for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
+    output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim));
+  }
+  const int32 lhs_rank = c->Rank(lhs_shape_handle);
+  for (int64 i = 0; i < lhs_rank; ++i) {
+    if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
+                              i) ||
+        absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
+      continue;
+    }
+    output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
+  }
+
+  const int32 rhs_rank = c->Rank(rhs_shape_handle);
+  for (int64 i = 0; i < rhs_rank; ++i) {
+    if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
+                              i) ||
+        absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
+      continue;
+    }
+    output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
+  }
+
+  c->set_output(0, c->MakeShape(output_dims));
+  return Status::OK();
+}
+
 REGISTER_OP("XlaDot")
     .Input("lhs: T")
     .Input("rhs: T")
@@ -168,120 +309,7 @@ REGISTER_OP("XlaDot")
     .Attr("dimension_numbers: string")
     .Attr("precision_config: string")
     .Output("output: T")
-    .SetShapeFn([](shape_inference::InferenceContext* c) {
-      shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
-      shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
-      if (!c->FullyDefined(lhs_shape_handle) ||
-          !c->FullyDefined(rhs_shape_handle)) {
-        return shape_inference::UnknownShape(c);
-      }
-
-      string dimension_numbers_string;
-      TF_RETURN_IF_ERROR(
-          c->GetAttr("dimension_numbers", &dimension_numbers_string));
-
-      xla::DotDimensionNumbers dimension_numbers;
-      dimension_numbers.ParseFromString(dimension_numbers_string);
-
-      // Check that number of contracting dimensions match.
-      if (dimension_numbers.lhs_contracting_dimensions_size() !=
-          dimension_numbers.rhs_contracting_dimensions_size())
-        return errors::InvalidArgument(
-            "Must specify the same number of contracting dimensions for lhs "
-            "and rhs. Got: ",
-            dimension_numbers.lhs_contracting_dimensions_size(), " and ",
-            dimension_numbers.rhs_contracting_dimensions_size());
-
-      // Check that contracting dimension sizes match.
-      for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
-           ++i) {
-        const int64 lhs_contracting_dimension =
-            dimension_numbers.lhs_contracting_dimensions(i);
-        const int64 rhs_contracting_dimension =
-            dimension_numbers.rhs_contracting_dimensions(i);
-        shape_inference::DimensionOrConstant
-            lhs_contracting_dimension_or_constant(
-                c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension));
-        shape_inference::DimensionOrConstant
-            rhs_contracting_dimension_or_constant(
-                c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension));
-        const int64 lhs_contracting_dimension_size =
-            c->Value(lhs_contracting_dimension_or_constant);
-        const int64 rhs_contracting_dimension_size =
-            c->Value(rhs_contracting_dimension_or_constant);
-        if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) {
-          return errors::InvalidArgument(
-              "Contracting dimension sizes do not match. Got: ",
-              lhs_contracting_dimension_size, " and ",
-              rhs_contracting_dimension_size);
-        }
-      }
-
-      // Check that number of batch dimensions match.
-      if (dimension_numbers.lhs_batch_dimensions_size() !=
-          dimension_numbers.rhs_batch_dimensions_size())
-        return errors::InvalidArgument(
-            "Must specify the same number of batch dimensions for lhs "
-            "and rhs. Got: ",
-            dimension_numbers.lhs_batch_dimensions_size(), " and ",
-            dimension_numbers.rhs_batch_dimensions_size());
-
-      // Check that batch dimension sizes match.
-      for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size();
-           ++i) {
-        const int64 lhs_batch_dimension =
-            dimension_numbers.lhs_batch_dimensions(i);
-        const int64 rhs_batch_dimension =
-            dimension_numbers.rhs_batch_dimensions(i);
-        shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant(
-            c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension));
-        shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant(
-            c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension));
-        const int64 lhs_batch_dimension_size =
-            c->Value(lhs_batch_dimension_or_constant);
-        const int64 rhs_batch_dimension_size =
-            c->Value(rhs_batch_dimension_or_constant);
-        if (lhs_batch_dimension_size != rhs_batch_dimension_size) {
-          return errors::InvalidArgument(
-              "Batch dimension sizes do not match. Got: ",
-              lhs_batch_dimension_size, " and ", rhs_batch_dimension_size);
-        }
-      }
-
-      // The ranks of lhs and rhs are decremented by 1 respectively due to the
-      // contraction, and added for the rank of the result. When an input tensor
-      // is a scalar, its contribution to the rank of the result is 0. Generate
-      // the result dimensions in order, rhs dimensions followed by lhs
-      // dimensions except the contracted and batch dimensions.
-      std::vector<shape_inference::DimensionHandle> output_dims;
-      for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
-        output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim));
-      }
-      const int32 lhs_rank = c->Rank(lhs_shape_handle);
-      for (int64 i = 0; i < lhs_rank; ++i) {
-        if (absl::c_linear_search(
-                dimension_numbers.lhs_contracting_dimensions(), i) ||
-            absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
-                                  i)) {
-          continue;
-        }
-        output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
-      }
-
-      const int32 rhs_rank = c->Rank(rhs_shape_handle);
-      for (int64 i = 0; i < rhs_rank; ++i) {
-        if (absl::c_linear_search(
-                dimension_numbers.rhs_contracting_dimensions(), i) ||
-            absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
-                                  i)) {
-          continue;
-        }
-        output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
-      }
-
-      c->set_output(0, c->MakeShape(output_dims));
-      return Status::OK();
-    })
+    .SetShapeFn(XlaDotShapeFunction)
     .Doc(R"doc(
 Wraps the XLA DotGeneral operator, documented at
  https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
@@ -293,6 +321,28 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto.
 precision_config: a serialized xla::PrecisionConfig proto.
 )doc");
 
+REGISTER_OP("XlaDotV2")
+    .Input("lhs: LhsT")
+    .Input("rhs: RhsT")
+    .Attr("LhsT: numbertype")
+    .Attr("RhsT: numbertype")
+    .Attr("dimension_numbers: string")
+    .Attr("precision_config: string")
+    .Attr("preferred_element_type: numbertype")
+    .Output("output: preferred_element_type")
+    .SetShapeFn(XlaDotShapeFunction)
+    .Doc(R"doc(
+Wraps the XLA DotGeneral operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
+.
+
+lhs: the LHS tensor
+rhs: the RHS tensor
+dimension_numbers: a serialized xla::DotDimensionNumbers proto.
+precision_config: a serialized xla::PrecisionConfig proto.
+preferred_element_type: The type of the tensor.
+)doc");
+
 REGISTER_OP("XlaSetBound")
     .Input("input: int32")
     .Input("bound: int32")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 4a12017bbce..5acd3025e1a 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_random_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import special_math_ops
+from tensorflow.python.ops.numpy_ops import np_utils
 
 # TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
 # ops include:
@@ -249,6 +250,7 @@ def conv(lhs,
          dimension_numbers,
          feature_group_count=1,
          precision_config=None,
+         preferred_element_type=None,
          name=None):
   """Wraps the XLA ConvGeneralDilated operator.
 
@@ -266,6 +268,7 @@ def conv(lhs,
     dimension_numbers: a `ConvolutionDimensionNumbers` proto.
     feature_group_count: number of feature groups for grouped convolution.
     precision_config: a `xla.PrecisionConfig` proto.
+    preferred_element_type: the result `dtype`.
     name: an optional name for the operator
 
   Returns:
@@ -274,7 +277,9 @@ def conv(lhs,
   precision_config_proto = ""
   if precision_config:
     precision_config_proto = precision_config.SerializeToString()
-  return gen_xla_ops.xla_conv(
+  if preferred_element_type is None:
+    preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
+  return gen_xla_ops.xla_conv_v2(
       lhs,
       rhs,
       window_strides=window_strides,
@@ -284,6 +289,7 @@ def conv(lhs,
       feature_group_count=feature_group_count,
       dimension_numbers=dimension_numbers.SerializeToString(),
       precision_config=precision_config_proto,
+      preferred_element_type=preferred_element_type,
       name=name)
 
 
@@ -294,15 +300,23 @@ def dot(lhs, rhs, name=None):
   return math_ops.tensordot(lhs, rhs, axes=1, name=name)
 
 
-def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
+def dot_general(lhs,
+                rhs,
+                dimension_numbers,
+                precision_config=None,
+                preferred_element_type=None,
+                name=None):
   precision_config_proto = ""
   if precision_config:
     precision_config_proto = precision_config.SerializeToString()
-  return gen_xla_ops.xla_dot(
+  if preferred_element_type is None:
+    preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
+  return gen_xla_ops.xla_dot_v2(
       lhs,
       rhs,
       dimension_numbers=dimension_numbers.SerializeToString(),
       precision_config=precision_config_proto,
+      preferred_element_type=preferred_element_type,
       name=name)
 
 
diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc
index 7aa635a7711..e5084d7edd8 100644
--- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc
+++ b/tensorflow/core/profiler/utils/kernel_stats_utils.cc
@@ -143,7 +143,8 @@ bool IsOpTensorCoreEligible(absl::string_view tf_op_name) {
       || absl::StrContains(tf_op_name, "CudnnRNNForward")
       || absl::StrContains(tf_op_name, "CudnnRNNBackprop")
       // Special cases.
-      || absl::EndsWith(tf_op_name, "XlaDot");
+      || absl::EndsWith(tf_op_name, "XlaDot")
+      || absl::EndsWith(tf_op_name, "XlaDotV2");
   // clang-format on
 }