STT-tensorflow/tensorflow/lite/python/lite_v2_test.py
Taehee Jeong 5ef02d3f57 Fix lite_tests for new quantizer
MLIR quantizer test path is added for following tests.

* testQuantizeFloat16
* testIntegerQuantizationWithUnsupportedOps

PiperOrigin-RevId: 347932208
Change-Id: Iefed1daab7e364621c5f9f6d65a234b0eb85bf44
2020-12-16 17:58:55 -08:00

1637 lines
64 KiB
Python

# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for lite.py functionality related to TensorFlow 2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
from six.moves import range
from six.moves import zip
import tensorflow as tf
from tensorflow.lite.kernels.hashtable import pywrap_hashtable_ops as hashtable_ops_registerer
from tensorflow.lite.python import convert
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_v2_test_util
from tensorflow.lite.python.convert import mlir_quantize
from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.lite.python.interpreter import InterpreterWithCustomOps
from tensorflow.lite.toco import types_pb2 as _types_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model import saved_model
from tensorflow.python.saved_model.loader_impl import parse_saved_model
from tensorflow.python.saved_model.save import save
from tensorflow.python.training.tracking import tracking
class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
@test_util.run_v2_only
def testTypeInvalid(self):
root = self._getSimpleVariableModel()
with self.assertRaises(ValueError) as error:
_ = lite.TFLiteConverterV2.from_concrete_functions([root.f])
self.assertIn('call get_concrete_function', str(error.exception))
@parameterized.named_parameters(
('EnableMlirConverter', True), # enable mlir
('DisableMlirConverter', False)) # disable mlir
@test_util.run_v2_only
def testFloat(self, enable_mlir_converter):
root = self._getSimpleVariableModel()
input_data = tf.constant(1., shape=[1])
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_new_converter = enable_mlir_converter
tflite_model = converter.convert()
# Check output value from converted model.
expected_value = root.f(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value.numpy(), actual_value)
@parameterized.named_parameters(('_INT8InputOutput', dtypes.int8),
('_UINT8InputOutput', dtypes.uint8),
('_INT16InputOutput', dtypes.int16))
@test_util.run_v2_only
def testInvalidFloat(self, inference_input_output_type):
root = self._getSimpleVariableModel()
input_data = tf.constant(1., shape=[1])
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
with self.assertRaises(ValueError) as error:
converter.inference_input_type = inference_input_output_type
converter.inference_output_type = inference_input_output_type
converter.convert()
self.assertEqual(
'The inference_input_type and inference_output_type '
'must be tf.float32.', str(error.exception))
@test_util.run_v2_only
def testScalarInput(self):
root = self._getSimpleVariableModel()
input_data = tf.constant(1., shape=[])
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.f(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value.numpy(), actual_value)
@test_util.run_v2_only
def testMultiFunctionModel(self):
"""Convert a single model in a multi-functional model."""
root = self._getMultiFunctionModel()
input_data = tf.constant(1., shape=[1])
concrete_func = root.add.get_concrete_function(input_data)
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.add(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value.numpy(), actual_value)
@test_util.run_v2_only
def testConvertMultipleFunctions(self):
"""Convert multiple functions in a multi-functional model."""
root = self._getMultiFunctionModel()
input_data = tf.constant(1., shape=[1])
add_func = root.add.get_concrete_function(input_data)
sub_func = root.sub.get_concrete_function(input_data)
# Try converting multiple functions.
converter = lite.TFLiteConverterV2.from_concrete_functions(
[add_func, sub_func])
with self.assertRaises(ValueError) as error:
_ = converter.convert()
self.assertIn('can only convert a single ConcreteFunction',
str(error.exception))
def _getIntegerQuantizeModel(self):
np.random.seed(0)
root = tracking.AutoTrackable()
@tf.function(
input_signature=[tf.TensorSpec(shape=[1, 5, 5, 3], dtype=tf.float32)])
def func(inp):
conv = tf.nn.conv2d(
inp, tf.ones([3, 3, 3, 16]), strides=[1, 1, 1, 1], padding='SAME')
output = tf.nn.relu(conv, name='output')
return output
def calibration_gen():
for _ in range(5):
yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
root.f = func
to_save = root.f.get_concrete_function()
return (to_save, calibration_gen)
@parameterized.named_parameters(
('EnableMlirQuantizer', True), # enable mlir quantizer
('DisableMlirQuantizer', False)) # disable mlir quantizer
def testPostTrainingCalibrateAndQuantize(self, mlir_quantizer):
func, calibration_gen = self._getIntegerQuantizeModel()
# Convert float model.
float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
float_tflite_model = float_converter.convert()
self.assertIsNotNone(float_tflite_model)
# Convert quantized model.
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.representative_dataset = calibration_gen
quantized_converter._experimental_new_quantizer = mlir_quantizer
quantized_tflite_model = quantized_converter.convert()
self.assertIsNotNone(quantized_tflite_model)
# The default input and output types should be float.
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(np.float32, input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(np.float32, output_details[0]['dtype'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
@parameterized.named_parameters(('_INT8InputOutput', dtypes.int8),
('_UINT8InputOutput', dtypes.uint8),
('_INT16InputOutput', dtypes.int16))
@test_util.run_v2_only
def testInvalidPostTrainingDynamicRangeQuantization(
self, inference_input_output_type):
func, _ = self._getIntegerQuantizeModel()
# Convert float model.
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Convert quantized model.
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
with self.assertRaises(ValueError) as error:
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_converter.convert()
self.assertEqual(
'The inference_input_type and inference_output_type '
'must be tf.float32.', str(error.exception))
@parameterized.named_parameters(
('_Default', False, False, dtypes.float32),
('_INT8InputOutput', False, False, dtypes.int8),
('_UINT8InputOutput', False, False, dtypes.uint8),
('_INT16Quantize', False, True, dtypes.float32),
('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
('_IntOnly', True, False, dtypes.float32),
('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
('_IntOnly_INT16Quantize', True, True, dtypes.float32),
('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
def testIntegerQuantization(self, is_int_only, is_int16_quantize,
inference_input_output_type):
func, calibration_gen = self._getIntegerQuantizeModel()
# Convert float model.
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Convert quantized model.
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.representative_dataset = calibration_gen
if is_int_only:
if is_int16_quantize:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.\
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
]
else:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8
]
else:
if is_int16_quantize:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.\
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
lite.OpsSet.TFLITE_BUILTINS
]
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite_model = quantized_converter.convert()
self.assertIsNotNone(quantized_tflite_model)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
output_details[0]['dtype'])
# Ensure that the quantized tflite model is smaller.
self.assertLess(len(quantized_tflite_model), len(tflite_model))
@parameterized.named_parameters(
('_INT16Quantize_INT8InputOutput', True, dtypes.int8))
def testInvalidIntegerQuantization(self, is_int16_quantize,
inference_input_output_type):
func, calibration_gen = self._getIntegerQuantizeModel()
# Convert quantized model.
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.representative_dataset = calibration_gen
if is_int16_quantize:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.\
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
lite.OpsSet.TFLITE_BUILTINS
]
with self.assertRaises(ValueError) as error:
quantized_converter.inference_input_type = dtypes.int8
quantized_converter.inference_output_type = dtypes.int8
quantized_converter.convert()
self.assertEqual(
'The inference_input_type and inference_output_type '
"must be in ['tf.float32', 'tf.int16'].", str(error.exception))
def testCalibrateAndQuantizeBuiltinInt16(self):
func, calibration_gen = self._getIntegerQuantizeModel()
# Convert float model.
float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
float_tflite_model = float_converter.convert()
self.assertIsNotNone(float_tflite_model)
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
# TODO(b/156309549): We should add INT16 to the builtin types.
converter.optimizations = [lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset = calibration_gen
converter._experimental_calibrate_only = True
calibrated_tflite = converter.convert()
quantized_tflite_model = mlir_quantize(
calibrated_tflite, inference_type=_types_pb2.QUANTIZED_INT16)
self.assertIsNotNone(quantized_tflite_model)
# The default input and output types should be float.
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(np.float32, input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(np.float32, output_details[0]['dtype'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
def _getTrainingTimeQuantizedModel(self):
class QLinear(tf.keras.layers.Layer):
def __init__(self, units=3, **kwargs):
super(QLinear, self).__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
'weight',
shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True)
self.min_var = self.add_weight(
'min',
initializer=tf.keras.initializers.Constant(-6.0),
trainable=False)
self.max_var = self.add_weight(
'max',
initializer=tf.keras.initializers.Constant(6.0),
trainable=False)
def call(self, inputs):
x = tf.quantization.fake_quant_with_min_max_vars(
inputs, self.min_var, self.max_var)
w_fq = tf.quantization.fake_quant_with_min_max_vars(
self.w, self.min_var, self.max_var)
x = tf.matmul(x, w_fq)
x = tf.quantization.fake_quant_with_min_max_vars(
x, self.min_var, self.max_var)
return x
return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
@parameterized.named_parameters(
('_DefaultFLOAT32InputOutput', dtypes.float32),
('_INT8InputOutput', dtypes.int8), ('_UINT8InputOutput', dtypes.uint8))
@test_util.run_v2_only
def testTrainingTimeQuantization(self, inference_input_output_type):
model = self._getTrainingTimeQuantizedModel()
float_converter = lite.TFLiteConverterV2.from_keras_model(model)
float_tflite_model = float_converter.convert()
self.assertIsNotNone(float_tflite_model)
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite_model = quantized_converter.convert()
self.assertIsNotNone(quantized_tflite_model)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
output_details[0]['dtype'])
# Ensure that the quantized tflite model is smaller.
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
@test_util.run_v2_only
def testNewQuantizer(self):
"""Test the model quantized by the new converter."""
func, calibration_gen = self._getIntegerQuantizeModel()
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8
]
quantized_converter.representative_dataset = calibration_gen
# default quantizer
quantized_converter._experimental_new_quantizer = False
old_tflite = quantized_converter.convert()
# new quantizer
quantized_converter._experimental_new_quantizer = True
new_tflite = quantized_converter.convert()
for _ in range(5):
input_data = tf.constant(
np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
old_value = self._evaluateTFLiteModel(old_tflite, [input_data])
new_value = self._evaluateTFLiteModel(new_tflite, [input_data])
self.assertAllClose(old_value, new_value, atol=1e-01)
@parameterized.named_parameters(
('EnableMlirConverter', True), # enable mlir
('DisableMlirConverter', False)) # disable mlir
@test_util.run_v2_only
def testEmbeddings(self, enable_mlir_converter):
"""Test model with embeddings."""
input_data = tf.constant(
np.array(np.random.random_sample((20)), dtype=np.int32))
class EmbeddingModel(tf.keras.Model):
def __init__(self):
super(EmbeddingModel, self).__init__()
self.shared_weights = self.add_weight(
'weights',
shape=(2000, 300),
dtype=tf.float32,
initializer=tf.random_normal_initializer(
mean=0.0, stddev=300**(-0.5)))
@tf.function(input_signature=[tf.TensorSpec(shape=(20), dtype=tf.int32)])
def func(self, x):
return tf.gather(self.shared_weights, x)
# Building the model.
root = EmbeddingModel()
concrete_func = root.func.get_concrete_function()
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_new_converter = enable_mlir_converter
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.func(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertAllClose(expected_value.numpy(), actual_value[0], atol=1e-05)
@test_util.run_v2_only
def testGraphDebugInfo(self):
"""Test a concrete function has debug info captured."""
root = tracking.AutoTrackable()
root.v1 = tf.Variable(3.)
root.f = tf.function(lambda x: root.v1 * x)
input_data = tf.constant(1., shape=[1])
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.convert()
self._assertValidDebugInfo(converter._debug_info)
def _getIntegerQuantizationModelWithFlexOp(self):
np.random.seed(0)
root = tracking.AutoTrackable()
@tf.function(input_signature=[
tf.TensorSpec(shape=[3, 3, 3, 3, 3], dtype=tf.float32)
])
def func(inp):
tanh = tf.math.tanh(inp)
# Flex delegate will merge the consecutive conv3d and erf ops into one
# Delegate node.
conv3d = tf.nn.conv3d(
tanh,
tf.ones([3, 3, 3, 3, 3]),
strides=[1, 1, 1, 1, 1],
padding='SAME')
erf = tf.math.erf(conv3d)
output = tf.math.tanh(erf)
return output
def calibration_gen():
for _ in range(5):
yield [
np.random.uniform(-1, 1, size=(3, 3, 3, 3, 3)).astype(np.float32)
]
root.f = func
return (root.f.get_concrete_function(), calibration_gen)
@parameterized.named_parameters(
('_Default', False, False, dtypes.float32),
('_INT8InputOutput', False, False, dtypes.int8),
('_UINT8InputOutput', False, False, dtypes.uint8),
('_INT16Quantize', False, True, dtypes.float32),
('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
('_IntOnly', True, False, dtypes.float32),
('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
('_IntOnly_INT16Quantize', True, True, dtypes.float32),
('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
@test_util.run_v2_only
def testIntegerQuantizationWithFlexOp(self, is_int_only, is_int16_quantize,
inference_input_output_type):
func, calibration_gen = self._getIntegerQuantizationModelWithFlexOp()
quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
[func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.representative_dataset = calibration_gen
if is_int_only:
if is_int16_quantize:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.\
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
lite.OpsSet.SELECT_TF_OPS
]
else:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.SELECT_TF_OPS
]
else:
if is_int16_quantize:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.\
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
lite.OpsSet.TFLITE_BUILTINS,
lite.OpsSet.SELECT_TF_OPS
]
else:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS, lite.OpsSet.SELECT_TF_OPS
]
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite_model = quantized_converter.convert()
self.assertIsNotNone(quantized_tflite_model)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
output_details[0]['dtype'])
def _getIntegerQuantizationModelWithUnsupportedOps(self):
np.random.seed(0)
root = tracking.AutoTrackable()
@tf.function(input_signature=[
tf.TensorSpec(shape=[3], dtype=tf.float32),
tf.TensorSpec(shape=[3], dtype=tf.float32)
])
def func(a, b):
# ceil kernel does not support int8 nor int16 types neither.
left = tf.math.ceil(a)
right = tf.nn.tanh(b)
add = tf.math.add(left, right)
# ceil kernel does not support int8 nor int16 types neither.
output = tf.math.ceil(add)
return (output, right)
def calibration_gen():
for _ in range(5):
yield [
np.random.uniform(-1, 1, size=(3)).astype(np.float32),
np.random.uniform(-1, 1, size=(3)).astype(np.float32)
]
root.f = func
return (root.f.get_concrete_function(), calibration_gen)
@parameterized.named_parameters(
('_INT8InputOutput', False, False, dtypes.int8),
('_UINT8InputOutput', False, False, dtypes.uint8),
('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
@test_util.run_v2_only
def testIntegerQuantizationWithUnsupportedOps(self,
is_int_only,
is_int16_quantize,
inference_input_output_type,
enable_mlir_quantizer=False):
func, calib_gen = self._getIntegerQuantizationModelWithUnsupportedOps()
quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
[func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.representative_dataset = calib_gen
if is_int_only:
if is_int16_quantize:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.\
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
lite.OpsSet.TFLITE_BUILTINS
]
else:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS
]
else:
if is_int16_quantize:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.\
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
lite.OpsSet.TFLITE_BUILTINS
]
else:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS
]
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_converter._experimental_new_quantizer = enable_mlir_quantizer
quantized_tflite_model = quantized_converter.convert()
self.assertIsNotNone(quantized_tflite_model)
expected_dtype = inference_input_output_type.as_numpy_dtype
# Allow float32 for fallback on non-quantizable op.
expected_ceil_dtype = (
expected_dtype if enable_mlir_quantizer else dtypes.float32)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 2)
self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype)
self.assertEqual(input_details[1]['dtype'], expected_dtype)
output_details = interpreter.get_output_details()
self.assertLen(output_details, 2)
self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
self.assertEqual(output_details[1]['dtype'], expected_dtype)
class FromSavedModelTest(lite_v2_test_util.ModelTest):
def _createV1SavedModel(self, shape):
"""Create a simple SavedModel."""
saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
in_tensor_1 = tf.compat.v1.placeholder(
shape=shape, dtype=tf.float32, name='inputB')
in_tensor_2 = tf.compat.v1.placeholder(
shape=shape, dtype=tf.float32, name='inputA')
variable_node = tf.Variable(1.0, name='variable_node')
out_tensor = in_tensor_1 + in_tensor_2 * variable_node
inputs = {'x': in_tensor_1, 'y': in_tensor_2}
outputs = {'z': out_tensor}
sess.run(tf.compat.v1.variables_initializer([variable_node]))
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
return saved_model_dir
@test_util.run_v2_only
def testV1SimpleModel(self):
"""Test a SavedModel."""
with tf.Graph().as_default():
saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 2)
self.assertStartsWith(input_details[0]['name'], 'inputA')
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
self.assertEqual((0., 0.), input_details[0]['quantization'])
self.assertStartsWith(
input_details[1]['name'],
'inputB',
)
self.assertEqual(np.float32, input_details[1]['dtype'])
self.assertTrue([1, 16, 16, 3], input_details[1]['shape'])
self.assertEqual((0., 0.), input_details[1]['quantization'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertStartsWith(output_details[0]['name'], 'add')
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue([1, 16, 16, 3], output_details[0]['shape'])
self.assertEqual((0., 0.), output_details[0]['quantization'])
@test_util.run_v2_only
def testTF1HubFormattedModel(self):
"""Test a TF1 hub formatted model."""
saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
# TF1 hub model is based on V1 saved model and they omit the saved model
# schema version setting.
saved_model_proto = parse_saved_model(saved_model_dir)
saved_model_proto.saved_model_schema_version = 0
saved_model_pb_file_path = os.path.join(saved_model_dir, 'saved_model.pb')
with file_io.FileIO(saved_model_pb_file_path, 'wb') as writer:
writer.write(saved_model_proto.SerializeToString())
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
def _createV1ModelWithHashTableInitializer(self):
# Create a v1 saved model with hash table initializers.
tf.compat.v1.disable_eager_execution()
saved_model_dir = os.path.join(self.get_temp_dir(),
'savedmodel_with_hashtable')
table_initializer = tf.lookup.KeyValueTensorInitializer(
keys=['a', 'b', 'c', 'd'],
values=[1, 2, 3, 4],
key_dtype=tf.string,
value_dtype=tf.int64)
table = tf.lookup.StaticHashTable(
table_initializer, default_value=tf.constant(-1, dtype=tf.int64))
x = tf.compat.v1.placeholder(tf.string, shape=(), name='input')
y = table.lookup(x)
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
signature_def_map, init_op, assets_collection = {
'serving_default':
(tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={'x': tensor_info_x},
outputs={'y': tensor_info_y},
method_name='some_function'))
}, tf.compat.v1.tables_initializer(), None
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.initializers.global_variables())
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(
saved_model_dir)
builder.add_meta_graph_and_variables(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map,
main_op=init_op,
assets_collection=assets_collection,
strip_default_attrs=True)
builder.save()
# Restore TF v2 behavior.
tf.compat.v1.reset_default_graph()
tf.compat.v1.enable_eager_execution()
return saved_model_dir
@test_util.run_v2_only
def testModelWithHashTableInitializer(self):
"""Test a model with saved_model's session initializer for hash tables."""
saved_model_dir = self._createV1ModelWithHashTableInitializer()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
converter.allow_custom_ops = True
tflite_model = converter.convert()
# Check values from converted model.
interpreter = InterpreterWithCustomOps(
model_content=tflite_model,
custom_op_registerers=[hashtable_ops_registerer.HashtableOpsRegisterer])
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_data = np.array(['a', 'b', 'c', 'z'], dtype=np.string_)
interpreter.resize_tensor_input(
input_details[0]['index'], [4], strict=False)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], input_data)
# Invoke multiple times to ensure the initializer graph runs only once.
interpreter.invoke()
actual_value = interpreter.get_tensor(output_details[0]['index'])
self.assertEqual([1, 2, 3, -1], list(actual_value))
interpreter.invoke()
actual_value = interpreter.get_tensor(output_details[0]['index'])
self.assertEqual([1, 2, 3, -1], list(actual_value))
interpreter.invoke()
actual_value = interpreter.get_tensor(output_details[0]['index'])
self.assertEqual([1, 2, 3, -1], list(actual_value))
@test_util.run_v2_only
def testConstModel(self):
"""Test a basic model with functions to make sure functions are inlined."""
input_data = tf.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.f = tf.function(lambda x: 2. * x)
to_save = root.f.get_concrete_function(input_data)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, to_save)
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.f(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value.numpy(), actual_value)
@test_util.run_v2_only
def testVariableModel(self):
"""Test a basic model with Variables with saving/loading the SavedModel."""
root = self._getSimpleVariableModel()
input_data = tf.constant(1., shape=[1])
to_save = root.f.get_concrete_function(input_data)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, to_save)
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.f(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value.numpy(), actual_value)
@test_util.run_v2_only
def testSignatures(self):
"""Test values for `signature_keys` argument."""
root = self._getSimpleVariableModel()
input_data = tf.constant(1., shape=[1])
to_save = root.f.get_concrete_function(input_data)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, to_save)
# Convert model with invalid `signature_keys`.
with self.assertRaises(ValueError) as error:
_ = lite.TFLiteConverterV2.from_saved_model(
save_dir, signature_keys=['INVALID'])
self.assertIn("Invalid signature key 'INVALID'", str(error.exception))
# Convert model with empty `signature_keys`.
converter = lite.TFLiteConverterV2.from_saved_model(
save_dir, signature_keys=[])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.f(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value.numpy(), actual_value)
@test_util.run_v2_only
def testSignatureDefs(self):
"""Test converting SignatureDef is correct and uses SignatureDef API."""
root = self._getMultiFunctionModel()
input_data_0 = tf.constant(1., shape=[1])
input_data_1 = tf.constant(3., shape=[1])
mul_add_func = root.mul_add.get_concrete_function(input_data_1,
input_data_0)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, {'mul_add': mul_add_func})
converter = lite.TFLiteConverterV2.from_saved_model(
save_dir, signature_keys=['mul_add'])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.mul_add(input_data_1, input_data_0)
interpreter = Interpreter(model_content=tflite_model)
signature_defs = interpreter.get_signature_list()
results = self._evaluateTFLiteModelUsingSignatureDef(
tflite_model, 'mul_add', {
'y': input_data_0,
'x': input_data_1
})
self.assertEqual(list(results.keys()), ['output_0'])
self.assertEqual(expected_value.numpy(), results['output_0'])
# Verify the SignatureDef structure returned is as expected.
self.assertEqual(len(signature_defs), 1)
self.assertEqual(list(signature_defs.keys()), ['mul_add'])
self.assertEqual(len(signature_defs.values()), 1)
self.assertEqual(
list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
@test_util.run_v2_only
def testSignatureDefsWithDefaultValue(self):
"""Test converting SignatureDef is correct and uses SignatureDef API.
This test uses None as method_name to test default behavior.
"""
root = self._getMultiFunctionModel()
input_data_0 = tf.constant(1., shape=[1])
input_data_1 = tf.constant(3., shape=[1])
mul_add_func = root.mul_add.get_concrete_function(input_data_1,
input_data_0)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, {'mul_add': mul_add_func})
converter = lite.TFLiteConverterV2.from_saved_model(
save_dir, signature_keys=['mul_add'])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.mul_add(input_data_1, input_data_0)
interpreter = Interpreter(model_content=tflite_model)
signature_defs = interpreter.get_signature_list()
results = self._evaluateTFLiteModelUsingSignatureDef(
tflite_model, None, {
'y': input_data_0,
'x': input_data_1
})
self.assertEqual(list(results.keys()), ['output_0'])
self.assertEqual(expected_value.numpy(), results['output_0'])
# Verify the SignatureDef structure returned is as expected.
self.assertEqual(len(signature_defs), 1)
self.assertEqual(list(signature_defs.keys()), ['mul_add'])
self.assertEqual(len(signature_defs.values()), 1)
self.assertEqual(
list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
@test_util.run_v2_only
def testMultipleFunctionModel(self):
"""Convert multiple functions in a multi-functional model."""
root = self._getMultiFunctionModel()
input_data = tf.constant(1., shape=[1])
add_func = root.add.get_concrete_function(input_data)
sub_func = root.sub.get_concrete_function(input_data)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, {'add': add_func, 'sub': sub_func})
# Try converting multiple functions.
with self.assertRaises(ValueError) as error:
_ = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertIn('Only support a single signature key.', str(error.exception))
@test_util.run_v2_only
def testNoConcreteFunctionModel(self):
root = self._getMultiFunctionModel()
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir)
with self.assertRaises(ValueError) as error:
_ = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertIn('Only support a single signature key.', str(error.exception))
@test_util.run_v2_only
def testKerasSequentialModel(self):
"""Test a simple sequential tf.Keras model."""
input_data = tf.constant(1., shape=[1, 1])
x = np.array([[1.], [2.]])
y = np.array([[2.], [4.]])
model = tf.keras.models.Sequential([
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(1),
])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=1)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(model, save_dir)
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
tflite_model = converter.convert()
# Check values from converted model.
expected_value = model.predict(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value, actual_value)
@test_util.run_v2_only
def testGraphDebugInfo(self):
"""Test a SavedModel has debug info captured."""
input_data = tf.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.f = tf.function(lambda x: 2. * x)
to_save = root.f.get_concrete_function(input_data)
options = save_options.SaveOptions(save_debug_info=True)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, to_save, options)
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
converter.convert()
self._assertValidDebugInfo(converter._debug_info)
@test_util.run_v2_only
def testFallbackPath(self):
"""Test a SavedModel fallback path using old converter."""
saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
converter.experimental_new_converter = False
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@test_util.run_v2_only
def testNonStatefulConvLSTM2D(self):
"""Test saved model with non stateful ConvLSTM2D keras layer."""
# Create keras model
model = tf.keras.Sequential([
tf.keras.layers.ConvLSTM2D(
32, (3, 3),
padding='same',
return_sequences=True,
stateful=False,
batch_input_shape=(1, 1, 10, 10, 1))
])
model.compile()
# Export the keras model to saved model.
saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_lstm_2d')
model.save(saved_model_dir, save_format='tf', include_optimizer=False)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
self.assertTrue(tflite_model)
def _createUnknownInputShapeModel(self):
"""Create a simple SavedModel with unknown input."""
saved_model_dir = os.path.join(self.get_temp_dir(), 'unknown_input_shape')
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
unknown_shape = tf.TensorShape(None)
in_tensor = tf.compat.v1.placeholder(
shape=unknown_shape, dtype=tf.float32, name='input')
out_tensor = in_tensor + in_tensor
inputs = {'input': in_tensor}
outputs = {'output': out_tensor}
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
return saved_model_dir
@test_util.run_v2_only
def testUnknownInputShapeModel(self):
"""Test a SavedModel with an unknown input shape."""
saved_model_dir = self._createUnknownInputShapeModel()
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_data = np.array([1., 2., 3.], dtype=np.float32)
interpreter.resize_tensor_input(
input_details[0]['index'], [3], strict=False)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
actual_value = interpreter.get_tensor(output_details[0]['index'])
self.assertEqual([2., 4., 6.], list(actual_value))
class FromKerasModelTest(lite_v2_test_util.ModelTest):
@test_util.run_v2_only
def testSequentialModel(self):
"""Test a simple sequential tf.Keras model."""
input_data = tf.constant(1., shape=[1, 1])
# Create a simple Keras model.
x = np.array([[1.], [2.]])
y = np.array([[2.], [4.]])
model = tf.keras.models.Sequential([
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(units=1, input_shape=[1])
])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=1)
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
# Check values from converted model.
expected_value = model.predict(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value, actual_value)
@test_util.run_v2_only
def testSequentialMultiInputOutputModel(self):
"""Test a tf.Keras model with multiple inputs and outputs."""
left_input_data = tf.constant(1., shape=[1, 3])
right_input_data = tf.constant(1., shape=[1, 3])
# Create a simple Keras model.
input_a_np = np.random.random((10, 3))
input_b_np = np.random.random((10, 3))
output_c_np = np.random.random((10, 3))
output_d_np = np.random.random((10, 2))
input_a = tf.keras.layers.Input(shape=(3,), name='input_a')
input_b = tf.keras.layers.Input(shape=(3,), name='input_b')
dense = tf.keras.layers.Dense(8, name='dense_1')
interm_a = dense(input_a)
interm_b = dense(input_b)
merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge')
output_c = tf.keras.layers.Dense(
3, activation='softmax', name='dense_2')(
merged)
output_d = tf.keras.layers.Dense(
2, activation='softmax', name='dense_3')(
merged)
model = tf.keras.models.Model(
inputs=[input_a, input_b], outputs=[output_c, output_d])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit([input_a_np, input_b_np], [output_c_np, output_d_np], epochs=1)
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
# Check values from converted model.
input_data = [left_input_data, right_input_data]
expected_value = model.predict(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, input_data)
for tf_result, tflite_result in zip(expected_value, actual_value):
self.assertAllClose(tf_result, tflite_result, atol=1e-05)
@test_util.run_v2_only
def testGraphDebugInfo(self):
"""Test a tf.Keras model has debug info captured."""
# Create a simple Keras model.
x = [-1, 0, 1, 2, 3, 4]
y = [-3, -1, 1, 3, 5, 7]
model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=1)
converter = lite.TFLiteConverterV2.from_keras_model(model)
converter.convert()
self._assertValidDebugInfo(converter._debug_info)
@test_util.run_v2_only
def testKerasFallbackPath(self):
"""Test keras model which failed when exporting to the saved model."""
input_data = tf.constant(
np.array(np.random.random_sample((20)), dtype=np.float32))
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
# A None name will cause a failure in exporting to a saved model.
self.shared_weights = self.add_weight(
name=None,
shape=(20, 1),
dtype=tf.float32,
initializer=tf.random_normal_initializer(
mean=0.0, stddev=300**(-0.5)))
def call(self, x):
return tf.add(self.shared_weights, x)
# Building the model.
model = Model()
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(input_data, input_data, epochs=1)
# Convert model.
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
class ControlFlowTest(lite_v2_test_util.ModelTest):
@test_util.run_v2_only
def testCond(self):
input_data = {
'x': tf.constant([1., 2.], shape=[1, 2]),
'b': tf.constant(True)
}
weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32)
def true_fn(x):
return tf.matmul(x, weights)
def false_fn(x):
return tf.add(x, weights)
@tf.function(input_signature=[
tf.TensorSpec(shape=[1, 2], dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.bool)
])
def model(x, b):
return tf.cond(
b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))
concrete_func = model.get_concrete_function()
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = concrete_func(**input_data)
actual_value = self._evaluateTFLiteModel(
tflite_model, [input_data['x'], input_data['b']])[0]
self.assertAllClose(expected_value, actual_value)
@test_util.run_v2_only
def testConverterErrorOnControlFlowV1Ops(self):
filename = resource_loader.get_path_to_datafile(
'testdata/control_flow_v1_saved_model')
converter = lite.TFLiteConverterV2.from_saved_model(filename)
with self.assertRaises(convert.ConverterError) as error:
converter.convert()
self.assertIn(
'Failed to functionalize Control Flow V1 ops. Consider using Control '
'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
@test_util.run_v2_only
def testStaticRnn(self):
input_data = tf.constant(
np.array(np.random.random_sample((3, 10)), dtype=np.float32))
cell = tf.compat.v1.nn.rnn_cell.LSTMCell(10)
@tf.function(
input_signature=[tf.TensorSpec(shape=[3, 10], dtype=tf.float32)])
def model(x):
seq = tf.split(x, 3, 0)
return tf.compat.v1.nn.static_rnn(
cell, seq, dtype=tf.float32, sequence_length=[1])
concrete_func = model.get_concrete_function()
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = concrete_func(input_data)[0]
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
for expected, actual in zip(expected_value, actual_value):
self.assertAllClose(expected, actual)
@test_util.run_v2_only
def testWhileLoop(self):
input_data = tf.constant([1., 2., 3., 4.], shape=[2, 2])
weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32)
def condition(x):
return tf.reduce_sum(x) < 100
def body(x):
return tf.add(x, weights)
@tf.function(
input_signature=[tf.TensorSpec(shape=[2, 2], dtype=tf.float32)])
def model(x):
return tf.while_loop(condition, body, [x])
concrete_func = model.get_concrete_function()
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = concrete_func(input_data)[0]
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
self.assertAllClose(expected_value, actual_value)
@test_util.run_v2_only
def testDynamicRnn(self):
input_data = tf.constant(
np.array(np.random.random_sample((3, 10, 10)), dtype=np.float32))
cell = tf.compat.v1.nn.rnn_cell.LSTMCell(10)
@tf.function(
input_signature=[tf.TensorSpec(shape=[3, 10, 10], dtype=tf.float32)])
def model(x):
return tf.compat.v1.nn.dynamic_rnn(cell, x, dtype=tf.float32)
concrete_func = model.get_concrete_function()
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = concrete_func(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
for expected, actual in zip(expected_value, actual_value):
if not isinstance(expected, ops.EagerTensor):
expected = expected.c
self.assertAllClose(expected, actual)
@parameterized.named_parameters(('LSTM', tf.keras.layers.LSTM),
('SimpleRNN', tf.keras.layers.SimpleRNN),
('GRU', tf.keras.layers.GRU))
@test_util.run_v2_only
def testKerasRNN(self, rnn_layer):
# This relies on TFLiteConverter to rewrite unknown batch size to 1. The
# model will fail if resizing the input to non-1 batch size.
input_data = tf.constant(
np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
rnn_obj = rnn_layer(units=10, input_shape=(10, 10))
model = tf.keras.models.Sequential([
tf.keras.layers.Input(batch_size=1, shape=(10, 10), name='input'),
rnn_obj,
])
# Convert model.
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
# Check values from converted model.
expected_value = model.predict(input_data)
self.assertAllClose(expected_value, actual_value, atol=1e-05)
@parameterized.named_parameters(('LSTM', tf.keras.layers.LSTM),
('SimpleRNN', tf.keras.layers.SimpleRNN),
('GRU', tf.keras.layers.GRU))
@test_util.run_v2_only
def testKerasRNNMultiBatches(self, rnn_layer):
input_data = tf.constant(
np.array(np.random.random_sample((4, 10, 10)), dtype=np.float32))
# Specify a fixed batch size(4) for the test model.
x = tf.keras.layers.Input(batch_shape=(4, 10, 10))
y = rnn_layer(units=10, input_shape=(10, 10))(x)
model = tf.keras.Model(inputs=[x], outputs=[y])
# Convert model.
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
# Check values from converted model.
expected_value = model.predict(input_data)
self.assertAllClose(expected_value, actual_value, atol=1e-05)
@test_util.run_v2_only
def testKerasBidirectionalRNNReturnSequence(self):
input_data = tf.constant(
np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Input(batch_size=1, shape=(10, 10), name='input'))
model.add(
tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(units=10, return_sequences=True),
input_shape=(10, 10)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(5))
model.add(tf.keras.layers.Activation('softmax'))
# Convert model.
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
# Check values from converted model.
expected_value = model.predict(input_data)
self.assertAllClose(expected_value, actual_value, atol=1e-05)
@test_util.run_v2_only
def testKerasBidirectionalRNN(self):
input_data = tf.constant(
np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Input(batch_size=1, shape=(10, 10), name='input'))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=10)))
model.add(tf.keras.layers.Dense(5))
model.add(tf.keras.layers.Activation('softmax'))
# Convert model.
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
# Check values from converted model.
expected_value = model.predict(input_data)
self.assertAllClose(expected_value, actual_value, atol=1e-05)
class GrapplerTest(lite_v2_test_util.ModelTest):
@test_util.run_v2_only
def testConstantFolding(self):
# Constant folding handles the tf.broadcast_to operation which was not
# supported by the TFLite at the time this test was added.
input_data = tf.constant([1., 2., 3., 4., 5., 6., 7., 8., 9.], shape=[3, 3])
@tf.function
def func(x):
y_const = tf.constant([1., 2., 3.])
y_broadcast = tf.broadcast_to(y_const, [3, 3])
return tf.matmul(x, y_broadcast)
root = tracking.AutoTrackable()
root.f = func
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.f(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
self.assertAllClose(expected_value, actual_value)
# Enable hybrid quantization, same result
converter.optimizations = [lite.Optimize.DEFAULT]
tflite_model = converter.convert()
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
self.assertAllClose(expected_value, actual_value)
class UnknownShapes(lite_v2_test_util.ModelTest):
@test_util.run_v2_only
def testMatMul(self):
input_data = tf.constant(
np.array(np.random.random_sample((10, 4)), dtype=np.float32))
@tf.function(
input_signature=[tf.TensorSpec(shape=[None, 4], dtype=tf.float32)])
def model(in_tensor):
shape = tf.shape(in_tensor)
fill = tf.transpose(tf.fill(shape, 1.))
return tf.matmul(fill, in_tensor)
concrete_func = model.get_concrete_function()
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = concrete_func(input_data)
actual_value = self._evaluateTFLiteModel(
tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])[0]
self.assertAllClose(expected_value, actual_value, atol=1e-06)
def _getIntegerQuantizeModelWithUnknownShapes(self):
np.random.seed(0)
@tf.function(
input_signature=[tf.TensorSpec(shape=[None, 33], dtype=tf.float32)])
def model(input_tensor):
"""Define a model with tf.MatMul and unknown shapes."""
# We need the tensor to have more than 1024 elements for quantize_weights
# to kick in. Thus, the [33, 33] shape.
const_tensor = tf.constant(
np.random.uniform(low=-10., high=10., size=[33, 33]),
shape=[33, 33],
dtype=tf.float32,
name='inputB')
shape = tf.shape(input_tensor)
fill = tf.transpose(tf.fill(shape, 1.))
mult = tf.matmul(fill, input_tensor)
return tf.matmul(mult, const_tensor)
root = tracking.AutoTrackable()
root.f = model
concrete_func = root.f.get_concrete_function()
def calibration_gen():
for batch in range(5, 20, 5):
for _ in range(5):
yield [np.random.uniform(-1, 1, size=(batch, 33)).astype(np.float32)]
return concrete_func, calibration_gen
@test_util.run_v2_only
def testMatMulQuantize(self):
concrete_func, _ = self._getIntegerQuantizeModelWithUnknownShapes()
float_converter = lite.TFLiteConverterV2.from_concrete_functions(
[concrete_func])
float_tflite_model = float_converter.convert()
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions(
[concrete_func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_tflite_model = quantized_converter.convert()
# The default input and output types should be float.
quantized_interpreter = Interpreter(model_content=quantized_tflite_model)
quantized_interpreter.allocate_tensors()
input_details = quantized_interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
@test_util.run_v2_only
def testMatMulCalibrateAndQuantize(self):
concrete_func, calibration_gen = \
self._getIntegerQuantizeModelWithUnknownShapes()
float_converter = lite.TFLiteConverterV2.from_concrete_functions(
[concrete_func])
float_tflite_model = float_converter.convert()
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions(
[concrete_func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.representative_dataset = calibration_gen
quantized_tflite_model = quantized_converter.convert()
# The default input and output types should be float.
quantized_interpreter = Interpreter(model_content=quantized_tflite_model)
quantized_interpreter.allocate_tensors()
input_details = quantized_interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
def testBatchMatMul(self):
input_data_1 = tf.constant(
np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
input_data_2 = tf.constant(
np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
@tf.function(input_signature=[
tf.TensorSpec(shape=[None, 256, 256], dtype=tf.float32),
tf.TensorSpec(shape=[None, 256, 256], dtype=tf.float32)
])
def model(in_tensor_1, in_tensor_2):
return tf.matmul(in_tensor_1, in_tensor_2)
concrete_func = model.get_concrete_function()
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = concrete_func(input_data_1, input_data_2)
actual_value = self._evaluateTFLiteModel(
tflite_model, [input_data_1, input_data_2],
input_shapes=[([-1, 256, 256], [1, 256, 256])])[0]
self.assertAllClose(expected_value, actual_value, atol=4)
def testSizeInvalid(self):
@tf.function(input_signature=[
tf.TensorSpec(shape=[1, None, 16, 3], dtype=tf.float32)
])
def model(in_tensor):
return in_tensor + in_tensor
concrete_func = model.get_concrete_function()
# Test invalid shape. None after 1st dimension. Run with TOCO in order to
# invoke shape checking code.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_new_converter = False
with self.assertRaises(ValueError) as error:
converter.convert()
self.assertEqual(
'None is only supported in the 1st dimension. Tensor '
'\'in_tensor\' has invalid shape \'[1, None, 16, 3]\'.',
str(error.exception))
if __name__ == '__main__':
test.main()