[NFC] Remove HloGetDimensionSizeRewriter.

- All backends support dynamic padder now, no need for a separate pass.
- This allows DynamicDimensionInference to run just once.

PiperOrigin-RevId: 326719617
Change-Id: I4a49ef16c3868224af0431d90e8fd164a367ea81
This commit is contained in:
Yunxing Dai 2020-08-14 13:31:12 -07:00 committed by TensorFlower Gardener
parent ba337c699f
commit 30f38ef537
9 changed files with 136 additions and 269 deletions

View File

@ -2684,6 +2684,7 @@ cc_library(
":hlo_casting_utils",
":hlo_dce",
":hlo_pass",
":shape_inference",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
@ -2707,7 +2708,6 @@ xla_test(
":dynamic_padder",
":hlo",
":hlo_dce",
":hlo_get_dimension_size_rewriter",
":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:debug_options_flags",
@ -3997,42 +3997,6 @@ tf_cc_test(
],
)
cc_library(
name = "hlo_get_dimension_size_rewriter",
srcs = ["hlo_get_dimension_size_rewriter.cc"],
hdrs = ["hlo_get_dimension_size_rewriter.h"],
deps = [
":dynamic_dimension_inference",
":hlo",
":hlo_pass",
":shape_inference",
"//tensorflow/compiler/xla:literal_util",
"@com_google_absl//absl/algorithm:container",
],
)
tf_cc_test(
name = "hlo_get_dimension_size_rewriter_test",
srcs = ["hlo_get_dimension_size_rewriter_test.cc"],
deps = [
":hlo",
":hlo_get_dimension_size_rewriter",
":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
cc_library(
name = "maybe_owning_device_memory",
srcs = [

View File

@ -140,7 +140,6 @@ cc_library(
"//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
"//tensorflow/compiler/xla/service:tree_reduction_rewriter",
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
"//tensorflow/compiler/xla/service:conditional_to_select",
"//tensorflow/compiler/xla/service:slow_operation_alarm",

View File

@ -85,7 +85,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_cse.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@ -292,7 +291,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<ConditionalCanonicalizer>();
pipeline.AddPass<DynamicPadder>();
pipeline.AddPass<ScatterExpander>();
pipeline.AddPass<HloGetDimensionSizeRewriter>();
pipeline.AddPass<ConvCanonicalization>(target_machine_features);
{
auto& pass =

View File

@ -32,6 +32,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@ -125,6 +127,58 @@ StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
}
}
StatusOr<bool> ReplaceGetSize(
HloInstruction* instr,
DynamicDimensionInference* dynamic_dimension_inference) {
if (instr->opcode() != HloOpcode::kGetDimensionSize) {
return false;
}
HloComputation* computation = instr->parent();
TF_ASSIGN_OR_RETURN(auto legal_shape,
ShapeInference::InferGetDimensionSizeShape(
instr->operand(0)->shape(), instr->dimension()));
TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
<< "instr->shape() " << instr->shape().ToString() << " , "
<< "legal_shape " << legal_shape.ToString();
TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
HloInstruction* operand = instr->mutable_operand(0);
int64 dim = instr->dimension();
HloInstruction* dynamic_size =
dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
if (dynamic_size != nullptr) {
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
// The dependency between a instruction and its dynamic dimensions is not
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
// dynamic dimension inference that the instruction is being replaced.
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
instr, dynamic_size);
} else {
int32 size = instr->operand(0)->shape().dimensions(dim);
HloInstruction* new_instr = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
new_instr);
}
return true;
}
StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
if (instr->opcode() != HloOpcode::kSetDimensionSize) {
return false;
}
TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
instr->shape(), instr->operand(0)->shape()))
<< "instr->shape() " << instr->shape().ToString() << " , "
<< "instruction operand shape " << instr->operand(0)->shape();
HloInstruction* operand = instr->mutable_operand(0);
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
return true;
}
bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
int64 dimension) {
if ((inst->opcode() == HloOpcode::kReduceWindow ||
@ -1292,6 +1346,22 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
/*require_dynamic_output=*/require_dynamic_output));
}
for (auto* computation : module->computations()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(
bool replaced_get_size,
ReplaceGetSize(instruction, &dynamic_dimension_inference));
changed = changed || replaced_get_size;
}
}
for (auto* computation : module->computations()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
changed = changed || replaced_set_size;
}
}
HloDCE dce;
TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
VLOG(2) << "Post DynamicPadder HLO:";

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@ -382,8 +381,6 @@ class ExecutionTest : public HloTestBase {
bool slice_dynamic_output = true) {
DynamicPadder padder(slice_dynamic_output);
TF_CHECK_OK(padder.Run(module.get()).status());
HloGetDimensionSizeRewriter rewriter;
TF_CHECK_OK(rewriter.Run(module.get()).status());
HloDCE dce;
TF_CHECK_OK(dce.Run(module.get()).status());
return ExecuteAndTransfer(std::move(module), arguments);
@ -1371,5 +1368,70 @@ ENTRY main {
EXPECT_EQ(result, expected);
}
namespace op = xla::testing::opcode_matchers;
class HloDimensionSizeLegalizerTest : public HloTestBase {
protected:
HloDimensionSizeLegalizerTest() {}
};
TEST_F(HloDimensionSizeLegalizerTest, Ok) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule _
ENTRY gds {
p = s32[3,4] parameter(0)
size0 = s32[] get-dimension-size(p), dimensions={0}
size1 = s32[] get-dimension-size(p), dimensions={1}
ROOT mul = s32[] multiply(size0, size1)
})")
.ValueOrDie();
DynamicPadder pass;
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Multiply(op::Constant(), op::Constant()));
}
TEST_F(HloDimensionSizeLegalizerTest, GetSetSetDimensionSizeRewriter) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule _
ENTRY gds {
p = s32[3,4] parameter(0)
size0 = s32[] get-dimension-size(p), dimensions={0}
p_copy = s32[3,4] copy(p)
p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
ROOT mul = s32[] multiply(size0, size1)
})")
.ValueOrDie();
DynamicPadder pass;
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Multiply(op::Constant(), op::Constant()));
}
TEST_F(HloDimensionSizeLegalizerTest, IllegalType) {
auto module = ParseAndReturnUnverifiedModule(R"(
HloModule _
ENTRY gds {
p = s32[3]{0} parameter(0)
ROOT gds = s64[] get-dimension-size(p), dimensions={0}
})")
.ValueOrDie();
DynamicPadder pass;
EXPECT_FALSE(pass.Run(module.get()).ok());
}
TEST_F(HloDimensionSizeLegalizerTest, IllegalDimension) {
auto module = ParseAndReturnUnverifiedModule(R"(
HloModule _
ENTRY gds {
p = f32[2,5] parameter(0)
ROOT gds = s32[] get-dimension-size(p), dimensions={2}
})")
.ValueOrDie();
DynamicPadder pass;
EXPECT_FALSE(pass.Run(module.get()).ok());
}
} // namespace
} // namespace xla

