Simplified estimator logic by MonitoredSession.

Removed graph_action usage.
Change: 144126485
This commit is contained in:
Mustafa Ispir 2017-01-10 14:12:02 -08:00 committed by TensorFlower Gardener
parent 3e59f0540e
commit 61a6797c4f
2 changed files with 318 additions and 347 deletions

View File

@ -21,20 +21,28 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from tensorflow.contrib import framework
from tensorflow.contrib.factorization.python.ops import gmm_ops
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed as random_seed_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops.control_flow_ops import with_dependencies
from tensorflow.python.platform import tf_logging as logging
def _streaming_sum(scalar_tensor):
@ -44,7 +52,7 @@ def _streaming_sum(scalar_tensor):
return sum_metric, sum_update
class GMM(estimator.Estimator, TransformerMixin):
class GMM(estimator_lib.Estimator, TransformerMixin):
"""GMM clustering."""
SCORES = 'scores'
ASSIGNMENTS = 'assignments'
@ -116,7 +124,8 @@ class GMM(estimator.Estimator, TransformerMixin):
self._data_feeder = data_feeder.setup_train_data_feeder(x, None,
self._num_clusters,
self.batch_size)
self._train_model(
_legacy_train_model( # pylint: disable=protected-access
self,
input_fn=self._data_feeder.input_builder,
feed_fn=self._data_feeder.get_feed_dict_fn(),
steps=steps or self.steps,
@ -218,3 +227,90 @@ class GMM(estimator.Estimator, TransformerMixin):
self._covariance_type,
self._params)
return {GMM.SCORES: _streaming_sum(math_ops.reduce_sum(losses))}
# TODO(xavigonzalvo): delete this after implementing model-fn based Estimator.
def _legacy_train_model(estimator,
input_fn,
steps,
feed_fn=None,
init_op=None,
init_feed_fn=None,
init_fn=None,
device_fn=None,
monitors=None,
log_every_steps=100,
fail_on_nan_loss=True,
max_steps=None):
"""Legacy train function of Estimator."""
if hasattr(estimator.config, 'execution_mode'):
if estimator.config.execution_mode not in ('all', 'train'):
return
# Stagger startup of worker sessions based on task id.
sleep_secs = min(
estimator.config.training_worker_max_startup_secs,
estimator.config.task_id *
estimator.config.training_worker_session_startup_stagger_secs)
if sleep_secs:
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
estimator.config.task_id)
time.sleep(sleep_secs)
# Device allocation
device_fn = device_fn or estimator._device_fn # pylint: disable=protected-access
with ops.Graph().as_default() as g, g.device(device_fn):
random_seed_lib.set_random_seed(estimator.config.tf_random_seed)
global_step = framework.create_global_step(g)
features, labels = input_fn()
estimator._check_inputs(features, labels) # pylint: disable=protected-access
# The default return type of _get_train_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_train_ops returns a
# (train_op, loss) tuple. The following else-statement code covers these
# cases, but will soon be deleted after the subclasses are updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
train_ops = estimator._get_train_ops(features, labels) # pylint: disable=protected-access
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
train_op = train_ops.train_op
loss_op = train_ops.loss
if estimator.config.is_chief:
hooks = train_ops.training_chief_hooks + train_ops.training_hooks
else:
hooks = train_ops.training_hooks
else: # Legacy signature
if len(train_ops) != 2:
raise ValueError('Expected a tuple of train_op and loss, got {}'.format(
train_ops))
train_op = train_ops[0]
loss_op = train_ops[1]
hooks = []
hooks += monitor_lib.replace_monitors_with_hooks(monitors, estimator)
ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
return graph_actions._monitored_train( # pylint: disable=protected-access
graph=g,
output_dir=estimator.model_dir,
train_op=train_op,
loss_op=loss_op,
global_step_tensor=global_step,
init_op=init_op,
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
init_fn=init_fn,
log_every_steps=log_every_steps,
supervisor_is_chief=estimator.config.is_chief,
supervisor_master=estimator.config.master,
supervisor_save_model_secs=estimator.config.save_checkpoints_secs,
supervisor_save_model_steps=estimator.config.save_checkpoints_steps,
supervisor_save_summaries_steps=estimator.config.save_summary_steps,
keep_checkpoint_max=estimator.config.keep_checkpoint_max,
keep_checkpoint_every_n_hours=(
estimator.config.keep_checkpoint_every_n_hours),
feed_fn=feed_fn,
steps=steps,
fail_on_nan_loss=fail_on_nan_loss,
hooks=hooks,
max_steps=max_steps)

