Check the value shape of set dimension size in shape inference.
PiperOrigin-RevId: 317351738 Change-Id: Ia185e7745753711ca9ebf657522ef7422c9696ca
This commit is contained in:
		
							parent
							
								
									f53e1aac65
								
							
						
					
					
						commit
						a8456eae42
					
				@ -2762,9 +2762,11 @@ XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
 | 
			
		||||
  return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
 | 
			
		||||
    HloInstructionProto instr;
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
 | 
			
		||||
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape(
 | 
			
		||||
                                         *operand_shape, dimension));
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(Shape shape,
 | 
			
		||||
                        ShapeInference::InferSetDimensionSizeShape(
 | 
			
		||||
                            *operand_shape, *val_shape, dimension));
 | 
			
		||||
    // Setting an op's dynamic dimension to the static size is a noop.
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
 | 
			
		||||
                        LookUpInstruction(val));
 | 
			
		||||
 | 
			
		||||
@ -1123,7 +1123,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
 | 
			
		||||
Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) {
 | 
			
		||||
  return CheckShape(set_size,
 | 
			
		||||
                    ShapeInference::InferSetDimensionSizeShape(
 | 
			
		||||
                        set_size->operand(0)->shape(), set_size->dimension()));
 | 
			
		||||
                        set_size->operand(0)->shape(),
 | 
			
		||||
                        set_size->operand(1)->shape(), set_size->dimension()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
 | 
			
		||||
 | 
			
		||||
@ -2248,12 +2248,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
 | 
			
		||||
    const Shape& shape, int64 dimension) {
 | 
			
		||||
    const Shape& shape, const Shape& val_shape, int64 dimension) {
 | 
			
		||||
  if (dimension < 0 || dimension >= shape.rank()) {
 | 
			
		||||
    return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
 | 
			
		||||
                           dimension);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (val_shape.rank() != 0 || val_shape.element_type() != S32) {
 | 
			
		||||
    return InvalidArgument(
 | 
			
		||||
        "SetDimensionSize's value has to be S32 scalar, got %s",
 | 
			
		||||
        val_shape.ToString());
 | 
			
		||||
  }
 | 
			
		||||
  // TODO(b/119580730): Remove this restriction when very large dimension size
 | 
			
		||||
  // is needed.
 | 
			
		||||
  if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
 | 
			
		||||
 | 
			
		||||
@ -303,10 +303,13 @@ class ShapeInference {
 | 
			
		||||
      const Shape& updates_shape, const ProgramShape& to_apply_shape,
 | 
			
		||||
      const ScatterDimensionNumbers& scatter_dim_numbers);
 | 
			
		||||
 | 
			
		||||
  // Helper that validates the given input shape to GetDimensionSize.
 | 
			
		||||
  static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape,
 | 
			
		||||
                                                    int64 dimension);
 | 
			
		||||
 | 
			
		||||
  static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& shape,
 | 
			
		||||
  // Helper that validates the given input shape to SetDimensionSize.
 | 
			
		||||
  static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape,
 | 
			
		||||
                                                    const Shape& val_shape,
 | 
			
		||||
                                                    int64 dimension);
 | 
			
		||||
 | 
			
		||||
  // Helper function for creating a Window proto from user-supplied data.
 | 
			
		||||
 | 
			
		||||
@ -1365,6 +1365,28 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
 | 
			
		||||
  Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
 | 
			
		||||
  Shape val_shape = ShapeUtil::MakeShape(S32, {1});
 | 
			
		||||
  auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
 | 
			
		||||
      arg_shape, val_shape, /*dimension=*/0);
 | 
			
		||||
 | 
			
		||||
  EXPECT_FALSE(inferred_status.ok());
 | 
			
		||||
  EXPECT_THAT(inferred_status.status().error_message(),
 | 
			
		||||
              HasSubstr("value has to be S32 scalar"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
 | 
			
		||||
  Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
 | 
			
		||||
  Shape val_shape = ShapeUtil::MakeShape(U32, {});
 | 
			
		||||
  auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
 | 
			
		||||
      arg_shape, val_shape, /*dimension=*/0);
 | 
			
		||||
 | 
			
		||||
  EXPECT_FALSE(inferred_status.ok());
 | 
			
		||||
  EXPECT_THAT(inferred_status.status().error_message(),
 | 
			
		||||
              HasSubstr("value has to be S32 scalar"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BatchMatMul with different batch dimension sizes fails.
 | 
			
		||||
TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
 | 
			
		||||
  Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user