diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 92172bbf104..135afb04e5d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -97,9 +97,6 @@ class BaseEstimator(sklearn.BaseEstimator): * _get_predict_ops `Estimator` implemented below is a good example of how to use this class. - - Parameters: - model_dir: Directory to save model parameters, graph and etc. """ __metaclass__ = abc.ABCMeta @@ -107,6 +104,12 @@ class BaseEstimator(sklearn.BaseEstimator): _Config = run_config.RunConfig # pylint: disable=invalid-name def __init__(self, model_dir=None, config=None): + """Initializes a BaseEstimator instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. + config: A RunConfig instance. + """ # Model directory. self._model_dir = model_dir if self._model_dir is None: @@ -621,29 +624,6 @@ class BaseEstimator(sklearn.BaseEstimator): class Estimator(BaseEstimator): """Estimator class is the basic TensorFlow model trainer/evaluator. - - Parameters: - model_fn: Model function, takes features and targets tensors or dicts of - tensors and returns predictions and loss tensors. - Supports next three signatures for the function: - * `(features, targets) -> (predictions, loss, train_op)` - * `(features, targets, mode) -> (predictions, loss, train_op)` - * `(features, targets, mode, params) -> - (predictions, loss, train_op)` - Where: - * `features` are single `Tensor` or `dict` of `Tensor`s - (depending on data passed to `fit`), - * `targets` are `Tensor` or - `dict` of `Tensor`s (for multi-head model). - * `mode` represents if this training, evaluation or prediction. - See `ModeKeys` for example keys. - * `params` is a `dict` of hyperparameters. Will receive what is - passed to Estimator in `params` parameter. This allows to - configure Estimators from hyper parameter tunning. - model_dir: Directory to save model parameters, graph and etc. - config: Configuration object. - params: `dict` of hyper parameters that will be passed into `model_fn`. - Keys are names of parameters, values are basic python types. """ def __init__(self, @@ -651,6 +631,34 @@ class Estimator(BaseEstimator): model_dir=None, config=None, params=None): + """Constructs an Estimator instance. + + Args: + model_fn: Model function, takes features and targets tensors or dicts of + tensors and returns predictions and loss tensors. + Supports next three signatures for the function: + * `(features, targets) -> (predictions, loss, train_op)` + * `(features, targets, mode) -> (predictions, loss, train_op)` + * `(features, targets, mode, params) -> + (predictions, loss, train_op)` + Where: + * `features` are single `Tensor` or `dict` of `Tensor`s + (depending on data passed to `fit`), + * `targets` are `Tensor` or + `dict` of `Tensor`s (for multi-head model). + * `mode` represents if this training, evaluation or + prediction. See `ModeKeys` for example keys. + * `params` is a `dict` of hyperparameters. Will receive what + is passed to Estimator in `params` parameter. This allows + to configure Estimators from hyper parameter tunning. + model_dir: Directory to save model parameters, graph and etc. + config: Configuration object. + params: `dict` of hyper parameters that will be passed into `model_fn`. + Keys are names of parameters, values are basic python types. + + Raises: + ValueError: parameters of `model_fn` don't match `params`. + """ super(Estimator, self).__init__(model_dir=model_dir, config=config) if model_fn is not None: # Check number of arguments of the given function matches requirements. diff --git a/tensorflow/g3doc/api_docs/python/contrib.learn.md b/tensorflow/g3doc/api_docs/python/contrib.learn.md index 258cf38b9f2..f5fe4dfb7c8 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.learn.md +++ b/tensorflow/g3doc/api_docs/python/contrib.learn.md @@ -21,14 +21,16 @@ Concrete implementation of this class should provide following functions: * _get_predict_ops `Estimator` implemented below is a good example of how to use this class. - -Parameters: - model_dir: Directory to save model parameters, graph and etc. - - - #### `tf.contrib.learn.BaseEstimator.__init__(model_dir=None, config=None)` {#BaseEstimator.__init__} +Initializes a BaseEstimator instance. +##### Parameters: + + +* `model_dir`: Directory to save model parameters, graph and etc. - - - @@ -270,16 +272,23 @@ component of a nested object. ### `class tf.contrib.learn.Estimator` {#Estimator} Estimator class is the basic TensorFlow model trainer/evaluator. +- - - -Parameters: - model_fn: Model function, takes features and targets tensors or dicts of +#### `tf.contrib.learn.Estimator.__init__(model_fn=None, model_dir=None, config=None, params=None)` {#Estimator.__init__} + +Constructs an Estimator instance. + +##### Args: + + +* `model_fn`: Model function, takes features and targets tensors or dicts of tensors and returns predictions and loss tensors. Supports next three signatures for the function: * `(features, targets) -> (predictions, loss, train_op)` * `(features, targets, mode) -> (predictions, loss, train_op)` * `(features, targets, mode, params) -> (predictions, loss, train_op)` - Where: +* `Where`: * `features` are single `Tensor` or `dict` of `Tensor`s (depending on data passed to `fit`), * `targets` are `Tensor` or @@ -289,15 +298,10 @@ Parameters: * `params` is a `dict` of hyperparameters. Will receive what is passed to Estimator in `params` parameter. This allows to configure Estimators from hyper parameter tunning. - model_dir: Directory to save model parameters, graph and etc. - config: Configuration object. - params: `dict` of hyper parameters that will be passed into `model_fn`. +* `model_dir`: Directory to save model parameters, graph and etc. +* `config`: Configuration object. +* `params`: `dict` of hyper parameters that will be passed into `model_fn`. Keys are names of parameters, values are basic python types. -- - - - -#### `tf.contrib.learn.Estimator.__init__(model_fn=None, model_dir=None, config=None, params=None)` {#Estimator.__init__} - - - - -