Use the function required by PackedDistributedVariable to return the primary variable.

PiperOrigin-RevId: 329414194
Change-Id: Ib2a046c60a869f87d143c1713b216e105c303749
This commit is contained in:
Anjali Sridhar 2020-08-31 17:47:47 -07:00 committed by TensorFlower Gardener
parent a9a692d221
commit e004bc40fd

View File

@ -1324,7 +1324,7 @@ class OnReadPolicy(VariablePolicy):
def _get_cross_replica(self, var):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return var._primary # pylint: disable=protected-access
return var._get_replica(0) # pylint: disable=protected-access
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
return var.distribute_strategy.reduce(