diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 67a8292d6cc..3da904beb36 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2794,6 +2794,47 @@ cc_library( ], ) +cc_library( + name = "dynamic_dimension_simplifier", + srcs = ["dynamic_dimension_simplifier.cc"], + hdrs = ["dynamic_dimension_simplifier.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + ], +) + +tf_cc_test( + name = "dynamic_dimension_simplifier_test", + srcs = ["dynamic_dimension_simplifier_test.cc"], + deps = [ + ":dynamic_dimension_simplifier", + ":hlo", + ":hlo_casting_utils", + ":hlo_creation_utils", + ":hlo_parser", + ":hlo_pass", + ":hlo_pass_pipeline", + ":pattern_matcher", + ":pattern_matcher_gmock", + ":shape_inference", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "dynamic_padder", srcs = ["dynamic_padder.cc"], diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 2328ad99113..2389a33e52c 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -636,7 +636,7 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate( } Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize( - HloInstruction*) { + HloInstruction* gds) { // Dynamic dimension doesn't propagate through GetDimensionSize: // // Input: F32[x, y, z] @@ -646,6 +646,24 @@ Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize( // The returned value is a scalar, which doesn't have any dynamic dimension in // the shape (although the value contains the real size of the dynamic // dimension of the input). + int64 dim = gds->dimension(); + HloInstruction* operand = gds->mutable_operand(0); + HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim); + HloComputation* computation = gds->parent(); + if (dynamic_size != nullptr) { + TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size)); + // The dependency between an instruction and its dynamic dimensions is not + // modeled in the IR. As instr is being replaced by dynamic_size, also tell + // dynamic dimension inference that the instruction is being replaced. + parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size); + } else { + TF_RET_CHECK(dim < gds->operand(0)->shape().rank()); + int32 size = gds->operand(0)->shape().dimensions(dim); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr)); + parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr); + } return Status::OK(); } @@ -794,7 +812,23 @@ Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) { Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary( HloInstruction* hlo) { - return PassThroughDynamicDimension(hlo); + HloComputation* comp = hlo->parent(); + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* existing_size = + parent_->GetDynamicSize(hlo, index, dimension); + if (existing_size == nullptr || existing_size == dynamic_size) { + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); + } else { + HloInstruction* max = + comp->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum, + dynamic_size, existing_size)); + parent_->SetDynamicSize(hlo, index, dimension, max); + } + return Status::OK(); + }); } Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc new file mode 100644 index 00000000000..d7253a3fbad --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc @@ -0,0 +1,214 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { +namespace { + +// Concat(Concat(A, B), C) => Concat(A, B, C) +StatusOr ConcatForwarding(HloInstruction* concat) { + if (concat->opcode() != HloOpcode::kConcatenate) { + return false; + } + bool changed = false; + + auto parent = concat->parent(); + std::vector new_operands; + for (HloInstruction* operand : concat->operands()) { + if (operand->opcode() != HloOpcode::kConcatenate || + operand->concatenate_dimension() != concat->concatenate_dimension()) { + new_operands.push_back(operand); + } else { + changed = true; + for (HloInstruction* operand_operand : operand->operands()) { + new_operands.push_back(operand_operand); + } + } + } + if (changed) { + auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate( + concat->shape(), new_operands, concat->concatenate_dimension())); + TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat)); + } + return changed; +} + +// Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An +StatusOr SliceConcatForwarding(HloInstruction* slice) { + if (slice->opcode() != HloOpcode::kSlice) { + return false; + } + auto concat = slice->mutable_operand(0); + if (concat->opcode() != HloOpcode::kConcatenate) { + return false; + } + + if (slice->shape().rank() != 1) { + // Slice concat forwarding only work for size 1 tensor. + return false; + } + + int64 concat_dim = concat->concatenate_dimension(); + + std::vector new_operands; + int64 size_so_far = 0; + int64 slice_size = slice->shape().dimensions(concat_dim); + if (slice_size != slice->slice_limits(0) - slice->slice_starts(0)) { + return false; + } + if (slice->slice_strides(0) != 1) { + return false; + } + for (HloInstruction* operand : concat->operands()) { + if (size_so_far == slice->slice_starts(0) && + operand->shape().dimensions(0) == slice_size) { + // Found an operand that can be forwarded. + TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand)); + return true; + } + size_so_far += operand->shape().dimensions(concat_dim); + } + + return false; +} + +// Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A +StatusOr ReshapeBroadcastForwarding(HloInstruction* reshape) { + if (reshape->opcode() != HloOpcode::kReshape) { + return false; + } + auto broadcast = reshape->mutable_operand(0); + if (broadcast->opcode() != HloOpcode::kBroadcast) { + return false; + } + + if (reshape->shape().rank() != 0) { + return false; + } + + if (broadcast->shape().rank() != 1) { + return false; + } + + if (broadcast->mutable_operand(0)->shape().rank() != 0) { + return false; + } + + TF_RETURN_IF_ERROR( + reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0))); + + return true; +} + +// Reshape(Reshape(A, []->[1]), [1]->[]) ==> A +StatusOr ReshapeReshapeForwarding(HloInstruction* reshape) { + if (reshape->opcode() != HloOpcode::kReshape) { + return false; + } + auto reshape_2 = reshape->mutable_operand(0); + if (reshape_2->opcode() != HloOpcode::kReshape) { + return false; + } + + if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) { + return false; + } + TF_RETURN_IF_ERROR( + reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0))); + + return true; +} + +// Convert(A, T->T) ==> A +StatusOr IdentityConvertRemoving(HloInstruction* convert) { + if (convert->opcode() != HloOpcode::kConvert) { + return false; + } + auto operand = convert->mutable_operand(0); + if (Shape::Equal()(convert->shape(), operand->shape())) { + TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand)); + return true; + } + return false; +} + +// Reshape(A, S->S) ==> A +StatusOr IdentityReshapeRemoving(HloInstruction* reshape) { + if (reshape->opcode() != HloOpcode::kReshape) { + return false; + } + auto operand = reshape->mutable_operand(0); + if (Shape::Equal()(reshape->shape(), operand->shape())) { + TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand)); + return true; + } + return false; +} + +} // namespace + +StatusOr DynamicDimensionSimplifier::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "DynamicDimensionSimplifier::Run(), before:\n" + module->ToString()); + bool changed = false; + + for (auto* comp : module->MakeNonfusionComputations()) { + for (auto* inst : comp->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst)); + changed |= local_changed; + } + } + + for (auto* comp : module->MakeNonfusionComputations()) { + for (auto* inst : comp->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst)); + changed |= local_changed; + } + } + + for (auto* comp : module->MakeNonfusionComputations()) { + for (auto* inst : comp->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst)); + changed |= local_changed; + } + } + for (auto* comp : module->MakeNonfusionComputations()) { + for (auto* inst : comp->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst)); + changed |= local_changed; + } + } + for (auto* comp : module->MakeNonfusionComputations()) { + for (auto* inst : comp->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst)); + changed |= local_changed; + } + } + for (auto* comp : module->MakeNonfusionComputations()) { + for (auto* inst : comp->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst)); + changed |= local_changed; + } + } + XLA_VLOG_LINES( + 2, "DynamicDimensionSimplifier::Run(), after:\n" + module->ToString()); + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h new file mode 100644 index 00000000000..e9b99212172 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass simplifies operations on dynamic dimension sizes so that it can be +// easily analyzed by later passes. +class DynamicDimensionSimplifier : public HloModulePass { + public: + absl::string_view name() const override { + return "dynamic dimension simplifier"; + } + + StatusOr Run(HloModule* module) override; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier_test.cc new file mode 100644 index 00000000000..1389d06953c --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier_test.cc @@ -0,0 +1,201 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace m = match; + +class DynamicDimensionSimplifierTest : public HloTestBase {}; + +TEST_F(DynamicDimensionSimplifierTest, ForwardConcat) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[1] parameter(0) + p1 = s32[1] parameter(1) + p2 = s32[1] parameter(2) + concat1 = s32[2] concatenate(p0, p1), dimensions={0} + ROOT concat2 = s32[3] concatenate(concat1, p2), dimensions={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + DynamicDimensionSimplifier simplifier; + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(1), + m::Parameter(2)))); +} + +TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatMultipleDims) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[1, 1] parameter(0) + p1 = s32[1, 1] parameter(1) + p2 = s32[2, 1] parameter(2) + concat1 = s32[2, 1] concatenate(p0, p1), dimensions={0} + ROOT concat2 = s32[2, 2] concatenate(concat1, p2), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + DynamicDimensionSimplifier simplifier; + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); +} + +TEST_F(DynamicDimensionSimplifierTest, ForwardConcatSlice) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[1] parameter(0) + p1 = s32[1] parameter(1) + p2 = s32[1] parameter(2) + concat = s32[3] concatenate(p0, p1, p2), dimensions={0} + ROOT slice = s32[1] slice(concat), slice={[1:2]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + DynamicDimensionSimplifier simplifier; + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(1))); +} + +TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceSizeMismatch) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[1] parameter(0) + p1 = s32[1] parameter(1) + p2 = s32[1] parameter(2) + concat = s32[3] concatenate(p0, p1, p2), dimensions={0} + ROOT slice = s32[2] slice(concat), slice={[1:3]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + DynamicDimensionSimplifier simplifier; + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); +} + +TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceStrided) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[1] parameter(0) + p1 = s32[1] parameter(1) + p2 = s32[1] parameter(2) + concat = s32[3] concatenate(p0, p1, p2), dimensions={0} + ROOT slice = s32[1] slice(concat), slice={[1:2:2]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + DynamicDimensionSimplifier simplifier; + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); +} + +TEST_F(DynamicDimensionSimplifierTest, BroadcastReshapeForwarding) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[] parameter(0) + broadcast = s32[1] broadcast(p0), dimensions={} + ROOT reshape = s32[] reshape(broadcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + DynamicDimensionSimplifier simplifier; + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +TEST_F(DynamicDimensionSimplifierTest, ReshapeReshapeForwarding) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[] parameter(0) + reshape = s32[1] reshape(p0) + ROOT reshape2 = s32[] reshape(reshape) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + DynamicDimensionSimplifier simplifier; + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +TEST_F(DynamicDimensionSimplifierTest, + DoNotReshapeReshapeForwardingShapeMismatch) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[1, 1] parameter(0) + reshape = s32[1] reshape(p0) + ROOT reshape2 = s32[] reshape(reshape) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + DynamicDimensionSimplifier simplifier; + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); +} + +TEST_F(DynamicDimensionSimplifierTest, IdConvertRemoving) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[1] parameter(0) + ROOT reshape2 = s32[1] convert(p0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + DynamicDimensionSimplifier simplifier; + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index ab94695c1e2..7785908e15a 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -1282,6 +1282,97 @@ StatusOr RewriteDynamicSort( return true; } +StatusOr RewriteDynamicBinaryOp( + HloInstruction* binary, + DynamicDimensionInference* dynamic_dimension_inference) { + HloInstruction* operand_0 = binary->mutable_operand(0); + HloInstruction* operand_1 = binary->mutable_operand(1); + + HloComputation* comp = binary->parent(); + TF_RET_CHECK(operand_0->shape().rank() == operand_1->shape().rank()); + auto dims_0 = dynamic_dimension_inference->GetDynamicSizes(operand_0, {}); + auto dims_1 = dynamic_dimension_inference->GetDynamicSizes(operand_1, {}); + bool changed = false; + for (int64 i = 0; i < dims_0.size(); ++i) { + HloInstruction* dim_0 = dims_0[i]; + HloInstruction* dim_1 = dims_1[i]; + + if (dims_0[i] != dims_1[i] && dims_0[i] != nullptr && + dims_1[i] != nullptr) { + changed = true; + // It is possible that a dynamic dimension of one operand is size 1 while + // the other is greater than one. According to implicit broadcast + // semantics, we need to insert broadcast in this case to make the dynamic + // shape match. + + // An implicit broadcast is inserted by slicing the small shape into a + // size 1 slice, reshape out the size 1 dimension then broadcast to the + // full shape: + // + // Input [2, <=5, 3] + // | + // Slice [2, 1, 3] + // | + // Reshape [2, 3] + // | + // Broadcast [2, 5, 3] + auto rewrite_operand = [&](HloInstruction* pred, + HloInstruction* operand) -> HloInstruction* { + Shape static_shape = operand->shape(); + static_shape.clear_dynamic_dimensions(); + pred = comp->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(static_shape, PRED), pred, {})); + Shape slice_shape = static_shape; + slice_shape.set_dimensions(i, 1); + std::vector start_indices(slice_shape.rank(), 0); + std::vector strides(slice_shape.rank(), 1); + HloInstruction* slice = comp->AddInstruction( + HloInstruction::CreateSlice(slice_shape, operand, start_indices, + slice_shape.dimensions(), strides)); + Shape reshape_shape = ShapeUtil::DeleteDimension(i, slice_shape); + HloInstruction* reshape = comp->AddInstruction( + HloInstruction::CreateReshape(reshape_shape, slice)); + std::vector broadcast_dims; + broadcast_dims.reserve(static_shape.rank() - 1); + // Broadcast to all dims execpt for i. + for (int64 j = 0; j < static_shape.rank(); ++j) { + if (j != i) { + broadcast_dims.push_back(j); + } + } + + HloInstruction* broadcast = + comp->AddInstruction(HloInstruction::CreateBroadcast( + static_shape, reshape, broadcast_dims), + "implicit_broadcast"); + + // Use a select instead of conditional as elementwise operations promote + // more fusion. + HloInstruction* select = + comp->AddInstruction(HloInstruction::CreateTernary( + static_shape, HloOpcode::kSelect, pred, broadcast, operand)); + return select; + }; + auto operand_0_needs_broadcast = binary->parent()->AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0, + dim_1, ComparisonDirection::kLt), + "lhs_needs_implicit_broadcast"); + operand_0 = rewrite_operand(operand_0_needs_broadcast, operand_0); + + auto operand_1_needs_broadcast = binary->parent()->AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1, + dim_0, ComparisonDirection::kLt), + "rhs_needs_implicit_broadcast"); + operand_1 = rewrite_operand(operand_1_needs_broadcast, operand_1); + } + } + if (changed) { + TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0)); + TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1)); + } + return changed; +} + StatusOr RewriteDynamicReshape( HloInstruction* reshape, DynamicDimensionInference* dynamic_dimension_inference) { @@ -1732,6 +1823,14 @@ StatusOr DynamicPadder::Run(HloModule* module) { continue; } + // Elementwise binary with dynamic shapes have implicit broadcast + // semantics. + if (inst->IsElementwiseBinary()) { + TF_ASSIGN_OR_RETURN(changed, RewriteDynamicBinaryOp( + inst, &dynamic_dimension_inference)); + continue; + } + if (inst->opcode() == HloOpcode::kDynamicReshape) { TF_ASSIGN_OR_RETURN( changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));