From 1020739e1724346ef0088186f0253defaad3edb6 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Thu, 14 Mar 2019 08:33:16 -0700 Subject: [PATCH] In tensor_name function, don't drop the suffix behind colon if it's not the first tensor. For example, it should make sure: tensor:0 -> tensor tensor:1 -> tensor:1 tensor:2 -> tensor:2 PiperOrigin-RevId: 238447941 --- tensorflow/lite/python/convert.py | 12 ++++++++++- tensorflow/lite/python/convert_test.py | 11 ++++++++++ tensorflow/lite/python/lite_test.py | 29 ++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index b933ef1ea04..c3f15816e25 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -214,7 +214,17 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): def tensor_name(x): - return x.name.split(":")[0] + """Returns name of the input tensor.""" + parts = x.name.split(":") + if len(parts) > 2: + raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format( + len(parts) - 1)) + + # To be consistent with the tensor naming scheme in tensorflow, we need + # drop the ':0' suffix for the first tensor. + if len(parts) > 1 and parts[1] != "0": + return x.name + return parts[0] # Don't expose these for now. diff --git a/tensorflow/lite/python/convert_test.py b/tensorflow/lite/python/convert_test.py index aecf346207e..12d8d494c1f 100644 --- a/tensorflow/lite/python/convert_test.py +++ b/tensorflow/lite/python/convert_test.py @@ -55,6 +55,17 @@ class ConvertTest(test_util.TensorFlowTestCase): # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"): # result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor]) + def testTensorName(self): + in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) + # out_tensors should have names: "split:0", "split:1", "split:2", "split:3". + out_tensors = array_ops.split( + value=in_tensor, num_or_size_splits=[1, 1, 1, 1], axis=0) + expect_names = ["split", "split:1", "split:2", "split:3"] + + for i in range(len(expect_names)): + got_name = convert.tensor_name(out_tensors[i]) + self.assertEqual(got_name, expect_names[i]) + def testQuantization(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.float32) diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 4bf9e0c679d..14d08ec70a6 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -597,6 +597,35 @@ class FromSessionTest(test_util.TensorFlowTestCase): interpreter = Interpreter(model_content=tflite_model) interpreter.allocate_tensors() + def testMultipleOutputNodeNames(self): + """Tests converting a graph with an op that have multiple outputs.""" + input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) + out0, out1, out2, out3 = array_ops.split(input_tensor, [1, 1, 1, 1], axis=0) + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverter.from_session(sess, [input_tensor], + [out0, out1, out2, out3]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + interpreter.set_tensor(input_details[0]['index'], + np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) + interpreter.invoke() + + output_details = interpreter.get_output_details() + self.assertEqual(4, len(output_details)) + self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index'])) + self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index'])) + self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index'])) + self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index'])) + @test_util.run_v1_only('b/120545219') class FromFrozenGraphFile(test_util.TensorFlowTestCase):