Use the function required by PackedDistributedVariable to return the primary variable.
PiperOrigin-RevId: 329414194 Change-Id: Ib2a046c60a869f87d143c1713b216e105c303749
This commit is contained in:
parent
a9a692d221
commit
e004bc40fd
@ -1324,7 +1324,7 @@ class OnReadPolicy(VariablePolicy):
|
||||
|
||||
def _get_cross_replica(self, var):
|
||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return var._primary # pylint: disable=protected-access
|
||||
return var._get_replica(0) # pylint: disable=protected-access
|
||||
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
return var.distribute_strategy.reduce(
|
||||
|
Loading…
x
Reference in New Issue
Block a user