Merge pull request #40493 from lgeiger:refactor-autocast-variable-tests
PiperOrigin-RevId: 317024448 Change-Id: Ie3d376a6922f77682d94835e0d9ca6f74331f442
This commit is contained in:
commit
56c01dc970
tensorflow/python/keras/mixed_precision/experimental
@ -144,9 +144,10 @@ py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
@ -23,14 +23,16 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -40,29 +42,10 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import gradient_descent as gradient_descent_v1
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
TESTCASES = ({
|
||||
'testcase_name': 'base',
|
||||
'distribute': False
|
||||
}, {
|
||||
'testcase_name': 'distribute',
|
||||
'distribute': True
|
||||
})
|
||||
|
||||
|
||||
def get_distribute_scope(distribute):
|
||||
|
||||
class DummyContextManager(object):
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
if distribute:
|
||||
return mirrored_strategy.MirroredStrategy(['cpu:0']).scope()
|
||||
else:
|
||||
return DummyContextManager()
|
||||
maybe_distribute = combinations.combine(distribution=[
|
||||
strategy_combinations.default_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_cpu_1_and_2
|
||||
])
|
||||
|
||||
|
||||
def get_var(val, dtype, name=None):
|
||||
@ -72,9 +55,13 @@ def get_var(val, dtype, name=None):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_read(self, distribute):
|
||||
with get_distribute_scope(distribute):
|
||||
def setUp(self):
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(3)
|
||||
super(AutoCastVariableTest, self).setUp()
|
||||
|
||||
@combinations.generate(maybe_distribute)
|
||||
def test_read(self, distribution):
|
||||
with distribution.scope():
|
||||
x = get_var(1., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
self.evaluate(x.initializer)
|
||||
@ -116,9 +103,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16)
|
||||
self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_read_nested_scopes(self, distribute):
|
||||
with get_distribute_scope(distribute):
|
||||
@combinations.generate(maybe_distribute)
|
||||
def test_read_nested_scopes(self, distribution):
|
||||
with distribution.scope():
|
||||
x = get_var(1., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
self.evaluate(x.initializer)
|
||||
@ -136,9 +123,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(x.dtype, dtypes.float16)
|
||||
self.assertEqual(x.read_value().dtype, dtypes.float16)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_dtype_is_not_string(self, distribute):
|
||||
with get_distribute_scope(distribute):
|
||||
@combinations.generate(maybe_distribute)
|
||||
def test_dtype_is_not_string(self, distribution):
|
||||
with distribution.scope():
|
||||
x = get_var(1., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
self.assertEqual(x.dtype, dtypes.float32)
|
||||
@ -153,13 +140,13 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(x.true_dtype, dtypes.float32)
|
||||
self.assertIsInstance(x.true_dtype, dtypes.DType)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_method_delegations(self, distribute):
|
||||
@combinations.generate(maybe_distribute)
|
||||
def test_method_delegations(self, distribution):
|
||||
# Test AutoCastVariable correctly delegates Variable methods to the
|
||||
# underlying variable.
|
||||
with self.test_session(), get_distribute_scope(distribute):
|
||||
with self.test_session(), distribution.scope():
|
||||
for read_dtype in (dtypes.float32, dtypes.float16):
|
||||
if distribute:
|
||||
if ds_context.has_strategy():
|
||||
# MirroredVariable.assign will (incorrectly) return a Mirrored value
|
||||
# instead of a MirroredVariable. So we cannot properly wrap it in an
|
||||
# AutoCastVariable.
|
||||
@ -183,14 +170,14 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(x.aggregation, x._variable.aggregation)
|
||||
self.assertEqual(self.evaluate(x.initialized_value()), 7)
|
||||
if not context.executing_eagerly():
|
||||
if not distribute:
|
||||
if not ds_context.has_strategy():
|
||||
# These functions are not supported for DistributedVariables
|
||||
x.load(9)
|
||||
self.assertEqual(x.eval(), 9)
|
||||
self.assertEqual(self.evaluate(x.initial_value), 7)
|
||||
self.assertEqual(x.op, x._variable.op)
|
||||
self.assertEqual(x.graph, x._variable.graph)
|
||||
if not distribute:
|
||||
if not ds_context.has_strategy():
|
||||
# These attributes are not supported for DistributedVariables
|
||||
self.assertIsNone(x.constraint)
|
||||
self.assertEqual(x.initializer, x._variable.initializer)
|
||||
@ -202,7 +189,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.get_shape(), ())
|
||||
|
||||
if not distribute:
|
||||
if not ds_context.has_strategy():
|
||||
# Test scatter_* methods. These are not supported for
|
||||
# DistributedVariables
|
||||
x = get_var([7, 8], dtypes.float32)
|
||||
@ -233,9 +220,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(
|
||||
evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2])
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_operator_overloads(self, distribute):
|
||||
with get_distribute_scope(distribute):
|
||||
@combinations.generate(maybe_distribute)
|
||||
def test_operator_overloads(self, distribution):
|
||||
with distribution.scope():
|
||||
for read_dtype in (dtypes.float32, dtypes.float16):
|
||||
x = get_var(7., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
@ -280,9 +267,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(x == [7., 8., 10.], [True, True, False])
|
||||
self.assertAllEqual(x != [7., 8., 10.], [False, False, True])
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_assign(self, distribute):
|
||||
with get_distribute_scope(distribute):
|
||||
@combinations.generate(maybe_distribute)
|
||||
def test_assign(self, distribution):
|
||||
with distribution.scope():
|
||||
x = get_var(0., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
self.evaluate(x.initializer)
|
||||
@ -318,18 +305,20 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(3., self.evaluate(x.assign_sub(3.)))
|
||||
|
||||
# Assign multiple times
|
||||
assign = x.assign(1.)
|
||||
self.assertAllClose(1., self.evaluate(assign))
|
||||
self.assertAllClose(0., self.evaluate(assign.assign(0.)))
|
||||
assign_add = x.assign_add(3.)
|
||||
self.assertAllClose(3., self.evaluate(assign_add))
|
||||
self.assertAllClose(3. * 3,
|
||||
self.evaluate(x.assign_add(3.).assign_add(3.)))
|
||||
self.assertAllClose(3. * 3, x)
|
||||
assign_sub = x.assign_sub(3.)
|
||||
self.assertAllClose(3. * 2, self.evaluate(assign_sub))
|
||||
self.assertAllClose(0.,
|
||||
self.evaluate(x.assign_sub(3.).assign_sub(3.)))
|
||||
# This currently only works if no strategy is used
|
||||
if not ds_context.has_strategy():
|
||||
assign = x.assign(1.)
|
||||
self.assertAllClose(1., self.evaluate(assign))
|
||||
self.assertAllClose(0., self.evaluate(assign.assign(0.)))
|
||||
assign_add = x.assign_add(3.)
|
||||
self.assertAllClose(3., self.evaluate(assign_add))
|
||||
self.assertAllClose(3. * 3,
|
||||
self.evaluate(x.assign_add(3.).assign_add(3.)))
|
||||
self.assertAllClose(3. * 3, x)
|
||||
assign_sub = x.assign_sub(3.)
|
||||
self.assertAllClose(3. * 2, self.evaluate(assign_sub))
|
||||
self.assertAllClose(0.,
|
||||
self.evaluate(x.assign_sub(3.).assign_sub(3.)))
|
||||
|
||||
# Assign with read_value=False
|
||||
self.assertIsNone(self.evaluate(x.assign(1., read_value=False)))
|
||||
@ -355,9 +344,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
# assign still expect float32 value even if in float16 scope
|
||||
run_and_check()
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_assign_stays_in_true_dtype(self, distribute):
|
||||
with get_distribute_scope(distribute):
|
||||
@combinations.generate(maybe_distribute)
|
||||
def test_assign_stays_in_true_dtype(self, distribution):
|
||||
with distribution.scope():
|
||||
x = get_var(1., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
self.evaluate(x.initializer)
|
||||
@ -382,10 +371,10 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(1., self.evaluate(x.value()))
|
||||
self.assertEqual(1. + small_val, self.evaluate(x.value()))
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_checkpoint(self, distribute):
|
||||
@combinations.generate(maybe_distribute)
|
||||
def test_checkpoint(self, distribution):
|
||||
with self.test_session():
|
||||
with get_distribute_scope(distribute):
|
||||
with distribution.scope():
|
||||
x = get_var(1., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
self.evaluate(x.initializer)
|
||||
@ -398,9 +387,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
|
||||
self.assertEqual(self.evaluate(x), 123.)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_invalid_wrapped_variable(self, distribute):
|
||||
with get_distribute_scope(distribute):
|
||||
@combinations.generate(maybe_distribute)
|
||||
def test_invalid_wrapped_variable(self, distribution):
|
||||
with distribution.scope():
|
||||
# Wrap a non-variable
|
||||
with self.assertRaisesRegexp(ValueError, 'variable must be of type'):
|
||||
x = constant_op.constant([1.], dtype=dtypes.float32)
|
||||
@ -443,7 +432,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
)
|
||||
|
||||
def test_repr_distributed(self):
|
||||
with get_distribute_scope(distribute=True):
|
||||
with mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2']).scope():
|
||||
x = get_var(1., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
self.assertRegexpMatches(
|
||||
|
Loading…
Reference in New Issue
Block a user