Print output of predictions in simple estimator example for DistributionStrategy
Resolved the TODO that asked to print meaningful results from the Estimator's predictions. Collects all elements yielded by the generator in a list and prints that list instead of the generator object. Results in the following output: `Prediction results: [{'logits': array([1.0162734], dtype=float32)}, ..., {'logits': array([1.0162734], dtype=float32)}]`.
This commit is contained in:
parent
a6d62ffea3
commit
93e2a37a5f
@ -78,9 +78,8 @@ def main(_):
|
||||
predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
|
||||
return predict_features
|
||||
|
||||
predictions = estimator.predict(input_fn=predict_input_fn)
|
||||
# TODO(anjalsridhar): This returns a generator object, figure out how to get
|
||||
# meaningful results here.
|
||||
prediction_iterable = estimator.predict(input_fn=predict_input_fn)
|
||||
predictions = [prediction_iterable.next() for _ in range(10)]
|
||||
print("Prediction results: {}".format(predictions))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user