Merge pull request #40493 from lgeiger:refactor-autocast-variable-tests
PiperOrigin-RevId: 317024448 Change-Id: Ie3d376a6922f77682d94835e0d9ca6f74331f442
This commit is contained in:
		
						commit
						56c01dc970
					
				| @ -144,9 +144,10 @@ py_test( | ||||
|         "//tensorflow/python:client_testlib", | ||||
|         "//tensorflow/python:framework", | ||||
|         "//tensorflow/python:platform_test", | ||||
|         "//tensorflow/python/distribute:combinations", | ||||
|         "//tensorflow/python/distribute:mirrored_strategy", | ||||
|         "//tensorflow/python/distribute:strategy_combinations", | ||||
|         "//tensorflow/python/eager:context", | ||||
|         "//tensorflow/python/keras:combinations", | ||||
|         "//tensorflow/python/keras/optimizer_v2", | ||||
|         "@absl_py//absl/testing:parameterized", | ||||
|     ], | ||||
|  | ||||
| @ -23,14 +23,16 @@ from absl.testing import parameterized | ||||
| import numpy as np | ||||
| 
 | ||||
| from tensorflow.python import tf2 | ||||
| from tensorflow.python.distribute import combinations | ||||
| from tensorflow.python.distribute import distribution_strategy_context as ds_context | ||||
| from tensorflow.python.distribute import mirrored_strategy | ||||
| from tensorflow.python.distribute import strategy_combinations | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.eager import def_function | ||||
| from tensorflow.python.framework import constant_op | ||||
| from tensorflow.python.framework import dtypes | ||||
| from tensorflow.python.framework import indexed_slices | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.keras import combinations | ||||
| from tensorflow.python.keras.mixed_precision.experimental import autocast_variable | ||||
| from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 | ||||
| from tensorflow.python.ops import array_ops | ||||
| @ -40,29 +42,10 @@ from tensorflow.python.platform import test | ||||
| from tensorflow.python.training import gradient_descent as gradient_descent_v1 | ||||
| from tensorflow.python.training.tracking import util as trackable_utils | ||||
| 
 | ||||
| TESTCASES = ({ | ||||
|     'testcase_name': 'base', | ||||
|     'distribute': False | ||||
| }, { | ||||
|     'testcase_name': 'distribute', | ||||
|     'distribute': True | ||||
| }) | ||||
| 
 | ||||
| 
 | ||||
| def get_distribute_scope(distribute): | ||||
| 
 | ||||
|   class DummyContextManager(object): | ||||
| 
 | ||||
|     def __enter__(self): | ||||
|       pass | ||||
| 
 | ||||
|     def __exit__(self, *args): | ||||
|       pass | ||||
| 
 | ||||
|   if distribute: | ||||
|     return mirrored_strategy.MirroredStrategy(['cpu:0']).scope() | ||||
|   else: | ||||
|     return DummyContextManager() | ||||
| maybe_distribute = combinations.combine(distribution=[ | ||||
|     strategy_combinations.default_strategy, | ||||
|     strategy_combinations.mirrored_strategy_with_cpu_1_and_2 | ||||
| ]) | ||||
| 
 | ||||
| 
 | ||||
| def get_var(val, dtype, name=None): | ||||
| @ -72,9 +55,13 @@ def get_var(val, dtype, name=None): | ||||
| @combinations.generate(combinations.combine(mode=['graph', 'eager'])) | ||||
| class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
| 
 | ||||
|   @parameterized.named_parameters(*TESTCASES) | ||||
|   def test_read(self, distribute): | ||||
|     with get_distribute_scope(distribute): | ||||
|   def setUp(self): | ||||
|     strategy_combinations.set_virtual_cpus_to_at_least(3) | ||||
|     super(AutoCastVariableTest, self).setUp() | ||||
| 
 | ||||
|   @combinations.generate(maybe_distribute) | ||||
|   def test_read(self, distribution): | ||||
|     with distribution.scope(): | ||||
|       x = get_var(1., dtypes.float32) | ||||
|       x = autocast_variable.create_autocast_variable(x) | ||||
|       self.evaluate(x.initializer) | ||||
| @ -116,9 +103,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|       self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16) | ||||
|       self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16) | ||||
| 
 | ||||
|   @parameterized.named_parameters(*TESTCASES) | ||||
|   def test_read_nested_scopes(self, distribute): | ||||
|     with get_distribute_scope(distribute): | ||||
|   @combinations.generate(maybe_distribute) | ||||
|   def test_read_nested_scopes(self, distribution): | ||||
|     with distribution.scope(): | ||||
|       x = get_var(1., dtypes.float32) | ||||
|       x = autocast_variable.create_autocast_variable(x) | ||||
|       self.evaluate(x.initializer) | ||||
| @ -136,9 +123,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|         self.assertEqual(x.dtype, dtypes.float16) | ||||
|         self.assertEqual(x.read_value().dtype, dtypes.float16) | ||||
| 
 | ||||
