diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 7ada956124b..c9065b0a898 100755 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -51,6 +51,7 @@ py_library( "//tensorflow/contrib/ndlstm", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", + "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/resampler:resampler_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index f70b810ec69..513e657a333 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -50,6 +50,7 @@ from tensorflow.contrib import metrics from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt +from tensorflow.contrib import predictor from tensorflow.contrib import quantization from tensorflow.contrib import resampler from tensorflow.contrib import rnn diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index c4b46551c12..e298fd3cb23 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -1,6 +1,6 @@ # `Predictor` classes provide an interface for efficient, repeated inference. -package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"]) +package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"]) licenses(["notice"]) # Apache 2.0 @@ -62,7 +62,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":base_predictor", - "//tensorflow/python/tools:saved_model_cli", ], ) diff --git a/tensorflow/contrib/predictor/__init__.py b/tensorflow/contrib/predictor/__init__.py index d270c3f7983..e0a2152b371 100644 --- a/tensorflow/contrib/predictor/__init__.py +++ b/tensorflow/contrib/predictor/__init__.py @@ -19,6 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.predictor import from_contrib_estimator -from tensorflow.contrib.predictor import from_estimator -from tensorflow.contrib.predictor import from_saved_model +from tensorflow.contrib.predictor.predictor_factories import from_contrib_estimator +from tensorflow.contrib.predictor.predictor_factories import from_estimator +from tensorflow.contrib.predictor.predictor_factories import from_saved_model diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py index ab2bafa0c86..0dbca0f8136 100644 --- a/tensorflow/contrib/predictor/saved_model_predictor.py +++ b/tensorflow/contrib/predictor/saved_model_predictor.py @@ -22,12 +22,12 @@ from __future__ import print_function import logging from tensorflow.contrib.predictor import predictor +from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import signature_constants -from tensorflow.python.tools import saved_model_cli DEFAULT_TAGS = 'serve' @@ -35,13 +35,37 @@ DEFAULT_TAGS = 'serve' _DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}' +def get_meta_graph_def(saved_model_dir, tags): + """Gets `MetaGraphDef` from a directory containing a `SavedModel`. + + Returns the `MetaGraphDef` for the given tag-set and SavedModel directory. + + Args: + saved_model_dir: Directory containing the SavedModel. + tags: Comma separated list of tags used to identify the correct + `MetaGraphDef`. + + Raises: + ValueError: An error when the given tags cannot be found. + + Returns: + A `MetaGraphDef` corresponding to the given tags. + """ + saved_model = reader.read_saved_model(saved_model_dir) + set_of_tags = set([tag.strip() for tag in tags.split(',')]) + for meta_graph_def in saved_model.meta_graphs: + if set(meta_graph_def.meta_info_def.tags) == set_of_tags: + return meta_graph_def + raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags)) + + def _get_signature_def(signature_def_key, export_dir, tags): """Construct a `SignatureDef` proto.""" signature_def_key = ( signature_def_key or signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) - metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags) + metagraph_def = get_meta_graph_def(export_dir, tags) try: signature_def = signature_def_utils.get_signature_def_by_key( @@ -114,8 +138,8 @@ class SavedModelPredictor(predictor.Predictor): output_names: A dictionary mapping strings to `Tensor`s in the `SavedModel` that represent the output. The keys can be any string of the user's choosing. - tags: Optional. Tags that will be used to retrieve the correct - `SignatureDef`. Defaults to `DEFAULT_TAGS`. + tags: Optional. Comma separated list of tags that will be used to retrieve + the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be done. Raises: