Automated rollback of commit 8257891f37

PiperOrigin-RevId: 205466000
This commit is contained in:
Katherine Wu 2018-07-20 15:45:15 -07:00 committed by TensorFlower Gardener
parent 5e876a8c25
commit 6c528feaf8
10 changed files with 35 additions and 1017 deletions

View File

@ -28,7 +28,6 @@ py_library(
":multi_head",
":replicate_model_fn",
":rnn",
":saved_model_estimator",
"//tensorflow:tensorflow_py_no_contrib",
],
)
@ -466,43 +465,3 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "saved_model_estimator",
srcs = ["python/estimator/saved_model_estimator.py"],
deps = [
":export",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:export",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/saved_model",
],
)
py_test(
name = "saved_model_estimator_test",
size = "medium",
srcs = ["python/estimator/saved_model_estimator_test.py"],
srcs_version = "PY2AND3",
deps = [
":export",
":saved_model_estimator",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:metrics",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:export_output",
"//tensorflow/python/estimator:model_fn",
],
)

View File

@ -33,8 +33,6 @@ from tensorflow.contrib.estimator.python.estimator.logit_fns import *
from tensorflow.contrib.estimator.python.estimator.multi_head import *
from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import *
from tensorflow.contrib.estimator.python.estimator.rnn import *
from tensorflow.contrib.estimator.python.estimator.saved_model_estimator import *
from tensorflow.python.estimator.export.export import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
@ -72,9 +70,6 @@ _allowed_symbols = [
'stop_if_higher_hook',
'stop_if_no_increase_hook',
'stop_if_no_decrease_hook',
'build_raw_supervised_input_receiver_fn',
'build_supervised_input_receiver_fn_from_input_fn',
'SavedModelEstimator'
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -1,445 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class that creates an Estimator from a SavedModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.export import export as export_lib
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import monitored_session
from tensorflow.python.training import training_util
class SavedModelEstimator(estimator_lib.Estimator):
"""Create an Estimator from a SavedModel.
Only SavedModels exported with
`tf.contrib.estimator.export_all_saved_models()` or
`tf.estimator.Estimator.export_savedmodel()` are supported for this class.
Example with `tf.estimator.DNNClassifier`:
**Step 1: Create and train DNNClassifier.**
```python
feature1 = tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_vocabulary_list(
key='feature1', vocabulary_list=('green', 'yellow')), dimension=1)
feature2 = tf.feature_column.numeric_column(key='feature2', default_value=0.0)
classifier = tf.estimator.DNNClassifier(
hidden_units=[4,2], feature_columns=[feature1, feature2])
def input_fn():
features = {'feature1': tf.constant(['green', 'green', 'yellow']),
'feature2': tf.constant([3.5, 4.2, 6.1])}
label = tf.constant([1., 0., 0.])
return tf.data.Dataset.from_tensors((features, label)).repeat()
classifier.train(input_fn=input_fn, steps=10)
```
**Step 2: Export classifier.**
First, build functions that specify the expected inputs.
```python
# During train and evaluation, both the features and labels should be defined.
supervised_input_receiver_fn = (
tf.contrib.estimator.build_raw_supervised_input_receiver_fn(
{'feature1': tf.placeholder(dtype=tf.string, shape=[None]),
'feature2': tf.placeholder(dtype=tf.float32, shape=[None])},
tf.placeholder(dtype=tf.float32, shape=[None])))
# During predict mode, expect to receive a `tf.Example` proto, so a parsing
# function is used.
serving_input_receiver_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec([feature1, feature2])))
```
Next, export the model as a SavedModel. A timestamped directory will be
created (for example `/tmp/export_all/1234567890`).
```python
# Option 1: Save all modes (train, eval, predict)
export_dir = tf.contrib.estimator.export_all_saved_models(
classifier, '/tmp/export_all',
{tf.estimator.ModeKeys.TRAIN: supervised_input_receiver_fn,
tf.estimator.ModeKeys.EVAL: supervised_input_receiver_fn,
tf.estimator.ModeKeys.PREDICT: serving_input_receiver_fn})
# Option 2: Only export predict mode
export_dir = classifier.export_savedmodel(
'/tmp/export_predict', serving_input_receiver_fn)
```
**Step 3: Create a SavedModelEstimator from the exported SavedModel.**
```python
est = tf.contrib.estimator.SavedModelEstimator(export_dir)
# If all modes were exported, you can immediately evaluate and predict, or
# continue training. Otherwise only predict is available.
eval_results = est.evaluate(input_fn=input_fn, steps=1)
print(eval_results)
est.train(input_fn=input_fn, steps=20)
def predict_input_fn():
example = example_pb2.Example()
example.features.feature['feature1'].bytes_list.value.extend(['yellow'])
example.features.feature['feature2'].float_list.value.extend([1.])
return {'inputs':tf.constant([example.SerializeToString()])}
predictions = est.predict(predict_input_fn)
print(next(predictions))
```
"""
def __init__(self, saved_model_dir, model_dir=None):
"""Initialize a SavedModelEstimator.
The SavedModelEstimator loads its model function and variable values from
the graphs defined in the SavedModel. There is no option to pass in
`RunConfig` or `params` arguments, because the model function graph is
defined statically in the SavedModel.
Args:
saved_model_dir: Directory containing SavedModel protobuf and subfolders.
model_dir: Directory to save new checkpoints during training.
Raises:
NotImplementedError: If a DistributionStrategy is defined in the config.
Unless the SavedModelEstimator is subclassed, this shouldn't happen.
"""
checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access
vars_to_warm_start = [name for name, _ in
checkpoint_utils.list_variables(checkpoint)]
warm_start_settings = estimator_lib.WarmStartSettings(
ckpt_to_initialize_from=checkpoint,
vars_to_warm_start=vars_to_warm_start)
super(SavedModelEstimator, self).__init__(
model_fn=self._model_fn_from_saved_model, model_dir=model_dir,
warm_start_from=warm_start_settings)
if self._distribution is not None:
raise NotImplementedError(
'SavedModelEstimator currently does not support '
'DistributionStrategy.')
self.saved_model_dir = saved_model_dir
self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir)
self._available_modes = self._extract_available_modes()
def _extract_available_modes(self):
"""Return list of modes found in SavedModel."""
available_modes = []
logging.info('Checking available modes for SavedModelEstimator.')
for mode in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
model_fn_lib.ModeKeys.PREDICT]:
try:
self._get_meta_graph_def_for_mode(mode)
except RuntimeError:
logging.warning('%s mode not found in SavedModel.' % mode)
continue
if self._get_signature_def_for_mode(mode) is not None:
available_modes.append(mode)
logging.info('Available modes for Estimator: %s' % available_modes)
return available_modes
def _validate_mode(self, mode):
"""Make sure that mode can be run using the SavedModel."""
if mode not in self._available_modes:
raise RuntimeError('%s mode is not available in the SavedModel. Use '
'saved_model_cli to check that the Metagraph for this '
'mode has been exported.' % mode)
def _get_meta_graph_def_for_mode(self, mode):
tags = model_fn_lib.EXPORT_TAG_MAP[mode]
return self.saved_model_loader.get_meta_graph_def_from_tags(tags)
def _get_signature_def_for_mode(self, mode):
meta_graph_def = self._get_meta_graph_def_for_mode(mode)
sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
if mode == model_fn_lib.ModeKeys.PREDICT else mode)
if sig_def_key not in meta_graph_def.signature_def:
logging.warning('Metagraph for mode %s was found, but SignatureDef with'
' key \"%s\" is missing.' % (mode, sig_def_key))
return None
return meta_graph_def.signature_def[sig_def_key]
def _create_and_assert_global_step(self, graph):
# Do nothing here. The global step variable will be created/loaded from the
# SavedModel. If a global step variable were created here, the result
# will be two duplicate global step variables, causing issues during
# the warm-start phase.
# Due to the global variable being created in the model function, this may
# cause issues when running DistributionStrategy. Thus, DistributionStrategy
# is not yet supported with SavedModelEstimator.
pass
def _model_fn_from_saved_model(self, features, labels, mode):
"""Load a SavedModel graph and return an EstimatorSpec."""
# TODO(kathywu): Model function loads placeholders from the graph. Calling
# export_all_saved_models creates another placeholder for the inputs, on top
# of the original placeholders. There should be a way to avoid this.
self._validate_mode(mode)
g = ops.get_default_graph()
if training_util.get_global_step(g) is not None:
raise RuntimeError(
'Graph must not contain a global step tensor before the SavedModel is'
' loaded. Please make sure that the input function does not create a '
'global step.')
# Extract SignatureDef for information about the input and output tensors.
signature_def = self._get_signature_def_for_mode(mode)
# Generate input map for replacing the inputs in the SavedModel graph with
# the provided features and labels.
input_map = _generate_input_map(signature_def, features, labels)
# Create a list of the names of output tensors. When the graph is loaded,
# names of the output tensors may be remapped. This ensures that the correct
# tensors are returned in the EstimatorSpec.
output_tensor_names = [
value.name for value in six.itervalues(signature_def.outputs)]
# Load the graph. `output_tensors` contains output `Tensors` in the same
# same order as the `output_tensor_names` list.
tags = model_fn_lib.EXPORT_TAG_MAP[mode]
_, output_tensors = self.saved_model_loader.load_graph(
g, tags, input_map=input_map, return_elements=output_tensor_names)
# Create a scaffold from the MetaGraphDef that contains ops to initialize
# the graph. This should mirror the steps from _add_meta_graph_for_mode(),
# which creates a MetaGraphDef from the EstimatorSpec's scaffold.
scaffold = monitored_session.Scaffold(
local_init_op=loader_impl._get_legacy_init_op_tensor( # pylint: disable=protected-access
self._get_meta_graph_def_for_mode(mode)))
# Ensure that a global step tensor has been created.
global_step_tensor = training_util.get_global_step(g)
training_util.assert_global_step(global_step_tensor)
# Extract values to return in the EstimatorSpec.
output_map = dict(zip(output_tensor_names, output_tensors))
outputs = {key: output_map[value.name]
for key, value in six.iteritems(signature_def.outputs)}
loss, predictions, metrics = _validate_and_extract_outputs(
mode, outputs, signature_def.method_name)
train_op = ops.get_collection(constants.TRAIN_OP_KEY)
if len(train_op) > 1:
raise RuntimeError('Multiple ops found in the train_op collection.')
train_op = None if not train_op else train_op[0]
_clear_saved_model_collections()
return model_fn_lib.EstimatorSpec(
scaffold=scaffold,
mode=mode,
loss=loss,
train_op=train_op,
predictions=predictions,
eval_metric_ops=metrics)
def _clear_saved_model_collections():
"""Clear collections that are expected empty when exporting a SavedModel.
The SavedModel builder uses these collections to track ops necessary to
restore the graph state. These collections are expected to be empty before
MetaGraphs are added to the builder.
"""
del ops.get_collection_ref(constants.ASSETS_KEY)[:]
del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:]
del ops.get_collection_ref(constants.MAIN_OP_KEY)[:]
del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:]
def _generate_input_map(signature_def, features, labels):
"""Return dict mapping an input tensor name to a feature or label tensor.
Args:
signature_def: SignatureDef loaded from SavedModel
features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
`SparseTensor`, specifying the features to be passed to the model.
labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
`SparseTensor`, specifying the labels to be passed to the model. May be
`None`.
Returns:
dict mapping string names of inputs to features or labels tensors
Raises:
ValueError: if SignatureDef inputs are not completely mapped by the input
features and labels.
"""
# pylint: disable=protected-access
if not isinstance(features, dict):
features = {export_lib._SINGLE_FEATURE_DEFAULT_NAME: features}
if labels is not None and not isinstance(labels, dict):
labels = {export_lib._SINGLE_LABEL_DEFAULT_NAME: labels}
# pylint: enable=protected-access
inputs = signature_def.inputs
input_map = {}
for key, tensor_info in six.iteritems(inputs):
input_name = tensor_info.name
if ':' in input_name:
input_name = input_name[:input_name.find(':')]
# When tensors are used as control inputs for operations, their names are
# prepended with a '^' character in the GraphDef. To handle possible control
# flow edge cases, control input names must be included in the input map.
control_dependency_name = '^' + input_name
if key in features:
_check_same_dtype_and_shape(features[key], tensor_info, key)
input_map[input_name] = input_map[control_dependency_name] = features[key]
elif labels is not None and key in labels:
_check_same_dtype_and_shape(labels[key], tensor_info, key)
input_map[input_name] = input_map[control_dependency_name] = labels[key]
else:
raise ValueError(
'Key \"%s\" not found in features or labels passed in to the model '
'function. All required keys: %s' % (key, inputs.keys()))
return input_map
def _check_same_dtype_and_shape(tensor, tensor_info, name):
"""Validate that tensor has the same properties as the TensorInfo proto.
Args:
tensor: a `Tensor` object.
tensor_info: a `TensorInfo` proto.
name: Name of the input (to identify Tensor if an error is raised).
Raises:
ValueError: If the tensor shape or dtype don't match the TensorInfo
"""
dtype_error = (tensor.dtype != dtypes.DType(tensor_info.dtype))
shape_error = not tensor.shape.is_compatible_with(tensor_info.tensor_shape)
if dtype_error or shape_error:
msg = 'Tensor shape and/or dtype validation failed for input %s:' % name
if dtype_error:
msg += ('\n\tExpected dtype: %s, Got: %s'
% (dtypes.DType(tensor_info.dtype), tensor.dtype))
if shape_error:
msg += ('\n\tExpected shape: %s, Got: %s'
% (tensor_shape.TensorShape(tensor_info.tensor_shape),
tensor.shape))
raise ValueError(msg)
def _extract_eval_metrics(output_dict):
"""Return a eval metric dict extracted from the output_dict.
Eval metrics consist of a value tensor and an update op. Both must be in the
passed-in tensor dictionary for an eval metric to be added to the returned
dictionary.
Args:
output_dict: a dict that maps strings to tensors.
Returns:
dict mapping strings to (value, update_op) tuples.
"""
# pylint: disable=protected-access
metric_ops = {}
separator_char = export_output._SupervisedOutput._SEPARATOR_CHAR
for key, tensor in six.iteritems(output_dict):
split_key = key.split(separator_char)
# The metric name may contain the separator character, so recreate its name.
metric_name = separator_char.join(split_key[:-1])
if split_key[0] == export_output._SupervisedOutput.METRICS_NAME:
# If the key ends with the value suffix, and there is a corresponding
# key ending with the update_op suffix, then add tensors to metrics dict.
if split_key[-1] == export_output._SupervisedOutput.METRIC_VALUE_SUFFIX:
update_op = ''.join(
[metric_name, separator_char,
export_output._SupervisedOutput.METRIC_UPDATE_SUFFIX])
if update_op in output_dict:
update_op_tensor = output_dict[update_op]
metric_ops[metric_name] = (tensor, update_op_tensor)
# pylint: enable=protected-access
return metric_ops
def _validate_and_extract_outputs(mode, output_dict, method_name):
"""Extract values from SignatureDef output dictionary.
Args:
mode: One of the modes enumerated in `tf.estimator.ModeKeys`.
output_dict: dict of string SignatureDef keys to `Tensor`.
method_name: Method name of the SignatureDef as a string.
Returns:
Tuple of (
loss: `Tensor` object,
predictions: dictionary mapping string keys to `Tensor` objects,
metrics: dictionary mapping string keys to a tuple of two `Tensor` objects
)
Raises:
RuntimeError: raised if SignatureDef has an invalid method name for the mode
"""
# pylint: disable=protected-access
loss, predictions, metrics = None, None, None
if mode == model_fn_lib.ModeKeys.PREDICT:
predictions = output_dict
else:
# Validate that the SignatureDef's method name matches the expected name for
# the given mode.
expected_method_name = signature_constants.SUPERVISED_TRAIN_METHOD_NAME
if mode == model_fn_lib.ModeKeys.EVAL:
expected_method_name = signature_constants.SUPERVISED_EVAL_METHOD_NAME
if method_name != expected_method_name:
raise RuntimeError(
'Invalid SignatureDef method name for mode %s.\n\tExpected: %s\n\t'
'Got: %s\nPlease ensure that the SavedModel was exported with '
'`tf.contrib.estimator.export_all_saved_models()`.' %
(mode, expected_method_name, method_name))
# Extract loss, metrics and predictions from the output dict.
loss = output_dict[export_output._SupervisedOutput.LOSS_NAME]
metrics = _extract_eval_metrics(output_dict)
predictions = {
key: value for key, value in six.iteritems(output_dict)
if key.split(export_output._SupervisedOutput._SEPARATOR_CHAR)[0] == (
export_output._SupervisedOutput.PREDICTIONS_NAME)}
# pylint: enable=protected-access
return loss, predictions, metrics

