[NFC] Remove HloGetDimensionSizeRewriter.
- All backends support dynamic padder now, no need for a separate pass. - This allows DynamicDimensionInference to run just once. PiperOrigin-RevId: 326719617 Change-Id: I4a49ef16c3868224af0431d90e8fd164a367ea81
This commit is contained in:
parent
ba337c699f
commit
30f38ef537
@ -2684,6 +2684,7 @@ cc_library(
|
||||
":hlo_casting_utils",
|
||||
":hlo_dce",
|
||||
":hlo_pass",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -2707,7 +2708,6 @@ xla_test(
|
||||
":dynamic_padder",
|
||||
":hlo",
|
||||
":hlo_dce",
|
||||
":hlo_get_dimension_size_rewriter",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
@ -3997,42 +3997,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_get_dimension_size_rewriter",
|
||||
srcs = ["hlo_get_dimension_size_rewriter.cc"],
|
||||
hdrs = ["hlo_get_dimension_size_rewriter.h"],
|
||||
deps = [
|
||||
":dynamic_dimension_inference",
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_get_dimension_size_rewriter_test",
|
||||
srcs = ["hlo_get_dimension_size_rewriter_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_get_dimension_size_rewriter",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "maybe_owning_device_memory",
|
||||
srcs = [
|
||||
|
@ -140,7 +140,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:map_inliner",
|
||||
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
|
||||
"//tensorflow/compiler/xla/service:tree_reduction_rewriter",
|
||||
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
|
||||
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
|
||||
"//tensorflow/compiler/xla/service:conditional_to_select",
|
||||
"//tensorflow/compiler/xla/service:slow_operation_alarm",
|
||||
|
@ -85,7 +85,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_cse.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -292,7 +291,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pipeline.AddPass<ConditionalCanonicalizer>();
|
||||
pipeline.AddPass<DynamicPadder>();
|
||||
pipeline.AddPass<ScatterExpander>();
|
||||
pipeline.AddPass<HloGetDimensionSizeRewriter>();
|
||||
pipeline.AddPass<ConvCanonicalization>(target_machine_features);
|
||||
{
|
||||
auto& pass =
|
||||
|
@ -32,6 +32,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -125,6 +127,58 @@ StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<bool> ReplaceGetSize(
|
||||
HloInstruction* instr,
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
if (instr->opcode() != HloOpcode::kGetDimensionSize) {
|
||||
return false;
|
||||
}
|
||||
HloComputation* computation = instr->parent();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto legal_shape,
|
||||
ShapeInference::InferGetDimensionSizeShape(
|
||||
instr->operand(0)->shape(), instr->dimension()));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
|
||||
<< "instr->shape() " << instr->shape().ToString() << " , "
|
||||
<< "legal_shape " << legal_shape.ToString();
|
||||
TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
|
||||
HloInstruction* operand = instr->mutable_operand(0);
|
||||
int64 dim = instr->dimension();
|
||||
HloInstruction* dynamic_size =
|
||||
dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
|
||||
if (dynamic_size != nullptr) {
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
|
||||
// The dependency between a instruction and its dynamic dimensions is not
|
||||
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
|
||||
// dynamic dimension inference that the instruction is being replaced.
|
||||
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
|
||||
instr, dynamic_size);
|
||||
} else {
|
||||
int32 size = instr->operand(0)->shape().dimensions(dim);
|
||||
HloInstruction* new_instr = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
|
||||
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
|
||||
new_instr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
|
||||
if (instr->opcode() != HloOpcode::kSetDimensionSize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
|
||||
instr->shape(), instr->operand(0)->shape()))
|
||||
<< "instr->shape() " << instr->shape().ToString() << " , "
|
||||
<< "instruction operand shape " << instr->operand(0)->shape();
|
||||
HloInstruction* operand = instr->mutable_operand(0);
|
||||
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
|
||||
int64 dimension) {
|
||||
if ((inst->opcode() == HloOpcode::kReduceWindow ||
|
||||
@ -1292,6 +1346,22 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
/*require_dynamic_output=*/require_dynamic_output));
|
||||
}
|
||||
|
||||
for (auto* computation : module->computations()) {
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool replaced_get_size,
|
||||
ReplaceGetSize(instruction, &dynamic_dimension_inference));
|
||||
changed = changed || replaced_get_size;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* computation : module->computations()) {
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
|
||||
changed = changed || replaced_set_size;
|
||||
}
|
||||
}
|
||||
|
||||
HloDCE dce;
|
||||
TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
|
||||
VLOG(2) << "Post DynamicPadder HLO:";
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
@ -382,8 +381,6 @@ class ExecutionTest : public HloTestBase {
|
||||
bool slice_dynamic_output = true) {
|
||||
DynamicPadder padder(slice_dynamic_output);
|
||||
TF_CHECK_OK(padder.Run(module.get()).status());
|
||||
HloGetDimensionSizeRewriter rewriter;
|
||||
TF_CHECK_OK(rewriter.Run(module.get()).status());
|
||||
HloDCE dce;
|
||||
TF_CHECK_OK(dce.Run(module.get()).status());
|
||||
return ExecuteAndTransfer(std::move(module), arguments);
|
||||
@ -1371,5 +1368,70 @@ ENTRY main {
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class HloDimensionSizeLegalizerTest : public HloTestBase {
|
||||
protected:
|
||||
HloDimensionSizeLegalizerTest() {}
|
||||
};
|
||||
|
||||
TEST_F(HloDimensionSizeLegalizerTest, Ok) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = s32[3,4] parameter(0)
|
||||
size0 = s32[] get-dimension-size(p), dimensions={0}
|
||||
size1 = s32[] get-dimension-size(p), dimensions={1}
|
||||
ROOT mul = s32[] multiply(size0, size1)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
DynamicPadder pass;
|
||||
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Multiply(op::Constant(), op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(HloDimensionSizeLegalizerTest, GetSetSetDimensionSizeRewriter) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = s32[3,4] parameter(0)
|
||||
size0 = s32[] get-dimension-size(p), dimensions={0}
|
||||
p_copy = s32[3,4] copy(p)
|
||||
p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
|
||||
size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
|
||||
ROOT mul = s32[] multiply(size0, size1)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
DynamicPadder pass;
|
||||
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Multiply(op::Constant(), op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(HloDimensionSizeLegalizerTest, IllegalType) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = s32[3]{0} parameter(0)
|
||||
ROOT gds = s64[] get-dimension-size(p), dimensions={0}
|
||||
})")
|
||||
.ValueOrDie();
|
||||
DynamicPadder pass;
|
||||
EXPECT_FALSE(pass.Run(module.get()).ok());
|
||||
}
|
||||
|
||||
TEST_F(HloDimensionSizeLegalizerTest, IllegalDimension) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = f32[2,5] parameter(0)
|
||||
ROOT gds = s32[] get-dimension-size(p), dimensions={2}
|
||||
})")
|
||||
.ValueOrDie();
|
||||
DynamicPadder pass;
|
||||
EXPECT_FALSE(pass.Run(module.get()).ok());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1194,7 +1194,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
|
||||
"//tensorflow/compiler/xla/service:hlo_dce",
|
||||
"//tensorflow/compiler/xla/service:hlo_element_type_converter",
|
||||
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_util",
|
||||
|
@ -83,7 +83,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
@ -197,8 +196,6 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
|
||||
pass.AddPass<HloGetDimensionSizeRewriter>();
|
||||
|
||||
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
|
||||
// elimination has to come after that pass.
|
||||
pass.AddPass<ZeroSizedHloElimination>();
|
||||
|
@ -1,120 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
StatusOr<bool> ReplaceGetSize(
|
||||
HloInstruction* instr,
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
if (instr->opcode() != HloOpcode::kGetDimensionSize) {
|
||||
return false;
|
||||
}
|
||||
HloComputation* computation = instr->parent();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto legal_shape,
|
||||
ShapeInference::InferGetDimensionSizeShape(
|
||||
instr->operand(0)->shape(), instr->dimension()));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
|
||||
<< "instr->shape() " << instr->shape().ToString() << " , "
|
||||
<< "legal_shape " << legal_shape.ToString();
|
||||
TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
|
||||
HloInstruction* operand = instr->mutable_operand(0);
|
||||
int64 dim = instr->dimension();
|
||||
HloInstruction* dynamic_size =
|
||||
dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
|
||||
if (dynamic_size != nullptr) {
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
|
||||
// The dependency between a instruction and its dynamic dimensions is not
|
||||
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
|
||||
// dynamic dimension inference that the instruction is being replaced.
|
||||
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
|
||||
instr, dynamic_size);
|
||||
} else {
|
||||
int32 size = instr->operand(0)->shape().dimensions(dim);
|
||||
HloInstruction* new_instr = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
|
||||
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
|
||||
new_instr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
|
||||
if (instr->opcode() != HloOpcode::kSetDimensionSize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
|
||||
instr->shape(), instr->operand(0)->shape()))
|
||||
<< "instr->shape() " << instr->shape().ToString() << " , "
|
||||
<< "instruction operand shape " << instr->operand(0)->shape();
|
||||
HloInstruction* operand = instr->mutable_operand(0);
|
||||
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> HloGetDimensionSizeRewriter::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
HloProto proto;
|
||||
TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference,
|
||||
DynamicDimensionInference::Run(module));
|
||||
*proto.mutable_hlo_module() = module->ToProto();
|
||||
// It's important to replace get-dimension-size first before
|
||||
// set-dimension-size for the case below:
|
||||
// static_op dynamic_size
|
||||
// | |
|
||||
// set-dimension-size // Marks the dimension as dynamic
|
||||
// |
|
||||
// get-dimension-size
|
||||
//
|
||||
// If we replace set dimension size first, we'd have
|
||||
//
|
||||
// static_op
|
||||
// |
|
||||
// get-dimension-size
|
||||
//
|
||||
// This will get static size of the op, which is incorrect.
|
||||
for (auto* computation : module->computations()) {
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool replaced_get_size,
|
||||
ReplaceGetSize(instruction, &inference));
|
||||
changed = changed || replaced_get_size;
|
||||
}
|
||||
}
|
||||
for (auto* computation : module->computations()) {
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
|
||||
changed = changed || replaced_set_size;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -1,102 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class HloGetDimensionSizeRewriterTest : public HloTestBase {
|
||||
protected:
|
||||
HloGetDimensionSizeRewriterTest() {}
|
||||
};
|
||||
|
||||
TEST_F(HloGetDimensionSizeRewriterTest, Ok) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = s32[3,4] parameter(0)
|
||||
size0 = s32[] get-dimension-size(p), dimensions={0}
|
||||
size1 = s32[] get-dimension-size(p), dimensions={1}
|
||||
ROOT mul = s32[] multiply(size0, size1)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
HloGetDimensionSizeRewriter pass;
|
||||
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Multiply(op::Constant(), op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = s32[3,4] parameter(0)
|
||||
size0 = s32[] get-dimension-size(p), dimensions={0}
|
||||
p_copy = s32[3,4] copy(p)
|
||||
p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
|
||||
size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
|
||||
ROOT mul = s32[] multiply(size0, size1)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
HloGetDimensionSizeRewriter pass;
|
||||
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Multiply(op::Constant(), op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = s32[3]{0} parameter(0)
|
||||
ROOT gds = s64[] get-dimension-size(p), dimensions={0}
|
||||
})")
|
||||
.ValueOrDie();
|
||||
HloGetDimensionSizeRewriter pass;
|
||||
EXPECT_FALSE(pass.Run(module.get()).ok());
|
||||
}
|
||||
|
||||
TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = f32[2,5] parameter(0)
|
||||
ROOT gds = s32[] get-dimension-size(p), dimensions={2}
|
||||
})")
|
||||
.ValueOrDie();
|
||||
HloGetDimensionSizeRewriter pass;
|
||||
EXPECT_FALSE(pass.Run(module.get()).ok());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user