Internal change

PiperOrigin-RevId: 356417008
Change-Id: If13d7ebaf8a66cb8fbc9296cd5dc92b75e528a34
This commit is contained in:
A. Unique TensorFlower 2021-02-08 20:44:40 -08:00 committed by TensorFlower Gardener
parent d61b9211bc
commit a0b5d0bacb
6 changed files with 2 additions and 628 deletions

View File

@ -2794,47 +2794,6 @@ cc_library(
],
)
cc_library(
name = "dynamic_dimension_simplifier",
srcs = ["dynamic_dimension_simplifier.cc"],
hdrs = ["dynamic_dimension_simplifier.h"],
deps = [
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:status_macros",
],
)
tf_cc_test(
name = "dynamic_dimension_simplifier_test",
srcs = ["dynamic_dimension_simplifier_test.cc"],
deps = [
":dynamic_dimension_simplifier",
":hlo",
":hlo_casting_utils",
":hlo_creation_utils",
":hlo_parser",
":hlo_pass",
":hlo_pass_pipeline",
":pattern_matcher",
":pattern_matcher_gmock",
":shape_inference",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "dynamic_padder",
srcs = ["dynamic_padder.cc"],

View File

@ -636,7 +636,7 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate(
}
Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
HloInstruction* gds) {
HloInstruction*) {
// Dynamic dimension doesn't propagate through GetDimensionSize:
//
// Input: F32[x, y, z]
@ -646,24 +646,6 @@ Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
// The returned value is a scalar, which doesn't have any dynamic dimension in
// the shape (although the value contains the real size of the dynamic
// dimension of the input).
int64 dim = gds->dimension();
HloInstruction* operand = gds->mutable_operand(0);
HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim);
HloComputation* computation = gds->parent();
if (dynamic_size != nullptr) {
TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size));
// The dependency between an instruction and its dynamic dimensions is not
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
// dynamic dimension inference that the instruction is being replaced.
parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size);
} else {
TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
int32 size = gds->operand(0)->shape().dimensions(dim);
HloInstruction* new_instr = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr));
parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr);
}
return Status::OK();
}
@ -812,23 +794,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
HloInstruction* hlo) {
HloComputation* comp = hlo->parent();
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size) {
HloInstruction* existing_size =
parent_->GetDynamicSize(hlo, index, dimension);
if (existing_size == nullptr || existing_size == dynamic_size) {
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
} else {
HloInstruction* max =
comp->AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum,
dynamic_size, existing_size));
parent_->SetDynamicSize(hlo, index, dimension, max);
}
return Status::OK();
});
return PassThroughDynamicDimension(hlo);
}
Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {

View File

@ -1,214 +0,0 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace xla {
namespace {
// Concat(Concat(A, B), C) => Concat(A, B, C)
StatusOr<bool> ConcatForwarding(HloInstruction* concat) {
if (concat->opcode() != HloOpcode::kConcatenate) {
return false;
}
bool changed = false;
auto parent = concat->parent();
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : concat->operands()) {
if (operand->opcode() != HloOpcode::kConcatenate ||
operand->concatenate_dimension() != concat->concatenate_dimension()) {
new_operands.push_back(operand);
} else {
changed = true;
for (HloInstruction* operand_operand : operand->operands()) {
new_operands.push_back(operand_operand);
}
}
}
if (changed) {
auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate(
concat->shape(), new_operands, concat->concatenate_dimension()));
TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat));
}
return changed;
}
// Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An
StatusOr<bool> SliceConcatForwarding(HloInstruction* slice) {
if (slice->opcode() != HloOpcode::kSlice) {
return false;
}
auto concat = slice->mutable_operand(0);
if (concat->opcode() != HloOpcode::kConcatenate) {
return false;
}
if (slice->shape().rank() != 1) {
// Slice concat forwarding only work for size 1 tensor.
return false;
}
int64 concat_dim = concat->concatenate_dimension();
std::vector<HloInstruction*> new_operands;
int64 size_so_far = 0;
int64 slice_size = slice->shape().dimensions(concat_dim);
if (slice_size != slice->slice_limits(0) - slice->slice_starts(0)) {
return false;
}
if (slice->slice_strides(0) != 1) {
return false;
}
for (HloInstruction* operand : concat->operands()) {
if (size_so_far == slice->slice_starts(0) &&
operand->shape().dimensions(0) == slice_size) {
// Found an operand that can be forwarded.
TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand));
return true;
}
size_so_far += operand->shape().dimensions(concat_dim);
}
return false;
}
// Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A
StatusOr<bool> ReshapeBroadcastForwarding(HloInstruction* reshape) {
if (reshape->opcode() != HloOpcode::kReshape) {
return false;
}
auto broadcast = reshape->mutable_operand(0);
if (broadcast->opcode() != HloOpcode::kBroadcast) {
return false;
}
if (reshape->shape().rank() != 0) {
return false;
}
if (broadcast->shape().rank() != 1) {
return false;
}
if (broadcast->mutable_operand(0)->shape().rank() != 0) {
return false;
}
TF_RETURN_IF_ERROR(
reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0)));
return true;
}
// Reshape(Reshape(A, []->[1]), [1]->[]) ==> A
StatusOr<bool> ReshapeReshapeForwarding(HloInstruction* reshape) {
if (reshape->opcode() != HloOpcode::kReshape) {
return false;
}
auto reshape_2 = reshape->mutable_operand(0);
if (reshape_2->opcode() != HloOpcode::kReshape) {
return false;
}
if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) {
return false;
}
TF_RETURN_IF_ERROR(
reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0)));
return true;
}
// Convert(A, T->T) ==> A
StatusOr<bool> IdentityConvertRemoving(HloInstruction* convert) {
if (convert->opcode() != HloOpcode::kConvert) {
return false;
}
auto operand = convert->mutable_operand(0);
if (Shape::Equal()(convert->shape(), operand->shape())) {
TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand));
return true;
}
return false;
}
// Reshape(A, S->S) ==> A
StatusOr<bool> IdentityReshapeRemoving(HloInstruction* reshape) {
if (reshape->opcode() != HloOpcode::kReshape) {
return false;
}
auto operand = reshape->mutable_operand(0);
if (Shape::Equal()(reshape->shape(), operand->shape())) {
TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand));
return true;
}
return false;
}
} // namespace
StatusOr<bool> DynamicDimensionSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(
2, "DynamicDimensionSimplifier::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst));
changed |= local_changed;
}
}
XLA_VLOG_LINES(
2, "DynamicDimensionSimplifier::Run(), after:\n" + module->ToString());
return changed;
}
} // namespace xla

