diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index d3b885643fc..c23b40ab6cd 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -3347,6 +3347,8 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { // contant False if dimension is static. // - Reduce: Convert to reduce or. // - Constant: Convert to constant False. + // - Reshape, slice, transpose, pad: + // Convert into predicate type with same opcode. // - Other ops: Not supported. // Create the instruction for the new handle. TF_ASSIGN_OR_RETURN(HloOpcode opcode, @@ -3449,6 +3451,7 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { case HloOpcode::kBroadcast: case HloOpcode::kConcatenate: case HloOpcode::kReshape: + case HloOpcode::kPad: break; case HloOpcode::kGetDimensionSize: { int64 dimension = instr_proto->dimensions(0); diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc index 96ba73ac9f0..1763ed6090e 100644 --- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc +++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc @@ -317,5 +317,27 @@ TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) { } } +TEST_F(DynamismInferenceTest, InferThroughPad) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + // Test the analysis on a gather. + auto operand1 = ConstantR1(&b, {1, 2}); + auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0"); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_edge_padding_high(1); + // After pad the value is [constant, constant, parameter]. + auto pad = Pad(operand1, parameter, padding_config); + ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); + // Everything is constant, result is also contant. + EXPECT_FALSE( + ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get({0})); + EXPECT_FALSE( + ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get({1})); + EXPECT_TRUE( + ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get({2})); + } +} + } // namespace } // namespace xla