Merge pull request from lgeiger:refactor-autocast-variable-tests

PiperOrigin-RevId: 317024448
Change-Id: Ie3d376a6922f77682d94835e0d9ca6f74331f442
This commit is contained in:
TensorFlower Gardener 2020-06-17 20:51:41 -07:00
commit 56c01dc970
2 changed files with 59 additions and 69 deletions
tensorflow/python/keras/mixed_precision/experimental

View File

@ -144,9 +144,10 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/eager:context",
"//tensorflow/python/keras:combinations",
"//tensorflow/python/keras/optimizer_v2",
"@absl_py//absl/testing:parameterized",
],

View File

@ -23,14 +23,16 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.keras import combinations
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.ops import array_ops
@ -40,29 +42,10 @@ from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent as gradient_descent_v1
from tensorflow.python.training.tracking import util as trackable_utils
TESTCASES = ({
'testcase_name': 'base',
'distribute': False
}, {
'testcase_name': 'distribute',
'distribute': True
})
def get_distribute_scope(distribute):
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
if distribute:
return mirrored_strategy.MirroredStrategy(['cpu:0']).scope()
else:
return DummyContextManager()
maybe_distribute = combinations.combine(distribution=[
strategy_combinations.default_strategy,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2
])
def get_var(val, dtype, name=None):
@ -72,9 +55,13 @@ def get_var(val, dtype, name=None):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(*TESTCASES)
def test_read(self, distribute):
with get_distribute_scope(distribute):
def setUp(self):
strategy_combinations.set_virtual_cpus_to_at_least(3)
super(AutoCastVariableTest, self).setUp()
@combinations.generate(maybe_distribute)
def test_read(self, distribution):
with distribution.scope():
x = get_var(1., dtypes.float32)
x = autocast_variable.create_autocast_variable(x)
self.evaluate(x.initializer)
@ -116,9 +103,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16)
self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16)
@parameterized.named_parameters(*TESTCASES)
def test_read_nested_scopes(self, distribute):
with get_distribute_scope(distribute):
@combinations.generate(maybe_distribute)
def test_read_nested_scopes(self, distribution):
with distribution.scope():
x = get_var(1., dtypes.float32)
x = autocast_variable.create_autocast_variable(x)
self.evaluate(x.initializer)
@ -136,9 +123,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(x.read_value().dtype, dtypes.float16)
@parameterized.named_parameters(*TESTCASES)
def test_dtype_is_not_string(self, distribute):
with get_distribute_scope(distribute):
@combinations.generate(maybe_distribute)
def test_dtype_is_not_string(self, distribution):
with distribution.scope():
x = get_var(1., dtypes.float32)
x = autocast_variable.create_autocast_variable(x)
self.assertEqual(x.dtype, dtypes.float32)
@ -153,13 +140,13 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(x.true_dtype, dtypes.float32)
self.assertIsInstance(x.true_dtype, dtypes.DType)
@parameterized.named_parameters(*TESTCASES)
def test_method_delegations(self, distribute):
@combinations.generate(maybe_distribute)
def test_method_delegations(self, distribution):
# Test AutoCastVariable correctly delegates Variable methods to the
# underlying variable.
with self.test_session(), get_distribute_scope(distribute):
with self.test_session(), distribution.scope():
for read_dtype in (dtypes.float32, dtypes.float16):
if distribute:
if ds_context.has_strategy():
# MirroredVariable.assign will (incorrectly) return a Mirrored value
# instead of a MirroredVariable. So we cannot properly wrap it in an
# AutoCastVariable.
@ -183,14 +170,14 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(x.aggregation, x._variable.aggregation)
self.assertEqual(self.evaluate(x.initialized_value()), 7)
if not context.executing_eagerly():
if not distribute:
if not ds_context.has_strategy():
# These functions are not supported for DistributedVariables
x.load(9)
self.assertEqual(x.eval(), 9)
self.assertEqual(self.evaluate(x.initial_value), 7)
self.assertEqual(x.op, x._variable.op)
self.assertEqual(x.graph, x._variable.graph)
if not distribute:
if not ds_context.has_strategy():
# These attributes are not supported for DistributedVariables
self.assertIsNone(x.constraint)
self.assertEqual(x.initializer, x._variable.initializer)
@ -202,7 +189,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(x.shape, ())
self.assertEqual(x.get_shape(), ())
if not distribute:
if not ds_context.has_strategy():
# Test scatter_* methods. These are not supported for
# DistributedVariables
x = get_var([7, 8], dtypes.float32)
@ -233,9 +220,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(
evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2])
@parameterized.named_parameters(*TESTCASES)
def test_operator_overloads(self, distribute):
with get_distribute_scope(distribute):
@combinations.generate(maybe_distribute)
def test_operator_overloads(self, distribution):
with distribution.scope():
for read_dtype in (dtypes.float32, dtypes.float16):
x = get_var(7., dtypes.float32)
x = autocast_variable.create_autocast_variable(x)
@ -280,9 +267,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(x == [7., 8., 10.], [True, True, False])
self.assertAllEqual(x != [7., 8., 10.], [False, False, True])
@parameterized.named_parameters(*TESTCASES)
def test_assign(self, distribute):
with get_distribute_scope(distribute):
@combinations.generate(maybe_distribute)
def test_assign(self, distribution):
with distribution.scope():
x = get_var(0., dtypes.float32)
x = autocast_variable.create_autocast_variable(x)
self.evaluate(x.initializer)
@ -318,18 +305,20 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(3., self.evaluate(x.assign_sub(3.)))
# Assign multiple times
assign = x.assign(1.)
self.assertAllClose(1., self.evaluate(assign))
self.assertAllClose(0., self.evaluate(assign.assign(0.)))
assign_add = x.assign_add(3.)
self.assertAllClose(3., self.evaluate(assign_add))
self.assertAllClose(3. * 3,
self.evaluate(x.assign_add(3.).assign_add(3.)))
self.assertAllClose(3. * 3, x)
assign_sub = x.assign_sub(3.)
self.assertAllClose(3. * 2, self.evaluate(assign_sub))
self.assertAllClose(0.,
self.evaluate(x.assign_sub(3.).assign_sub(3.)))
# This currently only works if no strategy is used
if not ds_context.has_strategy():
assign = x.assign(1.)
self.assertAllClose(1., self.evaluate(assign))
self.assertAllClose(0., self.evaluate(assign.assign(0.)))
assign_add = x.assign_add(3.)
self.assertAllClose(3., self.evaluate(assign_add))
self.assertAllClose(3. * 3,
self.evaluate(x.assign_add(3.).assign_add(3.)))
self.assertAllClose(3. * 3, x)
assign_sub = x.assign_sub(3.)
self.assertAllClose(3. * 2, self.evaluate(assign_sub))
self.assertAllClose(0.,
self.evaluate(x.assign_sub(3.).assign_sub(3.)))
# Assign with read_value=False
self.assertIsNone(self.evaluate(x.assign(1., read_value=False)))
@ -355,9 +344,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
# assign still expect float32 value even if in float16 scope
run_and_check()
@parameterized.named_parameters(*TESTCASES)
def test_assign_stays_in_true_dtype(self, distribute):
with get_distribute_scope(distribute):
@combinations.generate(maybe_distribute)
def test_assign_stays_in_true_dtype(self, distribution):
with distribution.scope():
x = get_var(1., dtypes.float32)
x = autocast_variable.create_autocast_variable(x)
self.evaluate(x.initializer)
@ -382,10 +371,10 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(1., self.evaluate(x.value()))
self.assertEqual(1. + small_val, self.evaluate(x.value()))
@parameterized.named_parameters(*TESTCASES)
def test_checkpoint(self, distribute):
@combinations.generate(maybe_distribute)
def test_checkpoint(self, distribution):
with self.test_session():
with get_distribute_scope(distribute):
with distribution.scope():
x = get_var(1., dtypes.float32)
x = autocast_variable.create_autocast_variable(x)
self.evaluate(x.initializer)
@ -398,9 +387,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
self.assertEqual(self.evaluate(x), 123.)
@parameterized.named_parameters(*TESTCASES)
def test_invalid_wrapped_variable(self, distribute):
with get_distribute_scope(distribute):
@combinations.generate(maybe_distribute)
def test_invalid_wrapped_variable(self, distribution):
with distribution.scope():
# Wrap a non-variable
with self.assertRaisesRegexp(ValueError, 'variable must be of type'):
x = constant_op.constant([1.], dtype=dtypes.float32)
@ -443,7 +432,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
)
def test_repr_distributed(self):
with get_distribute_scope(distribute=True):
with mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2']).scope():
x = get_var(1., dtypes.float32)
x = autocast_variable.create_autocast_variable(x)
self.assertRegexpMatches(