Have AutoCastVariable.dtype refer to the variable dtype.

This allows us the flexibility to later remove AutoCastVariable and instead have a mechanism so that individual ops will cast variables (and potentially other tensors) to the correct dtype. See the last paragraph of the this section of the mixed precision RFC (8563574455/rfcs/20200929-keras-mixed-precision.md (op-based-autocasting-api)) for an example of how this could be done.

PiperOrigin-RevId: 337793570
Change-Id: I8e56f7d276117a9a81070ab0984369e8a4490eea
This commit is contained in:
Reed Wanderman-Milne 2020-10-18 22:28:42 -07:00 committed by TensorFlower Gardener
parent 7ba60d5a29
commit 861f63a327
4 changed files with 41 additions and 37 deletions

View File

@ -51,9 +51,6 @@ class AutoCastVariable(variables.Variable, core.Tensor):
>>> with enable_auto_cast_variables(tf.float16):
... tf.identity(v).dtype
tf.float16
>>> with enable_auto_cast_variables(tf.float16):
... v.dtype # v.dtype also changes under the context manager
tf.float16
The purpose of this class is to allow Keras layers to create variables in
float32, and automatically cast them to float16 or bfloat16 when the layer is
@ -82,38 +79,42 @@ class AutoCastVariable(variables.Variable, core.Tensor):
def _should_cast(self):
"""Returns True if this variable should be casted when accessed."""
autocast_dtype = getattr(_autocast_dtype, 'dtype', None)
return autocast_dtype is not None and self.true_dtype != autocast_dtype
return autocast_dtype is not None and self.dtype != autocast_dtype
@property
def dtype(self):
"""The dtype this variable will be casted to when read."""
dtype = getattr(_autocast_dtype, 'dtype', None)
return dtype or self._variable.dtype
"""The dtype of the underlying variable, before any casts are done."""
return self._variable.dtype
@property
def true_dtype(self):
"""The dtype of the underlying variable, before any casts are done."""
"""Deprecated alias of `dtype`."""
return self._variable.dtype
@property
def _cast_dtype(self):
dtype = getattr(_autocast_dtype, 'dtype', None)
return dtype or self._variable.dtype
def value(self):
val = self._variable.value()
if not self._should_cast():
return val
return math_ops.cast(val, self.dtype)
return math_ops.cast(val, self._cast_dtype)
def read_value(self):
val = self._variable.read_value()
return math_ops.cast(val, self.dtype)
return math_ops.cast(val, self._cast_dtype)
def sparse_read(self, indices, name=None):
"""Reads the value of this variable sparsely, using `gather`."""
val = self._variable.sparse_read(indices, name=name)
return math_ops.cast(val, self.dtype)
return math_ops.cast(val, self._cast_dtype)
def gather_nd(self, indices, name=None):
"""Gather slices of the variable into a Tensor."""
val = self._variable.gather_nd(indices, name=name)
return math_ops.cast(val, self.dtype)
return math_ops.cast(val, self._cast_dtype)
def __getattr__(self, name):
return getattr(self._variable, name)
@ -124,13 +125,14 @@ class AutoCastVariable(variables.Variable, core.Tensor):
return ops.convert_to_tensor(self._variable, dtype, name, as_ref)
# TODO(reedwm): Support as_ref?
assert not as_ref
if dtype is not None and not dtype.is_compatible_with(self.dtype):
if dtype is not None and not dtype.is_compatible_with(self._cast_dtype):
raise ValueError(
'Incompatible type conversion requested to type {!r} for variable '
'of type {!r}'.format(dtype.name, self.dtype.name))
'Incompatible type conversion requested to type {!r} for '
'AutoCastVariable which is casted to type {!r}'.format(
dtype.name, self._cast_dtype.name))
val = ops.convert_to_tensor_v2_with_dispatch(
self._variable, dtype=self._variable.dtype, name=name)
return math_ops.cast(val, self.dtype)
return math_ops.cast(val, self._cast_dtype)
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
@ -139,13 +141,13 @@ class AutoCastVariable(variables.Variable, core.Tensor):
def __repr__(self):
if context.executing_eagerly() and not self._in_graph_mode:
repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} "
'dtype={v.dtype.name} true_dtype={v.true_dtype.name}, '
'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, '
'numpy={np_repr}>')
return repr_str.format(
v=self, np_repr=ops.numpy_text(self.read_value(), is_repr=True))
else:
repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} "
'dtype={v.dtype.name} true_dtype={v.true_dtype.name}>')
'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>')
return repr_str.format(v=self)
# Method delegations: We delegate the following methods to self._variable.
@ -504,7 +506,8 @@ def create_autocast_variable(variable, op=None):
# pylint: disable=missing-format-attribute
return ('<AutoCastDistributedVariable dtype={v.dtype.name} '
'true_dtype={v.true_dtype.name} inner_variable={v._variable}>'
'dtype_to_cast_to={v._cast_dtype.name} '
'inner_variable={v._variable}>'
).format(v=self)
# pylint: enable=missing-format-attribute

