[XLA] HandleClamp in DynamicPadder.

PiperOrigin-RevId: 274700208
This commit is contained in:
Yunxing Dai 2019-10-14 17:46:05 -07:00 committed by TensorFlower Gardener
parent 68fedf4c80
commit e2eb7e3641
2 changed files with 40 additions and 0 deletions

View File

@ -85,6 +85,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
Status HandleElementwiseBinary(HloInstruction* hlo) override;
Status HandleClamp(HloInstruction* hlo) override;
Status HandleWhile(HloInstruction* hlo) override;
Status HandleSlice(HloInstruction* hlo) override;
@ -511,6 +513,10 @@ Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
return PassThroughDynamicDimension(hlo);
}
Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
return PassThroughDynamicDimension(hlo);
}
Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo,

View File

@ -461,6 +461,40 @@ ENTRY main {
EXPECT_EQ(padded, not_padded);
}
XLA_TEST_F(ExecutionTest, DynamicDimensionClamp) {
const string hlo_text = R"(
HloModule TensorFlowTenaryV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(lhs, rhs)
}
ENTRY main {
param = s32[5] parameter(0)
const = s32[] constant(3)
param_padded = s32[5] set-dimension-size(param, const), dimensions={0}
clamp = s32[5] clamp(param_padded, param_padded, param_padded)
init = s32[] constant(0)
ROOT reduce = s32[] reduce(clamp, init),
dimensions={0},
to_apply=update_s32
}
)";
// Input has upper bound of 5, dynamic dimension is 3.
Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5});
auto module = GetHloModule(hlo_text);
Literal result = PadAndExecute(std::move(module), {&operand});
// only first 3 elements will be reduced.
Literal expected = LiteralUtil::CreateR0<int32>(6);
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DynamicDimensionReduce) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1