Merge pull request #42326 from WindQAQ:verify-and-fold-tile

PiperOrigin-RevId: 328953167
Change-Id: Ic8c0837f3b39051cb5791dc3f744b904cbc8b024
This commit is contained in:
TensorFlower Gardener 2020-08-28 09:43:30 -07:00
commit ccd2fff028
4 changed files with 142 additions and 3 deletions

View File

@ -11749,9 +11749,9 @@ array([[1, 2, 3, 1, 2, 3],
TF_DerivedOperandTypeAttr Tmultiples = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
// TODO(parkers): Add folds for multiples = [1,...].
// TODO(parkers): Add errors for negative multiples and multiples.size() !=
// input.rank()
let verifier = [{ return Verify(*this); }];
let hasFolder = 1;
}
def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {

View File

@ -1783,6 +1783,87 @@ static LogicalResult Verify(TensorScatterUpdateOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// TileOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// - input has at least rank 1
// - multiples is rank 1
// - multiples.size() == input.rank()
// - input.rank() == output.rank()
// - Elements in multiples are non-negative
// - input.shape[i] * multiples[i] == output.shape[i]
// for i in [0, input.rank() - 1]
static LogicalResult Verify(TileOp op) {
auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
if (multiples_type && multiples_type.getRank() != 1) {
return op.emitOpError() << "expected multiples to be rank 1, got rank = "
<< multiples_type.getRank();
}
if (input_type && multiples_type && multiples_type.hasStaticShape() &&
(input_type.getRank() != multiples_type.getNumElements() ||
(input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
return op.emitOpError()
<< "expected size of multiples equal to rank of input"
<< ", got multiples of size " << multiples_type.getNumElements()
<< ", and input of rank " << input_type.getRank();
}
if (input_type && output_type) {
if (input_type.getRank() != output_type.getRank()) {
return op.emitOpError()
<< "expected rank of input to equal to rank of output"
<< ", got input of rank " << input_type.getRank()
<< ", and output of rank " << output_type.getRank();
}
DenseIntElementsAttr multiples_attr;
if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
const int64_t input_dim = input_type.getDimSize(i);
const int64_t output_dim = output_type.getDimSize(i);
const int64_t m = multiples_attr.getValue<APInt>(i).getSExtValue();
if (m < 0) {
return op.emitOpError()
<< "expected multiples to be non-negative, got "
<< "multiples[" << i << "] = " << m;
}
if (!ShapedType::isDynamic(input_dim) &&
!ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
return op.emitOpError()
<< "requires input.shape[" << i << "] (" << input_dim << ")"
<< " * " << m << " to be equal to "
<< "output.shape[" << i << "] (" << output_dim << ")";
}
}
}
}
return success();
}
OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
DenseIntElementsAttr multiples_attr;
if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
// Return input directly when multiples are all ones,
// regardless what input is.
if (multiples_attr.isSplat() &&
multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
return input();
}
}
return {};
}
//===----------------------------------------------------------------------===//
// TopKV2Op
//===----------------------------------------------------------------------===//

View File

@ -568,6 +568,14 @@ func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2:
return %0: tensor<*xf16>
}
// CHECK-LABEL: testTileMultiplesAllOnes
func @testTileMultiplesAllOnes(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%cst = constant dense <[1, 1]> : tensor<2xi32>
// CHECK: return %arg0
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32>
return %0: tensor<2x3xf32>
}
// CHECK-LABEL: testLogicalNotOfEqual
func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>

View File

@ -3468,3 +3468,53 @@ func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}
// -----
func @testTile(%arg0: tensor<2x3x?xf32>) {
%cst = constant dense <[2, 3, 4]> : tensor<3xi32>
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3x?xf32>, tensor<3xi32>) -> tensor<4x9x?xf32>
return
}
// -----
func @testTileMultipleNotRank1(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1xi32>) {
// expected-error @+1 {{expected multiples to be rank 1, got rank = 2}}
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<1x1xi32>) -> tensor<2x3xf32>
return
}
// -----
func @testTileInputRankNotEqualToMultiplesSize(%arg0: tensor<2x3xf32>, %arg1: tensor<3xi32>) {
// expected-error @+1 {{expected size of multiples equal to rank of input, got multiples of size 3, and input of rank 2}}
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<2x3xf32>
return
}
// -----
func @testTileInputRankNotEqualToOutputRank(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) {
// expected-error @+1 {{expected rank of input to equal to rank of output, got input of rank 2, and output of rank 3}}
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3x1xf32>
return
}
// -----
func @testTileNegativeMultiples(%arg0: tensor<2x3xf32>) {
%cst = constant dense <[-1, 1]> : tensor<2xi32>
// expected-error @+1 {{expected multiples to be non-negative, got multiples[0] = -1}}
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32>
return
}
// -----
func @testTileInvalidOutputShape(%arg0: tensor<2x3xf32>) {
%cst = constant dense <[2, 3]> : tensor<2xi32>
// expected-error @+1 {{requires input.shape[1] (3) * 3 to be equal to output.shape[1] (6)}}
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<4x6xf32>
return
}