View File

@ -1194,7 +1194,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
"//tensorflow/compiler/xla/service:hlo_dce",
"//tensorflow/compiler/xla/service:hlo_element_type_converter",
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_proto_util",

View File

@ -83,7 +83,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
@ -197,8 +196,6 @@ Status GpuCompiler::OptimizeHloModule(
/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pass.AddPass<HloGetDimensionSizeRewriter>();
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
// elimination has to come after that pass.
pass.AddPass<ZeroSizedHloElimination>();

View File

@ -1,120 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
namespace xla {
namespace {
StatusOr<bool> ReplaceGetSize(
HloInstruction* instr,
DynamicDimensionInference* dynamic_dimension_inference) {
if (instr->opcode() != HloOpcode::kGetDimensionSize) {
return false;
}
HloComputation* computation = instr->parent();
TF_ASSIGN_OR_RETURN(auto legal_shape,
ShapeInference::InferGetDimensionSizeShape(
instr->operand(0)->shape(), instr->dimension()));
TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
<< "instr->shape() " << instr->shape().ToString() << " , "
<< "legal_shape " << legal_shape.ToString();
TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
HloInstruction* operand = instr->mutable_operand(0);
int64 dim = instr->dimension();
HloInstruction* dynamic_size =
dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
if (dynamic_size != nullptr) {
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
// The dependency between a instruction and its dynamic dimensions is not
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
// dynamic dimension inference that the instruction is being replaced.
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
instr, dynamic_size);
} else {
int32 size = instr->operand(0)->shape().dimensions(dim);
HloInstruction* new_instr = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
new_instr);
}
return true;
}
StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
if (instr->opcode() != HloOpcode::kSetDimensionSize) {
return false;
}
TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
instr->shape(), instr->operand(0)->shape()))
<< "instr->shape() " << instr->shape().ToString() << " , "
<< "instruction operand shape " << instr->operand(0)->shape();
HloInstruction* operand = instr->mutable_operand(0);
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
return true;
}
} // namespace
StatusOr<bool> HloGetDimensionSizeRewriter::Run(HloModule* module) {
bool changed = false;
HloProto proto;
TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference,
DynamicDimensionInference::Run(module));
*proto.mutable_hlo_module() = module->ToProto();
// It's important to replace get-dimension-size first before
// set-dimension-size for the case below:
// static_op dynamic_size
// | |
// set-dimension-size // Marks the dimension as dynamic
// |
// get-dimension-size
//
// If we replace set dimension size first, we'd have
//
// static_op
// |
// get-dimension-size
//
// This will get static size of the op, which is incorrect.
for (auto* computation : module->computations()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool replaced_get_size,
ReplaceGetSize(instruction, &inference));
changed = changed || replaced_get_size;
}
}
for (auto* computation : module->computations()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
changed = changed || replaced_set_size;
}
}
return changed;
}
} // namespace xla

