Use correct variable _device attribute in Keras optimizer_v2.

PiperOrigin-RevId: 302069884
Change-Id: I32ff43f146c6f60d462d2713908c3cf258ace3de
This commit is contained in:
Ken Franko 2020-03-20 11:47:16 -07:00 committed by TensorFlower Gardener
parent 0ad3c881ff
commit d0e21cd468

View File

@ -699,8 +699,10 @@ class OptimizerV2(trackable.Trackable):
def _prepare(self, var_list):
keys = set()
for var in var_list:
var_devices = (getattr(var, "devices", None) or # Distributed
[var.device]) # Regular
if isinstance(var, ds_values.DistributedValues):
var_devices = var._devices # pylint: disable=protected-access
else:
var_devices = [var.device]
var_dtype = var.dtype.base_dtype
for var_device in var_devices:
keys.add((var_device, var_dtype))