Simplified estimator logic by MonitoredSession.

Removed graph_action usage.
Change: 144126485
This commit is contained in:
Mustafa Ispir 2017-01-10 14:12:02 -08:00 committed by TensorFlower Gardener
parent 3e59f0540e
commit 61a6797c4f
2 changed files with 318 additions and 347 deletions

View File

@ -21,20 +21,28 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time
import numpy as np import numpy as np
from tensorflow.contrib import framework from tensorflow.contrib import framework
from tensorflow.contrib.factorization.python.ops import gmm_ops from tensorflow.contrib.factorization.python.ops import gmm_ops
from tensorflow.contrib.framework.python.framework import checkpoint_utils from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed as random_seed_lib
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
from tensorflow.python.ops.control_flow_ops import with_dependencies from tensorflow.python.ops.control_flow_ops import with_dependencies
from tensorflow.python.platform import tf_logging as logging
def _streaming_sum(scalar_tensor): def _streaming_sum(scalar_tensor):
@ -44,7 +52,7 @@ def _streaming_sum(scalar_tensor):
return sum_metric, sum_update return sum_metric, sum_update
class GMM(estimator.Estimator, TransformerMixin): class GMM(estimator_lib.Estimator, TransformerMixin):
"""GMM clustering.""" """GMM clustering."""
SCORES = 'scores' SCORES = 'scores'
ASSIGNMENTS = 'assignments' ASSIGNMENTS = 'assignments'
@ -116,7 +124,8 @@ class GMM(estimator.Estimator, TransformerMixin):
self._data_feeder = data_feeder.setup_train_data_feeder(x, None, self._data_feeder = data_feeder.setup_train_data_feeder(x, None,
self._num_clusters, self._num_clusters,
self.batch_size) self.batch_size)
self._train_model( _legacy_train_model( # pylint: disable=protected-access
self,
input_fn=self._data_feeder.input_builder, input_fn=self._data_feeder.input_builder,
feed_fn=self._data_feeder.get_feed_dict_fn(), feed_fn=self._data_feeder.get_feed_dict_fn(),
steps=steps or self.steps, steps=steps or self.steps,
@ -218,3 +227,90 @@ class GMM(estimator.Estimator, TransformerMixin):
self._covariance_type, self._covariance_type,
self._params) self._params)
return {GMM.SCORES: _streaming_sum(math_ops.reduce_sum(losses))} return {GMM.SCORES: _streaming_sum(math_ops.reduce_sum(losses))}
# TODO(xavigonzalvo): delete this after implementing model-fn based Estimator.
def _legacy_train_model(estimator,
input_fn,
steps,
feed_fn=None,
init_op=None,
init_feed_fn=None,
init_fn=None,
device_fn=None,
monitors=None,
log_every_steps=100,
fail_on_nan_loss=True,
max_steps=None):
"""Legacy train function of Estimator."""
if hasattr(estimator.config, 'execution_mode'):
if estimator.config.execution_mode not in ('all', 'train'):
return
# Stagger startup of worker sessions based on task id.
sleep_secs = min(
estimator.config.training_worker_max_startup_secs,
estimator.config.task_id *
estimator.config.training_worker_session_startup_stagger_secs)
if sleep_secs:
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
estimator.config.task_id)
time.sleep(sleep_secs)
# Device allocation
device_fn = device_fn or estimator._device_fn # pylint: disable=protected-access
with ops.Graph().as_default() as g, g.device(device_fn):
random_seed_lib.set_random_seed(estimator.config.tf_random_seed)
global_step = framework.create_global_step(g)
features, labels = input_fn()
estimator._check_inputs(features, labels) # pylint: disable=protected-access
# The default return type of _get_train_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_train_ops returns a
# (train_op, loss) tuple. The following else-statement code covers these
# cases, but will soon be deleted after the subclasses are updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
train_ops = estimator._get_train_ops(features, labels) # pylint: disable=protected-access
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
train_op = train_ops.train_op
loss_op = train_ops.loss
if estimator.config.is_chief:
hooks = train_ops.training_chief_hooks + train_ops.training_hooks
else:
hooks = train_ops.training_hooks
else: # Legacy signature
if len(train_ops) != 2:
raise ValueError('Expected a tuple of train_op and loss, got {}'.format(
train_ops))
train_op = train_ops[0]
loss_op = train_ops[1]
hooks = []
hooks += monitor_lib.replace_monitors_with_hooks(monitors, estimator)
ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
return graph_actions._monitored_train( # pylint: disable=protected-access
graph=g,
output_dir=estimator.model_dir,
train_op=train_op,
loss_op=loss_op,
global_step_tensor=global_step,
init_op=init_op,
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
init_fn=init_fn,
log_every_steps=log_every_steps,
supervisor_is_chief=estimator.config.is_chief,
supervisor_master=estimator.config.master,
supervisor_save_model_secs=estimator.config.save_checkpoints_secs,
supervisor_save_model_steps=estimator.config.save_checkpoints_steps,
supervisor_save_summaries_steps=estimator.config.save_summary_steps,
keep_checkpoint_max=estimator.config.keep_checkpoint_max,
keep_checkpoint_every_n_hours=(
estimator.config.keep_checkpoint_every_n_hours),
feed_fn=feed_fn,
steps=steps,
fail_on_nan_loss=fail_on_nan_loss,
hooks=hooks,
max_steps=max_steps)

