Add 'device' property to TPUMirroredVariable, so tf.train.init_from_checkpoint can be supported.
PiperOrigin-RevId: 215843249
This commit is contained in:
parent
83ff640fa5
commit
5608454c31
@ -571,6 +571,10 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase):
|
||||
ValueError("Device %s not found in %s (current device %s)" %
|
||||
(device, self._index.keys(), device_util.current())), e)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._get().device
|
||||
|
||||
# The arguments to update() are automatically unwrapped so the update()
|
||||
# function would normally see regular variables, not MirroredVariables.
|
||||
# However, the update function can still operate on wrapped MirroredVariables
|
||||
|
Loading…
Reference in New Issue
Block a user