Clear up some disable_tfrt decorators.

PiperOrigin-RevId: 338558257
Change-Id: I21458d7faf0565d52d50bf48dc62fecd72cca26d
This commit is contained in:
Chuanhao Zhuge 2020-10-22 15:12:31 -07:00 committed by TensorFlower Gardener
parent c17952e2c5
commit ab53121fa8

View File

@ -103,7 +103,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
v0 = resource_variable_ops.ResourceVariable(1.0) v0 = resource_variable_ops.ResourceVariable(1.0)
self.assertAllEqual(v0.numpy(), 1.0) self.assertAllEqual(v0.numpy(), 1.0)
@test_util.disable_tfrt("b/169375363: error code support")
def testReadVariableDtypeMismatchEager(self): def testReadVariableDtypeMismatchEager(self):
with context.eager_mode(): with context.eager_mode():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
@ -199,7 +198,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
value, _ = sess.run([v, v.assign_add(1.0)]) value, _ = sess.run([v, v.assign_add(1.0)])
self.assertAllEqual(value, 0.0) self.assertAllEqual(value, 0.0)
@test_util.disable_tfrt("b/169375363: error code support")
def testAssignVariableDtypeMismatchEager(self): def testAssignVariableDtypeMismatchEager(self):
with context.eager_mode(): with context.eager_mode():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
@ -750,7 +748,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
self.assertEqual(v.handle.op.colocation_groups(), self.assertEqual(v.handle.op.colocation_groups(),
v.initializer.inputs[1].op.colocation_groups()) v.initializer.inputs[1].op.colocation_groups())
@test_util.disable_tfrt("b/169375363: error code support")
def testCountUpTo(self): def testCountUpTo(self):
with context.eager_mode(): with context.eager_mode():
v = resource_variable_ops.ResourceVariable(0, name="upto") v = resource_variable_ops.ResourceVariable(0, name="upto")
@ -758,7 +755,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
v.count_up_to(1) v.count_up_to(1)
@test_util.disable_tfrt("b/169375363: error code support")
def testCountUpToFunction(self): def testCountUpToFunction(self):
with context.eager_mode(): with context.eager_mode():
v = resource_variable_ops.ResourceVariable(0, name="upto") v = resource_variable_ops.ResourceVariable(0, name="upto")
@ -857,7 +853,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
variable_def=other_v_def) variable_def=other_v_def)
self.assertIsNotNone(other_v_prime._cached_value) self.assertIsNotNone(other_v_prime._cached_value)
@test_util.disable_tfrt("b/169375363: error code support")
def testVariableDefInitializedInstances(self): def testVariableDefInitializedInstances(self):
with ops.Graph().as_default(), self.cached_session(): with ops.Graph().as_default(), self.cached_session():
v_def = resource_variable_ops.ResourceVariable( v_def = resource_variable_ops.ResourceVariable(
@ -979,7 +974,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
self.evaluate(assign_without_read) self.evaluate(assign_without_read)
self.assertEqual(0.0, self.evaluate(v.value())) self.assertEqual(0.0, self.evaluate(v.value()))
@test_util.disable_tfrt("b/169375363: error code support")
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testDestroyResource(self): def testDestroyResource(self):
@ -1006,7 +1000,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
[assign], [assign],
feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)}) feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)})
@test_util.disable_tfrt("b/169375363: error code support")
def testAssignDifferentShapesEagerNotAllowed(self): def testAssignDifferentShapesEagerNotAllowed(self):
with context.eager_mode(): with context.eager_mode():
with variable_scope.variable_scope("foo"): with variable_scope.variable_scope("foo"):
@ -1069,7 +1062,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
.batch_scatter_update(batch_slices2), .batch_scatter_update(batch_slices2),
[[1, 3], [2, 3]]) [[1, 3], [2, 3]])
@test_util.disable_tfrt("b/169375363: error code support")
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testInitValueWrongShape(self): def testInitValueWrongShape(self):
with self.assertRaisesWithPredicateMatch( with self.assertRaisesWithPredicateMatch(
@ -1088,7 +1080,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
self.assertEqual(v.dtype, w.dtype) self.assertEqual(v.dtype, w.dtype)
# TODO(alive): get caching to work in eager mode. # TODO(alive): get caching to work in eager mode.
@test_util.disable_tfrt("b/169375363: error code support")
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testCachingDevice(self): def testCachingDevice(self):
with ops.device("/job:server/task:1"): with ops.device("/job:server/task:1"):
@ -1105,7 +1096,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = w.value().op.get_attr("_class") _ = w.value().op.get_attr("_class")
@test_util.disable_tfrt("b/169375363: error code support")
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testSharedName(self): def testSharedName(self):
with self.cached_session(): with self.cached_session():
@ -1164,7 +1154,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
v.initializer.run(feed_dict={v.initial_value: 3.0}) v.initializer.run(feed_dict={v.initial_value: 3.0})
self.assertEqual(3.0, v.value().eval()) self.assertEqual(3.0, v.value().eval())
@test_util.disable_tfrt("b/169375363: error code support")
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testControlFlowInitialization(self): def testControlFlowInitialization(self):
"""Expects an error if an initializer is in a control-flow scope.""" """Expects an error if an initializer is in a control-flow scope."""
@ -1252,7 +1241,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
self.assertEqual(1, v1.read_value().numpy()) self.assertEqual(1, v1.read_value().numpy())
self.assertEqual(2, v2.read_value().numpy()) self.assertEqual(2, v2.read_value().numpy())
@test_util.disable_tfrt("b/169375363: error code support")
def testDestruction(self): def testDestruction(self):
with context.eager_mode(): with context.eager_mode():
var = resource_variable_ops.ResourceVariable(initial_value=1.0, var = resource_variable_ops.ResourceVariable(initial_value=1.0,
@ -1340,7 +1328,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
state_ops.scatter_update(v, [1], [3]) state_ops.scatter_update(v, [1], [3])
self.assertAllEqual([1.0, 3.0], v.numpy()) self.assertAllEqual([1.0, 3.0], v.numpy())
@test_util.disable_tfrt("b/169375363: error code support")
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testScatterUpdateInvalidArgs(self): def testScatterUpdateInvalidArgs(self):
v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update") v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update")
@ -1350,7 +1337,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
with self.assertRaisesRegex(Exception, r"shape.*2.*3"): with self.assertRaisesRegex(Exception, r"shape.*2.*3"):
state_ops.scatter_update(v, [0, 1], [0, 1, 2]) state_ops.scatter_update(v, [0, 1], [0, 1, 2])
@test_util.disable_tfrt("b/169375363: error code support")
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testAssignIncompatibleShape(self): def testAssignIncompatibleShape(self):
v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) v = resource_variable_ops.ResourceVariable([0, 1, 2, 3])