Handle LSTM with cifg in MLIR quantizer.

* Added CIFG constraints to MLIR op definition.
* modified ops.mlir.test to check cifg constraint.

Currently, this logic is not merged to operator_property.cc to keep it simple, and also calibration step relied on size of `OperatorProperty::intermediates` to allocate necessary intermediate tensors.

PiperOrigin-RevId: 350266407
Change-Id: I1f1ea9318743054bf2db9687915cc7a73c62cdb1
This commit is contained in:
Taehee Jeong 2021-01-05 19:19:42 -08:00 committed by TensorFlower Gardener
parent 60ba158560
commit d57f956517
5 changed files with 181 additions and 20 deletions

View File

@ -543,6 +543,7 @@ cc_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/optimize:operator_property", "//tensorflow/lite/tools/optimize:operator_property",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",

View File

@ -3691,8 +3691,12 @@ def LstmMandatoryInputsConstraint : PredOpTrait<
def LstmOptionalPeepholeWeightConstraint : PredOpTrait< def LstmOptionalPeepholeWeightConstraint : PredOpTrait<
"the optional peephole weights should all be specified or none", "the optional peephole weights should all be specified or none",
And<[TFL_TCopVTEtAreSameAt<9, 10, 16>, // Ignore input 9 (cell_to_input_weights) for LSTM with CIFG.
TFL_TCopVTEtAreSameAt<9, 11, 16>]>>; And<[
TFL_TCopVTEtAreSameAt<10, 11, 16>,
Or<[TFL_TCopVTEtAreSameAt<9, 10, 16>,
And<[TypeIsPred<"input_to_input_weights", NoneType>,
TypeIsPred<"cell_to_input_weights", NoneType>]>]>]>>;
def LstmProjectionWeightBiasConstraint : PredOpTrait< def LstmProjectionWeightBiasConstraint : PredOpTrait<
"either projection weight must be specified or both projection weight and " "either projection weight must be specified or both projection weight and "
@ -3702,13 +3706,25 @@ def LstmProjectionWeightBiasConstraint : PredOpTrait<
TypeIsPred<"projection_bias", NoneType>]>, TypeIsPred<"projection_bias", NoneType>]>,
Neg<TypeIsPred<"projection_weights", NoneType>>]>>; Neg<TypeIsPred<"projection_weights", NoneType>>]>>;
// TODO(b/137798843): Need to add two additional constraints for both LSTM and def LstmCifgInputConstraint : PredOpTrait<
"the cifg inputs should all be specified or none",
// If LSTM has combined input/forget gate, input 1, 5, 9, 12, 20 are all none
// or 1, 5, 12 should not be none. Inputs 9 and 20 depend on LSTM's variants.
Or<[
And<[TypeIsPred<"input_to_input_weights", NoneType>,
TypeIsPred<"recurrent_to_input_weights", NoneType>,
TypeIsPred<"cell_to_input_weights", NoneType>,
TypeIsPred<"input_gate_bias", NoneType>,
TypeIsPred<"input_layer_norm_coefficients", NoneType>]>,
Neg<Or<[
TypeIsPred<"input_to_input_weights", NoneType>,
TypeIsPred<"recurrent_to_input_weights", NoneType>,
TypeIsPred<"input_gate_bias", NoneType>]>>]>>;
// TODO(b/137798843): Need to add an additional constraint for both LSTM and
// UnidirectionalSequenceLstm // UnidirectionalSequenceLstm
// For coupling of input and forget gates (cifg): if cifg is true, // For layer norm: if layer norm is false, tensor {20, 21, 22, 23}
// tensor {1, 5, 9, 12, 20} are null; if cifg is
// false, tensors {1, 5, 12} are not null; tensor {9} is not null if
// additionally peephole = true; tensor {20} is not null if additionally layer
// norm = true. For layer norm: if layer norm is false, tensor {20, 21, 22, 23}
// are null; if layer norm is true, tensors {21, 22, 23} are not null; tensor // are null; if layer norm is true, tensors {21, 22, 23} are not null; tensor
// {20} is not null if additionally cifg = false. // {20} is not null if additionally cifg = false.
@ -3759,6 +3775,7 @@ def TFL_LSTMOp :
[LstmMandatoryInputsConstraint, [LstmMandatoryInputsConstraint,
LstmOptionalPeepholeWeightConstraint, LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint, LstmProjectionWeightBiasConstraint,
LstmCifgInputConstraint,
LstmResultConstraint, LstmResultConstraint,
TFL_OperandHasRank<2, 2>, // input_to_forget_weights TFL_OperandHasRank<2, 2>, // input_to_forget_weights
TFL_OperandHasRank<3, 2>, // input_to_cell_weights TFL_OperandHasRank<3, 2>, // input_to_cell_weights
@ -3883,6 +3900,7 @@ def TFL_UnidirectionalSequenceLSTMOp :
[LstmMandatoryInputsConstraint, [LstmMandatoryInputsConstraint,
LstmOptionalPeepholeWeightConstraint, LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint, LstmProjectionWeightBiasConstraint,
LstmCifgInputConstraint,
LstmResultConstraint, LstmResultConstraint,
TFL_OperandHasRankAtLeast<0, 2>, // input TFL_OperandHasRankAtLeast<0, 2>, // input
TFL_OperandIsNoneOrHasRank<1, 2>, // input_to_input_weights TFL_OperandIsNoneOrHasRank<1, 2>, // input_to_input_weights

View File

@ -686,16 +686,37 @@ func @testUnidirectionalSequenceLstm(%arg0: tensor<? x ? x f32>, %arg1: tensor<?
// ----- // -----
// CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr // CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>,
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>,
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %arg5: none, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>,
%arg9: none, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>,
%arg12: none, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>,
%arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>,
%arg18: tensor<? x f32>, %arg19: tensor<? x f32>,
%arg20: none, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0,
// CHECK-SAME: %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15,
// CHECK-SAME: %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} :
// CHECK-SAME: (tensor<?x?xf32>,
// CHECK-SAME: none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
// CHECK-SAME: none, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
// CHECK-SAME: none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0,
%arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8,
%arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15,
%arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>,
none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
none, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates // CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
@ -709,9 +730,8 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>,
} }
// ----- // -----
// CHECK-LABEL: testLstmIntermediates // CHECK-LABEL: testLstmIntermediates
func @testLstmIntermediates(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> { func @testLstmIntermediates(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
%cst = constant unit %cst = constant unit
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
@ -744,7 +764,6 @@ func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.0372480
// CHECK: return %[[RES1]] // CHECK: return %[[RES1]]
} }
// ----- // -----
// CHECK-LABEL: testLstm // CHECK-LABEL: testLstm
@ -774,10 +793,27 @@ func @testQuantizedBasicLstm(%arg0: tensor<1x384x!quant.uniform<u8:f32, 7.812500
// ----- // -----
// CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr // CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr
func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>,
%arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>,
%arg5: none, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>,
%arg9: none, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>,
%arg12: none, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>,
%arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>,
%arg18: tensor<? x f32>, %arg19: tensor<? x f32>,
%arg20: none, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
// CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} :
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> // CHECK-SAME: (tensor<?xf32>,
// CHECK-SAME: none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
// CHECK-SAME: none, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
// CHECK-SAME: none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0,
%arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8,
%arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15,
%arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>,
none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
none, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }

View File

@ -115,6 +115,87 @@ func @QuantizeWithoutNorm(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {t
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>} // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
} }
// CHECK-LABEL: QuantizeLstmCifg
func @QuantizeLstmCifg(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} {
%none = constant unit
%input = "quant.stats"(%arg0) {layerStats = dense<[-1.2, 1.5]> : tensor<2xf32>} : (tensor<1x5xf32>) -> tensor<1x5xf32>
%1 = "tfl.pseudo_const"() {value = dense<[[2.32939887, -0.623641372, -0.0191893689, 0.326861918, 0.734137893], [0.499284297, 1.25277913, 0.60228157, -1.39478016, 0.115529917]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
%2 = "tfl.pseudo_const"() {value = dense<[[0.839470446, 0.564852297, -0.80136007, -0.0372898243, 0.57127893], [-5.516230e-01, -1.082380e+00, 1.41860521, -0.92541927, -1.13971734]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
%3 = "tfl.pseudo_const"() {value = dense<[[-0.440826088, -0.0863231644, -0.707756281, -0.695703208, -1.87899077], [0.16942361, 0.206325337, 1.09067786, -2.18648934, 0.273400396]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
%5 = "tfl.pseudo_const"() {value = dense<[[-0.435141891, -0.940576493, 1.30446923, -1.02953017], [0.684501767, 0.363370508, -2.29151702, 2.41928673]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
%6 = "tfl.pseudo_const"() {value = dense<[[0.270476967, 0.00706229592, 0.489950746, 1.05166924], [1.28193891, 0.273171216, 0.484176666, 1.11504579]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
%7 = "tfl.pseudo_const"() {value = dense<[[-2.36692929, -3.483900e-01, 0.322934568, -1.56939185], [-5.623850e-01, -0.083735466, 1.73820043, 0.218063414]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
%9 = "tfl.pseudo_const"() {value = dense<[-1.66391921, 1.14934266]> : tensor<2xf32>} : () -> tensor<2xf32>
%10 = "tfl.pseudo_const"() {value = dense<[-1.59288621, 0.904723584]> : tensor<2xf32>} : () -> tensor<2xf32>
%12 = "tfl.pseudo_const"() {value = dense<[-1.0347594, -1.09994471]> : tensor<2xf32>} : () -> tensor<2xf32>
%13 = "tfl.pseudo_const"() {value = dense<[-2.03072214, -1.63648951]> : tensor<2xf32>} : () -> tensor<2xf32>
%14 = "tfl.pseudo_const"() {value = dense<[-1.90073407, -0.286088765]> : tensor<2xf32>} : () -> tensor<2xf32>
%15 = "tfl.pseudo_const"() {value = dense<[[0.580187321, -1.72028887], [1.48392391, 0.859561979], [0.316514879, 0.81852132], [0.0933789983, 0.58165586]]> : tensor<4x2xf32>} : () -> tensor<4x2xf32>
%16 = "tfl.pseudo_const"() {value = dense<[-0.0432887711, -0.431485623, -0.307492912, -0.882515907]> : tensor<4xf32>} : () -> tensor<4xf32>
%recurrent_input = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
%recurrent_stats = "quant.stats"(%recurrent_input) {layerStats = dense<[-2.0, 1.0]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
%cell_input = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
%cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
%20 = "tfl.pseudo_const"() {value = dense<[-0.76004064, -0.892570137]> : tensor<2xf32>} : () -> tensor<2xf32>
%21 = "tfl.pseudo_const"() {value = dense<[-0.330534697, -1.68513882]> : tensor<2xf32>} : () -> tensor<2xf32>
%22 = "tfl.pseudo_const"() {value = dense<[-0.896740913, -0.382640809]> : tensor<2xf32>} : () -> tensor<2xf32>
%23 = "tfl.lstm"(%input,
%none, %1, %2, %3,
%none, %5, %6, %7,
%none, %9, %10,
%none, %12, %13, %14,
%15, %16,
%recurrent_stats, %cell_stats,
%none, %20, %21, %22) ({}) {
cell_clip = 5.000000e+01 : f32,
effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>,
fused_activation_function = "TANH",
input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>,
input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>,
input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>,
proj_clip = 0.000000e+00 : f32,time_major = false} : (
tensor<1x5xf32>,
none, tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>,
none, tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>,
none, tensor<2xf32>, tensor<2xf32>,
none, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
tensor<4x2xf32>, tensor<4xf32>,
tensor<1x4xf32>, tensor<1x2xf32>,
none, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<*xf32>
%24 = "quant.stats"(%23) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32>
return %24 : tensor<*xf32>
// CHECK: %[[none:.*]] = constant unit
// CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x5x!quant.uniform<i8:f32, 0.010588235481112611:-15>>) -> tensor<1x5xf32>
// CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.018341723389512912>>) -> tensor<2x5xf32>
// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.011170119751156785>>) -> tensor<2x5xf32>
// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform<i8<-127:127>:f32, 0.017216451524749515>>) -> tensor<2x5xf32>
// CHECK-DAG: %[[input_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.019049501794529713>>) -> tensor<2x4xf32>
// CHECK-DAG: %[[input_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.010094007169167826>>) -> tensor<2x4xf32>
// CHECK-DAG: %[[input_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x4x!quant.uniform<i8<-127:127>:f32, 0.018637238525030179>>) -> tensor<2x4xf32>
// CHECK-DAG: %[[input_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.0780334190922573E-5>>) -> tensor<2xf32>
// CHECK-DAG: %[[input_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 4.8612512878442185E-5>>) -> tensor<2xf32>
// CHECK-DAG: %[[input_13:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 2.6601474818224132E-8>>) -> tensor<2xf32>
// CHECK-DAG: %[[input_14:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 5.0222583101003261E-8>>) -> tensor<2xf32>
// CHECK-DAG: %[[input_15:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 2.6725777405118232E-8>>) -> tensor<2xf32>
// CHECK-DAG: %[[input_16:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x2x!quant.uniform<i8<-127:127>:f32, 0.013545581674951268>>) -> tensor<4x2xf32>
// CHECK-DAG: %[[input_17:.*]] = "tfl.dequantize"({{.*}}) : (tensor<4x!quant.uniform<i32:f32, 5.3119928137063791E-5>>) -> tensor<4xf32>
// CHECK-DAG: %[[input_18:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform<i8:f32, 0.015686274509803921:-1>>) -> tensor<1x4xf32>
// CHECK-DAG: %[[input_19:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x2x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x2xf32>
// CHECK-DAG: %[[input_21:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.7239910213861512E-5>>) -> tensor<2xf32>
// CHECK-DAG: %[[input_22:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 5.1427925095427339E-5>>) -> tensor<2xf32>
// CHECK-DAG: %[[input_23:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i16<-32767:32767>:f32, 2.736719606284107E-5>>) -> tensor<2xf32>
// CHECK: %[[lstm:.*]] = "tfl.lstm"(%[[input_0]], %[[none]], %[[input_2]], %[[input_3]], %[[input_4]], %[[none]], %[[input_6]], %[[input_7]], %[[input_8]],
// CHECK-SAME: %[[none]], %[[input_10]], %[[input_11]], %[[none]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]],
// CHECK-SAME: %[[none]], %[[input_21]], %[[input_22]], %[[input_23]])
// CHECK-NEXT: effective_hidden_scale_intermediate = tensor<!quant.uniform<i8:f32, 0.0039215686274509803:-1>>
// CHECK-SAME: input_to_cell_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 1.2207403790398877E-4>>
// CHECK-SAME: input_to_forget_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 4.8829615161595508E-4>>
// CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>>
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
}
// CHECK-LABEL: QuantizeUnidirectionalLstmFull // CHECK-LABEL: QuantizeUnidirectionalLstmFull
func @QuantizeUnidirectionalLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} { func @QuantizeUnidirectionalLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/container/flat_hash_set.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/MathExtras.h" #include "llvm/Support/MathExtras.h"
@ -104,6 +105,30 @@ LogicalResult GetLstmProperty(
!op.forget_layer_norm_coefficients().getType().template isa<NoneType>(); !op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
*op_property = operator_property::GetOperatorProperty(*lstm_variant); *op_property = operator_property::GetOperatorProperty(*lstm_variant);
// TODO(b/176258587) move this to operator_property.cc if this is needed in
// other components, too.
bool use_cifg =
op.input_to_input_weights().getType().template isa<NoneType>();
if (use_cifg) {
const absl::flat_hash_set<int> cifg_non_inputs = {1, 5, 9, 12, 20};
const int cifg_non_intermediate = 0;
op_property->inputs.erase(
std::remove_if(
op_property->inputs.begin(), op_property->inputs.end(),
[&](std::pair<int, operator_property::TensorProperty> input) {
return cifg_non_inputs.find(input.first) != cifg_non_inputs.end();
}),
op_property->inputs.end());
op_property->intermediates.erase(
std::remove_if(op_property->intermediates.begin(),
op_property->intermediates.end(),
[&](std::pair<int, operator_property::TensorProperty>
intermediate) {
return intermediate.first == cifg_non_intermediate;
}),
op_property->intermediates.end());
}
return success(); return success();
} }