Make TPUMirroredVariable
a subclass of MirroredVariable
.
PiperOrigin-RevId: 252370482
This commit is contained in:
parent
f070a4b01c
commit
8da25cbfe7
@ -485,16 +485,16 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
return cross_device_ops_lib.reduce_non_distributed_value(
|
return cross_device_ops_lib.reduce_non_distributed_value(
|
||||||
reduce_op, self._device_map, value, destinations)
|
reduce_op, self._device_map, value, destinations)
|
||||||
|
|
||||||
devices = cross_device_ops_lib.get_devices_from(destinations)
|
# TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
|
||||||
if len(devices) != 1:
|
|
||||||
raise ValueError("Multiple devices are not supported for TPUStrategy")
|
|
||||||
|
|
||||||
# Always performs the reduction on the TPU host.
|
# Always performs the reduction on the TPU host.
|
||||||
with ops.device(self._host_device):
|
with ops.device(self._host_device):
|
||||||
output = math_ops.add_n(value.values)
|
output = math_ops.add_n(value.values)
|
||||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||||
output *= (1. / len(value.values))
|
output *= (1. / len(value.values))
|
||||||
|
|
||||||
|
devices = cross_device_ops_lib.get_devices_from(destinations)
|
||||||
|
|
||||||
|
if len(devices) == 1:
|
||||||
# If necessary, copy to requested destination.
|
# If necessary, copy to requested destination.
|
||||||
dest_canonical = device_util.canonicalize(devices[0])
|
dest_canonical = device_util.canonicalize(devices[0])
|
||||||
host_canonical = device_util.canonicalize(self._host_device)
|
host_canonical = device_util.canonicalize(self._host_device)
|
||||||
@ -502,6 +502,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
if dest_canonical != host_canonical:
|
if dest_canonical != host_canonical:
|
||||||
with ops.device(dest_canonical):
|
with ops.device(dest_canonical):
|
||||||
output = array_ops.identity(output)
|
output = array_ops.identity(output)
|
||||||
|
else:
|
||||||
|
output = cross_device_ops_lib.simple_broadcast(output, destinations)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -388,63 +388,73 @@ class DistributedDelegate(DistributedValues):
|
|||||||
# __getattr__ and @property. See b/120402273.
|
# __getattr__ and @property. See b/120402273.
|
||||||
return getattr(self.get(), name)
|
return getattr(self.get(), name)
|
||||||
|
|
||||||
|
def _get_as_operand(self):
|
||||||
|
"""Returns the value for operations for the current device.
|
||||||
|
|
||||||
|
Some implementations, e.g. `TPUMirroredVariable`, are not able to return the
|
||||||
|
value type within a replica context. They can, however, return a value that
|
||||||
|
can be used by the operations below.
|
||||||
|
"""
|
||||||
|
return self.get()
|
||||||
|
|
||||||
# pylint: disable=multiple-statements
|
# pylint: disable=multiple-statements
|
||||||
def __add__(self, o): return self.get() + o
|
def __add__(self, o): return self._get_as_operand() + o
|
||||||
def __radd__(self, o): return o + self.get()
|
def __radd__(self, o): return o + self._get_as_operand()
|
||||||
def __sub__(self, o): return self.get() - o
|
def __sub__(self, o): return self._get_as_operand() - o
|
||||||
def __rsub__(self, o): return o - self.get()
|
def __rsub__(self, o): return o - self._get_as_operand()
|
||||||
def __mul__(self, o): return self.get() * o
|
def __mul__(self, o): return self._get_as_operand() * o
|
||||||
def __rmul__(self, o): return o * self.get()
|
def __rmul__(self, o): return o * self._get_as_operand()
|
||||||
def __truediv__(self, o): return self.get() / o
|
def __truediv__(self, o): return self._get_as_operand() / o
|
||||||
def __rtruediv__(self, o): return o / self.get()
|
def __rtruediv__(self, o): return o / self._get_as_operand()
|
||||||
|
|
||||||
def __floordiv__(self, o):
|
def __floordiv__(self, o):
|
||||||
return self.get() // o
|
return self._get_as_operand() // o
|
||||||
|
|
||||||
def __rfloordiv__(self, o): return o // self.get()
|
def __rfloordiv__(self, o): return o // self._get_as_operand()
|
||||||
def __mod__(self, o): return self.get() % o
|
def __mod__(self, o): return self._get_as_operand() % o
|
||||||
def __rmod__(self, o): return o % self.get()
|
def __rmod__(self, o): return o % self._get_as_operand()
|
||||||
def __lt__(self, o): return self.get() < o
|
def __lt__(self, o): return self._get_as_operand() < o
|
||||||
def __le__(self, o): return self.get() <= o
|
def __le__(self, o): return self._get_as_operand() <= o
|
||||||
def __gt__(self, o): return self.get() > o
|
def __gt__(self, o): return self._get_as_operand() > o
|
||||||
def __ge__(self, o): return self.get() >= o
|
def __ge__(self, o): return self._get_as_operand() >= o
|
||||||
def __and__(self, o): return self.get() & o
|
def __and__(self, o): return self._get_as_operand() & o
|
||||||
def __rand__(self, o): return o & self.get()
|
def __rand__(self, o): return o & self._get_as_operand()
|
||||||
def __or__(self, o): return self.get() | o
|
def __or__(self, o): return self._get_as_operand() | o
|
||||||
def __ror__(self, o): return o | self.get()
|
def __ror__(self, o): return o | self._get_as_operand()
|
||||||
def __xor__(self, o): return self.get() ^ o
|
def __xor__(self, o): return self._get_as_operand() ^ o
|
||||||
def __rxor__(self, o): return o ^ self.get()
|
def __rxor__(self, o): return o ^ self._get_as_operand()
|
||||||
def __getitem__(self, o): return self.get()[o]
|
def __getitem__(self, o): return self._get_as_operand()[o]
|
||||||
def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo)
|
def __pow__(self, o, modulo=None):
|
||||||
def __rpow__(self, o): return pow(o, self.get())
|
return pow(self._get_as_operand(), o, modulo)
|
||||||
def __invert__(self): return ~self.get()
|
def __rpow__(self, o): return pow(o, self._get_as_operand())
|
||||||
def __neg__(self): return -self.get()
|
def __invert__(self): return ~self._get_as_operand()
|
||||||
def __abs__(self): return abs(self.get())
|
def __neg__(self): return -self._get_as_operand()
|
||||||
|
def __abs__(self): return abs(self._get_as_operand())
|
||||||
|
|
||||||
def __div__(self, o):
|
def __div__(self, o):
|
||||||
try:
|
try:
|
||||||
return self.get().__div__(o)
|
return self._get_as_operand().__div__(o)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __rdiv__(self, o):
|
def __rdiv__(self, o):
|
||||||
try:
|
try:
|
||||||
return self.get().__rdiv__(o)
|
return self._get_as_operand().__rdiv__(o)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __matmul__(self, o):
|
def __matmul__(self, o):
|
||||||
try:
|
try:
|
||||||
return self.get().__matmul__(o)
|
return self._get_as_operand().__matmul__(o)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __rmatmul__(self, o):
|
def __rmatmul__(self, o):
|
||||||
try:
|
try:
|
||||||
return self.get().__rmatmul__(o)
|
return self._get_as_operand().__rmatmul__(o)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
@ -920,15 +930,19 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
return _MirroredSaveable(self, self.primary, name)
|
return _MirroredSaveable(self, self.primary, name)
|
||||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||||
|
|
||||||
|
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||||
# Register a conversion function which reads the value of the variable,
|
"""Converts a variable to a tensor."""
|
||||||
# allowing instances of the class to be used as tensors.
|
|
||||||
def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
|
|
||||||
# Try to avoid assignments to and other mutations of MirroredVariable
|
# Try to avoid assignments to and other mutations of MirroredVariable
|
||||||
# state except through a DistributionStrategy.extended.update() call.
|
# state except through a DistributionStrategy.extended.update() call.
|
||||||
assert not as_ref
|
assert not as_ref
|
||||||
return ops.internal_convert_to_tensor(
|
return ops.internal_convert_to_tensor(
|
||||||
var.get(), dtype=dtype, name=name, as_ref=as_ref)
|
self.get(), dtype=dtype, name=name, as_ref=as_ref)
|
||||||
|
|
||||||
|
|
||||||
|
# Register a conversion function which reads the value of the variable,
|
||||||
|
# allowing instances of the class to be used as tensors.
|
||||||
|
def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
|
||||||
|
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
ops.register_tensor_conversion_function(MirroredVariable,
|
ops.register_tensor_conversion_function(MirroredVariable,
|
||||||
@ -947,33 +961,17 @@ def _enclosing_tpu_context():
|
|||||||
|
|
||||||
def is_distributed_variable(v):
|
def is_distributed_variable(v):
|
||||||
"""Determine if a variable is ds variable or TPU mirrored variable."""
|
"""Determine if a variable is ds variable or TPU mirrored variable."""
|
||||||
return (isinstance(v, DistributedVariable)
|
return isinstance(v, DistributedVariable)
|
||||||
or isinstance(v, TPUMirroredVariable))
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(jhseu): Deduplicate code. We copy code because we don't want to
|
class TPUMirroredVariable(MirroredVariable):
|
||||||
# inherit from DistributedDelegate. DistributedDelegate will not work in a
|
|
||||||
# tpu.replicate() because it assumes that you're in a device context where you
|
|
||||||
# can operate on a single version of the variable, but a tpu.replicate()
|
|
||||||
# operates on all variables and is replicated during a rewrite pass.
|
|
||||||
class TPUMirroredVariable(variables_lib.Variable):
|
|
||||||
"""Holds a map from device to TPU variables whose values are kept in sync."""
|
"""Holds a map from device to TPU variables whose values are kept in sync."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, strategy, device_map, values, aggregation, logical_device=None):
|
self, strategy, device_map, values, aggregation, logical_device=None):
|
||||||
assert isinstance(device_map, DeviceMap)
|
super(TPUMirroredVariable, self).__init__(
|
||||||
self._distribute_strategy = strategy
|
strategy=strategy, device_map=device_map, values=values,
|
||||||
self._device_map = device_map
|
aggregation=aggregation, logical_device=logical_device)
|
||||||
self._values = tuple(values)
|
|
||||||
if logical_device is None:
|
|
||||||
logical_device = device_map.logical_device_from_values(self._values)
|
|
||||||
self._logical_device = logical_device
|
|
||||||
|
|
||||||
# Use a weakref to make it easy to map from the contained values
|
|
||||||
# to the container without introducing a reference cycle.
|
|
||||||
for v in self._values:
|
|
||||||
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
|
|
||||||
self._common_name = self.primary.name.split(":")[0]
|
|
||||||
|
|
||||||
# Handle id is needed for get_replicated_var_handle to cache the variables
|
# Handle id is needed for get_replicated_var_handle to cache the variables
|
||||||
# correctly since in eager mode different variables can have the same name.
|
# correctly since in eager mode different variables can have the same name.
|
||||||
@ -982,28 +980,21 @@ class TPUMirroredVariable(variables_lib.Variable):
|
|||||||
else:
|
else:
|
||||||
self._handle_id = self._common_name
|
self._handle_id = self._common_name
|
||||||
|
|
||||||
self._aggregation = aggregation
|
def get(self, device=None):
|
||||||
# Needed for GradientTape
|
if (_enclosing_tpu_context() is None) or (device is not None):
|
||||||
self._trainable = self.primary.trainable
|
return super(TPUMirroredVariable, self).get(device=device)
|
||||||
# Typically like `DistributedVariable`, a `TPUMirroredVariable`'s
|
|
||||||
# initializer is composed of the initializers of the components variables.
|
|
||||||
# However, in some cases, such as when restoring from a checkpoint, we may
|
|
||||||
# set the _initializer_op property on the entire `TPUMirroredVariable`.
|
|
||||||
self._initializer_op = None
|
|
||||||
|
|
||||||
def _get(self, device=None):
|
|
||||||
"""Returns the value for the current device or raises a ValueError."""
|
|
||||||
if device is None:
|
|
||||||
replica_context = distribution_strategy_context.get_replica_context()
|
|
||||||
if replica_context:
|
|
||||||
return self._device_map.select_for_current_replica(
|
|
||||||
self._values, replica_context)
|
|
||||||
else:
|
else:
|
||||||
device = distribute_lib.get_update_device()
|
raise NotImplementedError(
|
||||||
if device is None:
|
"`TPUMirroredVariable.get()` is not supported within a TPU context.")
|
||||||
return self._get_cross_replica()
|
|
||||||
device = device_util.canonicalize(device)
|
def _get_as_operand(self):
|
||||||
return self._device_map.select_for_device(self._values, device)
|
return self.read_value()
|
||||||
|
|
||||||
|
def _get_closest(self):
|
||||||
|
if _enclosing_tpu_context() is None:
|
||||||
|
return super(TPUMirroredVariable, self)._get_closest()
|
||||||
|
else:
|
||||||
|
return self.primary
|
||||||
|
|
||||||
def numpy(self):
|
def numpy(self):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
@ -1011,174 +1002,20 @@ class TPUMirroredVariable(variables_lib.Variable):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"numpy() is only available when eager execution is enabled.")
|
"numpy() is only available when eager execution is enabled.")
|
||||||
|
|
||||||
def initialized_value(self):
|
|
||||||
return self.primary.initialized_value()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def initial_value(self):
|
|
||||||
return self.primary.initial_value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def primary(self):
|
|
||||||
"""Returns a representative component."""
|
|
||||||
return self._values[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def devices(self):
|
|
||||||
return self._device_map.logical_to_actual_devices(self._logical_device)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def logical_device(self):
|
|
||||||
return self._logical_device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device_map(self):
|
|
||||||
return self._device_map
|
|
||||||
|
|
||||||
# TODO(josh11b): Replace experimental_local_results with this?
|
|
||||||
@property
|
|
||||||
def values(self):
|
|
||||||
return self._values
|
|
||||||
|
|
||||||
@property
|
|
||||||
def distribute_strategy(self):
|
|
||||||
return self._distribute_strategy
|
|
||||||
|
|
||||||
# pylint: disable=multiple-statements
|
|
||||||
def __add__(self, o): return self.read_value() + o
|
|
||||||
def __radd__(self, o): return o + self.read_value()
|
|
||||||
def __sub__(self, o): return self.read_value() - o
|
|
||||||
def __rsub__(self, o): return o - self.read_value()
|
|
||||||
def __mul__(self, o): return self.read_value() * o
|
|
||||||
def __rmul__(self, o): return o * self.read_value()
|
|
||||||
def __truediv__(self, o): return self.read_value() / o
|
|
||||||
def __rtruediv__(self, o): return o / self.read_value()
|
|
||||||
def __floordiv__(self, o): return self.read_value() // o
|
|
||||||
def __rfloordiv__(self, o): return o // self.read_value()
|
|
||||||
def __mod__(self, o): return self.read_value() % o
|
|
||||||
def __rmod__(self, o): return o % self.read_value()
|
|
||||||
def __lt__(self, o): return self.read_value() < o
|
|
||||||
def __le__(self, o): return self.read_value() <= o
|
|
||||||
def __gt__(self, o): return self.read_value() > o
|
|
||||||
def __ge__(self, o): return self.read_value() >= o
|
|
||||||
def __and__(self, o): return self.read_value() & o
|
|
||||||
def __rand__(self, o): return o & self.read_value()
|
|
||||||
def __or__(self, o): return self.read_value() | o
|
|
||||||
def __ror__(self, o): return o | self.read_value()
|
|
||||||
def __xor__(self, o): return self.read_value() ^ o
|
|
||||||
def __rxor__(self, o): return o ^ self.read_value()
|
|
||||||
def __getitem__(self, o): return self.read_value()[o]
|
|
||||||
def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo)
|
|
||||||
def __rpow__(self, o): return pow(o, self.read_value())
|
|
||||||
def __invert__(self): return ~self.read_value()
|
|
||||||
def __neg__(self): return -self.read_value()
|
|
||||||
def __abs__(self): return abs(self.read_value())
|
|
||||||
|
|
||||||
def __div__(self, o):
|
|
||||||
try:
|
|
||||||
return self.read_value().__div__(o)
|
|
||||||
except AttributeError:
|
|
||||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __rdiv__(self, o):
|
|
||||||
try:
|
|
||||||
return self.read_value().__rdiv__(o)
|
|
||||||
except AttributeError:
|
|
||||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __matmul__(self, o):
|
|
||||||
try:
|
|
||||||
return self.read_value().__matmul__(o)
|
|
||||||
except AttributeError:
|
|
||||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __rmatmul__(self, o):
|
|
||||||
try:
|
|
||||||
return self.read_value().__rmatmul__(o)
|
|
||||||
except AttributeError:
|
|
||||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
devices = self.devices
|
|
||||||
debug_str = ",\n".join(" %d %s: %s" % (i, devices[i], self._values[i])
|
|
||||||
for i in range(len(devices)))
|
|
||||||
return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
devices = self.devices
|
|
||||||
debug_repr = ",\n".join(" %d %s: %r" % (i, devices[i], self._values[i])
|
|
||||||
for i in range(len(devices)))
|
|
||||||
return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def handle(self):
|
def handle(self):
|
||||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||||
tpu_context = _enclosing_tpu_context()
|
tpu_context = _enclosing_tpu_context()
|
||||||
if tpu_context is not None:
|
if tpu_context is None:
|
||||||
|
return self._get_closest().handle
|
||||||
|
else:
|
||||||
return tpu_context.get_replicated_var_handle(
|
return tpu_context.get_replicated_var_handle(
|
||||||
self._handle_id, self._values)
|
self._handle_id, self._values)
|
||||||
|
|
||||||
device = distribute_lib.get_update_device()
|
|
||||||
if device is None:
|
|
||||||
return self.primary.handle
|
|
||||||
return self._get(device=device).handle
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return self.handle.device
|
return self.handle.device
|
||||||
|
|
||||||
def eval(self, session=None):
|
|
||||||
return self.primary.eval(session)
|
|
||||||
|
|
||||||
# The arguments to update() are automatically unwrapped so the update()
|
|
||||||
# function would normally see regular variables, not MirroredVariables.
|
|
||||||
# However, the update function can still operate on wrapped MirroredVariables
|
|
||||||
# through object members, captured arguments, etc. This is more likely in an
|
|
||||||
# update_non_slot() function (like OptimizerV2._finish), which can
|
|
||||||
# update several non-slot variables in one call.
|
|
||||||
def _assign_func(self, *args, **kwargs):
|
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
|
||||||
f = kwargs.pop("f")
|
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
|
||||||
if _enclosing_tpu_context() is not None:
|
|
||||||
return self._distribute_strategy.extended.update(
|
|
||||||
self, f, args=args, kwargs=kwargs)
|
|
||||||
|
|
||||||
update_device = distribute_lib.get_update_device()
|
|
||||||
# We are calling update on the mirrored variable in cross replica
|
|
||||||
# context.
|
|
||||||
if update_device is not None:
|
|
||||||
# We are calling an assign function on the mirrored variable in cross
|
|
||||||
# replica context.
|
|
||||||
v = self._get(device=update_device)
|
|
||||||
return f(v, *args, **kwargs)
|
|
||||||
|
|
||||||
return self._distribute_strategy.extended.update(
|
|
||||||
self, f, args=args, kwargs=kwargs)
|
|
||||||
else:
|
|
||||||
_assert_replica_context(self._distribute_strategy)
|
|
||||||
# We are calling an assign function on the mirrored variable in replica
|
|
||||||
# context.
|
|
||||||
# We reduce the value we want to assign/add/sub. More details about how
|
|
||||||
# we handle the different use cases can be found in the _reduce method.
|
|
||||||
# We call the function on each of the mirrored variables with the
|
|
||||||
# reduced value.
|
|
||||||
if self._aggregation == vs.VariableAggregation.NONE:
|
|
||||||
raise ValueError(_aggregation_error_msg.format(
|
|
||||||
variable_type="TPUMirroredVariable"))
|
|
||||||
|
|
||||||
def merge_fn(strategy, value, *other_args, **other_kwargs):
|
|
||||||
v = _apply_aggregation(strategy, value, self._aggregation, self)
|
|
||||||
return strategy.extended.update(
|
|
||||||
self, f, args=(v,) + other_args, kwargs=other_kwargs)
|
|
||||||
|
|
||||||
return distribution_strategy_context.get_replica_context().merge_call(
|
|
||||||
merge_fn, args=args, kwargs=kwargs)
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _handle_graph(self, handle):
|
def _handle_graph(self, handle):
|
||||||
# Note: might have an eager tensor but not be executing eagerly when
|
# Note: might have an eager tensor but not be executing eagerly when
|
||||||
@ -1190,10 +1027,6 @@ class TPUMirroredVariable(variables_lib.Variable):
|
|||||||
with handle.graph.as_default():
|
with handle.graph.as_default():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@property
|
|
||||||
def trainable(self):
|
|
||||||
return self._trainable
|
|
||||||
|
|
||||||
def _read_variable_op(self, parent_op=None):
|
def _read_variable_op(self, parent_op=None):
|
||||||
if self.trainable:
|
if self.trainable:
|
||||||
tape.variable_accessed(self)
|
tape.variable_accessed(self)
|
||||||
@ -1208,156 +1041,64 @@ class TPUMirroredVariable(variables_lib.Variable):
|
|||||||
def read_value(self):
|
def read_value(self):
|
||||||
return self._read_variable_op()
|
return self._read_variable_op()
|
||||||
|
|
||||||
def assign_sub(self, *args, **kwargs):
|
def _assign_func(self, *args, **kwargs):
|
||||||
def assign_sub_fn(var, delta, *ar, **kw):
|
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||||
del ar
|
if (distribution_strategy_context.in_cross_replica_context()
|
||||||
name = kw.pop("name", None)
|
and (_enclosing_tpu_context() is not None)):
|
||||||
read_value = kw.pop("read_value", True)
|
f = kwargs.pop("f")
|
||||||
with self._handle_graph(var.handle):
|
return self._distribute_strategy.extended.update(
|
||||||
op = gen_resource_variable_ops.assign_sub_variable_op(
|
self, f, args=args, kwargs=kwargs)
|
||||||
var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
|
else:
|
||||||
name=name)
|
return super(TPUMirroredVariable, self)._assign_func(*args, **kwargs)
|
||||||
if read_value:
|
|
||||||
return self._read_variable_op(parent_op=op)
|
|
||||||
return op
|
|
||||||
|
|
||||||
|
def _make_raw_assign_fn(self, raw_assign_fn):
|
||||||
|
def assign_fn(var, value, *args, **kwargs):
|
||||||
|
del args
|
||||||
|
name = kwargs.pop("name", None)
|
||||||
|
read_value = kwargs.pop("read_value", True)
|
||||||
|
with self._handle_graph(var.handle):
|
||||||
|
op = raw_assign_fn(
|
||||||
|
var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
|
||||||
|
name=name)
|
||||||
|
return self._read_variable_op(parent_op=op) if read_value else op
|
||||||
|
return assign_fn
|
||||||
|
|
||||||
|
def assign_sub(self, *args, **kwargs):
|
||||||
|
assign_sub_fn = self._make_raw_assign_fn(
|
||||||
|
gen_resource_variable_ops.assign_sub_variable_op)
|
||||||
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
|
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
|
||||||
|
|
||||||
def assign_add(self, *args, **kwargs):
|
def assign_add(self, *args, **kwargs):
|
||||||
def assign_add_fn(var, delta, *ar, **kw):
|
assign_add_fn = self._make_raw_assign_fn(
|
||||||
del ar
|
gen_resource_variable_ops.assign_add_variable_op)
|
||||||
name = kw.pop("name", None)
|
|
||||||
read_value = kw.pop("read_value", True)
|
|
||||||
with self._handle_graph(var.handle):
|
|
||||||
op = gen_resource_variable_ops.assign_add_variable_op(
|
|
||||||
var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
|
|
||||||
name=name)
|
|
||||||
if read_value:
|
|
||||||
return self._read_variable_op(parent_op=op)
|
|
||||||
return op
|
|
||||||
|
|
||||||
return self._assign_func(f=assign_add_fn, *args, **kwargs)
|
return self._assign_func(f=assign_add_fn, *args, **kwargs)
|
||||||
|
|
||||||
def assign(self, *args, **kwargs):
|
def assign(self, *args, **kwargs):
|
||||||
def assign_fn(var, value, *ar, **kw):
|
assign_fn = self._make_raw_assign_fn(
|
||||||
del ar
|
gen_resource_variable_ops.assign_variable_op)
|
||||||
name = kw.pop("name", None)
|
|
||||||
read_value = kw.pop("read_value", True)
|
|
||||||
with self._handle_graph(var.handle):
|
|
||||||
op = gen_resource_variable_ops.assign_variable_op(
|
|
||||||
var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
|
|
||||||
name=name)
|
|
||||||
if read_value:
|
|
||||||
return self._read_variable_op(parent_op=op)
|
|
||||||
return op
|
|
||||||
|
|
||||||
return self._assign_func(f=assign_fn, *args, **kwargs)
|
return self._assign_func(f=assign_fn, *args, **kwargs)
|
||||||
|
|
||||||
@property
|
|
||||||
def aggregation(self):
|
|
||||||
return self._aggregation
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def constraint(self):
|
def constraint(self):
|
||||||
return self.primary.constraint
|
return self.primary.constraint
|
||||||
|
|
||||||
@property
|
|
||||||
def initializer(self):
|
|
||||||
if self._initializer_op:
|
|
||||||
init_op = self._initializer_op
|
|
||||||
else:
|
|
||||||
init_op = control_flow_ops.group(tuple(
|
|
||||||
v.initializer for v in self._values))
|
|
||||||
return init_op
|
|
||||||
|
|
||||||
@property
|
|
||||||
def graph(self):
|
|
||||||
return self.primary.graph
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _shared_name(self):
|
|
||||||
return self._common_name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _unique_id(self):
|
|
||||||
return self.primary._unique_id # pylint: disable=protected-access
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return self.primary.name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.primary.dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
return self.primary.shape
|
|
||||||
|
|
||||||
def get_shape(self):
|
|
||||||
return self.primary.get_shape()
|
|
||||||
|
|
||||||
def to_proto(self, export_scope=None):
|
|
||||||
return self.primary.to_proto(export_scope=export_scope)
|
|
||||||
|
|
||||||
def _get_cross_replica(self):
|
|
||||||
device = device_util.canonicalize(device_util.current())
|
|
||||||
replica = self._device_map.replica_for_device(device)
|
|
||||||
if replica is None:
|
|
||||||
return self.primary
|
|
||||||
return self._values[replica]
|
|
||||||
|
|
||||||
def _as_graph_element(self):
|
def _as_graph_element(self):
|
||||||
# pylint: disable=protected-access
|
|
||||||
if _enclosing_tpu_context() is None:
|
if _enclosing_tpu_context() is None:
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
return super(TPUMirroredVariable, self)._as_graph_element() # pylint: disable=protected-access
|
||||||
return self.primary._as_graph_element()
|
else:
|
||||||
return self._get()._as_graph_element()
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _gather_saveables_for_checkpoint(self):
|
|
||||||
"""Overrides Trackable method.
|
|
||||||
|
|
||||||
This allows both name-based and object-based save and restore of
|
|
||||||
MirroredVariables.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary mapping attribute names to `SaveableObject` factories.
|
|
||||||
"""
|
|
||||||
def _saveable_factory(name=self._common_name):
|
|
||||||
return _MirroredSaveable(self, self.primary, name)
|
|
||||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
|
||||||
|
|
||||||
def _should_act_as_resource_variable(self):
|
|
||||||
"""Pass resource_variable_ops.is_resource_variable check."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Needed to pass ResourceVariable checks.
|
# Needed to pass ResourceVariable checks.
|
||||||
@property
|
@property
|
||||||
def op(self):
|
def op(self):
|
||||||
return self.primary.op
|
return self.primary.op
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
@property
|
|
||||||
def _save_slice_info(self):
|
|
||||||
return self.primary._save_slice_info
|
|
||||||
|
|
||||||
def _get_save_slice_info(self):
|
|
||||||
return self.primary._get_save_slice_info()
|
|
||||||
|
|
||||||
def _set_save_slice_info(self, save_slice_info):
|
|
||||||
return self.primary._set_save_slice_info(save_slice_info)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _in_graph_mode(self):
|
|
||||||
return self.primary._in_graph_mode # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||||
"""Converts a variable to a tensor."""
|
"""Converts a variable to a tensor."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if _enclosing_tpu_context() is None:
|
if _enclosing_tpu_context() is None:
|
||||||
return self._get()._dense_var_to_tensor(dtype, name, as_ref)
|
return super(TPUMirroredVariable, self)._dense_var_to_tensor(
|
||||||
|
dtype, name, as_ref)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
if dtype is not None and dtype != self.dtype:
|
if dtype is not None and dtype != self.dtype:
|
||||||
return math_ops.cast(self.read_value(), dtype)
|
return math_ops.cast(self.read_value(), dtype)
|
||||||
@ -1366,40 +1107,6 @@ class TPUMirroredVariable(variables_lib.Variable):
|
|||||||
else:
|
else:
|
||||||
return self.read_value()
|
return self.read_value()
|
||||||
|
|
||||||
def is_initialized(self, name=None):
|
|
||||||
"""Identifies if all the component variables are initialized.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of the final `logical_and` op.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The op that evaluates to True or False depending on if all the
|
|
||||||
component variables are initialized.
|
|
||||||
"""
|
|
||||||
# TODO(jhseu): Do we need TPU context implementation?
|
|
||||||
|
|
||||||
result = self.primary.is_initialized()
|
|
||||||
# We iterate through the list of values except the last one to allow us to
|
|
||||||
# name the final `logical_and` op the same name that is passed by the user
|
|
||||||
# to the `is_initialized` op. For distributed variables, the
|
|
||||||
# `is_initialized` op is a `logical_and` op.
|
|
||||||
for v in self._values[1:-1]:
|
|
||||||
result = math_ops.logical_and(result, v.is_initialized())
|
|
||||||
result = math_ops.logical_and(result, self._values[-1].is_initialized(),
|
|
||||||
name=name)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# Register a conversion function which reads the value of the variable,
|
|
||||||
# allowing instances of the class to be used as tensors.
|
|
||||||
def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False):
|
|
||||||
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
|
|
||||||
ops.register_tensor_conversion_function(TPUMirroredVariable,
|
|
||||||
_tensor_conversion_tpu_mirrored)
|
|
||||||
ops.register_dense_tensor_like_type(TPUMirroredVariable)
|
|
||||||
|
|
||||||
|
|
||||||
class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
|
class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||||
"""Class for defining how to restore a SyncOnReadVariable."""
|
"""Class for defining how to restore a SyncOnReadVariable."""
|
||||||
|
@ -996,7 +996,7 @@ def _var_key(var):
|
|||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
# Get the distributed variable if it exists.
|
# Get the distributed variable if it exists.
|
||||||
if getattr(var, "_distributed_container", None) is not None:
|
if hasattr(var, "_distributed_container"):
|
||||||
var = var._distributed_container()
|
var = var._distributed_container()
|
||||||
if var._in_graph_mode:
|
if var._in_graph_mode:
|
||||||
return var._shared_name
|
return var._shared_name
|
||||||
|
Loading…
x
Reference in New Issue
Block a user