Have AutoCastVariable subclass from Variable.

This allows AutoCastVariable to pass isinstance(..., tf.Variable) checks, fixing various small issues.

PiperOrigin-RevId: 269699969
This commit is contained in:
Reed Wanderman-Milne 2019-09-17 18:24:54 -07:00 committed by TensorFlower Gardener
parent 790ff20d98
commit 74c5253184
3 changed files with 231 additions and 30 deletions

View File

@ -22,11 +22,10 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.ops import variables
# TODO(reedwm) Make this subclass AutoCastVariable.
class AutoCastVariable(trackable.Trackable):
class AutoCastVariable(variables.Variable):
"""Variable that will cast itself to a different dtype in applicable contexts.
This class wraps a floating-point tf.Variable. It emulates the variable
@ -67,14 +66,6 @@ class AutoCastVariable(trackable.Trackable):
'type: %s' % variable.dtype.name)
self._variable = variable
# Delegate to the underlying variable for checkpointing.
self._gather_saveables_for_checkpoint = (
self._variable._gather_saveables_for_checkpoint) # pylint: disable=protected-access
@property
def name(self):
return self._variable.name
def _should_cast(self):
"""Returns True if this variable should be casted when accessed."""
g = ops.get_default_graph()
@ -108,31 +99,17 @@ class AutoCastVariable(trackable.Trackable):
def read_value(self):
val = self._variable.read_value()
if not self._should_cast():
return val
return math_ops.cast(val, self.dtype)
def sparse_read(self, indices, name=None):
"""Reads the value of this variable sparsely, using `gather`."""
val = self._variable.sparse_read(indices, name=name)
if not self._should_cast():
return val
return math_ops.cast(val, self.dtype)
def assign(self, value, use_locking=None, name=None, read_value=True):
return self._variable.assign(
value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self, delta, use_locking=None, name=None, read_value=True):
return self._variable.assign_add(
delta, use_locking=use_locking, name=name, read_value=read_value)
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
return self._variable.assign_sub(
delta, use_locking=use_locking, name=name, read_value=read_value)
# TODO(reedwm): Support assigning variables with tf.compat.v1.assign(),
# var.scatter_add, etc.
def gather_nd(self, indices, name=None):
"""Gather slices of the variable into a Tensor."""
val = self._variable.gather_nd(indices, name=name)
return math_ops.cast(val, self.dtype)
def __getattr__(self, name):
return getattr(self._variable, name)
@ -171,11 +148,129 @@ class AutoCastVariable(trackable.Trackable):
'dtype={v.dtype.name} true_dtype={v.true_dtype.name}>')
return repr_str.format(v=self)
# Method delegations: We delegate the following methods to self._variable.
# Each of these methods simply calls the same method on self._variable. The
# base Variable raises NotImplementedError for most of these, so we must
# override them.
#
# We do not define the following methods from Variable for the following
# reasons:
# * 'count_up_to': This method only applies to int variables, which cannot
# be wrapped with an AutoCastVariable.
# * 'experimental_ref': Instead we inherit the definition from Variable.
# If we defined and delegated to Variable, the ref of an AutoCastVariable
# would be the same as the ref of the underlying variable, which would be
# strange as they are different Python objects.
# pylint: disable=multiple-statements
def set_shape(self, shape): return self._variable.set_shape(self, shape)
@property
def trainable(self): return self._variable.trainable
@property
def synchronization(self): return self._variable.synchronization
@property
def aggregation(self): return self._variable.aggregation
def eval(self, session=None): return self._variable.eval(session)
def initialized_value(self): return self._variable.initialized_value()
@property
def initial_value(self): return self._variable.initial_value
@property
def constraint(self): return self._variable.constraint
def assign(self, value, use_locking=None, name=None, read_value=True):
return self._variable.assign(value, use_locking, name, read_value)
def assign_add(self, delta, use_locking=None, name=None, read_value=True):
return self._variable.assign_add(delta, use_locking, name, read_value)
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
return self._variable.assign_sub(delta, use_locking, name, read_value)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
return self._variable.scatter_sub(sparse_delta, use_locking, name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
return self._variable.scatter_add(sparse_delta, use_locking, name)
def scatter_max(self, sparse_delta, use_locking=False, name=None):
return self._variable.scatter_max(sparse_delta, use_locking, name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
return self._variable.scatter_min(sparse_delta, use_locking, name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
return self._variable.scatter_mul(sparse_delta, use_locking, name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
return self._variable.scatter_div(sparse_delta, use_locking, name)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
return self._variable.scatter_update(sparse_delta, use_locking, name)
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
return self._variable.batch_scatter_update(sparse_delta, use_locking, name)
def scatter_nd_sub(self, indices, updates, name=None):
return self._variable.scatter_nd_sub(indices, updates, name)
def scatter_nd_add(self, indices, updates, name=None):
return self._variable.scatter_nd_add(indices, updates, name)
def scatter_nd_update(self, indices, updates, name=None):
return self._variable.scatter_nd_update(indices, updates, name)
def load(self, value, session=None):
return self._variable.load(value, session)
@property
def name(self): return self._variable.name
@property
def _shared_name(self): return self._variable._shared_name # pylint:disable=protected-access
@property
def initializer(self): return self._variable.initializer
@property
def device(self): return self._variable.device
@property
def op(self): return self._variable.op
@property
def graph(self): return self._variable.graph
@property
def shape(self): return self._variable.shape
def get_shape(self): return self._variable.get_shape()
def _gather_saveables_for_checkpoint(self):
# By delegating this method to the wrapped variable, checkpoints with
# AutoCastVariables are identical to checkpoints with normal variables.
# Therefore models checkpointed with AutoCastVariables can be restored on
# models with normal variables, and vice versa.
return self._variable._gather_saveables_for_checkpoint() # pylint:disable=protected-access
# TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in
# to_proto().
def to_proto(self, export_scope=None):
return self._variable.to_proto(export_scope)
def from_proto(self, variable_def, import_scope=None):
return self._variable.from_proto(variable_def, import_scope)
# Operator overloads:
# Note we only overload operators that support floating-point types, as
# non-float variables cannot be wrapped with an AutoCastVariable.
# pylint: disable=multiple-statements
def __add__(self, o): return self.value() + o
def __radd__(self, o): return o + self.value()
def __sub__(self, o): return self.value() - o

View File

@ -22,14 +22,17 @@ import os
from absl.testing import parameterized
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training.tracking import util as trackable_utils
@ -102,6 +105,21 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(x.read_value().dtype, dtypes.float32)
self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)
def test_sparse_reads(self):
x = get_var([1., 2], dtypes.float32)
# DistributedVariables do not support sparse_read or gather_nd, so we pass
# distribute=False
x = get_autocast_var(x, distribute=False)
self.evaluate(x.initializer)
self.assertEqual(x.sparse_read([0]).dtype, dtypes.float32)
self.assertEqual(x.gather_nd([0]).dtype, dtypes.float32)
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16)
self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16)
@parameterized.named_parameters(*TESTCASES)
def test_read_nested_scopes(self, distribute):
with get_distribute_scope(distribute):
@ -138,6 +156,75 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(x.true_dtype, dtypes.float32)
self.assertIsInstance(x.true_dtype, dtypes.DType)
@parameterized.named_parameters(*TESTCASES)
def test_method_delegations(self, distribute):
# Test AutoCastVariable correctly delegates Variable methods to the
# underlying variable.
with get_distribute_scope(distribute):
evaluate = self.evaluate
for read_dtype in (dtypes.float32, dtypes.float16):
x = get_var(7., dtypes.float32)
x = get_autocast_var(x, distribute)
with ops.get_default_graph()._enable_auto_casting_variables(
read_dtype):
evaluate(x.initializer)
self.assertEqual(evaluate(x.value()), 7)
self.assertEqual(evaluate(x.read_value()), 7)
self.assertTrue(x.trainable)
self.assertEqual(x.synchronization, x._variable.synchronization)
self.assertEqual(x.aggregation, x._variable.aggregation)
self.assertEqual(evaluate(x.initialized_value()), 7)
if not context.executing_eagerly():
if not distribute:
# These functions are not supported for DistributedVariables
x.load(9)
self.assertEqual(x.eval(), 9)
self.assertEqual(evaluate(x.initial_value), 7)
self.assertEqual(x.op, x._variable.op)
self.assertEqual(x.graph, x._variable.graph)
if not distribute:
# These attributes are not supported for DistributedVariables
self.assertIsNone(x.constraint)
self.assertEqual(x.initializer, x._variable.initializer)
self.assertEqual(evaluate(x.assign(8)), 8)
self.assertEqual(evaluate(x.assign_add(2)), 10)
self.assertEqual(evaluate(x.assign_sub(3)), 7)
self.assertEqual(x.name, x._variable.name)
self.assertEqual(x.device, x._variable.device)
self.assertEqual(x.shape, ())
self.assertEqual(x.get_shape(), ())
if not distribute:
# Test scatter_* methods. These are not supported for
# DistributedVariables
x = get_var([7, 8], dtypes.float32)
x = get_autocast_var(x, distribute)
with ops.get_default_graph()._enable_auto_casting_variables(
read_dtype):
evaluate(x.initializer)
self.assertAllEqual(evaluate(x.value()), [7, 8])
def slices(val, index):
return indexed_slices.IndexedSlices(
values=constant_op.constant(val, dtype=dtypes.float32),
indices=constant_op.constant(index, dtype=dtypes.int32),
dense_shape=constant_op.constant([2], dtype=dtypes.int32))
self.assertAllEqual(evaluate(x.scatter_sub(slices(1., 0))), [6, 8])
self.assertAllEqual(evaluate(x.scatter_add(slices(1., 0))), [7, 8])
self.assertAllEqual(evaluate(x.scatter_max(slices(9., 1))), [7, 9])
self.assertAllEqual(evaluate(x.scatter_min(slices(8., 1))), [7, 8])
self.assertAllEqual(evaluate(x.scatter_mul(slices(2., 1))), [7, 16])
self.assertAllEqual(evaluate(x.scatter_div(slices(2., 1))), [7, 8])
self.assertAllEqual(
evaluate(x.scatter_update(slices(4., 1))), [7, 4])
self.assertAllEqual(
evaluate(x.scatter_nd_sub([[0], [1]], [1., 2.])), [6, 2])
self.assertAllEqual(
evaluate(x.scatter_nd_add([[0], [1]], [1., 2.])), [7, 4])
self.assertAllEqual(
evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2])
@parameterized.named_parameters(*TESTCASES)
def test_operator_overloads(self, distribute):
with get_distribute_scope(distribute):
@ -181,6 +268,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
x = get_autocast_var(x, distribute)
self.evaluate(x.initializer)
self.assertEqual(self.evaluate(x[1]), 8)
if tf2.enabled() and context.executing_eagerly():
self.assertAllEqual(x == [7., 8., 10.], [True, True, False])
self.assertAllEqual(x != [7., 8., 10.], [False, False, True])
@parameterized.named_parameters(*TESTCASES)
def test_assign(self, distribute):
@ -214,10 +304,18 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.evaluate(x.assign_sub(v2))
# Assign Python floats
self.assertAllClose(0., self.evaluate(x.assign(0.)))
self.assertAllClose(3.14, self.evaluate(x.assign(3.14)))
self.assertAllClose(3.14 * 2, self.evaluate(x.assign_add(3.14)))
self.assertAllClose(3.14, self.evaluate(x.assign_sub(3.14)))
# Use the tf.assign functions instead of the var.assign methods.
self.assertAllClose(0., self.evaluate(state_ops.assign(x, 0.)))
self.assertAllClose(3.14, self.evaluate(state_ops.assign(x, 3.14)))
self.assertAllClose(3.14 * 2,
self.evaluate(state_ops.assign_add(x, 3.14)))
self.assertAllClose(3.14, self.evaluate(state_ops.assign_sub(x, 3.14)))
run_and_check()
# reset x
self.evaluate(x.assign(0.))

View File

@ -429,6 +429,14 @@ class KerasLayerTest(keras_parameterized.TestCase):
self._test_checkpointing_layer_weights(
strategy_fn, mixed_prec_when_saving=False, mixed_prec_when_loading=True)
@test_util.run_in_graph_and_eager_modes
def test_delete_variable(self):
layer = base_layer.Layer(dtype=policy.Policy('mixed_float16'))
layer.x = layer.add_weight('x')
self.assertEqual(layer.trainable_weights, [layer.x])
del layer.x
self.assertEqual(layer.trainable_weights, [])
class KerasModelTest(keras_parameterized.TestCase):
"""Test mixed precision with Keras models."""