* Add input function for training and testing Estimator is decoupled from Scikit Learn interface by moving into separate class SKCompat. Arguments x, y and batch_size are only available in the SKCompat class, Estimator will only accept input_fn * remove extra comma
This commit is contained in:
parent
b0ab95c7af
commit
d0042ed637
@ -134,12 +134,22 @@ def main(unused_argv):
|
|||||||
|
|
||||||
# Instantiate Estimator
|
# Instantiate Estimator
|
||||||
nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params)
|
nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params)
|
||||||
|
|
||||||
|
def get_train_inputs():
|
||||||
|
x = tf.constant(training_set.data)
|
||||||
|
y = tf.constant(training_set.target)
|
||||||
|
return x, y
|
||||||
|
|
||||||
# Fit
|
# Fit
|
||||||
nn.fit(x=training_set.data, y=training_set.target, steps=5000)
|
nn.fit(input_fn=get_train_inputs, steps=5000)
|
||||||
|
|
||||||
# Score accuracy
|
# Score accuracy
|
||||||
ev = nn.evaluate(x=test_set.data, y=test_set.target, steps=1)
|
def get_test_inputs():
|
||||||
|
x = tf.constant(test_set.data)
|
||||||
|
y = tf.constant(test_set.target)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
ev = nn.evaluate(input_fn=get_test_inputs, steps=1)
|
||||||
print("Loss: %s" % ev["loss"])
|
print("Loss: %s" % ev["loss"])
|
||||||
print("Root Mean Squared Error: %s" % ev["rmse"])
|
print("Root Mean Squared Error: %s" % ev["rmse"])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user