diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 9fdff922529..4bb5e055eea 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -543,6 +543,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/tools/optimize:operator_property", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index f283f9719e4..e9b4500380b 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -3691,8 +3691,12 @@ def LstmMandatoryInputsConstraint : PredOpTrait< def LstmOptionalPeepholeWeightConstraint : PredOpTrait< "the optional peephole weights should all be specified or none", - And<[TFL_TCopVTEtAreSameAt<9, 10, 16>, - TFL_TCopVTEtAreSameAt<9, 11, 16>]>>; + // Ignore input 9 (cell_to_input_weights) for LSTM with CIFG. + And<[ + TFL_TCopVTEtAreSameAt<10, 11, 16>, + Or<[TFL_TCopVTEtAreSameAt<9, 10, 16>, + And<[TypeIsPred<"input_to_input_weights", NoneType>, + TypeIsPred<"cell_to_input_weights", NoneType>]>]>]>>; def LstmProjectionWeightBiasConstraint : PredOpTrait< "either projection weight must be specified or both projection weight and " @@ -3702,13 +3706,25 @@ def LstmProjectionWeightBiasConstraint : PredOpTrait< TypeIsPred<"projection_bias", NoneType>]>, Neg>]>>; -// TODO(b/137798843): Need to add two additional constraints for both LSTM and +def LstmCifgInputConstraint : PredOpTrait< + "the cifg inputs should all be specified or none", + // If LSTM has combined input/forget gate, input 1, 5, 9, 12, 20 are all none + // or 1, 5, 12 should not be none. Inputs 9 and 20 depend on LSTM's variants. + Or<[ + And<[TypeIsPred<"input_to_input_weights", NoneType>, + TypeIsPred<"recurrent_to_input_weights", NoneType>, + TypeIsPred<"cell_to_input_weights", NoneType>, + TypeIsPred<"input_gate_bias", NoneType>, + TypeIsPred<"input_layer_norm_coefficients", NoneType>]>, + Neg, + TypeIsPred<"recurrent_to_input_weights", NoneType>, + TypeIsPred<"input_gate_bias", NoneType>]>>]>>; + + +// TODO(b/137798843): Need to add an additional constraint for both LSTM and // UnidirectionalSequenceLstm -// For coupling of input and forget gates (cifg): if cifg is true, -// tensor {1, 5, 9, 12, 20} are null; if cifg is -// false, tensors {1, 5, 12} are not null; tensor {9} is not null if -// additionally peephole = true; tensor {20} is not null if additionally layer -// norm = true. For layer norm: if layer norm is false, tensor {20, 21, 22, 23} +// For layer norm: if layer norm is false, tensor {20, 21, 22, 23} // are null; if layer norm is true, tensors {21, 22, 23} are not null; tensor // {20} is not null if additionally cifg = false. @@ -3759,6 +3775,7 @@ def TFL_LSTMOp : [LstmMandatoryInputsConstraint, LstmOptionalPeepholeWeightConstraint, LstmProjectionWeightBiasConstraint, + LstmCifgInputConstraint, LstmResultConstraint, TFL_OperandHasRank<2, 2>, // input_to_forget_weights TFL_OperandHasRank<3, 2>, // input_to_cell_weights @@ -3883,6 +3900,7 @@ def TFL_UnidirectionalSequenceLSTMOp : [LstmMandatoryInputsConstraint, LstmOptionalPeepholeWeightConstraint, LstmProjectionWeightBiasConstraint, + LstmCifgInputConstraint, LstmResultConstraint, TFL_OperandHasRankAtLeast<0, 2>, // input TFL_OperandIsNoneOrHasRank<1, 2>, // input_to_input_weights diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index b01edef4477..782d4ad70af 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -686,16 +686,37 @@ func @testUnidirectionalSequenceLstm(%arg0: tensor, %arg1: tensor, %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor +func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor, + %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, + %arg5: none, %arg6: tensor, %arg7: tensor, %arg8: tensor, + %arg9: none, %arg10: tensor, %arg11: tensor, + %arg12: none, %arg13: tensor, %arg14: tensor, %arg15: tensor, + %arg16: tensor, %arg17: tensor, + %arg18: tensor, %arg19: tensor, + %arg20: none, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, + // CHECK-SAME: %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, + // CHECK-SAME: %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : + // CHECK-SAME: (tensor, + // CHECK-SAME: none, tensor, tensor, tensor, none, tensor, tensor, tensor, + // CHECK-SAME: none, tensor, tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, + // CHECK-SAME: none, tensor, tensor, tensor) -> tensor + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, + %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, + %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, + %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor, + none, tensor, tensor, tensor, none, tensor, tensor, tensor, + none, tensor, tensor, none, tensor, tensor, tensor, + tensor, tensor, tensor, tensor, none, tensor, tensor, tensor) -> tensor return %0 : tensor } +// ----- + // CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates -func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor, %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor +func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -709,9 +730,8 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor, } // ----- + // CHECK-LABEL: testLstmIntermediates - - func @testLstmIntermediates(%arg0: tensor<1x528x!quant.uniform>, %arg1: tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform>, %arg10: tensor<2048x!quant.uniform>, %arg11: tensor<2048x!quant.uniform>, %arg12: tensor<2048x!quant.uniform>, %arg13: tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform>, %arg15: tensor<2048x!quant.uniform>, %arg16: tensor<2048x!quant.uniform>, %arg17: tensor<2048x!quant.uniform>, %arg18: tensor<2048x!quant.uniform>, %arg19: tensor<1x640x!quant.uniform>, %arg20: tensor<1x2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> { %cst = constant unit %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> @@ -744,7 +764,6 @@ func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform, %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { +func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor, + %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, + %arg5: none, %arg6: tensor, %arg7: tensor, %arg8: tensor, + %arg9: none, %arg10: tensor, %arg11: tensor, + %arg12: none, %arg13: tensor, %arg14: tensor, %arg15: tensor, + %arg16: tensor, %arg17: tensor, + %arg18: tensor, %arg19: tensor, + %arg20: none, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) - // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : + // CHECK-SAME: (tensor, + // CHECK-SAME: none, tensor, tensor, tensor, none, tensor, tensor, tensor, + // CHECK-SAME: none, tensor, tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, + // CHECK-SAME: none, tensor, tensor, tensor) -> tensor + %0 = "tfl.lstm"(%arg0, + %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, + %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, + %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, + none, tensor, tensor, tensor, none, tensor, tensor, tensor, + none, tensor, tensor, none, tensor, tensor, tensor, + tensor, tensor, tensor, tensor, none, tensor, tensor, tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir index 6b5a59965e3..aac58c9c43e 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir @@ -115,6 +115,87 @@ func @QuantizeWithoutNorm(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {t // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>} } +// CHECK-LABEL: QuantizeLstmCifg +func @QuantizeLstmCifg(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} { + %none = constant unit + %input = "quant.stats"(%arg0) {layerStats = dense<[-1.2, 1.5]> : tensor<2xf32>} : (tensor<1x5xf32>) -> tensor<1x5xf32> + %1 = "tfl.pseudo_const"() {value = dense<[[2.32939887, -0.623641372, -0.0191893689, 0.326861918, 0.734137893], [0.499284297, 1.25277913, 0.60228157, -1.39478016, 0.115529917]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32> + %2 = "tfl.pseudo_const"() {value = dense<[[0.839470446, 0.564852297, -0.80136007, -0.0372898243, 0.57127893], [-5.516230e-01, -1.082380e+00, 1.41860521, -0.92541927, -1.13971734]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32> + %3 = "tfl.pseudo_const"() {value = dense<[[-0.440826088, -0.0863231644, -0.707756281, -0.695703208, -1.87899077], [0.16942361, 0.206325337, 1.09067786, -2.18648934, 0.273400396]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32> + %5 = "tfl.pseudo_const"() {value = dense<[[-0.435141891, -0.940576493, 1.30446923, -1.02953017], [0.684501767, 0.363370508, -2.29151702, 2.41928673]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> + %6 = "tfl.pseudo_const"() {value = dense<[[0.270476967, 0.00706229592, 0.489950746, 1.05166924], [1.28193891, 0.273171216, 0.484176666, 1.11504579]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> + %7 = "tfl.pseudo_const"() {value = dense<[[-2.36692929, -3.483900e-01, 0.322934568, -1.56939185], [-5.623850e-01, -0.083735466, 1.73820043, 0.218063414]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> + %9 = "tfl.pseudo_const"() {value = dense<[-1.66391921, 1.14934266]> : tensor<2xf32>} : () -> tensor<2xf32> + %10 = "tfl.pseudo_const"() {value = dense<[-1.59288621, 0.904723584]> : tensor<2xf32>} : () -> tensor<2xf32> + %12 = "tfl.pseudo_const"() {value = dense<[-1.0347594, -1.09994471]> : tensor<2xf32>} : () -> tensor<2xf32> + %13 = "tfl.pseudo_const"() {value = dense<[-2.03072214, -1.63648951]> : tensor<2xf32>} : () -> tensor<2xf32> + %14 = "tfl.pseudo_const"() {value = dense<[-1.90073407, -0.286088765]> : tensor<2xf32>} : () -> tensor<2xf32> + %15 = "tfl.pseudo_const"() {value = dense<[[0.580187321, -1.72028887], [1.48392391, 0.859561979], [0.316514879, 0.81852132], [0.0933789983, 0.58165586]]> : tensor<4x2xf32>} : () -> tensor<4x2xf32> + %16 = "tfl.pseudo_const"() {value = dense<[-0.0432887711, -0.431485623, -0.307492912, -0.882515907]> : tensor<4xf32>} : () -> tensor<4xf32> + %recurrent_input = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32> + %recurrent_stats = "quant.stats"(%recurrent_input) {layerStats = dense<[-2.0, 1.0]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32> + %cell_input = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %20 = "tfl.pseudo_const"() {value = dense<[-0.76004064, -0.892570137]> : tensor<2xf32>} : () -> tensor<2xf32> + %21 = "tfl.pseudo_const"() {value = dense<[-0.330534697, -1.68513882]> : tensor<2xf32>} : () -> tensor<2xf32> + %22 = "tfl.pseudo_const"() {value = dense<[-0.896740913, -0.382640809]> : tensor<2xf32>} : () -> tensor<2xf32> + %23 = "tfl.lstm"(%input, + %none, %1, %2, %3, + %none, %5, %6, %7, + %none, %9, %10, + %none, %12, %13, %14, + %15, %16, + %recurrent_stats, %cell_stats, + %none, %20, %21, %22) ({}) { + cell_clip = 5.000000e+01 : f32, + effective_hidden_scale_intermediate = tensor>>, + fused_activation_function = "TANH", + input_to_cell_intermediate = tensor>>, + input_to_forget_intermediate = tensor>>, + input_to_output_intermediate = tensor>>, + proj_clip = 0.000000e+00 : f32,time_major = false} : ( + tensor<1x5xf32>, + none, tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>, + none, tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>, + none, tensor<2xf32>, tensor<2xf32>, + none, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, + tensor<4x2xf32>, tensor<4xf32>, + tensor<1x4xf32>, tensor<1x2xf32>, + none, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<*xf32> + %24 = "quant.stats"(%23) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + return %24 : tensor<*xf32> + +// CHECK: %[[none:.*]] = constant unit +// CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x5x!quant.uniform>) -> tensor<1x5xf32> +// CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform:f32, 0.018341723389512912>>) -> tensor<2x5xf32> +// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform:f32, 0.011170119751156785>>) -> tensor<2x5xf32> +// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform:f32, 0.017216451524749515>>) -> tensor<2x5xf32> +// CHECK-DAG: %[[input_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform:f32, 0.019049501794529713>>) -> tensor<2x4xf32> +// CHECK-DAG: %[[input_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform:f32, 0.010094007169167826>>) -> tensor<2x4xf32> +// CHECK-DAG: %[[input_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform:f32, 0.018637238525030179>>) -> tensor<2x4xf32> +// CHECK-DAG: %[[input_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform:f32, 5.0780334190922573E-5>>) -> tensor<2xf32> +// CHECK-DAG: %[[input_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform:f32, 4.8612512878442185E-5>>) -> tensor<2xf32> +// CHECK-DAG: %[[input_13:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform>) -> tensor<2xf32> +// CHECK-DAG: %[[input_14:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform>) -> tensor<2xf32> +// CHECK-DAG: %[[input_15:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform>) -> tensor<2xf32> +// CHECK-DAG: %[[input_16:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x2x!quant.uniform:f32, 0.013545581674951268>>) -> tensor<4x2xf32> +// CHECK-DAG: %[[input_17:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x!quant.uniform>) -> tensor<4xf32> +// CHECK-DAG: %[[input_18:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> +// CHECK-DAG: %[[input_19:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK-DAG: %[[input_21:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform:f32, 2.7239910213861512E-5>>) -> tensor<2xf32> +// CHECK-DAG: %[[input_22:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform:f32, 5.1427925095427339E-5>>) -> tensor<2xf32> +// CHECK-DAG: %[[input_23:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform:f32, 2.736719606284107E-5>>) -> tensor<2xf32> + +// CHECK: %[[lstm:.*]] = "tfl.lstm"(%[[input_0]], %[[none]], %[[input_2]], %[[input_3]], %[[input_4]], %[[none]], %[[input_6]], %[[input_7]], %[[input_8]], +// CHECK-SAME: %[[none]], %[[input_10]], %[[input_11]], %[[none]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]], +// CHECK-SAME: %[[none]], %[[input_21]], %[[input_22]], %[[input_23]]) +// CHECK-NEXT: effective_hidden_scale_intermediate = tensor> +// CHECK-SAME: input_to_cell_intermediate = tensor:f32, 1.2207403790398877E-4>> +// CHECK-SAME: input_to_forget_intermediate = tensor:f32, 4.8829615161595508E-4>> +// CHECK-SAME: input_to_output_intermediate = tensor:f32, 3.0518509475997192E-5>> + +// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>} +} // CHECK-LABEL: QuantizeUnidirectionalLstmFull func @QuantizeUnidirectionalLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h index ed92a3a178c..74c3e23a298 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/MathExtras.h" @@ -104,6 +105,30 @@ LogicalResult GetLstmProperty( !op.forget_layer_norm_coefficients().getType().template isa(); *op_property = operator_property::GetOperatorProperty(*lstm_variant); + + // TODO(b/176258587) move this to operator_property.cc if this is needed in + // other components, too. + bool use_cifg = + op.input_to_input_weights().getType().template isa(); + if (use_cifg) { + const absl::flat_hash_set cifg_non_inputs = {1, 5, 9, 12, 20}; + const int cifg_non_intermediate = 0; + op_property->inputs.erase( + std::remove_if( + op_property->inputs.begin(), op_property->inputs.end(), + [&](std::pair input) { + return cifg_non_inputs.find(input.first) != cifg_non_inputs.end(); + }), + op_property->inputs.end()); + op_property->intermediates.erase( + std::remove_if(op_property->intermediates.begin(), + op_property->intermediates.end(), + [&](std::pair + intermediate) { + return intermediate.first == cifg_non_intermediate; + }), + op_property->intermediates.end()); + } return success(); }