From d0a4066ffa4ca176e16ebe566d193ff77bda9a89 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Mon, 4 May 2020 22:38:02 -0700 Subject: [PATCH] Implement DistributedVariable.numpy() MirroredVariable.numpy() currently works, through DistributedDelegate, but it's better to implement it explicitly. PiperOrigin-RevId: 309884308 Change-Id: I314410ae37c6c3c39a7034138c3787b9bd045a75 --- tensorflow/python/distribute/values.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 2c31b24539c..4581cef0479 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -587,6 +587,13 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable): def value(self): return self._get_closest().value() + def numpy(self): + if context.executing_eagerly(): + return self.read_value().numpy() + else: + raise NotImplementedError( + "numpy() is only available when eager execution is enabled.") + def assign_sub(self, value, use_locking=False, name=None, read_value=True): assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) return self._update( @@ -1143,13 +1150,6 @@ class SyncOnReadVariable(DistributedVariable): # _get_closest() returns a Variable. return self._get_closest().value() - def numpy(self): - if context.executing_eagerly(): - return self.read_value().numpy() - else: - raise NotImplementedError( - "numpy() is only available when eager execution is enabled.") - def _get_cross_replica(self): if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: return self._primary