Test Eager-mode in from_keras_model_file in 1.X.

PiperOrigin-RevId: 247218931
This commit is contained in:
Nupur Garg 2019-05-08 08:37:15 -07:00 committed by TensorFlower Gardener
parent c2d506cc04
commit 324820231d

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import os import os
import tempfile import tempfile
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.lite.python import lite from tensorflow.lite.python import lite
@ -27,6 +28,7 @@ from tensorflow.lite.python import lite_constants
from tensorflow.lite.python.interpreter import Interpreter from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -736,6 +738,84 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1] == output_details[0]['shape']).all()) self.assertTrue(([1] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization']) self.assertEqual((0., 0.), output_details[0]['quantization'])
def testInferenceInputOutputTypeFloatDefault(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('add', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
def testInferenceInputOutputTypeQuantizedUint8Default(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = array_ops.fake_quant_with_min_max_args(
in_tensor + in_tensor, min=0., max=1., name='output')
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.uint8, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
def testReusingConverterWithDifferentPostTrainingQuantization(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = array_ops.fake_quant_with_min_max_args(
in_tensor + in_tensor, min=0., max=1., name='output')
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.post_training_quantize = True
tflite_model = converter.convert()
self.assertTrue(tflite_model)
converter.post_training_quantize = False
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@test_util.run_v1_only('Incompatible with 2.0.') @test_util.run_v1_only('Incompatible with 2.0.')
class FromFrozenGraphFile(test_util.TensorFlowTestCase): class FromFrozenGraphFile(test_util.TensorFlowTestCase):
@ -1148,62 +1228,70 @@ class MyAddLayer(keras.layers.Layer):
@test_util.run_v1_only('Incompatible with 2.0.') @test_util.run_v1_only('Incompatible with 2.0.')
class FromKerasFile(test_util.TensorFlowTestCase): class FromKerasFile(test_util.TensorFlowTestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
keras.backend.clear_session() super(FromKerasFile, self).setUp()
self._keras_file = None
self._custom_objects = None
if not context.executing_eagerly():
keras.backend.clear_session()
def tearDown(self):
if self._keras_file:
os.remove(self._keras_file)
super(FromKerasFile, self).tearDown()
def _getSequentialModel(self, include_custom_layer=False): def _getSequentialModel(self, include_custom_layer=False):
with session.Session().as_default(): model = keras.models.Sequential()
model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(2, input_shape=(3,))) if include_custom_layer:
if include_custom_layer: model.add(MyAddLayer(1.0))
model.add(MyAddLayer(1.0)) model.add(keras.layers.RepeatVector(3))
model.add(keras.layers.RepeatVector(3)) model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) model.compile(
model.compile( loss=keras.losses.MSE,
loss=keras.losses.MSE, optimizer='sgd',
optimizer=keras.optimizers.RMSprop(), metrics=[keras.metrics.categorical_accuracy],
metrics=[keras.metrics.categorical_accuracy], sample_weight_mode='temporal')
sample_weight_mode='temporal') x = np.random.random((1, 3))
x = np.random.random((1, 3)) y = np.random.random((1, 3, 3))
y = np.random.random((1, 3, 3)) model.train_on_batch(x, y)
model.train_on_batch(x, y) model.predict(x)
model.predict(x)
try: try:
fd, keras_file = tempfile.mkstemp('.h5') fd, self._keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model, keras_file) keras.models.save_model(model, self._keras_file)
finally: finally:
os.close(fd) os.close(fd)
if include_custom_layer: if include_custom_layer:
custom_objects = {'MyAddLayer': MyAddLayer} self._custom_objects = {'MyAddLayer': MyAddLayer}
return keras_file, custom_objects
return keras_file
def testSequentialModel(self): @parameterized.named_parameters(('_graph', context.graph_mode),
('_eager', context.eager_mode))
def testSequentialModel(self, test_context):
"""Test a Sequential tf.keras model with default inputs.""" """Test a Sequential tf.keras model with default inputs."""
keras_file = self._getSequentialModel() with test_context():
self._getSequentialModel()
converter = lite.TFLiteConverter.from_keras_model_file(keras_file) converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
tflite_model = converter.convert() tflite_model = converter.convert()
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
# Check tensor details of converted model. # Check tensor details of converted model.
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() interpreter.allocate_tensors()
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details)) self.assertLen(input_details, 1)
self.assertEqual('dense_input', input_details[0]['name']) self.assertEqual('dense_input', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype']) self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 3] == input_details[0]['shape']).all()) self.assertTrue(([1, 3] == input_details[0]['shape']).all())
self.assertEqual((0., 0.), input_details[0]['quantization']) self.assertEqual((0., 0.), input_details[0]['quantization'])
output_details = interpreter.get_output_details() output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details)) self.assertLen(output_details, 1)
self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype']) self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization']) self.assertEqual((0., 0.), output_details[0]['quantization'])
@ -1214,22 +1302,22 @@ class FromKerasFile(test_util.TensorFlowTestCase):
interpreter.invoke() interpreter.invoke()
tflite_result = interpreter.get_tensor(output_details[0]['index']) tflite_result = interpreter.get_tensor(output_details[0]['index'])
keras_model = keras.models.load_model(keras_file) keras_model = keras.models.load_model(self._keras_file)
keras_result = keras_model.predict(input_data) keras_result = keras_model.predict(input_data)
np.testing.assert_almost_equal(tflite_result, keras_result, 5) np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)
def testCustomLayer(self): @parameterized.named_parameters(('_graph', context.graph_mode),
('_eager', context.eager_mode))
def testCustomLayer(self, test_context):
"""Test a Sequential tf.keras model with default inputs.""" """Test a Sequential tf.keras model with default inputs."""
keras_file, custom_objects = self._getSequentialModel( with test_context():
include_custom_layer=True) self._getSequentialModel(include_custom_layer=True)
converter = lite.TFLiteConverter.from_keras_model_file( converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, custom_objects=custom_objects) self._keras_file, custom_objects=self._custom_objects)
tflite_model = converter.convert()
tflite_model = converter.convert() self.assertTrue(tflite_model)
self.assertTrue(tflite_model)
# Check tensor details of converted model. # Check tensor details of converted model.
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
@ -1245,47 +1333,44 @@ class FromKerasFile(test_util.TensorFlowTestCase):
tflite_result = interpreter.get_tensor(output_details[0]['index']) tflite_result = interpreter.get_tensor(output_details[0]['index'])
keras_model = keras.models.load_model( keras_model = keras.models.load_model(
keras_file, custom_objects=custom_objects) self._keras_file, custom_objects=self._custom_objects)
keras_result = keras_model.predict(input_data) keras_result = keras_model.predict(input_data)
np.testing.assert_almost_equal(tflite_result, keras_result, 5) np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)
def testSequentialModelInputArray(self): def testSequentialModelInputArray(self):
"""Test a Sequential tf.keras model testing input arrays argument.""" """Test a Sequential tf.keras model testing input arrays argument."""
keras_file = self._getSequentialModel() self._getSequentialModel()
# Invalid input array raises error. # Invalid input array raises error.
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
lite.TFLiteConverter.from_keras_model_file( lite.TFLiteConverter.from_keras_model_file(
keras_file, input_arrays=['invalid-input']) self._keras_file, input_arrays=['invalid-input'])
self.assertEqual("Invalid tensors 'invalid-input' were found.", self.assertEqual("Invalid tensors 'invalid-input' were found.",
str(error.exception)) str(error.exception))
# Valid input array. # Valid input array.
converter = lite.TFLiteConverter.from_keras_model_file( converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_arrays=['dense_input']) self._keras_file, input_arrays=['dense_input'])
tflite_model = converter.convert() tflite_model = converter.convert()
os.remove(keras_file)
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
def testSequentialModelInputShape(self): def testSequentialModelInputShape(self):
"""Test a Sequential tf.keras model testing input shapes argument.""" """Test a Sequential tf.keras model testing input shapes argument."""
keras_file = self._getSequentialModel() self._getSequentialModel()
# Passing in shape of invalid input array raises error. # Passing in shape of invalid input array raises error.
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
converter = lite.TFLiteConverter.from_keras_model_file( converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_shapes={'invalid-input': [2, 3]}) self._keras_file, input_shapes={'invalid-input': [2, 3]})
self.assertEqual( self.assertEqual(
"Invalid tensor 'invalid-input' found in tensor shapes map.", "Invalid tensor 'invalid-input' found in tensor shapes map.",
str(error.exception)) str(error.exception))
# Passing in shape of valid input array. # Passing in shape of valid input array.
converter = lite.TFLiteConverter.from_keras_model_file( converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_shapes={'dense_input': [2, 3]}) self._keras_file, input_shapes={'dense_input': [2, 3]})
tflite_model = converter.convert() tflite_model = converter.convert()
os.remove(keras_file)
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
# Check input shape from converted model. # Check input shape from converted model.
@ -1293,31 +1378,32 @@ class FromKerasFile(test_util.TensorFlowTestCase):
interpreter.allocate_tensors() interpreter.allocate_tensors()
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details)) self.assertLen(input_details, 1)
self.assertEqual('dense_input', input_details[0]['name']) self.assertEqual('dense_input', input_details[0]['name'])
self.assertTrue(([2, 3] == input_details[0]['shape']).all()) self.assertTrue(([2, 3] == input_details[0]['shape']).all())
def testSequentialModelOutputArray(self): def testSequentialModelOutputArray(self):
"""Test a Sequential tf.keras model testing output arrays argument.""" """Test a Sequential tf.keras model testing output arrays argument."""
keras_file = self._getSequentialModel() self._getSequentialModel()
# Invalid output array raises error. # Invalid output array raises error.
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
lite.TFLiteConverter.from_keras_model_file( lite.TFLiteConverter.from_keras_model_file(
keras_file, output_arrays=['invalid-output']) self._keras_file, output_arrays=['invalid-output'])
self.assertEqual("Invalid tensors 'invalid-output' were found.", self.assertEqual("Invalid tensors 'invalid-output' were found.",
str(error.exception)) str(error.exception))
# Valid output array. # Valid output array.
converter = lite.TFLiteConverter.from_keras_model_file( converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, output_arrays=['time_distributed/Reshape_1']) self._keras_file, output_arrays=['time_distributed/Reshape_1'])
tflite_model = converter.convert() tflite_model = converter.convert()
os.remove(keras_file)
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
def testFunctionalModel(self): @parameterized.named_parameters(('_graph', context.graph_mode),
('_eager', context.eager_mode))
def testFunctionalModel(self, test_context):
"""Test a Functional tf.keras model with default inputs.""" """Test a Functional tf.keras model with default inputs."""
with session.Session().as_default(): with test_context():
inputs = keras.layers.Input(shape=(3,), name='input') inputs = keras.layers.Input(shape=(3,), name='input')
x = keras.layers.Dense(2)(inputs) x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x) output = keras.layers.Dense(3)(x)
@ -1325,38 +1411,37 @@ class FromKerasFile(test_util.TensorFlowTestCase):
model = keras.models.Model(inputs, output) model = keras.models.Model(inputs, output)
model.compile( model.compile(
loss=keras.losses.MSE, loss=keras.losses.MSE,
optimizer=keras.optimizers.RMSprop(), optimizer='sgd',
metrics=[keras.metrics.categorical_accuracy]) metrics=[keras.metrics.categorical_accuracy])
x = np.random.random((1, 3)) x = np.random.random((1, 3))
y = np.random.random((1, 3)) y = np.random.random((1, 3))
model.train_on_batch(x, y) model.train_on_batch(x, y)
model.predict(x) model.predict(x)
fd, keras_file = tempfile.mkstemp('.h5') fd, self._keras_file = tempfile.mkstemp('.h5')
try: try:
keras.models.save_model(model, keras_file) keras.models.save_model(model, self._keras_file)
finally: finally:
os.close(fd) os.close(fd)
# Convert to TFLite model. # Convert to TFLite model.
converter = lite.TFLiteConverter.from_keras_model_file(keras_file) converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
tflite_model = converter.convert() tflite_model = converter.convert()
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
# Check tensor details of converted model. # Check tensor details of converted model.
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() interpreter.allocate_tensors()
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details)) self.assertLen(input_details, 1)
self.assertEqual('input', input_details[0]['name']) self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype']) self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 3] == input_details[0]['shape']).all()) self.assertTrue(([1, 3] == input_details[0]['shape']).all())
self.assertEqual((0., 0.), input_details[0]['quantization']) self.assertEqual((0., 0.), input_details[0]['quantization'])
output_details = interpreter.get_output_details() output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details)) self.assertLen(output_details, 1)
self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype']) self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 3] == output_details[0]['shape']).all()) self.assertTrue(([1, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization']) self.assertEqual((0., 0.), output_details[0]['quantization'])
@ -1367,55 +1452,51 @@ class FromKerasFile(test_util.TensorFlowTestCase):
interpreter.invoke() interpreter.invoke()
tflite_result = interpreter.get_tensor(output_details[0]['index']) tflite_result = interpreter.get_tensor(output_details[0]['index'])
keras_model = keras.models.load_model(keras_file) keras_model = keras.models.load_model(self._keras_file)
keras_result = keras_model.predict(input_data) keras_result = keras_model.predict(input_data)
np.testing.assert_almost_equal(tflite_result, keras_result, 5) np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)
def testFunctionalModelMultipleInputs(self): def testFunctionalModelMultipleInputs(self):
"""Test a Functional tf.keras model with multiple inputs and outputs.""" """Test a Functional tf.keras model with multiple inputs and outputs."""
with session.Session().as_default(): a = keras.layers.Input(shape=(3,), name='input_a')
a = keras.layers.Input(shape=(3,), name='input_a') b = keras.layers.Input(shape=(3,), name='input_b')
b = keras.layers.Input(shape=(3,), name='input_b') dense = keras.layers.Dense(4, name='dense')
dense = keras.layers.Dense(4, name='dense') c = dense(a)
c = dense(a) d = dense(b)
d = dense(b) e = keras.layers.Dropout(0.5, name='dropout')(c)
e = keras.layers.Dropout(0.5, name='dropout')(c)
model = keras.models.Model([a, b], [d, e]) model = keras.models.Model([a, b], [d, e])
model.compile( model.compile(
loss=keras.losses.MSE, loss=keras.losses.MSE,
optimizer=keras.optimizers.RMSprop(), optimizer='sgd',
metrics=[keras.metrics.mae], metrics=[keras.metrics.mae],
loss_weights=[1., 0.5]) loss_weights=[1., 0.5])
input_a_np = np.random.random((10, 3)) input_a_np = np.random.random((10, 3))
input_b_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 3))
output_d_np = np.random.random((10, 4)) output_d_np = np.random.random((10, 4))
output_e_np = np.random.random((10, 4)) output_e_np = np.random.random((10, 4))
model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
model.predict([input_a_np, input_b_np], batch_size=5) model.predict([input_a_np, input_b_np], batch_size=5)
fd, keras_file = tempfile.mkstemp('.h5') fd, self._keras_file = tempfile.mkstemp('.h5')
try: try:
keras.models.save_model(model, keras_file) keras.models.save_model(model, self._keras_file)
finally: finally:
os.close(fd) os.close(fd)
# Convert to TFLite model. # Convert to TFLite model.
converter = lite.TFLiteConverter.from_keras_model_file(keras_file) converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
tflite_model = converter.convert() tflite_model = converter.convert()
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
os.remove(keras_file)
# Check values from converted model. # Check values from converted model.
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() interpreter.allocate_tensors()
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
self.assertEqual(2, len(input_details)) self.assertLen(input_details, 2)
self.assertEqual('input_a', input_details[0]['name']) self.assertEqual('input_a', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype']) self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 3] == input_details[0]['shape']).all()) self.assertTrue(([1, 3] == input_details[0]['shape']).all())
@ -1427,7 +1508,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
self.assertEqual((0., 0.), input_details[1]['quantization']) self.assertEqual((0., 0.), input_details[1]['quantization'])
output_details = interpreter.get_output_details() output_details = interpreter.get_output_details()
self.assertEqual(2, len(output_details)) self.assertLen(output_details, 2)
self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype']) self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 4] == output_details[0]['shape']).all()) self.assertTrue(([1, 4] == output_details[0]['shape']).all())
@ -1440,32 +1521,31 @@ class FromKerasFile(test_util.TensorFlowTestCase):
def testFunctionalSequentialModel(self): def testFunctionalSequentialModel(self):
"""Test a Functional tf.keras model containing a Sequential model.""" """Test a Functional tf.keras model containing a Sequential model."""
with session.Session().as_default(): model = keras.models.Sequential()
model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.RepeatVector(3))
model.add(keras.layers.RepeatVector(3)) model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) model = keras.models.Model(model.input, model.output)
model = keras.models.Model(model.input, model.output)
model.compile( model.compile(
loss=keras.losses.MSE, loss=keras.losses.MSE,
optimizer=keras.optimizers.RMSprop(), optimizer='sgd',
metrics=[keras.metrics.categorical_accuracy], metrics=[keras.metrics.categorical_accuracy],
sample_weight_mode='temporal') sample_weight_mode='temporal')
x = np.random.random((1, 3)) x = np.random.random((1, 3))
y = np.random.random((1, 3, 3)) y = np.random.random((1, 3, 3))
model.train_on_batch(x, y) model.train_on_batch(x, y)
model.predict(x) model.predict(x)
model.predict(x) model.predict(x)
fd, keras_file = tempfile.mkstemp('.h5') fd, self._keras_file = tempfile.mkstemp('.h5')
try: try:
keras.models.save_model(model, keras_file) keras.models.save_model(model, self._keras_file)
finally: finally:
os.close(fd) os.close(fd)
# Convert to TFLite model. # Convert to TFLite model.
converter = lite.TFLiteConverter.from_keras_model_file(keras_file) converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
tflite_model = converter.convert() tflite_model = converter.convert()
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
@ -1474,14 +1554,14 @@ class FromKerasFile(test_util.TensorFlowTestCase):
interpreter.allocate_tensors() interpreter.allocate_tensors()
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details)) self.assertLen(input_details, 1)
self.assertEqual('dense_input', input_details[0]['name']) self.assertEqual('dense_input', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype']) self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 3] == input_details[0]['shape']).all()) self.assertTrue(([1, 3] == input_details[0]['shape']).all())
self.assertEqual((0., 0.), input_details[0]['quantization']) self.assertEqual((0., 0.), input_details[0]['quantization'])
output_details = interpreter.get_output_details() output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details)) self.assertLen(output_details, 1)
self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype']) self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
@ -1493,17 +1573,16 @@ class FromKerasFile(test_util.TensorFlowTestCase):
interpreter.invoke() interpreter.invoke()
tflite_result = interpreter.get_tensor(output_details[0]['index']) tflite_result = interpreter.get_tensor(output_details[0]['index'])
keras_model = keras.models.load_model(keras_file) keras_model = keras.models.load_model(self._keras_file)
keras_result = keras_model.predict(input_data) keras_result = keras_model.predict(input_data)
np.testing.assert_almost_equal(tflite_result, keras_result, 5) np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)
def testSequentialModelTocoConverter(self): def testSequentialModelTocoConverter(self):
"""Test a Sequential tf.keras model with deprecated TocoConverter.""" """Test a Sequential tf.keras model with deprecated TocoConverter."""
keras_file = self._getSequentialModel() self._getSequentialModel()
converter = lite.TocoConverter.from_keras_model_file(keras_file) converter = lite.TocoConverter.from_keras_model_file(self._keras_file)
tflite_model = converter.convert() tflite_model = converter.convert()
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
@ -1511,84 +1590,6 @@ class FromKerasFile(test_util.TensorFlowTestCase):
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() interpreter.allocate_tensors()
def testInferenceInputOutputTypeFloatDefault(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('add', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
def testInferenceInputOutputTypeQuantizedUint8Default(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = array_ops.fake_quant_with_min_max_args(
in_tensor + in_tensor, min=0., max=1., name='output')
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.uint8, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
def testReusingConverterWithDifferentPostTrainingQuantization(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = array_ops.fake_quant_with_min_max_args(
in_tensor + in_tensor, min=0., max=1., name='output')
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.post_training_quantize = True
tflite_model = converter.convert()
self.assertTrue(tflite_model)
converter.post_training_quantize = False
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@test_util.run_v1_only('Incompatible with 2.0.') @test_util.run_v1_only('Incompatible with 2.0.')
class GrapplerTest(test_util.TensorFlowTestCase): class GrapplerTest(test_util.TensorFlowTestCase):