Implement implicit broadcast for binary operation of dynamic shapes.

- This cl instructs dynamic padder to insert implicit broadcasts into the graph when a binary operation is performed on two dynamic tensors.
- Optimization #1: The implicit broadcast is only inserted when we can't proof two dynamic dimensions are the same.
- Optimization #2: Added a simplification pass that allows us to simplify operations on dynamic dimensions, this opens up more opportunities for optimization #1

PiperOrigin-RevId: 356407626
Change-Id: I980477ee6f3ccb42342226afaab03b4b09549360
This commit is contained in:
Yunxing Dai 2021-02-08 19:21:10 -08:00 committed by TensorFlower Gardener
parent 969371e28f
commit d3e7ce26d9
6 changed files with 628 additions and 2 deletions

View File

@ -2794,6 +2794,47 @@ cc_library(
],
)
cc_library(
name = "dynamic_dimension_simplifier",
srcs = ["dynamic_dimension_simplifier.cc"],
hdrs = ["dynamic_dimension_simplifier.h"],
deps = [
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:status_macros",
],
)
tf_cc_test(
name = "dynamic_dimension_simplifier_test",
srcs = ["dynamic_dimension_simplifier_test.cc"],
deps = [
":dynamic_dimension_simplifier",
":hlo",
":hlo_casting_utils",
":hlo_creation_utils",
":hlo_parser",
":hlo_pass",
":hlo_pass_pipeline",
":pattern_matcher",
":pattern_matcher_gmock",
":shape_inference",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "dynamic_padder",
srcs = ["dynamic_padder.cc"],

View File

@ -636,7 +636,7 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate(
}
Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
HloInstruction*) {
HloInstruction* gds) {
// Dynamic dimension doesn't propagate through GetDimensionSize:
//
// Input: F32[x, y, z]
@ -646,6 +646,24 @@ Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
// The returned value is a scalar, which doesn't have any dynamic dimension in
// the shape (although the value contains the real size of the dynamic
// dimension of the input).
int64 dim = gds->dimension();
HloInstruction* operand = gds->mutable_operand(0);
HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim);
HloComputation* computation = gds->parent();
if (dynamic_size != nullptr) {
TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size));
// The dependency between an instruction and its dynamic dimensions is not
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
// dynamic dimension inference that the instruction is being replaced.
parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size);
} else {
TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
int32 size = gds->operand(0)->shape().dimensions(dim);
HloInstruction* new_instr = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr));
parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr);
}
return Status::OK();
}
@ -794,7 +812,23 @@ Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
HloInstruction* hlo) {
return PassThroughDynamicDimension(hlo);
HloComputation* comp = hlo->parent();
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size) {
HloInstruction* existing_size =
parent_->GetDynamicSize(hlo, index, dimension);
if (existing_size == nullptr || existing_size == dynamic_size) {
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
} else {
HloInstruction* max =
comp->AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum,
dynamic_size, existing_size));
parent_->SetDynamicSize(hlo, index, dimension, max);
}
return Status::OK();
});
}
Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {

View File

@ -0,0 +1,214 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace xla {
namespace {
// Concat(Concat(A, B), C) => Concat(A, B, C)
StatusOr<bool> ConcatForwarding(HloInstruction* concat) {
if (concat->opcode() != HloOpcode::kConcatenate) {
return false;
}
bool changed = false;
auto parent = concat->parent();
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : concat->operands()) {
if (operand->opcode() != HloOpcode::kConcatenate ||
operand->concatenate_dimension() != concat->concatenate_dimension()) {
new_operands.push_back(operand);
} else {
changed = true;
for (HloInstruction* operand_operand : operand->operands()) {
new_operands.push_back(operand_operand);
}
}
}
if (changed) {
auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate(
concat->shape(), new_operands, concat->concatenate_dimension()));
TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat));
}
return changed;
}
// Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An
StatusOr<bool> SliceConcatForwarding(HloInstruction* slice) {
if (slice->opcode() != HloOpcode::kSlice) {
return false;
}
auto concat = slice->mutable_operand(0);
if (concat->opcode() != HloOpcode::kConcatenate) {
return false;
}
if (slice->shape().rank() != 1) {
// Slice concat forwarding only work for size 1 tensor.
return false;
}
int64 concat_dim = concat->concatenate_dimension();
std::vector<HloInstruction*> new_operands;
int64 size_so_far = 0;
int64 slice_size = slice->shape().dimensions(concat_dim);
if (slice_size != slice->slice_limits(0) - slice->slice_starts(0)) {
return false;
}
if (slice->slice_strides(0) != 1) {
return false;
}
for (HloInstruction* operand : concat->operands()) {
if (size_so_far == slice->slice_starts(0) &&
operand->shape().dimensions(0) == slice_size) {
// Found an operand that can be forwarded.
TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand));
return true;
}
size_so_far += operand->shape().dimensions(concat_dim);
}
return false;
}
// Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A
StatusOr<bool> ReshapeBroadcastForwarding(HloInstruction* reshape) {
if (reshape->opcode() != HloOpcode::kReshape) {
return false;
}
auto broadcast = reshape->mutable_operand(0);
if (broadcast->opcode() != HloOpcode::kBroadcast) {
return false;
}
if (reshape->shape().rank() != 0) {
return false;
}
if (broadcast->shape().rank() != 1) {
return false;
}
if (broadcast->mutable_operand(0)->shape().rank() != 0) {
return false;
}
TF_RETURN_IF_ERROR(
reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0)));
return true;
}
// Reshape(Reshape(A, []->[1]), [1]->[]) ==> A
StatusOr<bool> ReshapeReshapeForwarding(HloInstruction* reshape) {
if (reshape->opcode() != HloOpcode::kReshape) {
return false;
}
auto reshape_2 = reshape->mutable_operand(0);
if (reshape_2->opcode() != HloOpcode::kReshape) {
return false;
}
if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) {
return false;
}
TF_RETURN_IF_ERROR(
reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0)));
return true;
}
// Convert(A, T->T) ==> A
StatusOr<bool> IdentityConvertRemoving(HloInstruction* convert) {
if (convert->opcode() != HloOpcode::kConvert) {
return false;
}
auto operand = convert->mutable_operand(0);
if (Shape::Equal()(convert->shape(), operand->shape())) {
TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand));
return true;
}
return false;
}
// Reshape(A, S->S) ==> A
StatusOr<bool> IdentityReshapeRemoving(HloInstruction* reshape) {
if (reshape->opcode() != HloOpcode::kReshape) {
return false;
}
auto operand = reshape->mutable_operand(0);
if (Shape::Equal()(reshape->shape(), operand->shape())) {
TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand));
return true;
}
return false;
}
} // namespace
StatusOr<bool> DynamicDimensionSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(
2, "DynamicDimensionSimplifier::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst));
changed |= local_changed;
}
}
XLA_VLOG_LINES(
2, "DynamicDimensionSimplifier::Run(), after:\n" + module->ToString());
return changed;
}
} // namespace xla

