Merge pull request #40493 from lgeiger:refactor-autocast-variable-tests
PiperOrigin-RevId: 317024448 Change-Id: Ie3d376a6922f77682d94835e0d9ca6f74331f442
This commit is contained in:
commit
56c01dc970
@ -144,9 +144,10 @@ py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python/distribute:combinations",
|
||||||
"//tensorflow/python/distribute:mirrored_strategy",
|
"//tensorflow/python/distribute:mirrored_strategy",
|
||||||
|
"//tensorflow/python/distribute:strategy_combinations",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/keras:combinations",
|
|
||||||
"//tensorflow/python/keras/optimizer_v2",
|
"//tensorflow/python/keras/optimizer_v2",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
|
@ -23,14 +23,16 @@ from absl.testing import parameterized
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
|
from tensorflow.python.distribute import combinations
|
||||||
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.distribute import mirrored_strategy
|
from tensorflow.python.distribute import mirrored_strategy
|
||||||
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import indexed_slices
|
from tensorflow.python.framework import indexed_slices
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import combinations
|
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
|
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
|
||||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
|
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -40,29 +42,10 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.training import gradient_descent as gradient_descent_v1
|
from tensorflow.python.training import gradient_descent as gradient_descent_v1
|
||||||
from tensorflow.python.training.tracking import util as trackable_utils
|
from tensorflow.python.training.tracking import util as trackable_utils
|
||||||
|
|
||||||
TESTCASES = ({
|
maybe_distribute = combinations.combine(distribution=[
|
||||||
'testcase_name': 'base',
|
strategy_combinations.default_strategy,
|
||||||
'distribute': False
|
strategy_combinations.mirrored_strategy_with_cpu_1_and_2
|
||||||
}, {
|
])
|
||||||
'testcase_name': 'distribute',
|
|
||||||
'distribute': True
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def get_distribute_scope(distribute):
|
|
||||||
|
|
||||||
class DummyContextManager(object):
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if distribute:
|
|
||||||
return mirrored_strategy.MirroredStrategy(['cpu:0']).scope()
|
|
||||||
else:
|
|
||||||
return DummyContextManager()
|
|
||||||
|
|
||||||
|
|
||||||
def get_var(val, dtype, name=None):
|
def get_var(val, dtype, name=None):
|
||||||
@ -72,9 +55,13 @@ def get_var(val, dtype, name=None):
|
|||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
def setUp(self):
|
||||||
def test_read(self, distribute):
|
strategy_combinations.set_virtual_cpus_to_at_least(3)
|
||||||
with get_distribute_scope(distribute):
|
super(AutoCastVariableTest, self).setUp()
|
||||||
|
|
||||||
|
@combinations.generate(maybe_distribute)
|
||||||
|
def test_read(self, distribution):
|
||||||
|
with distribution.scope():
|
||||||
x = get_var(1., dtypes.float32)
|
x = get_var(1., dtypes.float32)
|
||||||
x = autocast_variable.create_autocast_variable(x)
|
x = autocast_variable.create_autocast_variable(x)
|
||||||
self.evaluate(x.initializer)
|
self.evaluate(x.initializer)
|
||||||
@ -116,9 +103,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16)
|
self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16)
|
||||||
self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16)
|
self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16)
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@combinations.generate(maybe_distribute)
|
||||||
def test_read_nested_scopes(self, distribute):
|
def test_read_nested_scopes(self, distribution):
|
||||||
with get_distribute_scope(distribute):
|
with distribution.scope():
|
||||||
x = get_var(1., dtypes.float32)
|
x = get_var(1., dtypes.float32)
|
||||||
x = autocast_variable.create_autocast_variable(x)
|
x = autocast_variable.create_autocast_variable(x)
|
||||||
self.evaluate(x.initializer)
|
self.evaluate(x.initializer)
|
||||||
@ -136,9 +123,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(x.dtype, dtypes.float16)
|
self.assertEqual(x.dtype, dtypes.float16)
|
||||||
self.assertEqual(x.read_value().dtype, dtypes.float16)
|
self.assertEqual(x.read_value().dtype, dtypes.float16)
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@combinations.generate(maybe_distribute)
|
||||||
def test_dtype_is_not_string(self, distribute):
|
def test_dtype_is_not_string(self, distribution):
|
||||||
with get_distribute_scope(distribute):
|
with distribution.scope():
|
||||||
x = get_var(1., dtypes.float32)
|
x = get_var(1., dtypes.float32)
|
||||||
x = autocast_variable.create_autocast_variable(x)
|
x = autocast_variable.create_autocast_variable(x)
|
||||||
self.assertEqual(x.dtype, dtypes.float32)
|
self.assertEqual(x.dtype, dtypes.float32)
|
||||||
@ -153,13 +140,13 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(x.true_dtype, dtypes.float32)
|
self.assertEqual(x.true_dtype, dtypes.float32)
|
||||||
self.assertIsInstance(x.true_dtype, dtypes.DType)
|
self.assertIsInstance(x.true_dtype, dtypes.DType)
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@combinations.generate(maybe_distribute)
|
||||||
def test_method_delegations(self, distribute):
|
def test_method_delegations(self, distribution):
|
||||||
# Test AutoCastVariable correctly delegates Variable methods to the
|
# Test AutoCastVariable correctly delegates Variable methods to the
|
||||||
# underlying variable.
|
# underlying variable.
|
||||||
with self.test_session(), get_distribute_scope(distribute):
|
with self.test_session(), distribution.scope():
|
||||||
for read_dtype in (dtypes.float32, dtypes.float16):
|
for read_dtype in (dtypes.float32, dtypes.float16):
|
||||||
if distribute:
|
if ds_context.has_strategy():
|
||||||
# MirroredVariable.assign will (incorrectly) return a Mirrored value
|
# MirroredVariable.assign will (incorrectly) return a Mirrored value
|
||||||
# instead of a MirroredVariable. So we cannot properly wrap it in an
|
# instead of a MirroredVariable. So we cannot properly wrap it in an
|
||||||
# AutoCastVariable.
|
# AutoCastVariable.
|
||||||
@ -183,14 +170,14 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(x.aggregation, x._variable.aggregation)
|
self.assertEqual(x.aggregation, x._variable.aggregation)
|
||||||
self.assertEqual(self.evaluate(x.initialized_value()), 7)
|
self.assertEqual(self.evaluate(x.initialized_value()), 7)
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
if not distribute:
|
if not ds_context.has_strategy():
|
||||||
# These functions are not supported for DistributedVariables
|
# These functions are not supported for DistributedVariables
|
||||||
x.load(9)
|
x.load(9)
|
||||||
self.assertEqual(x.eval(), 9)
|
self.assertEqual(x.eval(), 9)
|
||||||
self.assertEqual(self.evaluate(x.initial_value), 7)
|
self.assertEqual(self.evaluate(x.initial_value), 7)
|
||||||
self.assertEqual(x.op, x._variable.op)
|
self.assertEqual(x.op, x._variable.op)
|
||||||
self.assertEqual(x.graph, x._variable.graph)
|
self.assertEqual(x.graph, x._variable.graph)
|
||||||
if not distribute:
|
if not ds_context.has_strategy():
|
||||||
# These attributes are not supported for DistributedVariables
|
# These attributes are not supported for DistributedVariables
|
||||||
self.assertIsNone(x.constraint)
|
self.assertIsNone(x.constraint)
|
||||||
self.assertEqual(x.initializer, x._variable.initializer)
|
self.assertEqual(x.initializer, x._variable.initializer)
|
||||||
@ -202,7 +189,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(x.shape, ())
|
self.assertEqual(x.shape, ())
|
||||||
self.assertEqual(x.get_shape(), ())
|
self.assertEqual(x.get_shape(), ())
|
||||||
|
|
||||||
if not distribute:
|
if not ds_context.has_strategy():
|
||||||
# Test scatter_* methods. These are not supported for
|
# Test scatter_* methods. These are not supported for
|
||||||
# DistributedVariables
|
# DistributedVariables
|
||||||
x = get_var([7, 8], dtypes.float32)
|
x = get_var([7, 8], dtypes.float32)
|
||||||
@ -233,9 +220,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2])
|
evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2])
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@combinations.generate(maybe_distribute)
|
||||||
def test_operator_overloads(self, distribute):
|
def test_operator_overloads(self, distribution):
|
||||||
with get_distribute_scope(distribute):
|
with distribution.scope():
|
||||||
for read_dtype in (dtypes.float32, dtypes.float16):
|
for read_dtype in (dtypes.float32, dtypes.float16):
|
||||||
x = get_var(7., dtypes.float32)
|
x = get_var(7., dtypes.float32)
|
||||||
x = autocast_variable.create_autocast_variable(x)
|
x = autocast_variable.create_autocast_variable(x)
|
||||||
@ -280,9 +267,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(x == [7., 8., 10.], [True, True, False])
|
self.assertAllEqual(x == [7., 8., 10.], [True, True, False])
|
||||||
self.assertAllEqual(x != [7., 8., 10.], [False, False, True])
|
self.assertAllEqual(x != [7., 8., 10.], [False, False, True])
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@combinations.generate(maybe_distribute)
|
||||||
def test_assign(self, distribute):
|
def test_assign(self, distribution):
|
||||||
with get_distribute_scope(distribute):
|
with distribution.scope():
|
||||||
x = get_var(0., dtypes.float32)
|
x = get_var(0., dtypes.float32)
|
||||||
x = autocast_variable.create_autocast_variable(x)
|
x = autocast_variable.create_autocast_variable(x)
|
||||||
self.evaluate(x.initializer)
|
self.evaluate(x.initializer)
|
||||||
@ -318,18 +305,20 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(3., self.evaluate(x.assign_sub(3.)))
|
self.assertAllClose(3., self.evaluate(x.assign_sub(3.)))
|
||||||
|
|
||||||
# Assign multiple times
|
# Assign multiple times
|
||||||
assign = x.assign(1.)
|
# This currently only works if no strategy is used
|
||||||
self.assertAllClose(1., self.evaluate(assign))
|
if not ds_context.has_strategy():
|
||||||
self.assertAllClose(0., self.evaluate(assign.assign(0.)))
|
assign = x.assign(1.)
|
||||||
assign_add = x.assign_add(3.)
|
self.assertAllClose(1., self.evaluate(assign))
|
||||||
self.assertAllClose(3., self.evaluate(assign_add))
|
self.assertAllClose(0., self.evaluate(assign.assign(0.)))
|
||||||
self.assertAllClose(3. * 3,
|
assign_add = x.assign_add(3.)
|
||||||
self.evaluate(x.assign_add(3.).assign_add(3.)))
|
self.assertAllClose(3., self.evaluate(assign_add))
|
||||||
self.assertAllClose(3. * 3, x)
|
self.assertAllClose(3. * 3,
|
||||||
assign_sub = x.assign_sub(3.)
|
self.evaluate(x.assign_add(3.).assign_add(3.)))
|
||||||
self.assertAllClose(3. * 2, self.evaluate(assign_sub))
|
self.assertAllClose(3. * 3, x)
|
||||||
self.assertAllClose(0.,
|
assign_sub = x.assign_sub(3.)
|
||||||
self.evaluate(x.assign_sub(3.).assign_sub(3.)))
|
self.assertAllClose(3. * 2, self.evaluate(assign_sub))
|
||||||
|
self.assertAllClose(0.,
|
||||||
|
self.evaluate(x.assign_sub(3.).assign_sub(3.)))
|
||||||
|
|
||||||
# Assign with read_value=False
|
# Assign with read_value=False
|
||||||
self.assertIsNone(self.evaluate(x.assign(1., read_value=False)))
|
self.assertIsNone(self.evaluate(x.assign(1., read_value=False)))
|
||||||
@ -355,9 +344,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
# assign still expect float32 value even if in float16 scope
|
# assign still expect float32 value even if in float16 scope
|
||||||
run_and_check()
|
run_and_check()
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@combinations.generate(maybe_distribute)
|
||||||
def test_assign_stays_in_true_dtype(self, distribute):
|
def test_assign_stays_in_true_dtype(self, distribution):
|
||||||
with get_distribute_scope(distribute):
|
with distribution.scope():
|
||||||
x = get_var(1., dtypes.float32)
|
x = get_var(1., dtypes.float32)
|
||||||
x = autocast_variable.create_autocast_variable(x)
|
x = autocast_variable.create_autocast_variable(x)
|
||||||
self.evaluate(x.initializer)
|
self.evaluate(x.initializer)
|
||||||
@ -382,10 +371,10 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(1., self.evaluate(x.value()))
|
self.assertEqual(1., self.evaluate(x.value()))
|
||||||
self.assertEqual(1. + small_val, self.evaluate(x.value()))
|
self.assertEqual(1. + small_val, self.evaluate(x.value()))
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@combinations.generate(maybe_distribute)
|
||||||
def test_checkpoint(self, distribute):
|
def test_checkpoint(self, distribution):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
with get_distribute_scope(distribute):
|
with distribution.scope():
|
||||||
x = get_var(1., dtypes.float32)
|
x = get_var(1., dtypes.float32)
|
||||||
x = autocast_variable.create_autocast_variable(x)
|
x = autocast_variable.create_autocast_variable(x)
|
||||||
self.evaluate(x.initializer)
|
self.evaluate(x.initializer)
|
||||||
@ -398,9 +387,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
|
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
|
||||||
self.assertEqual(self.evaluate(x), 123.)
|
self.assertEqual(self.evaluate(x), 123.)
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@combinations.generate(maybe_distribute)
|
||||||
def test_invalid_wrapped_variable(self, distribute):
|
def test_invalid_wrapped_variable(self, distribution):
|
||||||
with get_distribute_scope(distribute):
|
with distribution.scope():
|
||||||
# Wrap a non-variable
|
# Wrap a non-variable
|
||||||
with self.assertRaisesRegexp(ValueError, 'variable must be of type'):
|
with self.assertRaisesRegexp(ValueError, 'variable must be of type'):
|
||||||
x = constant_op.constant([1.], dtype=dtypes.float32)
|
x = constant_op.constant([1.], dtype=dtypes.float32)
|
||||||
@ -443,7 +432,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_repr_distributed(self):
|
def test_repr_distributed(self):
|
||||||
with get_distribute_scope(distribute=True):
|
with mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2']).scope():
|
||||||
x = get_var(1., dtypes.float32)
|
x = get_var(1., dtypes.float32)
|
||||||
x = autocast_variable.create_autocast_variable(x)
|
x = autocast_variable.create_autocast_variable(x)
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user