Simplified estimator logic by MonitoredSession.
Removed graph_action usage. Change: 144126485
This commit is contained in:
parent
3e59f0540e
commit
61a6797c4f
@ -21,20 +21,28 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib import framework
|
||||
from tensorflow.contrib.factorization.python.ops import gmm_ops
|
||||
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn import graph_actions
|
||||
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
||||
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed as random_seed_lib
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
def _streaming_sum(scalar_tensor):
|
||||
@ -44,7 +52,7 @@ def _streaming_sum(scalar_tensor):
|
||||
return sum_metric, sum_update
|
||||
|
||||
|
||||
class GMM(estimator.Estimator, TransformerMixin):
|
||||
class GMM(estimator_lib.Estimator, TransformerMixin):
|
||||
"""GMM clustering."""
|
||||
SCORES = 'scores'
|
||||
ASSIGNMENTS = 'assignments'
|
||||
@ -116,7 +124,8 @@ class GMM(estimator.Estimator, TransformerMixin):
|
||||
self._data_feeder = data_feeder.setup_train_data_feeder(x, None,
|
||||
self._num_clusters,
|
||||
self.batch_size)
|
||||
self._train_model(
|
||||
_legacy_train_model( # pylint: disable=protected-access
|
||||
self,
|
||||
input_fn=self._data_feeder.input_builder,
|
||||
feed_fn=self._data_feeder.get_feed_dict_fn(),
|
||||
steps=steps or self.steps,
|
||||
@ -218,3 +227,90 @@ class GMM(estimator.Estimator, TransformerMixin):
|
||||
self._covariance_type,
|
||||
self._params)
|
||||
return {GMM.SCORES: _streaming_sum(math_ops.reduce_sum(losses))}
|
||||
|
||||
|
||||
# TODO(xavigonzalvo): delete this after implementing model-fn based Estimator.
|
||||
def _legacy_train_model(estimator,
|
||||
input_fn,
|
||||
steps,
|
||||
feed_fn=None,
|
||||
init_op=None,
|
||||
init_feed_fn=None,
|
||||
init_fn=None,
|
||||
device_fn=None,
|
||||
monitors=None,
|
||||
log_every_steps=100,
|
||||
fail_on_nan_loss=True,
|
||||
max_steps=None):
|
||||
"""Legacy train function of Estimator."""
|
||||
if hasattr(estimator.config, 'execution_mode'):
|
||||
if estimator.config.execution_mode not in ('all', 'train'):
|
||||
return
|
||||
|
||||
# Stagger startup of worker sessions based on task id.
|
||||
sleep_secs = min(
|
||||
estimator.config.training_worker_max_startup_secs,
|
||||
estimator.config.task_id *
|
||||
estimator.config.training_worker_session_startup_stagger_secs)
|
||||
if sleep_secs:
|
||||
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
|
||||
estimator.config.task_id)
|
||||
time.sleep(sleep_secs)
|
||||
|
||||
# Device allocation
|
||||
device_fn = device_fn or estimator._device_fn # pylint: disable=protected-access
|
||||
|
||||
with ops.Graph().as_default() as g, g.device(device_fn):
|
||||
random_seed_lib.set_random_seed(estimator.config.tf_random_seed)
|
||||
global_step = framework.create_global_step(g)
|
||||
features, labels = input_fn()
|
||||
estimator._check_inputs(features, labels) # pylint: disable=protected-access
|
||||
|
||||
# The default return type of _get_train_ops is ModelFnOps. But there are
|
||||
# some subclasses of tf.contrib.learn.Estimator which override this
|
||||
# method and use the legacy signature, namely _get_train_ops returns a
|
||||
# (train_op, loss) tuple. The following else-statement code covers these
|
||||
# cases, but will soon be deleted after the subclasses are updated.
|
||||
# TODO(b/32664904): Update subclasses and delete the else-statement.
|
||||
train_ops = estimator._get_train_ops(features, labels) # pylint: disable=protected-access
|
||||
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
|
||||
train_op = train_ops.train_op
|
||||
loss_op = train_ops.loss
|
||||
if estimator.config.is_chief:
|
||||
hooks = train_ops.training_chief_hooks + train_ops.training_hooks
|
||||
else:
|
||||
hooks = train_ops.training_hooks
|
||||
else: # Legacy signature
|
||||
if len(train_ops) != 2:
|
||||
raise ValueError('Expected a tuple of train_op and loss, got {}'.format(
|
||||
train_ops))
|
||||
train_op = train_ops[0]
|
||||
loss_op = train_ops[1]
|
||||
hooks = []
|
||||
|
||||
hooks += monitor_lib.replace_monitors_with_hooks(monitors, estimator)
|
||||
|
||||
ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
|
||||
return graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
graph=g,
|
||||
output_dir=estimator.model_dir,
|
||||
train_op=train_op,
|
||||
loss_op=loss_op,
|
||||
global_step_tensor=global_step,
|
||||
init_op=init_op,
|
||||
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
|
||||
init_fn=init_fn,
|
||||
log_every_steps=log_every_steps,
|
||||
supervisor_is_chief=estimator.config.is_chief,
|
||||
supervisor_master=estimator.config.master,
|
||||
supervisor_save_model_secs=estimator.config.save_checkpoints_secs,
|
||||
supervisor_save_model_steps=estimator.config.save_checkpoints_steps,
|
||||
supervisor_save_summaries_steps=estimator.config.save_summary_steps,
|
||||
keep_checkpoint_max=estimator.config.keep_checkpoint_max,
|
||||
keep_checkpoint_every_n_hours=(
|
||||
estimator.config.keep_checkpoint_every_n_hours),
|
||||
feed_fn=feed_fn,
|
||||
steps=steps,
|
||||
fail_on_nan_loss=fail_on_nan_loss,
|
||||
hooks=hooks,
|
||||
max_steps=max_steps)
|
||||
|
@ -22,10 +22,8 @@ from __future__ import print_function
|
||||
import abc
|
||||
import copy
|
||||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
@ -39,10 +37,8 @@ from tensorflow.contrib.framework import deprecated_args
|
||||
from tensorflow.contrib.framework import list_variables
|
||||
from tensorflow.contrib.framework import load_variable
|
||||
from tensorflow.contrib.framework.python.framework import experimental
|
||||
from tensorflow.contrib.framework.python.ops import ops as contrib_ops
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import graph_actions
|
||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
@ -58,7 +54,6 @@ from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
|
||||
from tensorflow.contrib.training.python.training import evaluation
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -92,6 +87,25 @@ SCIKIT_DECOUPLE_INSTRUCTIONS = (
|
||||
' est = Estimator(...) -> est = SKCompat(Estimator(...))')
|
||||
|
||||
|
||||
def _verify_input_args(x, y, input_fn, feed_fn, batch_size):
|
||||
"""Verifies validity of co-existance of input arguments."""
|
||||
if input_fn is None:
|
||||
if x is None:
|
||||
raise ValueError('Either x or input_fn must be provided.')
|
||||
|
||||
if contrib_framework.is_tensor(x) or (y is not None and
|
||||
contrib_framework.is_tensor(y)):
|
||||
raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
|
||||
|
||||
if feed_fn is not None:
|
||||
raise ValueError('Can not provide both feed_fn and x or y.')
|
||||
else:
|
||||
if (x is not None) or (y is not None):
|
||||
raise ValueError('Can not provide both input_fn and x or y.')
|
||||
if batch_size is not None:
|
||||
raise ValueError('Can not provide both input_fn and batch_size.')
|
||||
|
||||
|
||||
def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
|
||||
"""Make inputs into input and feed functions.
|
||||
|
||||
@ -110,29 +124,17 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
|
||||
Raises:
|
||||
ValueError: Only one of `(x & y)` or `input_fn` must be provided.
|
||||
"""
|
||||
if input_fn is None:
|
||||
if x is None:
|
||||
raise ValueError('Either x or input_fn must be provided.')
|
||||
|
||||
if contrib_framework.is_tensor(x) or (y is not None and
|
||||
contrib_framework.is_tensor(y)):
|
||||
raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
|
||||
|
||||
if feed_fn is not None:
|
||||
raise ValueError('Can not provide both feed_fn and x or y.')
|
||||
|
||||
df = data_feeder.setup_train_data_feeder(x, y, n_classes=None,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
epochs=epochs)
|
||||
return df.input_builder, df.get_feed_dict_fn()
|
||||
|
||||
if (x is not None) or (y is not None):
|
||||
raise ValueError('Can not provide both input_fn and x or y.')
|
||||
if batch_size is not None:
|
||||
raise ValueError('Can not provide both input_fn and batch_size.')
|
||||
|
||||
return input_fn, feed_fn
|
||||
_verify_input_args(x, y, input_fn, feed_fn, batch_size)
|
||||
if input_fn is not None:
|
||||
return input_fn, feed_fn
|
||||
df = data_feeder.setup_train_data_feeder(
|
||||
x,
|
||||
y,
|
||||
n_classes=None,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
epochs=epochs)
|
||||
return df.input_builder, df.get_feed_dict_fn()
|
||||
|
||||
|
||||
def infer_real_valued_columns_from_input_fn(input_fn):
|
||||
@ -311,9 +313,8 @@ def _write_dict_to_summary(output_dir,
|
||||
dictionary: the `dict` to be written to summary file.
|
||||
current_global_step: `int`, the current global step.
|
||||
"""
|
||||
logging.info(
|
||||
'Saving dict for global step %d: %s' %
|
||||
(current_global_step, _dict_to_str(dictionary)))
|
||||
logging.info('Saving dict for global step %d: %s', current_global_step,
|
||||
_dict_to_str(dictionary))
|
||||
summary_writer = summary_io.SummaryWriterCache.get(output_dir)
|
||||
summary_proto = summary_pb2.Summary()
|
||||
for key in dictionary:
|
||||
@ -404,15 +405,24 @@ class BaseEstimator(
|
||||
"""
|
||||
if (steps is not None) and (max_steps is not None):
|
||||
raise ValueError('Can not provide both steps and max_steps.')
|
||||
_verify_input_args(x, y, input_fn, None, batch_size)
|
||||
if x is not None:
|
||||
return SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
|
||||
|
||||
input_fn, feed_fn = _get_input_fn(x, y, input_fn, feed_fn=None,
|
||||
batch_size=batch_size, shuffle=True,
|
||||
epochs=None)
|
||||
loss = self._train_model(input_fn=input_fn,
|
||||
feed_fn=feed_fn,
|
||||
steps=steps,
|
||||
monitors=monitors,
|
||||
max_steps=max_steps)
|
||||
if max_steps is not None:
|
||||
try:
|
||||
start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
|
||||
if max_steps <= start_step:
|
||||
logging.info('Skipping training since max_steps has already saved.')
|
||||
return None
|
||||
except: # pylint: disable=bare-except
|
||||
pass
|
||||
|
||||
hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
|
||||
if steps is not None or max_steps is not None:
|
||||
hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
|
||||
|
||||
loss = self._train_model(input_fn=input_fn, hooks=hooks)
|
||||
logging.info('Loss for final step: %s.', loss)
|
||||
return self
|
||||
|
||||
@ -485,9 +495,10 @@ class BaseEstimator(
|
||||
`input_fn` or `feed_fn` is provided.
|
||||
Or if `metrics` is not `None` or `dict`.
|
||||
"""
|
||||
input_fn, feed_fn = _get_input_fn(x, y, input_fn=input_fn,
|
||||
feed_fn=feed_fn, batch_size=batch_size,
|
||||
shuffle=False, epochs=1)
|
||||
_verify_input_args(x, y, input_fn, feed_fn, batch_size)
|
||||
if x is not None:
|
||||
return SKCompat(self).score(x, y, batch_size, steps, metrics)
|
||||
|
||||
if metrics is not None and not isinstance(metrics, dict):
|
||||
raise ValueError('Metrics argument should be None or dict. '
|
||||
'Got %s.' % metrics)
|
||||
@ -537,11 +548,15 @@ class BaseEstimator(
|
||||
Raises:
|
||||
ValueError: If x and input_fn are both provided or both `None`.
|
||||
"""
|
||||
input_fn, feed_fn = _get_input_fn(
|
||||
x, None, input_fn=input_fn, feed_fn=None, batch_size=batch_size,
|
||||
shuffle=False, epochs=1)
|
||||
_verify_input_args(x, None, input_fn, None, batch_size)
|
||||
if x is not None and not as_iterable:
|
||||
return SKCompat(self).predict(x, batch_size)
|
||||
|
||||
input_fn, feed_fn = _get_input_fn(x, None, input_fn, None, batch_size)
|
||||
return self._infer_model(
|
||||
input_fn=input_fn, feed_fn=feed_fn, outputs=outputs,
|
||||
input_fn=input_fn,
|
||||
feed_fn=feed_fn,
|
||||
outputs=outputs,
|
||||
as_iterable=as_iterable)
|
||||
|
||||
def get_variable_value(self, name):
|
||||
@ -728,91 +743,6 @@ class BaseEstimator(
|
||||
self._labels_info = tensor_signature.create_signatures(labels)
|
||||
logging.debug('Setting labels info to %s', str(self._labels_info))
|
||||
|
||||
def _train_model(self,
|
||||
input_fn,
|
||||
steps,
|
||||
feed_fn=None,
|
||||
init_op=None,
|
||||
init_feed_fn=None,
|
||||
init_fn=None,
|
||||
device_fn=None,
|
||||
monitors=None,
|
||||
log_every_steps=100,
|
||||
fail_on_nan_loss=True,
|
||||
max_steps=None):
|
||||
# TODO(wicke): Remove this once Model and associated code are gone.
|
||||
if hasattr(self._config, 'execution_mode'):
|
||||
if self._config.execution_mode not in ('all', 'train'):
|
||||
return
|
||||
|
||||
# Stagger startup of worker sessions based on task id.
|
||||
sleep_secs = min(
|
||||
self._config.training_worker_max_startup_secs,
|
||||
self._config.task_id *
|
||||
self._config.training_worker_session_startup_stagger_secs)
|
||||
if sleep_secs:
|
||||
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
|
||||
self._config.task_id)
|
||||
time.sleep(sleep_secs)
|
||||
|
||||
# Device allocation
|
||||
device_fn = device_fn or self._device_fn
|
||||
|
||||
self._graph = ops.Graph()
|
||||
with self._graph.as_default() as g, g.device(device_fn):
|
||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||
global_step = contrib_framework.create_global_step(g)
|
||||
features, labels = input_fn()
|
||||
self._check_inputs(features, labels)
|
||||
|
||||
# The default return type of _get_train_ops is ModelFnOps. But there are
|
||||
# some subclasses of tf.contrib.learn.Estimator which override this
|
||||
# method and use the legacy signature, namely _get_train_ops returns a
|
||||
# (train_op, loss) tuple. The following else-statement code covers these
|
||||
# cases, but will soon be deleted after the subclasses are updated.
|
||||
# TODO(b/32664904): Update subclasses and delete the else-statement.
|
||||
train_ops = self._get_train_ops(features, labels)
|
||||
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
|
||||
train_op = train_ops.train_op
|
||||
loss_op = train_ops.loss
|
||||
if self.config.is_chief:
|
||||
hooks = train_ops.training_chief_hooks + train_ops.training_hooks
|
||||
else:
|
||||
hooks = train_ops.training_hooks
|
||||
else: # Legacy signature
|
||||
if len(train_ops) != 2:
|
||||
raise ValueError('Expected a tuple of train_op and loss, got {}'.
|
||||
format(train_ops))
|
||||
train_op = train_ops[0]
|
||||
loss_op = train_ops[1]
|
||||
hooks = []
|
||||
|
||||
hooks += monitor_lib.replace_monitors_with_hooks(monitors, self)
|
||||
|
||||
ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
|
||||
return graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
graph=g,
|
||||
output_dir=self._model_dir,
|
||||
train_op=train_op,
|
||||
loss_op=loss_op,
|
||||
global_step_tensor=global_step,
|
||||
init_op=init_op,
|
||||
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
|
||||
init_fn=init_fn,
|
||||
log_every_steps=log_every_steps,
|
||||
supervisor_is_chief=self.config.is_chief,
|
||||
supervisor_master=self._config.master,
|
||||
supervisor_save_model_secs=self._config.save_checkpoints_secs,
|
||||
supervisor_save_model_steps=self._config.save_checkpoints_steps,
|
||||
supervisor_save_summaries_steps=self._config.save_summary_steps,
|
||||
keep_checkpoint_max=self._config.keep_checkpoint_max,
|
||||
keep_checkpoint_every_n_hours=self._config.keep_checkpoint_every_n_hours,
|
||||
feed_fn=feed_fn,
|
||||
steps=steps,
|
||||
fail_on_nan_loss=fail_on_nan_loss,
|
||||
hooks=hooks,
|
||||
max_steps=max_steps)
|
||||
|
||||
def _extract_metric_update_ops(self, eval_dict):
|
||||
"""Separate update operations from metric value operations."""
|
||||
update_ops = []
|
||||
@ -915,8 +845,12 @@ class BaseEstimator(
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
def _infer_model(
|
||||
self, input_fn, feed_fn=None, outputs=None, as_iterable=True):
|
||||
def _infer_model(self,
|
||||
input_fn,
|
||||
feed_fn=None,
|
||||
outputs=None,
|
||||
as_iterable=True,
|
||||
iterate_batches=False):
|
||||
# Check that model has been trained.
|
||||
checkpoint_path = saver.latest_checkpoint(self._model_dir)
|
||||
if not checkpoint_path:
|
||||
@ -927,103 +861,152 @@ class BaseEstimator(
|
||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||
contrib_framework.create_global_step(g)
|
||||
features = self._get_features_from_input_fn(input_fn)
|
||||
|
||||
# The default return type of _get_predict_ops is ModelFnOps. But there are
|
||||
# some subclasses of tf.contrib.learn.Estimator which override this
|
||||
# method and use the legacy signature, namely _get_predict_ops returns a
|
||||
# `predictions` Tensor or dict or Tensors. The following else-statement
|
||||
# code covers these cases, but will soon be deleted after the subclasses
|
||||
# are updated.
|
||||
# TODO(b/32664904): Update subclasses and delete the else-statement.
|
||||
infer_ops = self._get_predict_ops(features)
|
||||
if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
|
||||
predictions = infer_ops.predictions
|
||||
else: # Legacy signature
|
||||
predictions = infer_ops
|
||||
|
||||
# If predictions is single output - wrap it into dict, and remember to
|
||||
# return not a dict.
|
||||
return_dict = isinstance(predictions, dict)
|
||||
if not return_dict:
|
||||
predictions = {'predictions': predictions}
|
||||
|
||||
# Filter what to run predictions on, if outputs provided.
|
||||
if outputs:
|
||||
existing_keys = predictions.keys()
|
||||
predictions = {
|
||||
key: value
|
||||
for key, value in six.iteritems(predictions) if key in outputs
|
||||
}
|
||||
if not predictions:
|
||||
raise ValueError('Expected to run at least one output from %s, '
|
||||
'provided %s.' % (existing_keys, outputs))
|
||||
|
||||
if as_iterable:
|
||||
return self._infer_model_as_iterable(
|
||||
checkpoint_path, predictions, feed_fn, return_dict)
|
||||
infer_ops = self._call_legacy_get_predict_ops(features)
|
||||
predictions = self._filter_predictions(infer_ops.predictions, outputs)
|
||||
mon_sess = monitored_session.MonitoredSession(
|
||||
session_creator=monitored_session.ChiefSessionCreator(
|
||||
checkpoint_filename_with_path=checkpoint_path))
|
||||
if not as_iterable:
|
||||
with mon_sess:
|
||||
if not mon_sess.should_stop():
|
||||
return mon_sess.run(predictions, feed_fn() if feed_fn else None)
|
||||
else:
|
||||
return self._infer_model_single(
|
||||
checkpoint_path, predictions, feed_fn, return_dict)
|
||||
return self._predict_generator(mon_sess, predictions, feed_fn,
|
||||
iterate_batches)
|
||||
|
||||
def _infer_model_single(
|
||||
self, checkpoint_path, predictions, feed_fn, return_dict):
|
||||
if feed_fn is None:
|
||||
preds = graph_actions.infer(checkpoint_path, predictions)
|
||||
else:
|
||||
def _feed_fn():
|
||||
while True:
|
||||
yield feed_fn()
|
||||
|
||||
outputs = graph_actions.run_feeds(
|
||||
output_dict=predictions,
|
||||
feed_dicts=_feed_fn(),
|
||||
restore_checkpoint_path=checkpoint_path)
|
||||
preds = {
|
||||
key: np.concatenate([output[key] for output in outputs], axis=0)
|
||||
for key in predictions}
|
||||
|
||||
return preds if return_dict else preds['predictions']
|
||||
|
||||
def _infer_model_as_iterable(
|
||||
self, checkpoint_path, predictions, feed_fn, return_dict):
|
||||
if feed_fn is None:
|
||||
# If there are no queue_runners, the input `predictions` is a
|
||||
# constant, and we should stop after the first epoch. If,
|
||||
# instead, there are queue_runners, eventually they should throw
|
||||
# an `OutOfRangeError`.
|
||||
graph = contrib_ops.get_graph_from_inputs(predictions.values())
|
||||
if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
|
||||
feed_dicts = itertools.repeat(None)
|
||||
else:
|
||||
feed_dicts = [None]
|
||||
else:
|
||||
def _feed_fn():
|
||||
while True:
|
||||
yield feed_fn()
|
||||
feed_dicts = _feed_fn()
|
||||
|
||||
try:
|
||||
for output_batch in graph_actions.run_feeds_iter(
|
||||
output_dict=predictions,
|
||||
feed_dicts=feed_dicts,
|
||||
restore_checkpoint_path=checkpoint_path):
|
||||
# Unpack batches into individual predictions
|
||||
if return_dict:
|
||||
first_tensor = list(output_batch.values())[0]
|
||||
def _predict_generator(self, mon_sess, predictions, feed_fn, iterate_batches):
|
||||
with mon_sess:
|
||||
while not mon_sess.should_stop():
|
||||
preds = mon_sess.run(predictions, feed_fn() if feed_fn else None)
|
||||
if iterate_batches:
|
||||
yield preds
|
||||
elif not isinstance(predictions, dict):
|
||||
for pred in preds:
|
||||
yield pred
|
||||
else:
|
||||
first_tensor = list(preds.values())[0]
|
||||
if isinstance(first_tensor, sparse_tensor.SparseTensorValue):
|
||||
batch_length = first_tensor.dense_shape[0]
|
||||
else:
|
||||
batch_length = first_tensor.shape[0]
|
||||
for i in range(batch_length):
|
||||
yield {key: value[i] for key, value in six.iteritems(output_batch)}
|
||||
else:
|
||||
for pred in output_batch['predictions']:
|
||||
yield pred
|
||||
yield {key: value[i] for key, value in six.iteritems(preds)}
|
||||
if self._is_input_constant(feed_fn, mon_sess.graph):
|
||||
return
|
||||
|
||||
except errors.OutOfRangeError:
|
||||
# We fall out of the above loop naturally if feed_fn raises StopIteration,
|
||||
# or we catch an OutOfRangeError if we've reached the end of inputs.
|
||||
logging.info('Reached end of inputs for predict_iter.')
|
||||
def _is_input_constant(self, feed_fn, graph):
|
||||
# If there are no queue_runners, the input `predictions` is a
|
||||
# constant, and we should stop after the first epoch. If,
|
||||
# instead, there are queue_runners, eventually they should throw
|
||||
# an `OutOfRangeError`.
|
||||
if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
|
||||
return False
|
||||
# data_feeder uses feed_fn to generate `OutOfRangeError`.
|
||||
if feed_fn is not None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _filter_predictions(self, predictions, outputs):
|
||||
if not outputs:
|
||||
return predictions
|
||||
if not isinstance(predictions, dict):
|
||||
raise ValueError(
|
||||
'outputs argument is not valid in case of non-dict predictions.')
|
||||
existing_keys = predictions.keys()
|
||||
predictions = {
|
||||
key: value
|
||||
for key, value in six.iteritems(predictions) if key in outputs
|
||||
}
|
||||
if not predictions:
|
||||
raise ValueError('Expected to run at least one output from %s, '
|
||||
'provided %s.' % (existing_keys, outputs))
|
||||
return predictions
|
||||
|
||||
def _train_model(self, input_fn, hooks):
|
||||
all_hooks = []
|
||||
self._graph = ops.Graph()
|
||||
with self._graph.as_default() as g, g.device(self._device_fn):
|
||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||
global_step = contrib_framework.create_global_step(g)
|
||||
features, labels = input_fn()
|
||||
self._check_inputs(features, labels)
|
||||
model_fn_ops = self._call_legacy_get_train_ops(features, labels)
|
||||
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
|
||||
all_hooks.extend([
|
||||
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
|
||||
basic_session_run_hooks.LoggingTensorHook(
|
||||
{
|
||||
'loss': model_fn_ops.loss,
|
||||
'step': global_step
|
||||
},
|
||||
every_n_iter=100)
|
||||
])
|
||||
all_hooks.extend(hooks)
|
||||
|
||||
scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
|
||||
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
|
||||
ops.add_to_collection(
|
||||
ops.GraphKeys.SAVERS,
|
||||
saver.Saver(
|
||||
sharded=True,
|
||||
max_to_keep=self._config.keep_checkpoint_max,
|
||||
defer_build=True))
|
||||
|
||||
chief_hooks = []
|
||||
if (self._config.save_checkpoints_secs or
|
||||
self._config.save_checkpoints_steps):
|
||||
saver_hook_exists = any([
|
||||
isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
|
||||
for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
|
||||
model_fn_ops.training_chief_hooks)
|
||||
])
|
||||
if not saver_hook_exists:
|
||||
chief_hooks = [
|
||||
basic_session_run_hooks.CheckpointSaverHook(
|
||||
self._model_dir,
|
||||
save_secs=self._config.save_checkpoints_secs,
|
||||
save_steps=self._config.save_checkpoints_steps,
|
||||
scaffold=scaffold)
|
||||
]
|
||||
with monitored_session.MonitoredTrainingSession(
|
||||
master=self._config.master,
|
||||
is_chief=self._config.is_chief,
|
||||
checkpoint_dir=self._model_dir,
|
||||
scaffold=scaffold,
|
||||
hooks=all_hooks + model_fn_ops.training_hooks,
|
||||
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
|
||||
save_checkpoint_secs=0, # Saving is handled by a hook.
|
||||
save_summaries_steps=self._config.save_summary_steps,
|
||||
config=None) as mon_sess:
|
||||
loss = None
|
||||
while not mon_sess.should_stop():
|
||||
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
|
||||
summary_io.SummaryWriterCache.clear()
|
||||
return loss
|
||||
|
||||
def _call_legacy_get_predict_ops(self, features):
|
||||
# The default return type of _get_predict_ops is ModelFnOps. But there are
|
||||
# some subclasses of tf.contrib.learn.Estimator which override this
|
||||
# method and use the legacy signature, namely _get_predict_ops returns a
|
||||
# `predictions` Tensor or dict or Tensors. The following else-statement
|
||||
# code covers these cases, but will soon be deleted after the subclasses
|
||||
# are updated.
|
||||
# TODO(b/32664904): Update subclasses and delete the else-statement.
|
||||
infer_ops = self._get_predict_ops(features)
|
||||
if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
|
||||
return infer_ops
|
||||
return model_fn_lib.ModelFnOps(
|
||||
mode=model_fn_lib.ModeKeys.INFER, predictions=infer_ops)
|
||||
|
||||
def _call_legacy_get_train_ops(self, features, labels):
|
||||
train_ops = self._get_train_ops(features, labels)
|
||||
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
|
||||
return train_ops
|
||||
return model_fn_lib.ModelFnOps(
|
||||
mode=model_fn_lib.ModeKeys.TRAIN,
|
||||
predictions=None,
|
||||
loss=train_ops[1],
|
||||
train_op=train_ops[0])
|
||||
|
||||
|
||||
def _identity_feature_engineering_fn(features, labels):
|
||||
@ -1177,17 +1160,6 @@ class Estimator(BaseEstimator):
|
||||
"""
|
||||
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
|
||||
|
||||
# TODO(ispir): delete this function after converting all legacy usages.
|
||||
def _call_legacy_get_train_ops(self, features, labels):
|
||||
train_ops = self._get_train_ops(features, labels)
|
||||
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
|
||||
return train_ops
|
||||
return model_fn_lib.ModelFnOps(
|
||||
mode=model_fn_lib.ModeKeys.TRAIN,
|
||||
predictions=None,
|
||||
loss=train_ops[1],
|
||||
train_op=train_ops[0])
|
||||
|
||||
def _get_eval_ops(self, features, labels, metrics):
|
||||
"""Method that builds model graph and returns evaluation ops.
|
||||
|
||||
@ -1343,114 +1315,6 @@ class Estimator(BaseEstimator):
|
||||
|
||||
return export_dir
|
||||
|
||||
@deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y',
|
||||
'batch_size')
|
||||
def fit(self,
|
||||
x=None,
|
||||
y=None,
|
||||
input_fn=None,
|
||||
steps=None,
|
||||
batch_size=None,
|
||||
monitors=None,
|
||||
max_steps=None):
|
||||
# pylint: disable=g-doc-args,g-doc-return-or-yield
|
||||
"""See `Trainable`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
|
||||
ValueError: If both `steps` and `max_steps` are not `None`.
|
||||
"""
|
||||
if (steps is not None) and (max_steps is not None):
|
||||
raise ValueError('Can not provide both steps and max_steps.')
|
||||
if max_steps is not None:
|
||||
try:
|
||||
start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
|
||||
if max_steps <= start_step:
|
||||
logging.info('Skipping training since max_steps has already saved.')
|
||||
return None
|
||||
except: # pylint: disable=bare-except
|
||||
pass
|
||||
|
||||
hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
|
||||
if steps is not None or max_steps is not None:
|
||||
hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
|
||||
|
||||
input_fn, feed_fn = _get_input_fn(
|
||||
x,
|
||||
y,
|
||||
input_fn,
|
||||
feed_fn=None,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
epochs=None)
|
||||
if feed_fn:
|
||||
hooks.append(_FeedFnHook(feed_fn))
|
||||
loss = self._train_model_v2(input_fn=input_fn, hooks=hooks)
|
||||
logging.info('Loss for final step: %s.', loss)
|
||||
return self
|
||||
|
||||
def _train_model_v2(self, input_fn, hooks):
|
||||
all_hooks = []
|
||||
self._graph = ops.Graph()
|
||||
with self._graph.as_default() as g, g.device(self._device_fn):
|
||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||
global_step = contrib_framework.create_global_step(g)
|
||||
features, labels = input_fn()
|
||||
self._check_inputs(features, labels)
|
||||
model_fn_ops = self._call_legacy_get_train_ops(features, labels)
|
||||
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
|
||||
all_hooks.extend([
|
||||
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
|
||||
basic_session_run_hooks.LoggingTensorHook(
|
||||
{
|
||||
'loss': model_fn_ops.loss,
|
||||
'step': global_step
|
||||
},
|
||||
every_n_iter=100)
|
||||
])
|
||||
all_hooks.extend(hooks)
|
||||
|
||||
scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
|
||||
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
|
||||
ops.add_to_collection(
|
||||
ops.GraphKeys.SAVERS,
|
||||
saver.Saver(
|
||||
sharded=True,
|
||||
max_to_keep=self._config.keep_checkpoint_max,
|
||||
defer_build=True))
|
||||
|
||||
chief_hooks = []
|
||||
if (self._config.save_checkpoints_secs or
|
||||
self._config.save_checkpoints_steps):
|
||||
saver_hook_exists = any([
|
||||
isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
|
||||
for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
|
||||
model_fn_ops.training_chief_hooks)
|
||||
])
|
||||
if not saver_hook_exists:
|
||||
chief_hooks = [
|
||||
basic_session_run_hooks.CheckpointSaverHook(
|
||||
self._model_dir,
|
||||
save_secs=self._config.save_checkpoints_secs,
|
||||
save_steps=self._config.save_checkpoints_steps,
|
||||
scaffold=scaffold)
|
||||
]
|
||||
with monitored_session.MonitoredTrainingSession(
|
||||
master=self._config.master,
|
||||
is_chief=self._config.is_chief,
|
||||
checkpoint_dir=self._model_dir,
|
||||
scaffold=scaffold,
|
||||
hooks=all_hooks + model_fn_ops.training_hooks,
|
||||
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
|
||||
save_checkpoint_secs=0, # Saving is handled by a hook.
|
||||
save_summaries_steps=self._config.save_summary_steps,
|
||||
config=None) as mon_sess:
|
||||
loss = None
|
||||
while not mon_sess.should_stop():
|
||||
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
|
||||
summary_io.SummaryWriterCache.clear()
|
||||
return loss
|
||||
|
||||
|
||||
class _FeedFnHook(session_run_hook.SessionRunHook):
|
||||
"""Runs feed_fn and sets the feed_dict accordingly."""
|
||||
@ -1509,6 +1373,17 @@ class SKCompat(sklearn.BaseEstimator):
|
||||
input_fn, feed_fn = _get_input_fn(
|
||||
x, None, input_fn=None, feed_fn=None, batch_size=batch_size,
|
||||
shuffle=False, epochs=1)
|
||||
return self._estimator._infer_model(
|
||||
input_fn=input_fn, feed_fn=feed_fn, outputs=outputs,
|
||||
as_iterable=False)
|
||||
results = list(
|
||||
self._estimator._infer_model(
|
||||
input_fn=input_fn,
|
||||
feed_fn=feed_fn,
|
||||
outputs=outputs,
|
||||
as_iterable=True,
|
||||
iterate_batches=True))
|
||||
if not isinstance(results[0], dict):
|
||||
return np.concatenate([output for output in results], axis=0)
|
||||
return {
|
||||
key: np.concatenate(
|
||||
[output[key] for output in results], axis=0)
|
||||
for key in results[0]
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user