Apply run_deprecated_v1 to entire test class
Test classes having operations in the setUp function as well as individual tests annotated with run_deprecated_v1 need to have the whole class annotated with run_deprecated_v1 to ensure the setUp function as well as the test function is run in graph mode. PiperOrigin-RevId: 225964901
This commit is contained in:
parent
15fa7c49e2
commit
33bb4fe143
tensorflow/python
debug
framework
kernel_tests
training
@ -105,7 +105,7 @@ class TimeToReadableStrTest(test_util.TensorFlowTestCase):
|
|||||||
cli_shared.time_to_readable_str(100, force_time_unit="ks")
|
cli_shared.time_to_readable_str(100, force_time_unit="ks")
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -119,7 +119,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testSingleFetchNoFeeds(self):
|
def testSingleFetchNoFeeds(self):
|
||||||
run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {})
|
||||||
|
|
||||||
@ -183,7 +182,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
|
||||||
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
|
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testTwoFetchesListNoFeeds(self):
|
def testTwoFetchesListNoFeeds(self):
|
||||||
fetches = [self.const_a, self.const_b]
|
fetches = [self.const_a, self.const_b]
|
||||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||||
@ -200,7 +198,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||||
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testNestedListAsFetches(self):
|
def testNestedListAsFetches(self):
|
||||||
fetches = [self.const_c, [self.const_a, self.const_b]]
|
fetches = [self.const_c, [self.const_a, self.const_b]]
|
||||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||||
@ -214,7 +211,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||||
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
|
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testNestedDictAsFetches(self):
|
def testNestedDictAsFetches(self):
|
||||||
fetches = {"c": self.const_c, "ab": {"a": self.const_a, "b": self.const_b}}
|
fetches = {"c": self.const_c, "ab": {"a": self.const_a, "b": self.const_b}}
|
||||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||||
@ -232,7 +228,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||||
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
|
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testTwoFetchesAsTupleNoFeeds(self):
|
def testTwoFetchesAsTupleNoFeeds(self):
|
||||||
fetches = (self.const_a, self.const_b)
|
fetches = (self.const_a, self.const_b)
|
||||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||||
@ -249,7 +244,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||||
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testTwoFetchesAsNamedTupleNoFeeds(self):
|
def testTwoFetchesAsNamedTupleNoFeeds(self):
|
||||||
fetches_namedtuple = namedtuple("fetches", "x y")
|
fetches_namedtuple = namedtuple("fetches", "x y")
|
||||||
fetches = fetches_namedtuple(self.const_b, self.const_c)
|
fetches = fetches_namedtuple(self.const_b, self.const_c)
|
||||||
@ -267,7 +261,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||||
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testWithFeedDict(self):
|
def testWithFeedDict(self):
|
||||||
feed_dict = {
|
feed_dict = {
|
||||||
self.const_a: 10.0,
|
self.const_a: 10.0,
|
||||||
@ -291,7 +284,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
feed_dict)
|
feed_dict)
|
||||||
self.assertEqual("run #1: 1 fetch (c:0); 2 feeds", description)
|
self.assertEqual("run #1: 1 fetch (c:0); 2 feeds", description)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testTensorFilters(self):
|
def testTensorFilters(self):
|
||||||
feed_dict = {self.const_a: 10.0}
|
feed_dict = {self.const_a: 10.0}
|
||||||
tensor_filters = {
|
tensor_filters = {
|
||||||
@ -322,20 +314,18 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
command_set.add(annot[2].content)
|
command_set.add(annot[2].content)
|
||||||
self.assertEqual({"run -f filter_a", "run -f filter_b"}, command_set)
|
self.assertEqual({"run -f filter_a", "run -f filter_b"}, command_set)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGetRunShortDescriptionWorksForTensorFeedKey(self):
|
def testGetRunShortDescriptionWorksForTensorFeedKey(self):
|
||||||
short_description = cli_shared.get_run_short_description(
|
short_description = cli_shared.get_run_short_description(
|
||||||
1, self.const_a, {self.const_a: 42.0})
|
1, self.const_a, {self.const_a: 42.0})
|
||||||
self.assertEqual("run #1: 1 fetch (a:0); 1 feed (a:0)", short_description)
|
self.assertEqual("run #1: 1 fetch (a:0); 1 feed (a:0)", short_description)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGetRunShortDescriptionWorksForUnicodeFeedKey(self):
|
def testGetRunShortDescriptionWorksForUnicodeFeedKey(self):
|
||||||
short_description = cli_shared.get_run_short_description(
|
short_description = cli_shared.get_run_short_description(
|
||||||
1, self.const_a, {u"foo": 42.0})
|
1, self.const_a, {u"foo": 42.0})
|
||||||
self.assertEqual("run #1: 1 fetch (a:0); 1 feed (foo)", short_description)
|
self.assertEqual("run #1: 1 fetch (a:0); 1 feed (foo)", short_description)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class GetErrorIntroTest(test_util.TensorFlowTestCase):
|
class GetErrorIntroTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -36,7 +36,7 @@ from tensorflow.python.platform import googletest
|
|||||||
from tensorflow.python.training import gradient_descent
|
from tensorflow.python.training import gradient_descent
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
|
|||||||
from tensorflow.python.training import monitored_session
|
from tensorflow.python.training import monitored_session
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class DumpingDebugWrapperDiskUsageLimitTest(test_util.TensorFlowTestCase):
|
class DumpingDebugWrapperDiskUsageLimitTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -141,7 +141,7 @@ class TestDebugWrapperSessionBadAction(framework.BaseDebugWrapperSession):
|
|||||||
return framework.OnRunEndResponse()
|
return framework.OnRunEndResponse()
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def _no_rewrite_session_config(self):
|
def _no_rewrite_session_config(self):
|
||||||
|
@ -115,7 +115,7 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
|
|||||||
self.assertIn("No node-device colocations", summary)
|
self.assertIn("No node-device colocations", summary)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
|
class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -197,7 +197,7 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
|
|||||||
self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*")
|
self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*")
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class InputNodesTest(test.TestCase):
|
class InputNodesTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -235,7 +235,7 @@ class InputNodesTest(test.TestCase):
|
|||||||
self.assertRegexpMatches(interpolated_string, expected_regex)
|
self.assertRegexpMatches(interpolated_string, expected_regex)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class InterpolateDeviceSummaryTest(test.TestCase):
|
class InterpolateDeviceSummaryTest(test.TestCase):
|
||||||
|
|
||||||
def _fancy_device_function(self, unused_op):
|
def _fancy_device_function(self, unused_op):
|
||||||
@ -279,7 +279,7 @@ class InterpolateDeviceSummaryTest(test.TestCase):
|
|||||||
self.assertRegexpMatches(result, expected_re)
|
self.assertRegexpMatches(result, expected_re)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class InterpolateColocationSummaryTest(test.TestCase):
|
class InterpolateColocationSummaryTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -304,13 +304,11 @@ class InterpolateColocationSummaryTest(test.TestCase):
|
|||||||
|
|
||||||
self.graph = node_three.graph
|
self.graph = node_three.graph
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNodeThreeHasColocationInterpolation(self):
|
def testNodeThreeHasColocationInterpolation(self):
|
||||||
message = "{{colocation_node Three_with_one}}"
|
message = "{{colocation_node Three_with_one}}"
|
||||||
result = error_interpolation.interpolate(message, self.graph)
|
result = error_interpolation.interpolate(message, self.graph)
|
||||||
self.assertIn("colocate_with(One)", result)
|
self.assertIn("colocate_with(One)", result)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
|
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
|
||||||
message = "{{colocation_node Four_with_three}}"
|
message = "{{colocation_node Four_with_three}}"
|
||||||
result = error_interpolation.interpolate(message, self.graph)
|
result = error_interpolation.interpolate(message, self.graph)
|
||||||
@ -319,14 +317,12 @@ class InterpolateColocationSummaryTest(test.TestCase):
|
|||||||
"One", result,
|
"One", result,
|
||||||
"Node One should not appear in Four_with_three's summary:\n%s" % result)
|
"Node One should not appear in Four_with_three's summary:\n%s" % result)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
|
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
|
||||||
message = "{{colocation_node Five_with_one_with_two}}"
|
message = "{{colocation_node Five_with_one_with_two}}"
|
||||||
result = error_interpolation.interpolate(message, self.graph)
|
result = error_interpolation.interpolate(message, self.graph)
|
||||||
self.assertIn("colocate_with(One)", result)
|
self.assertIn("colocate_with(One)", result)
|
||||||
self.assertIn("colocate_with(Two)", result)
|
self.assertIn("colocate_with(Two)", result)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testColocationInterpolationForNodeLackingColocation(self):
|
def testColocationInterpolationForNodeLackingColocation(self):
|
||||||
message = "{{colocation_node One}}"
|
message = "{{colocation_node One}}"
|
||||||
result = error_interpolation.interpolate(message, self.graph)
|
result = error_interpolation.interpolate(message, self.graph)
|
||||||
|
@ -1052,7 +1052,16 @@ def run_deprecated_v1(func=None):
|
|||||||
|
|
||||||
def decorator(f):
|
def decorator(f):
|
||||||
if tf_inspect.isclass(f):
|
if tf_inspect.isclass(f):
|
||||||
raise ValueError("`run_deprecated_v1` only supports test methods.")
|
setup = f.__dict__.get("setUp")
|
||||||
|
if setup is not None:
|
||||||
|
setattr(f, "setUp", decorator(setup))
|
||||||
|
|
||||||
|
for name, value in f.__dict__.copy().items():
|
||||||
|
if (callable(value) and
|
||||||
|
name.startswith(unittest.TestLoader.testMethodPrefix)):
|
||||||
|
setattr(f, name, decorator(value))
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
def decorated(self, *args, **kwargs):
|
def decorated(self, *args, **kwargs):
|
||||||
if tf2.enabled():
|
if tf2.enabled():
|
||||||
|
@ -31,7 +31,7 @@ from tensorflow.python.ops import string_ops
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class Base64OpsTest(test_util.TensorFlowTestCase):
|
class Base64OpsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.platform import googletest
|
|||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training import saver
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
class QuantileOpsTest(test_util.TensorFlowTestCase):
|
class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def create_resource(self, name, eps, max_elements, num_streams=1):
|
def create_resource(self, name, eps, max_elements, num_streams=1):
|
||||||
@ -82,7 +83,6 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||||||
self.max_elements = 1 << 16
|
self.max_elements = 1 << 16
|
||||||
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
|
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testBasicQuantileBucketsSingleResource(self):
|
def testBasicQuantileBucketsSingleResource(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
quantile_accumulator_handle = self.create_resource("floats", self.eps,
|
quantile_accumulator_handle = self.create_resource("floats", self.eps,
|
||||||
@ -107,7 +107,6 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
|
self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
|
||||||
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
|
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testBasicQuantileBucketsMultipleResources(self):
|
def testBasicQuantileBucketsMultipleResources(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
|
quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
|
||||||
@ -142,7 +141,6 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
|
self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
|
||||||
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
|
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testSaveRestoreAfterFlush(self):
|
def testSaveRestoreAfterFlush(self):
|
||||||
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
||||||
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||||
@ -175,7 +173,6 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
|
self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
|
||||||
self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
|
self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testSaveRestoreBeforeFlush(self):
|
def testSaveRestoreBeforeFlush(self):
|
||||||
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
||||||
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.training import momentum as momentum_lib
|
from tensorflow.python.training import momentum as momentum_lib
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
class AbsoluteDifferenceLossTest(test.TestCase):
|
class AbsoluteDifferenceLossTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -51,26 +52,22 @@ class AbsoluteDifferenceLossTest(test.TestCase):
|
|||||||
losses.absolute_difference(
|
losses.absolute_difference(
|
||||||
self._predictions, self._predictions, weights=None)
|
self._predictions, self._predictions, weights=None)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testAllCorrectNoLossWeight(self):
|
def testAllCorrectNoLossWeight(self):
|
||||||
loss = losses.absolute_difference(self._predictions, self._predictions)
|
loss = losses.absolute_difference(self._predictions, self._predictions)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLoss(self):
|
def testNonZeroLoss(self):
|
||||||
loss = losses.absolute_difference(self._labels, self._predictions)
|
loss = losses.absolute_difference(self._labels, self._predictions)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAlmostEqual(5.5, self.evaluate(loss), 3)
|
self.assertAlmostEqual(5.5, self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLossWithPythonScalarWeight(self):
|
def testNonZeroLossWithPythonScalarWeight(self):
|
||||||
weights = 2.3
|
weights = 2.3
|
||||||
loss = losses.absolute_difference(self._labels, self._predictions, weights)
|
loss = losses.absolute_difference(self._labels, self._predictions, weights)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAlmostEqual(5.5 * weights, self.evaluate(loss), 3)
|
self.assertAlmostEqual(5.5 * weights, self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLossWithScalarTensorWeight(self):
|
def testNonZeroLossWithScalarTensorWeight(self):
|
||||||
weights = 2.3
|
weights = 2.3
|
||||||
loss = losses.absolute_difference(self._labels, self._predictions,
|
loss = losses.absolute_difference(self._labels, self._predictions,
|
||||||
@ -148,7 +145,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
|
|||||||
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
|
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
|
||||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def testNonZeroLossWithPythonScalarWeight(self):
|
def testNonZeroLossWithPythonScalarWeight(self):
|
||||||
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
||||||
[0.0, 0.0, 10.0]])
|
[0.0, 0.0, 10.0]])
|
||||||
@ -158,7 +155,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
|
|||||||
loss = losses.softmax_cross_entropy(labels, logits, weights)
|
loss = losses.softmax_cross_entropy(labels, logits, weights)
|
||||||
self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
|
self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def testNonZeroLossWithScalarTensorWeight(self):
|
def testNonZeroLossWithScalarTensorWeight(self):
|
||||||
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
||||||
[0.0, 0.0, 10.0]])
|
[0.0, 0.0, 10.0]])
|
||||||
@ -311,7 +308,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
|||||||
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def testNonZeroLossWithPythonScalarWeight(self):
|
def testNonZeroLossWithPythonScalarWeight(self):
|
||||||
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
||||||
[0.0, 0.0, 10.0]])
|
[0.0, 0.0, 10.0]])
|
||||||
@ -321,7 +318,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
|||||||
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
|
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
|
||||||
self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
|
self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def testNonZeroLossWithScalarTensorWeight(self):
|
def testNonZeroLossWithScalarTensorWeight(self):
|
||||||
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
||||||
[0.0, 0.0, 10.0]])
|
[0.0, 0.0, 10.0]])
|
||||||
@ -654,6 +651,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
|
|||||||
3)
|
3)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
class LogLossTest(test.TestCase):
|
class LogLossTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -677,13 +675,11 @@ class LogLossTest(test.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
losses.log_loss(self._labels, self._labels, weights=None)
|
losses.log_loss(self._labels, self._labels, weights=None)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testAllCorrectNoLossWeight(self):
|
def testAllCorrectNoLossWeight(self):
|
||||||
loss = losses.log_loss(self._labels, self._labels)
|
loss = losses.log_loss(self._labels, self._labels)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testAllCorrectNoLossWeightWithPlaceholder(self):
|
def testAllCorrectNoLossWeightWithPlaceholder(self):
|
||||||
tf_predictions = array_ops.placeholder(
|
tf_predictions = array_ops.placeholder(
|
||||||
dtypes.float32, shape=self._np_labels.shape)
|
dtypes.float32, shape=self._np_labels.shape)
|
||||||
@ -692,14 +688,12 @@ class LogLossTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
|
0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLoss(self):
|
def testNonZeroLoss(self):
|
||||||
loss = losses.log_loss(self._labels, self._predictions)
|
loss = losses.log_loss(self._labels, self._predictions)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
|
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
|
||||||
self.evaluate(loss), 3)
|
self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLossWithPythonScalarWeight(self):
|
def testNonZeroLossWithPythonScalarWeight(self):
|
||||||
weights = 2.3
|
weights = 2.3
|
||||||
loss = losses.log_loss(self._labels, self._predictions, weights)
|
loss = losses.log_loss(self._labels, self._predictions, weights)
|
||||||
@ -707,7 +701,6 @@ class LogLossTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
|
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
|
||||||
self.evaluate(loss), 3)
|
self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLossWithScalarTensorWeight(self):
|
def testNonZeroLossWithScalarTensorWeight(self):
|
||||||
weights = 2.3
|
weights = 2.3
|
||||||
loss = losses.log_loss(self._labels, self._predictions,
|
loss = losses.log_loss(self._labels, self._predictions,
|
||||||
@ -716,7 +709,6 @@ class LogLossTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
|
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
|
||||||
self.evaluate(loss), 3)
|
self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self):
|
def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self):
|
||||||
tf_predictions = array_ops.placeholder(
|
tf_predictions = array_ops.placeholder(
|
||||||
dtypes.float32, shape=self._np_predictions.shape)
|
dtypes.float32, shape=self._np_predictions.shape)
|
||||||
@ -728,7 +720,6 @@ class LogLossTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
|
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
|
||||||
loss, 3)
|
loss, 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self):
|
def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self):
|
||||||
tf_predictions = array_ops.placeholder(dtypes.float32, shape=[None, None])
|
tf_predictions = array_ops.placeholder(dtypes.float32, shape=[None, None])
|
||||||
weights = 2.3
|
weights = 2.3
|
||||||
@ -788,7 +779,6 @@ class LogLossTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0,
|
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0,
|
||||||
self.evaluate(loss), 3)
|
self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
|
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
|
||||||
weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
|
weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
|
||||||
expected_losses = np.multiply(self._expected_losses, weights)
|
expected_losses = np.multiply(self._expected_losses, weights)
|
||||||
@ -816,7 +806,6 @@ class LogLossTest(test.TestCase):
|
|||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAlmostEqual(-np.sum(expected_losses), self.evaluate(loss), 3)
|
self.assertAlmostEqual(-np.sum(expected_losses), self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
|
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
|
||||||
weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
|
weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
|
||||||
expected_losses = np.multiply(self._expected_losses, weights)
|
expected_losses = np.multiply(self._expected_losses, weights)
|
||||||
@ -934,6 +923,7 @@ class HuberLossTest(test.TestCase):
|
|||||||
self.assertAllClose(expected, self.evaluate(loss), atol=1e-5)
|
self.assertAllClose(expected, self.evaluate(loss), atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
class MeanSquaredErrorTest(test.TestCase):
|
class MeanSquaredErrorTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -955,26 +945,26 @@ class MeanSquaredErrorTest(test.TestCase):
|
|||||||
losses.mean_squared_error(predictions=constant_op.constant(0),
|
losses.mean_squared_error(predictions=constant_op.constant(0),
|
||||||
labels=constant_op.constant(0)).eval())
|
labels=constant_op.constant(0)).eval())
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def testAllCorrectNoLossWeight(self):
|
def testAllCorrectNoLossWeight(self):
|
||||||
loss = losses.mean_squared_error(self._predictions, self._predictions)
|
loss = losses.mean_squared_error(self._predictions, self._predictions)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def testNonZeroLoss(self):
|
def testNonZeroLoss(self):
|
||||||
loss = losses.mean_squared_error(self._labels, self._predictions)
|
loss = losses.mean_squared_error(self._labels, self._predictions)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAlmostEqual(49.5, self.evaluate(loss), 3)
|
self.assertAlmostEqual(49.5, self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def testNonZeroLossWithPythonScalarWeight(self):
|
def testNonZeroLossWithPythonScalarWeight(self):
|
||||||
weights = 2.3
|
weights = 2.3
|
||||||
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
|
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAlmostEqual(49.5 * weights, self.evaluate(loss), 3)
|
self.assertAlmostEqual(49.5 * weights, self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def testNonZeroLossWithScalarTensorWeight(self):
|
def testNonZeroLossWithScalarTensorWeight(self):
|
||||||
weights = 2.3
|
weights = 2.3
|
||||||
loss = losses.mean_squared_error(self._labels, self._predictions,
|
loss = losses.mean_squared_error(self._labels, self._predictions,
|
||||||
@ -1013,6 +1003,7 @@ class MeanSquaredErrorTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
class MeanPairwiseSquaredErrorTest(test.TestCase):
|
class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -1068,12 +1059,10 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
expected_loss, dynamic_inputs_op.eval(feed_dict=feed_dict), places=3)
|
expected_loss, dynamic_inputs_op.eval(feed_dict=feed_dict), places=3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testAllCorrectNoLossWeight(self):
|
def testAllCorrectNoLossWeight(self):
|
||||||
self._test_valid_weights(
|
self._test_valid_weights(
|
||||||
self._labels, self._labels, expected_loss=0.0)
|
self._labels, self._labels, expected_loss=0.0)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLoss(self):
|
def testNonZeroLoss(self):
|
||||||
self._test_valid_weights(
|
self._test_valid_weights(
|
||||||
self._labels, self._predictions,
|
self._labels, self._predictions,
|
||||||
@ -1104,7 +1093,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
|||||||
np_grad = self.evaluate(grad)
|
np_grad = self.evaluate(grad)
|
||||||
self.assertFalse(np.isnan(np_grad).any())
|
self.assertFalse(np.isnan(np_grad).any())
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLossWithPythonScalarWeight(self):
|
def testNonZeroLossWithPythonScalarWeight(self):
|
||||||
weight = 2.3
|
weight = 2.3
|
||||||
self._test_valid_weights(
|
self._test_valid_weights(
|
||||||
@ -1112,7 +1100,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
|||||||
expected_loss=weight * np.sum(self._expected_losses),
|
expected_loss=weight * np.sum(self._expected_losses),
|
||||||
weights=weight)
|
weights=weight)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testNonZeroLossWithScalarTensorWeight(self):
|
def testNonZeroLossWithScalarTensorWeight(self):
|
||||||
weights = 2.3
|
weights = 2.3
|
||||||
loss = losses.mean_pairwise_squared_error(
|
loss = losses.mean_pairwise_squared_error(
|
||||||
@ -1123,12 +1110,10 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
|
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
|
||||||
self.evaluate(loss), 3)
|
self.evaluate(loss), 3)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testNonZeroLossWithScalarZeroWeight(self):
|
def testNonZeroLossWithScalarZeroWeight(self):
|
||||||
self._test_valid_weights(
|
self._test_valid_weights(
|
||||||
self._labels, self._predictions, expected_loss=0.0, weights=0.0)
|
self._labels, self._predictions, expected_loss=0.0, weights=0.0)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test3d(self):
|
def test3d(self):
|
||||||
labels = np.array([
|
labels = np.array([
|
||||||
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
||||||
@ -1140,7 +1125,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
|||||||
])
|
])
|
||||||
self._test_valid_weights(labels, predictions, expected_loss=137.5)
|
self._test_valid_weights(labels, predictions, expected_loss=137.5)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test3dWeightedScalar(self):
|
def test3dWeightedScalar(self):
|
||||||
labels = np.array([
|
labels = np.array([
|
||||||
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
||||||
@ -1179,7 +1163,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
|||||||
weights_placeholder: weights,
|
weights_placeholder: weights,
|
||||||
})
|
})
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testInvalid3dWeighted2x0(self):
|
def testInvalid3dWeighted2x0(self):
|
||||||
labels = np.array([
|
labels = np.array([
|
||||||
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
||||||
@ -1192,7 +1175,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
|||||||
self._test_invalid_weights(
|
self._test_invalid_weights(
|
||||||
labels, predictions, weights=np.asarray((1.2, 3.4)))
|
labels, predictions, weights=np.asarray((1.2, 3.4)))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test3dWeighted2x3x3(self):
|
def test3dWeighted2x3x3(self):
|
||||||
labels = np.array([
|
labels = np.array([
|
||||||
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
||||||
@ -1209,7 +1191,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
|||||||
expected_loss=9 * 137.5,
|
expected_loss=9 * 137.5,
|
||||||
weights=np.ones((2, 3, 3)))
|
weights=np.ones((2, 3, 3)))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testLossWithAllZeroBatchSpecificWeights(self):
|
def testLossWithAllZeroBatchSpecificWeights(self):
|
||||||
self._test_valid_weights(
|
self._test_valid_weights(
|
||||||
self._labels, self._predictions, expected_loss=0.0,
|
self._labels, self._predictions, expected_loss=0.0,
|
||||||
@ -1251,6 +1232,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(loss0 + loss1, loss0_1, 5)
|
self.assertAlmostEqual(loss0 + loss1, loss0_1, 5)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
class CosineDistanceLossTest(test.TestCase):
|
class CosineDistanceLossTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -1329,7 +1311,6 @@ class CosineDistanceLossTest(test.TestCase):
|
|||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertEqual(3.0 / 4.0, self.evaluate(loss))
|
self.assertEqual(3.0 / 4.0, self.evaluate(loss))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testMeasurementSpecificWeightsWithPlaceholderWithShape(self):
|
def testMeasurementSpecificWeightsWithPlaceholderWithShape(self):
|
||||||
tf_predictions = array_ops.placeholder(
|
tf_predictions = array_ops.placeholder(
|
||||||
dtypes.float32, shape=self._labels.shape)
|
dtypes.float32, shape=self._labels.shape)
|
||||||
|
@ -1122,7 +1122,7 @@ class StepCounterHookTest(test.TestCase):
|
|||||||
self.assertGreater(summary_value.simple_value, 0)
|
self.assertGreater(summary_value.simple_value, 0)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
@test_util.run_deprecated_v1
|
||||||
class SummarySaverHookTest(test.TestCase):
|
class SummarySaverHookTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -1404,7 +1404,7 @@ class FinalOpsHookTest(test.TestCase):
|
|||||||
hook.final_ops_values.tolist())
|
hook.final_ops_values.tolist())
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
@test_util.run_deprecated_v1
|
||||||
class ResourceSummarySaverHookTest(test.TestCase):
|
class ResourceSummarySaverHookTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -33,7 +33,7 @@ from tensorflow.python.summary.writer import writer
|
|||||||
from tensorflow.python.training import tensorboard_logging
|
from tensorflow.python.training import tensorboard_logging
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
class EventLoggingTest(test.TestCase):
|
class EventLoggingTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -87,7 +87,6 @@ class EventLoggingTest(test.TestCase):
|
|||||||
(event_pb2.LogMessage.ERROR, "format")])
|
(event_pb2.LogMessage.ERROR, "format")])
|
||||||
self.assertEqual(2, self.logged_message_count)
|
self.assertEqual(2, self.logged_message_count)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testVerbosity(self):
|
def testVerbosity(self):
|
||||||
tensorboard_logging.set_summary_writer(self._sw)
|
tensorboard_logging.set_summary_writer(self._sw)
|
||||||
tensorboard_logging.set_verbosity(tensorboard_logging.ERROR)
|
tensorboard_logging.set_verbosity(tensorboard_logging.ERROR)
|
||||||
@ -115,7 +114,6 @@ class EventLoggingTest(test.TestCase):
|
|||||||
tensorboard_logging.warn("this should work")
|
tensorboard_logging.warn("this should work")
|
||||||
self.assertEqual(1, self.logged_message_count)
|
self.assertEqual(1, self.logged_message_count)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testSummaryWriterFailsAfterClear(self):
|
def testSummaryWriterFailsAfterClear(self):
|
||||||
tensorboard_logging._clear_summary_writer()
|
tensorboard_logging._clear_summary_writer()
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
|
Loading…
Reference in New Issue
Block a user