Add SavedModel CLI(Command-line Interface) tool.
Change: 153652492
This commit is contained in:
parent
bcb77007ad
commit
82da113f05
tensorflow
@ -17,6 +17,7 @@ py_library(
|
||||
":inspect_checkpoint",
|
||||
":optimize_for_inference",
|
||||
":print_selective_registration_header",
|
||||
":saved_model_cli",
|
||||
":strip_unused",
|
||||
],
|
||||
)
|
||||
@ -197,6 +198,28 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "saved_model_cli",
|
||||
srcs = ["saved_model_cli.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/saved_model:saved_model_py",
|
||||
"//tensorflow/python",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "saved_model_cli_test",
|
||||
srcs = ["saved_model_cli_test.py"],
|
||||
data = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":saved_model_cli",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
635
tensorflow/python/tools/saved_model_cli.py
Normal file
635
tensorflow/python/tools/saved_model_cli.py
Normal file
@ -0,0 +1,635 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Command-line interface to inspect and execute a graph in a SavedModel.
|
||||
|
||||
If TensorFlow is installed on your system through pip, the 'saved_model_cli'
|
||||
binary can be invoked directly from command line.
|
||||
|
||||
At a high level, SavedModel CLI allows users to both inspect and execute
|
||||
computations on a MetaGraphDef in a SavedModel. These are done through `show`
|
||||
and `run` commands. Following is the usage of the two commands. SavedModel
|
||||
CLI will also display these information with -h option.
|
||||
|
||||
'show' command usage: saved_model_cli show [-h] --dir DIR [--tag_set TAG_SET]
|
||||
[--signature_def SIGNATURE_DEF_KEY]
|
||||
Examples:
|
||||
To show all available tag-sets in the SavedModel:
|
||||
$saved_model_cli show --dir /tmp/saved_model
|
||||
|
||||
To show all available SignatureDef keys in a MetaGraphDef specified by its
|
||||
tag-set:
|
||||
$saved_model_cli show --dir /tmp/saved_model --tag_set serve
|
||||
For a MetaGraphDef with multiple tags in the tag-set, all tags must be passed
|
||||
in, separated by ',':
|
||||
$saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu
|
||||
|
||||
To show all inputs and outputs TensorInfo for a specific SignatureDef specified
|
||||
by the SignatureDef key in a MetaGraphDef:
|
||||
$saved_model_cli show --dir /tmp/saved_model --tag_set serve
|
||||
--signature_def serving_default
|
||||
Example output:
|
||||
The given SavedModel SignatureDef contains the following input(s):
|
||||
inputs['input0'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
inputs['input1'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
The given SavedModel SignatureDef contains the following output(s):
|
||||
outputs['output'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
Method name is: tensorflow/serving/regress
|
||||
|
||||
To show all available information in the SavedModel:
|
||||
$saved_model_cli show --dir /tmp/saved_model --all
|
||||
|
||||
'run' command usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET
|
||||
--signature_def SIGNATURE_DEF_KEY --inputs INPUTS
|
||||
[--outdir OUTDIR] [--overwrite]
|
||||
Examples:
|
||||
To run input tensors from files through a MetaGraphDef and save the output
|
||||
tensors to files:
|
||||
$saved_model_cli run --dir /tmp/saved_model --tag_set serve
|
||||
--signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy
|
||||
--outdir /tmp/out
|
||||
|
||||
To build this tool from source, run:
|
||||
$bazel build tensorflow/python/tools:saved_model_cli
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.saved_model.python.saved_model import reader
|
||||
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.saved_model import loader
|
||||
|
||||
|
||||
def _show_tag_sets(saved_model_dir):
|
||||
"""Prints the tag-sets stored in SavedModel directory.
|
||||
|
||||
Prints all the tag-sets for MetaGraphs stored in SavedModel directory.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel to inspect.
|
||||
"""
|
||||
tag_sets = reader.get_saved_model_tag_sets(saved_model_dir)
|
||||
print('The given SavedModel contains the following tag-sets:')
|
||||
for tag_set in sorted(tag_sets):
|
||||
print(', '.join(sorted(tag_set)))
|
||||
|
||||
|
||||
def _show_signature_def_map_keys(saved_model_dir, tag_set):
|
||||
"""Prints the keys for each SignatureDef in the SignatureDef map.
|
||||
|
||||
Prints the list of SignatureDef keys from the SignatureDef map specified by
|
||||
the given tag-set and SavedModel directory.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel to inspect.
|
||||
tag_set: Group of tag(s) of the MetaGraphDef to get SignatureDef map from,
|
||||
in string format, separated by ','. For tag-set contains multiple tags,
|
||||
all tags must be passed in.
|
||||
"""
|
||||
signature_def_map = get_signature_def_map(saved_model_dir, tag_set)
|
||||
print('The given SavedModel MetaGraphDef contains SignatureDefs with the '
|
||||
'following keys:')
|
||||
for signature_def_key in sorted(signature_def_map.keys()):
|
||||
print('SignatureDef key: \"%s\"' % signature_def_key)
|
||||
|
||||
|
||||
def _get_inputs_tensor_info_from_meta_graph_def(meta_graph_def,
|
||||
signature_def_key):
|
||||
"""Gets TensorInfo for all inputs of the SignatureDef.
|
||||
|
||||
Returns a dictionary that maps each input key to its TensorInfo for the given
|
||||
signature_def_key in the meta_graph_def
|
||||
|
||||
Args:
|
||||
meta_graph_def: MetaGraphDef protocol buffer with the SignatureDef map to
|
||||
look up SignatureDef key.
|
||||
signature_def_key: A SignatureDef key string.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps input tensor keys to TensorInfos.
|
||||
"""
|
||||
return signature_def_utils.get_signature_def_by_key(meta_graph_def,
|
||||
signature_def_key).inputs
|
||||
|
||||
|
||||
def _get_outputs_tensor_info_from_meta_graph_def(meta_graph_def,
|
||||
signature_def_key):
|
||||
"""Gets TensorInfos for all outputs of the SignatureDef.
|
||||
|
||||
Returns a dictionary that maps each output key to its TensorInfo for the given
|
||||
signature_def_key in the meta_graph_def.
|
||||
|
||||
Args:
|
||||
meta_graph_def: MetaGraphDef protocol buffer with the SignatureDefmap to
|
||||
look up signature_def_key.
|
||||
signature_def_key: A SignatureDef key string.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps output tensor keys to TensorInfos.
|
||||
"""
|
||||
return signature_def_utils.get_signature_def_by_key(meta_graph_def,
|
||||
signature_def_key).outputs
|
||||
|
||||
|
||||
def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key):
|
||||
"""Prints input and output TensorInfos.
|
||||
|
||||
Prints the details of input and output TensorInfos for the SignatureDef mapped
|
||||
by the given signature_def_key.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel to inspect.
|
||||
tag_set: Group of tag(s) of the MetaGraphDef, in string format, separated by
|
||||
','. For tag-set contains multiple tags, all tags must be passed in.
|
||||
signature_def_key: A SignatureDef key string.
|
||||
"""
|
||||
meta_graph_def = get_meta_graph_def(saved_model_dir, tag_set)
|
||||
inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
|
||||
meta_graph_def, signature_def_key)
|
||||
outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
|
||||
meta_graph_def, signature_def_key)
|
||||
|
||||
print('The given SavedModel SignatureDef contains the following input(s):')
|
||||
for input_key, input_tensor in sorted(inputs_tensor_info.items()):
|
||||
print('inputs[\'%s\'] tensor_info:' % input_key)
|
||||
_print_tensor_info(input_tensor)
|
||||
|
||||
print('The given SavedModel SignatureDef contains the following output(s):')
|
||||
for output_key, output_tensor in sorted(outputs_tensor_info.items()):
|
||||
print('outputs[\'%s\'] tensor_info:' % output_key)
|
||||
_print_tensor_info(output_tensor)
|
||||
|
||||
print('Method name is: %s' %
|
||||
meta_graph_def.signature_def[signature_def_key].method_name)
|
||||
|
||||
|
||||
def _print_tensor_info(tensor_info):
|
||||
"""Prints details of the given tensor_info.
|
||||
|
||||
Args:
|
||||
tensor_info: TensorInfo object to be printed.
|
||||
"""
|
||||
print(' dtype: ' + types_pb2.DataType.keys()[tensor_info.dtype])
|
||||
# Display shape as tuple.
|
||||
if tensor_info.tensor_shape.unknown_rank:
|
||||
shape = 'unknown_rank'
|
||||
else:
|
||||
dims = [str(dim.size) for dim in tensor_info.tensor_shape.dim]
|
||||
shape = ', '.join(dims)
|
||||
shape = '(' + shape + ')'
|
||||
print(' shape: ' + shape)
|
||||
|
||||
|
||||
def _show_all(saved_model_dir):
|
||||
"""Prints tag-set, SignatureDef and Inputs/Outputs information in SavedModel.
|
||||
|
||||
Prints all tag-set, SignatureDef and Inputs/Outputs information stored in
|
||||
SavedModel directory.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel to inspect.
|
||||
"""
|
||||
tag_sets = reader.get_saved_model_tag_sets(saved_model_dir)
|
||||
for tag_set in sorted(tag_sets):
|
||||
tag_set = ', '.join(tag_set)
|
||||
print('\nMetaGraphDef with tag-set: \'' + tag_set +
|
||||
'\' contains the following SignatureDefs:')
|
||||
|
||||
signature_def_map = get_signature_def_map(saved_model_dir, tag_set)
|
||||
for signature_def_key in sorted(signature_def_map.keys()):
|
||||
print('\nsignature_def[\'' + signature_def_key + '\']:')
|
||||
_show_inputs_outputs(saved_model_dir, tag_set, signature_def_key)
|
||||
|
||||
|
||||
def get_meta_graph_def(saved_model_dir, tag_set):
|
||||
"""Gets MetaGraphDef from SavedModel.
|
||||
|
||||
Returns the MetaGraphDef for the given tag-set and SavedModel directory.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel to inspect or execute.
|
||||
tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
|
||||
separated by ','. For tag-set contains multiple tags, all tags must be
|
||||
passed in.
|
||||
|
||||
Raises:
|
||||
RuntimeError: An error when the given tag-set does not exist in the
|
||||
SavedModel.
|
||||
|
||||
Returns:
|
||||
A MetaGraphDef corresponding to the tag-set.
|
||||
"""
|
||||
saved_model = reader.read_saved_model(saved_model_dir)
|
||||
set_of_tags = set(tag_set.split(','))
|
||||
for meta_graph_def in saved_model.meta_graphs:
|
||||
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
|
||||
return meta_graph_def
|
||||
|
||||
raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set +
|
||||
' could not be found in SavedModel')
|
||||
|
||||
|
||||
def get_signature_def_map(saved_model_dir, tag_set):
|
||||
"""Gets SignatureDef map from a MetaGraphDef in a SavedModel.
|
||||
|
||||
Returns the SignatureDef map for the given tag-set in the SavedModel
|
||||
directory.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel to inspect or execute.
|
||||
tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
|
||||
string format, separated by ','. For tag-set contains multiple tags, all
|
||||
tags must be passed in.
|
||||
|
||||
Returns:
|
||||
A SignatureDef map that maps from string keys to SignatureDefs.
|
||||
"""
|
||||
meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
|
||||
return meta_graph.signature_def
|
||||
|
||||
|
||||
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
|
||||
input_tensor_key_feed_dict, outdir,
|
||||
overwrite_flag):
|
||||
"""Runs SavedModel and fetch all outputs.
|
||||
|
||||
Runs the input dictionary through the MetaGraphDef within a SavedModel
|
||||
specified by the given tag_set and SignatureDef. Also save the outputs to file
|
||||
if outdir is not None.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel to execute.
|
||||
tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
|
||||
string format, separated by ','. For tag-set contains multiple tags, all
|
||||
tags must be passed in.
|
||||
signature_def_key: A SignatureDef key string.
|
||||
input_tensor_key_feed_dict: A dictionary maps input keys to numpy ndarrays.
|
||||
outdir: A directory to save the outputs to. If the directory doesn't exist,
|
||||
it will be created.
|
||||
overwrite_flag: A boolean flag to allow overwrite output file if file with
|
||||
the same name exists.
|
||||
|
||||
Raises:
|
||||
RuntimeError: An error when output file already exists and overwrite is not
|
||||
enabled.
|
||||
"""
|
||||
# Get a list of output tensor names.
|
||||
meta_graph_def = get_meta_graph_def(saved_model_dir, tag_set)
|
||||
|
||||
# Re-create feed_dict based on input tensor name instead of key as session.run
|
||||
# uses tensor name.
|
||||
inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
|
||||
meta_graph_def, signature_def_key)
|
||||
inputs_feed_dict = {
|
||||
inputs_tensor_info[key].name: tensor
|
||||
for key, tensor in input_tensor_key_feed_dict.items()
|
||||
}
|
||||
# Get outputs
|
||||
outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
|
||||
meta_graph_def, signature_def_key)
|
||||
# Sort to preserve order because we need to go from value to key later.
|
||||
output_tensor_keys_sorted = sorted(outputs_tensor_info.keys())
|
||||
output_tensor_names_sorted = [
|
||||
outputs_tensor_info[tensor_key].name
|
||||
for tensor_key in output_tensor_keys_sorted
|
||||
]
|
||||
|
||||
with session.Session(graph=ops_lib.Graph()) as sess:
|
||||
loader.load(sess, tag_set.split(','), saved_model_dir)
|
||||
|
||||
outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
output_tensor_key = output_tensor_keys_sorted[i]
|
||||
print('Result for output key %s:\n%s' % (output_tensor_key, output))
|
||||
|
||||
# Only save if outdir is specified.
|
||||
if outdir:
|
||||
# Create directory if outdir does not exist
|
||||
if not os.path.isdir(outdir):
|
||||
os.makedirs(outdir)
|
||||
output_full_path = os.path.join(outdir, output_tensor_key + '.npy')
|
||||
|
||||
# If overwrite not enabled and file already exist, error out
|
||||
if not overwrite_flag and os.path.exists(output_full_path):
|
||||
raise RuntimeError(
|
||||
'Output file %s already exists. Add \"--overwrite\" to overwrite'
|
||||
' the existing output files.' % output_full_path)
|
||||
|
||||
np.save(output_full_path, output)
|
||||
print('Output %s is saved to %s' % (output_tensor_key,
|
||||
output_full_path))
|
||||
|
||||
|
||||
def preprocess_input_arg_string(inputs_str):
|
||||
"""Parses input arg into dictionary that maps input to file/variable tuple.
|
||||
|
||||
Parses input string in the format of, for example,
|
||||
"input1=filename1[variable_name1],input2=filename2" into a
|
||||
dictionary looks like
|
||||
{'input_key1': (filename1, variable_name1),
|
||||
'input_key2': (file2, None)}
|
||||
, which maps input keys to a tuple of file name and varaible name(None if
|
||||
empty).
|
||||
|
||||
Args:
|
||||
inputs_str: A string that specified where to load inputs. Each input is
|
||||
separated by comma.
|
||||
* If the command line arg for inputs is quoted and contains
|
||||
whitespace(s), all whitespaces will be ignored.
|
||||
* For each input key:
|
||||
'input=filename<[variable_name]>'
|
||||
* The "[variable_name]" key is optional. Will be set to None if not
|
||||
specified.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps input keys to a tuple of file name and varaible name.
|
||||
|
||||
Raises:
|
||||
RuntimeError: An error when the given input is in a bad format.
|
||||
"""
|
||||
input_dict = {}
|
||||
inputs_raw = inputs_str.split(',')
|
||||
for input_raw in filter(bool, inputs_raw): # skip empty strings
|
||||
# Remove quotes and whitespaces
|
||||
input_raw = input_raw.replace('"', '').replace('\'', '').replace(' ', '')
|
||||
|
||||
# Format of input=filename[variable_name]'
|
||||
match = re.match(r'^([\w\-]+)=([\w\-.\/]+)\[([\w\-]+)\]$', input_raw)
|
||||
if match:
|
||||
input_dict[match.group(1)] = (match.group(2), match.group(3))
|
||||
else:
|
||||
# Format of input=filename'
|
||||
match = re.match(r'^([\w\-]+)=([\w\-.\/]+)$', input_raw)
|
||||
if match:
|
||||
input_dict[match.group(1)] = (match.group(2), None)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Input \"%s\" format is incorrect. Please follow \"--inputs '
|
||||
'input_key=file_name[variable_name]\" or input_key=file_name' %
|
||||
input_raw)
|
||||
|
||||
return input_dict
|
||||
|
||||
|
||||
def load_inputs_from_input_arg_string(inputs_str):
|
||||
"""Parses input arg string and load inputs into a dictionary.
|
||||
|
||||
Parses input string in the format of, for example,
|
||||
"input1=filename1[variable_name1],input2=filename2" into a
|
||||
dictionary looks like
|
||||
{'input1:0': ndarray_saved_as_variable_name1_in_filename1 ,
|
||||
'input2:0': ndarray_saved_in_filename2}
|
||||
, which maps input keys to a numpy ndarray loaded from file. See Args section
|
||||
for more details on inputs format.
|
||||
|
||||
Args:
|
||||
inputs_str: A string that specified where to load inputs. Each input is
|
||||
separated by comma.
|
||||
* If the command line arg for inputs is quoted and contains
|
||||
whitespace(s), all whitespaces will be ignored.
|
||||
* For each input key:
|
||||
'input=filename[variable_name]'
|
||||
* File specified by 'filename' will be loaded using numpy.load. Inputs
|
||||
can be loaded from only .npy, .npz or pickle files.
|
||||
* The "[variable_name]" key is optional depending on the input file type
|
||||
as descripted in more details below.
|
||||
When loading from a npy file, which always contains a numpy ndarray, the
|
||||
content will be directly assigned to the specified input tensor. If a
|
||||
varaible_name is specified, it will be ignored and a warning will be
|
||||
issued.
|
||||
When loading from a npz zip file, user can specify which variable within
|
||||
the zip file to load for the input tensor inside the square brackets. If
|
||||
nothing is specified, this function will check that only one file is
|
||||
included in the zip and load it for the specified input tensor.
|
||||
When loading from a pickle file, if no variable_name is specified in the
|
||||
square brackets, whatever that is inside the pickle file will be passed
|
||||
to the specified input tensor, else SavedModel CLI will assume a
|
||||
dictionary is stored in the pickle file and the value corresponding to
|
||||
the variable_name will be used.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps input tensor keys to a numpy ndarray loaded from
|
||||
file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: An error when a key is specified, but the input file contains
|
||||
multiple numpy ndarrays, none of which matches the given key.
|
||||
RuntimeError: An error when no key is specified, but the input file contains
|
||||
more than one numpy ndarrays.
|
||||
"""
|
||||
tensor_key_feed_dict = {}
|
||||
|
||||
for input_tensor_key, (
|
||||
filename,
|
||||
variable_name) in preprocess_input_arg_string(inputs_str).items():
|
||||
# When a variable_name key is specified for the input file
|
||||
if variable_name:
|
||||
data = np.load(filename)
|
||||
|
||||
# if file contains a single ndarray, ignore the input name
|
||||
if isinstance(data, np.ndarray):
|
||||
warnings.warn(
|
||||
'Input file %s contains a single ndarray. Name key \"%s\" ignored.'
|
||||
% (filename, variable_name))
|
||||
tensor_key_feed_dict[input_tensor_key] = data
|
||||
else:
|
||||
if variable_name in data:
|
||||
tensor_key_feed_dict[input_tensor_key] = data[variable_name]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Input file %s does not contain variable with name \"%s\".' %
|
||||
(filename, variable_name))
|
||||
# When no key is specified for the input file.
|
||||
else:
|
||||
data = np.load(filename)
|
||||
# Check if npz file only contains a single numpy ndarray.
|
||||
if isinstance(data, np.lib.npyio.NpzFile):
|
||||
variable_name_list = data.files
|
||||
if len(variable_name_list) != 1:
|
||||
raise RuntimeError(
|
||||
'Input file %s contains more than one ndarrays. Please specify '
|
||||
'the name of ndarray to use.' % filename)
|
||||
tensor_key_feed_dict[input_tensor_key] = data[variable_name_list[0]]
|
||||
else:
|
||||
tensor_key_feed_dict[input_tensor_key] = data
|
||||
|
||||
return tensor_key_feed_dict
|
||||
|
||||
|
||||
def show(args):
|
||||
"""Function triggered by show command.
|
||||
|
||||
Args:
|
||||
args: A namespace parsed from command line.
|
||||
"""
|
||||
# If all tag is specified, display all information.
|
||||
if args.all:
|
||||
_show_all(args.dir)
|
||||
else:
|
||||
# If no tag is specified, display all tag_set, if no signaure_def key is
|
||||
# specified, display all SignatureDef keys, else show input output tensor
|
||||
# infomation corresponding to the given SignatureDef key
|
||||
if args.tag_set is None:
|
||||
_show_tag_sets(args.dir)
|
||||
else:
|
||||
if args.signature_def is None:
|
||||
_show_signature_def_map_keys(args.dir, args.tag_set)
|
||||
else:
|
||||
_show_inputs_outputs(args.dir, args.tag_set, args.signature_def)
|
||||
|
||||
|
||||
def run(args):
|
||||
"""Function triggered by run command.
|
||||
|
||||
Args:
|
||||
args: A namespace parsed from command line.
|
||||
"""
|
||||
tensor_key_feed_dict = load_inputs_from_input_arg_string(args.inputs)
|
||||
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
|
||||
tensor_key_feed_dict, args.outdir,
|
||||
args.overwrite)
|
||||
|
||||
|
||||
def create_parser():
|
||||
"""Creates a parser that parse the command line arguments.
|
||||
|
||||
Returns:
|
||||
A namespace parsed from command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='saved_model_cli: Command-line interface for SavedModel')
|
||||
parser.add_argument('-v', '--version', action='version', version='0.1.0')
|
||||
|
||||
subparsers = parser.add_subparsers(
|
||||
title='commands', description='valid commands', help='additional help')
|
||||
|
||||
# show command
|
||||
show_msg = (
|
||||
'Usage examples:\n'
|
||||
'To show all tag-sets in a SavedModel:\n'
|
||||
'$saved_model_cli show --dir /tmp/saved_model\n'
|
||||
'To show all available SignatureDef keys in a '
|
||||
'MetaGraphDef specified by its tag-set:\n'
|
||||
'$saved_model_cli show --dir /tmp/saved_model --tag_set serve\n'
|
||||
'For a MetaGraphDef with multiple tags in the tag-set, all tags must be '
|
||||
'passed in, separated by \',\':\n'
|
||||
'$saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu\n\n'
|
||||
'To show all inputs and outputs TensorInfo for a specific'
|
||||
' SignatureDef specified by the SignatureDef key in a'
|
||||
' MetaGraph.\n'
|
||||
'$saved_model_cli show --dir /tmp/saved_model --tag_set serve'
|
||||
'--signature_def serving_default\n\n'
|
||||
'To show all available information in the SavedModel\n:'
|
||||
'$saved_model_cli show --dir /tmp/saved_model --all')
|
||||
parser_show = subparsers.add_parser(
|
||||
'show',
|
||||
description=show_msg,
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
parser_show.add_argument(
|
||||
'--dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='directory containing the SavedModel to inspect')
|
||||
parser_show.add_argument(
|
||||
'--all',
|
||||
action='store_true',
|
||||
help='if set, will output all infomation in given SavedModel')
|
||||
parser_show.add_argument(
|
||||
'--tag_set',
|
||||
type=str,
|
||||
default=None,
|
||||
help='tag-set of graph in SavedModel to show, separated by \',\'')
|
||||
parser_show.add_argument(
|
||||
'--signature_def',
|
||||
type=str,
|
||||
default=None,
|
||||
metavar='SIGNATURE_DEF_KEY',
|
||||
help='key of SignatureDef to display input(s) and output(s) for')
|
||||
parser_show.set_defaults(func=show)
|
||||
|
||||
# run command
|
||||
run_msg = ('Usage example:\n'
|
||||
'To run input tensors from files through a MetaGraphDef and save'
|
||||
' the output tensors to files:\n'
|
||||
'$saved_model_cli show --dir /tmp/saved_model --tag_set serve'
|
||||
'--signature_def serving_default '
|
||||
'--inputs x1=/tmp/124.npz[x],x2=/tmp/123.npy'
|
||||
'--outdir=/out\n\n'
|
||||
'For more information about input file format, please see:\n')
|
||||
parser_run = subparsers.add_parser(
|
||||
'run', description=run_msg, formatter_class=argparse.RawTextHelpFormatter)
|
||||
parser_run.add_argument(
|
||||
'--dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='directory containing the SavedModel to execute')
|
||||
parser_run.add_argument(
|
||||
'--tag_set',
|
||||
type=str,
|
||||
required=True,
|
||||
help='tag-set of graph in SavedModel to load, separated by \',\'')
|
||||
parser_run.add_argument(
|
||||
'--signature_def',
|
||||
type=str,
|
||||
required=True,
|
||||
metavar='SIGNATURE_DEF_KEY',
|
||||
help='key of SignatureDef to run')
|
||||
msg = ('inputs in the format of \'input_key=filename[variable_name]\', '
|
||||
'separated by \',\'. Inputs can only be loaded from .npy, .npz or '
|
||||
'pickle files.')
|
||||
parser_run.add_argument('--inputs', type=str, required=True, help=msg)
|
||||
parser_run.add_argument(
|
||||
'--outdir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='if specified, output tensor(s) will be saved to given directory')
|
||||
parser_run.add_argument(
|
||||
'--overwrite',
|
||||
action='store_true',
|
||||
help='if set, output file will be overwritten if it already exists.')
|
||||
parser_run.set_defaults(func=run)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
367
tensorflow/python/tools/saved_model_cli_test.py
Normal file
367
tensorflow/python/tools/saved_model_cli_test.py
Normal file
@ -0,0 +1,367 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for SavedModelCLI tool.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from six import StringIO
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tools import saved_model_cli
|
||||
|
||||
SAVED_MODEL_PATH = ('cc/saved_model/testdata/half_plus_two/00000123')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def captured_output():
|
||||
new_out, new_err = StringIO(), StringIO()
|
||||
old_out, old_err = sys.stdout, sys.stderr
|
||||
try:
|
||||
sys.stdout, sys.stderr = new_out, new_err
|
||||
yield sys.stdout, sys.stderr
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_out, old_err
|
||||
|
||||
|
||||
class SavedModelCLITestCase(test.TestCase):
|
||||
|
||||
def testShowCommandAll(self):
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
args = self.parser.parse_args(['show', '--dir', base_path, '--all'])
|
||||
with captured_output() as (out, err):
|
||||
saved_model_cli.show(args)
|
||||
output = out.getvalue().strip()
|
||||
# pylint: disable=line-too-long
|
||||
exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
|
||||
|
||||
signature_def['classify_x2_to_y3']:
|
||||
The given SavedModel SignatureDef contains the following input(s):
|
||||
inputs['inputs'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
The given SavedModel SignatureDef contains the following output(s):
|
||||
outputs['scores'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
Method name is: tensorflow/serving/classify
|
||||
|
||||
signature_def['classify_x_to_y']:
|
||||
The given SavedModel SignatureDef contains the following input(s):
|
||||
inputs['inputs'] tensor_info:
|
||||
dtype: DT_STRING
|
||||
shape: unknown_rank
|
||||
The given SavedModel SignatureDef contains the following output(s):
|
||||
outputs['scores'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
Method name is: tensorflow/serving/classify
|
||||
|
||||
signature_def['regress_x2_to_y3']:
|
||||
The given SavedModel SignatureDef contains the following input(s):
|
||||
inputs['inputs'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
The given SavedModel SignatureDef contains the following output(s):
|
||||
outputs['outputs'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
Method name is: tensorflow/serving/regress
|
||||
|
||||
signature_def['regress_x_to_y']:
|
||||
The given SavedModel SignatureDef contains the following input(s):
|
||||
inputs['inputs'] tensor_info:
|
||||
dtype: DT_STRING
|
||||
shape: unknown_rank
|
||||
The given SavedModel SignatureDef contains the following output(s):
|
||||
outputs['outputs'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
Method name is: tensorflow/serving/regress
|
||||
|
||||
signature_def['regress_x_to_y2']:
|
||||
The given SavedModel SignatureDef contains the following input(s):
|
||||
inputs['inputs'] tensor_info:
|
||||
dtype: DT_STRING
|
||||
shape: unknown_rank
|
||||
The given SavedModel SignatureDef contains the following output(s):
|
||||
outputs['outputs'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
Method name is: tensorflow/serving/regress
|
||||
|
||||
signature_def['serving_default']:
|
||||
The given SavedModel SignatureDef contains the following input(s):
|
||||
inputs['x'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
The given SavedModel SignatureDef contains the following output(s):
|
||||
outputs['y'] tensor_info:
|
||||
dtype: DT_FLOAT
|
||||
shape: (-1, 1)
|
||||
Method name is: tensorflow/serving/predict"""
|
||||
# pylint: enable=line-too-long
|
||||
self.assertMultiLineEqual(output, exp_out)
|
||||
self.assertEqual(err.getvalue().strip(), '')
|
||||
|
||||
def testShowCommandTags(self):
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
args = self.parser.parse_args(['show', '--dir', base_path])
|
||||
with captured_output() as (out, err):
|
||||
saved_model_cli.show(args)
|
||||
output = out.getvalue().strip()
|
||||
exp_out = 'The given SavedModel contains the following tag-sets:\nserve'
|
||||
self.assertMultiLineEqual(output, exp_out)
|
||||
self.assertEqual(err.getvalue().strip(), '')
|
||||
|
||||
def testShowCommandSignature(self):
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
args = self.parser.parse_args(
|
||||
['show', '--dir', base_path, '--tag_set', 'serve'])
|
||||
with captured_output() as (out, err):
|
||||
saved_model_cli.show(args)
|
||||
output = out.getvalue().strip()
|
||||
exp_header = ('The given SavedModel MetaGraphDef contains SignatureDefs '
|
||||
'with the following keys:')
|
||||
exp_start = 'SignatureDef key: '
|
||||
exp_keys = [
|
||||
'"classify_x2_to_y3"', '"classify_x_to_y"', '"regress_x2_to_y3"',
|
||||
'"regress_x_to_y"', '"regress_x_to_y2"', '"serving_default"'
|
||||
]
|
||||
# Order of signatures does not matter
|
||||
self.assertMultiLineEqual(
|
||||
output,
|
||||
'\n'.join([exp_header] + [exp_start + exp_key for exp_key in exp_keys]))
|
||||
self.assertEqual(err.getvalue().strip(), '')
|
||||
|
||||
def testShowCommandErrorNoTagSet(self):
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
args = self.parser.parse_args(
|
||||
['show', '--dir', base_path, '--tag_set', 'badtagset'])
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.show(args)
|
||||
|
||||
def testShowCommandInputsOutputs(self):
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
args = self.parser.parse_args([
|
||||
'show', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
||||
'serving_default'
|
||||
])
|
||||
with captured_output() as (out, err):
|
||||
saved_model_cli.show(args)
|
||||
output = out.getvalue().strip()
|
||||
expected_output = (
|
||||
'The given SavedModel SignatureDef contains the following input(s):\n'
|
||||
'inputs[\'x\'] tensor_info:\n'
|
||||
' dtype: DT_FLOAT\n shape: (-1, 1)\n'
|
||||
'The given SavedModel SignatureDef contains the following output(s):\n'
|
||||
'outputs[\'y\'] tensor_info:\n'
|
||||
' dtype: DT_FLOAT\n shape: (-1, 1)\n'
|
||||
'Method name is: tensorflow/serving/predict')
|
||||
self.assertEqual(output, expected_output)
|
||||
self.assertEqual(err.getvalue().strip(), '')
|
||||
|
||||
def testInputPreProcessFormats(self):
|
||||
input_str = 'input1=/path/file.txt[ab3], input2=file2,,'
|
||||
input_dict = saved_model_cli.preprocess_input_arg_string(input_str)
|
||||
self.assertTrue(input_dict['input1'] == ('/path/file.txt', 'ab3'))
|
||||
self.assertTrue(input_dict['input2'] == ('file2', None))
|
||||
|
||||
def testInputPreProcessQuoteAndWhitespace(self):
|
||||
input_str = '\' input1 = file[v_1]\', input2=file ["sd"] '
|
||||
input_dict = saved_model_cli.preprocess_input_arg_string(input_str)
|
||||
self.assertTrue(input_dict['input1'] == ('file', 'v_1'))
|
||||
self.assertTrue(input_dict['input2'] == ('file', 'sd'))
|
||||
self.assertTrue(len(input_dict) == 2)
|
||||
|
||||
def testInputPreProcessErrorBadFormat(self):
|
||||
input_str = 'inputx=file[[v1]v2'
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.preprocess_input_arg_string(input_str)
|
||||
input_str = 'inputx:file'
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.preprocess_input_arg_string(input_str)
|
||||
input_str = 'inputx=file(v_1)'
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.preprocess_input_arg_string(input_str)
|
||||
|
||||
def testInputParserNPY(self):
|
||||
x0 = np.array([[1], [2]])
|
||||
x1 = np.array(range(6)).reshape(2, 3)
|
||||
input0_path = os.path.join(test.get_temp_dir(), 'input0.npy')
|
||||
input1_path = os.path.join(test.get_temp_dir(), 'input1.npy')
|
||||
np.save(input0_path, x0)
|
||||
np.save(input1_path, x1)
|
||||
input_str = 'x0=' + input0_path + '[x0],x1=' + input1_path
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
self.assertTrue(np.all(feed_dict['x0'] == x0))
|
||||
self.assertTrue(np.all(feed_dict['x1'] == x1))
|
||||
|
||||
def testInputParserNPZ(self):
|
||||
x0 = np.array([[1], [2]])
|
||||
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
||||
np.savez(input_path, a=x0)
|
||||
input_str = 'x=' + input_path + '[a],y=' + input_path
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
self.assertTrue(np.all(feed_dict['x'] == x0))
|
||||
self.assertTrue(np.all(feed_dict['y'] == x0))
|
||||
|
||||
def testInputParserPickle(self):
|
||||
pkl0 = {'a': 5, 'b': np.array(range(4))}
|
||||
pkl1 = np.array([1])
|
||||
pkl2 = np.array([[1], [3]])
|
||||
input_path0 = os.path.join(test.get_temp_dir(), 'pickle0.pkl')
|
||||
input_path1 = os.path.join(test.get_temp_dir(), 'pickle1.pkl')
|
||||
input_path2 = os.path.join(test.get_temp_dir(), 'pickle2.pkl')
|
||||
with open(input_path0, 'wb') as f:
|
||||
pickle.dump(pkl0, f)
|
||||
with open(input_path1, 'wb') as f:
|
||||
pickle.dump(pkl1, f)
|
||||
with open(input_path2, 'wb') as f:
|
||||
pickle.dump(pkl2, f)
|
||||
input_str = 'x=' + input_path0 + '[b],y=' + input_path1 + '[c],'
|
||||
input_str += 'z=' + input_path2
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
self.assertTrue(np.all(feed_dict['x'] == pkl0['b']))
|
||||
self.assertTrue(np.all(feed_dict['y'] == pkl1))
|
||||
self.assertTrue(np.all(feed_dict['z'] == pkl2))
|
||||
|
||||
def testInputParserQuoteAndWhitespace(self):
|
||||
x0 = np.array([[1], [2]])
|
||||
x1 = np.array(range(6)).reshape(2, 3)
|
||||
input0_path = os.path.join(test.get_temp_dir(), 'input0.npy')
|
||||
input1_path = os.path.join(test.get_temp_dir(), 'input1.npy')
|
||||
np.save(input0_path, x0)
|
||||
np.save(input1_path, x1)
|
||||
input_str = '"x0=' + input0_path + '[x0] , x1 = ' + input1_path + '"'
|
||||
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
self.assertTrue(np.all(feed_dict['x0'] == x0))
|
||||
self.assertTrue(np.all(feed_dict['x1'] == x1))
|
||||
|
||||
def testInputParserErrorNoName(self):
|
||||
x0 = np.array([[1], [2]])
|
||||
x1 = np.array(range(5))
|
||||
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
||||
np.savez(input_path, a=x0, b=x1)
|
||||
input_str = 'x=' + input_path
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
|
||||
def testInputParserErrorWrongName(self):
|
||||
x0 = np.array([[1], [2]])
|
||||
x1 = np.array(range(5))
|
||||
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
||||
np.savez(input_path, a=x0, b=x1)
|
||||
input_str = 'x=' + input_path + '[c]'
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.load_inputs_from_input_arg_string(input_str)
|
||||
|
||||
def testRunCommandExistingOutdir(self):
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
x = np.array([[1], [2]])
|
||||
x_notused = np.zeros((6, 3))
|
||||
input_path = os.path.join(test.get_temp_dir(), 'testRunCommand_inputs.npz')
|
||||
np.savez(input_path, x0=x, x1=x_notused)
|
||||
output_file = os.path.join(test.get_temp_dir(), 'outputs.npy')
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
args = self.parser.parse_args([
|
||||
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
||||
'regress_x2_to_y3', '--inputs', 'inputs=' + input_path + '[x0]',
|
||||
'--outdir',
|
||||
test.get_temp_dir()
|
||||
])
|
||||
saved_model_cli.run(args)
|
||||
y = np.load(output_file)
|
||||
y_exp = np.array([[3.5], [4.0]])
|
||||
self.assertTrue(np.allclose(y, y_exp))
|
||||
|
||||
def testRunCommandNewOutdir(self):
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
x = np.array([[1], [2]])
|
||||
x_notused = np.zeros((6, 3))
|
||||
input_path = os.path.join(test.get_temp_dir(),
|
||||
'testRunCommandNewOutdir_inputs.npz')
|
||||
output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
|
||||
if os.path.isdir(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
np.savez(input_path, x0=x, x1=x_notused)
|
||||
args = self.parser.parse_args([
|
||||
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
||||
'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir',
|
||||
output_dir
|
||||
])
|
||||
saved_model_cli.run(args)
|
||||
y = np.load(os.path.join(output_dir, 'y.npy'))
|
||||
y_exp = np.array([[2.5], [3.0]])
|
||||
self.assertTrue(np.allclose(y, y_exp))
|
||||
|
||||
def testRunCommandOutOverwrite(self):
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
x = np.array([[1], [2]])
|
||||
x_notused = np.zeros((6, 3))
|
||||
input_path = os.path.join(test.get_temp_dir(),
|
||||
'testRunCommandOutOverwrite_inputs.npz')
|
||||
np.savez(input_path, x0=x, x1=x_notused)
|
||||
output_file = os.path.join(test.get_temp_dir(), 'y.npy')
|
||||
open(output_file, 'a').close()
|
||||
args = self.parser.parse_args([
|
||||
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
||||
'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir',
|
||||
test.get_temp_dir(), '--overwrite'
|
||||
])
|
||||
saved_model_cli.run(args)
|
||||
y = np.load(output_file)
|
||||
y_exp = np.array([[2.5], [3.0]])
|
||||
self.assertTrue(np.allclose(y, y_exp))
|
||||
|
||||
def testRunCommandOutputFileExistError(self):
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
x = np.array([[1], [2]])
|
||||
x_notused = np.zeros((6, 3))
|
||||
input_path = os.path.join(test.get_temp_dir(),
|
||||
'testRunCommandOutOverwrite_inputs.npz')
|
||||
np.savez(input_path, x0=x, x1=x_notused)
|
||||
output_file = os.path.join(test.get_temp_dir(), 'y.npy')
|
||||
open(output_file, 'a').close()
|
||||
args = self.parser.parse_args([
|
||||
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
||||
'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir',
|
||||
test.get_temp_dir()
|
||||
])
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.run(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -59,6 +59,7 @@ else:
|
||||
# pylint: disable=line-too-long
|
||||
CONSOLE_SCRIPTS = [
|
||||
'tensorboard = tensorflow.tensorboard.tensorboard:main',
|
||||
'saved_model_cli = tensorflow.python.tools.saved_model_cli:main',
|
||||
]
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user