In tensor_name function, don't drop the suffix behind colon if it's not the first tensor. For example, it should make sure:
tensor:0 -> tensor tensor:1 -> tensor:1 tensor:2 -> tensor:2 PiperOrigin-RevId: 238447941
This commit is contained in:
parent
fa5e3a9682
commit
1020739e17
@ -214,7 +214,17 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
|||||||
|
|
||||||
|
|
||||||
def tensor_name(x):
|
def tensor_name(x):
|
||||||
return x.name.split(":")[0]
|
"""Returns name of the input tensor."""
|
||||||
|
parts = x.name.split(":")
|
||||||
|
if len(parts) > 2:
|
||||||
|
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
|
||||||
|
len(parts) - 1))
|
||||||
|
|
||||||
|
# To be consistent with the tensor naming scheme in tensorflow, we need
|
||||||
|
# drop the ':0' suffix for the first tensor.
|
||||||
|
if len(parts) > 1 and parts[1] != "0":
|
||||||
|
return x.name
|
||||||
|
return parts[0]
|
||||||
|
|
||||||
|
|
||||||
# Don't expose these for now.
|
# Don't expose these for now.
|
||||||
|
@ -55,6 +55,17 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
# with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
|
# with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
|
||||||
# result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
|
# result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
|
||||||
|
|
||||||
|
def testTensorName(self):
|
||||||
|
in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
|
||||||
|
# out_tensors should have names: "split:0", "split:1", "split:2", "split:3".
|
||||||
|
out_tensors = array_ops.split(
|
||||||
|
value=in_tensor, num_or_size_splits=[1, 1, 1, 1], axis=0)
|
||||||
|
expect_names = ["split", "split:1", "split:2", "split:3"]
|
||||||
|
|
||||||
|
for i in range(len(expect_names)):
|
||||||
|
got_name = convert.tensor_name(out_tensors[i])
|
||||||
|
self.assertEqual(got_name, expect_names[i])
|
||||||
|
|
||||||
def testQuantization(self):
|
def testQuantization(self):
|
||||||
in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
|
in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
|
||||||
dtype=dtypes.float32)
|
dtype=dtypes.float32)
|
||||||
|
@ -597,6 +597,35 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
|||||||
interpreter = Interpreter(model_content=tflite_model)
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
def testMultipleOutputNodeNames(self):
|
||||||
|
"""Tests converting a graph with an op that have multiple outputs."""
|
||||||
|
input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
|
||||||
|
out0, out1, out2, out3 = array_ops.split(input_tensor, [1, 1, 1, 1], axis=0)
|
||||||
|
sess = session.Session()
|
||||||
|
|
||||||
|
# Convert model and ensure model is not None.
|
||||||
|
converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
|
||||||
|
[out0, out1, out2, out3])
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
|
# Check values from converted model.
|
||||||
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
self.assertEqual(1, len(input_details))
|
||||||
|
interpreter.set_tensor(input_details[0]['index'],
|
||||||
|
np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
|
||||||
|
interpreter.invoke()
|
||||||
|
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
self.assertEqual(4, len(output_details))
|
||||||
|
self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index']))
|
||||||
|
self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index']))
|
||||||
|
self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index']))
|
||||||
|
self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index']))
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
@test_util.run_v1_only('b/120545219')
|
||||||
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user