Simplified estimator logic by MonitoredSession.
Removed graph_action usage. Change: 144126485
This commit is contained in:
parent
3e59f0540e
commit
61a6797c4f
@ -21,20 +21,28 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib import framework
|
from tensorflow.contrib import framework
|
||||||
from tensorflow.contrib.factorization.python.ops import gmm_ops
|
from tensorflow.contrib.factorization.python.ops import gmm_ops
|
||||||
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
||||||
from tensorflow.contrib.framework.python.ops import variables
|
from tensorflow.contrib.framework.python.ops import variables
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn import graph_actions
|
||||||
|
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
|
||||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
||||||
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import random_seed as random_seed_lib
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
def _streaming_sum(scalar_tensor):
|
def _streaming_sum(scalar_tensor):
|
||||||
@ -44,7 +52,7 @@ def _streaming_sum(scalar_tensor):
|
|||||||
return sum_metric, sum_update
|
return sum_metric, sum_update
|
||||||
|
|
||||||
|
|
||||||
class GMM(estimator.Estimator, TransformerMixin):
|
class GMM(estimator_lib.Estimator, TransformerMixin):
|
||||||
"""GMM clustering."""
|
"""GMM clustering."""
|
||||||
SCORES = 'scores'
|
SCORES = 'scores'
|
||||||
ASSIGNMENTS = 'assignments'
|
ASSIGNMENTS = 'assignments'
|
||||||
@ -116,7 +124,8 @@ class GMM(estimator.Estimator, TransformerMixin):
|
|||||||
self._data_feeder = data_feeder.setup_train_data_feeder(x, None,
|
self._data_feeder = data_feeder.setup_train_data_feeder(x, None,
|
||||||
self._num_clusters,
|
self._num_clusters,
|
||||||
self.batch_size)
|
self.batch_size)
|
||||||
self._train_model(
|
_legacy_train_model( # pylint: disable=protected-access
|
||||||
|
self,
|
||||||
input_fn=self._data_feeder.input_builder,
|
input_fn=self._data_feeder.input_builder,
|
||||||
feed_fn=self._data_feeder.get_feed_dict_fn(),
|
feed_fn=self._data_feeder.get_feed_dict_fn(),
|
||||||
steps=steps or self.steps,
|
steps=steps or self.steps,
|
||||||
@ -218,3 +227,90 @@ class GMM(estimator.Estimator, TransformerMixin):
|
|||||||
self._covariance_type,
|
self._covariance_type,
|
||||||
self._params)
|
self._params)
|
||||||
return {GMM.SCORES: _streaming_sum(math_ops.reduce_sum(losses))}
|
return {GMM.SCORES: _streaming_sum(math_ops.reduce_sum(losses))}
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(xavigonzalvo): delete this after implementing model-fn based Estimator.
|
||||||
|
def _legacy_train_model(estimator,
|
||||||
|
input_fn,
|
||||||
|
steps,
|
||||||
|
feed_fn=None,
|
||||||
|
init_op=None,
|
||||||
|
init_feed_fn=None,
|
||||||
|
init_fn=None,
|
||||||
|
device_fn=None,
|
||||||
|
monitors=None,
|
||||||
|
log_every_steps=100,
|
||||||
|
fail_on_nan_loss=True,
|
||||||
|
max_steps=None):
|
||||||
|
"""Legacy train function of Estimator."""
|
||||||
|
if hasattr(estimator.config, 'execution_mode'):
|
||||||
|
if estimator.config.execution_mode not in ('all', 'train'):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stagger startup of worker sessions based on task id.
|
||||||
|
sleep_secs = min(
|
||||||
|
estimator.config.training_worker_max_startup_secs,
|
||||||
|
estimator.config.task_id *
|
||||||
|
estimator.config.training_worker_session_startup_stagger_secs)
|
||||||
|
if sleep_secs:
|
||||||
|
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
|
||||||
|
estimator.config.task_id)
|
||||||
|
time.sleep(sleep_secs)
|
||||||
|
|
||||||
|
# Device allocation
|
||||||
|
device_fn = device_fn or estimator._device_fn # pylint: disable=protected-access
|
||||||
|
|
||||||
|
with ops.Graph().as_default() as g, g.device(device_fn):
|
||||||
|
random_seed_lib.set_random_seed(estimator.config.tf_random_seed)
|
||||||
|
global_step = framework.create_global_step(g)
|
||||||
|
features, labels = input_fn()
|
||||||
|
estimator._check_inputs(features, labels) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
# The default return type of _get_train_ops is ModelFnOps. But there are
|
||||||
|
# some subclasses of tf.contrib.learn.Estimator which override this
|
||||||
|
# method and use the legacy signature, namely _get_train_ops returns a
|
||||||
|
# (train_op, loss) tuple. The following else-statement code covers these
|
||||||
|
# cases, but will soon be deleted after the subclasses are updated.
|
||||||
|
# TODO(b/32664904): Update subclasses and delete the else-statement.
|
||||||
|
train_ops = estimator._get_train_ops(features, labels) # pylint: disable=protected-access
|
||||||
|
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
|
||||||
|
train_op = train_ops.train_op
|
||||||
|
loss_op = train_ops.loss
|
||||||
|
if estimator.config.is_chief:
|
||||||
|
hooks = train_ops.training_chief_hooks + train_ops.training_hooks
|
||||||
|
else:
|
||||||
|
hooks = train_ops.training_hooks
|
||||||
|
else: # Legacy signature
|
||||||
|
if len(train_ops) != 2:
|
||||||
|
raise ValueError('Expected a tuple of train_op and loss, got {}'.format(
|
||||||
|
train_ops))
|
||||||
|
train_op = train_ops[0]
|
||||||
|
loss_op = train_ops[1]
|
||||||
|
hooks = []
|
||||||
|
|
||||||
|
hooks += monitor_lib.replace_monitors_with_hooks(monitors, estimator)
|
||||||
|
|
||||||
|
ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
|
||||||
|
return graph_actions._monitored_train( # pylint: disable=protected-access
|
||||||
|
graph=g,
|
||||||
|
output_dir=estimator.model_dir,
|
||||||
|
train_op=train_op,
|
||||||
|
loss_op=loss_op,
|
||||||
|
global_step_tensor=global_step,
|
||||||
|
init_op=init_op,
|
||||||
|
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
|
||||||
|
init_fn=init_fn,
|
||||||
|
log_every_steps=log_every_steps,
|
||||||
|
supervisor_is_chief=estimator.config.is_chief,
|
||||||
|
supervisor_master=estimator.config.master,
|
||||||
|
supervisor_save_model_secs=estimator.config.save_checkpoints_secs,
|
||||||
|
supervisor_save_model_steps=estimator.config.save_checkpoints_steps,
|
||||||
|
supervisor_save_summaries_steps=estimator.config.save_summary_steps,
|
||||||
|
keep_checkpoint_max=estimator.config.keep_checkpoint_max,
|
||||||
|
keep_checkpoint_every_n_hours=(
|
||||||
|
estimator.config.keep_checkpoint_every_n_hours),
|
||||||
|
feed_fn=feed_fn,
|
||||||
|
steps=steps,
|
||||||
|
fail_on_nan_loss=fail_on_nan_loss,
|
||||||
|
hooks=hooks,
|
||||||
|
max_steps=max_steps)
|
||||||
|
@ -22,10 +22,8 @@ from __future__ import print_function
|
|||||||
import abc
|
import abc
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import itertools
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
@ -39,10 +37,8 @@ from tensorflow.contrib.framework import deprecated_args
|
|||||||
from tensorflow.contrib.framework import list_variables
|
from tensorflow.contrib.framework import list_variables
|
||||||
from tensorflow.contrib.framework import load_variable
|
from tensorflow.contrib.framework import load_variable
|
||||||
from tensorflow.contrib.framework.python.framework import experimental
|
from tensorflow.contrib.framework.python.framework import experimental
|
||||||
from tensorflow.contrib.framework.python.ops import ops as contrib_ops
|
|
||||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||||
from tensorflow.contrib.learn.python.learn import evaluable
|
from tensorflow.contrib.learn.python.learn import evaluable
|
||||||
from tensorflow.contrib.learn.python.learn import graph_actions
|
|
||||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||||
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||||
from tensorflow.contrib.learn.python.learn import trainable
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
@ -58,7 +54,6 @@ from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
|
|||||||
from tensorflow.contrib.training.python.training import evaluation
|
from tensorflow.contrib.training.python.training import evaluation
|
||||||
from tensorflow.core.framework import summary_pb2
|
from tensorflow.core.framework import summary_pb2
|
||||||
from tensorflow.python.client import session as tf_session
|
from tensorflow.python.client import session as tf_session
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -92,6 +87,25 @@ SCIKIT_DECOUPLE_INSTRUCTIONS = (
|
|||||||
' est = Estimator(...) -> est = SKCompat(Estimator(...))')
|
' est = Estimator(...) -> est = SKCompat(Estimator(...))')
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_input_args(x, y, input_fn, feed_fn, batch_size):
|
||||||
|
"""Verifies validity of co-existance of input arguments."""
|
||||||
|
if input_fn is None:
|
||||||
|
if x is None:
|
||||||
|
raise ValueError('Either x or input_fn must be provided.')
|
||||||
|
|
||||||
|
if contrib_framework.is_tensor(x) or (y is not None and
|
||||||
|
contrib_framework.is_tensor(y)):
|
||||||
|
raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
|
||||||
|
|
||||||
|
if feed_fn is not None:
|
||||||
|
raise ValueError('Can not provide both feed_fn and x or y.')
|
||||||
|
else:
|
||||||
|
if (x is not None) or (y is not None):
|
||||||
|
raise ValueError('Can not provide both input_fn and x or y.')
|
||||||
|
if batch_size is not None:
|
||||||
|
raise ValueError('Can not provide both input_fn and batch_size.')
|
||||||
|
|
||||||
|
|
||||||
def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
|
def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
|
||||||
"""Make inputs into input and feed functions.
|
"""Make inputs into input and feed functions.
|
||||||
|
|
||||||
@ -110,29 +124,17 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: Only one of `(x & y)` or `input_fn` must be provided.
|
ValueError: Only one of `(x & y)` or `input_fn` must be provided.
|
||||||
"""
|
"""
|
||||||
if input_fn is None:
|
_verify_input_args(x, y, input_fn, feed_fn, batch_size)
|
||||||
if x is None:
|
if input_fn is not None:
|
||||||
raise ValueError('Either x or input_fn must be provided.')
|
return input_fn, feed_fn
|
||||||
|
df = data_feeder.setup_train_data_feeder(
|
||||||
if contrib_framework.is_tensor(x) or (y is not None and
|
x,
|
||||||
contrib_framework.is_tensor(y)):
|
y,
|
||||||
raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
|
n_classes=None,
|
||||||
|
batch_size=batch_size,
|
||||||
if feed_fn is not None:
|
shuffle=shuffle,
|
||||||
raise ValueError('Can not provide both feed_fn and x or y.')
|
epochs=epochs)
|
||||||
|
return df.input_builder, df.get_feed_dict_fn()
|
||||||
df = data_feeder.setup_train_data_feeder(x, y, n_classes=None,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=shuffle,
|
|
||||||
epochs=epochs)
|
|
||||||
return df.input_builder, df.get_feed_dict_fn()
|
|
||||||
|
|
||||||
if (x is not None) or (y is not None):
|
|
||||||
raise ValueError('Can not provide both input_fn and x or y.')
|
|
||||||
if batch_size is not None:
|
|
||||||
raise ValueError('Can not provide both input_fn and batch_size.')
|
|
||||||
|
|
||||||
return input_fn, feed_fn
|
|
||||||
|
|
||||||
|
|
||||||
def infer_real_valued_columns_from_input_fn(input_fn):
|
def infer_real_valued_columns_from_input_fn(input_fn):
|
||||||
@ -311,9 +313,8 @@ def _write_dict_to_summary(output_dir,
|
|||||||
dictionary: the `dict` to be written to summary file.
|
dictionary: the `dict` to be written to summary file.
|
||||||
current_global_step: `int`, the current global step.
|
current_global_step: `int`, the current global step.
|
||||||
"""
|
"""
|
||||||
logging.info(
|
logging.info('Saving dict for global step %d: %s', current_global_step,
|
||||||
'Saving dict for global step %d: %s' %
|
_dict_to_str(dictionary))
|
||||||
(current_global_step, _dict_to_str(dictionary)))
|
|
||||||
summary_writer = summary_io.SummaryWriterCache.get(output_dir)
|
summary_writer = summary_io.SummaryWriterCache.get(output_dir)
|
||||||
summary_proto = summary_pb2.Summary()
|
summary_proto = summary_pb2.Summary()
|
||||||
for key in dictionary:
|
for key in dictionary:
|
||||||
@ -404,15 +405,24 @@ class BaseEstimator(
|
|||||||
"""
|
"""
|
||||||
if (steps is not None) and (max_steps is not None):
|
if (steps is not None) and (max_steps is not None):
|
||||||
raise ValueError('Can not provide both steps and max_steps.')
|
raise ValueError('Can not provide both steps and max_steps.')
|
||||||
|
_verify_input_args(x, y, input_fn, None, batch_size)
|
||||||
|
if x is not None:
|
||||||
|
return SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
|
||||||
|
|
||||||
input_fn, feed_fn = _get_input_fn(x, y, input_fn, feed_fn=None,
|
if max_steps is not None:
|
||||||
batch_size=batch_size, shuffle=True,
|
try:
|
||||||
epochs=None)
|
start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
|
||||||
loss = self._train_model(input_fn=input_fn,
|
if max_steps <= start_step:
|
||||||
feed_fn=feed_fn,
|
logging.info('Skipping training since max_steps has already saved.')
|
||||||
steps=steps,
|
return None
|
||||||
monitors=monitors,
|
except: # pylint: disable=bare-except
|
||||||
max_steps=max_steps)
|
pass
|
||||||
|
|
||||||
|
hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
|
||||||
|
if steps is not None or max_steps is not None:
|
||||||
|
hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
|
||||||
|
|
||||||
|
loss = self._train_model(input_fn=input_fn, hooks=hooks)
|
||||||
logging.info('Loss for final step: %s.', loss)
|
logging.info('Loss for final step: %s.', loss)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -485,9 +495,10 @@ class BaseEstimator(
|
|||||||
`input_fn` or `feed_fn` is provided.
|
`input_fn` or `feed_fn` is provided.
|
||||||
Or if `metrics` is not `None` or `dict`.
|
Or if `metrics` is not `None` or `dict`.
|
||||||
"""
|
"""
|
||||||
input_fn, feed_fn = _get_input_fn(x, y, input_fn=input_fn,
|
_verify_input_args(x, y, input_fn, feed_fn, batch_size)
|
||||||
feed_fn=feed_fn, batch_size=batch_size,
|
if x is not None:
|
||||||
shuffle=False, epochs=1)
|
return SKCompat(self).score(x, y, batch_size, steps, metrics)
|
||||||
|
|
||||||
if metrics is not None and not isinstance(metrics, dict):
|
if metrics is not None and not isinstance(metrics, dict):
|
||||||
raise ValueError('Metrics argument should be None or dict. '
|
raise ValueError('Metrics argument should be None or dict. '
|
||||||
'Got %s.' % metrics)
|
'Got %s.' % metrics)
|
||||||
@ -537,11 +548,15 @@ class BaseEstimator(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If x and input_fn are both provided or both `None`.
|
ValueError: If x and input_fn are both provided or both `None`.
|
||||||
"""
|
"""
|
||||||
input_fn, feed_fn = _get_input_fn(
|
_verify_input_args(x, None, input_fn, None, batch_size)
|
||||||
x, None, input_fn=input_fn, feed_fn=None, batch_size=batch_size,
|
if x is not None and not as_iterable:
|
||||||
shuffle=False, epochs=1)
|
return SKCompat(self).predict(x, batch_size)
|
||||||
|
|
||||||
|
input_fn, feed_fn = _get_input_fn(x, None, input_fn, None, batch_size)
|
||||||
return self._infer_model(
|
return self._infer_model(
|
||||||
input_fn=input_fn, feed_fn=feed_fn, outputs=outputs,
|
input_fn=input_fn,
|
||||||
|
feed_fn=feed_fn,
|
||||||
|
outputs=outputs,
|
||||||
as_iterable=as_iterable)
|
as_iterable=as_iterable)
|
||||||
|
|
||||||
def get_variable_value(self, name):
|
def get_variable_value(self, name):
|
||||||
@ -728,91 +743,6 @@ class BaseEstimator(
|
|||||||
self._labels_info = tensor_signature.create_signatures(labels)
|
self._labels_info = tensor_signature.create_signatures(labels)
|
||||||
logging.debug('Setting labels info to %s', str(self._labels_info))
|
logging.debug('Setting labels info to %s', str(self._labels_info))
|
||||||
|
|
||||||
def _train_model(self,
|
|
||||||
input_fn,
|
|
||||||
steps,
|
|
||||||
feed_fn=None,
|
|
||||||
init_op=None,
|
|
||||||
init_feed_fn=None,
|
|
||||||
init_fn=None,
|
|
||||||
device_fn=None,
|
|
||||||
monitors=None,
|
|
||||||
log_every_steps=100,
|
|
||||||
fail_on_nan_loss=True,
|
|
||||||
max_steps=None):
|
|
||||||
# TODO(wicke): Remove this once Model and associated code are gone.
|
|
||||||
if hasattr(self._config, 'execution_mode'):
|
|
||||||
if self._config.execution_mode not in ('all', 'train'):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Stagger startup of worker sessions based on task id.
|
|
||||||
sleep_secs = min(
|
|
||||||
self._config.training_worker_max_startup_secs,
|
|
||||||
self._config.task_id *
|
|
||||||
self._config.training_worker_session_startup_stagger_secs)
|
|
||||||
if sleep_secs:
|
|
||||||
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
|
|
||||||
self._config.task_id)
|
|
||||||
time.sleep(sleep_secs)
|
|
||||||
|
|
||||||
# Device allocation
|
|
||||||
device_fn = device_fn or self._device_fn
|
|
||||||
|
|
||||||
self._graph = ops.Graph()
|
|
||||||
with self._graph.as_default() as g, g.device(device_fn):
|
|
||||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
|
||||||
global_step = contrib_framework.create_global_step(g)
|
|
||||||
features, labels = input_fn()
|
|
||||||
self._check_inputs(features, labels)
|
|
||||||
|
|
||||||
# The default return type of _get_train_ops is ModelFnOps. But there are
|
|
||||||
# some subclasses of tf.contrib.learn.Estimator which override this
|
|
||||||
# method and use the legacy signature, namely _get_train_ops returns a
|
|
||||||
# (train_op, loss) tuple. The following else-statement code covers these
|
|
||||||
# cases, but will soon be deleted after the subclasses are updated.
|
|
||||||
# TODO(b/32664904): Update subclasses and delete the else-statement.
|
|
||||||
train_ops = self._get_train_ops(features, labels)
|
|
||||||
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
|
|
||||||
train_op = train_ops.train_op
|
|
||||||
loss_op = train_ops.loss
|
|
||||||
if self.config.is_chief:
|
|
||||||
hooks = train_ops.training_chief_hooks + train_ops.training_hooks
|
|
||||||
else:
|
|
||||||
hooks = train_ops.training_hooks
|
|
||||||
else: # Legacy signature
|
|
||||||
if len(train_ops) != 2:
|
|
||||||
raise ValueError('Expected a tuple of train_op and loss, got {}'.
|
|
||||||
format(train_ops))
|
|
||||||
train_op = train_ops[0]
|
|
||||||
loss_op = train_ops[1]
|
|
||||||
hooks = []
|
|
||||||
|
|
||||||
hooks += monitor_lib.replace_monitors_with_hooks(monitors, self)
|
|
||||||
|
|
||||||
ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
|
|
||||||
return graph_actions._monitored_train( # pylint: disable=protected-access
|
|
||||||
graph=g,
|
|
||||||
output_dir=self._model_dir,
|
|
||||||
train_op=train_op,
|
|
||||||
loss_op=loss_op,
|
|
||||||
global_step_tensor=global_step,
|
|
||||||
init_op=init_op,
|
|
||||||
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
|
|
||||||
init_fn=init_fn,
|
|
||||||
log_every_steps=log_every_steps,
|
|
||||||
supervisor_is_chief=self.config.is_chief,
|
|
||||||
supervisor_master=self._config.master,
|
|
||||||
supervisor_save_model_secs=self._config.save_checkpoints_secs,
|
|
||||||
supervisor_save_model_steps=self._config.save_checkpoints_steps,
|
|
||||||
supervisor_save_summaries_steps=self._config.save_summary_steps,
|
|
||||||
keep_checkpoint_max=self._config.keep_checkpoint_max,
|
|
||||||
keep_checkpoint_every_n_hours=self._config.keep_checkpoint_every_n_hours,
|
|
||||||
feed_fn=feed_fn,
|
|
||||||
steps=steps,
|
|
||||||
fail_on_nan_loss=fail_on_nan_loss,
|
|
||||||
hooks=hooks,
|
|
||||||
max_steps=max_steps)
|
|
||||||
|
|
||||||
def _extract_metric_update_ops(self, eval_dict):
|
def _extract_metric_update_ops(self, eval_dict):
|
||||||
"""Separate update operations from metric value operations."""
|
"""Separate update operations from metric value operations."""
|
||||||
update_ops = []
|
update_ops = []
|
||||||
@ -915,8 +845,12 @@ class BaseEstimator(
|
|||||||
return result[0]
|
return result[0]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _infer_model(
|
def _infer_model(self,
|
||||||
self, input_fn, feed_fn=None, outputs=None, as_iterable=True):
|
input_fn,
|
||||||
|
feed_fn=None,
|
||||||
|
outputs=None,
|
||||||
|
as_iterable=True,
|
||||||
|
iterate_batches=False):
|
||||||
# Check that model has been trained.
|
# Check that model has been trained.
|
||||||
checkpoint_path = saver.latest_checkpoint(self._model_dir)
|
checkpoint_path = saver.latest_checkpoint(self._model_dir)
|
||||||
if not checkpoint_path:
|
if not checkpoint_path:
|
||||||
@ -927,103 +861,152 @@ class BaseEstimator(
|
|||||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||||
contrib_framework.create_global_step(g)
|
contrib_framework.create_global_step(g)
|
||||||
features = self._get_features_from_input_fn(input_fn)
|
features = self._get_features_from_input_fn(input_fn)
|
||||||
|
infer_ops = self._call_legacy_get_predict_ops(features)
|
||||||
# The default return type of _get_predict_ops is ModelFnOps. But there are
|
predictions = self._filter_predictions(infer_ops.predictions, outputs)
|
||||||
# some subclasses of tf.contrib.learn.Estimator which override this
|
mon_sess = monitored_session.MonitoredSession(
|
||||||
# method and use the legacy signature, namely _get_predict_ops returns a
|
session_creator=monitored_session.ChiefSessionCreator(
|
||||||
# `predictions` Tensor or dict or Tensors. The following else-statement
|
checkpoint_filename_with_path=checkpoint_path))
|
||||||
# code covers these cases, but will soon be deleted after the subclasses
|
if not as_iterable:
|
||||||
# are updated.
|
with mon_sess:
|
||||||
# TODO(b/32664904): Update subclasses and delete the else-statement.
|
if not mon_sess.should_stop():
|
||||||
infer_ops = self._get_predict_ops(features)
|
return mon_sess.run(predictions, feed_fn() if feed_fn else None)
|
||||||
if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
|
|
||||||
predictions = infer_ops.predictions
|
|
||||||
else: # Legacy signature
|
|
||||||
predictions = infer_ops
|
|
||||||
|
|
||||||
# If predictions is single output - wrap it into dict, and remember to
|
|
||||||
# return not a dict.
|
|
||||||
return_dict = isinstance(predictions, dict)
|
|
||||||
if not return_dict:
|
|
||||||
predictions = {'predictions': predictions}
|
|
||||||
|
|
||||||
# Filter what to run predictions on, if outputs provided.
|
|
||||||
if outputs:
|
|
||||||
existing_keys = predictions.keys()
|
|
||||||
predictions = {
|
|
||||||
key: value
|
|
||||||
for key, value in six.iteritems(predictions) if key in outputs
|
|
||||||
}
|
|
||||||
if not predictions:
|
|
||||||
raise ValueError('Expected to run at least one output from %s, '
|
|
||||||
'provided %s.' % (existing_keys, outputs))
|
|
||||||
|
|
||||||
if as_iterable:
|
|
||||||
return self._infer_model_as_iterable(
|
|
||||||
checkpoint_path, predictions, feed_fn, return_dict)
|
|
||||||
else:
|
else:
|
||||||
return self._infer_model_single(
|
return self._predict_generator(mon_sess, predictions, feed_fn,
|
||||||
checkpoint_path, predictions, feed_fn, return_dict)
|
iterate_batches)
|
||||||
|
|
||||||
def _infer_model_single(
|
def _predict_generator(self, mon_sess, predictions, feed_fn, iterate_batches):
|
||||||
self, checkpoint_path, predictions, feed_fn, return_dict):
|
with mon_sess:
|
||||||
if feed_fn is None:
|
while not mon_sess.should_stop():
|
||||||
preds = graph_actions.infer(checkpoint_path, predictions)
|
preds = mon_sess.run(predictions, feed_fn() if feed_fn else None)
|
||||||
else:
|
if iterate_batches:
|
||||||
def _feed_fn():
|
yield preds
|
||||||
while True:
|
elif not isinstance(predictions, dict):
|
||||||
yield feed_fn()
|
for pred in preds:
|
||||||
|
yield pred
|
||||||
outputs = graph_actions.run_feeds(
|
else:
|
||||||
output_dict=predictions,
|
first_tensor = list(preds.values())[0]
|
||||||
feed_dicts=_feed_fn(),
|
|
||||||
restore_checkpoint_path=checkpoint_path)
|
|
||||||
preds = {
|
|
||||||
key: np.concatenate([output[key] for output in outputs], axis=0)
|
|
||||||
for key in predictions}
|
|
||||||
|
|
||||||
return preds if return_dict else preds['predictions']
|
|
||||||
|
|
||||||
def _infer_model_as_iterable(
|
|
||||||
self, checkpoint_path, predictions, feed_fn, return_dict):
|
|
||||||
if feed_fn is None:
|
|
||||||
# If there are no queue_runners, the input `predictions` is a
|
|
||||||
# constant, and we should stop after the first epoch. If,
|
|
||||||
# instead, there are queue_runners, eventually they should throw
|
|
||||||
# an `OutOfRangeError`.
|
|
||||||
graph = contrib_ops.get_graph_from_inputs(predictions.values())
|
|
||||||
if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
|
|
||||||
feed_dicts = itertools.repeat(None)
|
|
||||||
else:
|
|
||||||
feed_dicts = [None]
|
|
||||||
else:
|
|
||||||
def _feed_fn():
|
|
||||||
while True:
|
|
||||||
yield feed_fn()
|
|
||||||
feed_dicts = _feed_fn()
|
|
||||||
|
|
||||||
try:
|
|
||||||
for output_batch in graph_actions.run_feeds_iter(
|
|
||||||
output_dict=predictions,
|
|
||||||
feed_dicts=feed_dicts,
|
|
||||||
restore_checkpoint_path=checkpoint_path):
|
|
||||||
# Unpack batches into individual predictions
|
|
||||||
if return_dict:
|
|
||||||
first_tensor = list(output_batch.values())[0]
|
|
||||||
if isinstance(first_tensor, sparse_tensor.SparseTensorValue):
|
if isinstance(first_tensor, sparse_tensor.SparseTensorValue):
|
||||||
batch_length = first_tensor.dense_shape[0]
|
batch_length = first_tensor.dense_shape[0]
|
||||||
else:
|
else:
|
||||||
batch_length = first_tensor.shape[0]
|
batch_length = first_tensor.shape[0]
|
||||||
for i in range(batch_length):
|
for i in range(batch_length):
|
||||||
yield {key: value[i] for key, value in six.iteritems(output_batch)}
|
yield {key: value[i] for key, value in six.iteritems(preds)}
|
||||||
else:
|
if self._is_input_constant(feed_fn, mon_sess.graph):
|
||||||
for pred in output_batch['predictions']:
|
return
|
||||||
yield pred
|
|
||||||
|
|
||||||
except errors.OutOfRangeError:
|
def _is_input_constant(self, feed_fn, graph):
|
||||||
# We fall out of the above loop naturally if feed_fn raises StopIteration,
|
# If there are no queue_runners, the input `predictions` is a
|
||||||
# or we catch an OutOfRangeError if we've reached the end of inputs.
|
# constant, and we should stop after the first epoch. If,
|
||||||
logging.info('Reached end of inputs for predict_iter.')
|
# instead, there are queue_runners, eventually they should throw
|
||||||
|
# an `OutOfRangeError`.
|
||||||
|
if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
|
||||||
|
return False
|
||||||
|
# data_feeder uses feed_fn to generate `OutOfRangeError`.
|
||||||
|
if feed_fn is not None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _filter_predictions(self, predictions, outputs):
|
||||||
|
if not outputs:
|
||||||
|
return predictions
|
||||||
|
if not isinstance(predictions, dict):
|
||||||
|
raise ValueError(
|
||||||
|
'outputs argument is not valid in case of non-dict predictions.')
|
||||||
|
existing_keys = predictions.keys()
|
||||||
|
predictions = {
|
||||||
|
key: value
|
||||||
|
for key, value in six.iteritems(predictions) if key in outputs
|
||||||
|
}
|
||||||
|
if not predictions:
|
||||||
|
raise ValueError('Expected to run at least one output from %s, '
|
||||||
|
'provided %s.' % (existing_keys, outputs))
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
def _train_model(self, input_fn, hooks):
|
||||||
|
all_hooks = []
|
||||||
|
self._graph = ops.Graph()
|
||||||
|
with self._graph.as_default() as g, g.device(self._device_fn):
|
||||||
|
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||||
|
global_step = contrib_framework.create_global_step(g)
|
||||||
|
features, labels = input_fn()
|
||||||
|
self._check_inputs(features, labels)
|
||||||
|
model_fn_ops = self._call_legacy_get_train_ops(features, labels)
|
||||||
|
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
|
||||||
|
all_hooks.extend([
|
||||||
|
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
|
||||||
|
basic_session_run_hooks.LoggingTensorHook(
|
||||||
|
{
|
||||||
|
'loss': model_fn_ops.loss,
|
||||||
|
'step': global_step
|
||||||
|
},
|
||||||
|
every_n_iter=100)
|
||||||
|
])
|
||||||
|
all_hooks.extend(hooks)
|
||||||
|
|
||||||
|
scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
|
||||||
|
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
|
||||||
|
ops.add_to_collection(
|
||||||
|
ops.GraphKeys.SAVERS,
|
||||||
|
saver.Saver(
|
||||||
|
sharded=True,
|
||||||
|
max_to_keep=self._config.keep_checkpoint_max,
|
||||||
|
defer_build=True))
|
||||||
|
|
||||||
|
chief_hooks = []
|
||||||
|
if (self._config.save_checkpoints_secs or
|
||||||
|
self._config.save_checkpoints_steps):
|
||||||
|
saver_hook_exists = any([
|
||||||
|
isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
|
||||||
|
for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
|
||||||
|
model_fn_ops.training_chief_hooks)
|
||||||
|
])
|
||||||
|
if not saver_hook_exists:
|
||||||
|
chief_hooks = [
|
||||||
|
basic_session_run_hooks.CheckpointSaverHook(
|
||||||
|
self._model_dir,
|
||||||
|
save_secs=self._config.save_checkpoints_secs,
|
||||||
|
save_steps=self._config.save_checkpoints_steps,
|
||||||
|
scaffold=scaffold)
|
||||||
|
]
|
||||||
|
with monitored_session.MonitoredTrainingSession(
|
||||||
|
master=self._config.master,
|
||||||
|
is_chief=self._config.is_chief,
|
||||||
|
checkpoint_dir=self._model_dir,
|
||||||
|
scaffold=scaffold,
|
||||||
|
hooks=all_hooks + model_fn_ops.training_hooks,
|
||||||
|
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
|
||||||
|
save_checkpoint_secs=0, # Saving is handled by a hook.
|
||||||
|
save_summaries_steps=self._config.save_summary_steps,
|
||||||
|
config=None) as mon_sess:
|
||||||
|
loss = None
|
||||||
|
while not mon_sess.should_stop():
|
||||||
|
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
|
||||||
|
summary_io.SummaryWriterCache.clear()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _call_legacy_get_predict_ops(self, features):
|
||||||
|
# The default return type of _get_predict_ops is ModelFnOps. But there are
|
||||||
|
# some subclasses of tf.contrib.learn.Estimator which override this
|
||||||
|
# method and use the legacy signature, namely _get_predict_ops returns a
|
||||||
|
# `predictions` Tensor or dict or Tensors. The following else-statement
|
||||||
|
# code covers these cases, but will soon be deleted after the subclasses
|
||||||
|
# are updated.
|
||||||
|
# TODO(b/32664904): Update subclasses and delete the else-statement.
|
||||||
|
infer_ops = self._get_predict_ops(features)
|
||||||
|
if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
|
||||||
|
return infer_ops
|
||||||
|
return model_fn_lib.ModelFnOps(
|
||||||
|
mode=model_fn_lib.ModeKeys.INFER, predictions=infer_ops)
|
||||||
|
|
||||||
|
def _call_legacy_get_train_ops(self, features, labels):
|
||||||
|
train_ops = self._get_train_ops(features, labels)
|
||||||
|
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
|
||||||
|
return train_ops
|
||||||
|
return model_fn_lib.ModelFnOps(
|
||||||
|
mode=model_fn_lib.ModeKeys.TRAIN,
|
||||||
|
predictions=None,
|
||||||
|
loss=train_ops[1],
|
||||||
|
train_op=train_ops[0])
|
||||||
|
|
||||||
|
|
||||||
def _identity_feature_engineering_fn(features, labels):
|
def _identity_feature_engineering_fn(features, labels):
|
||||||
@ -1177,17 +1160,6 @@ class Estimator(BaseEstimator):
|
|||||||
"""
|
"""
|
||||||
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
|
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
|
||||||
|
|
||||||
# TODO(ispir): delete this function after converting all legacy usages.
|
|
||||||
def _call_legacy_get_train_ops(self, features, labels):
|
|
||||||
train_ops = self._get_train_ops(features, labels)
|
|
||||||
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
|
|
||||||
return train_ops
|
|
||||||
return model_fn_lib.ModelFnOps(
|
|
||||||
mode=model_fn_lib.ModeKeys.TRAIN,
|
|
||||||
predictions=None,
|
|
||||||
loss=train_ops[1],
|
|
||||||
train_op=train_ops[0])
|
|
||||||
|
|
||||||
def _get_eval_ops(self, features, labels, metrics):
|
def _get_eval_ops(self, features, labels, metrics):
|
||||||
"""Method that builds model graph and returns evaluation ops.
|
"""Method that builds model graph and returns evaluation ops.
|
||||||
|
|
||||||
@ -1343,114 +1315,6 @@ class Estimator(BaseEstimator):
|
|||||||
|
|
||||||
return export_dir
|
return export_dir
|
||||||
|
|
||||||
@deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y',
|
|
||||||
'batch_size')
|
|
||||||
def fit(self,
|
|
||||||
x=None,
|
|
||||||
y=None,
|
|
||||||
input_fn=None,
|
|
||||||
steps=None,
|
|
||||||
batch_size=None,
|
|
||||||
monitors=None,
|
|
||||||
max_steps=None):
|
|
||||||
# pylint: disable=g-doc-args,g-doc-return-or-yield
|
|
||||||
"""See `Trainable`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
|
|
||||||
ValueError: If both `steps` and `max_steps` are not `None`.
|
|
||||||
"""
|
|
||||||
if (steps is not None) and (max_steps is not None):
|
|
||||||
raise ValueError('Can not provide both steps and max_steps.')
|
|
||||||
if max_steps is not None:
|
|
||||||
try:
|
|
||||||
start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
|
|
||||||
if max_steps <= start_step:
|
|
||||||
logging.info('Skipping training since max_steps has already saved.')
|
|
||||||
return None
|
|
||||||
except: # pylint: disable=bare-except
|
|
||||||
pass
|
|
||||||
|
|
||||||
hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
|
|
||||||
if steps is not None or max_steps is not None:
|
|
||||||
hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
|
|
||||||
|
|
||||||
input_fn, feed_fn = _get_input_fn(
|
|
||||||
x,
|
|
||||||
y,
|
|
||||||
input_fn,
|
|
||||||
feed_fn=None,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
epochs=None)
|
|
||||||
if feed_fn:
|
|
||||||
hooks.append(_FeedFnHook(feed_fn))
|
|
||||||
loss = self._train_model_v2(input_fn=input_fn, hooks=hooks)
|
|
||||||
logging.info('Loss for final step: %s.', loss)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def _train_model_v2(self, input_fn, hooks):
|
|
||||||
all_hooks = []
|
|
||||||
self._graph = ops.Graph()
|
|
||||||
with self._graph.as_default() as g, g.device(self._device_fn):
|
|
||||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
|
||||||
global_step = contrib_framework.create_global_step(g)
|
|
||||||
features, labels = input_fn()
|
|
||||||
self._check_inputs(features, labels)
|
|
||||||
model_fn_ops = self._call_legacy_get_train_ops(features, labels)
|
|
||||||
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
|
|
||||||
all_hooks.extend([
|
|
||||||
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
|
|
||||||
basic_session_run_hooks.LoggingTensorHook(
|
|
||||||
{
|
|
||||||
'loss': model_fn_ops.loss,
|
|
||||||
'step': global_step
|
|
||||||
},
|
|
||||||
every_n_iter=100)
|
|
||||||
])
|
|
||||||
all_hooks.extend(hooks)
|
|
||||||
|
|
||||||
scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
|
|
||||||
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
|
|
||||||
ops.add_to_collection(
|
|
||||||
ops.GraphKeys.SAVERS,
|
|
||||||
saver.Saver(
|
|
||||||
sharded=True,
|
|
||||||
max_to_keep=self._config.keep_checkpoint_max,
|
|
||||||
defer_build=True))
|
|
||||||
|
|
||||||
chief_hooks = []
|
|
||||||
if (self._config.save_checkpoints_secs or
|
|
||||||
self._config.save_checkpoints_steps):
|
|
||||||
saver_hook_exists = any([
|
|
||||||
isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
|
|
||||||
for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
|
|
||||||
model_fn_ops.training_chief_hooks)
|
|
||||||
])
|
|
||||||
if not saver_hook_exists:
|
|
||||||
chief_hooks = [
|
|
||||||
basic_session_run_hooks.CheckpointSaverHook(
|
|
||||||
self._model_dir,
|
|
||||||
save_secs=self._config.save_checkpoints_secs,
|
|
||||||
save_steps=self._config.save_checkpoints_steps,
|
|
||||||
scaffold=scaffold)
|
|
||||||
]
|
|
||||||
with monitored_session.MonitoredTrainingSession(
|
|
||||||
master=self._config.master,
|
|
||||||
is_chief=self._config.is_chief,
|
|
||||||
checkpoint_dir=self._model_dir,
|
|
||||||
scaffold=scaffold,
|
|
||||||
hooks=all_hooks + model_fn_ops.training_hooks,
|
|
||||||
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
|
|
||||||
save_checkpoint_secs=0, # Saving is handled by a hook.
|
|
||||||
save_summaries_steps=self._config.save_summary_steps,
|
|
||||||
config=None) as mon_sess:
|
|
||||||
loss = None
|
|
||||||
while not mon_sess.should_stop():
|
|
||||||
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
|
|
||||||
summary_io.SummaryWriterCache.clear()
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class _FeedFnHook(session_run_hook.SessionRunHook):
|
class _FeedFnHook(session_run_hook.SessionRunHook):
|
||||||
"""Runs feed_fn and sets the feed_dict accordingly."""
|
"""Runs feed_fn and sets the feed_dict accordingly."""
|
||||||
@ -1509,6 +1373,17 @@ class SKCompat(sklearn.BaseEstimator):
|
|||||||
input_fn, feed_fn = _get_input_fn(
|
input_fn, feed_fn = _get_input_fn(
|
||||||
x, None, input_fn=None, feed_fn=None, batch_size=batch_size,
|
x, None, input_fn=None, feed_fn=None, batch_size=batch_size,
|
||||||
shuffle=False, epochs=1)
|
shuffle=False, epochs=1)
|
||||||
return self._estimator._infer_model(
|
results = list(
|
||||||
input_fn=input_fn, feed_fn=feed_fn, outputs=outputs,
|
self._estimator._infer_model(
|
||||||
as_iterable=False)
|
input_fn=input_fn,
|
||||||
|
feed_fn=feed_fn,
|
||||||
|
outputs=outputs,
|
||||||
|
as_iterable=True,
|
||||||
|
iterate_batches=True))
|
||||||
|
if not isinstance(results[0], dict):
|
||||||
|
return np.concatenate([output for output in results], axis=0)
|
||||||
|
return {
|
||||||
|
key: np.concatenate(
|
||||||
|
[output[key] for output in results], axis=0)
|
||||||
|
for key in results[0]
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user