Internal change

PiperOrigin-RevId: 357610788
Change-Id: I8ac3355ab977d9dedc9657f622471aea5276d9c5
This commit is contained in:
A. Unique TensorFlower 2021-02-15 14:12:17 -08:00 committed by TensorFlower Gardener
parent 1ecfc7eff0
commit 5292800ba9
3 changed files with 52 additions and 8 deletions

View File

@ -24,7 +24,7 @@ from tensorflow.lite.tools import visualize
def get_ops_list(model_data):
"""Return a set of ops in the tflite model data."""
"""Returns a set of ops in the tflite model data."""
model = schema_fb.Model.GetRootAsModel(model_data, 0)
op_set = set()
@ -40,3 +40,18 @@ def get_ops_list(model_data):
else:
op_set.add(visualize.BuiltinCodeToName(builtin_code))
return op_set
def get_output_shapes(model_data):
"""Returns a list of output shapes in the tflite model data."""
model = schema_fb.Model.GetRootAsModel(model_data, 0)
output_shapes = []
for subgraph_idx in range(model.SubgraphsLength()):
subgraph = model.Subgraphs(subgraph_idx)
for output_idx in range(subgraph.OutputsLength()):
output_tensor_idx = subgraph.Outputs(output_idx)
output_tensor = subgraph.Tensors(output_tensor_idx)
output_shapes.append(output_tensor.ShapeAsNumpy().tolist())
return output_shapes

View File

@ -54,7 +54,8 @@ class TestModels(test_util.TensorFlowTestCase):
def _run(self,
flags_str,
should_succeed,
expected_ops_in_converted_model=None):
expected_ops_in_converted_model=None,
expected_output_shapes=None):
output_file = os.path.join(self.get_temp_dir(), 'model.tflite')
tflite_bin = resource_loader.get_path_to_datafile('tflite_convert')
cmdline = '{0} --output_file={1} {2}'.format(tflite_bin, output_file,
@ -69,6 +70,9 @@ class TestModels(test_util.TensorFlowTestCase):
op_set = tflite_test_util.get_ops_list(content)
for opname in expected_ops_in_converted_model:
self.assertIn(opname, op_set)
if expected_output_shapes:
output_shapes = tflite_test_util.get_output_shapes(content)
self.assertEqual(output_shapes, expected_output_shapes)
os.remove(output_file)
else:
self.assertFalse(should_succeed)
@ -88,6 +92,17 @@ class TestModels(test_util.TensorFlowTestCase):
keras.models.save_model(model, keras_file)
return keras_file
def _getKerasFunctionalModelFile(self):
"""Returns a functional Keras model with output shapes [[1, 1], [1, 2]]."""
input_tensor = keras.layers.Input(shape=(1,))
output1 = keras.layers.Dense(1, name='b')(input_tensor)
output2 = keras.layers.Dense(2, name='a')(input_tensor)
model = keras.models.Model(inputs=input_tensor, outputs=[output1, output2])
keras_file = self._getFilepath('functional_model.h5')
keras.models.save_model(model, keras_file)
return keras_file
class TfLiteConvertV1Test(TestModels):
@ -482,6 +497,25 @@ class TfLiteConvertV2Test(TestModels):
self._run(flags_str, should_succeed=True)
os.remove(keras_file)
@test_util.run_v2_only
def testFunctionalKerasModel(self):
keras_file = self._getKerasFunctionalModelFile()
flags_str = '--keras_model_file={}'.format(keras_file)
self._run(flags_str, should_succeed=True,
expected_output_shapes=[[1, 1], [1, 2]])
os.remove(keras_file)
@test_util.run_v2_only
def testFunctionalKerasModelMLIR(self):
keras_file = self._getKerasFunctionalModelFile()
flags_str = (
'--keras_model_file={} --experimental_new_converter'.format(keras_file))
self._run(flags_str, should_succeed=True,
expected_output_shapes=[[1, 1], [1, 2]])
os.remove(keras_file)
def testMissingRequired(self):
self._run('--invalid_args', should_succeed=False)

View File

@ -183,11 +183,6 @@ def trace_model_call(model, input_signature=None):
model, inputs=inputs, build_graph=False, training=False, saving=True):
outputs = model(inputs, training=False)
# Outputs always has to be a flat dict.
output_names = model.output_names # Functional Model.
if output_names is None: # Subclassed Model.
output_names = create_pseudo_output_names(outputs)
outputs = nest.flatten(outputs)
return {name: output for name, output in zip(output_names, outputs)}
return outputs
return _wrapped_model