diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 03ae23ea18b..56e9aba6112 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2762,9 +2762,11 @@ XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val)); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape( - *operand_shape, dimension)); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferSetDimensionSizeShape( + *operand_shape, *val_shape, dimension)); // Setting an op's dynamic dimension to the static size is a noop. TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto, LookUpInstruction(val)); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 4661b8fd9e3..d8baebd6fdd 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1123,7 +1123,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) { Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) { return CheckShape(set_size, ShapeInference::InferSetDimensionSizeShape( - set_size->operand(0)->shape(), set_size->dimension())); + set_size->operand(0)->shape(), + set_size->operand(1)->shape(), set_size->dimension())); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 75a80747c1d..bb4a38ded1e 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2248,12 +2248,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferSetDimensionSizeShape( - const Shape& shape, int64 dimension) { + const Shape& shape, const Shape& val_shape, int64 dimension) { if (dimension < 0 || dimension >= shape.rank()) { return InvalidArgument("SetDimensionSize dimension out of bounds: %d.", dimension); } + if (val_shape.rank() != 0 || val_shape.element_type() != S32) { + return InvalidArgument( + "SetDimensionSize's value has to be S32 scalar, got %s", + val_shape.ToString()); + } // TODO(b/119580730): Remove this restriction when very large dimension size // is needed. if (shape.dimensions(dimension) > std::numeric_limits::max()) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 2cb5930d098..d47d96ab52d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -303,10 +303,13 @@ class ShapeInference { const Shape& updates_shape, const ProgramShape& to_apply_shape, const ScatterDimensionNumbers& scatter_dim_numbers); + // Helper that validates the given input shape to GetDimensionSize. static StatusOr InferGetDimensionSizeShape(const Shape& shape, int64 dimension); - static StatusOr InferSetDimensionSizeShape(const Shape& shape, + // Helper that validates the given input shape to SetDimensionSize. + static StatusOr InferSetDimensionSizeShape(const Shape& operand_shape, + const Shape& val_shape, int64 dimension); // Helper function for creating a Window proto from user-supplied data. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index b5ecf6e583e..916d3ab15c8 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1365,6 +1365,28 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) { EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape)); } +TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) { + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape val_shape = ShapeUtil::MakeShape(S32, {1}); + auto inferred_status = ShapeInference::InferSetDimensionSizeShape( + arg_shape, val_shape, /*dimension=*/0); + + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("value has to be S32 scalar")); +} + +TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) { + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape val_shape = ShapeUtil::MakeShape(U32, {}); + auto inferred_status = ShapeInference::InferSetDimensionSizeShape( + arg_shape, val_shape, /*dimension=*/0); + + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("value has to be S32 scalar")); +} + // BatchMatMul with different batch dimension sizes fails. TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});