Add reshape op verification codes before applying op lowering rules.
PiperOrigin-RevId: 342772267 Change-Id: I94ae0941284452db249ba4e48e45b566a4154b4e
This commit is contained in:
parent
a01917bcfa
commit
e933e37d90
@ -428,6 +428,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:verification_utils",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
@ -465,6 +466,7 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:verification_utils",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
@ -520,6 +520,17 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
|
||||
return %1 : tensor<1x4x64x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @AvoidPadStridedSliceNewAxisMaskOnUnknownShapes
|
||||
func @AvoidPadStridedSliceNewAxisMaskOnUnknownShapes(%arg0: tensor<?x?xf32>) -> tensor<1x?x?x1xf32> {
|
||||
%cst = constant dense<0> : tensor<4xi32>
|
||||
%cst_0 = constant dense<1> : tensor<4xi32>
|
||||
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<?x?xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x?x1xf32>
|
||||
return %0 : tensor<1x?x?x1xf32>
|
||||
|
||||
// CHECK-NOT: "tf.Reshape"
|
||||
// CHECK: "tf.StridedSlice"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @StridedSliceRewriteMasks
|
||||
func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf32> {
|
||||
%cst = "tf.Const"() {device = "", value = dense<[1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
|
@ -64,6 +64,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||
@ -540,6 +541,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
new_axis_mask >>= 1;
|
||||
}
|
||||
|
||||
if (failed(TF::VerifyShapeOfReshapeOp(new_shape))) return failure();
|
||||
|
||||
const int dim_size = new_shape.size();
|
||||
Location loc = strided_slice_op.getLoc();
|
||||
auto shape_type =
|
||||
@ -549,6 +552,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
result_shape_data[i] =
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(new_shape[i]));
|
||||
}
|
||||
|
||||
auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
|
||||
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
|
||||
auto new_output_type =
|
||||
|
Loading…
x
Reference in New Issue
Block a user