Automated g4 rollback of changelist 161218103

PiperOrigin-RevId: 161671226
This commit is contained in:
A. Unique TensorFlower 2017-07-12 09:01:22 -07:00 committed by TensorFlower Gardener
parent 786bf6cd65
commit 576c7b1ec8
6 changed files with 48 additions and 10 deletions

View File

@ -51,6 +51,7 @@ py_library(
"//tensorflow/contrib/ndlstm",
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/opt:opt_py",
"//tensorflow/contrib/predictor",
"//tensorflow/contrib/quantization:quantization_py",
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
"//tensorflow/contrib/resampler:resampler_py",

View File

@ -50,6 +50,7 @@ from tensorflow.contrib import metrics
from tensorflow.contrib import nccl
from tensorflow.contrib import nn
from tensorflow.contrib import opt
from tensorflow.contrib import predictor
from tensorflow.contrib import quantization
from tensorflow.contrib import resampler
from tensorflow.contrib import rnn

View File

@ -455,6 +455,7 @@ add_python_module("tensorflow/contrib/pi_examples")
add_python_module("tensorflow/contrib/pi_examples/camera")
add_python_module("tensorflow/contrib/pi_examples/label_image")
add_python_module("tensorflow/contrib/pi_examples/label_image/data")
add_python_module("tensorflow/contrib/predictor")
add_python_module("tensorflow/contrib/quantization")
add_python_module("tensorflow/contrib/quantization/python")
add_python_module("tensorflow/contrib/remote_fused_graph/pylib")

View File

@ -1,6 +1,6 @@
# `Predictor` classes provide an interface for efficient, repeated inference.
package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"])
package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"])
licenses(["notice"]) # Apache 2.0
@ -62,7 +62,10 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":base_predictor",
"//tensorflow/python/tools:saved_model_cli",
"//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/saved_model:signature_constants",
"//tensorflow/python/saved_model:signature_def_utils",
],
)

View File

@ -13,12 +13,20 @@
# limitations under the License.
# ==============================================================================
"""Modules for `Predictor`s."""
"""Modules for `Predictor`s.
@@from_contrib_estimator
@@from_estimator
@@from_saved_model
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.predictor import from_contrib_estimator
from tensorflow.contrib.predictor import from_estimator
from tensorflow.contrib.predictor import from_saved_model
from tensorflow.contrib.predictor.predictor_factories import from_contrib_estimator
from tensorflow.contrib.predictor.predictor_factories import from_estimator
from tensorflow.contrib.predictor.predictor_factories import from_saved_model
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -22,12 +22,12 @@ from __future__ import print_function
import logging
from tensorflow.contrib.predictor import predictor
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_cli
DEFAULT_TAGS = 'serve'
@ -35,13 +35,37 @@ DEFAULT_TAGS = 'serve'
_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}'
def get_meta_graph_def(saved_model_dir, tags):
"""Gets `MetaGraphDef` from a directory containing a `SavedModel`.
Returns the `MetaGraphDef` for the given tag-set and SavedModel directory.
Args:
saved_model_dir: Directory containing the SavedModel.
tags: Comma separated list of tags used to identify the correct
`MetaGraphDef`.
Raises:
ValueError: An error when the given tags cannot be found.
Returns:
A `MetaGraphDef` corresponding to the given tags.
"""
saved_model = reader.read_saved_model(saved_model_dir)
set_of_tags = set([tag.strip() for tag in tags.split(',')])
for meta_graph_def in saved_model.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
return meta_graph_def
raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags))
def _get_signature_def(signature_def_key, export_dir, tags):
"""Construct a `SignatureDef` proto."""
signature_def_key = (
signature_def_key or
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags)
metagraph_def = get_meta_graph_def(export_dir, tags)
try:
signature_def = signature_def_utils.get_signature_def_by_key(
@ -114,8 +138,8 @@ class SavedModelPredictor(predictor.Predictor):
output_names: A dictionary mapping strings to `Tensor`s in the
`SavedModel` that represent the output. The keys can be any string of
the user's choosing.
tags: Optional. Tags that will be used to retrieve the correct
`SignatureDef`. Defaults to `DEFAULT_TAGS`.
tags: Optional. Comma separated list of tags that will be used to retrieve
the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
Raises: