From 2e2715baa84720f786b38d1f9cb6887399020d6f Mon Sep 17 00:00:00 2001 From: Yifei Feng Date: Thu, 28 Dec 2017 13:57:18 -0800 Subject: [PATCH] Fix saved_model_cli _print_tensor_info for REF types. Fix #15611. PiperOrigin-RevId: 180292752 --- tensorflow/python/tools/BUILD | 1 + tensorflow/python/tools/saved_model_cli.py | 4 +++- tensorflow/python/tools/saved_model_cli_test.py | 11 ++++++++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 69586c6a477..63f16c53a29 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -251,6 +251,7 @@ py_test( tags = ["manual"], deps = [ ":saved_model_cli", + "//tensorflow/core:protos_all_py", ], ) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index cff2c186e38..b4f6fea7261 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -152,7 +152,9 @@ def _print_tensor_info(tensor_info): Args: tensor_info: TensorInfo object to be printed. """ - print(' dtype: ' + types_pb2.DataType.keys()[tensor_info.dtype]) + print(' dtype: ' + + {value: key + for (key, value) in types_pb2.DataType.items()}[tensor_info.dtype]) # Display shape as tuple. if tensor_info.tensor_shape.unknown_rank: shape = 'unknown_rank' diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index a55cf168b23..0789e1e107c 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -28,6 +28,8 @@ import sys import numpy as np from six import StringIO +from tensorflow.core.framework import types_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.platform import test from tensorflow.python.tools import saved_model_cli @@ -200,6 +202,14 @@ Method name is: tensorflow/serving/predict""" self.assertEqual(output, expected_output) self.assertEqual(err.getvalue().strip(), '') + def testPrintREFTypeTensor(self): + ref_tensor_info = meta_graph_pb2.TensorInfo() + ref_tensor_info.dtype = types_pb2.DT_FLOAT_REF + with captured_output() as (out, err): + saved_model_cli._print_tensor_info(ref_tensor_info) + self.assertTrue('DT_FLOAT_REF' in out.getvalue().strip()) + self.assertEqual(err.getvalue().strip(), '') + def testInputPreProcessFormats(self): input_str = 'input1=/path/file.txt[ab3];input2=file2' input_expr_str = 'input3=np.zeros([2,2]);input4=[4,5]' @@ -217,7 +227,6 @@ Method name is: tensorflow/serving/predict""" input_str = (r'inputx=C:\Program Files\data.npz[v:0];' r'input:0=c:\PROGRA~1\data.npy') input_dict = saved_model_cli.preprocess_inputs_arg_string(input_str) - print(input_dict) self.assertTrue(input_dict['inputx'] == (r'C:\Program Files\data.npz', 'v:0')) self.assertTrue(input_dict['input:0'] == (r'c:\PROGRA~1\data.npy', None))