View File

@ -22,10 +22,8 @@ from __future__ import print_function
import abc import abc
import copy import copy
import inspect import inspect
import itertools
import os import os
import tempfile import tempfile
import time
import numpy as np import numpy as np
import six import six
@ -39,10 +37,8 @@ from tensorflow.contrib.framework import deprecated_args
from tensorflow.contrib.framework import list_variables from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable from tensorflow.contrib.framework import load_variable
from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import ops as contrib_ops
from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn import trainable
@ -58,7 +54,6 @@ from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.contrib.training.python.training import evaluation from tensorflow.contrib.training.python.training import evaluation
from tensorflow.core.framework import summary_pb2 from tensorflow.core.framework import summary_pb2
from tensorflow.python.client import session as tf_session from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
@ -92,6 +87,25 @@ SCIKIT_DECOUPLE_INSTRUCTIONS = (
' est = Estimator(...) -> est = SKCompat(Estimator(...))') ' est = Estimator(...) -> est = SKCompat(Estimator(...))')
def _verify_input_args(x, y, input_fn, feed_fn, batch_size):
"""Verifies validity of co-existance of input arguments."""
if input_fn is None:
if x is None:
raise ValueError('Either x or input_fn must be provided.')
if contrib_framework.is_tensor(x) or (y is not None and
contrib_framework.is_tensor(y)):
raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
if feed_fn is not None:
raise ValueError('Can not provide both feed_fn and x or y.')
else:
if (x is not None) or (y is not None):
raise ValueError('Can not provide both input_fn and x or y.')
if batch_size is not None:
raise ValueError('Can not provide both input_fn and batch_size.')
def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1): def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
"""Make inputs into input and feed functions. """Make inputs into input and feed functions.
@ -110,29 +124,17 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
Raises: Raises:
ValueError: Only one of `(x & y)` or `input_fn` must be provided. ValueError: Only one of `(x & y)` or `input_fn` must be provided.
""" """
if input_fn is None: _verify_input_args(x, y, input_fn, feed_fn, batch_size)
if x is None: if input_fn is not None:
raise ValueError('Either x or input_fn must be provided.') return input_fn, feed_fn
df = data_feeder.setup_train_data_feeder(
if contrib_framework.is_tensor(x) or (y is not None and x,
contrib_framework.is_tensor(y)): y,
raise ValueError('Inputs cannot be tensors. Please provide input_fn.') n_classes=None,
batch_size=batch_size,
if feed_fn is not None: shuffle=shuffle,
raise ValueError('Can not provide both feed_fn and x or y.') epochs=epochs)
return df.input_builder, df.get_feed_dict_fn()
df = data_feeder.setup_train_data_feeder(x, y, n_classes=None,
batch_size=batch_size,
shuffle=shuffle,
epochs=epochs)
return df.input_builder, df.get_feed_dict_fn()
if (x is not None) or (y is not None):
raise ValueError('Can not provide both input_fn and x or y.')
if batch_size is not None:
raise ValueError('Can not provide both input_fn and batch_size.')
return input_fn, feed_fn
def infer_real_valued_columns_from_input_fn(input_fn): def infer_real_valued_columns_from_input_fn(input_fn):
@ -311,9 +313,8 @@ def _write_dict_to_summary(output_dir,
dictionary: the `dict` to be written to summary file. dictionary: the `dict` to be written to summary file.
current_global_step: `int`, the current global step. current_global_step: `int`, the current global step.
""" """
logging.info( logging.info('Saving dict for global step %d: %s', current_global_step,
'Saving dict for global step %d: %s' % _dict_to_str(dictionary))
(current_global_step, _dict_to_str(dictionary)))
summary_writer = summary_io.SummaryWriterCache.get(output_dir) summary_writer = summary_io.SummaryWriterCache.get(output_dir)
summary_proto = summary_pb2.Summary() summary_proto = summary_pb2.Summary()
for key in dictionary: for key in dictionary:
@ -404,15 +405,24 @@ class BaseEstimator(
""" """
if (steps is not None) and (max_steps is not None): if (steps is not None) and (max_steps is not None):
raise ValueError('Can not provide both steps and max_steps.') raise ValueError('Can not provide both steps and max_steps.')
_verify_input_args(x, y, input_fn, None, batch_size)
if x is not None:
return SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
input_fn, feed_fn = _get_input_fn(x, y, input_fn, feed_fn=None, if max_steps is not None:
batch_size=batch_size, shuffle=True, try:
epochs=None) start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
loss = self._train_model(input_fn=input_fn, if max_steps <= start_step:
feed_fn=feed_fn, logging.info('Skipping training since max_steps has already saved.')
steps=steps, return None
monitors=monitors, except: # pylint: disable=bare-except
max_steps=max_steps) pass
hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
if steps is not None or max_steps is not None:
hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
loss = self._train_model(input_fn=input_fn, hooks=hooks)
logging.info('Loss for final step: %s.', loss) logging.info('Loss for final step: %s.', loss)
return self return self
@ -485,9 +495,10 @@ class BaseEstimator(
`input_fn` or `feed_fn` is provided. `input_fn` or `feed_fn` is provided.
Or if `metrics` is not `None` or `dict`. Or if `metrics` is not `None` or `dict`.
""" """
input_fn, feed_fn = _get_input_fn(x, y, input_fn=input_fn, _verify_input_args(x, y, input_fn, feed_fn, batch_size)
feed_fn=feed_fn, batch_size=batch_size, if x is not None:
shuffle=False, epochs=1) return SKCompat(self).score(x, y, batch_size, steps, metrics)
if metrics is not None and not isinstance(metrics, dict): if metrics is not None and not isinstance(metrics, dict):
raise ValueError('Metrics argument should be None or dict. ' raise ValueError('Metrics argument should be None or dict. '
'Got %s.' % metrics) 'Got %s.' % metrics)
@ -537,11 +548,15 @@ class BaseEstimator(
Raises: Raises:
ValueError: If x and input_fn are both provided or both `None`. ValueError: If x and input_fn are both provided or both `None`.
""" """
input_fn, feed_fn = _get_input_fn( _verify_input_args(x, None, input_fn, None, batch_size)
x, None, input_fn=input_fn, feed_fn=None, batch_size=batch_size, if x is not None and not as_iterable:
shuffle=False, epochs=1) return SKCompat(self).predict(x, batch_size)
input_fn, feed_fn = _get_input_fn(x, None, input_fn, None, batch_size)
return self._infer_model( return self._infer_model(
input_fn=input_fn, feed_fn=feed_fn, outputs=outputs, input_fn=input_fn,
feed_fn=feed_fn,
outputs=outputs,
as_iterable=as_iterable) as_iterable=as_iterable)
def get_variable_value(self, name): def get_variable_value(self, name):
@ -728,91 +743,6 @@ class BaseEstimator(
self._labels_info = tensor_signature.create_signatures(labels) self._labels_info = tensor_signature.create_signatures(labels)
logging.debug('Setting labels info to %s', str(self._labels_info)) logging.debug('Setting labels info to %s', str(self._labels_info))
def _train_model(self,
input_fn,
steps,
feed_fn=None,
init_op=None,
init_feed_fn=None,
init_fn=None,
device_fn=None,
monitors=None,
log_every_steps=100,
fail_on_nan_loss=True,
max_steps=None):
# TODO(wicke): Remove this once Model and associated code are gone.
if hasattr(self._config, 'execution_mode'):
if self._config.execution_mode not in ('all', 'train'):
return
# Stagger startup of worker sessions based on task id.
sleep_secs = min(
self._config.training_worker_max_startup_secs,
self._config.task_id *
self._config.training_worker_session_startup_stagger_secs)
if sleep_secs:
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
self._config.task_id)
time.sleep(sleep_secs)
# Device allocation
device_fn = device_fn or self._device_fn
self._graph = ops.Graph()
with self._graph.as_default() as g, g.device(device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
# The default return type of _get_train_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_train_ops returns a
# (train_op, loss) tuple. The following else-statement code covers these
# cases, but will soon be deleted after the subclasses are updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
train_ops = self._get_train_ops(features, labels)
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
train_op = train_ops.train_op
loss_op = train_ops.loss
if self.config.is_chief:
hooks = train_ops.training_chief_hooks + train_ops.training_hooks
else:
hooks = train_ops.training_hooks
else: # Legacy signature
if len(train_ops) != 2:
raise ValueError('Expected a tuple of train_op and loss, got {}'.
format(train_ops))
train_op = train_ops[0]
loss_op = train_ops[1]
hooks = []
hooks += monitor_lib.replace_monitors_with_hooks(monitors, self)
ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
return graph_actions._monitored_train( # pylint: disable=protected-access
graph=g,
output_dir=self._model_dir,
train_op=train_op,
loss_op=loss_op,
global_step_tensor=global_step,
init_op=init_op,
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
init_fn=init_fn,
log_every_steps=log_every_steps,
supervisor_is_chief=self.config.is_chief,
supervisor_master=self._config.master,
supervisor_save_model_secs=self._config.save_checkpoints_secs,
supervisor_save_model_steps=self._config.save_checkpoints_steps,
supervisor_save_summaries_steps=self._config.save_summary_steps,
keep_checkpoint_max=self._config.keep_checkpoint_max,
keep_checkpoint_every_n_hours=self._config.keep_checkpoint_every_n_hours,
feed_fn=feed_fn,
steps=steps,
fail_on_nan_loss=fail_on_nan_loss,
hooks=hooks,
max_steps=max_steps)
def _extract_metric_update_ops(self, eval_dict): def _extract_metric_update_ops(self, eval_dict):
"""Separate update operations from metric value operations.""" """Separate update operations from metric value operations."""
update_ops = [] update_ops = []
@ -915,8 +845,12 @@ class BaseEstimator(
return result[0] return result[0]
return result return result
def _infer_model( def _infer_model(self,
self, input_fn, feed_fn=None, outputs=None, as_iterable=True): input_fn,
feed_fn=None,
outputs=None,
as_iterable=True,
iterate_batches=False):
# Check that model has been trained. # Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir) checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path: if not checkpoint_path:
@ -927,103 +861,152 @@ class BaseEstimator(
random_seed.set_random_seed(self._config.tf_random_seed) random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g) contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn) features = self._get_features_from_input_fn(input_fn)
infer_ops = self._call_legacy_get_predict_ops(features)
# The default return type of _get_predict_ops is ModelFnOps. But there are predictions = self._filter_predictions(infer_ops.predictions, outputs)
# some subclasses of tf.contrib.learn.Estimator which override this mon_sess = monitored_session.MonitoredSession(
# method and use the legacy signature, namely _get_predict_ops returns a session_creator=monitored_session.ChiefSessionCreator(
# `predictions` Tensor or dict or Tensors. The following else-statement checkpoint_filename_with_path=checkpoint_path))
# code covers these cases, but will soon be deleted after the subclasses if not as_iterable:
# are updated. with mon_sess:
# TODO(b/32664904): Update subclasses and delete the else-statement. if not mon_sess.should_stop():
infer_ops = self._get_predict_ops(features) return mon_sess.run(predictions, feed_fn() if feed_fn else None)
if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
predictions = infer_ops.predictions
else: # Legacy signature
predictions = infer_ops
# If predictions is single output - wrap it into dict, and remember to
# return not a dict.
return_dict = isinstance(predictions, dict)
if not return_dict:
predictions = {'predictions': predictions}
# Filter what to run predictions on, if outputs provided.
if outputs:
existing_keys = predictions.keys()
predictions = {
key: value
for key, value in six.iteritems(predictions) if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
'provided %s.' % (existing_keys, outputs))
if as_iterable:
return self._infer_model_as_iterable(
checkpoint_path, predictions, feed_fn, return_dict)
else: else:
return self._infer_model_single( return self._predict_generator(mon_sess, predictions, feed_fn,
checkpoint_path, predictions, feed_fn, return_dict) iterate_batches)
def _infer_model_single( def _predict_generator(self, mon_sess, predictions, feed_fn, iterate_batches):
self, checkpoint_path, predictions, feed_fn, return_dict): with mon_sess:
if feed_fn is None: while not mon_sess.should_stop():
preds = graph_actions.infer(checkpoint_path, predictions) preds = mon_sess.run(predictions, feed_fn() if feed_fn else None)
else: if iterate_batches:
def _feed_fn(): yield preds
while True: elif not isinstance(predictions, dict):
yield feed_fn() for pred in preds:
yield pred
outputs = graph_actions.run_feeds( else:
output_dict=predictions, first_tensor = list(preds.values())[0]
feed_dicts=_feed_fn(),
restore_checkpoint_path=checkpoint_path)
preds = {
key: np.concatenate([output[key] for output in outputs], axis=0)
for key in predictions}
return preds if return_dict else preds['predictions']
def _infer_model_as_iterable(
self, checkpoint_path, predictions, feed_fn, return_dict):
if feed_fn is None:
# If there are no queue_runners, the input `predictions` is a
# constant, and we should stop after the first epoch. If,
# instead, there are queue_runners, eventually they should throw
# an `OutOfRangeError`.
graph = contrib_ops.get_graph_from_inputs(predictions.values())
if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
feed_dicts = itertools.repeat(None)
else:
feed_dicts = [None]
else:
def _feed_fn():
while True:
yield feed_fn()
feed_dicts = _feed_fn()
try:
for output_batch in graph_actions.run_feeds_iter(
output_dict=predictions,
feed_dicts=feed_dicts,
restore_checkpoint_path=checkpoint_path):
# Unpack batches into individual predictions
if return_dict:
first_tensor = list(output_batch.values())[0]
if isinstance(first_tensor, sparse_tensor.SparseTensorValue): if isinstance(first_tensor, sparse_tensor.SparseTensorValue):
batch_length = first_tensor.dense_shape[0] batch_length = first_tensor.dense_shape[0]
else: else:
batch_length = first_tensor.shape[0] batch_length = first_tensor.shape[0]
for i in range(batch_length): for i in range(batch_length):
yield {key: value[i] for key, value in six.iteritems(output_batch)} yield {key: value[i] for key, value in six.iteritems(preds)}
else: if self._is_input_constant(feed_fn, mon_sess.graph):
for pred in output_batch['predictions']: return
yield pred
except errors.OutOfRangeError: def _is_input_constant(self, feed_fn, graph):
# We fall out of the above loop naturally if feed_fn raises StopIteration, # If there are no queue_runners, the input `predictions` is a
# or we catch an OutOfRangeError if we've reached the end of inputs. # constant, and we should stop after the first epoch. If,
logging.info('Reached end of inputs for predict_iter.') # instead, there are queue_runners, eventually they should throw
# an `OutOfRangeError`.
if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
return False
# data_feeder uses feed_fn to generate `OutOfRangeError`.
if feed_fn is not None:
return False
return True
def _filter_predictions(self, predictions, outputs):
if not outputs:
return predictions
if not isinstance(predictions, dict):
raise ValueError(
'outputs argument is not valid in case of non-dict predictions.')
existing_keys = predictions.keys()
predictions = {
key: value
for key, value in six.iteritems(predictions) if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
'provided %s.' % (existing_keys, outputs))
return predictions
def _train_model(self, input_fn, hooks):
all_hooks = []
self._graph = ops.Graph()
with self._graph.as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
model_fn_ops = self._call_legacy_get_train_ops(features, labels)
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
all_hooks.extend([
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
basic_session_run_hooks.LoggingTensorHook(
{
'loss': model_fn_ops.loss,
'step': global_step
},
every_n_iter=100)
])
all_hooks.extend(hooks)
scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
ops.add_to_collection(
ops.GraphKeys.SAVERS,
saver.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
defer_build=True))
chief_hooks = []
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
saver_hook_exists = any([
isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
model_fn_ops.training_chief_hooks)
])
if not saver_hook_exists:
chief_hooks = [
basic_session_run_hooks.CheckpointSaverHook(
self._model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=scaffold)
]
with monitored_session.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
scaffold=scaffold,
hooks=all_hooks + model_fn_ops.training_hooks,
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
config=None) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
summary_io.SummaryWriterCache.clear()
return loss
def _call_legacy_get_predict_ops(self, features):
# The default return type of _get_predict_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_predict_ops returns a
# `predictions` Tensor or dict or Tensors. The following else-statement
# code covers these cases, but will soon be deleted after the subclasses
# are updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
infer_ops = self._get_predict_ops(features)
if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
return infer_ops
return model_fn_lib.ModelFnOps(
mode=model_fn_lib.ModeKeys.INFER, predictions=infer_ops)
def _call_legacy_get_train_ops(self, features, labels):
train_ops = self._get_train_ops(features, labels)
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
return train_ops
return model_fn_lib.ModelFnOps(
mode=model_fn_lib.ModeKeys.TRAIN,
predictions=None,
loss=train_ops[1],
train_op=train_ops[0])
def _identity_feature_engineering_fn(features, labels): def _identity_feature_engineering_fn(features, labels):
@ -1177,17 +1160,6 @@ class Estimator(BaseEstimator):
""" """
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN) return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
# TODO(ispir): delete this function after converting all legacy usages.
def _call_legacy_get_train_ops(self, features, labels):
train_ops = self._get_train_ops(features, labels)
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
return train_ops
return model_fn_lib.ModelFnOps(
mode=model_fn_lib.ModeKeys.TRAIN,
predictions=None,
loss=train_ops[1],
train_op=train_ops[0])
def _get_eval_ops(self, features, labels, metrics): def _get_eval_ops(self, features, labels, metrics):
"""Method that builds model graph and returns evaluation ops. """Method that builds model graph and returns evaluation ops.
@ -1343,114 +1315,6 @@ class Estimator(BaseEstimator):
return export_dir return export_dir
@deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y',
'batch_size')
def fit(self,
x=None,
y=None,
input_fn=None,
steps=None,
batch_size=None,
monitors=None,
max_steps=None):
# pylint: disable=g-doc-args,g-doc-return-or-yield
"""See `Trainable`.
Raises:
ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
ValueError: If both `steps` and `max_steps` are not `None`.
"""
if (steps is not None) and (max_steps is not None):
raise ValueError('Can not provide both steps and max_steps.')
if max_steps is not None:
try:
start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
if max_steps <= start_step:
logging.info('Skipping training since max_steps has already saved.')
return None
except: # pylint: disable=bare-except
pass
hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
if steps is not None or max_steps is not None:
hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
input_fn, feed_fn = _get_input_fn(
x,
y,
input_fn,
feed_fn=None,
batch_size=batch_size,
shuffle=True,
epochs=None)
if feed_fn:
hooks.append(_FeedFnHook(feed_fn))
loss = self._train_model_v2(input_fn=input_fn, hooks=hooks)
logging.info('Loss for final step: %s.', loss)
return self
def _train_model_v2(self, input_fn, hooks):
all_hooks = []
self._graph = ops.Graph()
with self._graph.as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
model_fn_ops = self._call_legacy_get_train_ops(features, labels)
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
all_hooks.extend([
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
basic_session_run_hooks.LoggingTensorHook(
{
'loss': model_fn_ops.loss,
'step': global_step
},
every_n_iter=100)
])
all_hooks.extend(hooks)
scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
ops.add_to_collection(
ops.GraphKeys.SAVERS,
saver.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
defer_build=True))
chief_hooks = []
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
saver_hook_exists = any([
isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
model_fn_ops.training_chief_hooks)
])
if not saver_hook_exists:
chief_hooks = [
basic_session_run_hooks.CheckpointSaverHook(
self._model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=scaffold)
]
with monitored_session.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
scaffold=scaffold,
hooks=all_hooks + model_fn_ops.training_hooks,
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
config=None) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
summary_io.SummaryWriterCache.clear()
return loss
class _FeedFnHook(session_run_hook.SessionRunHook): class _FeedFnHook(session_run_hook.SessionRunHook):
"""Runs feed_fn and sets the feed_dict accordingly.""" """Runs feed_fn and sets the feed_dict accordingly."""
@ -1509,6 +1373,17 @@ class SKCompat(sklearn.BaseEstimator):
input_fn, feed_fn = _get_input_fn( input_fn, feed_fn = _get_input_fn(
x, None, input_fn=None, feed_fn=None, batch_size=batch_size, x, None, input_fn=None, feed_fn=None, batch_size=batch_size,
shuffle=False, epochs=1) shuffle=False, epochs=1)
return self._estimator._infer_model( results = list(
input_fn=input_fn, feed_fn=feed_fn, outputs=outputs, self._estimator._infer_model(
as_iterable=False) input_fn=input_fn,
feed_fn=feed_fn,
outputs=outputs,
as_iterable=True,
iterate_batches=True))
if not isinstance(results[0], dict):
return np.concatenate([output for output in results], axis=0)
return {
key: np.concatenate(
[output[key] for output in results], axis=0)
for key in results[0]
}