From d0042ed637c5808d5de161353baf5209e46a5985 Mon Sep 17 00:00:00 2001 From: wuhaixutab Date: Thu, 4 May 2017 12:23:13 +0800 Subject: [PATCH] Add input function for training and testing (#9617) (#9650) * Add input function for training and testing Estimator is decoupled from Scikit Learn interface by moving into separate class SKCompat. Arguments x, y and batch_size are only available in the SKCompat class, Estimator will only accept input_fn * remove extra comma --- .../examples/tutorials/estimators/abalone.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tensorflow/examples/tutorials/estimators/abalone.py b/tensorflow/examples/tutorials/estimators/abalone.py index 932ce8a8b25..3c0ea2e4090 100644 --- a/tensorflow/examples/tutorials/estimators/abalone.py +++ b/tensorflow/examples/tutorials/estimators/abalone.py @@ -134,12 +134,22 @@ def main(unused_argv): # Instantiate Estimator nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params) - + + def get_train_inputs(): + x = tf.constant(training_set.data) + y = tf.constant(training_set.target) + return x, y + # Fit - nn.fit(x=training_set.data, y=training_set.target, steps=5000) + nn.fit(input_fn=get_train_inputs, steps=5000) # Score accuracy - ev = nn.evaluate(x=test_set.data, y=test_set.target, steps=1) + def get_test_inputs(): + x = tf.constant(test_set.data) + y = tf.constant(test_set.target) + return x, y + + ev = nn.evaluate(input_fn=get_test_inputs, steps=1) print("Loss: %s" % ev["loss"]) print("Root Mean Squared Error: %s" % ev["rmse"])