Fix dependencies and import statements for predictor module.
PiperOrigin-RevId: 161203536
This commit is contained in:
parent
aee58d0720
commit
81e81b796d
@ -51,6 +51,7 @@ py_library(
|
|||||||
"//tensorflow/contrib/ndlstm",
|
"//tensorflow/contrib/ndlstm",
|
||||||
"//tensorflow/contrib/nn:nn_py",
|
"//tensorflow/contrib/nn:nn_py",
|
||||||
"//tensorflow/contrib/opt:opt_py",
|
"//tensorflow/contrib/opt:opt_py",
|
||||||
|
"//tensorflow/contrib/predictor",
|
||||||
"//tensorflow/contrib/quantization:quantization_py",
|
"//tensorflow/contrib/quantization:quantization_py",
|
||||||
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
|
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
|
||||||
"//tensorflow/contrib/resampler:resampler_py",
|
"//tensorflow/contrib/resampler:resampler_py",
|
||||||
|
@ -50,6 +50,7 @@ from tensorflow.contrib import metrics
|
|||||||
from tensorflow.contrib import nccl
|
from tensorflow.contrib import nccl
|
||||||
from tensorflow.contrib import nn
|
from tensorflow.contrib import nn
|
||||||
from tensorflow.contrib import opt
|
from tensorflow.contrib import opt
|
||||||
|
from tensorflow.contrib import predictor
|
||||||
from tensorflow.contrib import quantization
|
from tensorflow.contrib import quantization
|
||||||
from tensorflow.contrib import resampler
|
from tensorflow.contrib import resampler
|
||||||
from tensorflow.contrib import rnn
|
from tensorflow.contrib import rnn
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# `Predictor` classes provide an interface for efficient, repeated inference.
|
# `Predictor` classes provide an interface for efficient, repeated inference.
|
||||||
|
|
||||||
package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"])
|
package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"])
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
@ -62,7 +62,6 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":base_predictor",
|
":base_predictor",
|
||||||
"//tensorflow/python/tools:saved_model_cli",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,6 +19,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.predictor import from_contrib_estimator
|
from tensorflow.contrib.predictor.predictor_factories import from_contrib_estimator
|
||||||
from tensorflow.contrib.predictor import from_estimator
|
from tensorflow.contrib.predictor.predictor_factories import from_estimator
|
||||||
from tensorflow.contrib.predictor import from_saved_model
|
from tensorflow.contrib.predictor.predictor_factories import from_saved_model
|
||||||
|
@ -22,12 +22,12 @@ from __future__ import print_function
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from tensorflow.contrib.predictor import predictor
|
from tensorflow.contrib.predictor import predictor
|
||||||
|
from tensorflow.contrib.saved_model.python.saved_model import reader
|
||||||
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
|
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.saved_model import loader
|
from tensorflow.python.saved_model import loader
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.tools import saved_model_cli
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TAGS = 'serve'
|
DEFAULT_TAGS = 'serve'
|
||||||
@ -35,13 +35,37 @@ DEFAULT_TAGS = 'serve'
|
|||||||
_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}'
|
_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}'
|
||||||
|
|
||||||
|
|
||||||
|
def get_meta_graph_def(saved_model_dir, tags):
|
||||||
|
"""Gets `MetaGraphDef` from a directory containing a `SavedModel`.
|
||||||
|
|
||||||
|
Returns the `MetaGraphDef` for the given tag-set and SavedModel directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
saved_model_dir: Directory containing the SavedModel.
|
||||||
|
tags: Comma separated list of tags used to identify the correct
|
||||||
|
`MetaGraphDef`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: An error when the given tags cannot be found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `MetaGraphDef` corresponding to the given tags.
|
||||||
|
"""
|
||||||
|
saved_model = reader.read_saved_model(saved_model_dir)
|
||||||
|
set_of_tags = set([tag.strip() for tag in tags.split(',')])
|
||||||
|
for meta_graph_def in saved_model.meta_graphs:
|
||||||
|
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
|
||||||
|
return meta_graph_def
|
||||||
|
raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags))
|
||||||
|
|
||||||
|
|
||||||
def _get_signature_def(signature_def_key, export_dir, tags):
|
def _get_signature_def(signature_def_key, export_dir, tags):
|
||||||
"""Construct a `SignatureDef` proto."""
|
"""Construct a `SignatureDef` proto."""
|
||||||
signature_def_key = (
|
signature_def_key = (
|
||||||
signature_def_key or
|
signature_def_key or
|
||||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
||||||
|
|
||||||
metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags)
|
metagraph_def = get_meta_graph_def(export_dir, tags)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
signature_def = signature_def_utils.get_signature_def_by_key(
|
signature_def = signature_def_utils.get_signature_def_by_key(
|
||||||
@ -114,8 +138,8 @@ class SavedModelPredictor(predictor.Predictor):
|
|||||||
output_names: A dictionary mapping strings to `Tensor`s in the
|
output_names: A dictionary mapping strings to `Tensor`s in the
|
||||||
`SavedModel` that represent the output. The keys can be any string of
|
`SavedModel` that represent the output. The keys can be any string of
|
||||||
the user's choosing.
|
the user's choosing.
|
||||||
tags: Optional. Tags that will be used to retrieve the correct
|
tags: Optional. Comma separated list of tags that will be used to retrieve
|
||||||
`SignatureDef`. Defaults to `DEFAULT_TAGS`.
|
the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`.
|
||||||
graph: Optional. The Tensorflow `graph` in which prediction should be
|
graph: Optional. The Tensorflow `graph` in which prediction should be
|
||||||
done.
|
done.
|
||||||
Raises:
|
Raises:
|
||||||
|
Loading…
Reference in New Issue
Block a user