View File

@ -77,7 +77,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
# within auto cast scope of different dtype
with autocast_variable.enable_auto_cast_variables(dtypes.float16):
self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(x.dtype, dtypes.float32)
self.assertEqual(x.value().dtype, dtypes.float16)
self.assertEqual(x.read_value().dtype, dtypes.float16)
self.assertEqual(array_ops.identity(x).dtype, dtypes.float16)
@ -111,14 +111,11 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.evaluate(x.initializer)
with autocast_variable.enable_auto_cast_variables(dtypes.float16):
self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(x.read_value().dtype, dtypes.float16)
with autocast_variable.enable_auto_cast_variables(dtypes.float32):
self.assertEqual(x.dtype, dtypes.float32)
self.assertEqual(x.read_value().dtype, dtypes.float32)
self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(x.read_value().dtype, dtypes.float16)
@ds_combinations.generate(maybe_distribute)
@ -133,7 +130,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
dtype = dtypes.float16
with autocast_variable.enable_auto_cast_variables(dtype):
self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(x.dtype, dtypes.float32)
self.assertIsInstance(x.dtype, dtypes.DType)
self.assertEqual(x.true_dtype, dtypes.float32)
self.assertIsInstance(x.true_dtype, dtypes.DType)
@ -153,7 +150,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
def evaluate(var):
self.assertIsInstance(var, autocast_variable.AutoCastVariable)
self.assertEqual(var.dtype, read_dtype)
self.assertEqual(array_ops.identity(var).dtype, read_dtype) # pylint: disable=cell-var-from-loop
return self.evaluate(var)
x = get_var(7., dtypes.float32)
@ -415,13 +412,13 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.evaluate(x.initializer)
with autocast_variable.enable_auto_cast_variables(dtypes.float16):
self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(array_ops.identity(x).dtype, dtypes.float16)
# New threads should not see the modified value of the autocast dtype.
var_dtype = None
def f():
nonlocal var_dtype
var_dtype = x.dtype
var_dtype = x._cast_dtype
thread = threading.Thread(target=f)
thread.start()
thread.join()
@ -465,24 +462,26 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
if context.executing_eagerly():
self.assertStartsWith(
repr(x),
"<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32, "
"numpy="
"<AutoCastVariable 'x:0' shape=() dtype=float32 "
"dtype_to_cast_to=float32, numpy="
)
with autocast_variable.enable_auto_cast_variables(dtypes.float16):
self.assertStartsWith(
repr(x),
"<AutoCastVariable 'x:0' shape=() dtype=float16 "
"true_dtype=float32, numpy="
"<AutoCastVariable 'x:0' shape=() dtype=float32 "
"dtype_to_cast_to=float16, numpy="
)
else:
self.assertEqual(
repr(x),
"<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32>"
"<AutoCastVariable 'x:0' shape=() dtype=float32 "
"dtype_to_cast_to=float32>"
)
with autocast_variable.enable_auto_cast_variables(dtypes.float16):
self.assertEqual(
repr(x),
"<AutoCastVariable 'x:0' shape=() dtype=float16 true_dtype=float32>"
"<AutoCastVariable 'x:0' shape=() dtype=float32 "
"dtype_to_cast_to=float16>"
)
def test_repr_distributed(self):
@ -494,12 +493,14 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
if use_policy:
self.assertRegex(
repr(x).replace('\n', ' '),
'<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
'<AutoCastDistributedVariable dtype=float32 '
'dtype_to_cast_to=float32 '
'inner_variable=DistributedVariable.*>')
else:
self.assertRegex(
repr(x).replace('\n', ' '),
'<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
'<AutoCastDistributedVariable dtype=float32 '
'dtype_to_cast_to=float32 '
'inner_variable=MirroredVariable.*>')
@parameterized.named_parameters(

View File

@ -231,7 +231,8 @@ class Policy(object):
>>> layer = MyLayer(dtype=policy)
>>> layer.build((2, 2))
>>> layer.x
<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32, numpy=...>
<AutoCastVariable 'x:0' shape=() dtype=float32 dtype_to_cast_to=float32,
numpy=...>
>>> layer.y
<tf.Variable 'y:0' shape=() dtype=float32, numpy=...>

View File

@ -159,7 +159,6 @@ class MultiplyLayer(AssertTypeLayer):
def call(self, inputs):
self.assert_input_types(inputs)
assert inputs.dtype == self.v.dtype
return self._multiply(inputs, self.v)
def _multiply(self, x, y):