View File

@ -22,10 +22,8 @@ from __future__ import print_function
import abc
import copy
import inspect
import itertools
import os
import tempfile
import time
import numpy as np
import six
@ -39,10 +37,8 @@ from tensorflow.contrib.framework import deprecated_args
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import ops as contrib_ops
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
from tensorflow.contrib.learn.python.learn import trainable
@ -58,7 +54,6 @@ from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.contrib.training.python.training import evaluation
from tensorflow.core.framework import summary_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
@ -92,6 +87,25 @@ SCIKIT_DECOUPLE_INSTRUCTIONS = (
' est = Estimator(...) -> est = SKCompat(Estimator(...))')
def _verify_input_args(x, y, input_fn, feed_fn, batch_size):
"""Verifies validity of co-existance of input arguments."""
if input_fn is None:
if x is None:
raise ValueError('Either x or input_fn must be provided.')
if contrib_framework.is_tensor(x) or (y is not None and
contrib_framework.is_tensor(y)):
raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
if feed_fn is not None:
raise ValueError('Can not provide both feed_fn and x or y.')
else:
if (x is not None) or (y is not None):
raise ValueError('Can not provide both input_fn and x or y.')
if batch_size is not None:
raise ValueError('Can not provide both input_fn and batch_size.')
def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
"""Make inputs into input and feed functions.
@ -110,29 +124,17 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
Raises:
ValueError: Only one of `(x & y)` or `input_fn` must be provided.
"""
if input_fn is None:
if x is None:
raise ValueError('Either x or input_fn must be provided.')
if contrib_framework.is_tensor(x) or (y is not None and
contrib_framework.is_tensor(y)):
raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
if feed_fn is not None:
raise ValueError('Can not provide both feed_fn and x or y.')
df = data_feeder.setup_train_data_feeder(x, y, n_classes=None,
batch_size=batch_size,
shuffle=shuffle,
epochs=epochs)
return df.input_builder, df.get_feed_dict_fn()
if (x is not None) or (y is not None):
raise ValueError('Can not provide both input_fn and x or y.')
if batch_size is not None:
raise ValueError('Can not provide both input_fn and batch_size.')
return input_fn, feed_fn
_verify_input_args(x, y, input_fn, feed_fn, batch_size)
if input_fn is not None:
return input_fn, feed_fn
df = data_feeder.setup_train_data_feeder(
x,
y,
n_classes=None,
batch_size=batch_size,
shuffle=shuffle,
epochs=epochs)
return df.input_builder, df.get_feed_dict_fn()
def infer_real_valued_columns_from_input_fn(input_fn):
@ -311,9 +313,8 @@ def _write_dict_to_summary(output_dir,
dictionary: the `dict` to be written to summary file.
current_global_step: `int`, the current global step.
"""
logging.info(
'Saving dict for global step %d: %s' %
(current_global_step, _dict_to_str(dictionary)))
logging.info('Saving dict for global step %d: %s', current_global_step,
_dict_to_str(dictionary))
summary_writer = summary_io.SummaryWriterCache.get(output_dir)
summary_proto = summary_pb2.Summary()
for key in dictionary:
@ -404,15 +405,24 @@ class BaseEstimator(
"""
if (steps is not None) and (max_steps is not None):
raise ValueError('Can not provide both steps and max_steps.')
_verify_input_args(x, y, input_fn, None, batch_size)
if x is not None:
return SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
input_fn, feed_fn = _get_input_fn(x, y, input_fn, feed_fn=None,
batch_size=batch_size, shuffle=True,
epochs=None)
loss = self._train_model(input_fn=input_fn,
feed_fn=feed_fn,
steps=steps,
monitors=monitors,
max_steps=max_steps)
if max_steps is not None:
try:
start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
if max_steps <= start_step:
logging.info('Skipping training since max_steps has already saved.')
return None
except: # pylint: disable=bare-except
pass
hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
if steps is not None or max_steps is not None:
hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
loss = self._train_model(input_fn=input_fn, hooks=hooks)
logging.info('Loss for final step: %s.', loss)
return self
@ -485,9 +495,10 @@ class BaseEstimator(
`input_fn` or `feed_fn` is provided.
Or if `metrics` is not `None` or `dict`.
"""
input_fn, feed_fn = _get_input_fn(x, y, input_fn=input_fn,
feed_fn=feed_fn, batch_size=batch_size,
shuffle=False, epochs=1)
_verify_input_args(x, y, input_fn, feed_fn, batch_size)
if x is not None:
return SKCompat(self).score(x, y, batch_size, steps, metrics)
if metrics is not None and not isinstance(metrics, dict):
raise ValueError('Metrics argument should be None or dict. '
'Got %s.' % metrics)
@ -537,11 +548,15 @@ class BaseEstimator(
Raises:
ValueError: If x and input_fn are both provided or both `None`.
"""
input_fn, feed_fn = _get_input_fn(
x, None, input_fn=input_fn, feed_fn=None, batch_size=batch_size,
shuffle=False, epochs=1)
_verify_input_args(x, None, input_fn, None, batch_size)
if x is not None and not as_iterable:
return SKCompat(self).predict(x, batch_size)
input_fn, feed_fn = _get_input_fn(x, None, input_fn, None, batch_size)
return self._infer_model(
input_fn=input_fn, feed_fn=feed_fn, outputs=outputs,
input_fn=input_fn,
feed_fn=feed_fn,
outputs=outputs,
as_iterable=as_iterable)
def get_variable_value(self, name):
@ -728,91 +743,6 @@ class BaseEstimator(
self._labels_info = tensor_signature.create_signatures(labels)
logging.debug('Setting labels info to %s', str(self._labels_info))
def _train_model(self,
input_fn,
steps,
feed_fn=None,
init_op=None,
init_feed_fn=None,
init_fn=None,
device_fn=None,
monitors=None,
log_every_steps=100,
fail_on_nan_loss=True,
max_steps=None):
# TODO(wicke): Remove this once Model and associated code are gone.
if hasattr(self._config, 'execution_mode'):
if self._config.execution_mode not in ('all', 'train'):
return
# Stagger startup of worker sessions based on task id.
sleep_secs = min(
self._config.training_worker_max_startup_secs,
self._config.task_id *
self._config.training_worker_session_startup_stagger_secs)
if sleep_secs:
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
self._config.task_id)
time.sleep(sleep_secs)
# Device allocation
device_fn = device_fn or self._device_fn
self._graph = ops.Graph()
with self._graph.as_default() as g, g.device(device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
# The default return type of _get_train_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_train_ops returns a
# (train_op, loss) tuple. The following else-statement code covers these
# cases, but will soon be deleted after the subclasses are updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
train_ops = self._get_train_ops(features, labels)
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
train_op = train_ops.train_op
loss_op = train_ops.loss
if self.config.is_chief:
hooks = train_ops.training_chief_hooks + train_ops.training_hooks
else:
hooks = train_ops.training_hooks
else: # Legacy signature
if len(train_ops) != 2:
raise ValueError('Expected a tuple of train_op and loss, got {}'.
format(train_ops))
train_op = train_ops[0]
loss_op = train_ops[1]
hooks = []
hooks += monitor_lib.replace_monitors_with_hooks(monitors, self)
ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
return graph_actions._monitored_train( # pylint: disable=protected-access
graph=g,
output_dir=self._model_dir,
train_op=train_op,
loss_op=loss_op,
global_step_tensor=global_step,
init_op=init_op,
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
init_fn=init_fn,
log_every_steps=log_every_steps,
supervisor_is_chief=self.config.is_chief,
supervisor_master=self._config.master,
supervisor_save_model_secs=self._config.save_checkpoints_secs,
supervisor_save_model_steps=self._config.save_checkpoints_steps,
supervisor_save_summaries_steps=self._config.save_summary_steps,
keep_checkpoint_max=self._config.keep_checkpoint_max,
keep_checkpoint_every_n_hours=self._config.keep_checkpoint_every_n_hours,
feed_fn=feed_fn,
steps=steps,
fail_on_nan_loss=fail_on_nan_loss,
hooks=hooks,
max_steps=max_steps)
def _extract_metric_update_ops(self, eval_dict):
"""Separate update operations from metric value operations."""
update_ops = []
@ -915,8 +845,12 @@ class BaseEstimator(
return result[0]
return result
def _infer_model(
self, input_fn, feed_fn=None, outputs=None, as_iterable=True):
def _infer_model(self,
input_fn,
feed_fn=None,
outputs=None,
as_iterable=True,
iterate_batches=False):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
@ -927,103 +861,152 @@ class BaseEstimator(
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
# The default return type of _get_predict_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_predict_ops returns a
# `predictions` Tensor or dict or Tensors. The following else-statement
# code covers these cases, but will soon be deleted after the subclasses
# are updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
infer_ops = self._get_predict_ops(features)
if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
predictions = infer_ops.predictions
else: # Legacy signature
predictions = infer_ops
# If predictions is single output - wrap it into dict, and remember to
# return not a dict.
return_dict = isinstance(predictions, dict)
if not return_dict:
predictions = {'predictions': predictions}
# Filter what to run predictions on, if outputs provided.
if outputs:
existing_keys = predictions.keys()
predictions = {
key: value
for key, value in six.iteritems(predictions) if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
'provided %s.' % (existing_keys, outputs))
if as_iterable:
return self._infer_model_as_iterable(
checkpoint_path, predictions, feed_fn, return_dict)
infer_ops = self._call_legacy_get_predict_ops(features)
predictions = self._filter_predictions(infer_ops.predictions, outputs)
mon_sess = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path))
if not as_iterable:
with mon_sess:
if not mon_sess.should_stop():
return mon_sess.run(predictions, feed_fn() if feed_fn else None)
else:
return self._infer_model_single(
checkpoint_path, predictions, feed_fn, return_dict)
return self._predict_generator(mon_sess, predictions, feed_fn,
iterate_batches)
def _infer_model_single(
self, checkpoint_path, predictions, feed_fn, return_dict):
if feed_fn is None:
preds = graph_actions.infer(checkpoint_path, predictions)
else:
def _feed_fn():
while True:
yield feed_fn()
outputs = graph_actions.run_feeds(
output_dict=predictions,
feed_dicts=_feed_fn(),
restore_checkpoint_path=checkpoint_path)
preds = {
key: np.concatenate([output[key] for output in outputs], axis=0)
for key in predictions}
return preds if return_dict else preds['predictions']
def _infer_model_as_iterable(
self, checkpoint_path, predictions, feed_fn, return_dict):
if feed_fn is None:
# If there are no queue_runners, the input `predictions` is a
# constant, and we should stop after the first epoch. If,
# instead, there are queue_runners, eventually they should throw
# an `OutOfRangeError`.
graph = contrib_ops.get_graph_from_inputs(predictions.values())
if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
feed_dicts = itertools.repeat(None)
else:
feed_dicts = [None]
else:
def _feed_fn():
while True:
yield feed_fn()
feed_dicts = _feed_fn()
try:
for output_batch in graph_actions.run_feeds_iter(
output_dict=predictions,
feed_dicts=feed_dicts,
restore_checkpoint_path=checkpoint_path):
# Unpack batches into individual predictions
if return_dict:
first_tensor = list(output_batch.values())[0]
def _predict_generator(self, mon_sess, predictions, feed_fn, iterate_batches):
with mon_sess:
while not mon_sess.should_stop():
preds = mon_sess.run(predictions, feed_fn() if feed_fn else None)
if iterate_batches:
yield preds
elif not isinstance(predictions, dict):
for pred in preds:
yield pred
else:
first_tensor = list(preds.values())[0]
if isinstance(first_tensor, sparse_tensor.SparseTensorValue):
batch_length = first_tensor.dense_shape[0]
else:
batch_length = first_tensor.shape[0]
for i in range(batch_length):
yield {key: value[i] for key, value in six.iteritems(output_batch)}
else:
for pred in output_batch['predictions']:
yield pred
yield {key: value[i] for key, value in six.iteritems(preds)}
if self._is_input_constant(feed_fn, mon_sess.graph):
return
except errors.OutOfRangeError:
# We fall out of the above loop naturally if feed_fn raises StopIteration,
# or we catch an OutOfRangeError if we've reached the end of inputs.
logging.info('Reached end of inputs for predict_iter.')
def _is_input_constant(self, feed_fn, graph):
# If there are no queue_runners, the input `predictions` is a
# constant, and we should stop after the first epoch. If,
# instead, there are queue_runners, eventually they should throw
# an `OutOfRangeError`.
if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
return False
# data_feeder uses feed_fn to generate `OutOfRangeError`.
if feed_fn is not None:
return False
return True
def _filter_predictions(self, predictions, outputs):
if not outputs:
return predictions
if not isinstance(predictions, dict):
raise ValueError(
'outputs argument is not valid in case of non-dict predictions.')
existing_keys = predictions.keys()
predictions = {
key: value
for key, value in six.iteritems(predictions) if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
'provided %s.' % (existing_keys, outputs))
return predictions
def _train_model(self, input_fn, hooks):
all_hooks = []
self._graph = ops.Graph()
with self._graph.as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
model_fn_ops = self._call_legacy_get_train_ops(features, labels)
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
all_hooks.extend([
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
basic_session_run_hooks.LoggingTensorHook(
{
'loss': model_fn_ops.loss,
'step': global_step
},
every_n_iter=100)
])
all_hooks.extend(hooks)
scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
ops.add_to_collection(
ops.GraphKeys.SAVERS,
saver.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
defer_build=True))
chief_hooks = []
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
saver_hook_exists = any([
isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
model_fn_ops.training_chief_hooks)
])
if not saver_hook_exists:
chief_hooks = [
basic_session_run_hooks.CheckpointSaverHook(
self._model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=scaffold)
]
with monitored_session.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
scaffold=scaffold,
hooks=all_hooks + model_fn_ops.training_hooks,
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
config=None) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
summary_io.SummaryWriterCache.clear()
return loss
def _call_legacy_get_predict_ops(self, features):
# The default return type of _get_predict_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_predict_ops returns a
# `predictions` Tensor or dict or Tensors. The following else-statement
# code covers these cases, but will soon be deleted after the subclasses
# are updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
infer_ops = self._get_predict_ops(features)
if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
return infer_ops
return model_fn_lib.ModelFnOps(
mode=model_fn_lib.ModeKeys.INFER, predictions=infer_ops)
def _call_legacy_get_train_ops(self, features, labels):
train_ops = self._get_train_ops(features, labels)
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
return train_ops
return model_fn_lib.ModelFnOps(
mode=model_fn_lib.ModeKeys.TRAIN,
predictions=None,
loss=train_ops[1],
train_op=train_ops[0])
def _identity_feature_engineering_fn(features, labels):
@ -1177,17 +1160,6 @@ class Estimator(BaseEstimator):
"""
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
# TODO(ispir): delete this function after converting all legacy usages.
def _call_legacy_get_train_ops(self, features, labels):
train_ops = self._get_train_ops(features, labels)
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
return train_ops
return model_fn_lib.ModelFnOps(
mode=model_fn_lib.ModeKeys.TRAIN,
predictions=None,
loss=train_ops[1],
train_op=train_ops[0])
def _get_eval_ops(self, features, labels, metrics):
"""Method that builds model graph and returns evaluation ops.
@ -1343,114 +1315,6 @@ class Estimator(BaseEstimator):
return export_dir
@deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y',
'batch_size')
def fit(self,
x=None,
y=None,
input_fn=None,
steps=None,
batch_size=None,
monitors=None,
max_steps=None):
# pylint: disable=g-doc-args,g-doc-return-or-yield
"""See `Trainable`.
Raises:
ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
ValueError: If both `steps` and `max_steps` are not `None`.
"""
if (steps is not None) and (max_steps is not None):
raise ValueError('Can not provide both steps and max_steps.')
if max_steps is not None:
try:
start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
if max_steps <= start_step:
logging.info('Skipping training since max_steps has already saved.')
return None
except: # pylint: disable=bare-except
pass
hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
if steps is not None or max_steps is not None:
hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
input_fn, feed_fn = _get_input_fn(
x,
y,
input_fn,
feed_fn=None,
batch_size=batch_size,
shuffle=True,
epochs=None)
if feed_fn:
hooks.append(_FeedFnHook(feed_fn))
loss = self._train_model_v2(input_fn=input_fn, hooks=hooks)
logging.info('Loss for final step: %s.', loss)
return self
def _train_model_v2(self, input_fn, hooks):
all_hooks = []
self._graph = ops.Graph()
with self._graph.as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
model_fn_ops = self._call_legacy_get_train_ops(features, labels)
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
all_hooks.extend([
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
basic_session_run_hooks.LoggingTensorHook(
{
'loss': model_fn_ops.loss,
'step': global_step
},
every_n_iter=100)
])
all_hooks.extend(hooks)
scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
ops.add_to_collection(
ops.GraphKeys.SAVERS,
saver.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
defer_build=True))
chief_hooks = []
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
saver_hook_exists = any([
isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
model_fn_ops.training_chief_hooks)
])
if not saver_hook_exists:
chief_hooks = [
basic_session_run_hooks.CheckpointSaverHook(
self._model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=scaffold)
]
with monitored_session.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
scaffold=scaffold,
hooks=all_hooks + model_fn_ops.training_hooks,
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
config=None) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
summary_io.SummaryWriterCache.clear()
return loss
class _FeedFnHook(session_run_hook.SessionRunHook):
"""Runs feed_fn and sets the feed_dict accordingly."""
@ -1509,6 +1373,17 @@ class SKCompat(sklearn.BaseEstimator):
input_fn, feed_fn = _get_input_fn(
x, None, input_fn=None, feed_fn=None, batch_size=batch_size,
shuffle=False, epochs=1)
return self._estimator._infer_model(
input_fn=input_fn, feed_fn=feed_fn, outputs=outputs,
as_iterable=False)
results = list(
self._estimator._infer_model(
input_fn=input_fn,
feed_fn=feed_fn,
outputs=outputs,
as_iterable=True,
iterate_batches=True))
if not isinstance(results[0], dict):
return np.concatenate([output for output in results], axis=0)
return {
key: np.concatenate(
[output[key] for output in results], axis=0)
for key in results[0]
}