View File

@ -1,37 +0,0 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
#include <utility>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// This pass simplifies operations on dynamic dimension sizes so that it can be
// easily analyzed by later passes.
class DynamicDimensionSimplifier : public HloModulePass {
public:
absl::string_view name() const override {
return "dynamic dimension simplifier";
}
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_

View File

@ -1,201 +0,0 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
#include <memory>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
namespace m = match;
class DynamicDimensionSimplifierTest : public HloTestBase {};
TEST_F(DynamicDimensionSimplifierTest, ForwardConcat) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1] parameter(0)
p1 = s32[1] parameter(1)
p2 = s32[1] parameter(2)
concat1 = s32[2] concatenate(p0, p1), dimensions={0}
ROOT concat2 = s32[3] concatenate(concat1, p2), dimensions={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(1),
m::Parameter(2))));
}
TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatMultipleDims) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1, 1] parameter(0)
p1 = s32[1, 1] parameter(1)
p2 = s32[2, 1] parameter(2)
concat1 = s32[2, 1] concatenate(p0, p1), dimensions={0}
ROOT concat2 = s32[2, 2] concatenate(concat1, p2), dimensions={1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
}
TEST_F(DynamicDimensionSimplifierTest, ForwardConcatSlice) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1] parameter(0)
p1 = s32[1] parameter(1)
p2 = s32[1] parameter(2)
concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
ROOT slice = s32[1] slice(concat), slice={[1:2]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(1)));
}
TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceSizeMismatch) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1] parameter(0)
p1 = s32[1] parameter(1)
p2 = s32[1] parameter(2)
concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
ROOT slice = s32[2] slice(concat), slice={[1:3]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
}
TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceStrided) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1] parameter(0)
p1 = s32[1] parameter(1)
p2 = s32[1] parameter(2)
concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
ROOT slice = s32[1] slice(concat), slice={[1:2:2]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
}
TEST_F(DynamicDimensionSimplifierTest, BroadcastReshapeForwarding) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[] parameter(0)
broadcast = s32[1] broadcast(p0), dimensions={}
ROOT reshape = s32[] reshape(broadcast)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}
TEST_F(DynamicDimensionSimplifierTest, ReshapeReshapeForwarding) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[] parameter(0)
reshape = s32[1] reshape(p0)
ROOT reshape2 = s32[] reshape(reshape)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}
TEST_F(DynamicDimensionSimplifierTest,
DoNotReshapeReshapeForwardingShapeMismatch) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1, 1] parameter(0)
reshape = s32[1] reshape(p0)
ROOT reshape2 = s32[] reshape(reshape)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
}
TEST_F(DynamicDimensionSimplifierTest, IdConvertRemoving) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = s32[1] parameter(0)
ROOT reshape2 = s32[1] convert(p0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
DynamicDimensionSimplifier simplifier;
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}
} // namespace
} // namespace xla

