Remove duplicate symbols and usages for utility functions.

PiperOrigin-RevId: 325087800
Change-Id: Ic15366260a57ef22da72c587c45d78e5094d8a41
This commit is contained in:
Anjali Sridhar 2020-08-05 13:21:17 -07:00 committed by TensorFlower Gardener
parent b2f5d100d1
commit 96cbf43548
2 changed files with 2 additions and 19 deletions

View File

@ -691,9 +691,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
def read_var(self, replica_local_var):
"""Read the aggregate value of a replica-local variable."""
# pylint: disable=protected-access
if values._is_sync_on_read(replica_local_var):
if distribute_utils.is_sync_on_read(replica_local_var):
return replica_local_var._get_cross_replica()
assert values._is_mirrored(replica_local_var)
assert distribute_utils.is_mirrored(replica_local_var)
return array_ops.identity(replica_local_var._get())
# pylint: enable=protected-access

View File

@ -1526,22 +1526,5 @@ class OnWritePolicy(AutoPolicy):
return _on_write_update_replica(var, update_fn, value, **kwargs)
# Utility functions
# Return True if the Value is Mirrored or the Variable is replicated and kept in
# sync.
def _is_mirrored(val):
if isinstance(val, DistributedVariable):
if val._policy: # pylint: disable=protected-access
return val._policy._is_mirrored() # pylint: disable=protected-access
return isinstance(val, Mirrored)
def _is_sync_on_read(val):
if isinstance(val, DistributedVariable):
if val._policy: # pylint: disable=protected-access
return not val._policy._is_mirrored() # pylint: disable=protected-access
return not isinstance(val, Mirrored)
def _in_update_replica():
return distribute_lib.get_update_replica_id() is not None