Add kPad to dynamism inference

PiperOrigin-RevId: 346416644
Change-Id: Ib601361be72e0047932ec56219892f5bb1e5cb12
This commit is contained in:
Yunxing Dai 2020-12-08 14:54:56 -08:00 committed by TensorFlower Gardener
parent 5ff53a1d66
commit 17f251ee93
2 changed files with 25 additions and 0 deletions

View File

@ -3347,6 +3347,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
// contant False if dimension is static.
// - Reduce: Convert to reduce or.
// - Constant: Convert to constant False.
// - Reshape, slice, transpose, pad:
// Convert into predicate type with same opcode.
// - Other ops: Not supported.
// Create the instruction for the new handle.
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
@ -3449,6 +3451,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
case HloOpcode::kBroadcast:
case HloOpcode::kConcatenate:
case HloOpcode::kReshape:
case HloOpcode::kPad:
break;
case HloOpcode::kGetDimensionSize: {
int64 dimension = instr_proto->dimensions(0);

View File

@ -317,5 +317,27 @@ TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
}
}
TEST_F(DynamismInferenceTest, InferThroughPad) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
// Test the analysis on a gather.
auto operand1 = ConstantR1<int32>(&b, {1, 2});
auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
PaddingConfig padding_config;
padding_config.add_dimensions()->set_edge_padding_high(1);
// After pad the value is [constant, constant, parameter].
auto pad = Pad(operand1, parameter, padding_config);
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
// Everything is constant, result is also contant.
EXPECT_FALSE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({0}));
EXPECT_FALSE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({1}));
EXPECT_TRUE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({2}));
}
}
} // namespace
} // namespace xla