Automated g4 rollback of changelist 161218103

PiperOrigin-RevId: 161671226
This commit is contained in:
A. Unique TensorFlower 2017-07-12 09:01:22 -07:00 committed by TensorFlower Gardener
parent 786bf6cd65
commit 576c7b1ec8
6 changed files with 48 additions and 10 deletions

View File

@ -51,6 +51,7 @@ py_library(
"//tensorflow/contrib/ndlstm", "//tensorflow/contrib/ndlstm",
"//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/opt:opt_py", "//tensorflow/contrib/opt:opt_py",
"//tensorflow/contrib/predictor",
"//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantization:quantization_py",
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
"//tensorflow/contrib/resampler:resampler_py", "//tensorflow/contrib/resampler:resampler_py",

View File

@ -50,6 +50,7 @@ from tensorflow.contrib import metrics
from tensorflow.contrib import nccl from tensorflow.contrib import nccl
from tensorflow.contrib import nn from tensorflow.contrib import nn
from tensorflow.contrib import opt from tensorflow.contrib import opt
from tensorflow.contrib import predictor
from tensorflow.contrib import quantization from tensorflow.contrib import quantization
from tensorflow.contrib import resampler from tensorflow.contrib import resampler
from tensorflow.contrib import rnn from tensorflow.contrib import rnn

View File

@ -455,6 +455,7 @@ add_python_module("tensorflow/contrib/pi_examples")
add_python_module("tensorflow/contrib/pi_examples/camera") add_python_module("tensorflow/contrib/pi_examples/camera")
add_python_module("tensorflow/contrib/pi_examples/label_image") add_python_module("tensorflow/contrib/pi_examples/label_image")
add_python_module("tensorflow/contrib/pi_examples/label_image/data") add_python_module("tensorflow/contrib/pi_examples/label_image/data")
add_python_module("tensorflow/contrib/predictor")
add_python_module("tensorflow/contrib/quantization") add_python_module("tensorflow/contrib/quantization")
add_python_module("tensorflow/contrib/quantization/python") add_python_module("tensorflow/contrib/quantization/python")
add_python_module("tensorflow/contrib/remote_fused_graph/pylib") add_python_module("tensorflow/contrib/remote_fused_graph/pylib")

View File

@ -1,6 +1,6 @@
# `Predictor` classes provide an interface for efficient, repeated inference. # `Predictor` classes provide an interface for efficient, repeated inference.
package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"]) package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
@ -62,7 +62,10 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":base_predictor", ":base_predictor",
"//tensorflow/python/tools:saved_model_cli", "//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/saved_model:signature_constants",
"//tensorflow/python/saved_model:signature_def_utils",
], ],
) )

View File

@ -13,12 +13,20 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Modules for `Predictor`s.""" """Modules for `Predictor`s.
@@from_contrib_estimator
@@from_estimator
@@from_saved_model
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.predictor import from_contrib_estimator from tensorflow.contrib.predictor.predictor_factories import from_contrib_estimator
from tensorflow.contrib.predictor import from_estimator from tensorflow.contrib.predictor.predictor_factories import from_estimator
from tensorflow.contrib.predictor import from_saved_model from tensorflow.contrib.predictor.predictor_factories import from_saved_model
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -22,12 +22,12 @@ from __future__ import print_function
import logging import logging
from tensorflow.contrib.predictor import predictor from tensorflow.contrib.predictor import predictor
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_cli
DEFAULT_TAGS = 'serve' DEFAULT_TAGS = 'serve'
@ -35,13 +35,37 @@ DEFAULT_TAGS = 'serve'
_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}' _DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}'
def get_meta_graph_def(saved_model_dir, tags):
"""Gets `MetaGraphDef` from a directory containing a `SavedModel`.
Returns the `MetaGraphDef` for the given tag-set and SavedModel directory.
Args:
saved_model_dir: Directory containing the SavedModel.
tags: Comma separated list of tags used to identify the correct
`MetaGraphDef`.
Raises:
ValueError: An error when the given tags cannot be found.
Returns:
A `MetaGraphDef` corresponding to the given tags.
"""
saved_model = reader.read_saved_model(saved_model_dir)
set_of_tags = set([tag.strip() for tag in tags.split(',')])
for meta_graph_def in saved_model.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
return meta_graph_def
raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags))
def _get_signature_def(signature_def_key, export_dir, tags): def _get_signature_def(signature_def_key, export_dir, tags):
"""Construct a `SignatureDef` proto.""" """Construct a `SignatureDef` proto."""
signature_def_key = ( signature_def_key = (
signature_def_key or signature_def_key or
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags) metagraph_def = get_meta_graph_def(export_dir, tags)
try: try:
signature_def = signature_def_utils.get_signature_def_by_key( signature_def = signature_def_utils.get_signature_def_by_key(
@ -114,8 +138,8 @@ class SavedModelPredictor(predictor.Predictor):
output_names: A dictionary mapping strings to `Tensor`s in the output_names: A dictionary mapping strings to `Tensor`s in the
`SavedModel` that represent the output. The keys can be any string of `SavedModel` that represent the output. The keys can be any string of
the user's choosing. the user's choosing.
tags: Optional. Tags that will be used to retrieve the correct tags: Optional. Comma separated list of tags that will be used to retrieve
`SignatureDef`. Defaults to `DEFAULT_TAGS`. the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be graph: Optional. The Tensorflow `graph` in which prediction should be
done. done.
Raises: Raises: