Apply run_deprecated_v1 to entire test class
Test classes having operations in the setUp function as well as individual tests annotated with run_deprecated_v1 need to have the whole class annotated with run_deprecated_v1 to ensure the setUp function as well as the test function is run in graph mode. PiperOrigin-RevId: 225964901
This commit is contained in:
parent
15fa7c49e2
commit
33bb4fe143
@ -105,7 +105,7 @@ class TimeToReadableStrTest(test_util.TensorFlowTestCase):
|
||||
cli_shared.time_to_readable_str(100, force_time_unit="ks")
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -119,7 +119,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
def tearDown(self):
|
||||
ops.reset_default_graph()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleFetchNoFeeds(self):
|
||||
run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {})
|
||||
|
||||
@ -183,7 +182,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
|
||||
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTwoFetchesListNoFeeds(self):
|
||||
fetches = [self.const_a, self.const_b]
|
||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||
@ -200,7 +198,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNestedListAsFetches(self):
|
||||
fetches = [self.const_c, [self.const_a, self.const_b]]
|
||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||
@ -214,7 +211,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNestedDictAsFetches(self):
|
||||
fetches = {"c": self.const_c, "ab": {"a": self.const_a, "b": self.const_b}}
|
||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||
@ -232,7 +228,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTwoFetchesAsTupleNoFeeds(self):
|
||||
fetches = (self.const_a, self.const_b)
|
||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||
@ -249,7 +244,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTwoFetchesAsNamedTupleNoFeeds(self):
|
||||
fetches_namedtuple = namedtuple("fetches", "x y")
|
||||
fetches = fetches_namedtuple(self.const_b, self.const_c)
|
||||
@ -267,7 +261,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWithFeedDict(self):
|
||||
feed_dict = {
|
||||
self.const_a: 10.0,
|
||||
@ -291,7 +284,6 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
feed_dict)
|
||||
self.assertEqual("run #1: 1 fetch (c:0); 2 feeds", description)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTensorFilters(self):
|
||||
feed_dict = {self.const_a: 10.0}
|
||||
tensor_filters = {
|
||||
@ -322,20 +314,18 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
command_set.add(annot[2].content)
|
||||
self.assertEqual({"run -f filter_a", "run -f filter_b"}, command_set)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGetRunShortDescriptionWorksForTensorFeedKey(self):
|
||||
short_description = cli_shared.get_run_short_description(
|
||||
1, self.const_a, {self.const_a: 42.0})
|
||||
self.assertEqual("run #1: 1 fetch (a:0); 1 feed (a:0)", short_description)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGetRunShortDescriptionWorksForUnicodeFeedKey(self):
|
||||
short_description = cli_shared.get_run_short_description(
|
||||
1, self.const_a, {u"foo": 42.0})
|
||||
self.assertEqual("run #1: 1 fetch (a:0); 1 feed (foo)", short_description)
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class GetErrorIntroTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -36,7 +36,7 @@ from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import monitored_session
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class DumpingDebugWrapperDiskUsageLimitTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@classmethod
|
||||
|
@ -141,7 +141,7 @@ class TestDebugWrapperSessionBadAction(framework.BaseDebugWrapperSession):
|
||||
return framework.OnRunEndResponse()
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _no_rewrite_session_config(self):
|
||||
|
@ -115,7 +115,7 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
|
||||
self.assertIn("No node-device colocations", summary)
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -197,7 +197,7 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
|
||||
self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*")
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class InputNodesTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -235,7 +235,7 @@ class InputNodesTest(test.TestCase):
|
||||
self.assertRegexpMatches(interpolated_string, expected_regex)
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class InterpolateDeviceSummaryTest(test.TestCase):
|
||||
|
||||
def _fancy_device_function(self, unused_op):
|
||||
@ -279,7 +279,7 @@ class InterpolateDeviceSummaryTest(test.TestCase):
|
||||
self.assertRegexpMatches(result, expected_re)
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class InterpolateColocationSummaryTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -304,13 +304,11 @@ class InterpolateColocationSummaryTest(test.TestCase):
|
||||
|
||||
self.graph = node_three.graph
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNodeThreeHasColocationInterpolation(self):
|
||||
message = "{{colocation_node Three_with_one}}"
|
||||
result = error_interpolation.interpolate(message, self.graph)
|
||||
self.assertIn("colocate_with(One)", result)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
|
||||
message = "{{colocation_node Four_with_three}}"
|
||||
result = error_interpolation.interpolate(message, self.graph)
|
||||
@ -319,14 +317,12 @@ class InterpolateColocationSummaryTest(test.TestCase):
|
||||
"One", result,
|
||||
"Node One should not appear in Four_with_three's summary:\n%s" % result)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
|
||||
message = "{{colocation_node Five_with_one_with_two}}"
|
||||
result = error_interpolation.interpolate(message, self.graph)
|
||||
self.assertIn("colocate_with(One)", result)
|
||||
self.assertIn("colocate_with(Two)", result)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testColocationInterpolationForNodeLackingColocation(self):
|
||||
message = "{{colocation_node One}}"
|
||||
result = error_interpolation.interpolate(message, self.graph)
|
||||
|
@ -1052,7 +1052,16 @@ def run_deprecated_v1(func=None):
|
||||
|
||||
def decorator(f):
|
||||
if tf_inspect.isclass(f):
|
||||
raise ValueError("`run_deprecated_v1` only supports test methods.")
|
||||
setup = f.__dict__.get("setUp")
|
||||
if setup is not None:
|
||||
setattr(f, "setUp", decorator(setup))
|
||||
|
||||
for name, value in f.__dict__.copy().items():
|
||||
if (callable(value) and
|
||||
name.startswith(unittest.TestLoader.testMethodPrefix)):
|
||||
setattr(f, name, decorator(value))
|
||||
|
||||
return f
|
||||
|
||||
def decorated(self, *args, **kwargs):
|
||||
if tf2.enabled():
|
||||
|
@ -31,7 +31,7 @@ from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class Base64OpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import saver
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def create_resource(self, name, eps, max_elements, num_streams=1):
|
||||
@ -82,7 +83,6 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||
self.max_elements = 1 << 16
|
||||
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testBasicQuantileBucketsSingleResource(self):
|
||||
with self.cached_session() as sess:
|
||||
quantile_accumulator_handle = self.create_resource("floats", self.eps,
|
||||
@ -107,7 +107,6 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
|
||||
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testBasicQuantileBucketsMultipleResources(self):
|
||||
with self.cached_session() as sess:
|
||||
quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
|
||||
@ -142,7 +141,6 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
|
||||
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testSaveRestoreAfterFlush(self):
|
||||
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
||||
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||
@ -175,7 +173,6 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
|
||||
self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testSaveRestoreBeforeFlush(self):
|
||||
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
||||
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import momentum as momentum_lib
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
class AbsoluteDifferenceLossTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -51,26 +52,22 @@ class AbsoluteDifferenceLossTest(test.TestCase):
|
||||
losses.absolute_difference(
|
||||
self._predictions, self._predictions, weights=None)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testAllCorrectNoLossWeight(self):
|
||||
loss = losses.absolute_difference(self._predictions, self._predictions)
|
||||
with self.cached_session():
|
||||
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLoss(self):
|
||||
loss = losses.absolute_difference(self._labels, self._predictions)
|
||||
with self.cached_session():
|
||||
self.assertAlmostEqual(5.5, self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
weights = 2.3
|
||||
loss = losses.absolute_difference(self._labels, self._predictions, weights)
|
||||
with self.cached_session():
|
||||
self.assertAlmostEqual(5.5 * weights, self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
weights = 2.3
|
||||
loss = losses.absolute_difference(self._labels, self._predictions,
|
||||
@ -148,7 +145,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
@ -158,7 +155,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
loss = losses.softmax_cross_entropy(labels, logits, weights)
|
||||
self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
@ -311,7 +308,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
@ -321,7 +318,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
|
||||
self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
@ -654,6 +651,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
|
||||
3)
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
class LogLossTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -677,13 +675,11 @@ class LogLossTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
losses.log_loss(self._labels, self._labels, weights=None)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testAllCorrectNoLossWeight(self):
|
||||
loss = losses.log_loss(self._labels, self._labels)
|
||||
with self.cached_session():
|
||||
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testAllCorrectNoLossWeightWithPlaceholder(self):
|
||||
tf_predictions = array_ops.placeholder(
|
||||
dtypes.float32, shape=self._np_labels.shape)
|
||||
@ -692,14 +688,12 @@ class LogLossTest(test.TestCase):
|
||||
self.assertAlmostEqual(
|
||||
0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLoss(self):
|
||||
loss = losses.log_loss(self._labels, self._predictions)
|
||||
with self.cached_session():
|
||||
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
|
||||
self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
weights = 2.3
|
||||
loss = losses.log_loss(self._labels, self._predictions, weights)
|
||||
@ -707,7 +701,6 @@ class LogLossTest(test.TestCase):
|
||||
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
|
||||
self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
weights = 2.3
|
||||
loss = losses.log_loss(self._labels, self._predictions,
|
||||
@ -716,7 +709,6 @@ class LogLossTest(test.TestCase):
|
||||
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
|
||||
self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self):
|
||||
tf_predictions = array_ops.placeholder(
|
||||
dtypes.float32, shape=self._np_predictions.shape)
|
||||
@ -728,7 +720,6 @@ class LogLossTest(test.TestCase):
|
||||
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
|
||||
loss, 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self):
|
||||
tf_predictions = array_ops.placeholder(dtypes.float32, shape=[None, None])
|
||||
weights = 2.3
|
||||
@ -788,7 +779,6 @@ class LogLossTest(test.TestCase):
|
||||
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0,
|
||||
self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
|
||||
weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
|
||||
expected_losses = np.multiply(self._expected_losses, weights)
|
||||
@ -816,7 +806,6 @@ class LogLossTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.assertAlmostEqual(-np.sum(expected_losses), self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
|
||||
weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
|
||||
expected_losses = np.multiply(self._expected_losses, weights)
|
||||
@ -934,6 +923,7 @@ class HuberLossTest(test.TestCase):
|
||||
self.assertAllClose(expected, self.evaluate(loss), atol=1e-5)
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
class MeanSquaredErrorTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -955,26 +945,26 @@ class MeanSquaredErrorTest(test.TestCase):
|
||||
losses.mean_squared_error(predictions=constant_op.constant(0),
|
||||
labels=constant_op.constant(0)).eval())
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllCorrectNoLossWeight(self):
|
||||
loss = losses.mean_squared_error(self._predictions, self._predictions)
|
||||
with self.cached_session():
|
||||
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroLoss(self):
|
||||
loss = losses.mean_squared_error(self._labels, self._predictions)
|
||||
with self.cached_session():
|
||||
self.assertAlmostEqual(49.5, self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
weights = 2.3
|
||||
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
|
||||
with self.cached_session():
|
||||
self.assertAlmostEqual(49.5 * weights, self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
weights = 2.3
|
||||
loss = losses.mean_squared_error(self._labels, self._predictions,
|
||||
@ -1013,6 +1003,7 @@ class MeanSquaredErrorTest(test.TestCase):
|
||||
self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -1068,12 +1059,10 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||
self.assertAlmostEqual(
|
||||
expected_loss, dynamic_inputs_op.eval(feed_dict=feed_dict), places=3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testAllCorrectNoLossWeight(self):
|
||||
self._test_valid_weights(
|
||||
self._labels, self._labels, expected_loss=0.0)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLoss(self):
|
||||
self._test_valid_weights(
|
||||
self._labels, self._predictions,
|
||||
@ -1104,7 +1093,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||
np_grad = self.evaluate(grad)
|
||||
self.assertFalse(np.isnan(np_grad).any())
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
weight = 2.3
|
||||
self._test_valid_weights(
|
||||
@ -1112,7 +1100,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||
expected_loss=weight * np.sum(self._expected_losses),
|
||||
weights=weight)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
weights = 2.3
|
||||
loss = losses.mean_pairwise_squared_error(
|
||||
@ -1123,12 +1110,10 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
|
||||
self.evaluate(loss), 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroLossWithScalarZeroWeight(self):
|
||||
self._test_valid_weights(
|
||||
self._labels, self._predictions, expected_loss=0.0, weights=0.0)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test3d(self):
|
||||
labels = np.array([
|
||||
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
||||
@ -1140,7 +1125,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||
])
|
||||
self._test_valid_weights(labels, predictions, expected_loss=137.5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test3dWeightedScalar(self):
|
||||
labels = np.array([
|
||||
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
||||
@ -1179,7 +1163,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||
weights_placeholder: weights,
|
||||
})
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testInvalid3dWeighted2x0(self):
|
||||
labels = np.array([
|
||||
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
||||
@ -1192,7 +1175,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||
self._test_invalid_weights(
|
||||
labels, predictions, weights=np.asarray((1.2, 3.4)))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test3dWeighted2x3x3(self):
|
||||
labels = np.array([
|
||||
[[1, 9, 2], [12, 11, 10], [9, 8, 7]],
|
||||
@ -1209,7 +1191,6 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||
expected_loss=9 * 137.5,
|
||||
weights=np.ones((2, 3, 3)))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testLossWithAllZeroBatchSpecificWeights(self):
|
||||
self._test_valid_weights(
|
||||
self._labels, self._predictions, expected_loss=0.0,
|
||||
@ -1251,6 +1232,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
|
||||
self.assertAlmostEqual(loss0 + loss1, loss0_1, 5)
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
class CosineDistanceLossTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -1329,7 +1311,6 @@ class CosineDistanceLossTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.assertEqual(3.0 / 4.0, self.evaluate(loss))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMeasurementSpecificWeightsWithPlaceholderWithShape(self):
|
||||
tf_predictions = array_ops.placeholder(
|
||||
dtypes.float32, shape=self._labels.shape)
|
||||
|
@ -1122,7 +1122,7 @@ class StepCounterHookTest(test.TestCase):
|
||||
self.assertGreater(summary_value.simple_value, 0)
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
@test_util.run_deprecated_v1
|
||||
class SummarySaverHookTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -1404,7 +1404,7 @@ class FinalOpsHookTest(test.TestCase):
|
||||
hook.final_ops_values.tolist())
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
@test_util.run_deprecated_v1
|
||||
class ResourceSummarySaverHookTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.summary.writer import writer
|
||||
from tensorflow.python.training import tensorboard_logging
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_deprecated_v1
|
||||
class EventLoggingTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -87,7 +87,6 @@ class EventLoggingTest(test.TestCase):
|
||||
(event_pb2.LogMessage.ERROR, "format")])
|
||||
self.assertEqual(2, self.logged_message_count)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testVerbosity(self):
|
||||
tensorboard_logging.set_summary_writer(self._sw)
|
||||
tensorboard_logging.set_verbosity(tensorboard_logging.ERROR)
|
||||
@ -115,7 +114,6 @@ class EventLoggingTest(test.TestCase):
|
||||
tensorboard_logging.warn("this should work")
|
||||
self.assertEqual(1, self.logged_message_count)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testSummaryWriterFailsAfterClear(self):
|
||||
tensorboard_logging._clear_summary_writer()
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
Loading…
Reference in New Issue
Block a user