From aec80b24204584b376a7f6c2423a6df80e35cec9 Mon Sep 17 00:00:00 2001 From: Ce Zheng Date: Tue, 17 Nov 2020 12:42:31 -0800 Subject: [PATCH] [XLA] Allow integral dot results to be accumulated as a wider type. PiperOrigin-RevId: 342923991 Change-Id: Ic3fee996c9a1c5fd10281b65264a2cef4c0cfcb4 --- tensorflow/compiler/xla/service/BUILD | 24 +++++ .../xla/service/algebraic_simplifier.cc | 4 + tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 4 + tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../compiler/xla/service/gpu/gpu_compiler.cc | 3 + .../xla/service/hlo_evaluator_test.cc | 41 ++++++++ .../compiler/xla/service/hlo_verifier.cc | 10 +- .../compiler/xla/service/integral_upcaster.cc | 76 ++++++++++++++ .../compiler/xla/service/integral_upcaster.h | 39 ++++++++ .../xla/service/integral_upcaster_test.cc | 99 +++++++++++++++++++ tensorflow/compiler/xla/shape_util.cc | 9 ++ tensorflow/compiler/xla/shape_util.h | 4 + .../compiler/xla/tests/dot_operation_test.cc | 16 +++ 14 files changed, 329 insertions(+), 2 deletions(-) create mode 100644 tensorflow/compiler/xla/service/integral_upcaster.cc create mode 100644 tensorflow/compiler/xla/service/integral_upcaster.h create mode 100644 tensorflow/compiler/xla/service/integral_upcaster_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index df2dc34be10..03a30695d13 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -5223,3 +5223,27 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "integral_upcaster", + srcs = ["integral_upcaster.cc"], + hdrs = ["integral_upcaster.h"], + deps = [ + ":hlo", + ":op_expander_pass", + ":shape_inference", + ], +) + +tf_cc_test( + name = "integral_upcaster_test", + srcs = ["integral_upcaster_test.cc"], + deps = [ + ":hlo_matchers", + ":integral_upcaster", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 046701c564f..39d1eddc569 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1800,6 +1800,8 @@ StatusOr AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( : dot->mutable_operand(1); TF_ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config())); + // TODO(b/165824019): Add an optional preferred element type to MakeDotHlo. + new_dot->mutable_shape()->set_element_type(dot->shape().element_type()); if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) { TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot)); } else { @@ -4678,6 +4680,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { TF_ASSIGN_OR_RETURN( auto new_dot, MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config())); dot->SetupDerivedInstruction(new_dot); + // TODO(b/165824019): Add an optional preferred element type to MakeDotHlo. + new_dot->mutable_shape()->set_element_type(dot->shape().element_type()); if (reduce_dims.empty()) { return ReplaceInstruction(hlo, new_dot); } diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index e5c59fc0c7a..5112db30e08 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -169,6 +169,7 @@ cc_library( "//tensorflow/compiler/xla/service:comparison_expander", "//tensorflow/compiler/xla/service:slice_sinker", "//tensorflow/compiler/xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:integral_upcaster", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e92f890ba67..ca67fe66994 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -102,6 +102,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/service/integral_upcaster.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logistic_expander.h" #include "tensorflow/compiler/xla/service/map_inliner.h" @@ -269,6 +270,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline pipeline("HLO passes through layout assignment"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + + pipeline.AddPass(); + // Expand random number generation. pipeline.AddPass(); pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index bf2fbfe1973..5edbfebddda 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1238,6 +1238,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto_util", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:integral_upcaster", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:logistic_expander", "//tensorflow/compiler/xla/service:loop_schedule_linearizer", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 1c1a028e2f9..b8b18d1cff3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -90,6 +90,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/integral_upcaster.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logistic_expander.h" #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h" @@ -142,6 +143,8 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pipeline.AddPass(); + // Expand random number generation. pipeline.AddPass(); pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 4cd5714de1a..01ad536e033 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -4567,5 +4567,46 @@ TEST_F(HloEvaluatorTest, MapBF16) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_F(HloEvaluatorTest, DotUpcast) { + const absl::string_view hlo_text = R"( + HloModule test + ENTRY DotUpcast { + l = s16[4,3]{1,0} parameter(0) + r = s8[3,2]{1,0} parameter(1) + ROOT result = s32[4,2] dot(l, r), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + // lhs: + // s16[4,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // { 13, 14, 15 }, + // } + auto lhs_array = absl::make_unique>(4, 3); + lhs_array->FillUnique(1); + auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); + + // rhs: + // s8[3,2] { + // { 1, 2 }, + // { 3, 4 }, + // { 5, 6 }, + // } + auto rhs_array = absl::make_unique>(3, 2); + rhs_array->FillUnique(1); + auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&lhs_literal, &rhs_literal})); + + auto expected_array = + Array2D({{22, 28}, {58, 76}, {94, 124}, {130, 172}}); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 2f5a633fde9..13c754438eb 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -136,20 +136,26 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } Status ShapeVerifier::HandleDot(HloInstruction* dot) { - TF_ASSIGN_OR_RETURN(const Shape expected, + TF_ASSIGN_OR_RETURN(Shape expected, ShapeInference::InferDotOpShape( dot->operand(0)->shape(), dot->operand(1)->shape(), dot->dot_dimension_numbers())); + if (ShapeUtil::CanUpcastIntegral(expected, dot->shape())) { + expected.set_element_type(dot->shape().element_type()); + } return CheckShape(dot, expected); } Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { TF_ASSIGN_OR_RETURN( - const Shape expected, + Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), convolution->feature_group_count(), convolution->batch_group_count(), convolution->window(), convolution->convolution_dimension_numbers())); + if (ShapeUtil::CanUpcastIntegral(expected, convolution->shape())) { + expected.set_element_type(convolution->shape().element_type()); + } return CheckShape(convolution, expected); } diff --git a/tensorflow/compiler/xla/service/integral_upcaster.cc b/tensorflow/compiler/xla/service/integral_upcaster.cc new file mode 100644 index 00000000000..d8383b25c84 --- /dev/null +++ b/tensorflow/compiler/xla/service/integral_upcaster.cc @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/integral_upcaster.h" + +#include "tensorflow/compiler/xla/service/shape_inference.h" + +namespace xla { +namespace { + +StatusOr> MaybeInferShape( + const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kDot: + return ShapeInference::InferDotOpShape( + instruction->operand(0)->shape(), instruction->operand(1)->shape(), + instruction->dot_dimension_numbers()); + case HloOpcode::kConvolution: + return ShapeInference::InferConvolveShape( + instruction->operand(0)->shape(), instruction->operand(1)->shape(), + instruction->feature_group_count(), instruction->batch_group_count(), + instruction->window(), instruction->convolution_dimension_numbers()); + default: + return absl::make_optional(); + } +} + +} // namespace + +bool IntegralUpcaster::InstructionMatchesPattern(HloInstruction* instruction) { + if (!ShapeUtil::ElementIsIntegral(instruction->shape())) { + return false; + } + auto status_or_inferred_shape = MaybeInferShape(instruction); + if (!status_or_inferred_shape.ok() || + !status_or_inferred_shape->has_value()) { + return false; + } + const Shape& inferred_shape = status_or_inferred_shape.ValueOrDie().value(); + + return inferred_shape.element_type() != instruction->shape().element_type() && + ShapeUtil::CanUpcastIntegral(inferred_shape, instruction->shape()); +} + +StatusOr IntegralUpcaster::ExpandInstruction( + HloInstruction* instruction) { + auto* computation = instruction->parent(); + auto type = instruction->shape().element_type(); + for (int i = 0; i < instruction->operand_count(); ++i) { + auto* operand = instruction->mutable_operand(i); + if (operand->shape().element_type() == type) { + continue; + } + auto upcast_shape = operand->shape(); + upcast_shape.set_element_type(type); + auto* convert_inst = computation->AddInstruction( + HloInstruction::CreateConvert(upcast_shape, operand)); + TF_RETURN_IF_ERROR( + instruction->ReplaceOperandWithDifferentShape(i, convert_inst)); + } + return nullptr; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/integral_upcaster.h b/tensorflow/compiler/xla/service/integral_upcaster.h new file mode 100644 index 00000000000..81915aa415f --- /dev/null +++ b/tensorflow/compiler/xla/service/integral_upcaster.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTEGRAL_UPCASTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTEGRAL_UPCASTER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +// Inserts Convert to integral operands of instructions that allows result +// accumulation as wider integral types. +class IntegralUpcaster : public OpExpanderPass { + public: + absl::string_view name() const override { return "integral_upcaster"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTEGRAL_UPCASTER_H_ diff --git a/tensorflow/compiler/xla/service/integral_upcaster_test.cc b/tensorflow/compiler/xla/service/integral_upcaster_test.cc new file mode 100644 index 00000000000..6cef3e0643a --- /dev/null +++ b/tensorflow/compiler/xla/service/integral_upcaster_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/integral_upcaster.h" + +#include "absl/strings/substitute.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +class IntegralUpcasterTest + : public HloTestBase, + public ::testing::WithParamInterface< + std::tuple> {}; + +bool ShouldUpcast(PrimitiveType operand_type, PrimitiveType result_type) { + return primitive_util::BitWidth(operand_type) < + primitive_util::BitWidth(result_type); +} + +TEST_P(IntegralUpcasterTest, ConvertInserted) { + PrimitiveType lhs_type, rhs_type, result_type; + std::tie(lhs_type, rhs_type, result_type) = GetParam(); + absl::string_view module_tmpl = R"( + HloModule module + + ENTRY main { + p0 = $0[2,3]{1,0} parameter(0) + p1 = $1[3,2]{1,0} parameter(1) + ROOT dot = $2[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; + auto module_string = absl::Substitute( + module_tmpl, primitive_util::LowercasePrimitiveTypeName(lhs_type), + primitive_util::LowercasePrimitiveTypeName(rhs_type), + primitive_util::LowercasePrimitiveTypeName(result_type)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool upcasted, IntegralUpcaster().Run(module.get())); + EXPECT_EQ(upcasted, ShouldUpcast(lhs_type, result_type) || + ShouldUpcast(rhs_type, result_type)); + auto original_lhs = op::Parameter(0); + auto original_rhs = op::Parameter(1); + auto upcasted_lhs = + ShouldUpcast(lhs_type, result_type) + ? AllOf(op::Convert(original_lhs), + op::Shape(absl::Substitute( + "$0[2,3]{1,0}", + primitive_util::LowercasePrimitiveTypeName(result_type)))) + : original_lhs; + auto upcasted_rhs = + ShouldUpcast(rhs_type, result_type) + ? AllOf(op::Convert(original_rhs), + op::Shape(absl::Substitute( + "$0[3,2]{1,0}", + primitive_util::LowercasePrimitiveTypeName(result_type)))) + : original_rhs; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + AllOf(op::Dot(upcasted_lhs, upcasted_rhs), + op::Shape(absl::Substitute( + "$0[2,2]{1,0}", + primitive_util::LowercasePrimitiveTypeName(result_type))))); +} + +INSTANTIATE_TEST_SUITE_P(S16U16, IntegralUpcasterTest, + ::testing::Values(std::make_tuple(S8, S8, S16), + std::make_tuple(U8, U8, U16))); + +INSTANTIATE_TEST_SUITE_P(S32, IntegralUpcasterTest, + ::testing::Combine(::testing::Values(S8, S16), + ::testing::Values(S8, S16), + ::testing::Values(S32))); + +INSTANTIATE_TEST_SUITE_P(U32, IntegralUpcasterTest, + ::testing::Combine(::testing::Values(U8, U16), + ::testing::Values(U8, U16), + ::testing::Values(U32))); + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 0c877bf6102..cb0edfb6be6 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1633,4 +1633,13 @@ Shape ShapeUtil::DeviceShapeToHostShape(Shape s) { return s; } +/*static*/ bool ShapeUtil::CanUpcastIntegral(const Shape& from, + const Shape& to) { + return ElementIsIntegral(from) && ElementIsIntegral(to) && + ElementIsSigned(from) == ElementIsSigned(to) && + primitive_util::BitWidth(from.element_type()) <= + primitive_util::BitWidth(to.element_type()) && + CompatibleIgnoringElementType(from, to); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 5a5695d32ee..c1a6a2c8b1d 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -787,6 +787,10 @@ class ShapeUtil { // information, from a shape. static Shape DeviceShapeToHostShape(Shape s); + // Returns true iff integral shape `from` can be safely upcasted to integral + // shape `to`. + static bool CanUpcastIntegral(const Shape& from, const Shape& to); + private: // Validates the shape size is sane. This makes sure it's safe to do // calculations in int64 without overflowing. diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index e06e2972f1c..72f27082fda 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -1753,6 +1753,22 @@ XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) { ComputeAndCompare(&builder, {}, error_spec_); } +XLA_TEST_F(DotOperationTextTest, WiderIntegralResultAccumulation) { + absl::string_view hlo_string = + R"( +HloModule WiderIntegralAccumulation + +ENTRY MatrixVectorComplex { + p0 = s8[5,5]{1,0} parameter(0) + p1 = s16[5,1]{0,1} parameter(1) + ROOT dot = s32[5,1]{1,0} dot(p0, p1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + // This benchmark is to show the performance impact of the following // transformation: // dot(reshape(transpose(A)), Const) ==>