Internal change
PiperOrigin-RevId: 356417008 Change-Id: If13d7ebaf8a66cb8fbc9296cd5dc92b75e528a34
This commit is contained in:
parent
d61b9211bc
commit
a0b5d0bacb
@ -2794,47 +2794,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dynamic_dimension_simplifier",
|
||||
srcs = ["dynamic_dimension_simplifier.cc"],
|
||||
hdrs = ["dynamic_dimension_simplifier.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "dynamic_dimension_simplifier_test",
|
||||
srcs = ["dynamic_dimension_simplifier_test.cc"],
|
||||
deps = [
|
||||
":dynamic_dimension_simplifier",
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_creation_utils",
|
||||
":hlo_parser",
|
||||
":hlo_pass",
|
||||
":hlo_pass_pipeline",
|
||||
":pattern_matcher",
|
||||
":pattern_matcher_gmock",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dynamic_padder",
|
||||
srcs = ["dynamic_padder.cc"],
|
||||
|
@ -636,7 +636,7 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate(
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
|
||||
HloInstruction* gds) {
|
||||
HloInstruction*) {
|
||||
// Dynamic dimension doesn't propagate through GetDimensionSize:
|
||||
//
|
||||
// Input: F32[x, y, z]
|
||||
@ -646,24 +646,6 @@ Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
|
||||
// The returned value is a scalar, which doesn't have any dynamic dimension in
|
||||
// the shape (although the value contains the real size of the dynamic
|
||||
// dimension of the input).
|
||||
int64 dim = gds->dimension();
|
||||
HloInstruction* operand = gds->mutable_operand(0);
|
||||
HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim);
|
||||
HloComputation* computation = gds->parent();
|
||||
if (dynamic_size != nullptr) {
|
||||
TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size));
|
||||
// The dependency between an instruction and its dynamic dimensions is not
|
||||
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
|
||||
// dynamic dimension inference that the instruction is being replaced.
|
||||
parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size);
|
||||
} else {
|
||||
TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
|
||||
int32 size = gds->operand(0)->shape().dimensions(dim);
|
||||
HloInstruction* new_instr = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
|
||||
TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr));
|
||||
parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -812,23 +794,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
|
||||
HloInstruction* hlo) {
|
||||
HloComputation* comp = hlo->parent();
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
HloInstruction* existing_size =
|
||||
parent_->GetDynamicSize(hlo, index, dimension);
|
||||
if (existing_size == nullptr || existing_size == dynamic_size) {
|
||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
|
||||
} else {
|
||||
HloInstruction* max =
|
||||
comp->AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum,
|
||||
dynamic_size, existing_size));
|
||||
parent_->SetDynamicSize(hlo, index, dimension, max);
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
return PassThroughDynamicDimension(hlo);
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
|
||||
|
@ -1,214 +0,0 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// Concat(Concat(A, B), C) => Concat(A, B, C)
|
||||
StatusOr<bool> ConcatForwarding(HloInstruction* concat) {
|
||||
if (concat->opcode() != HloOpcode::kConcatenate) {
|
||||
return false;
|
||||
}
|
||||
bool changed = false;
|
||||
|
||||
auto parent = concat->parent();
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
for (HloInstruction* operand : concat->operands()) {
|
||||
if (operand->opcode() != HloOpcode::kConcatenate ||
|
||||
operand->concatenate_dimension() != concat->concatenate_dimension()) {
|
||||
new_operands.push_back(operand);
|
||||
} else {
|
||||
changed = true;
|
||||
for (HloInstruction* operand_operand : operand->operands()) {
|
||||
new_operands.push_back(operand_operand);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate(
|
||||
concat->shape(), new_operands, concat->concatenate_dimension()));
|
||||
TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat));
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An
|
||||
StatusOr<bool> SliceConcatForwarding(HloInstruction* slice) {
|
||||
if (slice->opcode() != HloOpcode::kSlice) {
|
||||
return false;
|
||||
}
|
||||
auto concat = slice->mutable_operand(0);
|
||||
if (concat->opcode() != HloOpcode::kConcatenate) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (slice->shape().rank() != 1) {
|
||||
// Slice concat forwarding only work for size 1 tensor.
|
||||
return false;
|
||||
}
|
||||
|
||||
int64 concat_dim = concat->concatenate_dimension();
|
||||
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
int64 size_so_far = 0;
|
||||
int64 slice_size = slice->shape().dimensions(concat_dim);
|
||||
if (slice_size != slice->slice_limits(0) - slice->slice_starts(0)) {
|
||||
return false;
|
||||
}
|
||||
if (slice->slice_strides(0) != 1) {
|
||||
return false;
|
||||
}
|
||||
for (HloInstruction* operand : concat->operands()) {
|
||||
if (size_so_far == slice->slice_starts(0) &&
|
||||
operand->shape().dimensions(0) == slice_size) {
|
||||
// Found an operand that can be forwarded.
|
||||
TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand));
|
||||
return true;
|
||||
}
|
||||
size_so_far += operand->shape().dimensions(concat_dim);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A
|
||||
StatusOr<bool> ReshapeBroadcastForwarding(HloInstruction* reshape) {
|
||||
if (reshape->opcode() != HloOpcode::kReshape) {
|
||||
return false;
|
||||
}
|
||||
auto broadcast = reshape->mutable_operand(0);
|
||||
if (broadcast->opcode() != HloOpcode::kBroadcast) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (reshape->shape().rank() != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (broadcast->shape().rank() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (broadcast->mutable_operand(0)->shape().rank() != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0)));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reshape(Reshape(A, []->[1]), [1]->[]) ==> A
|
||||
StatusOr<bool> ReshapeReshapeForwarding(HloInstruction* reshape) {
|
||||
if (reshape->opcode() != HloOpcode::kReshape) {
|
||||
return false;
|
||||
}
|
||||
auto reshape_2 = reshape->mutable_operand(0);
|
||||
if (reshape_2->opcode() != HloOpcode::kReshape) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) {
|
||||
return false;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0)));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Convert(A, T->T) ==> A
|
||||
StatusOr<bool> IdentityConvertRemoving(HloInstruction* convert) {
|
||||
if (convert->opcode() != HloOpcode::kConvert) {
|
||||
return false;
|
||||
}
|
||||
auto operand = convert->mutable_operand(0);
|
||||
if (Shape::Equal()(convert->shape(), operand->shape())) {
|
||||
TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Reshape(A, S->S) ==> A
|
||||
StatusOr<bool> IdentityReshapeRemoving(HloInstruction* reshape) {
|
||||
if (reshape->opcode() != HloOpcode::kReshape) {
|
||||
return false;
|
||||
}
|
||||
auto operand = reshape->mutable_operand(0);
|
||||
if (Shape::Equal()(reshape->shape(), operand->shape())) {
|
||||
TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> DynamicDimensionSimplifier::Run(HloModule* module) {
|
||||
XLA_VLOG_LINES(
|
||||
2, "DynamicDimensionSimplifier::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
for (auto* inst : comp->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst));
|
||||
changed |= local_changed;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
for (auto* inst : comp->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst));
|
||||
changed |= local_changed;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
for (auto* inst : comp->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst));
|
||||
changed |= local_changed;
|
||||
}
|
||||
}
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
for (auto* inst : comp->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst));
|
||||
changed |= local_changed;
|
||||
}
|
||||
}
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
for (auto* inst : comp->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst));
|
||||
changed |= local_changed;
|
||||
}
|
||||
}
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
for (auto* inst : comp->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst));
|
||||
changed |= local_changed;
|
||||
}
|
||||
}
|
||||
XLA_VLOG_LINES(
|
||||
2, "DynamicDimensionSimplifier::Run(), after:\n" + module->ToString());
|
||||
return changed;
|
||||
}
|
||||
} // namespace xla
|
@ -1,37 +0,0 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// This pass simplifies operations on dynamic dimension sizes so that it can be
|
||||
// easily analyzed by later passes.
|
||||
class DynamicDimensionSimplifier : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override {
|
||||
return "dynamic dimension simplifier";
|
||||
}
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
|
@ -1,201 +0,0 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace m = match;
|
||||
|
||||
class DynamicDimensionSimplifierTest : public HloTestBase {};
|
||||
|
||||
TEST_F(DynamicDimensionSimplifierTest, ForwardConcat) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = s32[1] parameter(0)
|
||||
p1 = s32[1] parameter(1)
|
||||
p2 = s32[1] parameter(2)
|
||||
concat1 = s32[2] concatenate(p0, p1), dimensions={0}
|
||||
ROOT concat2 = s32[3] concatenate(concat1, p2), dimensions={0}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
DynamicDimensionSimplifier simplifier;
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(1),
|
||||
m::Parameter(2))));
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatMultipleDims) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = s32[1, 1] parameter(0)
|
||||
p1 = s32[1, 1] parameter(1)
|
||||
p2 = s32[2, 1] parameter(2)
|
||||
concat1 = s32[2, 1] concatenate(p0, p1), dimensions={0}
|
||||
ROOT concat2 = s32[2, 2] concatenate(concat1, p2), dimensions={1}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
DynamicDimensionSimplifier simplifier;
|
||||
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionSimplifierTest, ForwardConcatSlice) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = s32[1] parameter(0)
|
||||
p1 = s32[1] parameter(1)
|
||||
p2 = s32[1] parameter(2)
|
||||
concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
|
||||
ROOT slice = s32[1] slice(concat), slice={[1:2]}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
DynamicDimensionSimplifier simplifier;
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Parameter(1)));
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceSizeMismatch) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = s32[1] parameter(0)
|
||||
p1 = s32[1] parameter(1)
|
||||
p2 = s32[1] parameter(2)
|
||||
concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
|
||||
ROOT slice = s32[2] slice(concat), slice={[1:3]}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
DynamicDimensionSimplifier simplifier;
|
||||
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceStrided) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = s32[1] parameter(0)
|
||||
p1 = s32[1] parameter(1)
|
||||
p2 = s32[1] parameter(2)
|
||||
concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
|
||||
ROOT slice = s32[1] slice(concat), slice={[1:2:2]}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
DynamicDimensionSimplifier simplifier;
|
||||
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionSimplifierTest, BroadcastReshapeForwarding) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = s32[] parameter(0)
|
||||
broadcast = s32[1] broadcast(p0), dimensions={}
|
||||
ROOT reshape = s32[] reshape(broadcast)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
DynamicDimensionSimplifier simplifier;
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Parameter(0)));
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionSimplifierTest, ReshapeReshapeForwarding) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = s32[] parameter(0)
|
||||
reshape = s32[1] reshape(p0)
|
||||
ROOT reshape2 = s32[] reshape(reshape)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
DynamicDimensionSimplifier simplifier;
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Parameter(0)));
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionSimplifierTest,
|
||||
DoNotReshapeReshapeForwardingShapeMismatch) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = s32[1, 1] parameter(0)
|
||||
reshape = s32[1] reshape(p0)
|
||||
ROOT reshape2 = s32[] reshape(reshape)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
DynamicDimensionSimplifier simplifier;
|
||||
ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionSimplifierTest, IdConvertRemoving) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = s32[1] parameter(0)
|
||||
ROOT reshape2 = s32[1] convert(p0)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
DynamicDimensionSimplifier simplifier;
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Parameter(0)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -1282,97 +1282,6 @@ StatusOr<bool> RewriteDynamicSort(
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> RewriteDynamicBinaryOp(
|
||||
HloInstruction* binary,
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
HloInstruction* operand_0 = binary->mutable_operand(0);
|
||||
HloInstruction* operand_1 = binary->mutable_operand(1);
|
||||
|
||||
HloComputation* comp = binary->parent();
|
||||
TF_RET_CHECK(operand_0->shape().rank() == operand_1->shape().rank());
|
||||
auto dims_0 = dynamic_dimension_inference->GetDynamicSizes(operand_0, {});
|
||||
auto dims_1 = dynamic_dimension_inference->GetDynamicSizes(operand_1, {});
|
||||
bool changed = false;
|
||||
for (int64 i = 0; i < dims_0.size(); ++i) {
|
||||
HloInstruction* dim_0 = dims_0[i];
|
||||
HloInstruction* dim_1 = dims_1[i];
|
||||
|
||||
if (dims_0[i] != dims_1[i] && dims_0[i] != nullptr &&
|
||||
dims_1[i] != nullptr) {
|
||||
changed = true;
|
||||
// It is possible that a dynamic dimension of one operand is size 1 while
|
||||
// the other is greater than one. According to implicit broadcast
|
||||
// semantics, we need to insert broadcast in this case to make the dynamic
|
||||
// shape match.
|
||||
|
||||
// An implicit broadcast is inserted by slicing the small shape into a
|
||||
// size 1 slice, reshape out the size 1 dimension then broadcast to the
|
||||
// full shape:
|
||||
//
|
||||
// Input [2, <=5, 3]
|
||||
// |
|
||||
// Slice [2, 1, 3]
|
||||
// |
|
||||
// Reshape [2, 3]
|
||||
// |
|
||||
// Broadcast [2, 5, 3]
|
||||
auto rewrite_operand = [&](HloInstruction* pred,
|
||||
HloInstruction* operand) -> HloInstruction* {
|
||||
Shape static_shape = operand->shape();
|
||||
static_shape.clear_dynamic_dimensions();
|
||||
pred = comp->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::ChangeElementType(static_shape, PRED), pred, {}));
|
||||
Shape slice_shape = static_shape;
|
||||
slice_shape.set_dimensions(i, 1);
|
||||
std::vector<int64> start_indices(slice_shape.rank(), 0);
|
||||
std::vector<int64> strides(slice_shape.rank(), 1);
|
||||
HloInstruction* slice = comp->AddInstruction(
|
||||
HloInstruction::CreateSlice(slice_shape, operand, start_indices,
|
||||
slice_shape.dimensions(), strides));
|
||||
Shape reshape_shape = ShapeUtil::DeleteDimension(i, slice_shape);
|
||||
HloInstruction* reshape = comp->AddInstruction(
|
||||
HloInstruction::CreateReshape(reshape_shape, slice));
|
||||
std::vector<int64> broadcast_dims;
|
||||
broadcast_dims.reserve(static_shape.rank() - 1);
|
||||
// Broadcast to all dims execpt for i.
|
||||
for (int64 j = 0; j < static_shape.rank(); ++j) {
|
||||
if (j != i) {
|
||||
broadcast_dims.push_back(j);
|
||||
}
|
||||
}
|
||||
|
||||
HloInstruction* broadcast =
|
||||
comp->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
static_shape, reshape, broadcast_dims),
|
||||
"implicit_broadcast");
|
||||
|
||||
// Use a select instead of conditional as elementwise operations promote
|
||||
// more fusion.
|
||||
HloInstruction* select =
|
||||
comp->AddInstruction(HloInstruction::CreateTernary(
|
||||
static_shape, HloOpcode::kSelect, pred, broadcast, operand));
|
||||
return select;
|
||||
};
|
||||
auto operand_0_needs_broadcast = binary->parent()->AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0,
|
||||
dim_1, ComparisonDirection::kLt),
|
||||
"lhs_needs_implicit_broadcast");
|
||||
operand_0 = rewrite_operand(operand_0_needs_broadcast, operand_0);
|
||||
|
||||
auto operand_1_needs_broadcast = binary->parent()->AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1,
|
||||
dim_0, ComparisonDirection::kLt),
|
||||
"rhs_needs_implicit_broadcast");
|
||||
operand_1 = rewrite_operand(operand_1_needs_broadcast, operand_1);
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0));
|
||||
TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1));
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
StatusOr<bool> RewriteDynamicReshape(
|
||||
HloInstruction* reshape,
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
@ -1823,14 +1732,6 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Elementwise binary with dynamic shapes have implicit broadcast
|
||||
// semantics.
|
||||
if (inst->IsElementwiseBinary()) {
|
||||
TF_ASSIGN_OR_RETURN(changed, RewriteDynamicBinaryOp(
|
||||
inst, &dynamic_dimension_inference));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->opcode() == HloOpcode::kDynamicReshape) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
|
||||
|
Loading…
Reference in New Issue
Block a user