From 74c52531846cc10a63fb244966ab6bfd000af747 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Tue, 17 Sep 2019 18:24:54 -0700 Subject: [PATCH] Have AutoCastVariable subclass from Variable. This allows AutoCastVariable to pass isinstance(..., tf.Variable) checks, fixing various small issues. PiperOrigin-RevId: 269699969 --- .../experimental/autocast_variable.py | 155 ++++++++++++++---- .../experimental/autocast_variable_test.py | 98 +++++++++++ .../experimental/keras_test.py | 8 + 3 files changed, 231 insertions(+), 30 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py index 59a0e08cba1..35c65ac43f4 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py @@ -22,11 +22,10 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.ops import variables -# TODO(reedwm) Make this subclass AutoCastVariable. -class AutoCastVariable(trackable.Trackable): +class AutoCastVariable(variables.Variable): """Variable that will cast itself to a different dtype in applicable contexts. This class wraps a floating-point tf.Variable. It emulates the variable @@ -67,14 +66,6 @@ class AutoCastVariable(trackable.Trackable): 'type: %s' % variable.dtype.name) self._variable = variable - # Delegate to the underlying variable for checkpointing. - self._gather_saveables_for_checkpoint = ( - self._variable._gather_saveables_for_checkpoint) # pylint: disable=protected-access - - @property - def name(self): - return self._variable.name - def _should_cast(self): """Returns True if this variable should be casted when accessed.""" g = ops.get_default_graph() @@ -108,31 +99,17 @@ class AutoCastVariable(trackable.Trackable): def read_value(self): val = self._variable.read_value() - if not self._should_cast(): - return val return math_ops.cast(val, self.dtype) def sparse_read(self, indices, name=None): """Reads the value of this variable sparsely, using `gather`.""" val = self._variable.sparse_read(indices, name=name) - if not self._should_cast(): - return val return math_ops.cast(val, self.dtype) - def assign(self, value, use_locking=None, name=None, read_value=True): - return self._variable.assign( - value, use_locking=use_locking, name=name, read_value=read_value) - - def assign_add(self, delta, use_locking=None, name=None, read_value=True): - return self._variable.assign_add( - delta, use_locking=use_locking, name=name, read_value=read_value) - - def assign_sub(self, delta, use_locking=None, name=None, read_value=True): - return self._variable.assign_sub( - delta, use_locking=use_locking, name=name, read_value=read_value) - - # TODO(reedwm): Support assigning variables with tf.compat.v1.assign(), - # var.scatter_add, etc. + def gather_nd(self, indices, name=None): + """Gather slices of the variable into a Tensor.""" + val = self._variable.gather_nd(indices, name=name) + return math_ops.cast(val, self.dtype) def __getattr__(self, name): return getattr(self._variable, name) @@ -171,11 +148,129 @@ class AutoCastVariable(trackable.Trackable): 'dtype={v.dtype.name} true_dtype={v.true_dtype.name}>') return repr_str.format(v=self) + # Method delegations: We delegate the following methods to self._variable. + # Each of these methods simply calls the same method on self._variable. The + # base Variable raises NotImplementedError for most of these, so we must + # override them. + # + # We do not define the following methods from Variable for the following + # reasons: + # * 'count_up_to': This method only applies to int variables, which cannot + # be wrapped with an AutoCastVariable. + # * 'experimental_ref': Instead we inherit the definition from Variable. + # If we defined and delegated to Variable, the ref of an AutoCastVariable + # would be the same as the ref of the underlying variable, which would be + # strange as they are different Python objects. + + # pylint: disable=multiple-statements + def set_shape(self, shape): return self._variable.set_shape(self, shape) + + @property + def trainable(self): return self._variable.trainable + + @property + def synchronization(self): return self._variable.synchronization + + @property + def aggregation(self): return self._variable.aggregation + + def eval(self, session=None): return self._variable.eval(session) + + def initialized_value(self): return self._variable.initialized_value() + + @property + def initial_value(self): return self._variable.initial_value + + @property + def constraint(self): return self._variable.constraint + + def assign(self, value, use_locking=None, name=None, read_value=True): + return self._variable.assign(value, use_locking, name, read_value) + + def assign_add(self, delta, use_locking=None, name=None, read_value=True): + return self._variable.assign_add(delta, use_locking, name, read_value) + + def assign_sub(self, delta, use_locking=None, name=None, read_value=True): + return self._variable.assign_sub(delta, use_locking, name, read_value) + + def scatter_sub(self, sparse_delta, use_locking=False, name=None): + return self._variable.scatter_sub(sparse_delta, use_locking, name) + + def scatter_add(self, sparse_delta, use_locking=False, name=None): + return self._variable.scatter_add(sparse_delta, use_locking, name) + + def scatter_max(self, sparse_delta, use_locking=False, name=None): + return self._variable.scatter_max(sparse_delta, use_locking, name) + + def scatter_min(self, sparse_delta, use_locking=False, name=None): + return self._variable.scatter_min(sparse_delta, use_locking, name) + + def scatter_mul(self, sparse_delta, use_locking=False, name=None): + return self._variable.scatter_mul(sparse_delta, use_locking, name) + + def scatter_div(self, sparse_delta, use_locking=False, name=None): + return self._variable.scatter_div(sparse_delta, use_locking, name) + + def scatter_update(self, sparse_delta, use_locking=False, name=None): + return self._variable.scatter_update(sparse_delta, use_locking, name) + + def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): + return self._variable.batch_scatter_update(sparse_delta, use_locking, name) + + def scatter_nd_sub(self, indices, updates, name=None): + return self._variable.scatter_nd_sub(indices, updates, name) + + def scatter_nd_add(self, indices, updates, name=None): + return self._variable.scatter_nd_add(indices, updates, name) + + def scatter_nd_update(self, indices, updates, name=None): + return self._variable.scatter_nd_update(indices, updates, name) + + def load(self, value, session=None): + return self._variable.load(value, session) + + @property + def name(self): return self._variable.name + + @property + def _shared_name(self): return self._variable._shared_name # pylint:disable=protected-access + + @property + def initializer(self): return self._variable.initializer + + @property + def device(self): return self._variable.device + + @property + def op(self): return self._variable.op + + @property + def graph(self): return self._variable.graph + + @property + def shape(self): return self._variable.shape + + def get_shape(self): return self._variable.get_shape() + + def _gather_saveables_for_checkpoint(self): + # By delegating this method to the wrapped variable, checkpoints with + # AutoCastVariables are identical to checkpoints with normal variables. + # Therefore models checkpointed with AutoCastVariables can be restored on + # models with normal variables, and vice versa. + return self._variable._gather_saveables_for_checkpoint() # pylint:disable=protected-access + + # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in + # to_proto(). + def to_proto(self, export_scope=None): + return self._variable.to_proto(export_scope) + + def from_proto(self, variable_def, import_scope=None): + return self._variable.from_proto(variable_def, import_scope) + # Operator overloads: # Note we only overload operators that support floating-point types, as # non-float variables cannot be wrapped with an AutoCastVariable. - # pylint: disable=multiple-statements def __add__(self, o): return self.value() + o def __radd__(self, o): return o + self.value() def __sub__(self, o): return self.value() - o diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py index a9fdcfcc219..80b8adfea3c 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py @@ -22,14 +22,17 @@ import os from absl.testing import parameterized import numpy as np +from tensorflow.python import tf2 from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.mixed_precision.experimental import autocast_variable from tensorflow.python.ops import array_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training.tracking import util as trackable_utils @@ -102,6 +105,21 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertEqual(x.read_value().dtype, dtypes.float32) self.assertEqual(array_ops.identity(x).dtype, dtypes.float32) + def test_sparse_reads(self): + x = get_var([1., 2], dtypes.float32) + # DistributedVariables do not support sparse_read or gather_nd, so we pass + # distribute=False + x = get_autocast_var(x, distribute=False) + self.evaluate(x.initializer) + + self.assertEqual(x.sparse_read([0]).dtype, dtypes.float32) + self.assertEqual(x.gather_nd([0]).dtype, dtypes.float32) + + with ops.get_default_graph()._enable_auto_casting_variables( + dtypes.float16): + self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16) + self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16) + @parameterized.named_parameters(*TESTCASES) def test_read_nested_scopes(self, distribute): with get_distribute_scope(distribute): @@ -138,6 +156,75 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.assertEqual(x.true_dtype, dtypes.float32) self.assertIsInstance(x.true_dtype, dtypes.DType) + @parameterized.named_parameters(*TESTCASES) + def test_method_delegations(self, distribute): + # Test AutoCastVariable correctly delegates Variable methods to the + # underlying variable. + with get_distribute_scope(distribute): + evaluate = self.evaluate + for read_dtype in (dtypes.float32, dtypes.float16): + x = get_var(7., dtypes.float32) + x = get_autocast_var(x, distribute) + with ops.get_default_graph()._enable_auto_casting_variables( + read_dtype): + evaluate(x.initializer) + self.assertEqual(evaluate(x.value()), 7) + self.assertEqual(evaluate(x.read_value()), 7) + self.assertTrue(x.trainable) + self.assertEqual(x.synchronization, x._variable.synchronization) + self.assertEqual(x.aggregation, x._variable.aggregation) + self.assertEqual(evaluate(x.initialized_value()), 7) + if not context.executing_eagerly(): + if not distribute: + # These functions are not supported for DistributedVariables + x.load(9) + self.assertEqual(x.eval(), 9) + self.assertEqual(evaluate(x.initial_value), 7) + self.assertEqual(x.op, x._variable.op) + self.assertEqual(x.graph, x._variable.graph) + if not distribute: + # These attributes are not supported for DistributedVariables + self.assertIsNone(x.constraint) + self.assertEqual(x.initializer, x._variable.initializer) + self.assertEqual(evaluate(x.assign(8)), 8) + self.assertEqual(evaluate(x.assign_add(2)), 10) + self.assertEqual(evaluate(x.assign_sub(3)), 7) + self.assertEqual(x.name, x._variable.name) + self.assertEqual(x.device, x._variable.device) + self.assertEqual(x.shape, ()) + self.assertEqual(x.get_shape(), ()) + + if not distribute: + # Test scatter_* methods. These are not supported for + # DistributedVariables + x = get_var([7, 8], dtypes.float32) + x = get_autocast_var(x, distribute) + with ops.get_default_graph()._enable_auto_casting_variables( + read_dtype): + evaluate(x.initializer) + self.assertAllEqual(evaluate(x.value()), [7, 8]) + + def slices(val, index): + return indexed_slices.IndexedSlices( + values=constant_op.constant(val, dtype=dtypes.float32), + indices=constant_op.constant(index, dtype=dtypes.int32), + dense_shape=constant_op.constant([2], dtype=dtypes.int32)) + + self.assertAllEqual(evaluate(x.scatter_sub(slices(1., 0))), [6, 8]) + self.assertAllEqual(evaluate(x.scatter_add(slices(1., 0))), [7, 8]) + self.assertAllEqual(evaluate(x.scatter_max(slices(9., 1))), [7, 9]) + self.assertAllEqual(evaluate(x.scatter_min(slices(8., 1))), [7, 8]) + self.assertAllEqual(evaluate(x.scatter_mul(slices(2., 1))), [7, 16]) + self.assertAllEqual(evaluate(x.scatter_div(slices(2., 1))), [7, 8]) + self.assertAllEqual( + evaluate(x.scatter_update(slices(4., 1))), [7, 4]) + self.assertAllEqual( + evaluate(x.scatter_nd_sub([[0], [1]], [1., 2.])), [6, 2]) + self.assertAllEqual( + evaluate(x.scatter_nd_add([[0], [1]], [1., 2.])), [7, 4]) + self.assertAllEqual( + evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2]) + @parameterized.named_parameters(*TESTCASES) def test_operator_overloads(self, distribute): with get_distribute_scope(distribute): @@ -181,6 +268,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): x = get_autocast_var(x, distribute) self.evaluate(x.initializer) self.assertEqual(self.evaluate(x[1]), 8) + if tf2.enabled() and context.executing_eagerly(): + self.assertAllEqual(x == [7., 8., 10.], [True, True, False]) + self.assertAllEqual(x != [7., 8., 10.], [False, False, True]) @parameterized.named_parameters(*TESTCASES) def test_assign(self, distribute): @@ -214,10 +304,18 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.evaluate(x.assign_sub(v2)) # Assign Python floats + self.assertAllClose(0., self.evaluate(x.assign(0.))) self.assertAllClose(3.14, self.evaluate(x.assign(3.14))) self.assertAllClose(3.14 * 2, self.evaluate(x.assign_add(3.14))) self.assertAllClose(3.14, self.evaluate(x.assign_sub(3.14))) + # Use the tf.assign functions instead of the var.assign methods. + self.assertAllClose(0., self.evaluate(state_ops.assign(x, 0.))) + self.assertAllClose(3.14, self.evaluate(state_ops.assign(x, 3.14))) + self.assertAllClose(3.14 * 2, + self.evaluate(state_ops.assign_add(x, 3.14))) + self.assertAllClose(3.14, self.evaluate(state_ops.assign_sub(x, 3.14))) + run_and_check() # reset x self.evaluate(x.assign(0.)) diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py index bbea1a75a9b..1ddae81eb20 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py @@ -429,6 +429,14 @@ class KerasLayerTest(keras_parameterized.TestCase): self._test_checkpointing_layer_weights( strategy_fn, mixed_prec_when_saving=False, mixed_prec_when_loading=True) + @test_util.run_in_graph_and_eager_modes + def test_delete_variable(self): + layer = base_layer.Layer(dtype=policy.Policy('mixed_float16')) + layer.x = layer.add_weight('x') + self.assertEqual(layer.trainable_weights, [layer.x]) + del layer.x + self.assertEqual(layer.trainable_weights, []) + class KerasModelTest(keras_parameterized.TestCase): """Test mixed precision with Keras models."""