diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 3da904beb36..67a8292d6cc 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2794,47 +2794,6 @@ cc_library( ], ) -cc_library( - name = "dynamic_dimension_simplifier", - srcs = ["dynamic_dimension_simplifier.cc"], - hdrs = ["dynamic_dimension_simplifier.h"], - deps = [ - ":hlo", - ":hlo_pass", - "//tensorflow/compiler/xla:status_macros", - ], -) - -tf_cc_test( - name = "dynamic_dimension_simplifier_test", - srcs = ["dynamic_dimension_simplifier_test.cc"], - deps = [ - ":dynamic_dimension_simplifier", - ":hlo", - ":hlo_casting_utils", - ":hlo_creation_utils", - ":hlo_parser", - ":hlo_pass", - ":hlo_pass_pipeline", - ":pattern_matcher", - ":pattern_matcher_gmock", - ":shape_inference", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep - "//tensorflow/core:lib", - "//tensorflow/core:test", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "dynamic_padder", srcs = ["dynamic_padder.cc"], diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 2389a33e52c..2328ad99113 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -636,7 +636,7 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate( } Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize( - HloInstruction* gds) { + HloInstruction*) { // Dynamic dimension doesn't propagate through GetDimensionSize: // // Input: F32[x, y, z] @@ -646,24 +646,6 @@ Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize( // The returned value is a scalar, which doesn't have any dynamic dimension in // the shape (although the value contains the real size of the dynamic // dimension of the input). - int64 dim = gds->dimension(); - HloInstruction* operand = gds->mutable_operand(0); - HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim); - HloComputation* computation = gds->parent(); - if (dynamic_size != nullptr) { - TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size)); - // The dependency between an instruction and its dynamic dimensions is not - // modeled in the IR. As instr is being replaced by dynamic_size, also tell - // dynamic dimension inference that the instruction is being replaced. - parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size); - } else { - TF_RET_CHECK(dim < gds->operand(0)->shape().rank()); - int32 size = gds->operand(0)->shape().dimensions(dim); - HloInstruction* new_instr = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size))); - TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr)); - parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr); - } return Status::OK(); } @@ -812,23 +794,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) { Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary( HloInstruction* hlo) { - HloComputation* comp = hlo->parent(); - return ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size) { - HloInstruction* existing_size = - parent_->GetDynamicSize(hlo, index, dimension); - if (existing_size == nullptr || existing_size == dynamic_size) { - parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); - } else { - HloInstruction* max = - comp->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum, - dynamic_size, existing_size)); - parent_->SetDynamicSize(hlo, index, dimension, max); - } - return Status::OK(); - }); + return PassThroughDynamicDimension(hlo); } Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc deleted file mode 100644 index d7253a3fbad..00000000000 --- a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc +++ /dev/null @@ -1,214 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h" - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/status_macros.h" - -namespace xla { -namespace { - -// Concat(Concat(A, B), C) => Concat(A, B, C) -StatusOr<bool> ConcatForwarding(HloInstruction* concat) { - if (concat->opcode() != HloOpcode::kConcatenate) { - return false; - } - bool changed = false; - - auto parent = concat->parent(); - std::vector<HloInstruction*> new_operands; - for (HloInstruction* operand : concat->operands()) { - if (operand->opcode() != HloOpcode::kConcatenate || - operand->concatenate_dimension() != concat->concatenate_dimension()) { - new_operands.push_back(operand); - } else { - changed = true; - for (HloInstruction* operand_operand : operand->operands()) { - new_operands.push_back(operand_operand); - } - } - } - if (changed) { - auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate( - concat->shape(), new_operands, concat->concatenate_dimension())); - TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat)); - } - return changed; -} - -// Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An -StatusOr<bool> SliceConcatForwarding(HloInstruction* slice) { - if (slice->opcode() != HloOpcode::kSlice) { - return false; - } - auto concat = slice->mutable_operand(0); - if (concat->opcode() != HloOpcode::kConcatenate) { - return false; - } - - if (slice->shape().rank() != 1) { - // Slice concat forwarding only work for size 1 tensor. - return false; - } - - int64 concat_dim = concat->concatenate_dimension(); - - std::vector<HloInstruction*> new_operands; - int64 size_so_far = 0; - int64 slice_size = slice->shape().dimensions(concat_dim); - if (slice_size != slice->slice_limits(0) - slice->slice_starts(0)) { - return false; - } - if (slice->slice_strides(0) != 1) { - return false; - } - for (HloInstruction* operand : concat->operands()) { - if (size_so_far == slice->slice_starts(0) && - operand->shape().dimensions(0) == slice_size) { - // Found an operand that can be forwarded. - TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand)); - return true; - } - size_so_far += operand->shape().dimensions(concat_dim); - } - - return false; -} - -// Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A -StatusOr<bool> ReshapeBroadcastForwarding(HloInstruction* reshape) { - if (reshape->opcode() != HloOpcode::kReshape) { - return false; - } - auto broadcast = reshape->mutable_operand(0); - if (broadcast->opcode() != HloOpcode::kBroadcast) { - return false; - } - - if (reshape->shape().rank() != 0) { - return false; - } - - if (broadcast->shape().rank() != 1) { - return false; - } - - if (broadcast->mutable_operand(0)->shape().rank() != 0) { - return false; - } - - TF_RETURN_IF_ERROR( - reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0))); - - return true; -} - -// Reshape(Reshape(A, []->[1]), [1]->[]) ==> A -StatusOr<bool> ReshapeReshapeForwarding(HloInstruction* reshape) { - if (reshape->opcode() != HloOpcode::kReshape) { - return false; - } - auto reshape_2 = reshape->mutable_operand(0); - if (reshape_2->opcode() != HloOpcode::kReshape) { - return false; - } - - if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) { - return false; - } - TF_RETURN_IF_ERROR( - reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0))); - - return true; -} - -// Convert(A, T->T) ==> A -StatusOr<bool> IdentityConvertRemoving(HloInstruction* convert) { - if (convert->opcode() != HloOpcode::kConvert) { - return false; - } - auto operand = convert->mutable_operand(0); - if (Shape::Equal()(convert->shape(), operand->shape())) { - TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand)); - return true; - } - return false; -} - -// Reshape(A, S->S) ==> A -StatusOr<bool> IdentityReshapeRemoving(HloInstruction* reshape) { - if (reshape->opcode() != HloOpcode::kReshape) { - return false; - } - auto operand = reshape->mutable_operand(0); - if (Shape::Equal()(reshape->shape(), operand->shape())) { - TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand)); - return true; - } - return false; -} - -} // namespace - -StatusOr<bool> DynamicDimensionSimplifier::Run(HloModule* module) { - XLA_VLOG_LINES( - 2, "DynamicDimensionSimplifier::Run(), before:\n" + module->ToString()); - bool changed = false; - - for (auto* comp : module->MakeNonfusionComputations()) { - for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst)); - changed |= local_changed; - } - } - - for (auto* comp : module->MakeNonfusionComputations()) { - for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst)); - changed |= local_changed; - } - } - - for (auto* comp : module->MakeNonfusionComputations()) { - for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst)); - changed |= local_changed; - } - } - for (auto* comp : module->MakeNonfusionComputations()) { - for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst)); - changed |= local_changed; - } - } - for (auto* comp : module->MakeNonfusionComputations()) { - for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst)); - changed |= local_changed; - } - } - for (auto* comp : module->MakeNonfusionComputations()) { - for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst)); - changed |= local_changed; - } - } - XLA_VLOG_LINES( - 2, "DynamicDimensionSimplifier::Run(), after:\n" + module->ToString()); - return changed; -} -} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h deleted file mode 100644 index e9b99212172..00000000000 --- a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ - -#include <utility> - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { - -// This pass simplifies operations on dynamic dimension sizes so that it can be -// easily analyzed by later passes. -class DynamicDimensionSimplifier : public HloModulePass { - public: - absl::string_view name() const override { - return "dynamic dimension simplifier"; - } - - StatusOr<bool> Run(HloModule* module) override; -}; -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier_test.cc deleted file mode 100644 index 1389d06953c..00000000000 --- a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier_test.cc +++ /dev/null @@ -1,201 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h" - -#include <memory> -#include <utility> - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" -#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" -#include "tensorflow/compiler/xla/service/pattern_matcher.h" -#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/window_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace xla { -namespace { - -namespace m = match; - -class DynamicDimensionSimplifierTest : public HloTestBase {}; - -TEST_F(DynamicDimensionSimplifierTest, ForwardConcat) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = s32[1] parameter(0) - p1 = s32[1] parameter(1) - p2 = s32[1] parameter(2) - concat1 = s32[2] concatenate(p0, p1), dimensions={0} - ROOT concat2 = s32[3] concatenate(concat1, p2), dimensions={0} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - DynamicDimensionSimplifier simplifier; - ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(1), - m::Parameter(2)))); -} - -TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatMultipleDims) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = s32[1, 1] parameter(0) - p1 = s32[1, 1] parameter(1) - p2 = s32[2, 1] parameter(2) - concat1 = s32[2, 1] concatenate(p0, p1), dimensions={0} - ROOT concat2 = s32[2, 2] concatenate(concat1, p2), dimensions={1} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - DynamicDimensionSimplifier simplifier; - ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); -} - -TEST_F(DynamicDimensionSimplifierTest, ForwardConcatSlice) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = s32[1] parameter(0) - p1 = s32[1] parameter(1) - p2 = s32[1] parameter(2) - concat = s32[3] concatenate(p0, p1, p2), dimensions={0} - ROOT slice = s32[1] slice(concat), slice={[1:2]} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - DynamicDimensionSimplifier simplifier; - ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::Parameter(1))); -} - -TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceSizeMismatch) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = s32[1] parameter(0) - p1 = s32[1] parameter(1) - p2 = s32[1] parameter(2) - concat = s32[3] concatenate(p0, p1, p2), dimensions={0} - ROOT slice = s32[2] slice(concat), slice={[1:3]} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - DynamicDimensionSimplifier simplifier; - ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); -} - -TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceStrided) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = s32[1] parameter(0) - p1 = s32[1] parameter(1) - p2 = s32[1] parameter(2) - concat = s32[3] concatenate(p0, p1, p2), dimensions={0} - ROOT slice = s32[1] slice(concat), slice={[1:2:2]} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - DynamicDimensionSimplifier simplifier; - ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); -} - -TEST_F(DynamicDimensionSimplifierTest, BroadcastReshapeForwarding) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = s32[] parameter(0) - broadcast = s32[1] broadcast(p0), dimensions={} - ROOT reshape = s32[] reshape(broadcast) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - DynamicDimensionSimplifier simplifier; - ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::Parameter(0))); -} - -TEST_F(DynamicDimensionSimplifierTest, ReshapeReshapeForwarding) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = s32[] parameter(0) - reshape = s32[1] reshape(p0) - ROOT reshape2 = s32[] reshape(reshape) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - DynamicDimensionSimplifier simplifier; - ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::Parameter(0))); -} - -TEST_F(DynamicDimensionSimplifierTest, - DoNotReshapeReshapeForwardingShapeMismatch) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = s32[1, 1] parameter(0) - reshape = s32[1] reshape(p0) - ROOT reshape2 = s32[] reshape(reshape) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - DynamicDimensionSimplifier simplifier; - ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); -} - -TEST_F(DynamicDimensionSimplifierTest, IdConvertRemoving) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = s32[1] parameter(0) - ROOT reshape2 = s32[1] convert(p0) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - DynamicDimensionSimplifier simplifier; - ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::Parameter(0))); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 7785908e15a..ab94695c1e2 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -1282,97 +1282,6 @@ StatusOr<bool> RewriteDynamicSort( return true; } -StatusOr<bool> RewriteDynamicBinaryOp( - HloInstruction* binary, - DynamicDimensionInference* dynamic_dimension_inference) { - HloInstruction* operand_0 = binary->mutable_operand(0); - HloInstruction* operand_1 = binary->mutable_operand(1); - - HloComputation* comp = binary->parent(); - TF_RET_CHECK(operand_0->shape().rank() == operand_1->shape().rank()); - auto dims_0 = dynamic_dimension_inference->GetDynamicSizes(operand_0, {}); - auto dims_1 = dynamic_dimension_inference->GetDynamicSizes(operand_1, {}); - bool changed = false; - for (int64 i = 0; i < dims_0.size(); ++i) { - HloInstruction* dim_0 = dims_0[i]; - HloInstruction* dim_1 = dims_1[i]; - - if (dims_0[i] != dims_1[i] && dims_0[i] != nullptr && - dims_1[i] != nullptr) { - changed = true; - // It is possible that a dynamic dimension of one operand is size 1 while - // the other is greater than one. According to implicit broadcast - // semantics, we need to insert broadcast in this case to make the dynamic - // shape match. - - // An implicit broadcast is inserted by slicing the small shape into a - // size 1 slice, reshape out the size 1 dimension then broadcast to the - // full shape: - // - // Input [2, <=5, 3] - // | - // Slice [2, 1, 3] - // | - // Reshape [2, 3] - // | - // Broadcast [2, 5, 3] - auto rewrite_operand = [&](HloInstruction* pred, - HloInstruction* operand) -> HloInstruction* { - Shape static_shape = operand->shape(); - static_shape.clear_dynamic_dimensions(); - pred = comp->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::ChangeElementType(static_shape, PRED), pred, {})); - Shape slice_shape = static_shape; - slice_shape.set_dimensions(i, 1); - std::vector<int64> start_indices(slice_shape.rank(), 0); - std::vector<int64> strides(slice_shape.rank(), 1); - HloInstruction* slice = comp->AddInstruction( - HloInstruction::CreateSlice(slice_shape, operand, start_indices, - slice_shape.dimensions(), strides)); - Shape reshape_shape = ShapeUtil::DeleteDimension(i, slice_shape); - HloInstruction* reshape = comp->AddInstruction( - HloInstruction::CreateReshape(reshape_shape, slice)); - std::vector<int64> broadcast_dims; - broadcast_dims.reserve(static_shape.rank() - 1); - // Broadcast to all dims execpt for i. - for (int64 j = 0; j < static_shape.rank(); ++j) { - if (j != i) { - broadcast_dims.push_back(j); - } - } - - HloInstruction* broadcast = - comp->AddInstruction(HloInstruction::CreateBroadcast( - static_shape, reshape, broadcast_dims), - "implicit_broadcast"); - - // Use a select instead of conditional as elementwise operations promote - // more fusion. - HloInstruction* select = - comp->AddInstruction(HloInstruction::CreateTernary( - static_shape, HloOpcode::kSelect, pred, broadcast, operand)); - return select; - }; - auto operand_0_needs_broadcast = binary->parent()->AddInstruction( - HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0, - dim_1, ComparisonDirection::kLt), - "lhs_needs_implicit_broadcast"); - operand_0 = rewrite_operand(operand_0_needs_broadcast, operand_0); - - auto operand_1_needs_broadcast = binary->parent()->AddInstruction( - HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1, - dim_0, ComparisonDirection::kLt), - "rhs_needs_implicit_broadcast"); - operand_1 = rewrite_operand(operand_1_needs_broadcast, operand_1); - } - } - if (changed) { - TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0)); - TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1)); - } - return changed; -} - StatusOr<bool> RewriteDynamicReshape( HloInstruction* reshape, DynamicDimensionInference* dynamic_dimension_inference) { @@ -1823,14 +1732,6 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) { continue; } - // Elementwise binary with dynamic shapes have implicit broadcast - // semantics. - if (inst->IsElementwiseBinary()) { - TF_ASSIGN_OR_RETURN(changed, RewriteDynamicBinaryOp( - inst, &dynamic_dimension_inference)); - continue; - } - if (inst->opcode() == HloOpcode::kDynamicReshape) { TF_ASSIGN_OR_RETURN( changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));