Add shape inference pass before legalization: prepare-tf may introduce ops those do not have shaped tensor yet, we need to insert a shape inference pass to make proper shape propagation.
PiperOrigin-RevId: 312431443 Change-Id: Id1f4bc5d4b7df1bf2c84b69acd64e0bae567dd76
This commit is contained in:
parent
fe5ac4182b
commit
e9c9cfaf0a
|
@ -0,0 +1,101 @@
|
||||||
|
# RUN: tf_tfl_translate -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=2,5,3:3,7 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-output-arrays=MatMul -output-mlir %s -o - 2>&1 | FileCheck %s
|
||||||
|
|
||||||
|
node {
|
||||||
|
name: "Placeholder"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 5
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Placeholder_1"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 3
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "MatMul"
|
||||||
|
op: "BatchMatMulV2"
|
||||||
|
input: "Placeholder"
|
||||||
|
input: "Placeholder_1"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "adj_x"
|
||||||
|
value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "adj_y"
|
||||||
|
value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 175
|
||||||
|
}
|
||||||
|
|
||||||
|
# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} {
|
||||||
|
# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||||
|
# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
|
||||||
|
# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
|
||||||
|
# CHECK: %[[VAL_5:.*]] = constant unit
|
||||||
|
# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
|
||||||
|
# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
|
||||||
|
# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
|
||||||
|
# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
|
||||||
|
# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
|
||||||
|
# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
|
||||||
|
# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
|
||||||
|
# CHECK: %[[VAL_13:.*]] = "tfl.reshape"(%[[VAL_12]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
|
||||||
|
# CHECK: %[[VAL_14:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_9]]) : (tensor<3x7xf32>, tensor<3xi32>) -> tensor<1x3x7xf32>
|
||||||
|
# CHECK: %[[VAL_15:.*]] = "tfl.slice"(%[[VAL_14]], %[[VAL_8]], %[[VAL_9]]) : (tensor<1x3x7xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x3x7xf32>
|
||||||
|
# CHECK: %[[VAL_16:.*]] = "tfl.reshape"(%[[VAL_15]], %[[VAL_4]]) : (tensor<1x3x7xf32>, tensor<2xi32>) -> tensor<3x7xf32>
|
||||||
|
# CHECK: %[[VAL_17:.*]] = "tfl.transpose"(%[[VAL_16]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32>
|
||||||
|
# CHECK: %[[VAL_18:.*]] = "tfl.fully_connected"(%[[VAL_11]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32>
|
||||||
|
# CHECK: %[[VAL_19:.*]] = "tfl.fully_connected"(%[[VAL_13]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32>
|
||||||
|
# CHECK: %[[VAL_20:.*]] = "tfl.pack"(%[[VAL_18]], %[[VAL_19]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32>
|
||||||
|
# CHECK: return %[[VAL_20]] : tensor<2x5x7xf32>
|
||||||
|
# CHECK: }
|
|
@ -162,6 +162,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||||
pass_manager->addPass(
|
pass_manager->addPass(
|
||||||
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
||||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
|
if (pass_config.shape_inference) {
|
||||||
|
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||||
|
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||||
|
}
|
||||||
pass_manager->addPass(
|
pass_manager->addPass(
|
||||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
||||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||||
|
|
Loading…
Reference in New Issue