Print output of predictions in simple estimator example for DistributionStrategy

Resolved the TODO that asked to print meaningful results from the Estimator's predictions.
Collects all elements yielded by the generator in a list and prints that list instead of the generator object.
Results in the following output: `Prediction results: [{'logits': array([1.0162734], dtype=float32)}, ..., {'logits': array([1.0162734], dtype=float32)}]`.
This commit is contained in:
Joppe Geluykens 2018-08-18 13:46:56 +02:00 committed by GitHub
parent a6d62ffea3
commit 93e2a37a5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -78,9 +78,8 @@ def main(_):
predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
return predict_features
predictions = estimator.predict(input_fn=predict_input_fn)
# TODO(anjalsridhar): This returns a generator object, figure out how to get
# meaningful results here.
prediction_iterable = estimator.predict(input_fn=predict_input_fn)
predictions = [prediction_iterable.next() for _ in range(10)]
print("Prediction results: {}".format(predictions))