From 9ac993ae57b4a209f80edfb27727536360304caa Mon Sep 17 00:00:00 2001 From: Ce Zheng Date: Tue, 29 Dec 2020 11:10:09 -0800 Subject: [PATCH] [XLA] IntegralUpcaster->OperandUpcaster to allow auto upcasting floating point types. PiperOrigin-RevId: 349445605 Change-Id: I87d081c7f784155ffb885fcc3561bcc46a32c163 --- tensorflow/compiler/xla/service/BUILD | 12 +-- .../xla/service/convert_operand_folding.cc | 13 ++-- .../service/convert_operand_folding_test.cc | 76 ++++++++++++------- tensorflow/compiler/xla/service/cpu/BUILD | 2 +- .../compiler/xla/service/cpu/cpu_compiler.cc | 4 +- tensorflow/compiler/xla/service/gpu/BUILD | 2 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 4 +- ...tegral_upcaster.cc => operand_upcaster.cc} | 18 ++--- ...integral_upcaster.h => operand_upcaster.h} | 8 +- ...aster_test.cc => operand_upcaster_test.cc} | 19 +++-- tensorflow/compiler/xla/shape_util.cc | 10 +-- tensorflow/compiler/xla/shape_util.h | 6 +- 12 files changed, 99 insertions(+), 75 deletions(-) rename tensorflow/compiler/xla/service/{integral_upcaster.cc => operand_upcaster.cc} (84%) rename tensorflow/compiler/xla/service/{integral_upcaster.h => operand_upcaster.h} (83%) rename tensorflow/compiler/xla/service/{integral_upcaster_test.cc => operand_upcaster_test.cc} (85%) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 99572303cdb..3b0fe1190f8 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -5222,9 +5222,9 @@ tf_cc_test( ) cc_library( - name = "integral_upcaster", - srcs = ["integral_upcaster.cc"], - hdrs = ["integral_upcaster.h"], + name = "operand_upcaster", + srcs = ["operand_upcaster.cc"], + hdrs = ["operand_upcaster.h"], deps = [ ":hlo", ":op_expander_pass", @@ -5233,11 +5233,11 @@ cc_library( ) tf_cc_test( - name = "integral_upcaster_test", - srcs = ["integral_upcaster_test.cc"], + name = "operand_upcaster_test", + srcs = ["operand_upcaster_test.cc"], deps = [ ":hlo_matchers", - ":integral_upcaster", + ":operand_upcaster", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/service/convert_operand_folding.cc b/tensorflow/compiler/xla/service/convert_operand_folding.cc index 312f155788e..1e102a5e644 100644 --- a/tensorflow/compiler/xla/service/convert_operand_folding.cc +++ b/tensorflow/compiler/xla/service/convert_operand_folding.cc @@ -20,18 +20,19 @@ namespace { bool IsUpcastConvert(const HloInstruction* hlo) { return hlo->opcode() == HloOpcode::kConvert && - ShapeUtil::CanUpcastIntegral(hlo->operand(0)->shape(), hlo->shape()) && - ShapeUtil::EqualIgnoringElementType(hlo->operand(0)->shape(), - hlo->shape()); + ShapeUtil::ElementIsFloating(hlo->shape()) == + ShapeUtil::ElementIsFloating(hlo->operand(0)->shape()) && + ShapeUtil::ElementIsSigned(hlo->shape()) == + ShapeUtil::ElementIsSigned(hlo->operand(0)->shape()) && + ShapeUtil::HigherPrecisionElementType(hlo->operand(0)->shape(), + hlo->shape()) == + hlo->shape().element_type(); } } // namespace bool ConvertOperandFolding::InstructionMatchesPattern( HloInstruction* instruction) { - if (!ShapeUtil::ElementIsIntegral(instruction->shape())) { - return false; - } if (instruction->opcode() != HloOpcode::kDot && instruction->opcode() != HloOpcode::kConvolution) { return false; diff --git a/tensorflow/compiler/xla/service/convert_operand_folding_test.cc b/tensorflow/compiler/xla/service/convert_operand_folding_test.cc index ee96e17ae2d..658cdf79d5f 100644 --- a/tensorflow/compiler/xla/service/convert_operand_folding_test.cc +++ b/tensorflow/compiler/xla/service/convert_operand_folding_test.cc @@ -26,7 +26,7 @@ namespace op = ::xla::testing::opcode_matchers; using ConvertOperandFoldingTest = HloTestBase; -TEST_F(ConvertOperandFoldingTest, UpcastConvertFolded) { +TEST_F(ConvertOperandFoldingTest, IntegralUpcastConvertFolded) { absl::string_view module_string = R"( HloModule module @@ -48,6 +48,54 @@ TEST_F(ConvertOperandFoldingTest, UpcastConvertFolded) { op::Shape("s16[2,2]{1,0}"))); } +TEST_F(ConvertOperandFoldingTest, FloatingUpcastConvertFolded) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f16[2,3]{1,0} parameter(0) + p1 = bf16[3,2]{0,1} parameter(1) + c0 = f32[2,3]{1,0} convert(p0) + c1 = f32[3,2]{0,1} convert(p1) + ROOT dot = f32[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool folded, + ConvertOperandFolding().Run(module.get())); + EXPECT_TRUE(folded); + EXPECT_THAT(module->entry_computation()->root_instruction(), + AllOf(op::Dot(op::Parameter(0), op::Parameter(1)), + op::Shape("f32[2,2]{1,0}"))); +} + +TEST_F(ConvertOperandFoldingTest, IntegralToFloatingConvertNotFolded) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = s8[2,3]{1,0} parameter(0) + p1 = s16[3,2]{0,1} parameter(1) + c0 = f16[2,3]{1,0} convert(p0) + c1 = f32[3,2]{0,1} convert(p1) + ROOT dot = f32[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool folded, + ConvertOperandFolding().Run(module.get())); + EXPECT_FALSE(folded); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + AllOf( + op::Dot( + AllOf(op::Convert(op::Parameter(0)), op::Shape("f16[2,3]{1,0}")), + AllOf(op::Convert(op::Parameter(1)), op::Shape("f32[3,2]{0,1}"))), + op::Shape("f32[2,2]{1,0}"))); +} + TEST_F(ConvertOperandFoldingTest, DowncastConvertNotFolded) { absl::string_view module_string = R"( HloModule module @@ -74,32 +122,6 @@ TEST_F(ConvertOperandFoldingTest, DowncastConvertNotFolded) { op::Shape("s16[2,2]{1,0}"))); } -TEST_F(ConvertOperandFoldingTest, LayoutChangingConvertNotFolded) { - absl::string_view module_string = R"( - HloModule module - - ENTRY main { - p0 = s8[2,3]{1,0} parameter(0) - p1 = s16[3,2]{0,1} parameter(1) - c0 = s16[2,3]{0,1} convert(p0) - c1 = s16[3,2]{1,0} convert(p1) - ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1}, - rhs_contracting_dims={0} - })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_string)); - TF_ASSERT_OK_AND_ASSIGN(bool folded, - ConvertOperandFolding().Run(module.get())); - EXPECT_FALSE(folded); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - AllOf( - op::Dot( - AllOf(op::Convert(op::Parameter(0)), op::Shape("s16[2,3]{0,1}")), - AllOf(op::Convert(op::Parameter(1)), op::Shape("s16[3,2]{1,0}"))), - op::Shape("s16[2,2]{1,0}"))); -} - TEST_F(ConvertOperandFoldingTest, OneOperandFolded) { absl::string_view module_string = R"( HloModule module diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 5112db30e08..5d7ff7481f1 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -169,7 +169,7 @@ cc_library( "//tensorflow/compiler/xla/service:comparison_expander", "//tensorflow/compiler/xla/service:slice_sinker", "//tensorflow/compiler/xla:cpu_function_runtime", - "//tensorflow/compiler/xla/service:integral_upcaster", + "//tensorflow/compiler/xla/service:operand_upcaster", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 5bd2d13688b..6b801585fef 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -102,10 +102,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/service/integral_upcaster.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logistic_expander.h" #include "tensorflow/compiler/xla/service/map_inliner.h" +#include "tensorflow/compiler/xla/service/operand_upcaster.h" #include "tensorflow/compiler/xla/service/qr_expander.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" @@ -271,7 +271,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - pipeline.AddPass(); + pipeline.AddPass(); // Expand random number generation. pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a456b3f026d..09957450293 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1442,10 +1442,10 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto_util", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/service:integral_upcaster", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:logistic_expander", "//tensorflow/compiler/xla/service:loop_schedule_linearizer", + "//tensorflow/compiler/xla/service:operand_upcaster", "//tensorflow/compiler/xla/service:qr_expander", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:rng_bit_generator_expander", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 8084e0eb71d..0b1095e7683 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -94,10 +94,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/service/integral_upcaster.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logistic_expander.h" #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h" +#include "tensorflow/compiler/xla/service/operand_upcaster.h" #include "tensorflow/compiler/xla/service/qr_expander.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" @@ -149,7 +149,7 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - pipeline.AddPass(); + pipeline.AddPass(); // Expand random number generation. pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/integral_upcaster.cc b/tensorflow/compiler/xla/service/operand_upcaster.cc similarity index 84% rename from tensorflow/compiler/xla/service/integral_upcaster.cc rename to tensorflow/compiler/xla/service/operand_upcaster.cc index 9bb8e468ad4..eff0b557d4b 100644 --- a/tensorflow/compiler/xla/service/integral_upcaster.cc +++ b/tensorflow/compiler/xla/service/operand_upcaster.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/integral_upcaster.h" +#include "tensorflow/compiler/xla/service/operand_upcaster.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -35,28 +35,26 @@ StatusOr> MaybeInferShape( instruction->window(), instruction->convolution_dimension_numbers(), /*preferred_element_type=*/absl::nullopt); default: - return absl::make_optional(); + return absl::optional(absl::nullopt); } } } // namespace -bool IntegralUpcaster::InstructionMatchesPattern(HloInstruction* instruction) { - if (!ShapeUtil::ElementIsIntegral(instruction->shape())) { - return false; - } +bool OperandUpcaster::InstructionMatchesPattern(HloInstruction* instruction) { auto status_or_inferred_shape = MaybeInferShape(instruction); if (!status_or_inferred_shape.ok() || !status_or_inferred_shape->has_value()) { return false; } const Shape& inferred_shape = status_or_inferred_shape.ValueOrDie().value(); - - return inferred_shape.element_type() != instruction->shape().element_type() && - ShapeUtil::CanUpcastIntegral(inferred_shape, instruction->shape()); + if (inferred_shape.element_type() == instruction->shape().element_type()) { + return false; + } + return ShapeUtil::ElementCanUpcast(inferred_shape, instruction->shape()); } -StatusOr IntegralUpcaster::ExpandInstruction( +StatusOr OperandUpcaster::ExpandInstruction( HloInstruction* instruction) { auto* computation = instruction->parent(); auto type = instruction->shape().element_type(); diff --git a/tensorflow/compiler/xla/service/integral_upcaster.h b/tensorflow/compiler/xla/service/operand_upcaster.h similarity index 83% rename from tensorflow/compiler/xla/service/integral_upcaster.h rename to tensorflow/compiler/xla/service/operand_upcaster.h index 81915aa415f..15d4ea56d62 100644 --- a/tensorflow/compiler/xla/service/integral_upcaster.h +++ b/tensorflow/compiler/xla/service/operand_upcaster.h @@ -21,11 +21,11 @@ limitations under the License. namespace xla { -// Inserts Convert to integral operands of instructions that allows result -// accumulation as wider integral types. -class IntegralUpcaster : public OpExpanderPass { +// Inserts Convert to operands of instructions that allows result accumulation +// as wider integral types. +class OperandUpcaster : public OpExpanderPass { public: - absl::string_view name() const override { return "integral_upcaster"; } + absl::string_view name() const override { return "operand_upcaster"; } protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; diff --git a/tensorflow/compiler/xla/service/integral_upcaster_test.cc b/tensorflow/compiler/xla/service/operand_upcaster_test.cc similarity index 85% rename from tensorflow/compiler/xla/service/integral_upcaster_test.cc rename to tensorflow/compiler/xla/service/operand_upcaster_test.cc index 6cef3e0643a..af4a4a57b2e 100644 --- a/tensorflow/compiler/xla/service/integral_upcaster_test.cc +++ b/tensorflow/compiler/xla/service/operand_upcaster_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/integral_upcaster.h" +#include "tensorflow/compiler/xla/service/operand_upcaster.h" #include "absl/strings/substitute.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -25,7 +25,7 @@ namespace { namespace op = ::xla::testing::opcode_matchers; -class IntegralUpcasterTest +class OperandUpcasterTest : public HloTestBase, public ::testing::WithParamInterface< std::tuple> {}; @@ -35,7 +35,7 @@ bool ShouldUpcast(PrimitiveType operand_type, PrimitiveType result_type) { primitive_util::BitWidth(result_type); } -TEST_P(IntegralUpcasterTest, ConvertInserted) { +TEST_P(OperandUpcasterTest, ConvertInserted) { PrimitiveType lhs_type, rhs_type, result_type; std::tie(lhs_type, rhs_type, result_type) = GetParam(); absl::string_view module_tmpl = R"( @@ -53,7 +53,7 @@ TEST_P(IntegralUpcasterTest, ConvertInserted) { primitive_util::LowercasePrimitiveTypeName(result_type)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(module_string)); - TF_ASSERT_OK_AND_ASSIGN(bool upcasted, IntegralUpcaster().Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool upcasted, OperandUpcaster().Run(module.get())); EXPECT_EQ(upcasted, ShouldUpcast(lhs_type, result_type) || ShouldUpcast(rhs_type, result_type)); auto original_lhs = op::Parameter(0); @@ -80,20 +80,25 @@ TEST_P(IntegralUpcasterTest, ConvertInserted) { primitive_util::LowercasePrimitiveTypeName(result_type))))); } -INSTANTIATE_TEST_SUITE_P(S16U16, IntegralUpcasterTest, +INSTANTIATE_TEST_SUITE_P(S16U16, OperandUpcasterTest, ::testing::Values(std::make_tuple(S8, S8, S16), std::make_tuple(U8, U8, U16))); -INSTANTIATE_TEST_SUITE_P(S32, IntegralUpcasterTest, +INSTANTIATE_TEST_SUITE_P(S32, OperandUpcasterTest, ::testing::Combine(::testing::Values(S8, S16), ::testing::Values(S8, S16), ::testing::Values(S32))); -INSTANTIATE_TEST_SUITE_P(U32, IntegralUpcasterTest, +INSTANTIATE_TEST_SUITE_P(U32, OperandUpcasterTest, ::testing::Combine(::testing::Values(U8, U16), ::testing::Values(U8, U16), ::testing::Values(U32))); +INSTANTIATE_TEST_SUITE_P(F32, OperandUpcasterTest, + ::testing::Combine(::testing::Values(BF16, F16), + ::testing::Values(BF16, F16), + ::testing::Values(F32))); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index e84a2591707..1da01ca5f8e 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1711,13 +1711,11 @@ Shape ShapeUtil::DeviceShapeToHostShape(Shape s) { return s; } -/*static*/ bool ShapeUtil::CanUpcastIntegral(const Shape& from, - const Shape& to) { - return ElementIsIntegral(from) && ElementIsIntegral(to) && +/*static*/ bool ShapeUtil::ElementCanUpcast(const Shape& from, + const Shape& to) { + return ElementIsFloating(from) == ElementIsFloating(to) && ElementIsSigned(from) == ElementIsSigned(to) && - primitive_util::BitWidth(from.element_type()) <= - primitive_util::BitWidth(to.element_type()) && - CompatibleIgnoringElementType(from, to); + HigherPrecisionElementType(from, to) == to.element_type(); } } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index ff47ab6ea80..584d948e92c 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -787,9 +787,9 @@ class ShapeUtil { // information, from a shape. static Shape DeviceShapeToHostShape(Shape s); - // Returns true iff integral shape `from` can be safely upcasted to integral - // shape `to`. - static bool CanUpcastIntegral(const Shape& from, const Shape& to); + // Returns true iff element type of shape `from` can be safely upcasted to + // element type of shape `to`. + static bool ElementCanUpcast(const Shape& from, const Shape& to); private: // Fills *shape. Returns true on success.