From e2eb7e36417c4b4582161c5a12367a83da5ae82b Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Mon, 14 Oct 2019 17:46:05 -0700 Subject: [PATCH] [XLA] HandleClamp in DynamicPadder. PiperOrigin-RevId: 274700208 --- .../service/dynamic_dimension_inference.cc | 6 ++++ .../xla/service/dynamic_padder_test.cc | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index b36d125ed76..39a47501bfb 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -85,6 +85,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleElementwiseBinary(HloInstruction* hlo) override; + Status HandleClamp(HloInstruction* hlo) override; + Status HandleWhile(HloInstruction* hlo) override; Status HandleSlice(HloInstruction* hlo) override; @@ -511,6 +513,10 @@ Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary( return PassThroughDynamicDimension(hlo); } +Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index 94f33a22714..e09d8235a63 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -461,6 +461,40 @@ ENTRY main { EXPECT_EQ(padded, not_padded); } +XLA_TEST_F(ExecutionTest, DynamicDimensionClamp) { + const string hlo_text = R"( +HloModule TensorFlowTenaryV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +ENTRY main { + param = s32[5] parameter(0) + const = s32[] constant(3) + param_padded = s32[5] set-dimension-size(param, const), dimensions={0} + clamp = s32[5] clamp(param_padded, param_padded, param_padded) + init = s32[] constant(0) + ROOT reduce = s32[] reduce(clamp, init), + dimensions={0}, + to_apply=update_s32 +} +)"; + + // Input has upper bound of 5, dynamic dimension is 3. + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4, 5}); + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}); + + // only first 3 elements will be reduced. + Literal expected = LiteralUtil::CreateR0(6); + + EXPECT_EQ(result, expected); +} + XLA_TEST_F(ExecutionTest, DynamicDimensionReduce) { const string hlo_text = R"( HloModule TensorFlowScatterV1