From 93e2a37a5f656a71045de705d8f1cad545f295ab Mon Sep 17 00:00:00 2001 From: Joppe Geluykens Date: Sat, 18 Aug 2018 13:46:56 +0200 Subject: [PATCH 1/2] Print output of predictions in simple estimator example for DistributionStrategy Resolved the TODO that asked to print meaningful results from the Estimator's predictions. Collects all elements yielded by the generator in a list and prints that list instead of the generator object. Results in the following output: `Prediction results: [{'logits': array([1.0162734], dtype=float32)}, ..., {'logits': array([1.0162734], dtype=float32)}]`. --- .../distribute/python/examples/simple_estimator_example.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py index 44a69ed23a4..3f1dae332ff 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -78,9 +78,8 @@ def main(_): predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) return predict_features - predictions = estimator.predict(input_fn=predict_input_fn) - # TODO(anjalsridhar): This returns a generator object, figure out how to get - # meaningful results here. + prediction_iterable = estimator.predict(input_fn=predict_input_fn) + predictions = [prediction_iterable.next() for _ in range(10)] print("Prediction results: {}".format(predictions)) From eba4b9ff1eb8b88b38a32689155eff5c267cf6e9 Mon Sep 17 00:00:00 2001 From: Joppe Geluykens Date: Thu, 13 Sep 2018 00:21:39 +0200 Subject: [PATCH 2/2] Add comment clarifying items in `predictions` list This way, it becomes clear what the result of the `print()` call will look like. --- .../distribute/python/examples/simple_estimator_example.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py index 3f1dae332ff..e48d09a0a95 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -79,6 +79,8 @@ def main(_): return predict_features prediction_iterable = estimator.predict(input_fn=predict_input_fn) + # Create a list containing each of the prediction dictionaries that map + # the key 'logits' to an array of model outputs. predictions = [prediction_iterable.next() for _ in range(10)] print("Prediction results: {}".format(predictions))