diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt new file mode 100644 index 00000000000..096033e37cb --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt @@ -0,0 +1,101 @@ +# RUN: tf_tfl_translate -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=2,5,3:3,7 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-output-arrays=MatMul -output-mlir %s -o - 2>&1 | FileCheck %s + +node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 2 + } + dim { + size: 5 + } + dim { + size: 3 + } + } + } + } +} +node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 7 + } + } + } + } +} +node { + name: "MatMul" + op: "BatchMatMulV2" + input: "Placeholder" + input: "Placeholder_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "adj_x" + value { + b: false + } + } + attr { + key: "adj_y" + value { + b: false + } + } +} +versions { + producer: 175 +} + +# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} { +# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32> +# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32> +# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32> +# CHECK: %[[VAL_5:.*]] = constant unit +# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32> +# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32> +# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32> +# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32> +# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32> +# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> +# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32> +# CHECK: %[[VAL_13:.*]] = "tfl.reshape"(%[[VAL_12]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> +# CHECK: %[[VAL_14:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_9]]) : (tensor<3x7xf32>, tensor<3xi32>) -> tensor<1x3x7xf32> +# CHECK: %[[VAL_15:.*]] = "tfl.slice"(%[[VAL_14]], %[[VAL_8]], %[[VAL_9]]) : (tensor<1x3x7xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x3x7xf32> +# CHECK: %[[VAL_16:.*]] = "tfl.reshape"(%[[VAL_15]], %[[VAL_4]]) : (tensor<1x3x7xf32>, tensor<2xi32>) -> tensor<3x7xf32> +# CHECK: %[[VAL_17:.*]] = "tfl.transpose"(%[[VAL_16]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32> +# CHECK: %[[VAL_18:.*]] = "tfl.fully_connected"(%[[VAL_11]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_19:.*]] = "tfl.fully_connected"(%[[VAL_13]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_20:.*]] = "tfl.pack"(%[[VAL_18]], %[[VAL_19]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32> +# CHECK: return %[[VAL_20]] : tensor<2x5x7xf32> +# CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index d3f1a430642..40420eee697 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -162,6 +162,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass( mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul)); pass_manager->addNestedPass(mlir::createCanonicalizerPass()); + if (pass_config.shape_inference) { + // Add a shape inference pass to optimize away the unnecessary casts. + pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); + } pass_manager->addPass( mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification)); pass_manager->addPass(mlir::TFL::CreateOptimizePass());