View File

@ -0,0 +1,37 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
#include <utility>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// This pass simplifies operations on dynamic dimension sizes so that it can be
// easily analyzed by later passes.
class DynamicDimensionSimplifier : public HloModulePass {
public:
absl::string_view name() const override {
return "dynamic dimension simplifier";
}
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_

View File

@ -0,0 +1,201 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
#include <memory>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
namespace m = match;
class DynamicDimensionSimplifierTest : public HloTestBase {};
TEST_F(DynamicDimensionSimplifierTest, ForwardConcat) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1] parameter(0)
p1 = s32[1] parameter(1)
p2 = s32[1] parameter(2)
concat1 = s32[2] concatenate(p0, p1), dimensions={0}
ROOT concat2 = s32[3] concatenate(concat1, p2), dimensions={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(1),
m::Parameter(2))));
}
TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatMultipleDims) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1, 1] parameter(0)
p1 = s32[1, 1] parameter(1)
p2 = s32[2, 1] parameter(2)
concat1 = s32[2, 1] concatenate(p0, p1), dimensions={0}
ROOT concat2 = s32[2, 2] concatenate(concat1, p2), dimensions={1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
}
TEST_F(DynamicDimensionSimplifierTest, ForwardConcatSlice) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1] parameter(0)
p1 = s32[1] parameter(1)
p2 = s32[1] parameter(2)
concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
ROOT slice = s32[1] slice(concat), slice={[1:2]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(1)));
}
TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceSizeMismatch) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1] parameter(0)
p1 = s32[1] parameter(1)
p2 = s32[1] parameter(2)
concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
ROOT slice = s32[2] slice(concat), slice={[1:3]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
}
TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceStrided) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1] parameter(0)
p1 = s32[1] parameter(1)
p2 = s32[1] parameter(2)
concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
ROOT slice = s32[1] slice(concat), slice={[1:2:2]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
}
TEST_F(DynamicDimensionSimplifierTest, BroadcastReshapeForwarding) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[] parameter(0)
broadcast = s32[1] broadcast(p0), dimensions={}
ROOT reshape = s32[] reshape(broadcast)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}
TEST_F(DynamicDimensionSimplifierTest, ReshapeReshapeForwarding) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[] parameter(0)
reshape = s32[1] reshape(p0)
ROOT reshape2 = s32[] reshape(reshape)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}
TEST_F(DynamicDimensionSimplifierTest,
DoNotReshapeReshapeForwardingShapeMismatch) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1, 1] parameter(0)
reshape = s32[1] reshape(p0)
ROOT reshape2 = s32[] reshape(reshape)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
}
TEST_F(DynamicDimensionSimplifierTest, IdConvertRemoving) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1] parameter(0)
ROOT reshape2 = s32[1] convert(p0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}
} // namespace
} // namespace xla

