In tensor_name function, don't drop the suffix behind colon if it's not the first tensor. For example, it should make sure:

tensor:0 -> tensor
tensor:1 -> tensor:1
tensor:2 -> tensor:2

PiperOrigin-RevId: 238447941
This commit is contained in:
Haoliang Zhang 2019-03-14 08:33:16 -07:00 committed by TensorFlower Gardener
parent fa5e3a9682
commit 1020739e17
3 changed files with 51 additions and 1 deletions

View File

@ -214,7 +214,17 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
def tensor_name(x):
return x.name.split(":")[0]
"""Returns name of the input tensor."""
parts = x.name.split(":")
if len(parts) > 2:
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
len(parts) - 1))
# To be consistent with the tensor naming scheme in tensorflow, we need
# drop the ':0' suffix for the first tensor.
if len(parts) > 1 and parts[1] != "0":
return x.name
return parts[0]
# Don't expose these for now.

View File

@ -55,6 +55,17 @@ class ConvertTest(test_util.TensorFlowTestCase):
# with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
# result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
def testTensorName(self):
in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
# out_tensors should have names: "split:0", "split:1", "split:2", "split:3".
out_tensors = array_ops.split(
value=in_tensor, num_or_size_splits=[1, 1, 1, 1], axis=0)
expect_names = ["split", "split:1", "split:2", "split:3"]
for i in range(len(expect_names)):
got_name = convert.tensor_name(out_tensors[i])
self.assertEqual(got_name, expect_names[i])
def testQuantization(self):
in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
dtype=dtypes.float32)

View File

@ -597,6 +597,35 @@ class FromSessionTest(test_util.TensorFlowTestCase):
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
def testMultipleOutputNodeNames(self):
"""Tests converting a graph with an op that have multiple outputs."""
input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
out0, out1, out2, out3 = array_ops.split(input_tensor, [1, 1, 1, 1], axis=0)
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
[out0, out1, out2, out3])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
interpreter.set_tensor(input_details[0]['index'],
np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
interpreter.invoke()
output_details = interpreter.get_output_details()
self.assertEqual(4, len(output_details))
self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index']))
self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index']))
self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index']))
self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index']))
@test_util.run_v1_only('b/120545219')
class FromFrozenGraphFile(test_util.TensorFlowTestCase):