Have AutoCastVariable subclass from Variable.
This allows AutoCastVariable to pass isinstance(..., tf.Variable) checks, fixing various small issues. PiperOrigin-RevId: 269699969
This commit is contained in:
parent
790ff20d98
commit
74c5253184
@ -22,11 +22,10 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
# TODO(reedwm) Make this subclass AutoCastVariable.
|
||||
class AutoCastVariable(trackable.Trackable):
|
||||
class AutoCastVariable(variables.Variable):
|
||||
"""Variable that will cast itself to a different dtype in applicable contexts.
|
||||
|
||||
This class wraps a floating-point tf.Variable. It emulates the variable
|
||||
@ -67,14 +66,6 @@ class AutoCastVariable(trackable.Trackable):
|
||||
'type: %s' % variable.dtype.name)
|
||||
self._variable = variable
|
||||
|
||||
# Delegate to the underlying variable for checkpointing.
|
||||
self._gather_saveables_for_checkpoint = (
|
||||
self._variable._gather_saveables_for_checkpoint) # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._variable.name
|
||||
|
||||
def _should_cast(self):
|
||||
"""Returns True if this variable should be casted when accessed."""
|
||||
g = ops.get_default_graph()
|
||||
@ -108,31 +99,17 @@ class AutoCastVariable(trackable.Trackable):
|
||||
|
||||
def read_value(self):
|
||||
val = self._variable.read_value()
|
||||
if not self._should_cast():
|
||||
return val
|
||||
return math_ops.cast(val, self.dtype)
|
||||
|
||||
def sparse_read(self, indices, name=None):
|
||||
"""Reads the value of this variable sparsely, using `gather`."""
|
||||
val = self._variable.sparse_read(indices, name=name)
|
||||
if not self._should_cast():
|
||||
return val
|
||||
return math_ops.cast(val, self.dtype)
|
||||
|
||||
def assign(self, value, use_locking=None, name=None, read_value=True):
|
||||
return self._variable.assign(
|
||||
value, use_locking=use_locking, name=name, read_value=read_value)
|
||||
|
||||
def assign_add(self, delta, use_locking=None, name=None, read_value=True):
|
||||
return self._variable.assign_add(
|
||||
delta, use_locking=use_locking, name=name, read_value=read_value)
|
||||
|
||||
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
|
||||
return self._variable.assign_sub(
|
||||
delta, use_locking=use_locking, name=name, read_value=read_value)
|
||||
|
||||
# TODO(reedwm): Support assigning variables with tf.compat.v1.assign(),
|
||||
# var.scatter_add, etc.
|
||||
def gather_nd(self, indices, name=None):
|
||||
"""Gather slices of the variable into a Tensor."""
|
||||
val = self._variable.gather_nd(indices, name=name)
|
||||
return math_ops.cast(val, self.dtype)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._variable, name)
|
||||
@ -171,11 +148,129 @@ class AutoCastVariable(trackable.Trackable):
|
||||
'dtype={v.dtype.name} true_dtype={v.true_dtype.name}>')
|
||||
return repr_str.format(v=self)
|
||||
|
||||
# Method delegations: We delegate the following methods to self._variable.
|
||||
# Each of these methods simply calls the same method on self._variable. The
|
||||
# base Variable raises NotImplementedError for most of these, so we must
|
||||
# override them.
|
||||
#
|
||||
# We do not define the following methods from Variable for the following
|
||||
# reasons:
|
||||
# * 'count_up_to': This method only applies to int variables, which cannot
|
||||
# be wrapped with an AutoCastVariable.
|
||||
# * 'experimental_ref': Instead we inherit the definition from Variable.
|
||||
# If we defined and delegated to Variable, the ref of an AutoCastVariable
|
||||
# would be the same as the ref of the underlying variable, which would be
|
||||
# strange as they are different Python objects.
|
||||
|
||||
# pylint: disable=multiple-statements
|
||||
def set_shape(self, shape): return self._variable.set_shape(self, shape)
|
||||
|
||||
@property
|
||||
def trainable(self): return self._variable.trainable
|
||||
|
||||
@property
|
||||
def synchronization(self): return self._variable.synchronization
|
||||
|
||||
@property
|
||||
def aggregation(self): return self._variable.aggregation
|
||||
|
||||
def eval(self, session=None): return self._variable.eval(session)
|
||||
|
||||
def initialized_value(self): return self._variable.initialized_value()
|
||||
|
||||
@property
|
||||
def initial_value(self): return self._variable.initial_value
|
||||
|
||||
@property
|
||||
def constraint(self): return self._variable.constraint
|
||||
|
||||
def assign(self, value, use_locking=None, name=None, read_value=True):
|
||||
return self._variable.assign(value, use_locking, name, read_value)
|
||||
|
||||
def assign_add(self, delta, use_locking=None, name=None, read_value=True):
|
||||
return self._variable.assign_add(delta, use_locking, name, read_value)
|
||||
|
||||
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
|
||||
return self._variable.assign_sub(delta, use_locking, name, read_value)
|
||||
|
||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||
return self._variable.scatter_sub(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||
return self._variable.scatter_add(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||
return self._variable.scatter_max(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||
return self._variable.scatter_min(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||
return self._variable.scatter_mul(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||
return self._variable.scatter_div(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||
return self._variable.scatter_update(sparse_delta, use_locking, name)
|
||||
|
||||
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||
return self._variable.batch_scatter_update(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_nd_sub(self, indices, updates, name=None):
|
||||
return self._variable.scatter_nd_sub(indices, updates, name)
|
||||
|
||||
def scatter_nd_add(self, indices, updates, name=None):
|
||||
return self._variable.scatter_nd_add(indices, updates, name)
|
||||
|
||||
def scatter_nd_update(self, indices, updates, name=None):
|
||||
return self._variable.scatter_nd_update(indices, updates, name)
|
||||
|
||||
def load(self, value, session=None):
|
||||
return self._variable.load(value, session)
|
||||
|
||||
@property
|
||||
def name(self): return self._variable.name
|
||||
|
||||
@property
|
||||
def _shared_name(self): return self._variable._shared_name # pylint:disable=protected-access
|
||||
|
||||
@property
|
||||
def initializer(self): return self._variable.initializer
|
||||
|
||||
@property
|
||||
def device(self): return self._variable.device
|
||||
|
||||
@property
|
||||
def op(self): return self._variable.op
|
||||
|
||||
@property
|
||||
def graph(self): return self._variable.graph
|
||||
|
||||
@property
|
||||
def shape(self): return self._variable.shape
|
||||
|
||||
def get_shape(self): return self._variable.get_shape()
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
# By delegating this method to the wrapped variable, checkpoints with
|
||||
# AutoCastVariables are identical to checkpoints with normal variables.
|
||||
# Therefore models checkpointed with AutoCastVariables can be restored on
|
||||
# models with normal variables, and vice versa.
|
||||
return self._variable._gather_saveables_for_checkpoint() # pylint:disable=protected-access
|
||||
|
||||
# TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in
|
||||
# to_proto().
|
||||
def to_proto(self, export_scope=None):
|
||||
return self._variable.to_proto(export_scope)
|
||||
|
||||
def from_proto(self, variable_def, import_scope=None):
|
||||
return self._variable.from_proto(variable_def, import_scope)
|
||||
|
||||
# Operator overloads:
|
||||
# Note we only overload operators that support floating-point types, as
|
||||
# non-float variables cannot be wrapped with an AutoCastVariable.
|
||||
|
||||
# pylint: disable=multiple-statements
|
||||
def __add__(self, o): return self.value() + o
|
||||
def __radd__(self, o): return o + self.value()
|
||||
def __sub__(self, o): return self.value() - o
|
||||
|
@ -22,14 +22,17 @@ import os
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
@ -102,6 +105,21 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(x.read_value().dtype, dtypes.float32)
|
||||
self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)
|
||||
|
||||
def test_sparse_reads(self):
|
||||
x = get_var([1., 2], dtypes.float32)
|
||||
# DistributedVariables do not support sparse_read or gather_nd, so we pass
|
||||
# distribute=False
|
||||
x = get_autocast_var(x, distribute=False)
|
||||
self.evaluate(x.initializer)
|
||||
|
||||
self.assertEqual(x.sparse_read([0]).dtype, dtypes.float32)
|
||||
self.assertEqual(x.gather_nd([0]).dtype, dtypes.float32)
|
||||
|
||||
with ops.get_default_graph()._enable_auto_casting_variables(
|
||||
dtypes.float16):
|
||||
self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16)
|
||||
self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_read_nested_scopes(self, distribute):
|
||||
with get_distribute_scope(distribute):
|
||||
@ -138,6 +156,75 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(x.true_dtype, dtypes.float32)
|
||||
self.assertIsInstance(x.true_dtype, dtypes.DType)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_method_delegations(self, distribute):
|
||||
# Test AutoCastVariable correctly delegates Variable methods to the
|
||||
# underlying variable.
|
||||
with get_distribute_scope(distribute):
|
||||
evaluate = self.evaluate
|
||||
for read_dtype in (dtypes.float32, dtypes.float16):
|
||||
x = get_var(7., dtypes.float32)
|
||||
x = get_autocast_var(x, distribute)
|
||||
with ops.get_default_graph()._enable_auto_casting_variables(
|
||||
read_dtype):
|
||||
evaluate(x.initializer)
|
||||
self.assertEqual(evaluate(x.value()), 7)
|
||||
self.assertEqual(evaluate(x.read_value()), 7)
|
||||
self.assertTrue(x.trainable)
|
||||
self.assertEqual(x.synchronization, x._variable.synchronization)
|
||||
self.assertEqual(x.aggregation, x._variable.aggregation)
|
||||
self.assertEqual(evaluate(x.initialized_value()), 7)
|
||||
if not context.executing_eagerly():
|
||||
if not distribute:
|
||||
# These functions are not supported for DistributedVariables
|
||||
x.load(9)
|
||||
self.assertEqual(x.eval(), 9)
|
||||
self.assertEqual(evaluate(x.initial_value), 7)
|
||||
self.assertEqual(x.op, x._variable.op)
|
||||
self.assertEqual(x.graph, x._variable.graph)
|
||||
if not distribute:
|
||||
# These attributes are not supported for DistributedVariables
|
||||
self.assertIsNone(x.constraint)
|
||||
self.assertEqual(x.initializer, x._variable.initializer)
|
||||
self.assertEqual(evaluate(x.assign(8)), 8)
|
||||
self.assertEqual(evaluate(x.assign_add(2)), 10)
|
||||
self.assertEqual(evaluate(x.assign_sub(3)), 7)
|
||||
self.assertEqual(x.name, x._variable.name)
|
||||
self.assertEqual(x.device, x._variable.device)
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.get_shape(), ())
|
||||
|
||||
if not distribute:
|
||||
# Test scatter_* methods. These are not supported for
|
||||
# DistributedVariables
|
||||
x = get_var([7, 8], dtypes.float32)
|
||||
x = get_autocast_var(x, distribute)
|
||||
with ops.get_default_graph()._enable_auto_casting_variables(
|
||||
read_dtype):
|
||||
evaluate(x.initializer)
|
||||
self.assertAllEqual(evaluate(x.value()), [7, 8])
|
||||
|
||||
def slices(val, index):
|
||||
return indexed_slices.IndexedSlices(
|
||||
values=constant_op.constant(val, dtype=dtypes.float32),
|
||||
indices=constant_op.constant(index, dtype=dtypes.int32),
|
||||
dense_shape=constant_op.constant([2], dtype=dtypes.int32))
|
||||
|
||||
self.assertAllEqual(evaluate(x.scatter_sub(slices(1., 0))), [6, 8])
|
||||
self.assertAllEqual(evaluate(x.scatter_add(slices(1., 0))), [7, 8])
|
||||
self.assertAllEqual(evaluate(x.scatter_max(slices(9., 1))), [7, 9])
|
||||
self.assertAllEqual(evaluate(x.scatter_min(slices(8., 1))), [7, 8])
|
||||
self.assertAllEqual(evaluate(x.scatter_mul(slices(2., 1))), [7, 16])
|
||||
self.assertAllEqual(evaluate(x.scatter_div(slices(2., 1))), [7, 8])
|
||||
self.assertAllEqual(
|
||||
evaluate(x.scatter_update(slices(4., 1))), [7, 4])
|
||||
self.assertAllEqual(
|
||||
evaluate(x.scatter_nd_sub([[0], [1]], [1., 2.])), [6, 2])
|
||||
self.assertAllEqual(
|
||||
evaluate(x.scatter_nd_add([[0], [1]], [1., 2.])), [7, 4])
|
||||
self.assertAllEqual(
|
||||
evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2])
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_operator_overloads(self, distribute):
|
||||
with get_distribute_scope(distribute):
|
||||
@ -181,6 +268,9 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
x = get_autocast_var(x, distribute)
|
||||
self.evaluate(x.initializer)
|
||||
self.assertEqual(self.evaluate(x[1]), 8)
|
||||
if tf2.enabled() and context.executing_eagerly():
|
||||
self.assertAllEqual(x == [7., 8., 10.], [True, True, False])
|
||||
self.assertAllEqual(x != [7., 8., 10.], [False, False, True])
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_assign(self, distribute):
|
||||
@ -214,10 +304,18 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(x.assign_sub(v2))
|
||||
|
||||
# Assign Python floats
|
||||
self.assertAllClose(0., self.evaluate(x.assign(0.)))
|
||||
self.assertAllClose(3.14, self.evaluate(x.assign(3.14)))
|
||||
self.assertAllClose(3.14 * 2, self.evaluate(x.assign_add(3.14)))
|
||||
self.assertAllClose(3.14, self.evaluate(x.assign_sub(3.14)))
|
||||
|
||||
# Use the tf.assign functions instead of the var.assign methods.
|
||||
self.assertAllClose(0., self.evaluate(state_ops.assign(x, 0.)))
|
||||
self.assertAllClose(3.14, self.evaluate(state_ops.assign(x, 3.14)))
|
||||
self.assertAllClose(3.14 * 2,
|
||||
self.evaluate(state_ops.assign_add(x, 3.14)))
|
||||
self.assertAllClose(3.14, self.evaluate(state_ops.assign_sub(x, 3.14)))
|
||||
|
||||
run_and_check()
|
||||
# reset x
|
||||
self.evaluate(x.assign(0.))
|
||||
|
@ -429,6 +429,14 @@ class KerasLayerTest(keras_parameterized.TestCase):
|
||||
self._test_checkpointing_layer_weights(
|
||||
strategy_fn, mixed_prec_when_saving=False, mixed_prec_when_loading=True)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_delete_variable(self):
|
||||
layer = base_layer.Layer(dtype=policy.Policy('mixed_float16'))
|
||||
layer.x = layer.add_weight('x')
|
||||
self.assertEqual(layer.trainable_weights, [layer.x])
|
||||
del layer.x
|
||||
self.assertEqual(layer.trainable_weights, [])
|
||||
|
||||
|
||||
class KerasModelTest(keras_parameterized.TestCase):
|
||||
"""Test mixed precision with Keras models."""
|
||||
|
Loading…
Reference in New Issue
Block a user