View File

@ -1,102 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
namespace op = xla::testing::opcode_matchers;
class HloGetDimensionSizeRewriterTest : public HloTestBase {
protected:
HloGetDimensionSizeRewriterTest() {}
};
TEST_F(HloGetDimensionSizeRewriterTest, Ok) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule _
ENTRY gds {
p = s32[3,4] parameter(0)
size0 = s32[] get-dimension-size(p), dimensions={0}
size1 = s32[] get-dimension-size(p), dimensions={1}
ROOT mul = s32[] multiply(size0, size1)
})")
.ValueOrDie();
HloGetDimensionSizeRewriter pass;
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Multiply(op::Constant(), op::Constant()));
}
TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule _
ENTRY gds {
p = s32[3,4] parameter(0)
size0 = s32[] get-dimension-size(p), dimensions={0}
p_copy = s32[3,4] copy(p)
p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
ROOT mul = s32[] multiply(size0, size1)
})")
.ValueOrDie();
HloGetDimensionSizeRewriter pass;
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Multiply(op::Constant(), op::Constant()));
}
TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) {
auto module = ParseAndReturnUnverifiedModule(R"(
HloModule _
ENTRY gds {
p = s32[3]{0} parameter(0)
ROOT gds = s64[] get-dimension-size(p), dimensions={0}
})")
.ValueOrDie();
HloGetDimensionSizeRewriter pass;
EXPECT_FALSE(pass.Run(module.get()).ok());
}
TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) {
auto module = ParseAndReturnUnverifiedModule(R"(
HloModule _
ENTRY gds {
p = f32[2,5] parameter(0)
ROOT gds = s32[] get-dimension-size(p), dimensions={2}
})")
.ValueOrDie();
HloGetDimensionSizeRewriter pass;
EXPECT_FALSE(pass.Run(module.get()).ok());
}
} // namespace
} // namespace xla