diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index 3b72a60f3c6..448a4f9eb5f 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -154,7 +154,7 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3 // ----- module { -func @inference_standard_lstm_7410(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} { +func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor %1 = "tf.Add"(%0, %arg5) : (tensor, tensor<40xf32>) -> tensor %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor, tensor<10x40xf32>) -> tensor @@ -165,7 +165,7 @@ func @inference_standard_lstm_7410(%arg0: tensor, %arg1: tensor, tensor, tensor, tensor, tensor } -// CHECK: func @inference_standard_lstm_7410([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} { +// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { // CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64> // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> @@ -181,7 +181,46 @@ func @inference_standard_lstm_7410(%arg0: tensor, %arg1: tensor, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = constant unit // CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( { -// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor -// CHECK: return [[VAL_21:%.*]] : tensor - +// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x?x10xf32> +// CHECK: return [[VAL_21:%.*]] : tensor<8x?x10xf32> +// CHECK: } +} + +// ----- + +module { +func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { + %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> + %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32> + %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32> + %3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor) -> tensor<8x8x10xf32> + %4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor) -> tensor<8x8x10xf32> + %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor + %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor + return %5, %4, %5, %5, %6 : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor +} + +// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { +// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32> +// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> +// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64> +// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32> +// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) +// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) +// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) +// CHECK: [[VAL_21:%.*]] = constant unit +// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( { +// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_23:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> +// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_25:%.*]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32> +// CHECK: return [[VAL_24]] : tensor<8x8x10xf32> +// CHECK: } } diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index f7f77a53529..6d8bfab0e6c 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -70,14 +70,14 @@ Value CreateNoneValue(OpBuilder* builder, mlir::Location location) { builder->getUnitAttr()); } -Value Transpose2D(OpBuilder* builder, Value value_to_transpose, - RankedTensorType type, mlir::Location location) { +Value Transpose(OpBuilder* builder, Value value_to_transpose, + SmallVector perm, RankedTensorType original_type, + mlir::Location location) { // Create a constant op for transpose permutation. - SmallVector perm = {1, 0}; auto perm_op = CreateI64DenseConst(builder, perm, perm, location); // Create tensor type for the transpose result. - auto transpose_type = type; + auto transpose_type = original_type; auto transpose_shape = functional::map( [transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); }, perm); @@ -88,6 +88,13 @@ Value Transpose2D(OpBuilder* builder, Value value_to_transpose, value_to_transpose, perm_op); } +Value Transpose2D(OpBuilder* builder, Value value_to_transpose, + RankedTensorType type, mlir::Location location) { + // Create a constant op for transpose permutation. + SmallVector perm = {1, 0}; + return Transpose(builder, value_to_transpose, perm, type, location); +} + ArrayRef GetRankedTensorShape(Value value) { return value.getType().cast().getShape(); } @@ -586,15 +593,30 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { Value recurrent_kernel = func_op.getArgument(4); Value bias = func_op.getArgument(5); - // Assume it's batch majored. + // TFL lstm only supports time-majored inputs, so if it's not time-majored, + // we will transpose the inputs and outputs. + auto time_major_attr = func_op.getAttrOfType("tf.time_major"); + if (time_major_attr == nullptr) return failure(); + + bool time_majored = time_major_attr.getValue(); auto input_type = input.getType().dyn_cast_or_null(); if (!input_type) { func_op.emitError() << "Input type is not a ranked tensor type"; return failure(); } - int batch = input_type.getDimSize(0); - int time = input_type.getDimSize(1); + auto final_inputs = input; + auto final_input_type = input_type; + // We will transpose the inputs. + if (!time_majored) { + SmallVector perm = {1, 0, 2}; + final_inputs = + Transpose(builder, final_inputs, perm, input_type, func_op.getLoc()); + final_input_type = final_inputs.getType().dyn_cast(); + } + + int batch = final_input_type.getDimSize(1); + int time = final_input_type.getDimSize(0); // Setup correct weights. RankedTensorType weight_type = @@ -672,7 +694,13 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { builder->getF32FloatAttr(10.0), builder->getF32FloatAttr(0.0), builder->getStringAttr("FULL")); - builder->create(func_op.getLoc(), lstm.getResult()); + auto final_output = lstm.getResult(); + if (!time_majored) { + SmallVector perm = {1, 0, 2}; + final_output = + Transpose(builder, final_output, perm, result_type, func_op.getLoc()); + } + builder->create(func_op.getLoc(), final_output); return success(); }