Internal change
PiperOrigin-RevId: 357610788 Change-Id: I8ac3355ab977d9dedc9657f622471aea5276d9c5
This commit is contained in:
parent
1ecfc7eff0
commit
5292800ba9
@ -24,7 +24,7 @@ from tensorflow.lite.tools import visualize
|
||||
|
||||
|
||||
def get_ops_list(model_data):
|
||||
"""Return a set of ops in the tflite model data."""
|
||||
"""Returns a set of ops in the tflite model data."""
|
||||
model = schema_fb.Model.GetRootAsModel(model_data, 0)
|
||||
op_set = set()
|
||||
|
||||
@ -40,3 +40,18 @@ def get_ops_list(model_data):
|
||||
else:
|
||||
op_set.add(visualize.BuiltinCodeToName(builtin_code))
|
||||
return op_set
|
||||
|
||||
|
||||
def get_output_shapes(model_data):
|
||||
"""Returns a list of output shapes in the tflite model data."""
|
||||
model = schema_fb.Model.GetRootAsModel(model_data, 0)
|
||||
|
||||
output_shapes = []
|
||||
for subgraph_idx in range(model.SubgraphsLength()):
|
||||
subgraph = model.Subgraphs(subgraph_idx)
|
||||
for output_idx in range(subgraph.OutputsLength()):
|
||||
output_tensor_idx = subgraph.Outputs(output_idx)
|
||||
output_tensor = subgraph.Tensors(output_tensor_idx)
|
||||
output_shapes.append(output_tensor.ShapeAsNumpy().tolist())
|
||||
|
||||
return output_shapes
|
||||
|
@ -54,7 +54,8 @@ class TestModels(test_util.TensorFlowTestCase):
|
||||
def _run(self,
|
||||
flags_str,
|
||||
should_succeed,
|
||||
expected_ops_in_converted_model=None):
|
||||
expected_ops_in_converted_model=None,
|
||||
expected_output_shapes=None):
|
||||
output_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
tflite_bin = resource_loader.get_path_to_datafile('tflite_convert')
|
||||
cmdline = '{0} --output_file={1} {2}'.format(tflite_bin, output_file,
|
||||
@ -69,6 +70,9 @@ class TestModels(test_util.TensorFlowTestCase):
|
||||
op_set = tflite_test_util.get_ops_list(content)
|
||||
for opname in expected_ops_in_converted_model:
|
||||
self.assertIn(opname, op_set)
|
||||
if expected_output_shapes:
|
||||
output_shapes = tflite_test_util.get_output_shapes(content)
|
||||
self.assertEqual(output_shapes, expected_output_shapes)
|
||||
os.remove(output_file)
|
||||
else:
|
||||
self.assertFalse(should_succeed)
|
||||
@ -88,6 +92,17 @@ class TestModels(test_util.TensorFlowTestCase):
|
||||
keras.models.save_model(model, keras_file)
|
||||
return keras_file
|
||||
|
||||
def _getKerasFunctionalModelFile(self):
|
||||
"""Returns a functional Keras model with output shapes [[1, 1], [1, 2]]."""
|
||||
input_tensor = keras.layers.Input(shape=(1,))
|
||||
output1 = keras.layers.Dense(1, name='b')(input_tensor)
|
||||
output2 = keras.layers.Dense(2, name='a')(input_tensor)
|
||||
model = keras.models.Model(inputs=input_tensor, outputs=[output1, output2])
|
||||
|
||||
keras_file = self._getFilepath('functional_model.h5')
|
||||
keras.models.save_model(model, keras_file)
|
||||
return keras_file
|
||||
|
||||
|
||||
class TfLiteConvertV1Test(TestModels):
|
||||
|
||||
@ -482,6 +497,25 @@ class TfLiteConvertV2Test(TestModels):
|
||||
self._run(flags_str, should_succeed=True)
|
||||
os.remove(keras_file)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testFunctionalKerasModel(self):
|
||||
keras_file = self._getKerasFunctionalModelFile()
|
||||
|
||||
flags_str = '--keras_model_file={}'.format(keras_file)
|
||||
self._run(flags_str, should_succeed=True,
|
||||
expected_output_shapes=[[1, 1], [1, 2]])
|
||||
os.remove(keras_file)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testFunctionalKerasModelMLIR(self):
|
||||
keras_file = self._getKerasFunctionalModelFile()
|
||||
|
||||
flags_str = (
|
||||
'--keras_model_file={} --experimental_new_converter'.format(keras_file))
|
||||
self._run(flags_str, should_succeed=True,
|
||||
expected_output_shapes=[[1, 1], [1, 2]])
|
||||
os.remove(keras_file)
|
||||
|
||||
def testMissingRequired(self):
|
||||
self._run('--invalid_args', should_succeed=False)
|
||||
|
||||
|
@ -183,11 +183,6 @@ def trace_model_call(model, input_signature=None):
|
||||
model, inputs=inputs, build_graph=False, training=False, saving=True):
|
||||
outputs = model(inputs, training=False)
|
||||
|
||||
# Outputs always has to be a flat dict.
|
||||
output_names = model.output_names # Functional Model.
|
||||
if output_names is None: # Subclassed Model.
|
||||
output_names = create_pseudo_output_names(outputs)
|
||||
outputs = nest.flatten(outputs)
|
||||
return {name: output for name, output in zip(output_names, outputs)}
|
||||
return outputs
|
||||
|
||||
return _wrapped_model
|
||||
|
Loading…
x
Reference in New Issue
Block a user