Fix saved_model_cli _print_tensor_info for REF types.
Fix #15611. PiperOrigin-RevId: 180292752
This commit is contained in:
parent
711b10c280
commit
2e2715baa8
@ -251,6 +251,7 @@ py_test(
|
|||||||
tags = ["manual"],
|
tags = ["manual"],
|
||||||
deps = [
|
deps = [
|
||||||
":saved_model_cli",
|
":saved_model_cli",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -152,7 +152,9 @@ def _print_tensor_info(tensor_info):
|
|||||||
Args:
|
Args:
|
||||||
tensor_info: TensorInfo object to be printed.
|
tensor_info: TensorInfo object to be printed.
|
||||||
"""
|
"""
|
||||||
print(' dtype: ' + types_pb2.DataType.keys()[tensor_info.dtype])
|
print(' dtype: ' +
|
||||||
|
{value: key
|
||||||
|
for (key, value) in types_pb2.DataType.items()}[tensor_info.dtype])
|
||||||
# Display shape as tuple.
|
# Display shape as tuple.
|
||||||
if tensor_info.tensor_shape.unknown_rank:
|
if tensor_info.tensor_shape.unknown_rank:
|
||||||
shape = 'unknown_rank'
|
shape = 'unknown_rank'
|
||||||
|
@ -28,6 +28,8 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from six import StringIO
|
from six import StringIO
|
||||||
|
|
||||||
|
from tensorflow.core.framework import types_pb2
|
||||||
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.tools import saved_model_cli
|
from tensorflow.python.tools import saved_model_cli
|
||||||
@ -200,6 +202,14 @@ Method name is: tensorflow/serving/predict"""
|
|||||||
self.assertEqual(output, expected_output)
|
self.assertEqual(output, expected_output)
|
||||||
self.assertEqual(err.getvalue().strip(), '')
|
self.assertEqual(err.getvalue().strip(), '')
|
||||||
|
|
||||||
|
def testPrintREFTypeTensor(self):
|
||||||
|
ref_tensor_info = meta_graph_pb2.TensorInfo()
|
||||||
|
ref_tensor_info.dtype = types_pb2.DT_FLOAT_REF
|
||||||
|
with captured_output() as (out, err):
|
||||||
|
saved_model_cli._print_tensor_info(ref_tensor_info)
|
||||||
|
self.assertTrue('DT_FLOAT_REF' in out.getvalue().strip())
|
||||||
|
self.assertEqual(err.getvalue().strip(), '')
|
||||||
|
|
||||||
def testInputPreProcessFormats(self):
|
def testInputPreProcessFormats(self):
|
||||||
input_str = 'input1=/path/file.txt[ab3];input2=file2'
|
input_str = 'input1=/path/file.txt[ab3];input2=file2'
|
||||||
input_expr_str = 'input3=np.zeros([2,2]);input4=[4,5]'
|
input_expr_str = 'input3=np.zeros([2,2]);input4=[4,5]'
|
||||||
@ -217,7 +227,6 @@ Method name is: tensorflow/serving/predict"""
|
|||||||
input_str = (r'inputx=C:\Program Files\data.npz[v:0];'
|
input_str = (r'inputx=C:\Program Files\data.npz[v:0];'
|
||||||
r'input:0=c:\PROGRA~1\data.npy')
|
r'input:0=c:\PROGRA~1\data.npy')
|
||||||
input_dict = saved_model_cli.preprocess_inputs_arg_string(input_str)
|
input_dict = saved_model_cli.preprocess_inputs_arg_string(input_str)
|
||||||
print(input_dict)
|
|
||||||
self.assertTrue(input_dict['inputx'] == (r'C:\Program Files\data.npz',
|
self.assertTrue(input_dict['inputx'] == (r'C:\Program Files\data.npz',
|
||||||
'v:0'))
|
'v:0'))
|
||||||
self.assertTrue(input_dict['input:0'] == (r'c:\PROGRA~1\data.npy', None))
|
self.assertTrue(input_dict['input:0'] == (r'c:\PROGRA~1\data.npy', None))
|
||||||
|
Loading…
Reference in New Issue
Block a user