|   @parameterized.named_parameters(*TESTCASES) | ||||
|   def test_dtype_is_not_string(self, distribute): | ||||
|     with get_distribute_scope(distribute): | ||||
|   @combinations.generate(maybe_distribute) | ||||
|   def test_dtype_is_not_string(self, distribution): | ||||
|     with distribution.scope(): | ||||
|       x = get_var(1., dtypes.float32) | ||||
|       x = autocast_variable.create_autocast_variable(x) | ||||
|       self.assertEqual(x.dtype, dtypes.float32) | ||||
| @ -153,13 +140,13 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|         self.assertEqual(x.true_dtype, dtypes.float32) | ||||
|         self.assertIsInstance(x.true_dtype, dtypes.DType) | ||||
| 
 | ||||
|   @parameterized.named_parameters(*TESTCASES) | ||||
|   def test_method_delegations(self, distribute): | ||||
|   @combinations.generate(maybe_distribute) | ||||
|   def test_method_delegations(self, distribution): | ||||
|     # Test AutoCastVariable correctly delegates Variable methods to the | ||||
|     # underlying variable. | ||||
|     with self.test_session(), get_distribute_scope(distribute): | ||||
|     with self.test_session(), distribution.scope(): | ||||
|       for read_dtype in (dtypes.float32, dtypes.float16): | ||||
|         if distribute: | ||||
|         if ds_context.has_strategy(): | ||||
|           # MirroredVariable.assign will (incorrectly) return a Mirrored value | ||||
|           # instead of a MirroredVariable. So we cannot properly wrap it in an | ||||
|           # AutoCastVariable. | ||||
| @ -183,14 +170,14 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|           self.assertEqual(x.aggregation, x._variable.aggregation) | ||||
|           self.assertEqual(self.evaluate(x.initialized_value()), 7) | ||||
|           if not context.executing_eagerly(): | ||||
|             if not distribute: | ||||
|             if not ds_context.has_strategy(): | ||||
|               # These functions are not supported for DistributedVariables | ||||
|               x.load(9) | ||||
|               self.assertEqual(x.eval(), 9) | ||||
|             self.assertEqual(self.evaluate(x.initial_value), 7) | ||||
|             self.assertEqual(x.op, x._variable.op) | ||||
|             self.assertEqual(x.graph, x._variable.graph) | ||||
|           if not distribute: | ||||
|           if not ds_context.has_strategy(): | ||||
|             # These attributes are not supported for DistributedVariables | ||||
|             self.assertIsNone(x.constraint) | ||||
|             self.assertEqual(x.initializer, x._variable.initializer) | ||||
| @ -202,7 +189,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|           self.assertEqual(x.shape, ()) | ||||
|           self.assertEqual(x.get_shape(), ()) | ||||
| 
 | ||||
|         if not distribute: | ||||
|         if not ds_context.has_strategy(): | ||||
|           # Test scatter_* methods. These are not supported for | ||||
|           # DistributedVariables | ||||
|           x = get_var([7, 8], dtypes.float32) | ||||
| @ -233,9 +220,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|             self.assertAllEqual( | ||||
|                 evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2]) | ||||
| 
 | ||||
|   @parameterized.named_parameters(*TESTCASES) | ||||
|   def test_operator_overloads(self, distribute): | ||||
|     with get_distribute_scope(distribute): | ||||
|   @combinations.generate(maybe_distribute) | ||||
|   def test_operator_overloads(self, distribution): | ||||
|     with distribution.scope(): | ||||
|       for read_dtype in (dtypes.float32, dtypes.float16): | ||||
|         x = get_var(7., dtypes.float32) | ||||
|         x = autocast_variable.create_autocast_variable(x) | ||||
| @ -280,9 +267,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|             self.assertAllEqual(x == [7., 8., 10.], [True, True, False]) | ||||
|             self.assertAllEqual(x != [7., 8., 10.], [False, False, True]) | ||||
| 
 | ||||
|   @parameterized.named_parameters(*TESTCASES) | ||||
|   def test_assign(self, distribute): | ||||
|     with get_distribute_scope(distribute): | ||||
|   @combinations.generate(maybe_distribute) | ||||
|   def test_assign(self, distribution): | ||||
|     with distribution.scope(): | ||||
|       x = get_var(0., dtypes.float32) | ||||
|       x = autocast_variable.create_autocast_variable(x) | ||||
|       self.evaluate(x.initializer) | ||||
| @ -318,18 +305,20 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|         self.assertAllClose(3., self.evaluate(x.assign_sub(3.))) | ||||
| 
 | ||||
