Add kPad to dynamism inference
PiperOrigin-RevId: 346416644 Change-Id: Ib601361be72e0047932ec56219892f5bb1e5cb12
This commit is contained in:
parent
5ff53a1d66
commit
17f251ee93
@ -3347,6 +3347,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
// contant False if dimension is static.
|
||||
// - Reduce: Convert to reduce or.
|
||||
// - Constant: Convert to constant False.
|
||||
// - Reshape, slice, transpose, pad:
|
||||
// Convert into predicate type with same opcode.
|
||||
// - Other ops: Not supported.
|
||||
// Create the instruction for the new handle.
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||
@ -3449,6 +3451,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kPad:
|
||||
break;
|
||||
case HloOpcode::kGetDimensionSize: {
|
||||
int64 dimension = instr_proto->dimensions(0);
|
||||
|
||||
@ -317,5 +317,27 @@ TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, InferThroughPad) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
// Test the analysis on a gather.
|
||||
auto operand1 = ConstantR1<int32>(&b, {1, 2});
|
||||
auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
|
||||
PaddingConfig padding_config;
|
||||
padding_config.add_dimensions()->set_edge_padding_high(1);
|
||||
// After pad the value is [constant, constant, parameter].
|
||||
auto pad = Pad(operand1, parameter, padding_config);
|
||||
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||
// Everything is constant, result is also contant.
|
||||
EXPECT_FALSE(
|
||||
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({0}));
|
||||
EXPECT_FALSE(
|
||||
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({1}));
|
||||
EXPECT_TRUE(
|
||||
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({2}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user