Fix dependencies and import statements for predictor module.

PiperOrigin-RevId: 161203536
This commit is contained in:
A. Unique TensorFlower 2017-07-07 08:20:20 -07:00 committed by TensorFlower Gardener
parent aee58d0720
commit 81e81b796d
5 changed files with 34 additions and 9 deletions

View File

@ -51,6 +51,7 @@ py_library(
"//tensorflow/contrib/ndlstm",
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/opt:opt_py",
"//tensorflow/contrib/predictor",
"//tensorflow/contrib/quantization:quantization_py",
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
"//tensorflow/contrib/resampler:resampler_py",

View File

@ -50,6 +50,7 @@ from tensorflow.contrib import metrics
from tensorflow.contrib import nccl
from tensorflow.contrib import nn
from tensorflow.contrib import opt
from tensorflow.contrib import predictor
from tensorflow.contrib import quantization
from tensorflow.contrib import resampler
from tensorflow.contrib import rnn

View File

@ -1,6 +1,6 @@
# `Predictor` classes provide an interface for efficient, repeated inference.
package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"])
package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"])
licenses(["notice"]) # Apache 2.0
@ -62,7 +62,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":base_predictor",
"//tensorflow/python/tools:saved_model_cli",
],
)

View File

@ -19,6 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.predictor import from_contrib_estimator
from tensorflow.contrib.predictor import from_estimator
from tensorflow.contrib.predictor import from_saved_model
from tensorflow.contrib.predictor.predictor_factories import from_contrib_estimator
from tensorflow.contrib.predictor.predictor_factories import from_estimator
from tensorflow.contrib.predictor.predictor_factories import from_saved_model

View File

@ -22,12 +22,12 @@ from __future__ import print_function
import logging
from tensorflow.contrib.predictor import predictor
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_cli
DEFAULT_TAGS = 'serve'
@ -35,13 +35,37 @@ DEFAULT_TAGS = 'serve'
_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}'
def get_meta_graph_def(saved_model_dir, tags):
"""Gets `MetaGraphDef` from a directory containing a `SavedModel`.
Returns the `MetaGraphDef` for the given tag-set and SavedModel directory.
Args:
saved_model_dir: Directory containing the SavedModel.
tags: Comma separated list of tags used to identify the correct
`MetaGraphDef`.
Raises:
ValueError: An error when the given tags cannot be found.
Returns:
A `MetaGraphDef` corresponding to the given tags.
"""
saved_model = reader.read_saved_model(saved_model_dir)
set_of_tags = set([tag.strip() for tag in tags.split(',')])
for meta_graph_def in saved_model.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
return meta_graph_def
raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags))
def _get_signature_def(signature_def_key, export_dir, tags):
"""Construct a `SignatureDef` proto."""
signature_def_key = (
signature_def_key or
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags)
metagraph_def = get_meta_graph_def(export_dir, tags)
try:
signature_def = signature_def_utils.get_signature_def_by_key(
@ -114,8 +138,8 @@ class SavedModelPredictor(predictor.Predictor):
output_names: A dictionary mapping strings to `Tensor`s in the
`SavedModel` that represent the output. The keys can be any string of
the user's choosing.
tags: Optional. Tags that will be used to retrieve the correct
`SignatureDef`. Defaults to `DEFAULT_TAGS`.
tags: Optional. Comma separated list of tags that will be used to retrieve
the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
Raises: