Replace tf.estimator.inputs
with tf.compat.v1.estimator.inputs
PiperOrigin-RevId: 223118522
This commit is contained in:
parent
567c0692de
commit
95e808ba44
@ -76,12 +76,12 @@ def main(unused_argv):
|
||||
classifier = tf.estimator.Estimator(model_fn=my_model)
|
||||
|
||||
# Train.
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
|
||||
classifier.train(input_fn=train_input_fn, steps=1000)
|
||||
|
||||
# Predict.
|
||||
test_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
|
||||
predictions = classifier.predict(input_fn=test_input_fn)
|
||||
y_predicted = np.array(list(p['class'] for p in predictions))
|
||||
|
@ -73,12 +73,12 @@ def main(unused_argv):
|
||||
classifier = tf.estimator.Estimator(model_fn=my_model)
|
||||
|
||||
# Train.
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
|
||||
classifier.train(input_fn=train_input_fn, steps=1000)
|
||||
|
||||
# Predict.
|
||||
test_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
|
||||
predictions = classifier.predict(input_fn=test_input_fn)
|
||||
y_predicted = np.array(list(p['class'] for p in predictions))
|
||||
|
@ -134,7 +134,7 @@ def main(unused_argv):
|
||||
tensors=tensors_to_log, every_n_iter=50)
|
||||
|
||||
# Train the model
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={"x": train_data},
|
||||
y=train_labels,
|
||||
batch_size=100,
|
||||
@ -146,11 +146,8 @@ def main(unused_argv):
|
||||
hooks=[logging_hook])
|
||||
|
||||
# Evaluate the model and print results
|
||||
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={"x": eval_data},
|
||||
y=eval_labels,
|
||||
num_epochs=1,
|
||||
shuffle=False)
|
||||
eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={"x": eval_data}, y=eval_labels, num_epochs=1, shuffle=False)
|
||||
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
|
||||
print(eval_results)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user