diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 92172bbf104..135afb04e5d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -97,9 +97,6 @@ class BaseEstimator(sklearn.BaseEstimator):
* _get_predict_ops
`Estimator` implemented below is a good example of how to use this class.
-
- Parameters:
- model_dir: Directory to save model parameters, graph and etc.
"""
__metaclass__ = abc.ABCMeta
@@ -107,6 +104,12 @@ class BaseEstimator(sklearn.BaseEstimator):
_Config = run_config.RunConfig # pylint: disable=invalid-name
def __init__(self, model_dir=None, config=None):
+ """Initializes a BaseEstimator instance.
+
+ Args:
+ model_dir: Directory to save model parameters, graph and etc.
+ config: A RunConfig instance.
+ """
# Model directory.
self._model_dir = model_dir
if self._model_dir is None:
@@ -621,29 +624,6 @@ class BaseEstimator(sklearn.BaseEstimator):
class Estimator(BaseEstimator):
"""Estimator class is the basic TensorFlow model trainer/evaluator.
-
- Parameters:
- model_fn: Model function, takes features and targets tensors or dicts of
- tensors and returns predictions and loss tensors.
- Supports next three signatures for the function:
- * `(features, targets) -> (predictions, loss, train_op)`
- * `(features, targets, mode) -> (predictions, loss, train_op)`
- * `(features, targets, mode, params) ->
- (predictions, loss, train_op)`
- Where:
- * `features` are single `Tensor` or `dict` of `Tensor`s
- (depending on data passed to `fit`),
- * `targets` are `Tensor` or
- `dict` of `Tensor`s (for multi-head model).
- * `mode` represents if this training, evaluation or prediction.
- See `ModeKeys` for example keys.
- * `params` is a `dict` of hyperparameters. Will receive what is
- passed to Estimator in `params` parameter. This allows to
- configure Estimators from hyper parameter tunning.
- model_dir: Directory to save model parameters, graph and etc.
- config: Configuration object.
- params: `dict` of hyper parameters that will be passed into `model_fn`.
- Keys are names of parameters, values are basic python types.
"""
def __init__(self,
@@ -651,6 +631,34 @@ class Estimator(BaseEstimator):
model_dir=None,
config=None,
params=None):
+ """Constructs an Estimator instance.
+
+ Args:
+ model_fn: Model function, takes features and targets tensors or dicts of
+ tensors and returns predictions and loss tensors.
+ Supports next three signatures for the function:
+ * `(features, targets) -> (predictions, loss, train_op)`
+ * `(features, targets, mode) -> (predictions, loss, train_op)`
+ * `(features, targets, mode, params) ->
+ (predictions, loss, train_op)`
+ Where:
+ * `features` are single `Tensor` or `dict` of `Tensor`s
+ (depending on data passed to `fit`),
+ * `targets` are `Tensor` or
+ `dict` of `Tensor`s (for multi-head model).
+ * `mode` represents if this training, evaluation or
+ prediction. See `ModeKeys` for example keys.
+ * `params` is a `dict` of hyperparameters. Will receive what
+ is passed to Estimator in `params` parameter. This allows
+ to configure Estimators from hyper parameter tunning.
+ model_dir: Directory to save model parameters, graph and etc.
+ config: Configuration object.
+ params: `dict` of hyper parameters that will be passed into `model_fn`.
+ Keys are names of parameters, values are basic python types.
+
+ Raises:
+ ValueError: parameters of `model_fn` don't match `params`.
+ """
super(Estimator, self).__init__(model_dir=model_dir, config=config)
if model_fn is not None:
# Check number of arguments of the given function matches requirements.
diff --git a/tensorflow/g3doc/api_docs/python/contrib.learn.md b/tensorflow/g3doc/api_docs/python/contrib.learn.md
index 258cf38b9f2..f5fe4dfb7c8 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.learn.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.learn.md
@@ -21,14 +21,16 @@ Concrete implementation of this class should provide following functions:
* _get_predict_ops
`Estimator` implemented below is a good example of how to use this class.
-
-Parameters:
- model_dir: Directory to save model parameters, graph and etc.
- - -
#### `tf.contrib.learn.BaseEstimator.__init__(model_dir=None, config=None)` {#BaseEstimator.__init__}
+Initializes a BaseEstimator instance.
+##### Parameters:
+
+
+* `model_dir`: Directory to save model parameters, graph and etc.
- - -
@@ -270,16 +272,23 @@ component of a nested object.
### `class tf.contrib.learn.Estimator` {#Estimator}
Estimator class is the basic TensorFlow model trainer/evaluator.
+- - -
-Parameters:
- model_fn: Model function, takes features and targets tensors or dicts of
+#### `tf.contrib.learn.Estimator.__init__(model_fn=None, model_dir=None, config=None, params=None)` {#Estimator.__init__}
+
+Constructs an Estimator instance.
+
+##### Args:
+
+
+* `model_fn`: Model function, takes features and targets tensors or dicts of
tensors and returns predictions and loss tensors.
Supports next three signatures for the function:
* `(features, targets) -> (predictions, loss, train_op)`
* `(features, targets, mode) -> (predictions, loss, train_op)`
* `(features, targets, mode, params) ->
(predictions, loss, train_op)`
- Where:
+* `Where`:
* `features` are single `Tensor` or `dict` of `Tensor`s
(depending on data passed to `fit`),
* `targets` are `Tensor` or
@@ -289,15 +298,10 @@ Parameters:
* `params` is a `dict` of hyperparameters. Will receive what is
passed to Estimator in `params` parameter. This allows to
configure Estimators from hyper parameter tunning.
- model_dir: Directory to save model parameters, graph and etc.
- config: Configuration object.
- params: `dict` of hyper parameters that will be passed into `model_fn`.
+* `model_dir`: Directory to save model parameters, graph and etc.
+* `config`: Configuration object.
+* `params`: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
-- - -
-
-#### `tf.contrib.learn.Estimator.__init__(model_fn=None, model_dir=None, config=None, params=None)` {#Estimator.__init__}
-
-
- - -