diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py index dd7e9a34551..86450d4bbd2 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm.py +++ b/tensorflow/contrib/factorization/python/ops/gmm.py @@ -21,20 +21,28 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + import numpy as np from tensorflow.contrib import framework from tensorflow.contrib.factorization.python.ops import gmm_ops from tensorflow.contrib.framework.python.framework import checkpoint_utils from tensorflow.contrib.framework.python.ops import variables -from tensorflow.contrib.learn.python.learn.estimators import estimator +from tensorflow.contrib.learn.python.learn import graph_actions +from tensorflow.contrib.learn.python.learn import monitors as monitor_lib +from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib +from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin from tensorflow.contrib.learn.python.learn.learn_io import data_feeder from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed as random_seed_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops.control_flow_ops import with_dependencies +from tensorflow.python.platform import tf_logging as logging def _streaming_sum(scalar_tensor): @@ -44,7 +52,7 @@ def _streaming_sum(scalar_tensor): return sum_metric, sum_update -class GMM(estimator.Estimator, TransformerMixin): +class GMM(estimator_lib.Estimator, TransformerMixin): """GMM clustering.""" SCORES = 'scores' ASSIGNMENTS = 'assignments' @@ -116,7 +124,8 @@ class GMM(estimator.Estimator, TransformerMixin): self._data_feeder = data_feeder.setup_train_data_feeder(x, None, self._num_clusters, self.batch_size) - self._train_model( + _legacy_train_model( # pylint: disable=protected-access + self, input_fn=self._data_feeder.input_builder, feed_fn=self._data_feeder.get_feed_dict_fn(), steps=steps or self.steps, @@ -218,3 +227,90 @@ class GMM(estimator.Estimator, TransformerMixin): self._covariance_type, self._params) return {GMM.SCORES: _streaming_sum(math_ops.reduce_sum(losses))} + + +# TODO(xavigonzalvo): delete this after implementing model-fn based Estimator. +def _legacy_train_model(estimator, + input_fn, + steps, + feed_fn=None, + init_op=None, + init_feed_fn=None, + init_fn=None, + device_fn=None, + monitors=None, + log_every_steps=100, + fail_on_nan_loss=True, + max_steps=None): + """Legacy train function of Estimator.""" + if hasattr(estimator.config, 'execution_mode'): + if estimator.config.execution_mode not in ('all', 'train'): + return + + # Stagger startup of worker sessions based on task id. + sleep_secs = min( + estimator.config.training_worker_max_startup_secs, + estimator.config.task_id * + estimator.config.training_worker_session_startup_stagger_secs) + if sleep_secs: + logging.info('Waiting %d secs before starting task %d.', sleep_secs, + estimator.config.task_id) + time.sleep(sleep_secs) + + # Device allocation + device_fn = device_fn or estimator._device_fn # pylint: disable=protected-access + + with ops.Graph().as_default() as g, g.device(device_fn): + random_seed_lib.set_random_seed(estimator.config.tf_random_seed) + global_step = framework.create_global_step(g) + features, labels = input_fn() + estimator._check_inputs(features, labels) # pylint: disable=protected-access + + # The default return type of _get_train_ops is ModelFnOps. But there are + # some subclasses of tf.contrib.learn.Estimator which override this + # method and use the legacy signature, namely _get_train_ops returns a + # (train_op, loss) tuple. The following else-statement code covers these + # cases, but will soon be deleted after the subclasses are updated. + # TODO(b/32664904): Update subclasses and delete the else-statement. + train_ops = estimator._get_train_ops(features, labels) # pylint: disable=protected-access + if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature + train_op = train_ops.train_op + loss_op = train_ops.loss + if estimator.config.is_chief: + hooks = train_ops.training_chief_hooks + train_ops.training_hooks + else: + hooks = train_ops.training_hooks + else: # Legacy signature + if len(train_ops) != 2: + raise ValueError('Expected a tuple of train_op and loss, got {}'.format( + train_ops)) + train_op = train_ops[0] + loss_op = train_ops[1] + hooks = [] + + hooks += monitor_lib.replace_monitors_with_hooks(monitors, estimator) + + ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op) + return graph_actions._monitored_train( # pylint: disable=protected-access + graph=g, + output_dir=estimator.model_dir, + train_op=train_op, + loss_op=loss_op, + global_step_tensor=global_step, + init_op=init_op, + init_feed_dict=init_feed_fn() if init_feed_fn is not None else None, + init_fn=init_fn, + log_every_steps=log_every_steps, + supervisor_is_chief=estimator.config.is_chief, + supervisor_master=estimator.config.master, + supervisor_save_model_secs=estimator.config.save_checkpoints_secs, + supervisor_save_model_steps=estimator.config.save_checkpoints_steps, + supervisor_save_summaries_steps=estimator.config.save_summary_steps, + keep_checkpoint_max=estimator.config.keep_checkpoint_max, + keep_checkpoint_every_n_hours=( + estimator.config.keep_checkpoint_every_n_hours), + feed_fn=feed_fn, + steps=steps, + fail_on_nan_loss=fail_on_nan_loss, + hooks=hooks, + max_steps=max_steps) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index dce10e7b0f9..467d31c3317 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -22,10 +22,8 @@ from __future__ import print_function import abc import copy import inspect -import itertools import os import tempfile -import time import numpy as np import six @@ -39,10 +37,8 @@ from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework import list_variables from tensorflow.contrib.framework import load_variable from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.framework.python.ops import ops as contrib_ops from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.learn.python.learn import evaluable -from tensorflow.contrib.learn.python.learn import graph_actions from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn import monitors as monitor_lib from tensorflow.contrib.learn.python.learn import trainable @@ -58,7 +54,6 @@ from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils from tensorflow.contrib.training.python.training import evaluation from tensorflow.core.framework import summary_pb2 from tensorflow.python.client import session as tf_session -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor @@ -92,6 +87,25 @@ SCIKIT_DECOUPLE_INSTRUCTIONS = ( ' est = Estimator(...) -> est = SKCompat(Estimator(...))') +def _verify_input_args(x, y, input_fn, feed_fn, batch_size): + """Verifies validity of co-existance of input arguments.""" + if input_fn is None: + if x is None: + raise ValueError('Either x or input_fn must be provided.') + + if contrib_framework.is_tensor(x) or (y is not None and + contrib_framework.is_tensor(y)): + raise ValueError('Inputs cannot be tensors. Please provide input_fn.') + + if feed_fn is not None: + raise ValueError('Can not provide both feed_fn and x or y.') + else: + if (x is not None) or (y is not None): + raise ValueError('Can not provide both input_fn and x or y.') + if batch_size is not None: + raise ValueError('Can not provide both input_fn and batch_size.') + + def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1): """Make inputs into input and feed functions. @@ -110,29 +124,17 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1): Raises: ValueError: Only one of `(x & y)` or `input_fn` must be provided. """ - if input_fn is None: - if x is None: - raise ValueError('Either x or input_fn must be provided.') - - if contrib_framework.is_tensor(x) or (y is not None and - contrib_framework.is_tensor(y)): - raise ValueError('Inputs cannot be tensors. Please provide input_fn.') - - if feed_fn is not None: - raise ValueError('Can not provide both feed_fn and x or y.') - - df = data_feeder.setup_train_data_feeder(x, y, n_classes=None, - batch_size=batch_size, - shuffle=shuffle, - epochs=epochs) - return df.input_builder, df.get_feed_dict_fn() - - if (x is not None) or (y is not None): - raise ValueError('Can not provide both input_fn and x or y.') - if batch_size is not None: - raise ValueError('Can not provide both input_fn and batch_size.') - - return input_fn, feed_fn + _verify_input_args(x, y, input_fn, feed_fn, batch_size) + if input_fn is not None: + return input_fn, feed_fn + df = data_feeder.setup_train_data_feeder( + x, + y, + n_classes=None, + batch_size=batch_size, + shuffle=shuffle, + epochs=epochs) + return df.input_builder, df.get_feed_dict_fn() def infer_real_valued_columns_from_input_fn(input_fn): @@ -311,9 +313,8 @@ def _write_dict_to_summary(output_dir, dictionary: the `dict` to be written to summary file. current_global_step: `int`, the current global step. """ - logging.info( - 'Saving dict for global step %d: %s' % - (current_global_step, _dict_to_str(dictionary))) + logging.info('Saving dict for global step %d: %s', current_global_step, + _dict_to_str(dictionary)) summary_writer = summary_io.SummaryWriterCache.get(output_dir) summary_proto = summary_pb2.Summary() for key in dictionary: @@ -404,15 +405,24 @@ class BaseEstimator( """ if (steps is not None) and (max_steps is not None): raise ValueError('Can not provide both steps and max_steps.') + _verify_input_args(x, y, input_fn, None, batch_size) + if x is not None: + return SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors) - input_fn, feed_fn = _get_input_fn(x, y, input_fn, feed_fn=None, - batch_size=batch_size, shuffle=True, - epochs=None) - loss = self._train_model(input_fn=input_fn, - feed_fn=feed_fn, - steps=steps, - monitors=monitors, - max_steps=max_steps) + if max_steps is not None: + try: + start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP) + if max_steps <= start_step: + logging.info('Skipping training since max_steps has already saved.') + return None + except: # pylint: disable=bare-except + pass + + hooks = monitor_lib.replace_monitors_with_hooks(monitors, self) + if steps is not None or max_steps is not None: + hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps)) + + loss = self._train_model(input_fn=input_fn, hooks=hooks) logging.info('Loss for final step: %s.', loss) return self @@ -485,9 +495,10 @@ class BaseEstimator( `input_fn` or `feed_fn` is provided. Or if `metrics` is not `None` or `dict`. """ - input_fn, feed_fn = _get_input_fn(x, y, input_fn=input_fn, - feed_fn=feed_fn, batch_size=batch_size, - shuffle=False, epochs=1) + _verify_input_args(x, y, input_fn, feed_fn, batch_size) + if x is not None: + return SKCompat(self).score(x, y, batch_size, steps, metrics) + if metrics is not None and not isinstance(metrics, dict): raise ValueError('Metrics argument should be None or dict. ' 'Got %s.' % metrics) @@ -537,11 +548,15 @@ class BaseEstimator( Raises: ValueError: If x and input_fn are both provided or both `None`. """ - input_fn, feed_fn = _get_input_fn( - x, None, input_fn=input_fn, feed_fn=None, batch_size=batch_size, - shuffle=False, epochs=1) + _verify_input_args(x, None, input_fn, None, batch_size) + if x is not None and not as_iterable: + return SKCompat(self).predict(x, batch_size) + + input_fn, feed_fn = _get_input_fn(x, None, input_fn, None, batch_size) return self._infer_model( - input_fn=input_fn, feed_fn=feed_fn, outputs=outputs, + input_fn=input_fn, + feed_fn=feed_fn, + outputs=outputs, as_iterable=as_iterable) def get_variable_value(self, name): @@ -728,91 +743,6 @@ class BaseEstimator( self._labels_info = tensor_signature.create_signatures(labels) logging.debug('Setting labels info to %s', str(self._labels_info)) - def _train_model(self, - input_fn, - steps, - feed_fn=None, - init_op=None, - init_feed_fn=None, - init_fn=None, - device_fn=None, - monitors=None, - log_every_steps=100, - fail_on_nan_loss=True, - max_steps=None): - # TODO(wicke): Remove this once Model and associated code are gone. - if hasattr(self._config, 'execution_mode'): - if self._config.execution_mode not in ('all', 'train'): - return - - # Stagger startup of worker sessions based on task id. - sleep_secs = min( - self._config.training_worker_max_startup_secs, - self._config.task_id * - self._config.training_worker_session_startup_stagger_secs) - if sleep_secs: - logging.info('Waiting %d secs before starting task %d.', sleep_secs, - self._config.task_id) - time.sleep(sleep_secs) - - # Device allocation - device_fn = device_fn or self._device_fn - - self._graph = ops.Graph() - with self._graph.as_default() as g, g.device(device_fn): - random_seed.set_random_seed(self._config.tf_random_seed) - global_step = contrib_framework.create_global_step(g) - features, labels = input_fn() - self._check_inputs(features, labels) - - # The default return type of _get_train_ops is ModelFnOps. But there are - # some subclasses of tf.contrib.learn.Estimator which override this - # method and use the legacy signature, namely _get_train_ops returns a - # (train_op, loss) tuple. The following else-statement code covers these - # cases, but will soon be deleted after the subclasses are updated. - # TODO(b/32664904): Update subclasses and delete the else-statement. - train_ops = self._get_train_ops(features, labels) - if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature - train_op = train_ops.train_op - loss_op = train_ops.loss - if self.config.is_chief: - hooks = train_ops.training_chief_hooks + train_ops.training_hooks - else: - hooks = train_ops.training_hooks - else: # Legacy signature - if len(train_ops) != 2: - raise ValueError('Expected a tuple of train_op and loss, got {}'. - format(train_ops)) - train_op = train_ops[0] - loss_op = train_ops[1] - hooks = [] - - hooks += monitor_lib.replace_monitors_with_hooks(monitors, self) - - ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op) - return graph_actions._monitored_train( # pylint: disable=protected-access - graph=g, - output_dir=self._model_dir, - train_op=train_op, - loss_op=loss_op, - global_step_tensor=global_step, - init_op=init_op, - init_feed_dict=init_feed_fn() if init_feed_fn is not None else None, - init_fn=init_fn, - log_every_steps=log_every_steps, - supervisor_is_chief=self.config.is_chief, - supervisor_master=self._config.master, - supervisor_save_model_secs=self._config.save_checkpoints_secs, - supervisor_save_model_steps=self._config.save_checkpoints_steps, - supervisor_save_summaries_steps=self._config.save_summary_steps, - keep_checkpoint_max=self._config.keep_checkpoint_max, - keep_checkpoint_every_n_hours=self._config.keep_checkpoint_every_n_hours, - feed_fn=feed_fn, - steps=steps, - fail_on_nan_loss=fail_on_nan_loss, - hooks=hooks, - max_steps=max_steps) - def _extract_metric_update_ops(self, eval_dict): """Separate update operations from metric value operations.""" update_ops = [] @@ -915,8 +845,12 @@ class BaseEstimator( return result[0] return result - def _infer_model( - self, input_fn, feed_fn=None, outputs=None, as_iterable=True): + def _infer_model(self, + input_fn, + feed_fn=None, + outputs=None, + as_iterable=True, + iterate_batches=False): # Check that model has been trained. checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: @@ -927,103 +861,152 @@ class BaseEstimator( random_seed.set_random_seed(self._config.tf_random_seed) contrib_framework.create_global_step(g) features = self._get_features_from_input_fn(input_fn) - - # The default return type of _get_predict_ops is ModelFnOps. But there are - # some subclasses of tf.contrib.learn.Estimator which override this - # method and use the legacy signature, namely _get_predict_ops returns a - # `predictions` Tensor or dict or Tensors. The following else-statement - # code covers these cases, but will soon be deleted after the subclasses - # are updated. - # TODO(b/32664904): Update subclasses and delete the else-statement. - infer_ops = self._get_predict_ops(features) - if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature - predictions = infer_ops.predictions - else: # Legacy signature - predictions = infer_ops - - # If predictions is single output - wrap it into dict, and remember to - # return not a dict. - return_dict = isinstance(predictions, dict) - if not return_dict: - predictions = {'predictions': predictions} - - # Filter what to run predictions on, if outputs provided. - if outputs: - existing_keys = predictions.keys() - predictions = { - key: value - for key, value in six.iteritems(predictions) if key in outputs - } - if not predictions: - raise ValueError('Expected to run at least one output from %s, ' - 'provided %s.' % (existing_keys, outputs)) - - if as_iterable: - return self._infer_model_as_iterable( - checkpoint_path, predictions, feed_fn, return_dict) + infer_ops = self._call_legacy_get_predict_ops(features) + predictions = self._filter_predictions(infer_ops.predictions, outputs) + mon_sess = monitored_session.MonitoredSession( + session_creator=monitored_session.ChiefSessionCreator( + checkpoint_filename_with_path=checkpoint_path)) + if not as_iterable: + with mon_sess: + if not mon_sess.should_stop(): + return mon_sess.run(predictions, feed_fn() if feed_fn else None) else: - return self._infer_model_single( - checkpoint_path, predictions, feed_fn, return_dict) + return self._predict_generator(mon_sess, predictions, feed_fn, + iterate_batches) - def _infer_model_single( - self, checkpoint_path, predictions, feed_fn, return_dict): - if feed_fn is None: - preds = graph_actions.infer(checkpoint_path, predictions) - else: - def _feed_fn(): - while True: - yield feed_fn() - - outputs = graph_actions.run_feeds( - output_dict=predictions, - feed_dicts=_feed_fn(), - restore_checkpoint_path=checkpoint_path) - preds = { - key: np.concatenate([output[key] for output in outputs], axis=0) - for key in predictions} - - return preds if return_dict else preds['predictions'] - - def _infer_model_as_iterable( - self, checkpoint_path, predictions, feed_fn, return_dict): - if feed_fn is None: - # If there are no queue_runners, the input `predictions` is a - # constant, and we should stop after the first epoch. If, - # instead, there are queue_runners, eventually they should throw - # an `OutOfRangeError`. - graph = contrib_ops.get_graph_from_inputs(predictions.values()) - if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS): - feed_dicts = itertools.repeat(None) - else: - feed_dicts = [None] - else: - def _feed_fn(): - while True: - yield feed_fn() - feed_dicts = _feed_fn() - - try: - for output_batch in graph_actions.run_feeds_iter( - output_dict=predictions, - feed_dicts=feed_dicts, - restore_checkpoint_path=checkpoint_path): - # Unpack batches into individual predictions - if return_dict: - first_tensor = list(output_batch.values())[0] + def _predict_generator(self, mon_sess, predictions, feed_fn, iterate_batches): + with mon_sess: + while not mon_sess.should_stop(): + preds = mon_sess.run(predictions, feed_fn() if feed_fn else None) + if iterate_batches: + yield preds + elif not isinstance(predictions, dict): + for pred in preds: + yield pred + else: + first_tensor = list(preds.values())[0] if isinstance(first_tensor, sparse_tensor.SparseTensorValue): batch_length = first_tensor.dense_shape[0] else: batch_length = first_tensor.shape[0] for i in range(batch_length): - yield {key: value[i] for key, value in six.iteritems(output_batch)} - else: - for pred in output_batch['predictions']: - yield pred + yield {key: value[i] for key, value in six.iteritems(preds)} + if self._is_input_constant(feed_fn, mon_sess.graph): + return - except errors.OutOfRangeError: - # We fall out of the above loop naturally if feed_fn raises StopIteration, - # or we catch an OutOfRangeError if we've reached the end of inputs. - logging.info('Reached end of inputs for predict_iter.') + def _is_input_constant(self, feed_fn, graph): + # If there are no queue_runners, the input `predictions` is a + # constant, and we should stop after the first epoch. If, + # instead, there are queue_runners, eventually they should throw + # an `OutOfRangeError`. + if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS): + return False + # data_feeder uses feed_fn to generate `OutOfRangeError`. + if feed_fn is not None: + return False + return True + + def _filter_predictions(self, predictions, outputs): + if not outputs: + return predictions + if not isinstance(predictions, dict): + raise ValueError( + 'outputs argument is not valid in case of non-dict predictions.') + existing_keys = predictions.keys() + predictions = { + key: value + for key, value in six.iteritems(predictions) if key in outputs + } + if not predictions: + raise ValueError('Expected to run at least one output from %s, ' + 'provided %s.' % (existing_keys, outputs)) + return predictions + + def _train_model(self, input_fn, hooks): + all_hooks = [] + self._graph = ops.Graph() + with self._graph.as_default() as g, g.device(self._device_fn): + random_seed.set_random_seed(self._config.tf_random_seed) + global_step = contrib_framework.create_global_step(g) + features, labels = input_fn() + self._check_inputs(features, labels) + model_fn_ops = self._call_legacy_get_train_ops(features, labels) + ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss) + all_hooks.extend([ + basic_session_run_hooks.NanTensorHook(model_fn_ops.loss), + basic_session_run_hooks.LoggingTensorHook( + { + 'loss': model_fn_ops.loss, + 'step': global_step + }, + every_n_iter=100) + ]) + all_hooks.extend(hooks) + + scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold() + if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): + ops.add_to_collection( + ops.GraphKeys.SAVERS, + saver.Saver( + sharded=True, + max_to_keep=self._config.keep_checkpoint_max, + defer_build=True)) + + chief_hooks = [] + if (self._config.save_checkpoints_secs or + self._config.save_checkpoints_steps): + saver_hook_exists = any([ + isinstance(h, basic_session_run_hooks.CheckpointSaverHook) + for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks + + model_fn_ops.training_chief_hooks) + ]) + if not saver_hook_exists: + chief_hooks = [ + basic_session_run_hooks.CheckpointSaverHook( + self._model_dir, + save_secs=self._config.save_checkpoints_secs, + save_steps=self._config.save_checkpoints_steps, + scaffold=scaffold) + ] + with monitored_session.MonitoredTrainingSession( + master=self._config.master, + is_chief=self._config.is_chief, + checkpoint_dir=self._model_dir, + scaffold=scaffold, + hooks=all_hooks + model_fn_ops.training_hooks, + chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks, + save_checkpoint_secs=0, # Saving is handled by a hook. + save_summaries_steps=self._config.save_summary_steps, + config=None) as mon_sess: + loss = None + while not mon_sess.should_stop(): + _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss]) + summary_io.SummaryWriterCache.clear() + return loss + + def _call_legacy_get_predict_ops(self, features): + # The default return type of _get_predict_ops is ModelFnOps. But there are + # some subclasses of tf.contrib.learn.Estimator which override this + # method and use the legacy signature, namely _get_predict_ops returns a + # `predictions` Tensor or dict or Tensors. The following else-statement + # code covers these cases, but will soon be deleted after the subclasses + # are updated. + # TODO(b/32664904): Update subclasses and delete the else-statement. + infer_ops = self._get_predict_ops(features) + if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature + return infer_ops + return model_fn_lib.ModelFnOps( + mode=model_fn_lib.ModeKeys.INFER, predictions=infer_ops) + + def _call_legacy_get_train_ops(self, features, labels): + train_ops = self._get_train_ops(features, labels) + if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature + return train_ops + return model_fn_lib.ModelFnOps( + mode=model_fn_lib.ModeKeys.TRAIN, + predictions=None, + loss=train_ops[1], + train_op=train_ops[0]) def _identity_feature_engineering_fn(features, labels): @@ -1177,17 +1160,6 @@ class Estimator(BaseEstimator): """ return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN) - # TODO(ispir): delete this function after converting all legacy usages. - def _call_legacy_get_train_ops(self, features, labels): - train_ops = self._get_train_ops(features, labels) - if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature - return train_ops - return model_fn_lib.ModelFnOps( - mode=model_fn_lib.ModeKeys.TRAIN, - predictions=None, - loss=train_ops[1], - train_op=train_ops[0]) - def _get_eval_ops(self, features, labels, metrics): """Method that builds model graph and returns evaluation ops. @@ -1343,114 +1315,6 @@ class Estimator(BaseEstimator): return export_dir - @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', - 'batch_size') - def fit(self, - x=None, - y=None, - input_fn=None, - steps=None, - batch_size=None, - monitors=None, - max_steps=None): - # pylint: disable=g-doc-args,g-doc-return-or-yield - """See `Trainable`. - - Raises: - ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`. - ValueError: If both `steps` and `max_steps` are not `None`. - """ - if (steps is not None) and (max_steps is not None): - raise ValueError('Can not provide both steps and max_steps.') - if max_steps is not None: - try: - start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP) - if max_steps <= start_step: - logging.info('Skipping training since max_steps has already saved.') - return None - except: # pylint: disable=bare-except - pass - - hooks = monitor_lib.replace_monitors_with_hooks(monitors, self) - if steps is not None or max_steps is not None: - hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps)) - - input_fn, feed_fn = _get_input_fn( - x, - y, - input_fn, - feed_fn=None, - batch_size=batch_size, - shuffle=True, - epochs=None) - if feed_fn: - hooks.append(_FeedFnHook(feed_fn)) - loss = self._train_model_v2(input_fn=input_fn, hooks=hooks) - logging.info('Loss for final step: %s.', loss) - return self - - def _train_model_v2(self, input_fn, hooks): - all_hooks = [] - self._graph = ops.Graph() - with self._graph.as_default() as g, g.device(self._device_fn): - random_seed.set_random_seed(self._config.tf_random_seed) - global_step = contrib_framework.create_global_step(g) - features, labels = input_fn() - self._check_inputs(features, labels) - model_fn_ops = self._call_legacy_get_train_ops(features, labels) - ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss) - all_hooks.extend([ - basic_session_run_hooks.NanTensorHook(model_fn_ops.loss), - basic_session_run_hooks.LoggingTensorHook( - { - 'loss': model_fn_ops.loss, - 'step': global_step - }, - every_n_iter=100) - ]) - all_hooks.extend(hooks) - - scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold() - if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): - ops.add_to_collection( - ops.GraphKeys.SAVERS, - saver.Saver( - sharded=True, - max_to_keep=self._config.keep_checkpoint_max, - defer_build=True)) - - chief_hooks = [] - if (self._config.save_checkpoints_secs or - self._config.save_checkpoints_steps): - saver_hook_exists = any([ - isinstance(h, basic_session_run_hooks.CheckpointSaverHook) - for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks + - model_fn_ops.training_chief_hooks) - ]) - if not saver_hook_exists: - chief_hooks = [ - basic_session_run_hooks.CheckpointSaverHook( - self._model_dir, - save_secs=self._config.save_checkpoints_secs, - save_steps=self._config.save_checkpoints_steps, - scaffold=scaffold) - ] - with monitored_session.MonitoredTrainingSession( - master=self._config.master, - is_chief=self._config.is_chief, - checkpoint_dir=self._model_dir, - scaffold=scaffold, - hooks=all_hooks + model_fn_ops.training_hooks, - chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks, - save_checkpoint_secs=0, # Saving is handled by a hook. - save_summaries_steps=self._config.save_summary_steps, - config=None) as mon_sess: - loss = None - while not mon_sess.should_stop(): - _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss]) - summary_io.SummaryWriterCache.clear() - return loss - class _FeedFnHook(session_run_hook.SessionRunHook): """Runs feed_fn and sets the feed_dict accordingly.""" @@ -1509,6 +1373,17 @@ class SKCompat(sklearn.BaseEstimator): input_fn, feed_fn = _get_input_fn( x, None, input_fn=None, feed_fn=None, batch_size=batch_size, shuffle=False, epochs=1) - return self._estimator._infer_model( - input_fn=input_fn, feed_fn=feed_fn, outputs=outputs, - as_iterable=False) + results = list( + self._estimator._infer_model( + input_fn=input_fn, + feed_fn=feed_fn, + outputs=outputs, + as_iterable=True, + iterate_batches=True)) + if not isinstance(results[0], dict): + return np.concatenate([output for output in results], axis=0) + return { + key: np.concatenate( + [output[key] for output in results], axis=0) + for key in results[0] + }