View File

@ -1,369 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for SavedModelEstimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import shutil
import tempfile
from tensorflow.contrib.estimator.python.estimator import export as contrib_export
from tensorflow.contrib.estimator.python.estimator import saved_model_estimator
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session
from tensorflow.python.training import training
def dummy_input_fn():
return dataset_ops.Dataset.from_tensors((
{'x': constant_op.constant([[1], [-2]], dtype=dtypes.int64)},
constant_op.constant([[4], [-3]], dtype=dtypes.float32))).repeat()
def dummy_input_fn_features_only():
return dataset_ops.Dataset.from_tensors(
{'x': constant_op.constant([[5], [6]], dtype=dtypes.int64)}).repeat()
def dummy_supervised_receiver_fn():
feature_spec = {
'x': array_ops.placeholder(
dtype=dtypes.int64, shape=(2, 1), name='feature_x'),
}
label_spec = array_ops.placeholder(
dtype=dtypes.float32, shape=[2, 1], name='truth')
return export.build_raw_supervised_input_receiver_fn(
feature_spec, label_spec)
def dummy_serving_receiver_fn():
feature_spec = {'x': array_ops.placeholder(
dtype=dtypes.int64, shape=(2, 1), name='feature_x'),}
return export.build_raw_serving_input_receiver_fn(feature_spec)
def model_fn_diff_modes(features, labels, mode):
_, _ = features, labels
v = variables.Variable(21, name='some_var')
train_op = None
loss = constant_op.constant(104)
if mode == model_fn_lib.ModeKeys.TRAIN:
loss = constant_op.constant(105)
predictions = constant_op.constant([501])
train_op = control_flow_ops.group(
state_ops.assign_add(training.get_global_step(), 1),
state_ops.assign_add(v, 3))
elif mode == model_fn_lib.ModeKeys.EVAL:
loss = constant_op.constant(106)
predictions = constant_op.constant([502])
else:
loss = constant_op.constant(107)
predictions = constant_op.constant([503])
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
train_op=train_op,
eval_metric_ops={
'abs_err': metrics_lib.mean_absolute_error(
constant_op.constant(0), predictions)},
predictions=predictions)
class SavedModelEstimatorTest(test.TestCase):
def setUp(self):
self.tmpdirs = []
def tearDown(self):
for tmpdir in self.tmpdirs:
# gfile.DeleteRecursively fails in the windows cmake test, so use shutil.
shutil.rmtree(tmpdir, ignore_errors=True)
self.tmpdirs = []
def _get_tmp_dir(self):
tmpdir = tempfile.mkdtemp()
self.tmpdirs.append(tmpdir)
return tmpdir
def _export_estimator(self, train=True, evaluate=True, predict=True,
model_fn=model_fn_diff_modes):
est = estimator.Estimator(model_fn, self._get_tmp_dir())
est.train(input_fn=dummy_input_fn, steps=10)
input_receiver_fn_map = {}
if train:
input_receiver_fn_map[model_fn_lib.ModeKeys.TRAIN] = (
dummy_supervised_receiver_fn())
if evaluate:
input_receiver_fn_map[model_fn_lib.ModeKeys.EVAL] = (
dummy_supervised_receiver_fn())
if predict:
input_receiver_fn_map[model_fn_lib.ModeKeys.PREDICT] = (
dummy_serving_receiver_fn())
export_base_path = self._get_tmp_dir()
export_dir = contrib_export.export_all_saved_models(
est, export_base_path, input_receiver_fn_map)
return export_dir
def test_load_all_modes(self):
sme = saved_model_estimator.SavedModelEstimator(
self._export_estimator(), self._get_tmp_dir())
sme.train(input_fn=dummy_input_fn, steps=1)
sme.train(input_fn=dummy_input_fn, steps=2)
self.assertEqual(13, sme.get_variable_value('global_step'))
self.assertEqual(60, sme.get_variable_value('some_var'))
eval_results = sme.evaluate(dummy_input_fn, steps=5)
self.assertEqual(13, eval_results['global_step'])
self.assertEqual(106, eval_results['loss'])
self.assertEqual(502, eval_results['metrics/abs_err'])
predictions = next(sme.predict(dummy_input_fn_features_only))
self.assertDictEqual({'output': 503}, predictions)
def test_load_all_modes_no_train(self):
"""Ensure that all functions can be used without requiring a ckpt."""
sme = saved_model_estimator.SavedModelEstimator(
self._export_estimator(), self._get_tmp_dir())
eval_results = sme.evaluate(dummy_input_fn, steps=5)
self.assertEqual(10, eval_results['global_step'])
self.assertEqual(106, eval_results['loss'])
self.assertEqual(502, eval_results['metrics/abs_err'])
predictions = next(sme.predict(dummy_input_fn_features_only))
self.assertDictEqual({'output': 503}, predictions)
def test_partial_exported_estimator(self):
sme1 = saved_model_estimator.SavedModelEstimator(
self._export_estimator(train=False, predict=False), self._get_tmp_dir())
sme1.evaluate(dummy_input_fn, steps=5)
with self.assertRaisesRegexp(RuntimeError, 'train mode is not available'):
sme1.train(input_fn=dummy_input_fn, steps=1)
with self.assertRaisesRegexp(RuntimeError, 'infer mode is not available'):
next(sme1.predict(dummy_input_fn_features_only))
sme2 = saved_model_estimator.SavedModelEstimator(
self._export_estimator(evaluate=False), self._get_tmp_dir())
sme2.train(input_fn=dummy_input_fn, steps=1)
next(sme2.predict(dummy_input_fn_features_only))
with self.assertRaisesRegexp(RuntimeError, 'eval mode is not available'):
sme2.evaluate(dummy_input_fn, steps=5)
def test_with_incorrect_input(self):
sme = saved_model_estimator.SavedModelEstimator(
self._export_estimator(), self._get_tmp_dir())
def bad_shape_input_fn():
return dataset_ops.Dataset.from_tensors((
{'x': constant_op.constant([1, 2], dtype=dtypes.int64)},
constant_op.constant([1, 2], dtype=dtypes.float32)))
with self.assertRaisesRegexp(ValueError, 'Expected shape'):
sme.train(bad_shape_input_fn, steps=1)
def bad_dtype_input_fn():
return dataset_ops.Dataset.from_tensors((
{'x': constant_op.constant([[1], [1]], dtype=dtypes.int32)},
constant_op.constant([[1], [1]], dtype=dtypes.int64)))
with self.assertRaisesRegexp(ValueError, 'Expected dtype'):
sme.train(bad_dtype_input_fn, steps=1)
def test_input_fn_with_global_step(self):
sme = saved_model_estimator.SavedModelEstimator(
self._export_estimator(), self._get_tmp_dir())
def bad_input_fn():
training.get_or_create_global_step()
return dataset_ops.Dataset.from_tensors((
{'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)},
constant_op.constant([[1], [1]], dtype=dtypes.float32)))
with self.assertRaisesRegexp(RuntimeError,
'Graph must not contain a global step tensor'):
sme.train(bad_input_fn, steps=1)
def test_re_export_saved_model_serving_only(self):
sme = saved_model_estimator.SavedModelEstimator(
self._export_estimator(), self._get_tmp_dir())
sme.train(dummy_input_fn, steps=3)
self.assertEqual(13, sme.get_variable_value('global_step'))
self.assertEqual(60, sme.get_variable_value('some_var'))
predictions = next(sme.predict(dummy_input_fn_features_only))
self.assertDictEqual({'output': 503}, predictions)
# Export SavedModel, and test that the variable and prediction values are
# the same.
sme_export_dir = sme.export_savedmodel(
self._get_tmp_dir(), dummy_serving_receiver_fn())
sme2 = saved_model_estimator.SavedModelEstimator(
sme_export_dir, self._get_tmp_dir())
self.assertEqual(60, sme.get_variable_value('some_var'))
self.assertEqual(13, sme.get_variable_value('global_step'))
predictions = next(sme2.predict(dummy_input_fn_features_only))
self.assertDictEqual({'output': 503}, predictions)
def test_re_export_saved_model(self):
sme = saved_model_estimator.SavedModelEstimator(
self._export_estimator(), self._get_tmp_dir())
self.assertDictEqual(
{'loss': 106, 'metrics/abs_err': 502, 'global_step': 10},
sme.evaluate(dummy_input_fn, steps=1))
sme.train(dummy_input_fn, steps=3)
self.assertDictEqual(
{'loss': 106, 'metrics/abs_err': 502, 'global_step': 13},
sme.evaluate(dummy_input_fn, steps=1))
self.assertEqual(60, sme.get_variable_value('some_var'))
predictions = next(sme.predict(dummy_input_fn_features_only))
self.assertDictEqual({'output': 503}, predictions)
# Export SavedModel for all modes
input_receiver_fn_map = {
model_fn_lib.ModeKeys.TRAIN: dummy_supervised_receiver_fn(),
model_fn_lib.ModeKeys.EVAL: dummy_supervised_receiver_fn(),
model_fn_lib.ModeKeys.PREDICT: dummy_serving_receiver_fn()}
sme_export_dir = contrib_export.export_all_saved_models(
sme, self._get_tmp_dir(), input_receiver_fn_map)
sme2 = saved_model_estimator.SavedModelEstimator(
sme_export_dir, self._get_tmp_dir())
self.assertDictEqual(
{'loss': 106, 'metrics/abs_err': 502, 'global_step': 13},
sme.evaluate(dummy_input_fn, steps=1))
self.assertEqual(60, sme.get_variable_value('some_var'))
sme.train(dummy_input_fn, steps=7)
self.assertEqual(20, sme.get_variable_value('global_step'))
predictions = next(sme2.predict(dummy_input_fn_features_only))
self.assertDictEqual({'output': 503}, predictions)
def test_load_saved_model_from_serving_only(self):
def model_fn(features, labels, mode):
_, _ = features, labels
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant([103]),
train_op=state_ops.assign_add(training.get_global_step(), 1),
predictions=constant_op.constant([502]),
export_outputs={'test': export_output.ClassificationOutput(
constant_op.constant([[32.]]))})
est = estimator.Estimator(model_fn, self._get_tmp_dir())
est.train(input_fn=dummy_input_fn, steps=10)
def serving_input_receiver_fn():
return export.ServingInputReceiver(
{'test-features': constant_op.constant([[1], [1]])},
array_ops.placeholder(dtype=dtypes.string))
export_dir = est.export_savedmodel(
self._get_tmp_dir(), serving_input_receiver_fn)
sme = saved_model_estimator.SavedModelEstimator(
export_dir, self._get_tmp_dir())
def input_fn():
return {'inputs': constant_op.constant('someinputstr')}
prediction = next(sme.predict(input_fn))
self.assertDictEqual({'scores': 32}, prediction)
def test_with_local_init_op(self):
def model_fn(features, labels, mode):
_, _ = features, labels
v = variables.Variable(21, name='some_var')
scaffold = monitored_session.Scaffold(
local_init_op=state_ops.assign_add(v, -3).op
)
return model_fn_lib.EstimatorSpec(
mode,
scaffold=scaffold,
train_op=state_ops.assign_add(training.get_global_step(), 1),
loss=array_ops.identity(v))
export_dir = self._export_estimator(predict=False, model_fn=model_fn)
sme = saved_model_estimator.SavedModelEstimator(
export_dir, self._get_tmp_dir())
eval_results1 = sme.evaluate(dummy_input_fn, steps=2)
self.assertEqual(15, eval_results1['loss'])
sme.train(dummy_input_fn, steps=1)
self.assertEqual(15, sme.get_variable_value('some_var'))
eval_results2 = sme.evaluate(dummy_input_fn, steps=5)
self.assertEqual(12, eval_results2['loss'])
def test_with_working_input_fn(self):
def model_fn(features, labels, mode):
loss = None
if labels is not None:
loss = labels[0][0] + labels[1][0]
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
train_op=state_ops.assign_add(training.get_global_step(), 1),
predictions={'features_0': array_ops.identity([features['x'][0][0]]),
'features_1': array_ops.identity([features['x'][1][0]])})
sme = saved_model_estimator.SavedModelEstimator(
self._export_estimator(model_fn=model_fn), self._get_tmp_dir())
eval_results = sme.evaluate(dummy_input_fn, steps=1)
self.assertEqual(1, eval_results['loss'])
predictions = next(sme.predict(dummy_input_fn_features_only))
self.assertDictEqual({'features_0': 5, 'features_1': 6}, predictions)
def test_control_dependency(self):
# Control dependencies are saved with "^" appended to the start of the input
# name. The input map must include control dependencies as well.
def model_fn(features, labels, mode):
_ = labels
with ops.control_dependencies([features['x']]):
loss = features['x'][1][0]
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
train_op=state_ops.assign_add(training.get_global_step(), 1))
sme = saved_model_estimator.SavedModelEstimator(
self._export_estimator(train=False, predict=False, model_fn=model_fn),
self._get_tmp_dir())
sme.evaluate(dummy_input_fn, steps=1) # Should run without error
if __name__ == '__main__':
test.main()

