Implement DistributedVariable.numpy()
MirroredVariable.numpy() currently works, through DistributedDelegate, but it's better to implement it explicitly. PiperOrigin-RevId: 309884308 Change-Id: I314410ae37c6c3c39a7034138c3787b9bd045a75
This commit is contained in:
parent
3e8bbe399b
commit
d0a4066ffa
@ -587,6 +587,13 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
def value(self):
|
||||
return self._get_closest().value()
|
||||
|
||||
def numpy(self):
|
||||
if context.executing_eagerly():
|
||||
return self.read_value().numpy()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"numpy() is only available when eager execution is enabled.")
|
||||
|
||||
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||
return self._update(
|
||||
@ -1143,13 +1150,6 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
# _get_closest() returns a Variable.
|
||||
return self._get_closest().value()
|
||||
|
||||
def numpy(self):
|
||||
if context.executing_eagerly():
|
||||
return self.read_value().numpy()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"numpy() is only available when eager execution is enabled.")
|
||||
|
||||
def _get_cross_replica(self):
|
||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return self._primary
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user