diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py index 44a69ed23a4..3f1dae332ff 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -78,9 +78,8 @@ def main(_): predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) return predict_features - predictions = estimator.predict(input_fn=predict_input_fn) - # TODO(anjalsridhar): This returns a generator object, figure out how to get - # meaningful results here. + prediction_iterable = estimator.predict(input_fn=predict_input_fn) + predictions = [prediction_iterable.next() for _ in range(10)] print("Prediction results: {}".format(predictions))