|         # Assign multiple times | ||||
|         assign = x.assign(1.) | ||||
|         self.assertAllClose(1., self.evaluate(assign)) | ||||
|         self.assertAllClose(0., self.evaluate(assign.assign(0.))) | ||||
|         assign_add = x.assign_add(3.) | ||||
|         self.assertAllClose(3., self.evaluate(assign_add)) | ||||
|         self.assertAllClose(3. * 3, | ||||
|                             self.evaluate(x.assign_add(3.).assign_add(3.))) | ||||
|         self.assertAllClose(3. * 3, x) | ||||
|         assign_sub = x.assign_sub(3.) | ||||
|         self.assertAllClose(3. * 2, self.evaluate(assign_sub)) | ||||
|         self.assertAllClose(0., | ||||
|                             self.evaluate(x.assign_sub(3.).assign_sub(3.))) | ||||
|         # This currently only works if no strategy is used | ||||
|         if not ds_context.has_strategy(): | ||||
|           assign = x.assign(1.) | ||||
|           self.assertAllClose(1., self.evaluate(assign)) | ||||
|           self.assertAllClose(0., self.evaluate(assign.assign(0.))) | ||||
|           assign_add = x.assign_add(3.) | ||||
|           self.assertAllClose(3., self.evaluate(assign_add)) | ||||
|           self.assertAllClose(3. * 3, | ||||
|                               self.evaluate(x.assign_add(3.).assign_add(3.))) | ||||
|           self.assertAllClose(3. * 3, x) | ||||
|           assign_sub = x.assign_sub(3.) | ||||
|           self.assertAllClose(3. * 2, self.evaluate(assign_sub)) | ||||
|           self.assertAllClose(0., | ||||
|                               self.evaluate(x.assign_sub(3.).assign_sub(3.))) | ||||
| 
 | ||||
|         # Assign with read_value=False | ||||
|         self.assertIsNone(self.evaluate(x.assign(1., read_value=False))) | ||||
| @ -355,9 +344,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|         # assign still expect float32 value even if in float16 scope | ||||
|         run_and_check() | ||||
| 
 | ||||
|   @parameterized.named_parameters(*TESTCASES) | ||||
|   def test_assign_stays_in_true_dtype(self, distribute): | ||||
|     with get_distribute_scope(distribute): | ||||
|   @combinations.generate(maybe_distribute) | ||||
|   def test_assign_stays_in_true_dtype(self, distribution): | ||||
|     with distribution.scope(): | ||||
|       x = get_var(1., dtypes.float32) | ||||
|       x = autocast_variable.create_autocast_variable(x) | ||||
|       self.evaluate(x.initializer) | ||||
| @ -382,10 +371,10 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|         self.assertEqual(1., self.evaluate(x.value())) | ||||
|       self.assertEqual(1. + small_val, self.evaluate(x.value())) | ||||
| 
 | ||||
|   @parameterized.named_parameters(*TESTCASES) | ||||
|   def test_checkpoint(self, distribute): | ||||
|   @combinations.generate(maybe_distribute) | ||||
|   def test_checkpoint(self, distribution): | ||||
|     with self.test_session(): | ||||
|       with get_distribute_scope(distribute): | ||||
|       with distribution.scope(): | ||||
|         x = get_var(1., dtypes.float32) | ||||
|         x = autocast_variable.create_autocast_variable(x) | ||||
|       self.evaluate(x.initializer) | ||||
| @ -398,9 +387,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|       checkpoint.restore(save_path).assert_consumed().run_restore_ops() | ||||
|       self.assertEqual(self.evaluate(x), 123.) | ||||
| 
 | ||||
|   @parameterized.named_parameters(*TESTCASES) | ||||
|   def test_invalid_wrapped_variable(self, distribute): | ||||
|     with get_distribute_scope(distribute): | ||||
|   @combinations.generate(maybe_distribute) | ||||
|   def test_invalid_wrapped_variable(self, distribution): | ||||
|     with distribution.scope(): | ||||
|       # Wrap a non-variable | ||||
|       with self.assertRaisesRegexp(ValueError, 'variable must be of type'): | ||||
|         x = constant_op.constant([1.], dtype=dtypes.float32) | ||||
| @ -443,7 +432,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): | ||||
|         ) | ||||
| 
 | ||||
|   def test_repr_distributed(self): | ||||
|     with get_distribute_scope(distribute=True): | ||||
|     with mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2']).scope(): | ||||
|       x = get_var(1., dtypes.float32) | ||||
|       x = autocast_variable.create_autocast_variable(x) | ||||
|       self.assertRegexpMatches( | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user