Add hybrid quantization support for BatchMatMul during conversion.
PiperOrigin-RevId: 344098609 Change-Id: Ib00f38d3bb7403fbd9d21169301c0a080e632d17
This commit is contained in:
		
							parent
							
								
									256556c132
								
							
						
					
					
						commit
						f70160322a
					
				| @ -49,6 +49,7 @@ | ||||
|         *   Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators. | ||||
|     *   Added support for saved model's session initializer through | ||||
|          `TFLiteConverter.from_saved_model`. | ||||
|     *   Added dynamic range quantization support for the BatchMatMul op. | ||||
| 
 | ||||
| *   TF Core: | ||||
|     *   Corrected higher-order gradients of control flow constructs (`tf.cond`, | ||||
|  | ||||
| @ -213,6 +213,10 @@ typedef struct { | ||||
| typedef struct { | ||||
|   bool adj_x; | ||||
|   bool adj_y; | ||||
|   // Parameters for BatchMatMul version 4 or above.
 | ||||
|   // If set to true and the weights are quantized, then non constant inputs
 | ||||
|   // are quantized at evaluation time with asymmetric quantization.
 | ||||
|   bool asymmetric_quantize_inputs; | ||||
| } TfLiteBatchMatMulParams; | ||||
| 
 | ||||
| typedef struct { | ||||
|  | ||||
| @ -990,6 +990,10 @@ table SegmentSumOptions { | ||||
| table BatchMatMulOptions { | ||||
|   adj_x:bool; | ||||
|   adj_y:bool; | ||||
|   // Parameters for BatchMatMul version 4 or above. | ||||
|   // If set to true, then weights-only op will use asymmetric quantization for | ||||
|   // inputs. | ||||
|   asymmetric_quantize_inputs: bool; | ||||
| } | ||||
| 
 | ||||
| table CumsumOptions { | ||||
|  | ||||
| @ -90,6 +90,7 @@ std::vector<int32_t> GetWeightInputIndices(const OperatorCodeT* op_code, | ||||
|   } else if (builtin_op_code == BuiltinOperator_CONV_2D || | ||||
|              builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D || | ||||
|              builtin_op_code == BuiltinOperator_FULLY_CONNECTED || | ||||
|              builtin_op_code == BuiltinOperator_BATCH_MATMUL || | ||||
|              builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP) { | ||||
|     return {1}; | ||||
|   } else if (builtin_op_code == BuiltinOperator_SVDF) { | ||||
| @ -145,6 +146,7 @@ bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code, | ||||
|       return custom_op_info->second.is_hybrid; | ||||
|     } | ||||
|   } else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED || | ||||
|              builtin_op_code == BuiltinOperator_BATCH_MATMUL || | ||||
|              builtin_op_code == BuiltinOperator_CONV_2D || | ||||
|              builtin_op_code == BuiltinOperator_SVDF || | ||||
|              builtin_op_code == BuiltinOperator_RNN || | ||||
| @ -255,6 +257,10 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator( | ||||
|           op->builtin_options.AsFullyConnectedOptions() | ||||
|               ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; | ||||
|           break; | ||||
|         case BuiltinOperator_BATCH_MATMUL: | ||||
|           op->builtin_options.AsBatchMatMulOptions() | ||||
|               ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; | ||||
|           break; | ||||
|         case BuiltinOperator_LSTM: | ||||
|           op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs = | ||||
|               use_updated_hybrid_scheme; | ||||
|  | ||||
| @ -557,8 +557,25 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { | ||||
|       } | ||||
|       return 1; | ||||
| 
 | ||||
|     case BuiltinOperator_CONCATENATION: | ||||
|     case BuiltinOperator_BATCH_MATMUL: | ||||
|       // In case of int16 inputs, the version is 3.
 | ||||
|       if (op_sig.input_types.at(0) == TensorType_INT16) { | ||||
|         return 3; | ||||
|       } | ||||
|       if (op_sig.input_types.at(0) == TensorType_INT8) { | ||||
|         return 2; | ||||
|       } | ||||
|       if (op_sig.input_types.at(0) == TensorType_FLOAT32 && | ||||
|           op_sig.input_types.at(1) == TensorType_INT8 && | ||||
|           op_sig.output_types.at(0) == TensorType_FLOAT32) { | ||||
|         if (op_sig.options.input_quantization.asymmetric_quantize_inputs) { | ||||
|           // This is to use the updated quantization scheme.
 | ||||
|           return 4; | ||||
|         } | ||||
|       } | ||||
|       return 1; | ||||
| 
 | ||||
|     case BuiltinOperator_CONCATENATION: | ||||
|     case BuiltinOperator_SOFTMAX: | ||||
|     case BuiltinOperator_MEAN: | ||||
|     case BuiltinOperator_PAD: | ||||
| @ -792,6 +809,12 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, | ||||
|           std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1)); | ||||
|     } break; | ||||
| 
 | ||||
|     case BuiltinOperator_BATCH_MATMUL: { | ||||
|       auto batch_matmul_option = op->builtin_options_as_BatchMatMulOptions(); | ||||
|       op_sig.options.input_quantization.asymmetric_quantize_inputs = | ||||
|           batch_matmul_option->asymmetric_quantize_inputs(); | ||||
|     } break; | ||||
| 
 | ||||
|     default: | ||||
|       break; | ||||
|   } | ||||
|  | ||||
| @ -18,6 +18,7 @@ limitations under the License. | ||||
| 
 | ||||
| #include <gmock/gmock.h> | ||||
| #include <gtest/gtest.h> | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| 
 | ||||
| @ -779,4 +780,49 @@ TEST(OpVersionTest, VersioningAbsTest) { | ||||
|   }; | ||||
|   EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); | ||||
| } | ||||
| TEST(OpVersionTest, VersioningBatchMatMulTest) { | ||||
|   // Default.
 | ||||
|   OpSignature fake_op_sig = { | ||||
|       .op = BuiltinOperator_BATCH_MATMUL, | ||||
|       .input_types = | ||||
|           std::vector<TensorType>{TensorType_FLOAT32, TensorType_FLOAT32}, | ||||
|       .output_types = std::vector<TensorType>{TensorType_FLOAT32}, | ||||
|   }; | ||||
|   EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); | ||||
| 
 | ||||
