From 609a60b44bbf934b31a1dce4f0aa84e731b83c35 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 16 Jun 2020 00:41:09 +0100 Subject: [PATCH 1/3] Refactor AutoCastVariable tests to rely on strategy_combinations --- .../keras/mixed_precision/experimental/BUILD | 3 +- .../experimental/autocast_variable_test.py | 130 +++++++++--------- 2 files changed, 64 insertions(+), 69 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD index 024b093c469..4060e455f84 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/BUILD +++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD @@ -144,9 +144,10 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:context", - "//tensorflow/python/keras:combinations", "//tensorflow/python/keras/optimizer_v2", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py index 78041973cc1..95957f5634e 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py @@ -17,20 +17,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib import os from absl.testing import parameterized import numpy as np from tensorflow.python import tf2 +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import strategy_combinations from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops -from tensorflow.python.keras import combinations from tensorflow.python.keras.mixed_precision.experimental import autocast_variable from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.ops import array_ops @@ -40,30 +43,17 @@ from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent as gradient_descent_v1 from tensorflow.python.training.tracking import util as trackable_utils -TESTCASES = ({ - 'testcase_name': 'base', - 'distribute': False -}, { - 'testcase_name': 'distribute', - 'distribute': True -}) - - -def get_distribute_scope(distribute): - - class DummyContextManager(object): - - def __enter__(self): - pass - - def __exit__(self, *args): - pass - - if distribute: - return mirrored_strategy.MirroredStrategy(['cpu:0']).scope() - else: - return DummyContextManager() +class DummyStrategy(object): + @contextlib.contextmanager + def scope(self): + yield +maybe_distribute = combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Dummy", lambda: DummyStrategy(), required_gpus=None), + strategy_combinations.mirrored_strategy_with_cpu_1_and_2 + ]) def get_var(val, dtype, name=None): return variables.VariableV1(val, use_resource=True, dtype=dtype, name=name) @@ -71,10 +61,13 @@ def get_var(val, dtype, name=None): @combinations.generate(combinations.combine(mode=['graph', 'eager'])) class AutoCastVariableTest(test.TestCase, parameterized.TestCase): + def setUp(self): + strategy_combinations.set_virtual_cpus_to_at_least(3) + super(AutoCastVariableTest, self).setUp() - @parameterized.named_parameters(*TESTCASES) - def test_read(self, distribute): - with get_distribute_scope(distribute): + @combinations.generate(maybe_distribute) + def test_read(self, distribution): + with distribution.scope(): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) @@ -116,9 +109,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16) self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16) - @parameterized.named_parameters(*TESTCASES) - def test_read_nested_scopes(self, distribute): - with get_distribute_scope(distribute): + @combinations.generate(maybe_distribute) + def test_read_nested_scopes(self, distribution): + with distribution.scope(): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) @@ -136,9 +129,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertEqual(x.dtype, dtypes.float16) self.assertEqual(x.read_value().dtype, dtypes.float16) - @parameterized.named_parameters(*TESTCASES) - def test_dtype_is_not_string(self, distribute): - with get_distribute_scope(distribute): + @combinations.generate(maybe_distribute) + def test_dtype_is_not_string(self, distribution): + with distribution.scope(): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.assertEqual(x.dtype, dtypes.float32) @@ -153,13 +146,13 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertEqual(x.true_dtype, dtypes.float32) self.assertIsInstance(x.true_dtype, dtypes.DType) - @parameterized.named_parameters(*TESTCASES) - def test_method_delegations(self, distribute): + @combinations.generate(maybe_distribute) + def test_method_delegations(self, distribution): # Test AutoCastVariable correctly delegates Variable methods to the # underlying variable. - with self.test_session(), get_distribute_scope(distribute): + with self.test_session(), distribution.scope(): for read_dtype in (dtypes.float32, dtypes.float16): - if distribute: + if ds_context.has_strategy(): # MirroredVariable.assign will (incorrectly) return a Mirrored value # instead of a MirroredVariable. So we cannot properly wrap it in an # AutoCastVariable. @@ -183,14 +176,14 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertEqual(x.aggregation, x._variable.aggregation) self.assertEqual(self.evaluate(x.initialized_value()), 7) if not context.executing_eagerly(): - if not distribute: + if not ds_context.has_strategy(): # These functions are not supported for DistributedVariables x.load(9) self.assertEqual(x.eval(), 9) self.assertEqual(self.evaluate(x.initial_value), 7) self.assertEqual(x.op, x._variable.op) self.assertEqual(x.graph, x._variable.graph) - if not distribute: + if not ds_context.has_strategy(): # These attributes are not supported for DistributedVariables self.assertIsNone(x.constraint) self.assertEqual(x.initializer, x._variable.initializer) @@ -202,7 +195,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertEqual(x.shape, ()) self.assertEqual(x.get_shape(), ()) - if not distribute: + if not ds_context.has_strategy(): # Test scatter_* methods. These are not supported for # DistributedVariables x = get_var([7, 8], dtypes.float32) @@ -233,9 +226,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertAllEqual( evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2]) - @parameterized.named_parameters(*TESTCASES) - def test_operator_overloads(self, distribute): - with get_distribute_scope(distribute): + @combinations.generate(maybe_distribute) + def test_operator_overloads(self, distribution): + with distribution.scope(): for read_dtype in (dtypes.float32, dtypes.float16): x = get_var(7., dtypes.float32) x = autocast_variable.create_autocast_variable(x) @@ -280,9 +273,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(x == [7., 8., 10.], [True, True, False]) self.assertAllEqual(x != [7., 8., 10.], [False, False, True]) - @parameterized.named_parameters(*TESTCASES) - def test_assign(self, distribute): - with get_distribute_scope(distribute): + @combinations.generate(maybe_distribute) + def test_assign(self, distribution): + with distribution.scope(): x = get_var(0., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) @@ -318,18 +311,19 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertAllClose(3., self.evaluate(x.assign_sub(3.))) # Assign multiple times - assign = x.assign(1.) - self.assertAllClose(1., self.evaluate(assign)) - self.assertAllClose(0., self.evaluate(assign.assign(0.))) - assign_add = x.assign_add(3.) - self.assertAllClose(3., self.evaluate(assign_add)) - self.assertAllClose(3. * 3, - self.evaluate(x.assign_add(3.).assign_add(3.))) - self.assertAllClose(3. * 3, x) - assign_sub = x.assign_sub(3.) - self.assertAllClose(3. * 2, self.evaluate(assign_sub)) - self.assertAllClose(0., - self.evaluate(x.assign_sub(3.).assign_sub(3.))) + if not ds_context.has_strategy(): + assign = x.assign(1.) + self.assertAllClose(1., self.evaluate(assign)) + self.assertAllClose(0., self.evaluate(assign.assign(0.))) + assign_add = x.assign_add(3.) + self.assertAllClose(3., self.evaluate(assign_add)) + self.assertAllClose(3. * 3, + self.evaluate(x.assign_add(3.).assign_add(3.))) + self.assertAllClose(3. * 3, x) + assign_sub = x.assign_sub(3.) + self.assertAllClose(3. * 2, self.evaluate(assign_sub)) + self.assertAllClose(0., + self.evaluate(x.assign_sub(3.).assign_sub(3.))) # Assign with read_value=False self.assertIsNone(self.evaluate(x.assign(1., read_value=False))) @@ -355,9 +349,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): # assign still expect float32 value even if in float16 scope run_and_check() - @parameterized.named_parameters(*TESTCASES) - def test_assign_stays_in_true_dtype(self, distribute): - with get_distribute_scope(distribute): + @combinations.generate(maybe_distribute) + def test_assign_stays_in_true_dtype(self, distribution): + with distribution.scope(): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) @@ -382,10 +376,10 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertEqual(1., self.evaluate(x.value())) self.assertEqual(1. + small_val, self.evaluate(x.value())) - @parameterized.named_parameters(*TESTCASES) - def test_checkpoint(self, distribute): + @combinations.generate(maybe_distribute) + def test_checkpoint(self, distribution): with self.test_session(): - with get_distribute_scope(distribute): + with distribution.scope(): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) @@ -398,9 +392,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): checkpoint.restore(save_path).assert_consumed().run_restore_ops() self.assertEqual(self.evaluate(x), 123.) - @parameterized.named_parameters(*TESTCASES) - def test_invalid_wrapped_variable(self, distribute): - with get_distribute_scope(distribute): + @combinations.generate(maybe_distribute) + def test_invalid_wrapped_variable(self, distribution): + with distribution.scope(): # Wrap a non-variable with self.assertRaisesRegexp(ValueError, 'variable must be of type'): x = constant_op.constant([1.], dtype=dtypes.float32) @@ -443,7 +437,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): ) def test_repr_distributed(self): - with get_distribute_scope(distribute=True): + with mirrored_strategy.MirroredStrategy(["/cpu:1", "/cpu:2"]).scope(): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.assertRegexpMatches( From fe6580b4d85b12a2ce7b1a529b70fdbfeedd899e Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 16 Jun 2020 02:16:09 +0100 Subject: [PATCH 2/3] Add comment about ignoring distributed multi assignment --- .../keras/mixed_precision/experimental/autocast_variable_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py index 95957f5634e..14f26cdf953 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py @@ -311,6 +311,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertAllClose(3., self.evaluate(x.assign_sub(3.))) # Assign multiple times + # This currently only works if no strategy is used if not ds_context.has_strategy(): assign = x.assign(1.) self.assertAllClose(1., self.evaluate(assign)) From da2046b8a3b1bb79c77bf258aa8a52887bc3703a Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 16 Jun 2020 02:25:18 +0100 Subject: [PATCH 3/3] Use default_strategy instead of dummy scope --- .../experimental/autocast_variable_test.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py index 14f26cdf953..c45015b644e 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib import os from absl.testing import parameterized @@ -43,15 +42,9 @@ from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent as gradient_descent_v1 from tensorflow.python.training.tracking import util as trackable_utils -class DummyStrategy(object): - @contextlib.contextmanager - def scope(self): - yield - maybe_distribute = combinations.combine( distribution=[ - combinations.NamedDistribution( - "Dummy", lambda: DummyStrategy(), required_gpus=None), + strategy_combinations.default_strategy, strategy_combinations.mirrored_strategy_with_cpu_1_and_2 ])