[NFC] Remove HloGetDimensionSizeRewriter.
- All backends support dynamic padder now, no need for a separate pass. - This allows DynamicDimensionInference to run just once. PiperOrigin-RevId: 326719617 Change-Id: I4a49ef16c3868224af0431d90e8fd164a367ea81
This commit is contained in:
parent
ba337c699f
commit
30f38ef537
@ -2684,6 +2684,7 @@ cc_library(
|
|||||||
":hlo_casting_utils",
|
":hlo_casting_utils",
|
||||||
":hlo_dce",
|
":hlo_dce",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
|
":shape_inference",
|
||||||
"//tensorflow/compiler/xla:comparison_util",
|
"//tensorflow/compiler/xla:comparison_util",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
@ -2707,7 +2708,6 @@ xla_test(
|
|||||||
":dynamic_padder",
|
":dynamic_padder",
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_dce",
|
":hlo_dce",
|
||||||
":hlo_get_dimension_size_rewriter",
|
|
||||||
":hlo_matchers",
|
":hlo_matchers",
|
||||||
":hlo_parser",
|
":hlo_parser",
|
||||||
"//tensorflow/compiler/xla:debug_options_flags",
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
@ -3997,42 +3997,6 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "hlo_get_dimension_size_rewriter",
|
|
||||||
srcs = ["hlo_get_dimension_size_rewriter.cc"],
|
|
||||||
hdrs = ["hlo_get_dimension_size_rewriter.h"],
|
|
||||||
deps = [
|
|
||||||
":dynamic_dimension_inference",
|
|
||||||
":hlo",
|
|
||||||
":hlo_pass",
|
|
||||||
":shape_inference",
|
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
|
||||||
"@com_google_absl//absl/algorithm:container",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cc_test(
|
|
||||||
name = "hlo_get_dimension_size_rewriter_test",
|
|
||||||
srcs = ["hlo_get_dimension_size_rewriter_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":hlo",
|
|
||||||
":hlo_get_dimension_size_rewriter",
|
|
||||||
":hlo_matchers",
|
|
||||||
":hlo_parser",
|
|
||||||
"//tensorflow/compiler/xla:literal",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
|
||||||
"//tensorflow/compiler/xla:types",
|
|
||||||
"//tensorflow/compiler/xla:util",
|
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
|
||||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
|
||||||
"//tensorflow/compiler/xla/tests:test_utils",
|
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "maybe_owning_device_memory",
|
name = "maybe_owning_device_memory",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -140,7 +140,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:map_inliner",
|
"//tensorflow/compiler/xla/service:map_inliner",
|
||||||
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
|
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
|
||||||
"//tensorflow/compiler/xla/service:tree_reduction_rewriter",
|
"//tensorflow/compiler/xla/service:tree_reduction_rewriter",
|
||||||
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
|
|
||||||
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
|
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
|
||||||
"//tensorflow/compiler/xla/service:conditional_to_select",
|
"//tensorflow/compiler/xla/service:conditional_to_select",
|
||||||
"//tensorflow/compiler/xla/service:slow_operation_alarm",
|
"//tensorflow/compiler/xla/service:slow_operation_alarm",
|
||||||
|
@ -85,7 +85,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_cse.h"
|
#include "tensorflow/compiler/xla/service/hlo_cse.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
|
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
@ -292,7 +291,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
|||||||
pipeline.AddPass<ConditionalCanonicalizer>();
|
pipeline.AddPass<ConditionalCanonicalizer>();
|
||||||
pipeline.AddPass<DynamicPadder>();
|
pipeline.AddPass<DynamicPadder>();
|
||||||
pipeline.AddPass<ScatterExpander>();
|
pipeline.AddPass<ScatterExpander>();
|
||||||
pipeline.AddPass<HloGetDimensionSizeRewriter>();
|
|
||||||
pipeline.AddPass<ConvCanonicalization>(target_machine_features);
|
pipeline.AddPass<ConvCanonicalization>(target_machine_features);
|
||||||
{
|
{
|
||||||
auto& pass =
|
auto& pass =
|
||||||
|
@ -32,6 +32,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
@ -125,6 +127,58 @@ StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<bool> ReplaceGetSize(
|
||||||
|
HloInstruction* instr,
|
||||||
|
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||||
|
if (instr->opcode() != HloOpcode::kGetDimensionSize) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
HloComputation* computation = instr->parent();
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto legal_shape,
|
||||||
|
ShapeInference::InferGetDimensionSizeShape(
|
||||||
|
instr->operand(0)->shape(), instr->dimension()));
|
||||||
|
TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
|
||||||
|
<< "instr->shape() " << instr->shape().ToString() << " , "
|
||||||
|
<< "legal_shape " << legal_shape.ToString();
|
||||||
|
TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
|
||||||
|
HloInstruction* operand = instr->mutable_operand(0);
|
||||||
|
int64 dim = instr->dimension();
|
||||||
|
HloInstruction* dynamic_size =
|
||||||
|
dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
|
||||||
|
if (dynamic_size != nullptr) {
|
||||||
|
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
|
||||||
|
// The dependency between a instruction and its dynamic dimensions is not
|
||||||
|
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
|
||||||
|
// dynamic dimension inference that the instruction is being replaced.
|
||||||
|
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
|
||||||
|
instr, dynamic_size);
|
||||||
|
} else {
|
||||||
|
int32 size = instr->operand(0)->shape().dimensions(dim);
|
||||||
|
HloInstruction* new_instr = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
|
||||||
|
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
|
||||||
|
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
|
||||||
|
new_instr);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
|
||||||
|
if (instr->opcode() != HloOpcode::kSetDimensionSize) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
|
||||||
|
instr->shape(), instr->operand(0)->shape()))
|
||||||
|
<< "instr->shape() " << instr->shape().ToString() << " , "
|
||||||
|
<< "instruction operand shape " << instr->operand(0)->shape();
|
||||||
|
HloInstruction* operand = instr->mutable_operand(0);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
|
bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
|
||||||
int64 dimension) {
|
int64 dimension) {
|
||||||
if ((inst->opcode() == HloOpcode::kReduceWindow ||
|
if ((inst->opcode() == HloOpcode::kReduceWindow ||
|
||||||
@ -1292,6 +1346,22 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
|||||||
/*require_dynamic_output=*/require_dynamic_output));
|
/*require_dynamic_output=*/require_dynamic_output));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto* computation : module->computations()) {
|
||||||
|
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
bool replaced_get_size,
|
||||||
|
ReplaceGetSize(instruction, &dynamic_dimension_inference));
|
||||||
|
changed = changed || replaced_get_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto* computation : module->computations()) {
|
||||||
|
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
|
||||||
|
changed = changed || replaced_set_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
HloDCE dce;
|
HloDCE dce;
|
||||||
TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
|
TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
|
||||||
VLOG(2) << "Post DynamicPadder HLO:";
|
VLOG(2) << "Post DynamicPadder HLO:";
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
@ -382,8 +381,6 @@ class ExecutionTest : public HloTestBase {
|
|||||||
bool slice_dynamic_output = true) {
|
bool slice_dynamic_output = true) {
|
||||||
DynamicPadder padder(slice_dynamic_output);
|
DynamicPadder padder(slice_dynamic_output);
|
||||||
TF_CHECK_OK(padder.Run(module.get()).status());
|
TF_CHECK_OK(padder.Run(module.get()).status());
|
||||||
HloGetDimensionSizeRewriter rewriter;
|
|
||||||
TF_CHECK_OK(rewriter.Run(module.get()).status());
|
|
||||||
HloDCE dce;
|
HloDCE dce;
|
||||||
TF_CHECK_OK(dce.Run(module.get()).status());
|
TF_CHECK_OK(dce.Run(module.get()).status());
|
||||||
return ExecuteAndTransfer(std::move(module), arguments);
|
return ExecuteAndTransfer(std::move(module), arguments);
|
||||||
@ -1371,5 +1368,70 @@ ENTRY main {
|
|||||||
EXPECT_EQ(result, expected);
|
EXPECT_EQ(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace op = xla::testing::opcode_matchers;
|
||||||
|
|
||||||
|
class HloDimensionSizeLegalizerTest : public HloTestBase {
|
||||||
|
protected:
|
||||||
|
HloDimensionSizeLegalizerTest() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(HloDimensionSizeLegalizerTest, Ok) {
|
||||||
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
|
HloModule _
|
||||||
|
ENTRY gds {
|
||||||
|
p = s32[3,4] parameter(0)
|
||||||
|
size0 = s32[] get-dimension-size(p), dimensions={0}
|
||||||
|
size1 = s32[] get-dimension-size(p), dimensions={1}
|
||||||
|
ROOT mul = s32[] multiply(size0, size1)
|
||||||
|
})")
|
||||||
|
.ValueOrDie();
|
||||||
|
DynamicPadder pass;
|
||||||
|
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
op::Multiply(op::Constant(), op::Constant()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloDimensionSizeLegalizerTest, GetSetSetDimensionSizeRewriter) {
|
||||||
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
|
HloModule _
|
||||||
|
ENTRY gds {
|
||||||
|
p = s32[3,4] parameter(0)
|
||||||
|
size0 = s32[] get-dimension-size(p), dimensions={0}
|
||||||
|
p_copy = s32[3,4] copy(p)
|
||||||
|
p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
|
||||||
|
size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
|
||||||
|
ROOT mul = s32[] multiply(size0, size1)
|
||||||
|
})")
|
||||||
|
.ValueOrDie();
|
||||||
|
DynamicPadder pass;
|
||||||
|
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
op::Multiply(op::Constant(), op::Constant()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloDimensionSizeLegalizerTest, IllegalType) {
|
||||||
|
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||||
|
HloModule _
|
||||||
|
ENTRY gds {
|
||||||
|
p = s32[3]{0} parameter(0)
|
||||||
|
ROOT gds = s64[] get-dimension-size(p), dimensions={0}
|
||||||
|
})")
|
||||||
|
.ValueOrDie();
|
||||||
|
DynamicPadder pass;
|
||||||
|
EXPECT_FALSE(pass.Run(module.get()).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloDimensionSizeLegalizerTest, IllegalDimension) {
|
||||||
|
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||||
|
HloModule _
|
||||||
|
ENTRY gds {
|
||||||
|
p = f32[2,5] parameter(0)
|
||||||
|
ROOT gds = s32[] get-dimension-size(p), dimensions={2}
|
||||||
|
})")
|
||||||
|
.ValueOrDie();
|
||||||
|
DynamicPadder pass;
|
||||||
|
EXPECT_FALSE(pass.Run(module.get()).ok());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -1194,7 +1194,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
|
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
|
||||||
"//tensorflow/compiler/xla/service:hlo_dce",
|
"//tensorflow/compiler/xla/service:hlo_dce",
|
||||||
"//tensorflow/compiler/xla/service:hlo_element_type_converter",
|
"//tensorflow/compiler/xla/service:hlo_element_type_converter",
|
||||||
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||||
"//tensorflow/compiler/xla/service:hlo_proto_util",
|
"//tensorflow/compiler/xla/service:hlo_proto_util",
|
||||||
|
@ -83,7 +83,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
|
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||||
@ -197,8 +196,6 @@ Status GpuCompiler::OptimizeHloModule(
|
|||||||
/*layout_sensitive=*/false,
|
/*layout_sensitive=*/false,
|
||||||
/*allow_mixed_precision=*/false);
|
/*allow_mixed_precision=*/false);
|
||||||
|
|
||||||
pass.AddPass<HloGetDimensionSizeRewriter>();
|
|
||||||
|
|
||||||
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
|
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
|
||||||
// elimination has to come after that pass.
|
// elimination has to come after that pass.
|
||||||
pass.AddPass<ZeroSizedHloElimination>();
|
pass.AddPass<ZeroSizedHloElimination>();
|
||||||
|
@ -1,120 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
|
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
|
||||||
|
|
||||||
namespace xla {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
StatusOr<bool> ReplaceGetSize(
|
|
||||||
HloInstruction* instr,
|
|
||||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
|
||||||
if (instr->opcode() != HloOpcode::kGetDimensionSize) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
HloComputation* computation = instr->parent();
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto legal_shape,
|
|
||||||
ShapeInference::InferGetDimensionSizeShape(
|
|
||||||
instr->operand(0)->shape(), instr->dimension()));
|
|
||||||
TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
|
|
||||||
<< "instr->shape() " << instr->shape().ToString() << " , "
|
|
||||||
<< "legal_shape " << legal_shape.ToString();
|
|
||||||
TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
|
|
||||||
HloInstruction* operand = instr->mutable_operand(0);
|
|
||||||
int64 dim = instr->dimension();
|
|
||||||
HloInstruction* dynamic_size =
|
|
||||||
dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
|
|
||||||
if (dynamic_size != nullptr) {
|
|
||||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
|
|
||||||
// The dependency between a instruction and its dynamic dimensions is not
|
|
||||||
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
|
|
||||||
// dynamic dimension inference that the instruction is being replaced.
|
|
||||||
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
|
|
||||||
instr, dynamic_size);
|
|
||||||
} else {
|
|
||||||
int32 size = instr->operand(0)->shape().dimensions(dim);
|
|
||||||
HloInstruction* new_instr = computation->AddInstruction(
|
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
|
|
||||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
|
|
||||||
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
|
|
||||||
new_instr);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
|
|
||||||
if (instr->opcode() != HloOpcode::kSetDimensionSize) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
|
|
||||||
instr->shape(), instr->operand(0)->shape()))
|
|
||||||
<< "instr->shape() " << instr->shape().ToString() << " , "
|
|
||||||
<< "instruction operand shape " << instr->operand(0)->shape();
|
|
||||||
HloInstruction* operand = instr->mutable_operand(0);
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
StatusOr<bool> HloGetDimensionSizeRewriter::Run(HloModule* module) {
|
|
||||||
bool changed = false;
|
|
||||||
HloProto proto;
|
|
||||||
TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference,
|
|
||||||
DynamicDimensionInference::Run(module));
|
|
||||||
*proto.mutable_hlo_module() = module->ToProto();
|
|
||||||
// It's important to replace get-dimension-size first before
|
|
||||||
// set-dimension-size for the case below:
|
|
||||||
// static_op dynamic_size
|
|
||||||
// | |
|
|
||||||
// set-dimension-size // Marks the dimension as dynamic
|
|
||||||
// |
|
|
||||||
// get-dimension-size
|
|
||||||
//
|
|
||||||
// If we replace set dimension size first, we'd have
|
|
||||||
//
|
|
||||||
// static_op
|
|
||||||
// |
|
|
||||||
// get-dimension-size
|
|
||||||
//
|
|
||||||
// This will get static size of the op, which is incorrect.
|
|
||||||
for (auto* computation : module->computations()) {
|
|
||||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
|
||||||
TF_ASSIGN_OR_RETURN(bool replaced_get_size,
|
|
||||||
ReplaceGetSize(instruction, &inference));
|
|
||||||
changed = changed || replaced_get_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto* computation : module->computations()) {
|
|
||||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
|
||||||
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
|
|
||||||
changed = changed || replaced_set_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return changed;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla
|
|
@ -1,102 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
|
||||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
|
|
||||||
namespace xla {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
namespace op = xla::testing::opcode_matchers;
|
|
||||||
|
|
||||||
class HloGetDimensionSizeRewriterTest : public HloTestBase {
|
|
||||||
protected:
|
|
||||||
HloGetDimensionSizeRewriterTest() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(HloGetDimensionSizeRewriterTest, Ok) {
|
|
||||||
auto module = ParseAndReturnVerifiedModule(R"(
|
|
||||||
HloModule _
|
|
||||||
ENTRY gds {
|
|
||||||
p = s32[3,4] parameter(0)
|
|
||||||
size0 = s32[] get-dimension-size(p), dimensions={0}
|
|
||||||
size1 = s32[] get-dimension-size(p), dimensions={1}
|
|
||||||
ROOT mul = s32[] multiply(size0, size1)
|
|
||||||
})")
|
|
||||||
.ValueOrDie();
|
|
||||||
HloGetDimensionSizeRewriter pass;
|
|
||||||
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
|
||||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
|
||||||
op::Multiply(op::Constant(), op::Constant()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) {
|
|
||||||
auto module = ParseAndReturnVerifiedModule(R"(
|
|
||||||
HloModule _
|
|
||||||
ENTRY gds {
|
|
||||||
p = s32[3,4] parameter(0)
|
|
||||||
size0 = s32[] get-dimension-size(p), dimensions={0}
|
|
||||||
p_copy = s32[3,4] copy(p)
|
|
||||||
p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
|
|
||||||
size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
|
|
||||||
ROOT mul = s32[] multiply(size0, size1)
|
|
||||||
})")
|
|
||||||
.ValueOrDie();
|
|
||||||
HloGetDimensionSizeRewriter pass;
|
|
||||||
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
|
||||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
|
||||||
op::Multiply(op::Constant(), op::Constant()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) {
|
|
||||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
|
||||||
HloModule _
|
|
||||||
ENTRY gds {
|
|
||||||
p = s32[3]{0} parameter(0)
|
|
||||||
ROOT gds = s64[] get-dimension-size(p), dimensions={0}
|
|
||||||
})")
|
|
||||||
.ValueOrDie();
|
|
||||||
HloGetDimensionSizeRewriter pass;
|
|
||||||
EXPECT_FALSE(pass.Run(module.get()).ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) {
|
|
||||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
|
||||||
HloModule _
|
|
||||||
ENTRY gds {
|
|
||||||
p = f32[2,5] parameter(0)
|
|
||||||
ROOT gds = s32[] get-dimension-size(p), dimensions={2}
|
|
||||||
})")
|
|
||||||
.ValueOrDie();
|
|
||||||
HloGetDimensionSizeRewriter pass;
|
|
||||||
EXPECT_FALSE(pass.Run(module.get()).ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace xla
|
|
Loading…
x
Reference in New Issue
Block a user