Check the value shape of set dimension size in shape inference.

PiperOrigin-RevId: 317351738
Change-Id: Ia185e7745753711ca9ebf657522ef7422c9696ca
This commit is contained in:
Yunxing Dai 2020-06-19 11:59:51 -07:00 committed by TensorFlower Gardener
parent f53e1aac65
commit a8456eae42
5 changed files with 38 additions and 5 deletions

View File

@ -2762,9 +2762,11 @@ XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape(
*operand_shape, dimension));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferSetDimensionSizeShape(
*operand_shape, *val_shape, dimension));
// Setting an op's dynamic dimension to the static size is a noop.
TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
LookUpInstruction(val));

View File

@ -1123,7 +1123,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) {
return CheckShape(set_size,
ShapeInference::InferSetDimensionSizeShape(
set_size->operand(0)->shape(), set_size->dimension()));
set_size->operand(0)->shape(),
set_size->operand(1)->shape(), set_size->dimension()));
}
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,

View File

@ -2248,12 +2248,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
const Shape& shape, int64 dimension) {
const Shape& shape, const Shape& val_shape, int64 dimension) {
if (dimension < 0 || dimension >= shape.rank()) {
return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
dimension);
}
if (val_shape.rank() != 0 || val_shape.element_type() != S32) {
return InvalidArgument(
"SetDimensionSize's value has to be S32 scalar, got %s",
val_shape.ToString());
}
// TODO(b/119580730): Remove this restriction when very large dimension size
// is needed.
if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {

View File

@ -303,10 +303,13 @@ class ShapeInference {
const Shape& updates_shape, const ProgramShape& to_apply_shape,
const ScatterDimensionNumbers& scatter_dim_numbers);
// Helper that validates the given input shape to GetDimensionSize.
static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape,
int64 dimension);
static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& shape,
// Helper that validates the given input shape to SetDimensionSize.
static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape,
const Shape& val_shape,
int64 dimension);
// Helper function for creating a Window proto from user-supplied data.

View File

@ -1365,6 +1365,28 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
}
TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
Shape val_shape = ShapeUtil::MakeShape(S32, {1});
auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
arg_shape, val_shape, /*dimension=*/0);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
HasSubstr("value has to be S32 scalar"));
}
TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
Shape val_shape = ShapeUtil::MakeShape(U32, {});
auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
arg_shape, val_shape, /*dimension=*/0);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
HasSubstr("value has to be S32 scalar"));
}
// BatchMatMul with different batch dimension sizes fails.
TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});