- Add layout optimization passes to TFLite passes. This will change data layout for some ops (e.g. conv2d) so it will change the layout to NHWC and can be converted to tflite.
- Update Transpose legalization from TF -> TFL to handle int64 permutation. - Update unit-test of legalization to also do CSE to simplify some tests and avoid duplicate tensors which makes problems when matching params (which one to match) PiperOrigin-RevId: 331672423 Change-Id: I39d943d197bb0288bf441f0c4dd80b9c9deb8257
This commit is contained in:
parent
8dee182572
commit
d8f9b4a0d2
@ -407,6 +407,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
|
||||
|
232
tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt
Normal file
232
tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt
Normal file
@ -0,0 +1,232 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_0 -print-function-result-mapping %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 8
|
||||
}
|
||||
dim {
|
||||
size: 8
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d/conv_2d_0/w"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
tensor_content: ";;\177<5\241i\275\312f\211>#\346j>\033W\325\275\253>\210=Vr\r\276\304\222\313\276\374\346\214>\016e\211>)\253\000>\3241\337\275\235g-\276*(\216\276\326#\367\274\023\213\300\276\227\031\206>PUF=\253\330\263<\337IL\276\334\320\215>\377\306v\276\372C\302\273baM>H\314\270<2\221\352=J\026{\276\221\243\245\276?\314\240=UW2\2755\207\253\274\256\207\333\273\335\372\227>\246\232;\276%\r\374<Z\346\204>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d/conv_2d_0/w/read"
|
||||
op: "Identity"
|
||||
input: "conv_net_2d/conv_2d_0/w"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@conv_net_2d/conv_2d_0/w"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d_1/conv_2d_0/convolution"
|
||||
op: "Conv2D"
|
||||
input: "input"
|
||||
input: "conv_net_2d/conv_2d_0/w/read"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "data_format"
|
||||
value {
|
||||
s: "NCHW"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dilations"
|
||||
value {
|
||||
list {
|
||||
i: 1
|
||||
i: 1
|
||||
i: 1
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "explicit_paddings"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "padding"
|
||||
value {
|
||||
s: "SAME"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "strides"
|
||||
value {
|
||||
list {
|
||||
i: 1
|
||||
i: 1
|
||||
i: 1
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_cudnn_on_gpu"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d/conv_2d_0/b"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
tensor_content: "\315\314\314=\315\314\314="
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d/conv_2d_0/b/read"
|
||||
op: "Identity"
|
||||
input: "conv_net_2d/conv_2d_0/b"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@conv_net_2d/conv_2d_0/b"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d_1/conv_2d_0/BiasAdd"
|
||||
op: "BiasAdd"
|
||||
input: "conv_net_2d_1/conv_2d_0/convolution"
|
||||
input: "conv_net_2d/conv_2d_0/b/read"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "data_format"
|
||||
value {
|
||||
s: "NHWC"
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d_1/Relu"
|
||||
op: "Relu"
|
||||
input: "conv_net_2d_1/conv_2d_0/BiasAdd"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "output_0"
|
||||
op: "Identity"
|
||||
input: "conv_net_2d_1/Relu"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
}
|
||||
|
||||
# CHECK: 'main' inputs:
|
||||
# CHECK-NEXT: name: 'input'
|
||||
# CHECK-NEXT: 'main' outputs:
|
||||
# CHECK-NEXT: name: 'output_0'
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s
|
||||
// RUN: tf-opt %s -tfl-legalize-tf --cse | FileCheck %s
|
||||
|
||||
func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
@ -196,7 +196,6 @@ func @shape(%arg0: tensor<?x1001xf32>) -> tensor<2xi32> {
|
||||
|
||||
// CHECK-LABEL: shape
|
||||
// CHECK: "tfl.shape"(%arg0) : (tensor<?x1001xf32>) -> tensor<2xi32>
|
||||
// CHECK: %1 = "tfl.shape"(%arg0) : (tensor<?x1001xf32>) -> tensor<2xi32>
|
||||
}
|
||||
|
||||
func @fill(%arg0: tensor<3xi32>, %arg1: tensor<f32>) -> tensor<?x?x?xf32> {
|
||||
@ -719,9 +718,8 @@ func @matrix_diag_v2_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
|
||||
// CHECK: return [[VAL_4]] : tensor<8x16x16xf32>
|
||||
}
|
||||
|
||||
@ -753,9 +751,8 @@ func @matrix_diag_v3_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
|
||||
// CHECK: return [[VAL_4]] : tensor<8x16x16xf32>
|
||||
}
|
||||
|
||||
@ -1047,8 +1044,7 @@ func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) ->
|
||||
// CHECK-LABEL: matmul_transposed_a
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[CST_2:.*]] = constant unit
|
||||
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
@ -1359,10 +1355,7 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
|
||||
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant unit
|
||||
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
|
||||
// CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[CST_2:.*]] = constant unit
|
||||
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2, %[[CST_2]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
|
||||
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
|
||||
// CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32>
|
||||
}
|
||||
@ -1568,3 +1561,27 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x
|
||||
// CHECK-LABEL: add_with_int32_5d_inputs
|
||||
// CHECK: "tf.Add"(%arg0, %arg1)
|
||||
}
|
||||
|
||||
func @tranpose_int32_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
%cst = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
%0 = "tf.Transpose"(%arg0, %cst): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %0 : tensor<3x2xf32>
|
||||
// CHECK-LABEL: tranpose_int32_perm
|
||||
// CHECK: "tfl.transpose"
|
||||
}
|
||||
|
||||
func @tranpose_int64_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
%cst = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%0 = "tf.Transpose"(%arg0, %cst): (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32>
|
||||
return %0 : tensor<3x2xf32>
|
||||
// CHECK-LABEL: tranpose_int64_perm
|
||||
// CHECK: "tfl.transpose"
|
||||
}
|
||||
|
||||
func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf32> {
|
||||
%0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %0 : tensor<3x2xf32>
|
||||
// CHECK-LABEL: tranpose_arg
|
||||
// CHECK: "tfl.transpose"
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
// RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s
|
||||
// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NHWC -tfl-prepare-tf | FileCheck --check-prefix=LAYOUT --dump-input=always %s
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||
|
||||
@ -53,6 +54,15 @@ func @depthwiseConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>, tensor<256x3
|
||||
// CHECK: %5 = "tf.DepthwiseConv2dNative"
|
||||
}
|
||||
|
||||
func @Conv2dNCHW(%arg0: tensor<256x3x32x32xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32> {
|
||||
%0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x3x32x32xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32>
|
||||
return %0 : tensor<256x16x30x30xf32>
|
||||
|
||||
// LAYOUT-LABEL: Conv2dNCHW
|
||||
// LAYOUT: "tfl.conv_2d"
|
||||
}
|
||||
|
||||
|
||||
func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
|
||||
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
|
||||
// OK
|
||||
|
@ -38,6 +38,10 @@ CreateTFExecutorToControlDialectConversion();
|
||||
} // namespace mlir
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// Data layout supported by TFLite.
|
||||
const char kTFLiteDataLayout[] = "NHWC";
|
||||
} // namespace
|
||||
|
||||
void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
mlir::OpPassManager* pass_manager) {
|
||||
@ -170,6 +174,12 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
if (pass_config.shape_inference) {
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
}
|
||||
// Force layout supported by TFLite, this will transpose the data
|
||||
// to match 'kTFLiteDataLayout'
|
||||
mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
|
||||
layout_optimization_options.force_data_format = kTFLiteDataLayout;
|
||||
mlir::TF::CreateLayoutOptimizationPipeline(*pass_manager,
|
||||
layout_optimization_options);
|
||||
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
|
||||
// the TFLite dialect.
|
||||
pass_manager->addPass(
|
||||
|
@ -27,6 +27,9 @@ def NonOpaqueElementsAttr : ElementsAttrBase<
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
|
||||
def Int64ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isInteger(64)">, "Int 64 constant tensor">;
|
||||
|
||||
// Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr
|
||||
// with builder.
|
||||
class ExtractI32At<int i> : NativeCodeCall<
|
||||
@ -50,6 +53,10 @@ def ExtractSingleElementAsInteger : NativeCodeCall<
|
||||
def ExtractSingleElementAsInt32 : NativeCodeCall<
|
||||
"$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">;
|
||||
|
||||
// Converts tensor with int64 to int32.
|
||||
def CreateCastToInt32 : NativeCodeCall<
|
||||
"CreateCastToInt32($0, $_loc, $_builder)">;
|
||||
|
||||
// Checks whether the given operation has static shapes and same shapes of all inputs.
|
||||
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
|
||||
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
|
||||
@ -208,8 +215,14 @@ def LegalizeSoftPlus : Pat<(TF_SoftplusOp F32Tensor:$arg0),
|
||||
def LegalizeSqueeze : Pat<(TF_SqueezeOp $arg, $squeeze_dims),
|
||||
(TFL_SqueezeOp $arg, $squeeze_dims)>;
|
||||
def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
|
||||
|
||||
def LegalizeTransposeInt64 : Pat<
|
||||
(TF_TransposeOp $arg, (ConstantOp Int64ElementsAttr:$perm)),
|
||||
(TFL_TransposeOp $arg, (CreateCastToInt32 $perm))>;
|
||||
|
||||
def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm),
|
||||
(TFL_TransposeOp $arg, $perm)>;
|
||||
|
||||
def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
|
||||
def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
|
||||
|
||||
|
@ -49,6 +49,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -116,6 +117,17 @@ bool HasSameStaticShapes(Operation* op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Util that casts 'val' to Int32 by adding a cast Op.
|
||||
Value CreateCastToInt32(Attribute val, Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
auto shape = val.getType().dyn_cast<RankedTensorType>().getShape();
|
||||
IntegerType new_ele_type = rewriter.getIntegerType(32);
|
||||
ShapedType new_type = RankedTensorType::get(shape, new_ele_type);
|
||||
return rewriter.create<TF::CastOp>(loc, new_type,
|
||||
rewriter.create<TF::ConstOp>(loc, val),
|
||||
rewriter.getBoolAttr(false));
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
|
||||
|
||||
#define DECL_CONVERT_OP(tf_op) \
|
||||
|
Loading…
x
Reference in New Issue
Block a user