tutorials/monitors/iris_monitor.py fixes (#8927)
* Iris_Monitors.py validation metric prediction key update to "classes" * Iris_monitors.py import removed code to import MetricSpec * iris_monitors.py repetitive code removed code should not have been duplicated
This commit is contained in:
parent
a336b06d29
commit
7ab36077ef
@ -21,7 +21,6 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
@ -41,18 +40,15 @@ def main(unused_argv):
|
||||
"accuracy":
|
||||
tf.contrib.learn.MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_accuracy,
|
||||
prediction_key=
|
||||
tf.contrib.learn.PredictionKey.CLASSES),
|
||||
prediction_key="classes"),
|
||||
"precision":
|
||||
tf.contrib.learn.MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_precision,
|
||||
prediction_key=
|
||||
tf.contrib.learn.PredictionKey.CLASSES),
|
||||
prediction_key="classes"),
|
||||
"recall":
|
||||
tf.contrib.learn.MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_recall,
|
||||
prediction_key=
|
||||
tf.contrib.learn.PredictionKey.CLASSES)
|
||||
prediction_key="classes")
|
||||
}
|
||||
validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
|
||||
test_set.data,
|
||||
@ -66,26 +62,6 @@ def main(unused_argv):
|
||||
# Specify that all features have real-value data
|
||||
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
|
||||
|
||||
validation_metrics = {
|
||||
"accuracy": MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_accuracy,
|
||||
prediction_key="classes"),
|
||||
"recall": MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_recall,
|
||||
prediction_key="classes"),
|
||||
"precision": MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_precision,
|
||||
prediction_key="classes")
|
||||
}
|
||||
validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
|
||||
test_set.data,
|
||||
test_set.target,
|
||||
every_n_steps=50,
|
||||
metrics=validation_metrics,
|
||||
early_stopping_metric="loss",
|
||||
early_stopping_metric_minimize=True,
|
||||
early_stopping_rounds=200)
|
||||
|
||||
# Build 3 layer DNN with 10, 20, 10 units respectively.
|
||||
classifier = tf.contrib.learn.DNNClassifier(
|
||||
feature_columns=feature_columns,
|
||||
|
Loading…
Reference in New Issue
Block a user