Check the value shape of set dimension size in shape inference.
PiperOrigin-RevId: 317351738 Change-Id: Ia185e7745753711ca9ebf657522ef7422c9696ca
This commit is contained in:
parent
f53e1aac65
commit
a8456eae42
@ -2762,9 +2762,11 @@ XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
|
|||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
|
TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape(
|
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||||
*operand_shape, dimension));
|
ShapeInference::InferSetDimensionSizeShape(
|
||||||
|
*operand_shape, *val_shape, dimension));
|
||||||
// Setting an op's dynamic dimension to the static size is a noop.
|
// Setting an op's dynamic dimension to the static size is a noop.
|
||||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
|
||||||
LookUpInstruction(val));
|
LookUpInstruction(val));
|
||||||
|
@ -1123,7 +1123,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
|
|||||||
Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) {
|
Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) {
|
||||||
return CheckShape(set_size,
|
return CheckShape(set_size,
|
||||||
ShapeInference::InferSetDimensionSizeShape(
|
ShapeInference::InferSetDimensionSizeShape(
|
||||||
set_size->operand(0)->shape(), set_size->dimension()));
|
set_size->operand(0)->shape(),
|
||||||
|
set_size->operand(1)->shape(), set_size->dimension()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
|
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
|
||||||
|
@ -2248,12 +2248,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
|
/* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
|
||||||
const Shape& shape, int64 dimension) {
|
const Shape& shape, const Shape& val_shape, int64 dimension) {
|
||||||
if (dimension < 0 || dimension >= shape.rank()) {
|
if (dimension < 0 || dimension >= shape.rank()) {
|
||||||
return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
|
return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
|
||||||
dimension);
|
dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (val_shape.rank() != 0 || val_shape.element_type() != S32) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"SetDimensionSize's value has to be S32 scalar, got %s",
|
||||||
|
val_shape.ToString());
|
||||||
|
}
|
||||||
// TODO(b/119580730): Remove this restriction when very large dimension size
|
// TODO(b/119580730): Remove this restriction when very large dimension size
|
||||||
// is needed.
|
// is needed.
|
||||||
if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
|
if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
|
||||||
|
@ -303,10 +303,13 @@ class ShapeInference {
|
|||||||
const Shape& updates_shape, const ProgramShape& to_apply_shape,
|
const Shape& updates_shape, const ProgramShape& to_apply_shape,
|
||||||
const ScatterDimensionNumbers& scatter_dim_numbers);
|
const ScatterDimensionNumbers& scatter_dim_numbers);
|
||||||
|
|
||||||
|
// Helper that validates the given input shape to GetDimensionSize.
|
||||||
static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape,
|
static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape,
|
||||||
int64 dimension);
|
int64 dimension);
|
||||||
|
|
||||||
static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& shape,
|
// Helper that validates the given input shape to SetDimensionSize.
|
||||||
|
static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape,
|
||||||
|
const Shape& val_shape,
|
||||||
int64 dimension);
|
int64 dimension);
|
||||||
|
|
||||||
// Helper function for creating a Window proto from user-supplied data.
|
// Helper function for creating a Window proto from user-supplied data.
|
||||||
|
@ -1365,6 +1365,28 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
|
|||||||
EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
|
EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
|
||||||
|
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
||||||
|
Shape val_shape = ShapeUtil::MakeShape(S32, {1});
|
||||||
|
auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
|
||||||
|
arg_shape, val_shape, /*dimension=*/0);
|
||||||
|
|
||||||
|
EXPECT_FALSE(inferred_status.ok());
|
||||||
|
EXPECT_THAT(inferred_status.status().error_message(),
|
||||||
|
HasSubstr("value has to be S32 scalar"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
|
||||||
|
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
||||||
|
Shape val_shape = ShapeUtil::MakeShape(U32, {});
|
||||||
|
auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
|
||||||
|
arg_shape, val_shape, /*dimension=*/0);
|
||||||
|
|
||||||
|
EXPECT_FALSE(inferred_status.ok());
|
||||||
|
EXPECT_THAT(inferred_status.status().error_message(),
|
||||||
|
HasSubstr("value has to be S32 scalar"));
|
||||||
|
}
|
||||||
|
|
||||||
// BatchMatMul with different batch dimension sizes fails.
|
// BatchMatMul with different batch dimension sizes fails.
|
||||||
TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
|
TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
|
||||||
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
|
||||||
|
Loading…
Reference in New Issue
Block a user