View File

@ -1282,6 +1282,97 @@ StatusOr<bool> RewriteDynamicSort(
return true;
}
StatusOr<bool> RewriteDynamicBinaryOp(
HloInstruction* binary,
DynamicDimensionInference* dynamic_dimension_inference) {
HloInstruction* operand_0 = binary->mutable_operand(0);
HloInstruction* operand_1 = binary->mutable_operand(1);
HloComputation* comp = binary->parent();
TF_RET_CHECK(operand_0->shape().rank() == operand_1->shape().rank());
auto dims_0 = dynamic_dimension_inference->GetDynamicSizes(operand_0, {});
auto dims_1 = dynamic_dimension_inference->GetDynamicSizes(operand_1, {});
bool changed = false;
for (int64 i = 0; i < dims_0.size(); ++i) {
HloInstruction* dim_0 = dims_0[i];
HloInstruction* dim_1 = dims_1[i];
if (dims_0[i] != dims_1[i] && dims_0[i] != nullptr &&
dims_1[i] != nullptr) {
changed = true;
// It is possible that a dynamic dimension of one operand is size 1 while
// the other is greater than one. According to implicit broadcast
// semantics, we need to insert broadcast in this case to make the dynamic
// shape match.
// An implicit broadcast is inserted by slicing the small shape into a
// size 1 slice, reshape out the size 1 dimension then broadcast to the
// full shape:
//
// Input [2, <=5, 3]
// |
// Slice [2, 1, 3]
// |
// Reshape [2, 3]
// |
// Broadcast [2, 5, 3]
auto rewrite_operand = [&](HloInstruction* pred,
HloInstruction* operand) -> HloInstruction* {
Shape static_shape = operand->shape();
static_shape.clear_dynamic_dimensions();
pred = comp->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(static_shape, PRED), pred, {}));
Shape slice_shape = static_shape;
slice_shape.set_dimensions(i, 1);
std::vector<int64> start_indices(slice_shape.rank(), 0);
std::vector<int64> strides(slice_shape.rank(), 1);
HloInstruction* slice = comp->AddInstruction(
HloInstruction::CreateSlice(slice_shape, operand, start_indices,
slice_shape.dimensions(), strides));
Shape reshape_shape = ShapeUtil::DeleteDimension(i, slice_shape);
HloInstruction* reshape = comp->AddInstruction(
HloInstruction::CreateReshape(reshape_shape, slice));
std::vector<int64> broadcast_dims;
broadcast_dims.reserve(static_shape.rank() - 1);
// Broadcast to all dims execpt for i.
for (int64 j = 0; j < static_shape.rank(); ++j) {
if (j != i) {
broadcast_dims.push_back(j);
}
}
HloInstruction* broadcast =
comp->AddInstruction(HloInstruction::CreateBroadcast(
static_shape, reshape, broadcast_dims),
"implicit_broadcast");
// Use a select instead of conditional as elementwise operations promote
// more fusion.
HloInstruction* select =
comp->AddInstruction(HloInstruction::CreateTernary(
static_shape, HloOpcode::kSelect, pred, broadcast, operand));
return select;
};
auto operand_0_needs_broadcast = binary->parent()->AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0,
dim_1, ComparisonDirection::kLt),
"lhs_needs_implicit_broadcast");
operand_0 = rewrite_operand(operand_0_needs_broadcast, operand_0);
auto operand_1_needs_broadcast = binary->parent()->AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1,
dim_0, ComparisonDirection::kLt),
"rhs_needs_implicit_broadcast");
operand_1 = rewrite_operand(operand_1_needs_broadcast, operand_1);
}
}
if (changed) {
TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0));
TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1));
}
return changed;
}
StatusOr<bool> RewriteDynamicReshape(
HloInstruction* reshape,
DynamicDimensionInference* dynamic_dimension_inference) {
@ -1732,6 +1823,14 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
continue;
}
// Elementwise binary with dynamic shapes have implicit broadcast
// semantics.
if (inst->IsElementwiseBinary()) {
TF_ASSIGN_OR_RETURN(changed, RewriteDynamicBinaryOp(
inst, &dynamic_dimension_inference));
continue;
}
if (inst->opcode() == HloOpcode::kDynamicReshape) {
TF_ASSIGN_OR_RETURN(
changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));