diff --git a/tensorflow/docs_src/extend/estimators.md b/tensorflow/docs_src/extend/estimators.md index a18da1708d4..5265e5889be 100644 --- a/tensorflow/docs_src/extend/estimators.md +++ b/tensorflow/docs_src/extend/estimators.md @@ -578,7 +578,12 @@ def model_fn(features, labels, mode, params): # Reshape output layer to 1-dim Tensor to return predictions predictions = tf.reshape(output_layer, [-1]) - predictions_dict = {"ages": predictions} + + # Provide an estimator spec for `ModeKeys.PREDICT`. + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec( + mode=mode, + predictions={"ages": predictions}) # Calculate loss using mean squared error loss = tf.losses.mean_squared_error(labels, predictions) @@ -594,9 +599,9 @@ def model_fn(features, labels, mode, params): train_op = optimizer.minimize( loss=loss, global_step=tf.train.get_global_step()) + # Provide an estimator spec for `ModeKeys.EVAL` and `ModeKeys.TRAIN` modes. return tf.estimator.EstimatorSpec( mode=mode, - predictions=predictions_dict, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops) diff --git a/tensorflow/examples/tutorials/estimators/abalone.py b/tensorflow/examples/tutorials/estimators/abalone.py index 4765d5dabf4..737b3ee5d6a 100644 --- a/tensorflow/examples/tutorials/estimators/abalone.py +++ b/tensorflow/examples/tutorials/estimators/abalone.py @@ -87,25 +87,30 @@ def model_fn(features, labels, mode, params): # Reshape output layer to 1-dim Tensor to return predictions predictions = tf.reshape(output_layer, [-1]) - predictions_dict = {"ages": predictions} + + # Provide an estimator spec for `ModeKeys.PREDICT`. + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec( + mode=mode, + predictions={"ages": predictions}) # Calculate loss using mean squared error loss = tf.losses.mean_squared_error(labels, predictions) + optimizer = tf.train.GradientDescentOptimizer( + learning_rate=params["learning_rate"]) + train_op = optimizer.minimize( + loss=loss, global_step=tf.train.get_global_step()) + # Calculate root mean squared error as additional eval metric eval_metric_ops = { "rmse": tf.metrics.root_mean_squared_error( tf.cast(labels, tf.float64), predictions) } - optimizer = tf.train.GradientDescentOptimizer( - learning_rate=params["learning_rate"]) - train_op = optimizer.minimize( - loss=loss, global_step=tf.train.get_global_step()) - + # Provide an estimator spec for `ModeKeys.EVAL` and `ModeKeys.TRAIN` modes. return tf.estimator.EstimatorSpec( mode=mode, - predictions=predictions_dict, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)