[XLA] HandleClamp in DynamicPadder.
PiperOrigin-RevId: 274700208
This commit is contained in:
parent
68fedf4c80
commit
e2eb7e3641
@ -85,6 +85,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleClamp(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleWhile(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleSlice(HloInstruction* hlo) override;
|
||||
@ -511,6 +513,10 @@ Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
|
||||
return PassThroughDynamicDimension(hlo);
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
|
||||
return PassThroughDynamicDimension(hlo);
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo,
|
||||
|
@ -461,6 +461,40 @@ ENTRY main {
|
||||
EXPECT_EQ(padded, not_padded);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicDimensionClamp) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowTenaryV1
|
||||
|
||||
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||
lhs = s32[] parameter(0)
|
||||
rhs = s32[] parameter(1)
|
||||
ROOT add = s32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = s32[5] parameter(0)
|
||||
const = s32[] constant(3)
|
||||
param_padded = s32[5] set-dimension-size(param, const), dimensions={0}
|
||||
clamp = s32[5] clamp(param_padded, param_padded, param_padded)
|
||||
init = s32[] constant(0)
|
||||
ROOT reduce = s32[] reduce(clamp, init),
|
||||
dimensions={0},
|
||||
to_apply=update_s32
|
||||
}
|
||||
)";
|
||||
|
||||
// Input has upper bound of 5, dynamic dimension is 3.
|
||||
Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5});
|
||||
auto module = GetHloModule(hlo_text);
|
||||
|
||||
Literal result = PadAndExecute(std::move(module), {&operand});
|
||||
|
||||
// only first 3 elements will be reduced.
|
||||
Literal expected = LiteralUtil::CreateR0<int32>(6);
|
||||
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicDimensionReduce) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatterV1
|
||||
|
Loading…
x
Reference in New Issue
Block a user