Internal cleanup: set up boilerplate for tests which require running inside a tf.function. The boilerplate helps keep the tests clean, although its mechanism may be tricky.
PiperOrigin-RevId: 322444035 Change-Id: I56bbb1e1eb2feb65e16b8171d4a1bd20e1b190a5
This commit is contained in:
parent
cdc531263a
commit
be88c5f8a7
@ -31,26 +31,22 @@ import six
|
|||||||
from tensorflow.python.autograph.operators import control_flow
|
from tensorflow.python.autograph.operators import control_flow
|
||||||
from tensorflow.python.autograph.operators import variables as variable_operators
|
from tensorflow.python.autograph.operators import variables as variable_operators
|
||||||
from tensorflow.python.autograph.utils import ag_logging
|
from tensorflow.python.autograph.utils import ag_logging
|
||||||
|
from tensorflow.python.autograph.utils import testing
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import def_function
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import func_graph
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import variables
|
|
||||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
class ForLoopTest(testing.AutoGraphTestCase):
|
||||||
class ForLoopTest(test.TestCase):
|
|
||||||
|
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -70,7 +66,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(self.evaluate(s), (1234,))
|
self.assertEqual(s, (1234,))
|
||||||
|
|
||||||
def test_range_tensor(self):
|
def test_range_tensor(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -90,7 +86,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={'iterate_names': 'i'})
|
opts={'iterate_names': 'i'})
|
||||||
self.assertEqual(self.evaluate(s), (1234,))
|
self.assertEqual(s, (1234,))
|
||||||
|
|
||||||
def test_range_tensor_explicit_limit_delta(self):
|
def test_range_tensor_explicit_limit_delta(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -110,7 +106,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={'iterate_names': 'i'})
|
opts={'iterate_names': 'i'})
|
||||||
self.assertEqual(self.evaluate(s), (-171207,))
|
self.assertEqual(s, (-171207,))
|
||||||
|
|
||||||
def test_range_tensor_explicit_limit_negative_delta(self):
|
def test_range_tensor_explicit_limit_negative_delta(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -130,7 +126,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={'iterate_names': 'i'})
|
opts={'iterate_names': 'i'})
|
||||||
self.assertEqual(self.evaluate(s), (171207,))
|
self.assertEqual(s, (171207,))
|
||||||
|
|
||||||
def test_range_tensor_random_delta(self):
|
def test_range_tensor_random_delta(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -151,7 +147,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={'iterate_names': 'i'})
|
opts={'iterate_names': 'i'})
|
||||||
self.assertEqual(self.evaluate(s), (1234,))
|
self.assertEqual(s, (1234,))
|
||||||
|
|
||||||
def test_range_tensor_random_negative_delta(self):
|
def test_range_tensor_random_negative_delta(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -172,7 +168,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={'iterate_names': 'i'})
|
opts={'iterate_names': 'i'})
|
||||||
self.assertEqual(self.evaluate(s), (171207,))
|
self.assertEqual(s, (171207,))
|
||||||
|
|
||||||
def test_tensor_with_extra_test_object_vars(self):
|
def test_tensor_with_extra_test_object_vars(self):
|
||||||
class MutableObject(object):
|
class MutableObject(object):
|
||||||
@ -198,7 +194,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('state.field_1', 'state.field_2'),
|
symbol_names=('state.field_1', 'state.field_2'),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(self.evaluate((state.field_1, state.field_2)), (6, 6))
|
self.assertEqual((state.field_1, state.field_2), (6, 6))
|
||||||
|
|
||||||
def test_python(self):
|
def test_python(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -297,7 +293,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(self.evaluate(s), (1234,))
|
self.assertEqual(s, (1234,))
|
||||||
|
|
||||||
def test_dataset_with_extra_test(self):
|
def test_dataset_with_extra_test(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -317,7 +313,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(self.evaluate(s), (12,))
|
self.assertEqual(s, (12,))
|
||||||
|
|
||||||
def test_dataset_with_extra_test_collection_vars(self):
|
def test_dataset_with_extra_test_collection_vars(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -339,7 +335,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('l[0]', 's'),
|
symbol_names=('l[0]', 's'),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(self.evaluate((l[0], s)), (3, 3))
|
self.assertEqual((l[0], s), (3, 3))
|
||||||
|
|
||||||
def test_dataset_with_extra_test_iteration_limiting(self):
|
def test_dataset_with_extra_test_iteration_limiting(self):
|
||||||
def body(it):
|
def body(it):
|
||||||
@ -360,100 +356,83 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('i',),
|
symbol_names=('i',),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(self.evaluate(i), (3,))
|
self.assertEqual(i, (3,))
|
||||||
|
|
||||||
def test_tf_dataset_no_loop_vars(self):
|
def test_tf_dataset_no_loop_vars(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
v.assign(v.read_value() * 10 + i)
|
v.assign(v.read_value() * 10 + i)
|
||||||
|
|
||||||
v = variables.Variable(0, dtype=dtypes.int64)
|
v = self.variable('v', 0, dtypes.int64)
|
||||||
self.evaluate(v.initializer)
|
|
||||||
|
|
||||||
# tf.function required for the automatic control dependencies, and because
|
control_flow.for_stmt(
|
||||||
# ops test for its presence.
|
dataset_ops.Dataset.range(5),
|
||||||
@def_function.function
|
extra_test=None,
|
||||||
def test_fn():
|
body=body,
|
||||||
control_flow.for_stmt(
|
get_state=lambda: (),
|
||||||
dataset_ops.Dataset.range(5),
|
set_state=lambda _: None,
|
||||||
extra_test=None,
|
symbol_names=(),
|
||||||
body=body,
|
opts={})
|
||||||
get_state=lambda: (),
|
|
||||||
set_state=lambda _: None,
|
|
||||||
symbol_names=(),
|
|
||||||
opts={})
|
|
||||||
|
|
||||||
self.evaluate(test_fn())
|
self.assertEqual(v.read_value(), 1234)
|
||||||
self.assertEqual(self.evaluate(v.read_value()), 1234)
|
|
||||||
|
|
||||||
def test_tf_iterator(self):
|
def test_tf_iterator(self):
|
||||||
# graph-mode iterators are only supported inside tf.function.
|
def body(i):
|
||||||
@def_function.function
|
nonlocal s
|
||||||
def test_fn():
|
s = s * 10 + i
|
||||||
def body(i):
|
|
||||||
nonlocal s
|
|
||||||
s = s * 10 + i
|
|
||||||
|
|
||||||
def set_state(loop_vars):
|
def set_state(loop_vars):
|
||||||
nonlocal s
|
nonlocal s
|
||||||
s, = loop_vars
|
s, = loop_vars
|
||||||
|
|
||||||
s = constant_op.constant(0, dtype=dtypes.int64)
|
s = constant_op.constant(0, dtype=dtypes.int64)
|
||||||
control_flow.for_stmt(
|
control_flow.for_stmt(
|
||||||
iter(dataset_ops.Dataset.range(5)),
|
iter(dataset_ops.Dataset.range(5)),
|
||||||
extra_test=None,
|
extra_test=None,
|
||||||
body=body,
|
body=body,
|
||||||
get_state=lambda: (s,),
|
get_state=lambda: (s,),
|
||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={})
|
opts={})
|
||||||
return s
|
|
||||||
self.assertAllEqual(test_fn(), 1234)
|
self.assertEqual(s, 1234)
|
||||||
|
|
||||||
def test_tf_iterator_shape_invariants(self):
|
def test_tf_iterator_shape_invariants(self):
|
||||||
# graph-mode iterators are only supported inside tf.function.
|
def body(i):
|
||||||
@def_function.function
|
nonlocal s
|
||||||
def test_fn():
|
s = array_ops.concat([s, [i]], 0)
|
||||||
def body(i):
|
|
||||||
nonlocal s
|
|
||||||
s = array_ops.concat([s, [i]], 0)
|
|
||||||
|
|
||||||
def set_state(loop_vars):
|
def set_state(loop_vars):
|
||||||
nonlocal s
|
nonlocal s
|
||||||
s, = loop_vars
|
s, = loop_vars
|
||||||
|
|
||||||
s = constant_op.constant([], dtype=dtypes.int64)
|
s = constant_op.constant([], dtype=dtypes.int64)
|
||||||
control_flow.for_stmt(
|
control_flow.for_stmt(
|
||||||
iter(dataset_ops.Dataset.range(5)),
|
iter(dataset_ops.Dataset.range(5)),
|
||||||
extra_test=None,
|
extra_test=None,
|
||||||
body=body,
|
body=body,
|
||||||
get_state=lambda: (s,),
|
get_state=lambda: (s,),
|
||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]})
|
opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]})
|
||||||
return s
|
|
||||||
self.assertAllEqual(test_fn(), [0, 1, 2, 3, 4])
|
self.assertAllEqual(s, [0, 1, 2, 3, 4])
|
||||||
|
|
||||||
def test_tf_iterator_no_loop_vars(self):
|
def test_tf_iterator_no_loop_vars(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
v.assign(v.read_value() * 10 + i)
|
v.assign(v.read_value() * 10 + i)
|
||||||
|
|
||||||
v = variables.Variable(0, dtype=dtypes.int64)
|
v = self.variable('v', 0, dtypes.int64)
|
||||||
self.evaluate(v.initializer)
|
|
||||||
|
|
||||||
# tf.function required for the automatic control dependencies.
|
control_flow.for_stmt(
|
||||||
@def_function.function
|
iter(dataset_ops.Dataset.range(5)),
|
||||||
def test_fn():
|
extra_test=None,
|
||||||
control_flow.for_stmt(
|
body=body,
|
||||||
iter(dataset_ops.Dataset.range(5)),
|
get_state=lambda: (),
|
||||||
extra_test=None,
|
set_state=lambda _: None,
|
||||||
body=body,
|
symbol_names=(),
|
||||||
get_state=lambda: (),
|
opts={})
|
||||||
set_state=lambda _: None,
|
|
||||||
symbol_names=(),
|
|
||||||
opts={})
|
|
||||||
|
|
||||||
self.evaluate(test_fn())
|
self.assertEqual(v.read_value(), 1234)
|
||||||
self.assertEqual(self.evaluate(v.read_value()), 1234)
|
|
||||||
|
|
||||||
def test_tf_ragged_tensor(self):
|
def test_tf_ragged_tensor(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -473,7 +452,8 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(self.evaluate(s), (123,))
|
|
||||||
|
self.assertEqual(s, (123,))
|
||||||
|
|
||||||
def test_tf_ragged_tensor_higher_dimensional(self):
|
def test_tf_ragged_tensor_higher_dimensional(self):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -497,30 +477,26 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('s',),
|
symbol_names=('s',),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(self.evaluate(s), (12,))
|
|
||||||
|
self.assertEqual(s, (12,))
|
||||||
|
|
||||||
def test_tf_ragged_tensor_no_loop_vars(self):
|
def test_tf_ragged_tensor_no_loop_vars(self):
|
||||||
v = variables.Variable(0, dtype=dtypes.int32)
|
v = self.variable('v', 0, dtypes.int32)
|
||||||
self.evaluate(v.initializer)
|
|
||||||
|
|
||||||
def body(i):
|
def body(i):
|
||||||
v.assign(v.read_value() * 10 + i[0])
|
v.assign(v.read_value() * 10 + i[0])
|
||||||
|
|
||||||
# tf.function required for the automatic control dependencies.
|
control_flow.for_stmt(
|
||||||
@def_function.function(autograph=False)
|
ragged_factory_ops.constant([[1], [2, 4], [3]]),
|
||||||
def test_fn():
|
extra_test=None,
|
||||||
control_flow.for_stmt(
|
body=body,
|
||||||
ragged_factory_ops.constant([[1], [2, 4], [3]]),
|
get_state=lambda: (),
|
||||||
extra_test=None,
|
set_state=lambda _: None,
|
||||||
body=body,
|
symbol_names=(),
|
||||||
get_state=lambda: (),
|
opts={})
|
||||||
set_state=lambda _: None,
|
|
||||||
symbol_names=(),
|
|
||||||
opts={})
|
|
||||||
|
|
||||||
self.evaluate(test_fn())
|
|
||||||
# Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row).
|
# Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row).
|
||||||
self.assertEqual(self.evaluate(v.read_value()), 123)
|
self.assertEqual(v.read_value(), 123)
|
||||||
|
|
||||||
def _basic_loop(self, init_value, body_fn):
|
def _basic_loop(self, init_value, body_fn):
|
||||||
def body(i):
|
def body(i):
|
||||||
@ -561,8 +537,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
class WhileLoopTest(testing.AutoGraphTestCase):
|
||||||
class WhileLoopTest(test.TestCase):
|
|
||||||
|
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
def body():
|
def body():
|
||||||
@ -584,40 +559,36 @@ class WhileLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('i', 's'),
|
symbol_names=('i', 's'),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(self.evaluate((i, s)), (5, 1234))
|
self.assertEqual(i, 5)
|
||||||
|
self.assertEqual(s, 1234)
|
||||||
|
|
||||||
def test_tensor_with_side_effecting_condition(self):
|
def test_tensor_with_side_effecting_condition(self):
|
||||||
v = variables.Variable(0)
|
v = self.variable('v', 0, dtypes.int32)
|
||||||
|
|
||||||
# tf.function required for the automatic control dependencies.
|
def cond():
|
||||||
@def_function.function
|
v.assign(v.read_value() * 10 + i)
|
||||||
def test_fn():
|
return i < n
|
||||||
def cond():
|
|
||||||
v.assign(v.read_value() * 10 + i)
|
|
||||||
return i < n
|
|
||||||
|
|
||||||
def body():
|
def body():
|
||||||
nonlocal i
|
nonlocal i
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
def set_state(loop_vars):
|
def set_state(loop_vars):
|
||||||
nonlocal i
|
nonlocal i
|
||||||
i, = loop_vars
|
i, = loop_vars
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
n = constant_op.constant(5)
|
n = constant_op.constant(5)
|
||||||
control_flow.while_stmt(
|
control_flow.while_stmt(
|
||||||
test=cond,
|
test=cond,
|
||||||
body=body,
|
body=body,
|
||||||
get_state=lambda: (i,),
|
get_state=lambda: (i,),
|
||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('i',),
|
symbol_names=('i',),
|
||||||
opts={})
|
opts={})
|
||||||
return i
|
|
||||||
|
|
||||||
self.evaluate(v.initializer)
|
self.assertEqual(i, (5,))
|
||||||
self.assertEqual(self.evaluate(test_fn()), (5,))
|
self.assertEqual(v, (12345,))
|
||||||
self.assertEqual(self.evaluate(v), (12345,))
|
|
||||||
|
|
||||||
def test_tensor_with_python_state(self):
|
def test_tensor_with_python_state(self):
|
||||||
class MutableObject(object):
|
class MutableObject(object):
|
||||||
@ -642,7 +613,8 @@ class WhileLoopTest(test.TestCase):
|
|||||||
set_state=set_state,
|
set_state=set_state,
|
||||||
symbol_names=('i', 'state.field'),
|
symbol_names=('i', 'state.field'),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(self.evaluate((i, state.field)), (5, 1234))
|
self.assertEqual(i, 5)
|
||||||
|
self.assertEqual(state.field, 1234)
|
||||||
|
|
||||||
def test_python(self):
|
def test_python(self):
|
||||||
def body():
|
def body():
|
||||||
@ -679,7 +651,7 @@ class WhileLoopTest(test.TestCase):
|
|||||||
symbol_names=('i', 's'),
|
symbol_names=('i', 's'),
|
||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(i, 5)
|
self.assertEqual(i, 5)
|
||||||
self.assertEqual(self.evaluate(s), 1234)
|
self.assertEqual(s, 1234)
|
||||||
|
|
||||||
def test_python_while_infinite(self):
|
def test_python_while_infinite(self):
|
||||||
if not __debug__:
|
if not __debug__:
|
||||||
@ -800,8 +772,7 @@ class WhileLoopTest(test.TestCase):
|
|||||||
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
class IfStmtTest(testing.AutoGraphTestCase):
|
||||||
class IfStmtTest(test.TestCase):
|
|
||||||
|
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
|
|
||||||
@ -829,8 +800,8 @@ class IfStmtTest(test.TestCase):
|
|||||||
nouts=1)
|
nouts=1)
|
||||||
return i
|
return i
|
||||||
|
|
||||||
self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True))))
|
self.assertEqual(test_fn(constant_op.constant(True)), 1)
|
||||||
self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False))))
|
self.assertEqual(test_fn(constant_op.constant(False)), -1)
|
||||||
|
|
||||||
def test_tensor_no_outputs(self):
|
def test_tensor_no_outputs(self):
|
||||||
|
|
||||||
@ -858,8 +829,8 @@ class IfStmtTest(test.TestCase):
|
|||||||
nouts=0)
|
nouts=0)
|
||||||
return i
|
return i
|
||||||
|
|
||||||
self.assertEqual(None, test_fn(constant_op.constant(True)))
|
self.assertIsNone(test_fn(constant_op.constant(True)))
|
||||||
self.assertEqual(None, test_fn(constant_op.constant(False)))
|
self.assertIsNone(test_fn(constant_op.constant(False)))
|
||||||
|
|
||||||
def test_tensor_multiple_returns(self):
|
def test_tensor_multiple_returns(self):
|
||||||
|
|
||||||
@ -889,9 +860,8 @@ class IfStmtTest(test.TestCase):
|
|||||||
nouts=2)
|
nouts=2)
|
||||||
return i, j
|
return i, j
|
||||||
|
|
||||||
self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True))))
|
self.assertEqual(test_fn(constant_op.constant(True)), (1, 2))
|
||||||
self.assertEqual((-1, -2),
|
self.assertEqual(test_fn(constant_op.constant(False)), (-1, -2))
|
||||||
self.evaluate(test_fn(constant_op.constant(False))))
|
|
||||||
|
|
||||||
def test_python(self):
|
def test_python(self):
|
||||||
|
|
||||||
@ -915,8 +885,8 @@ class IfStmtTest(test.TestCase):
|
|||||||
nouts=1)
|
nouts=1)
|
||||||
return i
|
return i
|
||||||
|
|
||||||
self.assertEqual(1, test_fn(True))
|
self.assertEqual(test_fn(True), 1)
|
||||||
self.assertEqual(-1, test_fn(False))
|
self.assertEqual(test_fn(False), -1)
|
||||||
|
|
||||||
def test_python_multiple_returns(self):
|
def test_python_multiple_returns(self):
|
||||||
|
|
||||||
@ -942,8 +912,8 @@ class IfStmtTest(test.TestCase):
|
|||||||
nouts=2)
|
nouts=2)
|
||||||
return i, j
|
return i, j
|
||||||
|
|
||||||
self.assertEqual((1, 2), test_fn(True))
|
self.assertEqual(test_fn(True), (1, 2))
|
||||||
self.assertEqual((-1, -2), test_fn(False))
|
self.assertEqual(test_fn(False), (-1, -2))
|
||||||
|
|
||||||
def _basic_cond(self, body_fn, else_fn):
|
def _basic_cond(self, body_fn, else_fn):
|
||||||
def body():
|
def body():
|
||||||
@ -959,16 +929,14 @@ class IfStmtTest(test.TestCase):
|
|||||||
x, = cond_vars
|
x, = cond_vars
|
||||||
|
|
||||||
x = 0
|
x = 0
|
||||||
# Eager cond had different semantics, we don't test those here.
|
control_flow.if_stmt(
|
||||||
with func_graph.FuncGraph('tmp').as_default():
|
cond=constant_op.constant(True),
|
||||||
control_flow.if_stmt(
|
body=body,
|
||||||
cond=constant_op.constant(True),
|
orelse=orelse,
|
||||||
body=body,
|
get_state=lambda: (x,),
|
||||||
orelse=orelse,
|
set_state=set_state,
|
||||||
get_state=lambda: (x,),
|
symbol_names=('x',),
|
||||||
set_state=set_state,
|
nouts=1)
|
||||||
symbol_names=('x',),
|
|
||||||
nouts=1)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def test_tensor_none_output(self):
|
def test_tensor_none_output(self):
|
||||||
|
@ -22,4 +22,3 @@ from tensorflow.python.autograph.utils.context_managers import control_dependenc
|
|||||||
from tensorflow.python.autograph.utils.misc import alias_tensors
|
from tensorflow.python.autograph.utils.misc import alias_tensors
|
||||||
from tensorflow.python.autograph.utils.py_func import wrap_py_func
|
from tensorflow.python.autograph.utils.py_func import wrap_py_func
|
||||||
from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
|
from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
|
||||||
from tensorflow.python.autograph.utils.testing import fake_tf
|
|
||||||
|
@ -18,20 +18,82 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import imp
|
import types
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
def fake_tf():
|
class AutoGraphTestCase(test.TestCase):
|
||||||
"""Creates a fake module that looks like TensorFlow, for testing."""
|
"""Tests specialized for AutoGraph, which run as tf.functions.
|
||||||
mod = imp.new_module('tensorflow')
|
|
||||||
mod_contents = {}
|
These tests use a staged programming-like approach: most of the test code runs
|
||||||
mod_contents.update(gen_math_ops.__dict__)
|
as-is inside a tf.function, but the assertions are lifted outside the
|
||||||
mod_contents.update(math_ops.__dict__)
|
function, and run with the corresponding function values instead.
|
||||||
mod_contents.update(ops.__dict__)
|
|
||||||
mod_contents.update(mod.__dict__)
|
For example, the test:
|
||||||
mod.__dict__.update(mod_contents)
|
|
||||||
return mod
|
def test_foo(self):
|
||||||
|
baz = bar();
|
||||||
|
self.assertEqual(baz, value)
|
||||||
|
|
||||||
|
is equivalent to writing:
|
||||||
|
|
||||||
|
def test_foo(self):
|
||||||
|
@tf.function
|
||||||
|
def test_fn():
|
||||||
|
baz = bar();
|
||||||
|
return baz, value
|
||||||
|
|
||||||
|
baz_actual, value_actual = test_fn()
|
||||||
|
self.assertEqual(baz_actual, value_actual)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, *args):
|
||||||
|
obj = super().__new__(cls)
|
||||||
|
|
||||||
|
for name in cls.__dict__:
|
||||||
|
if not name.startswith(unittest.TestLoader.testMethodPrefix):
|
||||||
|
continue
|
||||||
|
m = getattr(obj, name)
|
||||||
|
if callable(m):
|
||||||
|
wrapper = obj._run_as_tf_function(m)
|
||||||
|
setattr(obj, name, types.MethodType(wrapper, obj))
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def _run_as_tf_function(self, fn):
|
||||||
|
|
||||||
|
def wrapper(self):
|
||||||
|
@def_function.function(autograph=False) # Testing autograph itself.
|
||||||
|
def fn_wrapper():
|
||||||
|
self.assertions = []
|
||||||
|
fn()
|
||||||
|
targets = [args for _, args in self.assertions]
|
||||||
|
return targets
|
||||||
|
actuals = self.evaluate(fn_wrapper())
|
||||||
|
for (_, args), value in zip(self.assertions, actuals):
|
||||||
|
args[:] = value
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def variable(self, name, value, dtype):
|
||||||
|
with ops.init_scope():
|
||||||
|
if name not in self.variables:
|
||||||
|
self.variables[name] = variables.Variable(value, dtype=dtype)
|
||||||
|
self.evaluate(self.variables[name].initializer)
|
||||||
|
return self.variables[name]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.variables = {}
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
for fn, args in self.assertions:
|
||||||
|
fn(*args)
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
|
def assertEqual(self, *args):
|
||||||
|
self.assertions.append((super().assertEqual, list(args)))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user