From 4856f23a49dfe44fb697d5affcf10fb2a2585523 Mon Sep 17 00:00:00 2001 From: David Majnemer <majnemer@google.com> Date: Mon, 22 Feb 2021 11:15:54 -0800 Subject: [PATCH] Introduce Xla{Dot,Conv}V2 These support mixed operand precision (bf16 @ f32 -> f32), (int8 @ int8 -> int32) and thus diverge significantly from the original Xla{Dot,Conv}. PiperOrigin-RevId: 358860920 Change-Id: I316f1b43268b268ce50528f7aa307182bce304f6 --- .../compiler/jit/compilability_check_util.cc | 2 + .../compiler/jit/mark_for_compilation_pass.cc | 2 + .../jit/xla_ops_on_regular_devices.cc | 10 + .../mlir/tensorflow/ir/tf_generated_ops.td | 56 ++++ .../xla/transforms/legalize_tf_with_tf2xla.cc | 2 + tensorflow/compiler/tests/xla_ops_test.py | 29 ++ .../compiler/tf2xla/kernels/xla_conv_op.cc | 31 +- .../compiler/tf2xla/kernels/xla_dot_op.cc | 30 +- tensorflow/compiler/tf2xla/ops/xla_ops.cc | 278 +++++++++++------- tensorflow/compiler/tf2xla/python/xla.py | 20 +- .../core/profiler/utils/kernel_stats_utils.cc | 3 +- 11 files changed, 341 insertions(+), 122 deletions(-) diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index fb4c187f5bd..e155becab11 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -694,8 +694,10 @@ tensorflow::MemoryTypeVector GetOutputMemoryTypes( static auto const ops_triggering_xla_compilation = new absl::flat_hash_set<std::string>{"XlaBroadcastHelper", "XlaConv", + "XlaConvV2", "XlaDequantize", "XlaDot", + "XlaDotV2", "XlaDynamicSlice", "XlaDynamicUpdateSlice", "XlaEinsum", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 4d2c9f45318..8b74ce8042d 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2057,8 +2057,10 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() { "While", "XlaBroadcastHelper", "XlaConv", + "XlaConvV2", "XlaDequantize", "XlaDot", + "XlaDotV2", "XlaDynamicSlice", "XlaDynamicUpdateSlice", "XlaEinsum", diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc index 7ddb1a60b73..7c4378415a9 100644 --- a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc +++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc @@ -29,6 +29,14 @@ namespace tensorflow { .HostMemory("feature_group_count") \ .Device(DEVICE), \ XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaConvV2") \ + .HostMemory("window_strides") \ + .HostMemory("padding") \ + .HostMemory("lhs_dilation") \ + .HostMemory("rhs_dilation") \ + .HostMemory("feature_group_count") \ + .Device(DEVICE), \ + XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER( \ Name("XlaBroadcastHelper").HostMemory("broadcast_dims").Device(DEVICE), \ XlaCompileOnDemandOp); \ @@ -38,6 +46,8 @@ namespace tensorflow { XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE), \ XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaDotV2").Device(DEVICE), \ + XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER( \ Name("XlaDynamicSlice").HostMemory("size_indices").Device(DEVICE), \ XlaCompileOnDemandOp); \ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 60f727dd97c..ab6eb0c7cf7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -17682,6 +17682,37 @@ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaConvV2Op : TF_Op<"XlaConvV2", [NoSideEffect]> { + let summary = "Wraps the XLA ConvGeneralDilated operator, documented at"; + + let description = [{ +https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution +. + }]; + + let arguments = (ins + Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the input tensor}]>:$lhs, + Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the kernel tensor}]>:$rhs, + Arg<TF_I32OrI64Tensor, [{the inter-window strides}]>:$window_strides, + Arg<TF_I32OrI64Tensor, [{the padding to apply at the start and end of each input dimensions}]>:$padding, + Arg<TF_I32OrI64Tensor, [{dilation to apply between input elements}]>:$lhs_dilation, + Arg<TF_I32OrI64Tensor, [{dilation to apply between kernel elements}]>:$rhs_dilation, + Arg<TF_I32OrI64Tensor, [{number of feature groups for grouped convolution.}]>:$feature_group_count, + + StrAttr:$dimension_numbers, + StrAttr:$precision_config + ); + + let results = (outs + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr RhsT = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr preferred_element_type = TF_DerivedResultTypeAttr<0>; +} + def TF_XlaDotOp : TF_Op<"XlaDot", [NoSideEffect]> { let summary = "Wraps the XLA DotGeneral operator, documented at"; @@ -17705,6 +17736,31 @@ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaDotV2Op : TF_Op<"XlaDotV2", [NoSideEffect]> { + let summary = "Wraps the XLA DotGeneral operator, documented at"; + + let description = [{ +https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral +. + }]; + + let arguments = (ins + Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the LHS tensor}]>:$lhs, + Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the RHS tensor}]>:$rhs, + + StrAttr:$dimension_numbers, + StrAttr:$precision_config + ); + + let results = (outs + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr RhsT = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr preferred_element_type = TF_DerivedResultTypeAttr<0>; +} + def TF_XlaDynamicSliceOp : TF_Op<"XlaDynamicSlice", [NoSideEffect]> { let summary = "Wraps the XLA DynamicSlice operator, documented at"; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index d34f40c6b05..bae0196fc63 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -264,7 +264,9 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get<TF::UpperBoundOp>(), TypeID::get<TF::XlaBroadcastHelperOp>(), TypeID::get<TF::XlaConvOp>(), + TypeID::get<TF::XlaConvV2Op>(), TypeID::get<TF::XlaDotOp>(), + TypeID::get<TF::XlaDotV2Op>(), TypeID::get<TF::XlaDynamicSliceOp>(), TypeID::get<TF::XlaDynamicUpdateSliceOp>(), TypeID::get<TF::XlaEinsumOp>(), diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 7a99c27a075..cf7d342ddb8 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -205,6 +205,35 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): ], dtype=dtype)) + def testDotGeneralInt8xInt8ToInt32(self): + + def dot_fn(lhs, rhs): + dnums = xla_data_pb2.DotDimensionNumbers() + dnums.lhs_contracting_dimensions.append(2) + dnums.rhs_contracting_dimensions.append(1) + dnums.lhs_batch_dimensions.append(0) + dnums.rhs_batch_dimensions.append(0) + return xla.dot_general( + lhs, rhs, dimension_numbers=dnums, preferred_element_type=np.int32) + + lhs = np.array([ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + ], dtype=np.int8) + rhs = np.array([ + [[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]], + ], + dtype=np.int8) + self._assertOpOutputMatchesExpected( + dot_fn, + args=(lhs, rhs), + expected=np.array([ + [[9, 12, 15], [19, 26, 33]], + [[95, 106, 117], [129, 144, 159]], + ], + dtype=np.int32)) + def testNeg(self): for dtype in self.numeric_types - {np.uint8, np.int8}: self._assertOpOutputMatchesExpected( diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index 7a8aec295a6..4cc49c34363 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -38,6 +39,7 @@ class XlaConvOp : public XlaOpKernel { OP_REQUIRES(context, precision_config_.ParsePartialFromString(precision_config_attr), errors::InvalidArgument("Error parsing precision config.")); + preferred_element_type_ = absl::nullopt; } void Compile(XlaOpKernelContext* context) override { @@ -77,10 +79,13 @@ class XlaConvOp : public XlaOpKernel { xla::XlaOp output = xla::ConvGeneralDilated( context->Input(0), context->Input(1), window_strides, padding, lhs_dilation, rhs_dilation, dnums_, feature_group_count, - /*batch_group_count=*/1, &precision_config_); + /*batch_group_count=*/1, &precision_config_, preferred_element_type_); context->SetOutput(0, output); } + protected: + absl::optional<xla::PrimitiveType> preferred_element_type_; + private: xla::ConvolutionDimensionNumbers dnums_; xla::PrecisionConfig precision_config_; @@ -96,5 +101,29 @@ REGISTER_XLA_OP(Name("XlaConv") .CompileTimeConstantInput("padding"), XlaConvOp); +class XlaConvV2Op : public XlaConvOp { + public: + explicit XlaConvV2Op(OpKernelConstruction* context) : XlaConvOp(context) { + DataType preferred_element_dtype; + OP_REQUIRES_OK(context, context->GetAttr("preferred_element_type", + &preferred_element_dtype)); + xla::PrimitiveType preferred_element_type; + OP_REQUIRES_OK(context, DataTypeToPrimitiveType(preferred_element_dtype, + &preferred_element_type)); + preferred_element_type_ = preferred_element_type; + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaConvV2Op); +}; + +REGISTER_XLA_OP(Name("XlaConvV2") + .CompileTimeConstantInput("window_strides") + .CompileTimeConstantInput("lhs_dilation") + .CompileTimeConstantInput("rhs_dilation") + .CompileTimeConstantInput("feature_group_count") + .CompileTimeConstantInput("padding"), + XlaConvOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc index 40b15b5579a..f7e938aa2e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -14,12 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -39,6 +41,7 @@ class XlaDotOp : public XlaOpKernel { context, precision_config_.ParsePartialFromString(precision_config_attr), errors::InvalidArgument("Error parsing convolution dimension numbers")); + preferred_element_type_ = absl::nullopt; } void Compile(XlaOpKernelContext* context) override { @@ -47,19 +50,40 @@ class XlaDotOp : public XlaOpKernel { // We do only minimal checking, relying on XLA to check the shape // invariants. - xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1), - dnums_, &precision_config_); + xla::XlaOp output = + xla::DotGeneral(context->Input(0), context->Input(1), dnums_, + &precision_config_, preferred_element_type_); context->SetOutput(0, output); } + protected: + absl::optional<xla::PrimitiveType> preferred_element_type_; + private: xla::DotDimensionNumbers dnums_; xla::PrecisionConfig precision_config_; - TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp); }; REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp); +class XlaDotV2Op : public XlaDotOp { + public: + explicit XlaDotV2Op(OpKernelConstruction* context) : XlaDotOp(context) { + DataType preferred_element_dtype; + OP_REQUIRES_OK(context, context->GetAttr("preferred_element_type", + &preferred_element_dtype)); + xla::PrimitiveType preferred_element_type; + OP_REQUIRES_OK(context, DataTypeToPrimitiveType(preferred_element_dtype, + &preferred_element_type)); + preferred_element_type_ = preferred_element_type; + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaDotV2Op); +}; + +REGISTER_XLA_OP(Name("XlaDotV2"), XlaDotV2Op); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index e311424ae5f..217bb19f952 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -161,6 +161,147 @@ dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. precision_config: a serialized xla::PrecisionConfig proto. )doc"); +REGISTER_OP("XlaConvV2") + .Input("lhs: LhsT") + .Input("rhs: RhsT") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("lhs_dilation: Tindices") + .Input("rhs_dilation: Tindices") + .Input("feature_group_count: Tindices") + .Attr("LhsT: numbertype") + .Attr("RhsT: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Attr("preferred_element_type: numbertype") + .Output("output: preferred_element_type") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution +. + +lhs: the input tensor +rhs: the kernel tensor +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +lhs_dilation: dilation to apply between input elements +rhs_dilation: dilation to apply between kernel elements +feature_group_count: number of feature groups for grouped convolution. +dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfig proto. +preferred_element_type: The type of the tensor. +)doc"); + +static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle lhs_shape_handle = c->input(0); + shape_inference::ShapeHandle rhs_shape_handle = c->input(1); + if (!c->FullyDefined(lhs_shape_handle) || + !c->FullyDefined(rhs_shape_handle)) { + return shape_inference::UnknownShape(c); + } + + string dimension_numbers_string; + TF_RETURN_IF_ERROR( + c->GetAttr("dimension_numbers", &dimension_numbers_string)); + + xla::DotDimensionNumbers dimension_numbers; + dimension_numbers.ParseFromString(dimension_numbers_string); + + // Check that number of contracting dimensions match. + if (dimension_numbers.lhs_contracting_dimensions_size() != + dimension_numbers.rhs_contracting_dimensions_size()) + return errors::InvalidArgument( + "Must specify the same number of contracting dimensions for lhs " + "and rhs. Got: ", + dimension_numbers.lhs_contracting_dimensions_size(), " and ", + dimension_numbers.rhs_contracting_dimensions_size()); + + // Check that contracting dimension sizes match. + for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size(); + ++i) { + const int64 lhs_contracting_dimension = + dimension_numbers.lhs_contracting_dimensions(i); + const int64 rhs_contracting_dimension = + dimension_numbers.rhs_contracting_dimensions(i); + shape_inference::DimensionOrConstant lhs_contracting_dimension_or_constant( + c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension)); + shape_inference::DimensionOrConstant rhs_contracting_dimension_or_constant( + c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension)); + const int64 lhs_contracting_dimension_size = + c->Value(lhs_contracting_dimension_or_constant); + const int64 rhs_contracting_dimension_size = + c->Value(rhs_contracting_dimension_or_constant); + if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) { + return errors::InvalidArgument( + "Contracting dimension sizes do not match. Got: ", + lhs_contracting_dimension_size, " and ", + rhs_contracting_dimension_size); + } + } + + // Check that number of batch dimensions match. + if (dimension_numbers.lhs_batch_dimensions_size() != + dimension_numbers.rhs_batch_dimensions_size()) + return errors::InvalidArgument( + "Must specify the same number of batch dimensions for lhs " + "and rhs. Got: ", + dimension_numbers.lhs_batch_dimensions_size(), " and ", + dimension_numbers.rhs_batch_dimensions_size()); + + // Check that batch dimension sizes match. + for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { + const int64 lhs_batch_dimension = dimension_numbers.lhs_batch_dimensions(i); + const int64 rhs_batch_dimension = dimension_numbers.rhs_batch_dimensions(i); + shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant( + c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension)); + shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant( + c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension)); + const int64 lhs_batch_dimension_size = + c->Value(lhs_batch_dimension_or_constant); + const int64 rhs_batch_dimension_size = + c->Value(rhs_batch_dimension_or_constant); + if (lhs_batch_dimension_size != rhs_batch_dimension_size) { + return errors::InvalidArgument( + "Batch dimension sizes do not match. Got: ", lhs_batch_dimension_size, + " and ", rhs_batch_dimension_size); + } + } + + // The ranks of lhs and rhs are decremented by 1 respectively due to the + // contraction, and added for the rank of the result. When an input tensor + // is a scalar, its contribution to the rank of the result is 0. Generate + // the result dimensions in order, rhs dimensions followed by lhs + // dimensions except the contracted and batch dimensions. + std::vector<shape_inference::DimensionHandle> output_dims; + for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) { + output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim)); + } + const int32 lhs_rank = c->Rank(lhs_shape_handle); + for (int64 i = 0; i < lhs_rank; ++i) { + if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(), + i) || + absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) { + continue; + } + output_dims.emplace_back(c->Dim(lhs_shape_handle, i)); + } + + const int32 rhs_rank = c->Rank(rhs_shape_handle); + for (int64 i = 0; i < rhs_rank; ++i) { + if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(), + i) || + absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) { + continue; + } + output_dims.emplace_back(c->Dim(rhs_shape_handle, i)); + } + + c->set_output(0, c->MakeShape(output_dims)); + return Status::OK(); +} + REGISTER_OP("XlaDot") .Input("lhs: T") .Input("rhs: T") @@ -168,120 +309,7 @@ REGISTER_OP("XlaDot") .Attr("dimension_numbers: string") .Attr("precision_config: string") .Output("output: T") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle lhs_shape_handle = c->input(0); - shape_inference::ShapeHandle rhs_shape_handle = c->input(1); - if (!c->FullyDefined(lhs_shape_handle) || - !c->FullyDefined(rhs_shape_handle)) { - return shape_inference::UnknownShape(c); - } - - string dimension_numbers_string; - TF_RETURN_IF_ERROR( - c->GetAttr("dimension_numbers", &dimension_numbers_string)); - - xla::DotDimensionNumbers dimension_numbers; - dimension_numbers.ParseFromString(dimension_numbers_string); - - // Check that number of contracting dimensions match. - if (dimension_numbers.lhs_contracting_dimensions_size() != - dimension_numbers.rhs_contracting_dimensions_size()) - return errors::InvalidArgument( - "Must specify the same number of contracting dimensions for lhs " - "and rhs. Got: ", - dimension_numbers.lhs_contracting_dimensions_size(), " and ", - dimension_numbers.rhs_contracting_dimensions_size()); - - // Check that contracting dimension sizes match. - for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size(); - ++i) { - const int64 lhs_contracting_dimension = - dimension_numbers.lhs_contracting_dimensions(i); - const int64 rhs_contracting_dimension = - dimension_numbers.rhs_contracting_dimensions(i); - shape_inference::DimensionOrConstant - lhs_contracting_dimension_or_constant( - c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension)); - shape_inference::DimensionOrConstant - rhs_contracting_dimension_or_constant( - c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension)); - const int64 lhs_contracting_dimension_size = - c->Value(lhs_contracting_dimension_or_constant); - const int64 rhs_contracting_dimension_size = - c->Value(rhs_contracting_dimension_or_constant); - if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) { - return errors::InvalidArgument( - "Contracting dimension sizes do not match. Got: ", - lhs_contracting_dimension_size, " and ", - rhs_contracting_dimension_size); - } - } - - // Check that number of batch dimensions match. - if (dimension_numbers.lhs_batch_dimensions_size() != - dimension_numbers.rhs_batch_dimensions_size()) - return errors::InvalidArgument( - "Must specify the same number of batch dimensions for lhs " - "and rhs. Got: ", - dimension_numbers.lhs_batch_dimensions_size(), " and ", - dimension_numbers.rhs_batch_dimensions_size()); - - // Check that batch dimension sizes match. - for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); - ++i) { - const int64 lhs_batch_dimension = - dimension_numbers.lhs_batch_dimensions(i); - const int64 rhs_batch_dimension = - dimension_numbers.rhs_batch_dimensions(i); - shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant( - c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension)); - shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant( - c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension)); - const int64 lhs_batch_dimension_size = - c->Value(lhs_batch_dimension_or_constant); - const int64 rhs_batch_dimension_size = - c->Value(rhs_batch_dimension_or_constant); - if (lhs_batch_dimension_size != rhs_batch_dimension_size) { - return errors::InvalidArgument( - "Batch dimension sizes do not match. Got: ", - lhs_batch_dimension_size, " and ", rhs_batch_dimension_size); - } - } - - // The ranks of lhs and rhs are decremented by 1 respectively due to the - // contraction, and added for the rank of the result. When an input tensor - // is a scalar, its contribution to the rank of the result is 0. Generate - // the result dimensions in order, rhs dimensions followed by lhs - // dimensions except the contracted and batch dimensions. - std::vector<shape_inference::DimensionHandle> output_dims; - for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) { - output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim)); - } - const int32 lhs_rank = c->Rank(lhs_shape_handle); - for (int64 i = 0; i < lhs_rank; ++i) { - if (absl::c_linear_search( - dimension_numbers.lhs_contracting_dimensions(), i) || - absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), - i)) { - continue; - } - output_dims.emplace_back(c->Dim(lhs_shape_handle, i)); - } - - const int32 rhs_rank = c->Rank(rhs_shape_handle); - for (int64 i = 0; i < rhs_rank; ++i) { - if (absl::c_linear_search( - dimension_numbers.rhs_contracting_dimensions(), i) || - absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), - i)) { - continue; - } - output_dims.emplace_back(c->Dim(rhs_shape_handle, i)); - } - - c->set_output(0, c->MakeShape(output_dims)); - return Status::OK(); - }) + .SetShapeFn(XlaDotShapeFunction) .Doc(R"doc( Wraps the XLA DotGeneral operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral @@ -293,6 +321,28 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto. precision_config: a serialized xla::PrecisionConfig proto. )doc"); +REGISTER_OP("XlaDotV2") + .Input("lhs: LhsT") + .Input("rhs: RhsT") + .Attr("LhsT: numbertype") + .Attr("RhsT: numbertype") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Attr("preferred_element_type: numbertype") + .Output("output: preferred_element_type") + .SetShapeFn(XlaDotShapeFunction) + .Doc(R"doc( +Wraps the XLA DotGeneral operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral +. + +lhs: the LHS tensor +rhs: the RHS tensor +dimension_numbers: a serialized xla::DotDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfig proto. +preferred_element_type: The type of the tensor. +)doc"); + REGISTER_OP("XlaSetBound") .Input("input: int32") .Input("bound: int32") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 4a12017bbce..5acd3025e1a 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops.numpy_ops import np_utils # TODO(phawkins): provide wrappers for all XLA operators. Currently the missing # ops include: @@ -249,6 +250,7 @@ def conv(lhs, dimension_numbers, feature_group_count=1, precision_config=None, + preferred_element_type=None, name=None): """Wraps the XLA ConvGeneralDilated operator. @@ -266,6 +268,7 @@ def conv(lhs, dimension_numbers: a `ConvolutionDimensionNumbers` proto. feature_group_count: number of feature groups for grouped convolution. precision_config: a `xla.PrecisionConfig` proto. + preferred_element_type: the result `dtype`. name: an optional name for the operator Returns: @@ -274,7 +277,9 @@ def conv(lhs, precision_config_proto = "" if precision_config: precision_config_proto = precision_config.SerializeToString() - return gen_xla_ops.xla_conv( + if preferred_element_type is None: + preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) + return gen_xla_ops.xla_conv_v2( lhs, rhs, window_strides=window_strides, @@ -284,6 +289,7 @@ def conv(lhs, feature_group_count=feature_group_count, dimension_numbers=dimension_numbers.SerializeToString(), precision_config=precision_config_proto, + preferred_element_type=preferred_element_type, name=name) @@ -294,15 +300,23 @@ def dot(lhs, rhs, name=None): return math_ops.tensordot(lhs, rhs, axes=1, name=name) -def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): +def dot_general(lhs, + rhs, + dimension_numbers, + precision_config=None, + preferred_element_type=None, + name=None): precision_config_proto = "" if precision_config: precision_config_proto = precision_config.SerializeToString() - return gen_xla_ops.xla_dot( + if preferred_element_type is None: + preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) + return gen_xla_ops.xla_dot_v2( lhs, rhs, dimension_numbers=dimension_numbers.SerializeToString(), precision_config=precision_config_proto, + preferred_element_type=preferred_element_type, name=name) diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc index 7aa635a7711..e5084d7edd8 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc +++ b/tensorflow/core/profiler/utils/kernel_stats_utils.cc @@ -143,7 +143,8 @@ bool IsOpTensorCoreEligible(absl::string_view tf_op_name) { || absl::StrContains(tf_op_name, "CudnnRNNForward") || absl::StrContains(tf_op_name, "CudnnRNNBackprop") // Special cases. - || absl::EndsWith(tf_op_name, "XlaDot"); + || absl::EndsWith(tf_op_name, "XlaDot") + || absl::EndsWith(tf_op_name, "XlaDotV2"); // clang-format on }