Use correct variable _device attribute in Keras optimizer_v2.
PiperOrigin-RevId: 302069884 Change-Id: I32ff43f146c6f60d462d2713908c3cf258ace3de
This commit is contained in:
parent
0ad3c881ff
commit
d0e21cd468
@ -699,8 +699,10 @@ class OptimizerV2(trackable.Trackable):
|
||||
def _prepare(self, var_list):
|
||||
keys = set()
|
||||
for var in var_list:
|
||||
var_devices = (getattr(var, "devices", None) or # Distributed
|
||||
[var.device]) # Regular
|
||||
if isinstance(var, ds_values.DistributedValues):
|
||||
var_devices = var._devices # pylint: disable=protected-access
|
||||
else:
|
||||
var_devices = [var.device]
|
||||
var_dtype = var.dtype.base_dtype
|
||||
for var_device in var_devices:
|
||||
keys.add((var_device, var_dtype))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user