Fix saved_model_cli when trying to print pure ConcreteFunction arguments.
PiperOrigin-RevId: 286384508 Change-Id: I67a3997ba67a7b474a9e1157f6f3fb12dbe84007
This commit is contained in:
parent
6ce4169f68
commit
de5705ecbe
@ -31,12 +31,14 @@ import sys
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from six import integer_types
|
||||
import six
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import meta_graph as meta_graph_lib
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
@ -182,14 +184,25 @@ def _show_defined_functions(saved_model_dir):
|
||||
functions = sorted(functions.items(), key=lambda x: x[0])
|
||||
for name, function in functions:
|
||||
print('\n Function Name: \'%s\'' % name)
|
||||
concrete_functions = \
|
||||
function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
|
||||
concrete_functions = []
|
||||
if isinstance(function, defun.ConcreteFunction):
|
||||
concrete_functions.append(function)
|
||||
if isinstance(function, def_function.Function):
|
||||
concrete_functions.extend(
|
||||
function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access
|
||||
concrete_functions = sorted(concrete_functions, key=lambda x: x.name)
|
||||
for index, concrete_function in enumerate(concrete_functions, 1):
|
||||
args, kwargs = concrete_function.structured_input_signature
|
||||
print(' Option #%d' % index)
|
||||
print(' Callable with:')
|
||||
_print_args(args, indent=4)
|
||||
args, kwargs = None, None
|
||||
if concrete_function.structured_input_signature:
|
||||
args, kwargs = concrete_function.structured_input_signature
|
||||
elif concrete_function._arg_keywords: # pylint: disable=protected-access
|
||||
# For pure ConcreteFunctions we might have nothing better than
|
||||
# _arg_keywords.
|
||||
args = concrete_function._arg_keywords # pylint: disable=protected-access
|
||||
if args:
|
||||
print(' Option #%d' % index)
|
||||
print(' Callable with:')
|
||||
_print_args(args, indent=4)
|
||||
if kwargs:
|
||||
_print_args(kwargs, 'Named Argument', indent=4)
|
||||
|
||||
@ -215,7 +228,9 @@ def _print_args(arguments, argument_type='Argument', indent=0):
|
||||
for index, element in enumerate(arguments, 1):
|
||||
if indent == 4:
|
||||
in_print('%s #%d' % (argument_type, index))
|
||||
if isinstance(element, tensor_spec.TensorSpec):
|
||||
if isinstance(element, six.string_types):
|
||||
in_print(' %s' % element)
|
||||
elif isinstance(element, tensor_spec.TensorSpec):
|
||||
print((indent + 1) * ' ' + '%s: %s' % (element.name, repr(element)))
|
||||
elif (isinstance(element, collections.Iterable) and
|
||||
not isinstance(element, dict)):
|
||||
@ -567,7 +582,7 @@ def _create_example_string(example_dict):
|
||||
elif isinstance(feature_list[0], str):
|
||||
example.features.feature[feature_name].bytes_list.value.extend(
|
||||
feature_list)
|
||||
elif isinstance(feature_list[0], integer_types):
|
||||
elif isinstance(feature_list[0], six.integer_types):
|
||||
example.features.feature[feature_name].int64_list.value.extend(
|
||||
feature_list)
|
||||
else:
|
||||
|
@ -148,7 +148,7 @@ signature_def['serving_default']:
|
||||
self.assertMultiLineEqual(output, exp_out)
|
||||
self.assertEqual(err.getvalue().strip(), '')
|
||||
|
||||
def testShowAllWithConcreteFunctions(self):
|
||||
def testShowAllWithFunctions(self):
|
||||
|
||||
class DummyModel(tracking.AutoTrackable):
|
||||
"""Model with callable polymorphic functions specified."""
|
||||
@ -237,6 +237,73 @@ Defined Functions:
|
||||
self.assertMultiLineEqual(output, exp_out)
|
||||
self.assertEqual(err.getvalue().strip(), '')
|
||||
|
||||
def testShowAllWithPureConcreteFunction(self):
|
||||
|
||||
class DummyModel(tracking.AutoTrackable):
|
||||
"""Model with a callable concrete function."""
|
||||
|
||||
def __init__(self):
|
||||
function = def_function.function(
|
||||
self.multiply,
|
||||
input_signature=[
|
||||
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
|
||||
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
|
||||
])
|
||||
self.pure_concrete_function = function.get_concrete_function()
|
||||
super(DummyModel, self).__init__()
|
||||
|
||||
def multiply(self, a, b):
|
||||
return a * b
|
||||
|
||||
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
||||
dummy_model = DummyModel()
|
||||
save.save(dummy_model, saved_model_dir)
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
args = self.parser.parse_args(['show', '--dir', saved_model_dir, '--all'])
|
||||
with captured_output() as (out, err):
|
||||
saved_model_cli.show(args)
|
||||
output = out.getvalue().strip()
|
||||
exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
|
||||
|
||||
signature_def['__saved_model_init_op']:
|
||||
The given SavedModel SignatureDef contains the following input(s):
|
||||
The given SavedModel SignatureDef contains the following output(s):
|
||||
outputs['__saved_model_init_op'] tensor_info:
|
||||
dtype: DT_INVALID
|
||||
shape: unknown_rank
|
||||
name: NoOp
|
||||
Method name is:
|
||||
|
||||
signature_def['serving_default']:
|
||||
The given SavedModel SignatureDef contains the following input(s):
|
||||
inputs['a'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: ()
|
||||
name: serving_default_a:0
|
||||
inputs['b'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: ()
|
||||
name: serving_default_b:0
|
||||
The given SavedModel SignatureDef contains the following output(s):
|
||||
outputs['output_0'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: ()
|
||||
name: PartitionedCall:0
|
||||
Method name is: tensorflow/serving/predict
|
||||
|
||||
Defined Functions:
|
||||
Function Name: 'pure_concrete_function'
|
||||
Option #1
|
||||
Callable with:
|
||||
Argument #1
|
||||
a: TensorSpec(shape=(), dtype=tf.float32, name='a')
|
||||
Argument #2
|
||||
b: TensorSpec(shape=(), dtype=tf.float32, name='b')
|
||||
""".strip() # pylint: enable=line-too-long
|
||||
self.maxDiff = None # Produce a useful error msg if the comparison fails
|
||||
self.assertMultiLineEqual(output, exp_out)
|
||||
self.assertEqual(err.getvalue().strip(), '')
|
||||
|
||||
def testShowCommandTags(self):
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
|
Loading…
x
Reference in New Issue
Block a user