Check the value shape of set dimension size in shape inference.
PiperOrigin-RevId: 317351738 Change-Id: Ia185e7745753711ca9ebf657522ef7422c9696ca
This commit is contained in:
parent
f53e1aac65
commit
a8456eae42
@ -2762,9 +2762,11 @@ XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape(
|
||||
*operand_shape, dimension));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferSetDimensionSizeShape(
|
||||
*operand_shape, *val_shape, dimension));
|
||||
// Setting an op's dynamic dimension to the static size is a noop.
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
|
||||
LookUpInstruction(val));
|
||||
|
@ -1123,7 +1123,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
|
||||
Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) {
|
||||
return CheckShape(set_size,
|
||||
ShapeInference::InferSetDimensionSizeShape(
|
||||
set_size->operand(0)->shape(), set_size->dimension()));
|
||||
set_size->operand(0)->shape(),
|
||||
set_size->operand(1)->shape(), set_size->dimension()));
|
||||
}
|
||||
|
||||
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
|
||||
|
@ -2248,12 +2248,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
|
||||
const Shape& shape, int64 dimension) {
|
||||
const Shape& shape, const Shape& val_shape, int64 dimension) {
|
||||
if (dimension < 0 || dimension >= shape.rank()) {
|
||||
return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
|
||||
dimension);
|
||||
}
|
||||
|
||||
if (val_shape.rank() != 0 || val_shape.element_type() != S32) {
|
||||
return InvalidArgument(
|
||||
"SetDimensionSize's value has to be S32 scalar, got %s",
|
||||
val_shape.ToString());
|
||||
}
|
||||
// TODO(b/119580730): Remove this restriction when very large dimension size
|
||||
// is needed.
|
||||
if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
|
||||
|
@ -303,10 +303,13 @@ class ShapeInference {
|
||||
const Shape& updates_shape, const ProgramShape& to_apply_shape,
|
||||
const ScatterDimensionNumbers& scatter_dim_numbers);
|
||||
|
||||
// Helper that validates the given input shape to GetDimensionSize.
|
||||
static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape,
|
||||
int64 dimension);
|
||||
|
||||
static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& shape,
|
||||
// Helper that validates the given input shape to SetDimensionSize.
|
||||
static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape,
|
||||
const Shape& val_shape,
|
||||
int64 dimension);
|
||||
|
||||
// Helper function for creating a Window proto from user-supplied data.
|
||||
|
@ -1365,6 +1365,28 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
|
||||
EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
|
||||
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
||||
Shape val_shape = ShapeUtil::MakeShape(S32, {1});
|
||||
auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
|
||||
arg_shape, val_shape, /*dimension=*/0);
|
||||
|
||||
EXPECT_FALSE(inferred_status.ok());
|
||||
EXPECT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("value has to be S32 scalar"));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
|
||||
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
||||
Shape val_shape = ShapeUtil::MakeShape(U32, {});
|
||||
auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
|
||||
arg_shape, val_shape, /*dimension=*/0);
|
||||
|
||||
EXPECT_FALSE(inferred_status.ok());
|
||||
EXPECT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("value has to be S32 scalar"));
|
||||
}
|
||||
|
||||
// BatchMatMul with different batch dimension sizes fails.
|
||||
TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
|
||||
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
|
||||
|
Loading…
Reference in New Issue
Block a user