diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 68f902acae7..781f0b6f76d 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -1324,7 +1324,7 @@ class OnReadPolicy(VariablePolicy): def _get_cross_replica(self, var): if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return var._primary # pylint: disable=protected-access + return var._get_replica(0) # pylint: disable=protected-access with ds_context.enter_or_assert_strategy(var.distribute_strategy): return var.distribute_strategy.reduce(