Have AutoCastVariable.dtype refer to the variable dtype.
This allows us the flexibility to later remove AutoCastVariable and instead have a mechanism so that individual ops will cast variables (and potentially other tensors) to the correct dtype. See the last paragraph of the this section of the mixed precision RFC (8563574455/rfcs/20200929-keras-mixed-precision.md (op-based-autocasting-api)
) for an example of how this could be done.
PiperOrigin-RevId: 337793570
Change-Id: I8e56f7d276117a9a81070ab0984369e8a4490eea
This commit is contained in:
parent
7ba60d5a29
commit
861f63a327
@ -51,9 +51,6 @@ class AutoCastVariable(variables.Variable, core.Tensor):
|
||||
>>> with enable_auto_cast_variables(tf.float16):
|
||||
... tf.identity(v).dtype
|
||||
tf.float16
|
||||
>>> with enable_auto_cast_variables(tf.float16):
|
||||
... v.dtype # v.dtype also changes under the context manager
|
||||
tf.float16
|
||||
|
||||
The purpose of this class is to allow Keras layers to create variables in
|
||||
float32, and automatically cast them to float16 or bfloat16 when the layer is
|
||||
@ -82,38 +79,42 @@ class AutoCastVariable(variables.Variable, core.Tensor):
|
||||
def _should_cast(self):
|
||||
"""Returns True if this variable should be casted when accessed."""
|
||||
autocast_dtype = getattr(_autocast_dtype, 'dtype', None)
|
||||
return autocast_dtype is not None and self.true_dtype != autocast_dtype
|
||||
return autocast_dtype is not None and self.dtype != autocast_dtype
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""The dtype this variable will be casted to when read."""
|
||||
dtype = getattr(_autocast_dtype, 'dtype', None)
|
||||
return dtype or self._variable.dtype
|
||||
"""The dtype of the underlying variable, before any casts are done."""
|
||||
return self._variable.dtype
|
||||
|
||||
@property
|
||||
def true_dtype(self):
|
||||
"""The dtype of the underlying variable, before any casts are done."""
|
||||
"""Deprecated alias of `dtype`."""
|
||||
return self._variable.dtype
|
||||
|
||||
@property
|
||||
def _cast_dtype(self):
|
||||
dtype = getattr(_autocast_dtype, 'dtype', None)
|
||||
return dtype or self._variable.dtype
|
||||
|
||||
def value(self):
|
||||
val = self._variable.value()
|
||||
if not self._should_cast():
|
||||
return val
|
||||
return math_ops.cast(val, self.dtype)
|
||||
return math_ops.cast(val, self._cast_dtype)
|
||||
|
||||
def read_value(self):
|
||||
val = self._variable.read_value()
|
||||
return math_ops.cast(val, self.dtype)
|
||||
return math_ops.cast(val, self._cast_dtype)
|
||||
|
||||
def sparse_read(self, indices, name=None):
|
||||
"""Reads the value of this variable sparsely, using `gather`."""
|
||||
val = self._variable.sparse_read(indices, name=name)
|
||||
return math_ops.cast(val, self.dtype)
|
||||
return math_ops.cast(val, self._cast_dtype)
|
||||
|
||||
def gather_nd(self, indices, name=None):
|
||||
"""Gather slices of the variable into a Tensor."""
|
||||
val = self._variable.gather_nd(indices, name=name)
|
||||
return math_ops.cast(val, self.dtype)
|
||||
return math_ops.cast(val, self._cast_dtype)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._variable, name)
|
||||
@ -124,13 +125,14 @@ class AutoCastVariable(variables.Variable, core.Tensor):
|
||||
return ops.convert_to_tensor(self._variable, dtype, name, as_ref)
|
||||
# TODO(reedwm): Support as_ref?
|
||||
assert not as_ref
|
||||
if dtype is not None and not dtype.is_compatible_with(self.dtype):
|
||||
if dtype is not None and not dtype.is_compatible_with(self._cast_dtype):
|
||||
raise ValueError(
|
||||
'Incompatible type conversion requested to type {!r} for variable '
|
||||
'of type {!r}'.format(dtype.name, self.dtype.name))
|
||||
'Incompatible type conversion requested to type {!r} for '
|
||||
'AutoCastVariable which is casted to type {!r}'.format(
|
||||
dtype.name, self._cast_dtype.name))
|
||||
val = ops.convert_to_tensor_v2_with_dispatch(
|
||||
self._variable, dtype=self._variable.dtype, name=name)
|
||||
return math_ops.cast(val, self.dtype)
|
||||
return math_ops.cast(val, self._cast_dtype)
|
||||
|
||||
def _should_act_as_resource_variable(self):
|
||||
"""Pass resource_variable_ops.is_resource_variable check."""
|
||||
@ -139,13 +141,13 @@ class AutoCastVariable(variables.Variable, core.Tensor):
|
||||
def __repr__(self):
|
||||
if context.executing_eagerly() and not self._in_graph_mode:
|
||||
repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} "
|
||||
'dtype={v.dtype.name} true_dtype={v.true_dtype.name}, '
|
||||
'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, '
|
||||
'numpy={np_repr}>')
|
||||
return repr_str.format(
|
||||
v=self, np_repr=ops.numpy_text(self.read_value(), is_repr=True))
|
||||
else:
|
||||
repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} "
|
||||
'dtype={v.dtype.name} true_dtype={v.true_dtype.name}>')
|
||||
'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>')
|
||||
return repr_str.format(v=self)
|
||||
|
||||
# Method delegations: We delegate the following methods to self._variable.
|
||||
@ -504,7 +506,8 @@ def create_autocast_variable(variable, op=None):
|
||||
|
||||
# pylint: disable=missing-format-attribute
|
||||
return ('<AutoCastDistributedVariable dtype={v.dtype.name} '
|
||||
'true_dtype={v.true_dtype.name} inner_variable={v._variable}>'
|
||||
'dtype_to_cast_to={v._cast_dtype.name} '
|
||||
'inner_variable={v._variable}>'
|
||||
).format(v=self)
|
||||
# pylint: enable=missing-format-attribute
|
||||
|
||||
|
@ -77,7 +77,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
# within auto cast scope of different dtype
|
||||
with autocast_variable.enable_auto_cast_variables(dtypes.float16):
|
||||
self.assertEqual(x.dtype, dtypes.float16)
|
||||
self.assertEqual(x.dtype, dtypes.float32)
|
||||
self.assertEqual(x.value().dtype, dtypes.float16)
|
||||
self.assertEqual(x.read_value().dtype, dtypes.float16)
|
||||
self.assertEqual(array_ops.identity(x).dtype, dtypes.float16)
|
||||
@ -111,14 +111,11 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(x.initializer)
|
||||
|
||||
with autocast_variable.enable_auto_cast_variables(dtypes.float16):
|
||||
self.assertEqual(x.dtype, dtypes.float16)
|
||||
self.assertEqual(x.read_value().dtype, dtypes.float16)
|
||||
|
||||
with autocast_variable.enable_auto_cast_variables(dtypes.float32):
|
||||
self.assertEqual(x.dtype, dtypes.float32)
|
||||
self.assertEqual(x.read_value().dtype, dtypes.float32)
|
||||
|
||||
self.assertEqual(x.dtype, dtypes.float16)
|
||||
self.assertEqual(x.read_value().dtype, dtypes.float16)
|
||||
|
||||
@ds_combinations.generate(maybe_distribute)
|
||||
@ -133,7 +130,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
dtype = dtypes.float16
|
||||
with autocast_variable.enable_auto_cast_variables(dtype):
|
||||
self.assertEqual(x.dtype, dtypes.float16)
|
||||
self.assertEqual(x.dtype, dtypes.float32)
|
||||
self.assertIsInstance(x.dtype, dtypes.DType)
|
||||
self.assertEqual(x.true_dtype, dtypes.float32)
|
||||
self.assertIsInstance(x.true_dtype, dtypes.DType)
|
||||
@ -153,7 +150,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def evaluate(var):
|
||||
self.assertIsInstance(var, autocast_variable.AutoCastVariable)
|
||||
self.assertEqual(var.dtype, read_dtype)
|
||||
self.assertEqual(array_ops.identity(var).dtype, read_dtype) # pylint: disable=cell-var-from-loop
|
||||
return self.evaluate(var)
|
||||
|
||||
x = get_var(7., dtypes.float32)
|
||||
@ -415,13 +412,13 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(x.initializer)
|
||||
|
||||
with autocast_variable.enable_auto_cast_variables(dtypes.float16):
|
||||
self.assertEqual(x.dtype, dtypes.float16)
|
||||
self.assertEqual(array_ops.identity(x).dtype, dtypes.float16)
|
||||
|
||||
# New threads should not see the modified value of the autocast dtype.
|
||||
var_dtype = None
|
||||
def f():
|
||||
nonlocal var_dtype
|
||||
var_dtype = x.dtype
|
||||
var_dtype = x._cast_dtype
|
||||
thread = threading.Thread(target=f)
|
||||
thread.start()
|
||||
thread.join()
|
||||
@ -465,24 +462,26 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
if context.executing_eagerly():
|
||||
self.assertStartsWith(
|
||||
repr(x),
|
||||
"<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32, "
|
||||
"numpy="
|
||||
"<AutoCastVariable 'x:0' shape=() dtype=float32 "
|
||||
"dtype_to_cast_to=float32, numpy="
|
||||
)
|
||||
with autocast_variable.enable_auto_cast_variables(dtypes.float16):
|
||||
self.assertStartsWith(
|
||||
repr(x),
|
||||
"<AutoCastVariable 'x:0' shape=() dtype=float16 "
|
||||
"true_dtype=float32, numpy="
|
||||
"<AutoCastVariable 'x:0' shape=() dtype=float32 "
|
||||
"dtype_to_cast_to=float16, numpy="
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
repr(x),
|
||||
"<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32>"
|
||||
"<AutoCastVariable 'x:0' shape=() dtype=float32 "
|
||||
"dtype_to_cast_to=float32>"
|
||||
)
|
||||
with autocast_variable.enable_auto_cast_variables(dtypes.float16):
|
||||
self.assertEqual(
|
||||
repr(x),
|
||||
"<AutoCastVariable 'x:0' shape=() dtype=float16 true_dtype=float32>"
|
||||
"<AutoCastVariable 'x:0' shape=() dtype=float32 "
|
||||
"dtype_to_cast_to=float16>"
|
||||
)
|
||||
|
||||
def test_repr_distributed(self):
|
||||
@ -494,12 +493,14 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
if use_policy:
|
||||
self.assertRegex(
|
||||
repr(x).replace('\n', ' '),
|
||||
'<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
|
||||
'<AutoCastDistributedVariable dtype=float32 '
|
||||
'dtype_to_cast_to=float32 '
|
||||
'inner_variable=DistributedVariable.*>')
|
||||
else:
|
||||
self.assertRegex(
|
||||
repr(x).replace('\n', ' '),
|
||||
'<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
|
||||
'<AutoCastDistributedVariable dtype=float32 '
|
||||
'dtype_to_cast_to=float32 '
|
||||
'inner_variable=MirroredVariable.*>')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
@ -231,7 +231,8 @@ class Policy(object):
|
||||
>>> layer = MyLayer(dtype=policy)
|
||||
>>> layer.build((2, 2))
|
||||
>>> layer.x
|
||||
<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32, numpy=...>
|
||||
<AutoCastVariable 'x:0' shape=() dtype=float32 dtype_to_cast_to=float32,
|
||||
numpy=...>
|
||||
>>> layer.y
|
||||
<tf.Variable 'y:0' shape=() dtype=float32, numpy=...>
|
||||
|
||||
|
@ -159,7 +159,6 @@ class MultiplyLayer(AssertTypeLayer):
|
||||
|
||||
def call(self, inputs):
|
||||
self.assert_input_types(inputs)
|
||||
assert inputs.dtype == self.v.dtype
|
||||
return self._multiply(inputs, self.v)
|
||||
|
||||
def _multiply(self, x, y):
|
||||
|
Loading…
Reference in New Issue
Block a user