Fix dependencies and import statements for predictor module.

PiperOrigin-RevId: 161203536
This commit is contained in:
A. Unique TensorFlower 2017-07-07 08:20:20 -07:00 committed by TensorFlower Gardener
parent aee58d0720
commit 81e81b796d
5 changed files with 34 additions and 9 deletions

View File

@ -51,6 +51,7 @@ py_library(
"//tensorflow/contrib/ndlstm", "//tensorflow/contrib/ndlstm",
"//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/opt:opt_py", "//tensorflow/contrib/opt:opt_py",
"//tensorflow/contrib/predictor",
"//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantization:quantization_py",
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
"//tensorflow/contrib/resampler:resampler_py", "//tensorflow/contrib/resampler:resampler_py",

View File

@ -50,6 +50,7 @@ from tensorflow.contrib import metrics
from tensorflow.contrib import nccl from tensorflow.contrib import nccl
from tensorflow.contrib import nn from tensorflow.contrib import nn
from tensorflow.contrib import opt from tensorflow.contrib import opt
from tensorflow.contrib import predictor
from tensorflow.contrib import quantization from tensorflow.contrib import quantization
from tensorflow.contrib import resampler from tensorflow.contrib import resampler
from tensorflow.contrib import rnn from tensorflow.contrib import rnn

View File

@ -1,6 +1,6 @@
# `Predictor` classes provide an interface for efficient, repeated inference. # `Predictor` classes provide an interface for efficient, repeated inference.
package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"]) package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
@ -62,7 +62,6 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":base_predictor", ":base_predictor",
"//tensorflow/python/tools:saved_model_cli",
], ],
) )

View File

@ -19,6 +19,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.predictor import from_contrib_estimator from tensorflow.contrib.predictor.predictor_factories import from_contrib_estimator
from tensorflow.contrib.predictor import from_estimator from tensorflow.contrib.predictor.predictor_factories import from_estimator
from tensorflow.contrib.predictor import from_saved_model from tensorflow.contrib.predictor.predictor_factories import from_saved_model

View File

@ -22,12 +22,12 @@ from __future__ import print_function
import logging import logging
from tensorflow.contrib.predictor import predictor from tensorflow.contrib.predictor import predictor
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_cli
DEFAULT_TAGS = 'serve' DEFAULT_TAGS = 'serve'
@ -35,13 +35,37 @@ DEFAULT_TAGS = 'serve'
_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}' _DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}'
def get_meta_graph_def(saved_model_dir, tags):
"""Gets `MetaGraphDef` from a directory containing a `SavedModel`.
Returns the `MetaGraphDef` for the given tag-set and SavedModel directory.
Args:
saved_model_dir: Directory containing the SavedModel.
tags: Comma separated list of tags used to identify the correct
`MetaGraphDef`.
Raises:
ValueError: An error when the given tags cannot be found.
Returns:
A `MetaGraphDef` corresponding to the given tags.
"""
saved_model = reader.read_saved_model(saved_model_dir)
set_of_tags = set([tag.strip() for tag in tags.split(',')])
for meta_graph_def in saved_model.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
return meta_graph_def
raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags))
def _get_signature_def(signature_def_key, export_dir, tags): def _get_signature_def(signature_def_key, export_dir, tags):
"""Construct a `SignatureDef` proto.""" """Construct a `SignatureDef` proto."""
signature_def_key = ( signature_def_key = (
signature_def_key or signature_def_key or
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags) metagraph_def = get_meta_graph_def(export_dir, tags)
try: try:
signature_def = signature_def_utils.get_signature_def_by_key( signature_def = signature_def_utils.get_signature_def_by_key(
@ -114,8 +138,8 @@ class SavedModelPredictor(predictor.Predictor):
output_names: A dictionary mapping strings to `Tensor`s in the output_names: A dictionary mapping strings to `Tensor`s in the
`SavedModel` that represent the output. The keys can be any string of `SavedModel` that represent the output. The keys can be any string of
the user's choosing. the user's choosing.
tags: Optional. Tags that will be used to retrieve the correct tags: Optional. Comma separated list of tags that will be used to retrieve
`SignatureDef`. Defaults to `DEFAULT_TAGS`. the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be graph: Optional. The Tensorflow `graph` in which prediction should be
done. done.
Raises: Raises: