Implement DistributedVariable.numpy()

MirroredVariable.numpy() currently works, through DistributedDelegate, but it's
better to implement it explicitly.

PiperOrigin-RevId: 309884308
Change-Id: I314410ae37c6c3c39a7034138c3787b9bd045a75
This commit is contained in:
Ran Chen 2020-05-04 22:38:02 -07:00 committed by TensorFlower Gardener
parent 3e8bbe399b
commit d0a4066ffa

View File

@ -587,6 +587,13 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
def value(self):
return self._get_closest().value()
def numpy(self):
if context.executing_eagerly():
return self.read_value().numpy()
else:
raise NotImplementedError(
"numpy() is only available when eager execution is enabled.")
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
return self._update(
@ -1143,13 +1150,6 @@ class SyncOnReadVariable(DistributedVariable):
# _get_closest() returns a Variable.
return self._get_closest().value()
def numpy(self):
if context.executing_eagerly():
return self.read_value().numpy()
else:
raise NotImplementedError(
"numpy() is only available when eager execution is enabled.")
def _get_cross_replica(self):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return self._primary