View File

@ -568,14 +568,13 @@ class Estimator(object):
def _assert_members_are_not_overridden(self):
"""Asserts members of `Estimator` are not overridden."""
allowed_overrides = set([
'_call_input_fn', '_call_model_fn',
'_call_input_fn', '_create_global_step',
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
'_create_global_step', '_create_and_assert_global_step',
'_tf_api_names', '_tf_api_names_v1', '_estimator_api_names',
'_estimator_api_names_v1', '_estimator_api_constants',
'_estimator_api_constants_v1',
'_validate_features_in_predict_input',
'_add_meta_graph_for_mode'
'_call_model_fn', '_add_meta_graph_for_mode'
])
estimator_members = set([m for m in Estimator.__dict__.keys()
if not m.startswith('__')])
@ -902,10 +901,9 @@ class Estimator(object):
with tf_session.Session(config=self._session_config) as session:
if estimator_spec.scaffold.local_init_op is not None:
local_init_op = estimator_spec.scaffold.local_init_op
else:
local_init_op = monitored_session.Scaffold.default_local_init_op()
local_init_op = (
estimator_spec.scaffold.local_init_op or
monitored_session.Scaffold.default_local_init_op())
# This saver will be used both for restoring variables now,
# and in saving out the metagraph below. This ensures that any
@ -1156,15 +1154,14 @@ class Estimator(object):
worker_hooks = []
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
self._create_and_assert_global_step(g)
global_step_tensor = self._create_and_assert_global_step(g)
training_util._get_or_create_global_step_read() # pylint: disable=protected-access
features, labels, input_hooks = (
self._get_features_and_labels_from_input_fn(
input_fn, model_fn_lib.ModeKeys.TRAIN))
worker_hooks.extend(input_hooks)
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
global_step_tensor = training_util.get_global_step(g)
training_util._get_or_create_global_step_read() # pylint: disable=protected-access
return self._train_with_estimator_spec(estimator_spec, worker_hooks,
hooks, global_step_tensor,
saving_listeners)
@ -1367,8 +1364,10 @@ class Estimator(object):
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
global_step_tensor, saving_listeners):
"""Train a model with the given Estimator Spec."""
self._maybe_warm_start(self.latest_checkpoint())
if self._warm_start_settings:
logging.info('Warm-starting with WarmStartSettings: %s' %
(self._warm_start_settings,))
warm_starting_util.warm_start(*self._warm_start_settings)
# Check if the user created a loss summary, and add one if they didn't.
# We assume here that the summary is called 'loss'. If it is not, we will
# make another one with the name 'loss' to ensure it shows up in the right
@ -1449,13 +1448,13 @@ class Estimator(object):
def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None):
"""Builds the graph and related hooks to run evaluation."""
random_seed.set_random_seed(self._config.tf_random_seed)
self._create_and_assert_global_step(ops.get_default_graph())
global_step_tensor = self._create_and_assert_global_step(
ops.get_default_graph())
features, labels, input_hooks = (
self._get_features_and_labels_from_input_fn(input_fn,
model_fn_lib.ModeKeys.EVAL))
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL, self.config)
global_step_tensor = training_util.get_global_step(ops.get_default_graph())
# Call to warm_start has to be after model_fn is called.
self._maybe_warm_start(checkpoint_path)
@ -1481,21 +1480,7 @@ class Estimator(object):
all_hooks.extend(hooks)
all_hooks.extend(list(estimator_spec.evaluation_hooks or []))
# New local variables have been added, so update the estimator spec's
# local init op if it was defined.
scaffold = estimator_spec.scaffold
if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op:
# Ensure that eval step has been created before updating local init op.
evaluation._get_or_create_eval_step() # pylint: disable=protected-access
scaffold = monitored_session.Scaffold(
local_init_op=control_flow_ops.group(
estimator_spec.scaffold.local_init_op,
monitored_session.Scaffold.default_local_init_op()),
copy_from_scaffold=scaffold
)
return scaffold, update_op, eval_dict, all_hooks
return estimator_spec.scaffold, update_op, eval_dict, all_hooks
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,
all_hooks, output_dir):
@ -1926,19 +1911,6 @@ class WarmStartSettings(
)
def _get_saved_model_ckpt(saved_model_dir):
"""Return path to variables checkpoint in a SavedModel directory."""
if not gfile.Exists(
os.path.join(compat.as_bytes(saved_model_dir),
compat.as_bytes('variables/variables.index'))):
raise ValueError('Directory provided has an invalid SavedModel format: %s'
% saved_model_dir)
return os.path.join(
compat.as_bytes(saved_model_dir),
compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY,
constants.VARIABLES_FILENAME)))
def _get_default_warm_start_settings(warm_start_from):
"""Returns default WarmStartSettings.
@ -1962,8 +1934,10 @@ def _get_default_warm_start_settings(warm_start_from):
if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from),
compat.as_bytes('variables/variables.index'))):
logging.info('Warm-starting from a SavedModel')
return WarmStartSettings(
ckpt_to_initialize_from=_get_saved_model_ckpt(warm_start_from))
return WarmStartSettings(ckpt_to_initialize_from=os.path.join(
compat.as_bytes(warm_start_from),
compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY,
constants.VARIABLES_FILENAME))))
return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
elif isinstance(warm_start_from, WarmStartSettings):
return warm_start_from

View File

@ -205,7 +205,7 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
for input_src, input_dst in input_map.items():
input_src = compat.as_str(input_src)
if input_src.startswith('^'):
src_name = compat.as_str(input_src[1:])
src_name = compat.as_bytes(input_src[1:])
dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access
c_api.TF_ImportGraphDefOptionsRemapControlDependency(
options, src_name, dst_op)

View File

@ -696,67 +696,6 @@ def import_scoped_meta_graph(meta_graph_or_file,
Raises:
ValueError: If the graph_def contains unbound inputs.
"""
return import_scoped_meta_graph_with_return_elements(
meta_graph_or_file, clear_devices, graph, import_scope, input_map,
unbound_inputs_col_name, restore_collections_predicate)[0]
def import_scoped_meta_graph_with_return_elements(
meta_graph_or_file,
clear_devices=False,
graph=None,
import_scope=None,
input_map=None,
unbound_inputs_col_name="unbound_inputs",
restore_collections_predicate=(lambda key: True),
return_elements=None):
"""Imports graph from `MetaGraphDef` and returns vars and return elements.
This function takes a `MetaGraphDef` protocol buffer as input. If
the argument is a file containing a `MetaGraphDef` protocol buffer ,
it constructs a protocol buffer from the file content. The function
then adds all the nodes from the `graph_def` field to the
current graph, recreates the desired collections, and returns a dictionary of
all the Variables imported into the name scope.
In combination with `export_scoped_meta_graph()`, this function can be used to
* Serialize a graph along with other Python objects such as `QueueRunner`,
`Variable` into a `MetaGraphDef`.
* Restart training from a saved graph and checkpoints.
* Run inference from a saved graph and checkpoints.
Args:
meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
the path) containing a `MetaGraphDef`.
clear_devices: Boolean which controls whether to clear device information
from graph_def. Default false.
graph: The `Graph` to import into. If `None`, use the default graph.
import_scope: Optional `string`. Name scope into which to import the
subgraph. If `None`, the graph is imported to the root name scope.
input_map: A dictionary mapping input names (as strings) in `graph_def` to
`Tensor` objects. The values of the named input tensors in the imported
graph will be re-mapped to the respective `Tensor` values.
unbound_inputs_col_name: Collection name for looking up unbound inputs.
restore_collections_predicate: a predicate on collection names. A collection
named c (i.e whose key is c) will be restored iff
1) `restore_collections_predicate(c)` is True, and
2) `c != unbound_inputs_col_name`.
return_elements: A list of strings containing operation names in the
`MetaGraphDef` that will be returned as `Operation` objects; and/or
tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.
Returns:
A tuple of (
dictionary of all the `Variables` imported into the name scope,
list of `Operation` or `Tensor` objects from the `return_elements` list).
Raises:
ValueError: If the graph_def contains unbound inputs.
"""
if context.executing_eagerly():
raise ValueError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled.")
@ -798,12 +737,11 @@ def import_scoped_meta_graph_with_return_elements(
scope_to_prepend_to_names = graph.unique_name(
import_scope or "", mark_as_used=False)
imported_return_elements = importer.import_graph_def(
importer.import_graph_def(
input_graph_def,
name=(import_scope or scope_to_prepend_to_names),
input_map=input_map,
producer_op_list=producer_op_list,
return_elements=return_elements)
producer_op_list=producer_op_list)
# Restores all the other collections.
variable_objects = {}
@ -868,7 +806,7 @@ def import_scoped_meta_graph_with_return_elements(
for v in variables:
var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
return var_list, imported_return_elements
return var_list
def export_scoped_meta_graph(filename=None,

View File

@ -284,15 +284,12 @@ class SavedModelLoader(object):
**saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
Returns:
A tuple of
* Saver defined by the MetaGraph, which can be used to restore the
variable values.
* List of `Operation`/`Tensor` objects returned from
`tf.import_graph_def` (may be `None`).
Saver defined by the MetaGraph, which can be used to restore the variable
values.
"""
meta_graph_def = self.get_meta_graph_def_from_tags(tags)
with graph.as_default():
return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access
return tf_saver.import_meta_graph(
meta_graph_def, import_scope=import_scope, **saver_kwargs)
def restore_variables(self, sess, saver, import_scope=None):
@ -364,8 +361,8 @@ class SavedModelLoader(object):
`MetagraphDef` proto of the graph that was loaded.
"""
with sess.graph.as_default():
saver, _ = self.load_graph(sess.graph, tags, import_scope,
**saver_kwargs)
saver = self.load_graph(sess.graph, tags, import_scope,
**saver_kwargs)
self.restore_variables(sess, saver, import_scope)
self.run_init_ops(sess, tags, import_scope)
return self.get_meta_graph_def_from_tags(tags)

View File

@ -111,8 +111,7 @@ class SavedModelLoaderTest(test.TestCase):
def test_load_with_import_scope(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
with self.test_session(graph=ops.Graph()) as sess:
saver, _ = loader.load_graph(
sess.graph, ["foo_graph"], import_scope="baz")
saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz")
# The default saver should not work when the import scope is set.
with self.assertRaises(errors.NotFoundError):
@ -150,7 +149,7 @@ class SavedModelLoaderTest(test.TestCase):
def test_run_init_op(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
graph = ops.Graph()
saver, _ = loader.load_graph(graph, ["foo_graph"])
saver = loader.load_graph(graph, ["foo_graph"])
with self.test_session(graph=graph) as sess:
loader.restore_variables(sess, saver)
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
@ -204,7 +203,7 @@ class SavedModelLoaderTest(test.TestCase):
loader = loader_impl.SavedModelLoader(path)
with self.test_session(graph=ops.Graph()) as sess:
saver, _ = loader.load_graph(sess.graph, ["foo_graph"])
saver = loader.load_graph(sess.graph, ["foo_graph"])
self.assertFalse(variables._all_saveable_objects())
self.assertIsNotNone(saver)
@ -213,18 +212,6 @@ class SavedModelLoaderTest(test.TestCase):
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
def test_load_saved_model_graph_with_return_elements(self):
"""Ensure that the correct elements are returned."""
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
graph = ops.Graph()
_, ret = loader.load_graph(graph, ["foo_graph"],
return_elements=["y:0", "x:0"])
self.assertEqual(graph.get_tensor_by_name("y:0"), ret[0])
self.assertEqual(graph.get_tensor_by_name("x:0"), ret[1])
with self.assertRaisesRegexp(ValueError, "not found in graph"):
loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"])
if __name__ == "__main__":
test.main()

View File

@ -1928,14 +1928,6 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
execution is enabled.
@end_compatibility
""" # pylint: disable=g-doc-exception
return _import_meta_graph_with_return_elements(
meta_graph_or_file, clear_devices, import_scope, **kwargs)[0]
def _import_meta_graph_with_return_elements(
meta_graph_or_file, clear_devices=False, import_scope=None,
return_elements=None, **kwargs):
"""Import MetaGraph, and return both a saver and returned elements."""
if context.executing_eagerly():
raise RuntimeError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled. No graph exists when eager "
@ -1945,22 +1937,12 @@ def _import_meta_graph_with_return_elements(
else:
meta_graph_def = meta_graph_or_file
imported_vars, imported_return_elements = (
meta_graph.import_scoped_meta_graph_with_return_elements(
meta_graph_def,
clear_devices=clear_devices,
import_scope=import_scope,
return_elements=return_elements,
**kwargs))
imported_vars = meta_graph.import_scoped_meta_graph(
meta_graph_def,
clear_devices=clear_devices,
import_scope=import_scope,
**kwargs)
saver = _create_saver_from_imported_meta_graph(
meta_graph_def, import_scope, imported_vars)
return saver, imported_return_elements
def _create_saver_from_imported_meta_graph(
meta_graph_def, import_scope, imported_vars):
"""Return a saver for restoring variable values to an imported MetaGraph."""
if meta_graph_def.HasField("saver_def"):
# Infer the scope that is prepended by `import_scoped_meta_graph`.
scope = import_scope