In tensor_name function, don't drop the suffix behind colon if it's not the first tensor. For example, it should make sure:
tensor:0 -> tensor tensor:1 -> tensor:1 tensor:2 -> tensor:2 PiperOrigin-RevId: 238447941
This commit is contained in:
parent
fa5e3a9682
commit
1020739e17
@ -214,7 +214,17 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
||||
|
||||
|
||||
def tensor_name(x):
|
||||
return x.name.split(":")[0]
|
||||
"""Returns name of the input tensor."""
|
||||
parts = x.name.split(":")
|
||||
if len(parts) > 2:
|
||||
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
|
||||
len(parts) - 1))
|
||||
|
||||
# To be consistent with the tensor naming scheme in tensorflow, we need
|
||||
# drop the ':0' suffix for the first tensor.
|
||||
if len(parts) > 1 and parts[1] != "0":
|
||||
return x.name
|
||||
return parts[0]
|
||||
|
||||
|
||||
# Don't expose these for now.
|
||||
|
@ -55,6 +55,17 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
# with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
|
||||
# result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
|
||||
|
||||
def testTensorName(self):
|
||||
in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
|
||||
# out_tensors should have names: "split:0", "split:1", "split:2", "split:3".
|
||||
out_tensors = array_ops.split(
|
||||
value=in_tensor, num_or_size_splits=[1, 1, 1, 1], axis=0)
|
||||
expect_names = ["split", "split:1", "split:2", "split:3"]
|
||||
|
||||
for i in range(len(expect_names)):
|
||||
got_name = convert.tensor_name(out_tensors[i])
|
||||
self.assertEqual(got_name, expect_names[i])
|
||||
|
||||
def testQuantization(self):
|
||||
in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
|
||||
dtype=dtypes.float32)
|
||||
|
@ -597,6 +597,35 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
def testMultipleOutputNodeNames(self):
|
||||
"""Tests converting a graph with an op that have multiple outputs."""
|
||||
input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
|
||||
out0, out1, out2, out3 = array_ops.split(input_tensor, [1, 1, 1, 1], axis=0)
|
||||
sess = session.Session()
|
||||
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
|
||||
[out0, out1, out2, out3])
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(1, len(input_details))
|
||||
interpreter.set_tensor(input_details[0]['index'],
|
||||
np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
|
||||
interpreter.invoke()
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(4, len(output_details))
|
||||
self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index']))
|
||||
self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index']))
|
||||
self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index']))
|
||||
self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index']))
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user