Fix metric_spec_test.
Change: 132092514
This commit is contained in:
parent
aa39d8264c
commit
f366a3bcfd
tensorflow/contrib/learn
@ -299,6 +299,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metric_spec_test",
|
||||
size = "small",
|
||||
srcs = ["python/learn/metric_spec_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "experiment_test",
|
||||
size = "small",
|
||||
|
@ -66,8 +66,9 @@ class MetricSpecTest(tf.test.TestCase):
|
||||
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
"MetricSpec without specified prediction requires "
|
||||
"predictions tensor or single element dict, got",
|
||||
"MetricSpec without specified prediction_key "
|
||||
"requires predictions tensor or single element "
|
||||
"dict, got",
|
||||
MetricSpec(metric_fn=test_metric,
|
||||
label_key="label1",
|
||||
weight_key="feature2").create_metric_ops,
|
||||
@ -79,7 +80,7 @@ class MetricSpecTest(tf.test.TestCase):
|
||||
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
"MetricSpec without specified label requires "
|
||||
"MetricSpec without specified label_key requires "
|
||||
"labels tensor or single element dict, got",
|
||||
MetricSpec(metric_fn=test_metric,
|
||||
prediction_key="pred1",
|
||||
@ -122,7 +123,7 @@ class MetricSpecTest(tf.test.TestCase):
|
||||
predictions = "pred1_tensor"
|
||||
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
"MetricSpec with specified prediction requires "
|
||||
"MetricSpec with prediction_key specified requires "
|
||||
"predictions dict, got",
|
||||
MetricSpec(metric_fn=test_metric,
|
||||
prediction_key="pred1",
|
||||
@ -136,7 +137,7 @@ class MetricSpecTest(tf.test.TestCase):
|
||||
predictions = {"pred1": "pred1_tensor", "pred2": "pred2_tensor"}
|
||||
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
"MetricSpec with specified label requires "
|
||||
"MetricSpec with label_key specified requires "
|
||||
"labels dict, got",
|
||||
MetricSpec(metric_fn=test_metric,
|
||||
prediction_key="pred1",
|
||||
|
Loading…
Reference in New Issue
Block a user