Enable passing TFRT python tests.
PiperOrigin-RevId: 337418589 Change-Id: I2267e24d39aa367df75176108df95ddd81d7b968
This commit is contained in:
parent
8fc491bf78
commit
4a05ea9a74
@ -336,7 +336,6 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
|||||||
defun=False,
|
defun=False,
|
||||||
execution_mode=context.ASYNC)
|
execution_mode=context.ASYNC)
|
||||||
|
|
||||||
@test_util.disable_tfrt('Graph is not supported yet. b/156187905')
|
|
||||||
def benchmark_eager_apply_with_defun(self):
|
def benchmark_eager_apply_with_defun(self):
|
||||||
self._benchmark_eager_apply(
|
self._benchmark_eager_apply(
|
||||||
'eager_apply_with_defun',
|
'eager_apply_with_defun',
|
||||||
@ -416,7 +415,6 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
|||||||
resnet50_test_util.device_and_data_format(),
|
resnet50_test_util.device_and_data_format(),
|
||||||
defun=False)
|
defun=False)
|
||||||
|
|
||||||
@test_util.disable_tfrt('Graph is not supported yet. b/156187905')
|
|
||||||
def benchmark_eager_train_datasets_with_defun(self):
|
def benchmark_eager_train_datasets_with_defun(self):
|
||||||
|
|
||||||
def make_iterator(tensors):
|
def make_iterator(tensors):
|
||||||
|
@ -923,19 +923,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
func = lambda: math_ops.reduce_logsumexp(x)
|
func = lambda: math_ops.reduce_logsumexp(x)
|
||||||
self._run(func, 3000, execution_mode=execution_mode)
|
self._run(func, 3000, execution_mode=execution_mode)
|
||||||
|
|
||||||
@test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.")
|
|
||||||
def benchmark_tf_reduce_logsumexp_CPU(self):
|
def benchmark_tf_reduce_logsumexp_CPU(self):
|
||||||
self._benchmark_tf_reduce_logsumexp()
|
self._benchmark_tf_reduce_logsumexp()
|
||||||
|
|
||||||
@test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.")
|
|
||||||
def benchmark_tf_reduce_logsumexp_CPU_async(self):
|
def benchmark_tf_reduce_logsumexp_CPU_async(self):
|
||||||
self._benchmark_tf_reduce_logsumexp(execution_mode=context.ASYNC)
|
self._benchmark_tf_reduce_logsumexp(execution_mode=context.ASYNC)
|
||||||
|
|
||||||
@test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.")
|
|
||||||
def benchmark_tf_reduce_logsumexp_GPU(self):
|
def benchmark_tf_reduce_logsumexp_GPU(self):
|
||||||
self._benchmark_tf_reduce_logsumexp(device=GPU)
|
self._benchmark_tf_reduce_logsumexp(device=GPU)
|
||||||
|
|
||||||
@test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.")
|
|
||||||
def benchmark_tf_reduce_logsumexp_GPU_async(self):
|
def benchmark_tf_reduce_logsumexp_GPU_async(self):
|
||||||
self._benchmark_tf_reduce_logsumexp(device=GPU,
|
self._benchmark_tf_reduce_logsumexp(device=GPU,
|
||||||
execution_mode=context.ASYNC)
|
execution_mode=context.ASYNC)
|
||||||
|
@ -153,7 +153,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
init_fn()
|
init_fn()
|
||||||
self.assertEqual(state[0].numpy(), 2.0)
|
self.assertEqual(state[0].numpy(), 2.0)
|
||||||
|
|
||||||
@test_util.disable_tfrt('Error in native condition op.')
|
|
||||||
def testVariableInitializerNotConstant(self):
|
def testVariableInitializerNotConstant(self):
|
||||||
|
|
||||||
state = []
|
state = []
|
||||||
@ -385,7 +384,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
'defined in another function or code block'):
|
'defined in another function or code block'):
|
||||||
f(array_ops.zeros(shape=(8, 42, 3)))
|
f(array_ops.zeros(shape=(8, 42, 3)))
|
||||||
|
|
||||||
@test_util.disable_tfrt('b/169375363: error code support')
|
|
||||||
def testRuntimeErrorNotSticky(self):
|
def testRuntimeErrorNotSticky(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
|
@ -394,7 +394,6 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('Tensor', lambda: constant_op.constant(1.3+1j)),
|
('Tensor', lambda: constant_op.constant(1.3+1j)),
|
||||||
('Variable', lambda: resource_variable_ops.ResourceVariable(1.3+1j)))
|
('Variable', lambda: resource_variable_ops.ResourceVariable(1.3+1j)))
|
||||||
@test_util.disable_tfrt('cannot create complex tensor in TFRT.')
|
|
||||||
def testCastToPrimitiveTypesFrom(self, value_fn):
|
def testCastToPrimitiveTypesFrom(self, value_fn):
|
||||||
x = value_fn()
|
x = value_fn()
|
||||||
self.assertIsInstance(int(x), int)
|
self.assertIsInstance(int(x), int)
|
||||||
|
@ -326,7 +326,6 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(z, [False, False, False, True])
|
self.assertAllEqual(z, [False, False, False, True])
|
||||||
|
|
||||||
@test_util.disable_tfrt("b/169375363: error code support")
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testBitwiseAndErrors(self):
|
def testBitwiseAndErrors(self):
|
||||||
x_int = constant_op.constant(0)
|
x_int = constant_op.constant(0)
|
||||||
@ -368,7 +367,6 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(z, [False, True, True, True])
|
self.assertAllEqual(z, [False, True, True, True])
|
||||||
|
|
||||||
@test_util.disable_tfrt("b/169375363: error code support")
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testBitwiseOrErrors(self):
|
def testBitwiseOrErrors(self):
|
||||||
x_int = constant_op.constant(0)
|
x_int = constant_op.constant(0)
|
||||||
@ -410,7 +408,6 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(z, [False, True, True, False])
|
self.assertAllEqual(z, [False, True, True, False])
|
||||||
|
|
||||||
@test_util.disable_tfrt("b/169375363: error code support")
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testBitwiseXorErrors(self):
|
def testBitwiseXorErrors(self):
|
||||||
x_int = constant_op.constant(0)
|
x_int = constant_op.constant(0)
|
||||||
@ -450,7 +447,6 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(y, [True, False])
|
self.assertAllEqual(y, [True, False])
|
||||||
|
|
||||||
@test_util.disable_tfrt("b/169375363: error code support")
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testBitwiseNotErrors(self):
|
def testBitwiseNotErrors(self):
|
||||||
if context.executing_eagerly(): # :(
|
if context.executing_eagerly(): # :(
|
||||||
|
@ -203,8 +203,6 @@ class InitializersTest(test.TestCase):
|
|||||||
run_metadata=run_metadata)
|
run_metadata=run_metadata)
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
@test_util.disable_tfrt('b/165614506: Incorrect device name set in '
|
|
||||||
'tfrt::TensorHandle.')
|
|
||||||
def test_eager_orthogonal_gpu(self):
|
def test_eager_orthogonal_gpu(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
v = variable_scope.get_variable(
|
v = variable_scope.get_variable(
|
||||||
|
Loading…
Reference in New Issue
Block a user