Remove duplicate symbols and usages for utility functions.
PiperOrigin-RevId: 325087800 Change-Id: Ic15366260a57ef22da72c587c45d78e5094d8a41
This commit is contained in:
parent
b2f5d100d1
commit
96cbf43548
tensorflow/python/distribute
@ -691,9 +691,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
def read_var(self, replica_local_var):
|
def read_var(self, replica_local_var):
|
||||||
"""Read the aggregate value of a replica-local variable."""
|
"""Read the aggregate value of a replica-local variable."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if values._is_sync_on_read(replica_local_var):
|
if distribute_utils.is_sync_on_read(replica_local_var):
|
||||||
return replica_local_var._get_cross_replica()
|
return replica_local_var._get_cross_replica()
|
||||||
assert values._is_mirrored(replica_local_var)
|
assert distribute_utils.is_mirrored(replica_local_var)
|
||||||
return array_ops.identity(replica_local_var._get())
|
return array_ops.identity(replica_local_var._get())
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
@ -1526,22 +1526,5 @@ class OnWritePolicy(AutoPolicy):
|
|||||||
return _on_write_update_replica(var, update_fn, value, **kwargs)
|
return _on_write_update_replica(var, update_fn, value, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Utility functions
|
|
||||||
# Return True if the Value is Mirrored or the Variable is replicated and kept in
|
|
||||||
# sync.
|
|
||||||
def _is_mirrored(val):
|
|
||||||
if isinstance(val, DistributedVariable):
|
|
||||||
if val._policy: # pylint: disable=protected-access
|
|
||||||
return val._policy._is_mirrored() # pylint: disable=protected-access
|
|
||||||
return isinstance(val, Mirrored)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_sync_on_read(val):
|
|
||||||
if isinstance(val, DistributedVariable):
|
|
||||||
if val._policy: # pylint: disable=protected-access
|
|
||||||
return not val._policy._is_mirrored() # pylint: disable=protected-access
|
|
||||||
return not isinstance(val, Mirrored)
|
|
||||||
|
|
||||||
|
|
||||||
def _in_update_replica():
|
def _in_update_replica():
|
||||||
return distribute_lib.get_update_replica_id() is not None
|
return distribute_lib.get_update_replica_id() is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user