diff --git a/tensorflow/lite/python/test_util.py b/tensorflow/lite/python/test_util.py index da9453b547f..3da1e80fc22 100644 --- a/tensorflow/lite/python/test_util.py +++ b/tensorflow/lite/python/test_util.py @@ -24,7 +24,7 @@ from tensorflow.lite.tools import visualize def get_ops_list(model_data): - """Return a set of ops in the tflite model data.""" + """Returns a set of ops in the tflite model data.""" model = schema_fb.Model.GetRootAsModel(model_data, 0) op_set = set() @@ -40,3 +40,18 @@ def get_ops_list(model_data): else: op_set.add(visualize.BuiltinCodeToName(builtin_code)) return op_set + + +def get_output_shapes(model_data): + """Returns a list of output shapes in the tflite model data.""" + model = schema_fb.Model.GetRootAsModel(model_data, 0) + + output_shapes = [] + for subgraph_idx in range(model.SubgraphsLength()): + subgraph = model.Subgraphs(subgraph_idx) + for output_idx in range(subgraph.OutputsLength()): + output_tensor_idx = subgraph.Outputs(output_idx) + output_tensor = subgraph.Tensors(output_tensor_idx) + output_shapes.append(output_tensor.ShapeAsNumpy().tolist()) + + return output_shapes diff --git a/tensorflow/lite/python/tflite_convert_test.py b/tensorflow/lite/python/tflite_convert_test.py index 7c66e31d62f..71810cf75d2 100644 --- a/tensorflow/lite/python/tflite_convert_test.py +++ b/tensorflow/lite/python/tflite_convert_test.py @@ -54,7 +54,8 @@ class TestModels(test_util.TensorFlowTestCase): def _run(self, flags_str, should_succeed, - expected_ops_in_converted_model=None): + expected_ops_in_converted_model=None, + expected_output_shapes=None): output_file = os.path.join(self.get_temp_dir(), 'model.tflite') tflite_bin = resource_loader.get_path_to_datafile('tflite_convert') cmdline = '{0} --output_file={1} {2}'.format(tflite_bin, output_file, @@ -69,6 +70,9 @@ class TestModels(test_util.TensorFlowTestCase): op_set = tflite_test_util.get_ops_list(content) for opname in expected_ops_in_converted_model: self.assertIn(opname, op_set) + if expected_output_shapes: + output_shapes = tflite_test_util.get_output_shapes(content) + self.assertEqual(output_shapes, expected_output_shapes) os.remove(output_file) else: self.assertFalse(should_succeed) @@ -88,6 +92,17 @@ class TestModels(test_util.TensorFlowTestCase): keras.models.save_model(model, keras_file) return keras_file + def _getKerasFunctionalModelFile(self): + """Returns a functional Keras model with output shapes [[1, 1], [1, 2]].""" + input_tensor = keras.layers.Input(shape=(1,)) + output1 = keras.layers.Dense(1, name='b')(input_tensor) + output2 = keras.layers.Dense(2, name='a')(input_tensor) + model = keras.models.Model(inputs=input_tensor, outputs=[output1, output2]) + + keras_file = self._getFilepath('functional_model.h5') + keras.models.save_model(model, keras_file) + return keras_file + class TfLiteConvertV1Test(TestModels): @@ -482,6 +497,25 @@ class TfLiteConvertV2Test(TestModels): self._run(flags_str, should_succeed=True) os.remove(keras_file) + @test_util.run_v2_only + def testFunctionalKerasModel(self): + keras_file = self._getKerasFunctionalModelFile() + + flags_str = '--keras_model_file={}'.format(keras_file) + self._run(flags_str, should_succeed=True, + expected_output_shapes=[[1, 1], [1, 2]]) + os.remove(keras_file) + + @test_util.run_v2_only + def testFunctionalKerasModelMLIR(self): + keras_file = self._getKerasFunctionalModelFile() + + flags_str = ( + '--keras_model_file={} --experimental_new_converter'.format(keras_file)) + self._run(flags_str, should_succeed=True, + expected_output_shapes=[[1, 1], [1, 2]]) + os.remove(keras_file) + def testMissingRequired(self): self._run('--invalid_args', should_succeed=False) diff --git a/tensorflow/lite/python/tflite_keras_util.py b/tensorflow/lite/python/tflite_keras_util.py index c9f5b40c1e0..21f88731eaa 100644 --- a/tensorflow/lite/python/tflite_keras_util.py +++ b/tensorflow/lite/python/tflite_keras_util.py @@ -183,11 +183,6 @@ def trace_model_call(model, input_signature=None): model, inputs=inputs, build_graph=False, training=False, saving=True): outputs = model(inputs, training=False) - # Outputs always has to be a flat dict. - output_names = model.output_names # Functional Model. - if output_names is None: # Subclassed Model. - output_names = create_pseudo_output_names(outputs) - outputs = nest.flatten(outputs) - return {name: output for name, output in zip(output_names, outputs)} + return outputs return _wrapped_model