Add reshape op verification codes before applying op lowering rules.

PiperOrigin-RevId: 342772267
Change-Id: I94ae0941284452db249ba4e48e45b566a4154b4e
This commit is contained in:
Jaesung Chung 2020-11-16 19:25:49 -08:00 committed by TensorFlower Gardener
parent a01917bcfa
commit e933e37d90
3 changed files with 17 additions and 0 deletions

View File

@ -428,6 +428,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
"//tensorflow/compiler/mlir/tensorflow:verification_utils",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/xla:status",
@ -465,6 +466,7 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:verification_utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",

View File

@ -520,6 +520,17 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
return %1 : tensor<1x4x64x64xf32>
}
// CHECK-LABEL: @AvoidPadStridedSliceNewAxisMaskOnUnknownShapes
func @AvoidPadStridedSliceNewAxisMaskOnUnknownShapes(%arg0: tensor<?x?xf32>) -> tensor<1x?x?x1xf32> {
%cst = constant dense<0> : tensor<4xi32>
%cst_0 = constant dense<1> : tensor<4xi32>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<?x?xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x?x1xf32>
return %0 : tensor<1x?x?x1xf32>
// CHECK-NOT: "tf.Reshape"
// CHECK: "tf.StridedSlice"
}
// CHECK-LABEL: @StridedSliceRewriteMasks
func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf32> {
%cst = "tf.Const"() {device = "", value = dense<[1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>

View File

@ -64,6 +64,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#define DEBUG_TYPE "tf-tfl-legalization"
@ -540,6 +541,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
new_axis_mask >>= 1;
}
if (failed(TF::VerifyShapeOfReshapeOp(new_shape))) return failure();
const int dim_size = new_shape.size();
Location loc = strided_slice_op.getLoc();
auto shape_type =
@ -549,6 +552,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
result_shape_data[i] =
rewriter.getI32IntegerAttr(static_cast<int32_t>(new_shape[i]));
}
auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
auto new_output_type =