Internal cleanup: set up boilerplate for tests which require running inside a tf.function. The boilerplate helps keep the tests clean, although its mechanism may be tricky.
PiperOrigin-RevId: 322444035 Change-Id: I56bbb1e1eb2feb65e16b8171d4a1bd20e1b190a5
This commit is contained in:
parent
cdc531263a
commit
be88c5f8a7
@ -31,26 +31,22 @@ import six
|
||||
from tensorflow.python.autograph.operators import control_flow
|
||||
from tensorflow.python.autograph.operators import variables as variable_operators
|
||||
from tensorflow.python.autograph.utils import ag_logging
|
||||
from tensorflow.python.autograph.utils import testing
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ForLoopTest(test.TestCase):
|
||||
class ForLoopTest(testing.AutoGraphTestCase):
|
||||
|
||||
def test_tensor(self):
|
||||
def body(i):
|
||||
@ -70,7 +66,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
self.assertEqual(s, (1234,))
|
||||
|
||||
def test_range_tensor(self):
|
||||
def body(i):
|
||||
@ -90,7 +86,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'iterate_names': 'i'})
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
self.assertEqual(s, (1234,))
|
||||
|
||||
def test_range_tensor_explicit_limit_delta(self):
|
||||
def body(i):
|
||||
@ -110,7 +106,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'iterate_names': 'i'})
|
||||
self.assertEqual(self.evaluate(s), (-171207,))
|
||||
self.assertEqual(s, (-171207,))
|
||||
|
||||
def test_range_tensor_explicit_limit_negative_delta(self):
|
||||
def body(i):
|
||||
@ -130,7 +126,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'iterate_names': 'i'})
|
||||
self.assertEqual(self.evaluate(s), (171207,))
|
||||
self.assertEqual(s, (171207,))
|
||||
|
||||
def test_range_tensor_random_delta(self):
|
||||
def body(i):
|
||||
@ -151,7 +147,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'iterate_names': 'i'})
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
self.assertEqual(s, (1234,))
|
||||
|
||||
def test_range_tensor_random_negative_delta(self):
|
||||
def body(i):
|
||||
@ -172,7 +168,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'iterate_names': 'i'})
|
||||
self.assertEqual(self.evaluate(s), (171207,))
|
||||
self.assertEqual(s, (171207,))
|
||||
|
||||
def test_tensor_with_extra_test_object_vars(self):
|
||||
class MutableObject(object):
|
||||
@ -198,7 +194,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('state.field_1', 'state.field_2'),
|
||||
opts={})
|
||||
self.assertEqual(self.evaluate((state.field_1, state.field_2)), (6, 6))
|
||||
self.assertEqual((state.field_1, state.field_2), (6, 6))
|
||||
|
||||
def test_python(self):
|
||||
def body(i):
|
||||
@ -297,7 +293,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
self.assertEqual(s, (1234,))
|
||||
|
||||
def test_dataset_with_extra_test(self):
|
||||
def body(i):
|
||||
@ -317,7 +313,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
self.assertEqual(self.evaluate(s), (12,))
|
||||
self.assertEqual(s, (12,))
|
||||
|
||||
def test_dataset_with_extra_test_collection_vars(self):
|
||||
def body(i):
|
||||
@ -339,7 +335,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('l[0]', 's'),
|
||||
opts={})
|
||||
self.assertEqual(self.evaluate((l[0], s)), (3, 3))
|
||||
self.assertEqual((l[0], s), (3, 3))
|
||||
|
||||
def test_dataset_with_extra_test_iteration_limiting(self):
|
||||
def body(it):
|
||||
@ -360,100 +356,83 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('i',),
|
||||
opts={})
|
||||
self.assertEqual(self.evaluate(i), (3,))
|
||||
self.assertEqual(i, (3,))
|
||||
|
||||
def test_tf_dataset_no_loop_vars(self):
|
||||
def body(i):
|
||||
v.assign(v.read_value() * 10 + i)
|
||||
|
||||
v = variables.Variable(0, dtype=dtypes.int64)
|
||||
self.evaluate(v.initializer)
|
||||
v = self.variable('v', 0, dtypes.int64)
|
||||
|
||||
# tf.function required for the automatic control dependencies, and because
|
||||
# ops test for its presence.
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
control_flow.for_stmt(
|
||||
dataset_ops.Dataset.range(5),
|
||||
extra_test=None,
|
||||
body=body,
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
symbol_names=(),
|
||||
opts={})
|
||||
control_flow.for_stmt(
|
||||
dataset_ops.Dataset.range(5),
|
||||
extra_test=None,
|
||||
body=body,
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
symbol_names=(),
|
||||
opts={})
|
||||
|
||||
self.evaluate(test_fn())
|
||||
self.assertEqual(self.evaluate(v.read_value()), 1234)
|
||||
self.assertEqual(v.read_value(), 1234)
|
||||
|
||||
def test_tf_iterator(self):
|
||||
# graph-mode iterators are only supported inside tf.function.
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
def body(i):
|
||||
nonlocal s
|
||||
s = s * 10 + i
|
||||
def body(i):
|
||||
nonlocal s
|
||||
s = s * 10 + i
|
||||
|
||||
def set_state(loop_vars):
|
||||
nonlocal s
|
||||
s, = loop_vars
|
||||
def set_state(loop_vars):
|
||||
nonlocal s
|
||||
s, = loop_vars
|
||||
|
||||
s = constant_op.constant(0, dtype=dtypes.int64)
|
||||
control_flow.for_stmt(
|
||||
iter(dataset_ops.Dataset.range(5)),
|
||||
extra_test=None,
|
||||
body=body,
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
return s
|
||||
self.assertAllEqual(test_fn(), 1234)
|
||||
s = constant_op.constant(0, dtype=dtypes.int64)
|
||||
control_flow.for_stmt(
|
||||
iter(dataset_ops.Dataset.range(5)),
|
||||
extra_test=None,
|
||||
body=body,
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(s, 1234)
|
||||
|
||||
def test_tf_iterator_shape_invariants(self):
|
||||
# graph-mode iterators are only supported inside tf.function.
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
def body(i):
|
||||
nonlocal s
|
||||
s = array_ops.concat([s, [i]], 0)
|
||||
def body(i):
|
||||
nonlocal s
|
||||
s = array_ops.concat([s, [i]], 0)
|
||||
|
||||
def set_state(loop_vars):
|
||||
nonlocal s
|
||||
s, = loop_vars
|
||||
def set_state(loop_vars):
|
||||
nonlocal s
|
||||
s, = loop_vars
|
||||
|
||||
s = constant_op.constant([], dtype=dtypes.int64)
|
||||
control_flow.for_stmt(
|
||||
iter(dataset_ops.Dataset.range(5)),
|
||||
extra_test=None,
|
||||
body=body,
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]})
|
||||
return s
|
||||
self.assertAllEqual(test_fn(), [0, 1, 2, 3, 4])
|
||||
s = constant_op.constant([], dtype=dtypes.int64)
|
||||
control_flow.for_stmt(
|
||||
iter(dataset_ops.Dataset.range(5)),
|
||||
extra_test=None,
|
||||
body=body,
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]})
|
||||
|
||||
self.assertAllEqual(s, [0, 1, 2, 3, 4])
|
||||
|
||||
def test_tf_iterator_no_loop_vars(self):
|
||||
def body(i):
|
||||
v.assign(v.read_value() * 10 + i)
|
||||
|
||||
v = variables.Variable(0, dtype=dtypes.int64)
|
||||
self.evaluate(v.initializer)
|
||||
v = self.variable('v', 0, dtypes.int64)
|
||||
|
||||
# tf.function required for the automatic control dependencies.
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
control_flow.for_stmt(
|
||||
iter(dataset_ops.Dataset.range(5)),
|
||||
extra_test=None,
|
||||
body=body,
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
symbol_names=(),
|
||||
opts={})
|
||||
control_flow.for_stmt(
|
||||
iter(dataset_ops.Dataset.range(5)),
|
||||
extra_test=None,
|
||||
body=body,
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
symbol_names=(),
|
||||
opts={})
|
||||
|
||||
self.evaluate(test_fn())
|
||||
self.assertEqual(self.evaluate(v.read_value()), 1234)
|
||||
self.assertEqual(v.read_value(), 1234)
|
||||
|
||||
def test_tf_ragged_tensor(self):
|
||||
def body(i):
|
||||
@ -473,7 +452,8 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
self.assertEqual(self.evaluate(s), (123,))
|
||||
|
||||
self.assertEqual(s, (123,))
|
||||
|
||||
def test_tf_ragged_tensor_higher_dimensional(self):
|
||||
def body(i):
|
||||
@ -497,30 +477,26 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
self.assertEqual(self.evaluate(s), (12,))
|
||||
|
||||
self.assertEqual(s, (12,))
|
||||
|
||||
def test_tf_ragged_tensor_no_loop_vars(self):
|
||||
v = variables.Variable(0, dtype=dtypes.int32)
|
||||
self.evaluate(v.initializer)
|
||||
v = self.variable('v', 0, dtypes.int32)
|
||||
|
||||
def body(i):
|
||||
v.assign(v.read_value() * 10 + i[0])
|
||||
|
||||
# tf.function required for the automatic control dependencies.
|
||||
@def_function.function(autograph=False)
|
||||
def test_fn():
|
||||
control_flow.for_stmt(
|
||||
ragged_factory_ops.constant([[1], [2, 4], [3]]),
|
||||
extra_test=None,
|
||||
body=body,
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
symbol_names=(),
|
||||
opts={})
|
||||
control_flow.for_stmt(
|
||||
ragged_factory_ops.constant([[1], [2, 4], [3]]),
|
||||
extra_test=None,
|
||||
body=body,
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
symbol_names=(),
|
||||
opts={})
|
||||
|
||||
self.evaluate(test_fn())
|
||||
# Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row).
|
||||
self.assertEqual(self.evaluate(v.read_value()), 123)
|
||||
self.assertEqual(v.read_value(), 123)
|
||||
|
||||
def _basic_loop(self, init_value, body_fn):
|
||||
def body(i):
|
||||
@ -561,8 +537,7 @@ class ForLoopTest(test.TestCase):
|
||||
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class WhileLoopTest(test.TestCase):
|
||||
class WhileLoopTest(testing.AutoGraphTestCase):
|
||||
|
||||
def test_tensor(self):
|
||||
def body():
|
||||
@ -584,40 +559,36 @@ class WhileLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('i', 's'),
|
||||
opts={})
|
||||
self.assertEqual(self.evaluate((i, s)), (5, 1234))
|
||||
self.assertEqual(i, 5)
|
||||
self.assertEqual(s, 1234)
|
||||
|
||||
def test_tensor_with_side_effecting_condition(self):
|
||||
v = variables.Variable(0)
|
||||
v = self.variable('v', 0, dtypes.int32)
|
||||
|
||||
# tf.function required for the automatic control dependencies.
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
def cond():
|
||||
v.assign(v.read_value() * 10 + i)
|
||||
return i < n
|
||||
def cond():
|
||||
v.assign(v.read_value() * 10 + i)
|
||||
return i < n
|
||||
|
||||
def body():
|
||||
nonlocal i
|
||||
i += 1
|
||||
def body():
|
||||
nonlocal i
|
||||
i += 1
|
||||
|
||||
def set_state(loop_vars):
|
||||
nonlocal i
|
||||
i, = loop_vars
|
||||
def set_state(loop_vars):
|
||||
nonlocal i
|
||||
i, = loop_vars
|
||||
|
||||
i = 0
|
||||
n = constant_op.constant(5)
|
||||
control_flow.while_stmt(
|
||||
test=cond,
|
||||
body=body,
|
||||
get_state=lambda: (i,),
|
||||
set_state=set_state,
|
||||
symbol_names=('i',),
|
||||
opts={})
|
||||
return i
|
||||
i = 0
|
||||
n = constant_op.constant(5)
|
||||
control_flow.while_stmt(
|
||||
test=cond,
|
||||
body=body,
|
||||
get_state=lambda: (i,),
|
||||
set_state=set_state,
|
||||
symbol_names=('i',),
|
||||
opts={})
|
||||
|
||||
self.evaluate(v.initializer)
|
||||
self.assertEqual(self.evaluate(test_fn()), (5,))
|
||||
self.assertEqual(self.evaluate(v), (12345,))
|
||||
self.assertEqual(i, (5,))
|
||||
self.assertEqual(v, (12345,))
|
||||
|
||||
def test_tensor_with_python_state(self):
|
||||
class MutableObject(object):
|
||||
@ -642,7 +613,8 @@ class WhileLoopTest(test.TestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('i', 'state.field'),
|
||||
opts={})
|
||||
self.assertEqual(self.evaluate((i, state.field)), (5, 1234))
|
||||
self.assertEqual(i, 5)
|
||||
self.assertEqual(state.field, 1234)
|
||||
|
||||
def test_python(self):
|
||||
def body():
|
||||
@ -679,7 +651,7 @@ class WhileLoopTest(test.TestCase):
|
||||
symbol_names=('i', 's'),
|
||||
opts={})
|
||||
self.assertEqual(i, 5)
|
||||
self.assertEqual(self.evaluate(s), 1234)
|
||||
self.assertEqual(s, 1234)
|
||||
|
||||
def test_python_while_infinite(self):
|
||||
if not __debug__:
|
||||
@ -800,8 +772,7 @@ class WhileLoopTest(test.TestCase):
|
||||
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class IfStmtTest(test.TestCase):
|
||||
class IfStmtTest(testing.AutoGraphTestCase):
|
||||
|
||||
def test_tensor(self):
|
||||
|
||||
@ -829,8 +800,8 @@ class IfStmtTest(test.TestCase):
|
||||
nouts=1)
|
||||
return i
|
||||
|
||||
self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True))))
|
||||
self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False))))
|
||||
self.assertEqual(test_fn(constant_op.constant(True)), 1)
|
||||
self.assertEqual(test_fn(constant_op.constant(False)), -1)
|
||||
|
||||
def test_tensor_no_outputs(self):
|
||||
|
||||
@ -858,8 +829,8 @@ class IfStmtTest(test.TestCase):
|
||||
nouts=0)
|
||||
return i
|
||||
|
||||
self.assertEqual(None, test_fn(constant_op.constant(True)))
|
||||
self.assertEqual(None, test_fn(constant_op.constant(False)))
|
||||
self.assertIsNone(test_fn(constant_op.constant(True)))
|
||||
self.assertIsNone(test_fn(constant_op.constant(False)))
|
||||
|
||||
def test_tensor_multiple_returns(self):
|
||||
|
||||
@ -889,9 +860,8 @@ class IfStmtTest(test.TestCase):
|
||||
nouts=2)
|
||||
return i, j
|
||||
|
||||
self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True))))
|
||||
self.assertEqual((-1, -2),
|
||||
self.evaluate(test_fn(constant_op.constant(False))))
|
||||
self.assertEqual(test_fn(constant_op.constant(True)), (1, 2))
|
||||
self.assertEqual(test_fn(constant_op.constant(False)), (-1, -2))
|
||||
|
||||
def test_python(self):
|
||||
|
||||
@ -915,8 +885,8 @@ class IfStmtTest(test.TestCase):
|
||||
nouts=1)
|
||||
return i
|
||||
|
||||
self.assertEqual(1, test_fn(True))
|
||||
self.assertEqual(-1, test_fn(False))
|
||||
self.assertEqual(test_fn(True), 1)
|
||||
self.assertEqual(test_fn(False), -1)
|
||||
|
||||
def test_python_multiple_returns(self):
|
||||
|
||||
@ -942,8 +912,8 @@ class IfStmtTest(test.TestCase):
|
||||
nouts=2)
|
||||
return i, j
|
||||
|
||||
self.assertEqual((1, 2), test_fn(True))
|
||||
self.assertEqual((-1, -2), test_fn(False))
|
||||
self.assertEqual(test_fn(True), (1, 2))
|
||||
self.assertEqual(test_fn(False), (-1, -2))
|
||||
|
||||
def _basic_cond(self, body_fn, else_fn):
|
||||
def body():
|
||||
@ -959,16 +929,14 @@ class IfStmtTest(test.TestCase):
|
||||
x, = cond_vars
|
||||
|
||||
x = 0
|
||||
# Eager cond had different semantics, we don't test those here.
|
||||
with func_graph.FuncGraph('tmp').as_default():
|
||||
control_flow.if_stmt(
|
||||
cond=constant_op.constant(True),
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
get_state=lambda: (x,),
|
||||
set_state=set_state,
|
||||
symbol_names=('x',),
|
||||
nouts=1)
|
||||
control_flow.if_stmt(
|
||||
cond=constant_op.constant(True),
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
get_state=lambda: (x,),
|
||||
set_state=set_state,
|
||||
symbol_names=('x',),
|
||||
nouts=1)
|
||||
return x
|
||||
|
||||
def test_tensor_none_output(self):
|
||||
|
@ -22,4 +22,3 @@ from tensorflow.python.autograph.utils.context_managers import control_dependenc
|
||||
from tensorflow.python.autograph.utils.misc import alias_tensors
|
||||
from tensorflow.python.autograph.utils.py_func import wrap_py_func
|
||||
from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
|
||||
from tensorflow.python.autograph.utils.testing import fake_tf
|
||||
|
@ -18,20 +18,82 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import imp
|
||||
import types
|
||||
import unittest
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def fake_tf():
|
||||
"""Creates a fake module that looks like TensorFlow, for testing."""
|
||||
mod = imp.new_module('tensorflow')
|
||||
mod_contents = {}
|
||||
mod_contents.update(gen_math_ops.__dict__)
|
||||
mod_contents.update(math_ops.__dict__)
|
||||
mod_contents.update(ops.__dict__)
|
||||
mod_contents.update(mod.__dict__)
|
||||
mod.__dict__.update(mod_contents)
|
||||
return mod
|
||||
class AutoGraphTestCase(test.TestCase):
|
||||
"""Tests specialized for AutoGraph, which run as tf.functions.
|
||||
|
||||
These tests use a staged programming-like approach: most of the test code runs
|
||||
as-is inside a tf.function, but the assertions are lifted outside the
|
||||
function, and run with the corresponding function values instead.
|
||||
|
||||
For example, the test:
|
||||
|
||||
def test_foo(self):
|
||||
baz = bar();
|
||||
self.assertEqual(baz, value)
|
||||
|
||||
is equivalent to writing:
|
||||
|
||||
def test_foo(self):
|
||||
@tf.function
|
||||
def test_fn():
|
||||
baz = bar();
|
||||
return baz, value
|
||||
|
||||
baz_actual, value_actual = test_fn()
|
||||
self.assertEqual(baz_actual, value_actual)
|
||||
"""
|
||||
|
||||
def __new__(cls, *args):
|
||||
obj = super().__new__(cls)
|
||||
|
||||
for name in cls.__dict__:
|
||||
if not name.startswith(unittest.TestLoader.testMethodPrefix):
|
||||
continue
|
||||
m = getattr(obj, name)
|
||||
if callable(m):
|
||||
wrapper = obj._run_as_tf_function(m)
|
||||
setattr(obj, name, types.MethodType(wrapper, obj))
|
||||
|
||||
return obj
|
||||
|
||||
def _run_as_tf_function(self, fn):
|
||||
|
||||
def wrapper(self):
|
||||
@def_function.function(autograph=False) # Testing autograph itself.
|
||||
def fn_wrapper():
|
||||
self.assertions = []
|
||||
fn()
|
||||
targets = [args for _, args in self.assertions]
|
||||
return targets
|
||||
actuals = self.evaluate(fn_wrapper())
|
||||
for (_, args), value in zip(self.assertions, actuals):
|
||||
args[:] = value
|
||||
return wrapper
|
||||
|
||||
def variable(self, name, value, dtype):
|
||||
with ops.init_scope():
|
||||
if name not in self.variables:
|
||||
self.variables[name] = variables.Variable(value, dtype=dtype)
|
||||
self.evaluate(self.variables[name].initializer)
|
||||
return self.variables[name]
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.variables = {}
|
||||
|
||||
def tearDown(self):
|
||||
for fn, args in self.assertions:
|
||||
fn(*args)
|
||||
super().tearDown()
|
||||
|
||||
def assertEqual(self, *args):
|
||||
self.assertions.append((super().assertEqual, list(args)))
|
||||
|
Loading…
Reference in New Issue
Block a user