View File

@ -1282,97 +1282,6 @@ StatusOr<bool> RewriteDynamicSort(
return true;
}
StatusOr<bool> RewriteDynamicBinaryOp(
HloInstruction* binary,
DynamicDimensionInference* dynamic_dimension_inference) {
HloInstruction* operand_0 = binary->mutable_operand(0);
HloInstruction* operand_1 = binary->mutable_operand(1);
HloComputation* comp = binary->parent();
TF_RET_CHECK(operand_0->shape().rank() == operand_1->shape().rank());
auto dims_0 = dynamic_dimension_inference->GetDynamicSizes(operand_0, {});
auto dims_1 = dynamic_dimension_inference->GetDynamicSizes(operand_1, {});
bool changed = false;
for (int64 i = 0; i < dims_0.size(); ++i) {
HloInstruction* dim_0 = dims_0[i];
HloInstruction* dim_1 = dims_1[i];
if (dims_0[i] != dims_1[i] && dims_0[i] != nullptr &&
dims_1[i] != nullptr) {
changed = true;
// It is possible that a dynamic dimension of one operand is size 1 while
// the other is greater than one. According to implicit broadcast
// semantics, we need to insert broadcast in this case to make the dynamic
// shape match.
// An implicit broadcast is inserted by slicing the small shape into a
// size 1 slice, reshape out the size 1 dimension then broadcast to the
// full shape:
//
// Input [2, <=5, 3]
// |
// Slice [2, 1, 3]
// |
// Reshape [2, 3]
// |
// Broadcast [2, 5, 3]
auto rewrite_operand = [&](HloInstruction* pred,
HloInstruction* operand) -> HloInstruction* {
Shape static_shape = operand->shape();
static_shape.clear_dynamic_dimensions();
pred = comp->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(static_shape, PRED), pred, {}));
Shape slice_shape = static_shape;
slice_shape.set_dimensions(i, 1);
std::vector<int64> start_indices(slice_shape.rank(), 0);
std::vector<int64> strides(slice_shape.rank(), 1);
HloInstruction* slice = comp->AddInstruction(
HloInstruction::CreateSlice(slice_shape, operand, start_indices,
slice_shape.dimensions(), strides));
Shape reshape_shape = ShapeUtil::DeleteDimension(i, slice_shape);
HloInstruction* reshape = comp->AddInstruction(
HloInstruction::CreateReshape(reshape_shape, slice));
std::vector<int64> broadcast_dims;
broadcast_dims.reserve(static_shape.rank() - 1);
// Broadcast to all dims execpt for i.
for (int64 j = 0; j < static_shape.rank(); ++j) {
if (j != i) {
broadcast_dims.push_back(j);
}
}
HloInstruction* broadcast =
comp->AddInstruction(HloInstruction::CreateBroadcast(
static_shape, reshape, broadcast_dims),
"implicit_broadcast");
// Use a select instead of conditional as elementwise operations promote
// more fusion.
HloInstruction* select =
comp->AddInstruction(HloInstruction::CreateTernary(
static_shape, HloOpcode::kSelect, pred, broadcast, operand));
return select;
};
auto operand_0_needs_broadcast = binary->parent()->AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0,
dim_1, ComparisonDirection::kLt),
"lhs_needs_implicit_broadcast");
operand_0 = rewrite_operand(operand_0_needs_broadcast, operand_0);
auto operand_1_needs_broadcast = binary->parent()->AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1,
dim_0, ComparisonDirection::kLt),
"rhs_needs_implicit_broadcast");
operand_1 = rewrite_operand(operand_1_needs_broadcast, operand_1);
}
}
if (changed) {
TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0));
TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1));
}
return changed;
}
StatusOr<bool> RewriteDynamicReshape(
HloInstruction* reshape,
DynamicDimensionInference* dynamic_dimension_inference) {
@ -1823,14 +1732,6 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
continue;
}
// Elementwise binary with dynamic shapes have implicit broadcast
// semantics.
if (inst->IsElementwiseBinary()) {
TF_ASSIGN_OR_RETURN(changed, RewriteDynamicBinaryOp(
inst, &dynamic_dimension_inference));
continue;
}
if (inst->opcode() == HloOpcode::kDynamicReshape) {
TF_ASSIGN_OR_RETURN(
changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));