Merge pull request #42326 from WindQAQ:verify-and-fold-tile
PiperOrigin-RevId: 328953167 Change-Id: Ic8c0837f3b39051cb5791dc3f744b904cbc8b024
This commit is contained in:
commit
ccd2fff028
@ -11749,9 +11749,9 @@ array([[1, 2, 3, 1, 2, 3],
|
||||
TF_DerivedOperandTypeAttr Tmultiples = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
// TODO(parkers): Add folds for multiples = [1,...].
|
||||
// TODO(parkers): Add errors for negative multiples and multiples.size() !=
|
||||
// input.rank()
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
|
||||
|
||||
@ -1783,6 +1783,87 @@ static LogicalResult Verify(TensorScatterUpdateOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TileOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Verifies that,
|
||||
//
|
||||
// - input has at least rank 1
|
||||
// - multiples is rank 1
|
||||
// - multiples.size() == input.rank()
|
||||
// - input.rank() == output.rank()
|
||||
// - Elements in multiples are non-negative
|
||||
// - input.shape[i] * multiples[i] == output.shape[i]
|
||||
// for i in [0, input.rank() - 1]
|
||||
|
||||
static LogicalResult Verify(TileOp op) {
|
||||
auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
|
||||
auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
|
||||
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (multiples_type && multiples_type.getRank() != 1) {
|
||||
return op.emitOpError() << "expected multiples to be rank 1, got rank = "
|
||||
<< multiples_type.getRank();
|
||||
}
|
||||
|
||||
if (input_type && multiples_type && multiples_type.hasStaticShape() &&
|
||||
(input_type.getRank() != multiples_type.getNumElements() ||
|
||||
(input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
|
||||
return op.emitOpError()
|
||||
<< "expected size of multiples equal to rank of input"
|
||||
<< ", got multiples of size " << multiples_type.getNumElements()
|
||||
<< ", and input of rank " << input_type.getRank();
|
||||
}
|
||||
|
||||
if (input_type && output_type) {
|
||||
if (input_type.getRank() != output_type.getRank()) {
|
||||
return op.emitOpError()
|
||||
<< "expected rank of input to equal to rank of output"
|
||||
<< ", got input of rank " << input_type.getRank()
|
||||
<< ", and output of rank " << output_type.getRank();
|
||||
}
|
||||
|
||||
DenseIntElementsAttr multiples_attr;
|
||||
if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
|
||||
for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
|
||||
const int64_t input_dim = input_type.getDimSize(i);
|
||||
const int64_t output_dim = output_type.getDimSize(i);
|
||||
const int64_t m = multiples_attr.getValue<APInt>(i).getSExtValue();
|
||||
|
||||
if (m < 0) {
|
||||
return op.emitOpError()
|
||||
<< "expected multiples to be non-negative, got "
|
||||
<< "multiples[" << i << "] = " << m;
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(input_dim) &&
|
||||
!ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
|
||||
return op.emitOpError()
|
||||
<< "requires input.shape[" << i << "] (" << input_dim << ")"
|
||||
<< " * " << m << " to be equal to "
|
||||
<< "output.shape[" << i << "] (" << output_dim << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
|
||||
DenseIntElementsAttr multiples_attr;
|
||||
if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
|
||||
// Return input directly when multiples are all ones,
|
||||
// regardless what input is.
|
||||
if (multiples_attr.isSplat() &&
|
||||
multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
|
||||
return input();
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TopKV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -568,6 +568,14 @@ func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2:
|
||||
return %0: tensor<*xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testTileMultiplesAllOnes
|
||||
func @testTileMultiplesAllOnes(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%cst = constant dense <[1, 1]> : tensor<2xi32>
|
||||
// CHECK: return %arg0
|
||||
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32>
|
||||
return %0: tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLogicalNotOfEqual
|
||||
func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
|
||||
@ -3468,3 +3468,53 @@ func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTile(%arg0: tensor<2x3x?xf32>) {
|
||||
%cst = constant dense <[2, 3, 4]> : tensor<3xi32>
|
||||
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3x?xf32>, tensor<3xi32>) -> tensor<4x9x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTileMultipleNotRank1(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1xi32>) {
|
||||
// expected-error @+1 {{expected multiples to be rank 1, got rank = 2}}
|
||||
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<1x1xi32>) -> tensor<2x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTileInputRankNotEqualToMultiplesSize(%arg0: tensor<2x3xf32>, %arg1: tensor<3xi32>) {
|
||||
// expected-error @+1 {{expected size of multiples equal to rank of input, got multiples of size 3, and input of rank 2}}
|
||||
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<2x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTileInputRankNotEqualToOutputRank(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) {
|
||||
// expected-error @+1 {{expected rank of input to equal to rank of output, got input of rank 2, and output of rank 3}}
|
||||
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTileNegativeMultiples(%arg0: tensor<2x3xf32>) {
|
||||
%cst = constant dense <[-1, 1]> : tensor<2xi32>
|
||||
// expected-error @+1 {{expected multiples to be non-negative, got multiples[0] = -1}}
|
||||
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTileInvalidOutputShape(%arg0: tensor<2x3xf32>) {
|
||||
%cst = constant dense <[2, 3]> : tensor<2xi32>
|
||||
// expected-error @+1 {{requires input.shape[1] (3) * 3 to be equal to output.shape[1] (6)}}
|
||||
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<4x6xf32>
|
||||
return
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user