parent
5e876a8c25
commit
6c528feaf8
@ -28,7 +28,6 @@ py_library(
|
||||
":multi_head",
|
||||
":replicate_model_fn",
|
||||
":rnn",
|
||||
":saved_model_estimator",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
],
|
||||
)
|
||||
@ -466,43 +465,3 @@ py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saved_model_estimator",
|
||||
srcs = ["python/estimator/saved_model_estimator.py"],
|
||||
deps = [
|
||||
":export",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/estimator",
|
||||
"//tensorflow/python/estimator:export",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
"//tensorflow/python/saved_model",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "saved_model_estimator_test",
|
||||
size = "medium",
|
||||
srcs = ["python/estimator/saved_model_estimator_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":export",
|
||||
":saved_model_estimator",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:metrics",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/estimator",
|
||||
"//tensorflow/python/estimator:export_export",
|
||||
"//tensorflow/python/estimator:export_output",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
],
|
||||
)
|
||||
|
@ -33,8 +33,6 @@ from tensorflow.contrib.estimator.python.estimator.logit_fns import *
|
||||
from tensorflow.contrib.estimator.python.estimator.multi_head import *
|
||||
from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import *
|
||||
from tensorflow.contrib.estimator.python.estimator.rnn import *
|
||||
from tensorflow.contrib.estimator.python.estimator.saved_model_estimator import *
|
||||
from tensorflow.python.estimator.export.export import *
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
@ -72,9 +70,6 @@ _allowed_symbols = [
|
||||
'stop_if_higher_hook',
|
||||
'stop_if_no_increase_hook',
|
||||
'stop_if_no_decrease_hook',
|
||||
'build_raw_supervised_input_receiver_fn',
|
||||
'build_supervised_input_receiver_fn_from_input_fn',
|
||||
'SavedModelEstimator'
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
||||
|
@ -1,445 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Class that creates an Estimator from a SavedModel."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.estimator import estimator as estimator_lib
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator.export import export as export_lib
|
||||
from tensorflow.python.estimator.export import export_output
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
class SavedModelEstimator(estimator_lib.Estimator):
|
||||
"""Create an Estimator from a SavedModel.
|
||||
|
||||
Only SavedModels exported with
|
||||
`tf.contrib.estimator.export_all_saved_models()` or
|
||||
`tf.estimator.Estimator.export_savedmodel()` are supported for this class.
|
||||
|
||||
Example with `tf.estimator.DNNClassifier`:
|
||||
|
||||
**Step 1: Create and train DNNClassifier.**
|
||||
```python
|
||||
feature1 = tf.feature_column.embedding_column(
|
||||
tf.feature_column.categorical_column_with_vocabulary_list(
|
||||
key='feature1', vocabulary_list=('green', 'yellow')), dimension=1)
|
||||
feature2 = tf.feature_column.numeric_column(key='feature2', default_value=0.0)
|
||||
|
||||
classifier = tf.estimator.DNNClassifier(
|
||||
hidden_units=[4,2], feature_columns=[feature1, feature2])
|
||||
|
||||
def input_fn():
|
||||
features = {'feature1': tf.constant(['green', 'green', 'yellow']),
|
||||
'feature2': tf.constant([3.5, 4.2, 6.1])}
|
||||
label = tf.constant([1., 0., 0.])
|
||||
return tf.data.Dataset.from_tensors((features, label)).repeat()
|
||||
|
||||
classifier.train(input_fn=input_fn, steps=10)
|
||||
```
|
||||
|
||||
**Step 2: Export classifier.**
|
||||
First, build functions that specify the expected inputs.
|
||||
```python
|
||||
# During train and evaluation, both the features and labels should be defined.
|
||||
supervised_input_receiver_fn = (
|
||||
tf.contrib.estimator.build_raw_supervised_input_receiver_fn(
|
||||
{'feature1': tf.placeholder(dtype=tf.string, shape=[None]),
|
||||
'feature2': tf.placeholder(dtype=tf.float32, shape=[None])},
|
||||
tf.placeholder(dtype=tf.float32, shape=[None])))
|
||||
|
||||
# During predict mode, expect to receive a `tf.Example` proto, so a parsing
|
||||
# function is used.
|
||||
serving_input_receiver_fn = (
|
||||
tf.estimator.export.build_parsing_serving_input_receiver_fn(
|
||||
tf.feature_column.make_parse_example_spec([feature1, feature2])))
|
||||
```
|
||||
|
||||
Next, export the model as a SavedModel. A timestamped directory will be
|
||||
created (for example `/tmp/export_all/1234567890`).
|
||||
```python
|
||||
# Option 1: Save all modes (train, eval, predict)
|
||||
export_dir = tf.contrib.estimator.export_all_saved_models(
|
||||
classifier, '/tmp/export_all',
|
||||
{tf.estimator.ModeKeys.TRAIN: supervised_input_receiver_fn,
|
||||
tf.estimator.ModeKeys.EVAL: supervised_input_receiver_fn,
|
||||
tf.estimator.ModeKeys.PREDICT: serving_input_receiver_fn})
|
||||
|
||||
# Option 2: Only export predict mode
|
||||
export_dir = classifier.export_savedmodel(
|
||||
'/tmp/export_predict', serving_input_receiver_fn)
|
||||
```
|
||||
|
||||
**Step 3: Create a SavedModelEstimator from the exported SavedModel.**
|
||||
```python
|
||||
est = tf.contrib.estimator.SavedModelEstimator(export_dir)
|
||||
|
||||
# If all modes were exported, you can immediately evaluate and predict, or
|
||||
# continue training. Otherwise only predict is available.
|
||||
eval_results = est.evaluate(input_fn=input_fn, steps=1)
|
||||
print(eval_results)
|
||||
|
||||
est.train(input_fn=input_fn, steps=20)
|
||||
|
||||
def predict_input_fn():
|
||||
example = example_pb2.Example()
|
||||
example.features.feature['feature1'].bytes_list.value.extend(['yellow'])
|
||||
example.features.feature['feature2'].float_list.value.extend([1.])
|
||||
return {'inputs':tf.constant([example.SerializeToString()])}
|
||||
|
||||
predictions = est.predict(predict_input_fn)
|
||||
print(next(predictions))
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, saved_model_dir, model_dir=None):
|
||||
"""Initialize a SavedModelEstimator.
|
||||
|
||||
The SavedModelEstimator loads its model function and variable values from
|
||||
the graphs defined in the SavedModel. There is no option to pass in
|
||||
`RunConfig` or `params` arguments, because the model function graph is
|
||||
defined statically in the SavedModel.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing SavedModel protobuf and subfolders.
|
||||
model_dir: Directory to save new checkpoints during training.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If a DistributionStrategy is defined in the config.
|
||||
Unless the SavedModelEstimator is subclassed, this shouldn't happen.
|
||||
"""
|
||||
checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access
|
||||
vars_to_warm_start = [name for name, _ in
|
||||
checkpoint_utils.list_variables(checkpoint)]
|
||||
warm_start_settings = estimator_lib.WarmStartSettings(
|
||||
ckpt_to_initialize_from=checkpoint,
|
||||
vars_to_warm_start=vars_to_warm_start)
|
||||
|
||||
super(SavedModelEstimator, self).__init__(
|
||||
model_fn=self._model_fn_from_saved_model, model_dir=model_dir,
|
||||
warm_start_from=warm_start_settings)
|
||||
if self._distribution is not None:
|
||||
raise NotImplementedError(
|
||||
'SavedModelEstimator currently does not support '
|
||||
'DistributionStrategy.')
|
||||
self.saved_model_dir = saved_model_dir
|
||||
self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir)
|
||||
self._available_modes = self._extract_available_modes()
|
||||
|
||||
def _extract_available_modes(self):
|
||||
"""Return list of modes found in SavedModel."""
|
||||
available_modes = []
|
||||
logging.info('Checking available modes for SavedModelEstimator.')
|
||||
for mode in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
|
||||
model_fn_lib.ModeKeys.PREDICT]:
|
||||
try:
|
||||
self._get_meta_graph_def_for_mode(mode)
|
||||
except RuntimeError:
|
||||
logging.warning('%s mode not found in SavedModel.' % mode)
|
||||
continue
|
||||
|
||||
if self._get_signature_def_for_mode(mode) is not None:
|
||||
available_modes.append(mode)
|
||||
|
||||
logging.info('Available modes for Estimator: %s' % available_modes)
|
||||
return available_modes
|
||||
|
||||
def _validate_mode(self, mode):
|
||||
"""Make sure that mode can be run using the SavedModel."""
|
||||
if mode not in self._available_modes:
|
||||
raise RuntimeError('%s mode is not available in the SavedModel. Use '
|
||||
'saved_model_cli to check that the Metagraph for this '
|
||||
'mode has been exported.' % mode)
|
||||
|
||||
def _get_meta_graph_def_for_mode(self, mode):
|
||||
tags = model_fn_lib.EXPORT_TAG_MAP[mode]
|
||||
return self.saved_model_loader.get_meta_graph_def_from_tags(tags)
|
||||
|
||||
def _get_signature_def_for_mode(self, mode):
|
||||
meta_graph_def = self._get_meta_graph_def_for_mode(mode)
|
||||
sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
if mode == model_fn_lib.ModeKeys.PREDICT else mode)
|
||||
if sig_def_key not in meta_graph_def.signature_def:
|
||||
logging.warning('Metagraph for mode %s was found, but SignatureDef with'
|
||||
' key \"%s\" is missing.' % (mode, sig_def_key))
|
||||
return None
|
||||
return meta_graph_def.signature_def[sig_def_key]
|
||||
|
||||
def _create_and_assert_global_step(self, graph):
|
||||
# Do nothing here. The global step variable will be created/loaded from the
|
||||
# SavedModel. If a global step variable were created here, the result
|
||||
# will be two duplicate global step variables, causing issues during
|
||||
# the warm-start phase.
|
||||
# Due to the global variable being created in the model function, this may
|
||||
# cause issues when running DistributionStrategy. Thus, DistributionStrategy
|
||||
# is not yet supported with SavedModelEstimator.
|
||||
pass
|
||||
|
||||
def _model_fn_from_saved_model(self, features, labels, mode):
|
||||
"""Load a SavedModel graph and return an EstimatorSpec."""
|
||||
# TODO(kathywu): Model function loads placeholders from the graph. Calling
|
||||
# export_all_saved_models creates another placeholder for the inputs, on top
|
||||
# of the original placeholders. There should be a way to avoid this.
|
||||
self._validate_mode(mode)
|
||||
|
||||
g = ops.get_default_graph()
|
||||
if training_util.get_global_step(g) is not None:
|
||||
raise RuntimeError(
|
||||
'Graph must not contain a global step tensor before the SavedModel is'
|
||||
' loaded. Please make sure that the input function does not create a '
|
||||
'global step.')
|
||||
|
||||
# Extract SignatureDef for information about the input and output tensors.
|
||||
signature_def = self._get_signature_def_for_mode(mode)
|
||||
|
||||
# Generate input map for replacing the inputs in the SavedModel graph with
|
||||
# the provided features and labels.
|
||||
input_map = _generate_input_map(signature_def, features, labels)
|
||||
|
||||
# Create a list of the names of output tensors. When the graph is loaded,
|
||||
# names of the output tensors may be remapped. This ensures that the correct
|
||||
# tensors are returned in the EstimatorSpec.
|
||||
output_tensor_names = [
|
||||
value.name for value in six.itervalues(signature_def.outputs)]
|
||||
|
||||
# Load the graph. `output_tensors` contains output `Tensors` in the same
|
||||
# same order as the `output_tensor_names` list.
|
||||
tags = model_fn_lib.EXPORT_TAG_MAP[mode]
|
||||
_, output_tensors = self.saved_model_loader.load_graph(
|
||||
g, tags, input_map=input_map, return_elements=output_tensor_names)
|
||||
|
||||
# Create a scaffold from the MetaGraphDef that contains ops to initialize
|
||||
# the graph. This should mirror the steps from _add_meta_graph_for_mode(),
|
||||
# which creates a MetaGraphDef from the EstimatorSpec's scaffold.
|
||||
scaffold = monitored_session.Scaffold(
|
||||
local_init_op=loader_impl._get_legacy_init_op_tensor( # pylint: disable=protected-access
|
||||
self._get_meta_graph_def_for_mode(mode)))
|
||||
|
||||
# Ensure that a global step tensor has been created.
|
||||
global_step_tensor = training_util.get_global_step(g)
|
||||
training_util.assert_global_step(global_step_tensor)
|
||||
|
||||
# Extract values to return in the EstimatorSpec.
|
||||
output_map = dict(zip(output_tensor_names, output_tensors))
|
||||
outputs = {key: output_map[value.name]
|
||||
for key, value in six.iteritems(signature_def.outputs)}
|
||||
|
||||
loss, predictions, metrics = _validate_and_extract_outputs(
|
||||
mode, outputs, signature_def.method_name)
|
||||
|
||||
train_op = ops.get_collection(constants.TRAIN_OP_KEY)
|
||||
if len(train_op) > 1:
|
||||
raise RuntimeError('Multiple ops found in the train_op collection.')
|
||||
train_op = None if not train_op else train_op[0]
|
||||
|
||||
_clear_saved_model_collections()
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
scaffold=scaffold,
|
||||
mode=mode,
|
||||
loss=loss,
|
||||
train_op=train_op,
|
||||
predictions=predictions,
|
||||
eval_metric_ops=metrics)
|
||||
|
||||
|
||||
def _clear_saved_model_collections():
|
||||
"""Clear collections that are expected empty when exporting a SavedModel.
|
||||
|
||||
The SavedModel builder uses these collections to track ops necessary to
|
||||
restore the graph state. These collections are expected to be empty before
|
||||
MetaGraphs are added to the builder.
|
||||
"""
|
||||
del ops.get_collection_ref(constants.ASSETS_KEY)[:]
|
||||
del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:]
|
||||
del ops.get_collection_ref(constants.MAIN_OP_KEY)[:]
|
||||
del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:]
|
||||
|
||||
|
||||
def _generate_input_map(signature_def, features, labels):
|
||||
"""Return dict mapping an input tensor name to a feature or label tensor.
|
||||
|
||||
Args:
|
||||
signature_def: SignatureDef loaded from SavedModel
|
||||
features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
|
||||
`SparseTensor`, specifying the features to be passed to the model.
|
||||
labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
|
||||
`SparseTensor`, specifying the labels to be passed to the model. May be
|
||||
`None`.
|
||||
|
||||
Returns:
|
||||
dict mapping string names of inputs to features or labels tensors
|
||||
|
||||
Raises:
|
||||
ValueError: if SignatureDef inputs are not completely mapped by the input
|
||||
features and labels.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
if not isinstance(features, dict):
|
||||
features = {export_lib._SINGLE_FEATURE_DEFAULT_NAME: features}
|
||||
if labels is not None and not isinstance(labels, dict):
|
||||
labels = {export_lib._SINGLE_LABEL_DEFAULT_NAME: labels}
|
||||
# pylint: enable=protected-access
|
||||
|
||||
inputs = signature_def.inputs
|
||||
input_map = {}
|
||||
for key, tensor_info in six.iteritems(inputs):
|
||||
input_name = tensor_info.name
|
||||
if ':' in input_name:
|
||||
input_name = input_name[:input_name.find(':')]
|
||||
|
||||
# When tensors are used as control inputs for operations, their names are
|
||||
# prepended with a '^' character in the GraphDef. To handle possible control
|
||||
# flow edge cases, control input names must be included in the input map.
|
||||
control_dependency_name = '^' + input_name
|
||||
|
||||
if key in features:
|
||||
_check_same_dtype_and_shape(features[key], tensor_info, key)
|
||||
input_map[input_name] = input_map[control_dependency_name] = features[key]
|
||||
elif labels is not None and key in labels:
|
||||
_check_same_dtype_and_shape(labels[key], tensor_info, key)
|
||||
input_map[input_name] = input_map[control_dependency_name] = labels[key]
|
||||
else:
|
||||
raise ValueError(
|
||||
'Key \"%s\" not found in features or labels passed in to the model '
|
||||
'function. All required keys: %s' % (key, inputs.keys()))
|
||||
|
||||
return input_map
|
||||
|
||||
|
||||
def _check_same_dtype_and_shape(tensor, tensor_info, name):
|
||||
"""Validate that tensor has the same properties as the TensorInfo proto.
|
||||
|
||||
Args:
|
||||
tensor: a `Tensor` object.
|
||||
tensor_info: a `TensorInfo` proto.
|
||||
name: Name of the input (to identify Tensor if an error is raised).
|
||||
|
||||
Raises:
|
||||
ValueError: If the tensor shape or dtype don't match the TensorInfo
|
||||
"""
|
||||
dtype_error = (tensor.dtype != dtypes.DType(tensor_info.dtype))
|
||||
shape_error = not tensor.shape.is_compatible_with(tensor_info.tensor_shape)
|
||||
|
||||
if dtype_error or shape_error:
|
||||
msg = 'Tensor shape and/or dtype validation failed for input %s:' % name
|
||||
if dtype_error:
|
||||
msg += ('\n\tExpected dtype: %s, Got: %s'
|
||||
% (dtypes.DType(tensor_info.dtype), tensor.dtype))
|
||||
if shape_error:
|
||||
msg += ('\n\tExpected shape: %s, Got: %s'
|
||||
% (tensor_shape.TensorShape(tensor_info.tensor_shape),
|
||||
tensor.shape))
|
||||
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _extract_eval_metrics(output_dict):
|
||||
"""Return a eval metric dict extracted from the output_dict.
|
||||
|
||||
Eval metrics consist of a value tensor and an update op. Both must be in the
|
||||
passed-in tensor dictionary for an eval metric to be added to the returned
|
||||
dictionary.
|
||||
|
||||
Args:
|
||||
output_dict: a dict that maps strings to tensors.
|
||||
|
||||
Returns:
|
||||
dict mapping strings to (value, update_op) tuples.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
metric_ops = {}
|
||||
separator_char = export_output._SupervisedOutput._SEPARATOR_CHAR
|
||||
|
||||
for key, tensor in six.iteritems(output_dict):
|
||||
split_key = key.split(separator_char)
|
||||
|
||||
# The metric name may contain the separator character, so recreate its name.
|
||||
metric_name = separator_char.join(split_key[:-1])
|
||||
|
||||
if split_key[0] == export_output._SupervisedOutput.METRICS_NAME:
|
||||
# If the key ends with the value suffix, and there is a corresponding
|
||||
# key ending with the update_op suffix, then add tensors to metrics dict.
|
||||
if split_key[-1] == export_output._SupervisedOutput.METRIC_VALUE_SUFFIX:
|
||||
update_op = ''.join(
|
||||
[metric_name, separator_char,
|
||||
export_output._SupervisedOutput.METRIC_UPDATE_SUFFIX])
|
||||
if update_op in output_dict:
|
||||
update_op_tensor = output_dict[update_op]
|
||||
metric_ops[metric_name] = (tensor, update_op_tensor)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
return metric_ops
|
||||
|
||||
|
||||
def _validate_and_extract_outputs(mode, output_dict, method_name):
|
||||
"""Extract values from SignatureDef output dictionary.
|
||||
|
||||
Args:
|
||||
mode: One of the modes enumerated in `tf.estimator.ModeKeys`.
|
||||
output_dict: dict of string SignatureDef keys to `Tensor`.
|
||||
method_name: Method name of the SignatureDef as a string.
|
||||
|
||||
Returns:
|
||||
Tuple of (
|
||||
loss: `Tensor` object,
|
||||
predictions: dictionary mapping string keys to `Tensor` objects,
|
||||
metrics: dictionary mapping string keys to a tuple of two `Tensor` objects
|
||||
)
|
||||
|
||||
Raises:
|
||||
RuntimeError: raised if SignatureDef has an invalid method name for the mode
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
loss, predictions, metrics = None, None, None
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.PREDICT:
|
||||
predictions = output_dict
|
||||
else:
|
||||
# Validate that the SignatureDef's method name matches the expected name for
|
||||
# the given mode.
|
||||
expected_method_name = signature_constants.SUPERVISED_TRAIN_METHOD_NAME
|
||||
if mode == model_fn_lib.ModeKeys.EVAL:
|
||||
expected_method_name = signature_constants.SUPERVISED_EVAL_METHOD_NAME
|
||||
if method_name != expected_method_name:
|
||||
raise RuntimeError(
|
||||
'Invalid SignatureDef method name for mode %s.\n\tExpected: %s\n\t'
|
||||
'Got: %s\nPlease ensure that the SavedModel was exported with '
|
||||
'`tf.contrib.estimator.export_all_saved_models()`.' %
|
||||
(mode, expected_method_name, method_name))
|
||||
|
||||
# Extract loss, metrics and predictions from the output dict.
|
||||
loss = output_dict[export_output._SupervisedOutput.LOSS_NAME]
|
||||
metrics = _extract_eval_metrics(output_dict)
|
||||
predictions = {
|
||||
key: value for key, value in six.iteritems(output_dict)
|
||||
if key.split(export_output._SupervisedOutput._SEPARATOR_CHAR)[0] == (
|
||||
export_output._SupervisedOutput.PREDICTIONS_NAME)}
|
||||
|
||||
# pylint: enable=protected-access
|
||||
return loss, predictions, metrics
|
@ -1,369 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for SavedModelEstimator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from tensorflow.contrib.estimator.python.estimator import export as contrib_export
|
||||
from tensorflow.contrib.estimator.python.estimator import saved_model_estimator
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.estimator import estimator
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator.export import export
|
||||
from tensorflow.python.estimator.export import export_output
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import metrics as metrics_lib
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import training
|
||||
|
||||
|
||||
def dummy_input_fn():
|
||||
return dataset_ops.Dataset.from_tensors((
|
||||
{'x': constant_op.constant([[1], [-2]], dtype=dtypes.int64)},
|
||||
constant_op.constant([[4], [-3]], dtype=dtypes.float32))).repeat()
|
||||
|
||||
|
||||
def dummy_input_fn_features_only():
|
||||
return dataset_ops.Dataset.from_tensors(
|
||||
{'x': constant_op.constant([[5], [6]], dtype=dtypes.int64)}).repeat()
|
||||
|
||||
|
||||
def dummy_supervised_receiver_fn():
|
||||
feature_spec = {
|
||||
'x': array_ops.placeholder(
|
||||
dtype=dtypes.int64, shape=(2, 1), name='feature_x'),
|
||||
}
|
||||
label_spec = array_ops.placeholder(
|
||||
dtype=dtypes.float32, shape=[2, 1], name='truth')
|
||||
return export.build_raw_supervised_input_receiver_fn(
|
||||
feature_spec, label_spec)
|
||||
|
||||
|
||||
def dummy_serving_receiver_fn():
|
||||
feature_spec = {'x': array_ops.placeholder(
|
||||
dtype=dtypes.int64, shape=(2, 1), name='feature_x'),}
|
||||
return export.build_raw_serving_input_receiver_fn(feature_spec)
|
||||
|
||||
|
||||
def model_fn_diff_modes(features, labels, mode):
|
||||
_, _ = features, labels
|
||||
v = variables.Variable(21, name='some_var')
|
||||
train_op = None
|
||||
loss = constant_op.constant(104)
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
loss = constant_op.constant(105)
|
||||
predictions = constant_op.constant([501])
|
||||
train_op = control_flow_ops.group(
|
||||
state_ops.assign_add(training.get_global_step(), 1),
|
||||
state_ops.assign_add(v, 3))
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
loss = constant_op.constant(106)
|
||||
predictions = constant_op.constant([502])
|
||||
else:
|
||||
loss = constant_op.constant(107)
|
||||
predictions = constant_op.constant([503])
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode,
|
||||
loss=loss,
|
||||
train_op=train_op,
|
||||
eval_metric_ops={
|
||||
'abs_err': metrics_lib.mean_absolute_error(
|
||||
constant_op.constant(0), predictions)},
|
||||
predictions=predictions)
|
||||
|
||||
|
||||
class SavedModelEstimatorTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirs = []
|
||||
|
||||
def tearDown(self):
|
||||
for tmpdir in self.tmpdirs:
|
||||
# gfile.DeleteRecursively fails in the windows cmake test, so use shutil.
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
self.tmpdirs = []
|
||||
|
||||
def _get_tmp_dir(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
self.tmpdirs.append(tmpdir)
|
||||
return tmpdir
|
||||
|
||||
def _export_estimator(self, train=True, evaluate=True, predict=True,
|
||||
model_fn=model_fn_diff_modes):
|
||||
est = estimator.Estimator(model_fn, self._get_tmp_dir())
|
||||
est.train(input_fn=dummy_input_fn, steps=10)
|
||||
|
||||
input_receiver_fn_map = {}
|
||||
if train:
|
||||
input_receiver_fn_map[model_fn_lib.ModeKeys.TRAIN] = (
|
||||
dummy_supervised_receiver_fn())
|
||||
if evaluate:
|
||||
input_receiver_fn_map[model_fn_lib.ModeKeys.EVAL] = (
|
||||
dummy_supervised_receiver_fn())
|
||||
if predict:
|
||||
input_receiver_fn_map[model_fn_lib.ModeKeys.PREDICT] = (
|
||||
dummy_serving_receiver_fn())
|
||||
|
||||
export_base_path = self._get_tmp_dir()
|
||||
export_dir = contrib_export.export_all_saved_models(
|
||||
est, export_base_path, input_receiver_fn_map)
|
||||
return export_dir
|
||||
|
||||
def test_load_all_modes(self):
|
||||
sme = saved_model_estimator.SavedModelEstimator(
|
||||
self._export_estimator(), self._get_tmp_dir())
|
||||
sme.train(input_fn=dummy_input_fn, steps=1)
|
||||
sme.train(input_fn=dummy_input_fn, steps=2)
|
||||
self.assertEqual(13, sme.get_variable_value('global_step'))
|
||||
self.assertEqual(60, sme.get_variable_value('some_var'))
|
||||
|
||||
eval_results = sme.evaluate(dummy_input_fn, steps=5)
|
||||
|
||||
self.assertEqual(13, eval_results['global_step'])
|
||||
self.assertEqual(106, eval_results['loss'])
|
||||
self.assertEqual(502, eval_results['metrics/abs_err'])
|
||||
|
||||
predictions = next(sme.predict(dummy_input_fn_features_only))
|
||||
self.assertDictEqual({'output': 503}, predictions)
|
||||
|
||||
def test_load_all_modes_no_train(self):
|
||||
"""Ensure that all functions can be used without requiring a ckpt."""
|
||||
sme = saved_model_estimator.SavedModelEstimator(
|
||||
self._export_estimator(), self._get_tmp_dir())
|
||||
eval_results = sme.evaluate(dummy_input_fn, steps=5)
|
||||
self.assertEqual(10, eval_results['global_step'])
|
||||
self.assertEqual(106, eval_results['loss'])
|
||||
self.assertEqual(502, eval_results['metrics/abs_err'])
|
||||
|
||||
predictions = next(sme.predict(dummy_input_fn_features_only))
|
||||
self.assertDictEqual({'output': 503}, predictions)
|
||||
|
||||
def test_partial_exported_estimator(self):
|
||||
sme1 = saved_model_estimator.SavedModelEstimator(
|
||||
self._export_estimator(train=False, predict=False), self._get_tmp_dir())
|
||||
sme1.evaluate(dummy_input_fn, steps=5)
|
||||
with self.assertRaisesRegexp(RuntimeError, 'train mode is not available'):
|
||||
sme1.train(input_fn=dummy_input_fn, steps=1)
|
||||
with self.assertRaisesRegexp(RuntimeError, 'infer mode is not available'):
|
||||
next(sme1.predict(dummy_input_fn_features_only))
|
||||
|
||||
sme2 = saved_model_estimator.SavedModelEstimator(
|
||||
self._export_estimator(evaluate=False), self._get_tmp_dir())
|
||||
sme2.train(input_fn=dummy_input_fn, steps=1)
|
||||
next(sme2.predict(dummy_input_fn_features_only))
|
||||
with self.assertRaisesRegexp(RuntimeError, 'eval mode is not available'):
|
||||
sme2.evaluate(dummy_input_fn, steps=5)
|
||||
|
||||
def test_with_incorrect_input(self):
|
||||
sme = saved_model_estimator.SavedModelEstimator(
|
||||
self._export_estimator(), self._get_tmp_dir())
|
||||
|
||||
def bad_shape_input_fn():
|
||||
return dataset_ops.Dataset.from_tensors((
|
||||
{'x': constant_op.constant([1, 2], dtype=dtypes.int64)},
|
||||
constant_op.constant([1, 2], dtype=dtypes.float32)))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'Expected shape'):
|
||||
sme.train(bad_shape_input_fn, steps=1)
|
||||
|
||||
def bad_dtype_input_fn():
|
||||
return dataset_ops.Dataset.from_tensors((
|
||||
{'x': constant_op.constant([[1], [1]], dtype=dtypes.int32)},
|
||||
constant_op.constant([[1], [1]], dtype=dtypes.int64)))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'Expected dtype'):
|
||||
sme.train(bad_dtype_input_fn, steps=1)
|
||||
|
||||
def test_input_fn_with_global_step(self):
|
||||
sme = saved_model_estimator.SavedModelEstimator(
|
||||
self._export_estimator(), self._get_tmp_dir())
|
||||
|
||||
def bad_input_fn():
|
||||
training.get_or_create_global_step()
|
||||
return dataset_ops.Dataset.from_tensors((
|
||||
{'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)},
|
||||
constant_op.constant([[1], [1]], dtype=dtypes.float32)))
|
||||
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
'Graph must not contain a global step tensor'):
|
||||
sme.train(bad_input_fn, steps=1)
|
||||
|
||||
def test_re_export_saved_model_serving_only(self):
|
||||
sme = saved_model_estimator.SavedModelEstimator(
|
||||
self._export_estimator(), self._get_tmp_dir())
|
||||
sme.train(dummy_input_fn, steps=3)
|
||||
self.assertEqual(13, sme.get_variable_value('global_step'))
|
||||
self.assertEqual(60, sme.get_variable_value('some_var'))
|
||||
|
||||
predictions = next(sme.predict(dummy_input_fn_features_only))
|
||||
self.assertDictEqual({'output': 503}, predictions)
|
||||
|
||||
# Export SavedModel, and test that the variable and prediction values are
|
||||
# the same.
|
||||
sme_export_dir = sme.export_savedmodel(
|
||||
self._get_tmp_dir(), dummy_serving_receiver_fn())
|
||||
|
||||
sme2 = saved_model_estimator.SavedModelEstimator(
|
||||
sme_export_dir, self._get_tmp_dir())
|
||||
self.assertEqual(60, sme.get_variable_value('some_var'))
|
||||
self.assertEqual(13, sme.get_variable_value('global_step'))
|
||||
|
||||
predictions = next(sme2.predict(dummy_input_fn_features_only))
|
||||
self.assertDictEqual({'output': 503}, predictions)
|
||||
|
||||
def test_re_export_saved_model(self):
|
||||
sme = saved_model_estimator.SavedModelEstimator(
|
||||
self._export_estimator(), self._get_tmp_dir())
|
||||
self.assertDictEqual(
|
||||
{'loss': 106, 'metrics/abs_err': 502, 'global_step': 10},
|
||||
sme.evaluate(dummy_input_fn, steps=1))
|
||||
|
||||
sme.train(dummy_input_fn, steps=3)
|
||||
self.assertDictEqual(
|
||||
{'loss': 106, 'metrics/abs_err': 502, 'global_step': 13},
|
||||
sme.evaluate(dummy_input_fn, steps=1))
|
||||
self.assertEqual(60, sme.get_variable_value('some_var'))
|
||||
|
||||
predictions = next(sme.predict(dummy_input_fn_features_only))
|
||||
self.assertDictEqual({'output': 503}, predictions)
|
||||
|
||||
# Export SavedModel for all modes
|
||||
input_receiver_fn_map = {
|
||||
model_fn_lib.ModeKeys.TRAIN: dummy_supervised_receiver_fn(),
|
||||
model_fn_lib.ModeKeys.EVAL: dummy_supervised_receiver_fn(),
|
||||
model_fn_lib.ModeKeys.PREDICT: dummy_serving_receiver_fn()}
|
||||
sme_export_dir = contrib_export.export_all_saved_models(
|
||||
sme, self._get_tmp_dir(), input_receiver_fn_map)
|
||||
|
||||
sme2 = saved_model_estimator.SavedModelEstimator(
|
||||
sme_export_dir, self._get_tmp_dir())
|
||||
self.assertDictEqual(
|
||||
{'loss': 106, 'metrics/abs_err': 502, 'global_step': 13},
|
||||
sme.evaluate(dummy_input_fn, steps=1))
|
||||
self.assertEqual(60, sme.get_variable_value('some_var'))
|
||||
|
||||
sme.train(dummy_input_fn, steps=7)
|
||||
self.assertEqual(20, sme.get_variable_value('global_step'))
|
||||
|
||||
predictions = next(sme2.predict(dummy_input_fn_features_only))
|
||||
self.assertDictEqual({'output': 503}, predictions)
|
||||
|
||||
def test_load_saved_model_from_serving_only(self):
|
||||
def model_fn(features, labels, mode):
|
||||
_, _ = features, labels
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode,
|
||||
loss=constant_op.constant([103]),
|
||||
train_op=state_ops.assign_add(training.get_global_step(), 1),
|
||||
predictions=constant_op.constant([502]),
|
||||
export_outputs={'test': export_output.ClassificationOutput(
|
||||
constant_op.constant([[32.]]))})
|
||||
|
||||
est = estimator.Estimator(model_fn, self._get_tmp_dir())
|
||||
est.train(input_fn=dummy_input_fn, steps=10)
|
||||
|
||||
def serving_input_receiver_fn():
|
||||
return export.ServingInputReceiver(
|
||||
{'test-features': constant_op.constant([[1], [1]])},
|
||||
array_ops.placeholder(dtype=dtypes.string))
|
||||
|
||||
export_dir = est.export_savedmodel(
|
||||
self._get_tmp_dir(), serving_input_receiver_fn)
|
||||
|
||||
sme = saved_model_estimator.SavedModelEstimator(
|
||||
export_dir, self._get_tmp_dir())
|
||||
|
||||
def input_fn():
|
||||
return {'inputs': constant_op.constant('someinputstr')}
|
||||
|
||||
prediction = next(sme.predict(input_fn))
|
||||
self.assertDictEqual({'scores': 32}, prediction)
|
||||
|
||||
def test_with_local_init_op(self):
|
||||
def model_fn(features, labels, mode):
|
||||
_, _ = features, labels
|
||||
v = variables.Variable(21, name='some_var')
|
||||
scaffold = monitored_session.Scaffold(
|
||||
local_init_op=state_ops.assign_add(v, -3).op
|
||||
)
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode,
|
||||
scaffold=scaffold,
|
||||
train_op=state_ops.assign_add(training.get_global_step(), 1),
|
||||
loss=array_ops.identity(v))
|
||||
export_dir = self._export_estimator(predict=False, model_fn=model_fn)
|
||||
sme = saved_model_estimator.SavedModelEstimator(
|
||||
export_dir, self._get_tmp_dir())
|
||||
|
||||
eval_results1 = sme.evaluate(dummy_input_fn, steps=2)
|
||||
self.assertEqual(15, eval_results1['loss'])
|
||||
|
||||
sme.train(dummy_input_fn, steps=1)
|
||||
self.assertEqual(15, sme.get_variable_value('some_var'))
|
||||
|
||||
eval_results2 = sme.evaluate(dummy_input_fn, steps=5)
|
||||
self.assertEqual(12, eval_results2['loss'])
|
||||
|
||||
def test_with_working_input_fn(self):
|
||||
def model_fn(features, labels, mode):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = labels[0][0] + labels[1][0]
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode,
|
||||
loss=loss,
|
||||
train_op=state_ops.assign_add(training.get_global_step(), 1),
|
||||
predictions={'features_0': array_ops.identity([features['x'][0][0]]),
|
||||
'features_1': array_ops.identity([features['x'][1][0]])})
|
||||
|
||||
sme = saved_model_estimator.SavedModelEstimator(
|
||||
self._export_estimator(model_fn=model_fn), self._get_tmp_dir())
|
||||
eval_results = sme.evaluate(dummy_input_fn, steps=1)
|
||||
self.assertEqual(1, eval_results['loss'])
|
||||
|
||||
predictions = next(sme.predict(dummy_input_fn_features_only))
|
||||
self.assertDictEqual({'features_0': 5, 'features_1': 6}, predictions)
|
||||
|
||||
def test_control_dependency(self):
|
||||
# Control dependencies are saved with "^" appended to the start of the input
|
||||
# name. The input map must include control dependencies as well.
|
||||
def model_fn(features, labels, mode):
|
||||
_ = labels
|
||||
with ops.control_dependencies([features['x']]):
|
||||
loss = features['x'][1][0]
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode,
|
||||
loss=loss,
|
||||
train_op=state_ops.assign_add(training.get_global_step(), 1))
|
||||
sme = saved_model_estimator.SavedModelEstimator(
|
||||
self._export_estimator(train=False, predict=False, model_fn=model_fn),
|
||||
self._get_tmp_dir())
|
||||
sme.evaluate(dummy_input_fn, steps=1) # Should run without error
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -568,14 +568,13 @@ class Estimator(object):
|
||||
def _assert_members_are_not_overridden(self):
|
||||
"""Asserts members of `Estimator` are not overridden."""
|
||||
allowed_overrides = set([
|
||||
'_call_input_fn', '_call_model_fn',
|
||||
'_call_input_fn', '_create_global_step',
|
||||
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
|
||||
'_create_global_step', '_create_and_assert_global_step',
|
||||
'_tf_api_names', '_tf_api_names_v1', '_estimator_api_names',
|
||||
'_estimator_api_names_v1', '_estimator_api_constants',
|
||||
'_estimator_api_constants_v1',
|
||||
'_validate_features_in_predict_input',
|
||||
'_add_meta_graph_for_mode'
|
||||
'_call_model_fn', '_add_meta_graph_for_mode'
|
||||
])
|
||||
estimator_members = set([m for m in Estimator.__dict__.keys()
|
||||
if not m.startswith('__')])
|
||||
@ -902,10 +901,9 @@ class Estimator(object):
|
||||
|
||||
with tf_session.Session(config=self._session_config) as session:
|
||||
|
||||
if estimator_spec.scaffold.local_init_op is not None:
|
||||
local_init_op = estimator_spec.scaffold.local_init_op
|
||||
else:
|
||||
local_init_op = monitored_session.Scaffold.default_local_init_op()
|
||||
local_init_op = (
|
||||
estimator_spec.scaffold.local_init_op or
|
||||
monitored_session.Scaffold.default_local_init_op())
|
||||
|
||||
# This saver will be used both for restoring variables now,
|
||||
# and in saving out the metagraph below. This ensures that any
|
||||
@ -1156,15 +1154,14 @@ class Estimator(object):
|
||||
worker_hooks = []
|
||||
with ops.Graph().as_default() as g, g.device(self._device_fn):
|
||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||
self._create_and_assert_global_step(g)
|
||||
global_step_tensor = self._create_and_assert_global_step(g)
|
||||
training_util._get_or_create_global_step_read() # pylint: disable=protected-access
|
||||
features, labels, input_hooks = (
|
||||
self._get_features_and_labels_from_input_fn(
|
||||
input_fn, model_fn_lib.ModeKeys.TRAIN))
|
||||
worker_hooks.extend(input_hooks)
|
||||
estimator_spec = self._call_model_fn(
|
||||
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
|
||||
global_step_tensor = training_util.get_global_step(g)
|
||||
training_util._get_or_create_global_step_read() # pylint: disable=protected-access
|
||||
return self._train_with_estimator_spec(estimator_spec, worker_hooks,
|
||||
hooks, global_step_tensor,
|
||||
saving_listeners)
|
||||
@ -1367,8 +1364,10 @@ class Estimator(object):
|
||||
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
|
||||
global_step_tensor, saving_listeners):
|
||||
"""Train a model with the given Estimator Spec."""
|
||||
self._maybe_warm_start(self.latest_checkpoint())
|
||||
|
||||
if self._warm_start_settings:
|
||||
logging.info('Warm-starting with WarmStartSettings: %s' %
|
||||
(self._warm_start_settings,))
|
||||
warm_starting_util.warm_start(*self._warm_start_settings)
|
||||
# Check if the user created a loss summary, and add one if they didn't.
|
||||
# We assume here that the summary is called 'loss'. If it is not, we will
|
||||
# make another one with the name 'loss' to ensure it shows up in the right
|
||||
@ -1449,13 +1448,13 @@ class Estimator(object):
|
||||
def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None):
|
||||
"""Builds the graph and related hooks to run evaluation."""
|
||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||
self._create_and_assert_global_step(ops.get_default_graph())
|
||||
global_step_tensor = self._create_and_assert_global_step(
|
||||
ops.get_default_graph())
|
||||
features, labels, input_hooks = (
|
||||
self._get_features_and_labels_from_input_fn(input_fn,
|
||||
model_fn_lib.ModeKeys.EVAL))
|
||||
estimator_spec = self._call_model_fn(
|
||||
features, labels, model_fn_lib.ModeKeys.EVAL, self.config)
|
||||
global_step_tensor = training_util.get_global_step(ops.get_default_graph())
|
||||
|
||||
# Call to warm_start has to be after model_fn is called.
|
||||
self._maybe_warm_start(checkpoint_path)
|
||||
@ -1481,21 +1480,7 @@ class Estimator(object):
|
||||
all_hooks.extend(hooks)
|
||||
all_hooks.extend(list(estimator_spec.evaluation_hooks or []))
|
||||
|
||||
# New local variables have been added, so update the estimator spec's
|
||||
# local init op if it was defined.
|
||||
scaffold = estimator_spec.scaffold
|
||||
if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op:
|
||||
# Ensure that eval step has been created before updating local init op.
|
||||
evaluation._get_or_create_eval_step() # pylint: disable=protected-access
|
||||
|
||||
scaffold = monitored_session.Scaffold(
|
||||
local_init_op=control_flow_ops.group(
|
||||
estimator_spec.scaffold.local_init_op,
|
||||
monitored_session.Scaffold.default_local_init_op()),
|
||||
copy_from_scaffold=scaffold
|
||||
)
|
||||
|
||||
return scaffold, update_op, eval_dict, all_hooks
|
||||
return estimator_spec.scaffold, update_op, eval_dict, all_hooks
|
||||
|
||||
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,
|
||||
all_hooks, output_dir):
|
||||
@ -1926,19 +1911,6 @@ class WarmStartSettings(
|
||||
)
|
||||
|
||||
|
||||
def _get_saved_model_ckpt(saved_model_dir):
|
||||
"""Return path to variables checkpoint in a SavedModel directory."""
|
||||
if not gfile.Exists(
|
||||
os.path.join(compat.as_bytes(saved_model_dir),
|
||||
compat.as_bytes('variables/variables.index'))):
|
||||
raise ValueError('Directory provided has an invalid SavedModel format: %s'
|
||||
% saved_model_dir)
|
||||
return os.path.join(
|
||||
compat.as_bytes(saved_model_dir),
|
||||
compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY,
|
||||
constants.VARIABLES_FILENAME)))
|
||||
|
||||
|
||||
def _get_default_warm_start_settings(warm_start_from):
|
||||
"""Returns default WarmStartSettings.
|
||||
|
||||
@ -1962,8 +1934,10 @@ def _get_default_warm_start_settings(warm_start_from):
|
||||
if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from),
|
||||
compat.as_bytes('variables/variables.index'))):
|
||||
logging.info('Warm-starting from a SavedModel')
|
||||
return WarmStartSettings(
|
||||
ckpt_to_initialize_from=_get_saved_model_ckpt(warm_start_from))
|
||||
return WarmStartSettings(ckpt_to_initialize_from=os.path.join(
|
||||
compat.as_bytes(warm_start_from),
|
||||
compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY,
|
||||
constants.VARIABLES_FILENAME))))
|
||||
return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
|
||||
elif isinstance(warm_start_from, WarmStartSettings):
|
||||
return warm_start_from
|
||||
|
@ -205,7 +205,7 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
|
||||
for input_src, input_dst in input_map.items():
|
||||
input_src = compat.as_str(input_src)
|
||||
if input_src.startswith('^'):
|
||||
src_name = compat.as_str(input_src[1:])
|
||||
src_name = compat.as_bytes(input_src[1:])
|
||||
dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access
|
||||
c_api.TF_ImportGraphDefOptionsRemapControlDependency(
|
||||
options, src_name, dst_op)
|
||||
|
@ -696,67 +696,6 @@ def import_scoped_meta_graph(meta_graph_or_file,
|
||||
Raises:
|
||||
ValueError: If the graph_def contains unbound inputs.
|
||||
"""
|
||||
return import_scoped_meta_graph_with_return_elements(
|
||||
meta_graph_or_file, clear_devices, graph, import_scope, input_map,
|
||||
unbound_inputs_col_name, restore_collections_predicate)[0]
|
||||
|
||||
|
||||
def import_scoped_meta_graph_with_return_elements(
|
||||
meta_graph_or_file,
|
||||
clear_devices=False,
|
||||
graph=None,
|
||||
import_scope=None,
|
||||
input_map=None,
|
||||
unbound_inputs_col_name="unbound_inputs",
|
||||
restore_collections_predicate=(lambda key: True),
|
||||
return_elements=None):
|
||||
"""Imports graph from `MetaGraphDef` and returns vars and return elements.
|
||||
|
||||
This function takes a `MetaGraphDef` protocol buffer as input. If
|
||||
the argument is a file containing a `MetaGraphDef` protocol buffer ,
|
||||
it constructs a protocol buffer from the file content. The function
|
||||
then adds all the nodes from the `graph_def` field to the
|
||||
current graph, recreates the desired collections, and returns a dictionary of
|
||||
all the Variables imported into the name scope.
|
||||
|
||||
In combination with `export_scoped_meta_graph()`, this function can be used to
|
||||
|
||||
* Serialize a graph along with other Python objects such as `QueueRunner`,
|
||||
`Variable` into a `MetaGraphDef`.
|
||||
|
||||
* Restart training from a saved graph and checkpoints.
|
||||
|
||||
* Run inference from a saved graph and checkpoints.
|
||||
|
||||
Args:
|
||||
meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
|
||||
the path) containing a `MetaGraphDef`.
|
||||
clear_devices: Boolean which controls whether to clear device information
|
||||
from graph_def. Default false.
|
||||
graph: The `Graph` to import into. If `None`, use the default graph.
|
||||
import_scope: Optional `string`. Name scope into which to import the
|
||||
subgraph. If `None`, the graph is imported to the root name scope.
|
||||
input_map: A dictionary mapping input names (as strings) in `graph_def` to
|
||||
`Tensor` objects. The values of the named input tensors in the imported
|
||||
graph will be re-mapped to the respective `Tensor` values.
|
||||
unbound_inputs_col_name: Collection name for looking up unbound inputs.
|
||||
restore_collections_predicate: a predicate on collection names. A collection
|
||||
named c (i.e whose key is c) will be restored iff
|
||||
1) `restore_collections_predicate(c)` is True, and
|
||||
2) `c != unbound_inputs_col_name`.
|
||||
return_elements: A list of strings containing operation names in the
|
||||
`MetaGraphDef` that will be returned as `Operation` objects; and/or
|
||||
tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.
|
||||
|
||||
Returns:
|
||||
A tuple of (
|
||||
dictionary of all the `Variables` imported into the name scope,
|
||||
list of `Operation` or `Tensor` objects from the `return_elements` list).
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph_def contains unbound inputs.
|
||||
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
raise ValueError("Exporting/importing meta graphs is not supported when "
|
||||
"eager execution is enabled.")
|
||||
@ -798,12 +737,11 @@ def import_scoped_meta_graph_with_return_elements(
|
||||
scope_to_prepend_to_names = graph.unique_name(
|
||||
import_scope or "", mark_as_used=False)
|
||||
|
||||
imported_return_elements = importer.import_graph_def(
|
||||
importer.import_graph_def(
|
||||
input_graph_def,
|
||||
name=(import_scope or scope_to_prepend_to_names),
|
||||
input_map=input_map,
|
||||
producer_op_list=producer_op_list,
|
||||
return_elements=return_elements)
|
||||
producer_op_list=producer_op_list)
|
||||
|
||||
# Restores all the other collections.
|
||||
variable_objects = {}
|
||||
@ -868,7 +806,7 @@ def import_scoped_meta_graph_with_return_elements(
|
||||
for v in variables:
|
||||
var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
|
||||
|
||||
return var_list, imported_return_elements
|
||||
return var_list
|
||||
|
||||
|
||||
def export_scoped_meta_graph(filename=None,
|
||||
|
@ -284,15 +284,12 @@ class SavedModelLoader(object):
|
||||
**saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
|
||||
|
||||
Returns:
|
||||
A tuple of
|
||||
* Saver defined by the MetaGraph, which can be used to restore the
|
||||
variable values.
|
||||
* List of `Operation`/`Tensor` objects returned from
|
||||
`tf.import_graph_def` (may be `None`).
|
||||
Saver defined by the MetaGraph, which can be used to restore the variable
|
||||
values.
|
||||
"""
|
||||
meta_graph_def = self.get_meta_graph_def_from_tags(tags)
|
||||
with graph.as_default():
|
||||
return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access
|
||||
return tf_saver.import_meta_graph(
|
||||
meta_graph_def, import_scope=import_scope, **saver_kwargs)
|
||||
|
||||
def restore_variables(self, sess, saver, import_scope=None):
|
||||
@ -364,8 +361,8 @@ class SavedModelLoader(object):
|
||||
`MetagraphDef` proto of the graph that was loaded.
|
||||
"""
|
||||
with sess.graph.as_default():
|
||||
saver, _ = self.load_graph(sess.graph, tags, import_scope,
|
||||
**saver_kwargs)
|
||||
saver = self.load_graph(sess.graph, tags, import_scope,
|
||||
**saver_kwargs)
|
||||
self.restore_variables(sess, saver, import_scope)
|
||||
self.run_init_ops(sess, tags, import_scope)
|
||||
return self.get_meta_graph_def_from_tags(tags)
|
||||
|
@ -111,8 +111,7 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
def test_load_with_import_scope(self):
|
||||
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
saver, _ = loader.load_graph(
|
||||
sess.graph, ["foo_graph"], import_scope="baz")
|
||||
saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz")
|
||||
|
||||
# The default saver should not work when the import scope is set.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
@ -150,7 +149,7 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
def test_run_init_op(self):
|
||||
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
||||
graph = ops.Graph()
|
||||
saver, _ = loader.load_graph(graph, ["foo_graph"])
|
||||
saver = loader.load_graph(graph, ["foo_graph"])
|
||||
with self.test_session(graph=graph) as sess:
|
||||
loader.restore_variables(sess, saver)
|
||||
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
||||
@ -204,7 +203,7 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
|
||||
loader = loader_impl.SavedModelLoader(path)
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
saver, _ = loader.load_graph(sess.graph, ["foo_graph"])
|
||||
saver = loader.load_graph(sess.graph, ["foo_graph"])
|
||||
self.assertFalse(variables._all_saveable_objects())
|
||||
self.assertIsNotNone(saver)
|
||||
|
||||
@ -213,18 +212,6 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
||||
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
|
||||
|
||||
def test_load_saved_model_graph_with_return_elements(self):
|
||||
"""Ensure that the correct elements are returned."""
|
||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||
graph = ops.Graph()
|
||||
_, ret = loader.load_graph(graph, ["foo_graph"],
|
||||
return_elements=["y:0", "x:0"])
|
||||
|
||||
self.assertEqual(graph.get_tensor_by_name("y:0"), ret[0])
|
||||
self.assertEqual(graph.get_tensor_by_name("x:0"), ret[1])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "not found in graph"):
|
||||
loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -1928,14 +1928,6 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
||||
execution is enabled.
|
||||
@end_compatibility
|
||||
""" # pylint: disable=g-doc-exception
|
||||
return _import_meta_graph_with_return_elements(
|
||||
meta_graph_or_file, clear_devices, import_scope, **kwargs)[0]
|
||||
|
||||
|
||||
def _import_meta_graph_with_return_elements(
|
||||
meta_graph_or_file, clear_devices=False, import_scope=None,
|
||||
return_elements=None, **kwargs):
|
||||
"""Import MetaGraph, and return both a saver and returned elements."""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("Exporting/importing meta graphs is not supported when "
|
||||
"eager execution is enabled. No graph exists when eager "
|
||||
@ -1945,22 +1937,12 @@ def _import_meta_graph_with_return_elements(
|
||||
else:
|
||||
meta_graph_def = meta_graph_or_file
|
||||
|
||||
imported_vars, imported_return_elements = (
|
||||
meta_graph.import_scoped_meta_graph_with_return_elements(
|
||||
meta_graph_def,
|
||||
clear_devices=clear_devices,
|
||||
import_scope=import_scope,
|
||||
return_elements=return_elements,
|
||||
**kwargs))
|
||||
imported_vars = meta_graph.import_scoped_meta_graph(
|
||||
meta_graph_def,
|
||||
clear_devices=clear_devices,
|
||||
import_scope=import_scope,
|
||||
**kwargs)
|
||||
|
||||
saver = _create_saver_from_imported_meta_graph(
|
||||
meta_graph_def, import_scope, imported_vars)
|
||||
return saver, imported_return_elements
|
||||
|
||||
|
||||
def _create_saver_from_imported_meta_graph(
|
||||
meta_graph_def, import_scope, imported_vars):
|
||||
"""Return a saver for restoring variable values to an imported MetaGraph."""
|
||||
if meta_graph_def.HasField("saver_def"):
|
||||
# Infer the scope that is prepended by `import_scoped_meta_graph`.
|
||||
scope = import_scope
|
||||
|
Loading…
Reference in New Issue
Block a user