|   // int8 input is version 2.
 | ||||
|   fake_op_sig = { | ||||
|       .op = BuiltinOperator_BATCH_MATMUL, | ||||
|       .input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8}, | ||||
|       .output_types = std::vector<TensorType>{TensorType_INT8}, | ||||
|   }; | ||||
|   EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); | ||||
| 
 | ||||
|   // int16 input is version 3.
 | ||||
|   fake_op_sig = { | ||||
|       .op = BuiltinOperator_BATCH_MATMUL, | ||||
|       .input_types = std::vector<TensorType>{TensorType_INT16, TensorType_INT8}, | ||||
|       .output_types = std::vector<TensorType>{TensorType_INT16}, | ||||
|   }; | ||||
|   EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); | ||||
| 
 | ||||
|   // Symmetric hybrid quantized input is version 1.
 | ||||
|   fake_op_sig = { | ||||
|       .op = BuiltinOperator_BATCH_MATMUL, | ||||
|       .input_types = | ||||
|           std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8}, | ||||
|       .output_types = std::vector<TensorType>{TensorType_FLOAT32}, | ||||
|   }; | ||||
|   EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); | ||||
| 
 | ||||
|   // Asymmetric hybrid quantized input is version 4.
 | ||||
|   fake_op_sig = { | ||||
|       .op = BuiltinOperator_BATCH_MATMUL, | ||||
|       .input_types = | ||||
|           std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8}, | ||||
|       .output_types = std::vector<TensorType>{TensorType_FLOAT32}, | ||||
|   }; | ||||
|   fake_op_sig.options.input_quantization.asymmetric_quantize_inputs = true; | ||||
|   EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); | ||||
| } | ||||
| }  // namespace tflite
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user